From f8be0c58a187620cb2f9961d24a97fe4a63ed59e Mon Sep 17 00:00:00 2001 From: depenglee1707 <154987252+depenglee1707@users.noreply.github.com> Date: Wed, 10 Apr 2024 13:57:44 +0800 Subject: [PATCH] fix issue: non-support streaming pipeline cannot work when call it as streaming (#84) --- llmserve/backend/llm/engines/generic.py | 4 ++-- llmserve/backend/llm/predictor.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/llmserve/backend/llm/engines/generic.py b/llmserve/backend/llm/engines/generic.py index b9cb7cf..00b6ff6 100644 --- a/llmserve/backend/llm/engines/generic.py +++ b/llmserve/backend/llm/engines/generic.py @@ -376,7 +376,7 @@ async def launch_engine( ) self.base_worker_group = worker_group - self.can_stream = await asyncio.gather(*[worker_group[0].can_stream.remote()]) + self.can_stream = ray.get(worker_group[0].can_stream.remote()) return worker_group async def predict( @@ -503,5 +503,5 @@ async def stream( f"Pipeline {self.args.model_config.initialization.pipeline} does not support streaming. Ignoring queue." ) yield await self.predict( - prompts, timeout_s=timeout_s, start_timestamp=start_timestamp + prompts, timeout_s=timeout_s, start_timestamp=start_timestamp, lock=lock ) \ No newline at end of file diff --git a/llmserve/backend/llm/predictor.py b/llmserve/backend/llm/predictor.py index 84f0fc9..b2dc369 100644 --- a/llmserve/backend/llm/predictor.py +++ b/llmserve/backend/llm/predictor.py @@ -174,7 +174,7 @@ async def _predict_async( Returns: A list of generated texts. """ - prediction = await self.engine.predict(prompts, timeout_s=timeout_s, start_timestamp=start_timestamp, lock = self._base_worker_group_lock) + prediction = await self.engine.predict(prompts, timeout_s=timeout_s, start_timestamp=start_timestamp, lock=self._base_worker_group_lock) return prediction async def _stream_async( @@ -197,7 +197,7 @@ async def _stream_async( Returns: A list of generated texts. """ - async for s in self.engine.stream(prompts, timeout_s=timeout_s, start_timestamp=start_timestamp, lock = self._base_worker_group_lock): + async for s in self.engine.stream(prompts, timeout_s=timeout_s, start_timestamp=start_timestamp, lock=self._base_worker_group_lock): yield s # Called by Serve to check the replica's health.