Skip to content

Commit

Permalink
[HotFix] Fix final output truncation with stop string + streaming (vl…
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored and dtrifiro committed Sep 16, 2024
1 parent bf7e710 commit 8d32eaf
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
26 changes: 21 additions & 5 deletions tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def should_do_global_cleanup_after_test(request) -> bool:


@pytest.mark.asyncio(scope="module")
async def test_asyncio_run(async_engine):
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_asyncio_run(async_engine, stop):

scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps
Expand All @@ -169,6 +170,7 @@ async def run(prompt: str):
temperature=0,
max_tokens=32,
min_tokens=32,
stop=stop,
)

output_count = 0
Expand Down Expand Up @@ -203,7 +205,8 @@ async def run(prompt: str):


@pytest.mark.asyncio(scope="module")
async def test_output_kinds(async_engine):
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_output_kinds(async_engine, stop):
"""Test that output_kind works as expected and that
results are equivalent across different kinds."""

Expand All @@ -214,6 +217,7 @@ async def test_output_kinds(async_engine):
temperature=0,
max_tokens=32,
min_tokens=32,
stop=stop,
)

async def run(prompt: str, kind: RequestOutputKind):
Expand All @@ -229,6 +233,8 @@ async def run(prompt: str, kind: RequestOutputKind):
final_output = output

assert final_output is not None
assert final_output.finished

return (final_output.prompt_token_ids,
final_output.outputs[0].token_ids,
final_output.outputs[0].text, output_count)
Expand All @@ -241,16 +247,18 @@ async def run_deltas(prompt: str):
output_tokens: List[int] = []
output_text = ""
output_count = 0
final_output = None
async for output in async_engine.generate(prompt,
params,
request_id=uid()):
token_ids = output.outputs[0].token_ids
text = output.outputs[0].text
final_output = output

# Ensure we get prompt ids iff we haven't yet received output tokens
if output_tokens:
assert 1 <= len(token_ids) <= num_scheduler_steps
assert text
assert stop or text
assert not output.prompt_token_ids
else:
assert output.prompt_token_ids
Expand All @@ -260,6 +268,10 @@ async def run_deltas(prompt: str):
output_text += text

output_count += 1

assert final_output is not None
assert final_output.finished

return prompt_tokens, output_tokens, output_text, output_count

results = await asyncio.gather(
Expand Down Expand Up @@ -291,14 +303,16 @@ async def run_deltas(prompt: str):


@pytest.mark.asyncio(scope="module")
async def test_cancellation(async_engine):
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_cancellation(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps

sampling_params = SamplingParams(
temperature=0,
min_tokens=13,
max_tokens=13,
stop=stop,
)

stop_at = 5 if num_scheduler_steps == 1 else 1
Expand All @@ -319,7 +333,8 @@ async def test_cancellation(async_engine):


@pytest.mark.asyncio(scope="module")
async def test_delayed_generator(async_engine):
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_delayed_generator(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config()

if scheduler_config.num_scheduler_steps != 1:
Expand All @@ -329,6 +344,7 @@ async def test_delayed_generator(async_engine):
temperature=0,
min_tokens=10,
max_tokens=10,
stop=stop,
)

stream = async_engine.generate("test3", sampling_params, request_id=uid())
Expand Down
4 changes: 3 additions & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,9 @@ def get_output_text_to_return(self, buffer_length: int,
if not delta:
return self.output_text[:-buffer_length] if truncate else (
self.output_text)
length = len(self.output_text) - buffer_length
length = len(self.output_text)
if truncate:
length -= buffer_length
last_offset = self._last_output_text_offset
if last_offset < length:
self._last_output_text_offset = length
Expand Down

0 comments on commit 8d32eaf

Please sign in to comment.