diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e33d5fb2dc247..1e759c9616061 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -1,19 +1,36 @@ """Benchmark the latency of processing a single batch of requests.""" import argparse import time -from pathlib import Path -from typing import Optional +import os import numpy as np import torch from tqdm import tqdm from vllm import LLM, SamplingParams +from vllm.anyscale.lora.utils import LoRARequest + +SAMPLE_PROMPTS = [ + "The president of the United States is", + "Hello, my name is", + "The capital of France is", + "The future of AI is", +] + + +def add_lora(llm, batch_size): + LORA_FILE1 = "/mnt/local_storage/lora/" + for i in range(batch_size): + lora_request = LoRARequest(lora_id=f"lora_{i + 1}", + lora_int_id=i + 1, + lora_local_path=LORA_FILE1) + assert llm.llm_engine.add_lora(lora_request) def main(args: argparse.Namespace): print(args) + # Process all the requests in a single batch if possible. # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. llm = LLM( @@ -21,14 +38,35 @@ def main(args: argparse.Namespace): tokenizer=args.tokenizer, quantization=args.quantization, tensor_parallel_size=args.tensor_parallel_size, + max_num_seqs=args.batch_size, + max_num_batched_tokens=40960, trust_remote_code=args.trust_remote_code, + load_format="dummy" if args.use_dummy_weights else "auto", + enable_lora=args.enable_lora, + enable_cuda_graph=args.enable_cuda_graph, + cuda_graph_cache_size=args.cuda_graph_cache_size, dtype=args.dtype, - enforce_eager=args.enforce_eager, + flash_style=args.flash_style, + max_chunked_prefill_len=args.max_chunked_prefill_len, + max_num_prompt_seqs=args.max_num_prompt_seqs, + block_size=32 if args.flash_style else args.block_size, + speculative_model=args.speculative_model, + num_speculative_tokens=args.num_speculative_tokens, + speculative_model_uses_tp_1=args.speculative_model_uses_tp_1, + ray_workers_use_nsight=args.run_with_nsight, + disable_shared_memory=args.disable_shared_memory, + worker_use_ray=args.worker_use_ray, + disable_log_stats=not args.log_engine_stats, ) + if args.enable_lora: + lora_request = add_lora(llm, args.batch_size) + else: + lora_request = None + sampling_params = SamplingParams( n=args.n, - temperature=0.0 if args.use_beam_search else 1.0, + temperature=0 if args.use_sample else 1.0, top_p=1.0, use_beam_search=args.use_beam_search, ignore_eos=True, @@ -37,44 +75,75 @@ def main(args: argparse.Namespace): print(sampling_params) dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size - def run_to_completion(profile_dir: Optional[str] = None): - if profile_dir: - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - on_trace_ready=torch.profiler.tensorboard_trace_handler( - str(profile_dir))) as p: - llm.generate(prompt_token_ids=dummy_prompt_token_ids, - sampling_params=sampling_params, - use_tqdm=False) - print(p.key_averages()) + def run_to_completion(): + start_time = time.perf_counter() + + if args.use_sample: + batch = ( + SAMPLE_PROMPTS * + (args.batch_size // len(SAMPLE_PROMPTS) + 1))[:args.batch_size] + outputs = llm.generate(prompts=batch, + sampling_params=sampling_params, + use_tqdm=False, + lora_request=lora_request) else: - start_time = time.perf_counter() - llm.generate(prompt_token_ids=dummy_prompt_token_ids, - sampling_params=sampling_params, - use_tqdm=False) - end_time = time.perf_counter() - latency = end_time - start_time - return latency + outputs = llm.generate(prompt_token_ids=dummy_prompt_token_ids, + sampling_params=sampling_params, + use_tqdm=False, + lora_request=lora_request) + + end_time = time.perf_counter() + + if args.verbose: + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print( + f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + latency = end_time - start_time + return latency + + if args.profile and args.enable_cuda_graph: + # Workaround to enable profiling cuda graphs. + # https://github.com/pytorch/pytorch/issues/75504#issuecomment-1467065935 + llm.llm_engine.start_profile( + profile_ray_workers=args.profile_ray_workers) + llm.llm_engine.stop_profile( + profile_ray_workers=args.profile_ray_workers) print("Warming up...") - run_to_completion(profile_dir=None) + run_to_completion() if args.profile: - profile_dir = args.profile_result_dir - if not profile_dir: - profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}" - print(f"Profiling (results will be saved to '{profile_dir}')...") - run_to_completion(profile_dir=args.profile_result_dir) - return + model_name = args.model.replace("/", "-") + profile_logdir_name = os.path.join( + args.profile_logdir, + f"{model_name}_tp-{args.tensor_parallel_size}_input-len{args.input_len}_output-len{args.output_len}_batch-size{args.batch_size}" + .lstrip("-")) + llm.llm_engine.start_profile( + profile_ray_workers=args.profile_ray_workers, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA + ], + on_trace_ready=torch.profiler.tensorboard_trace_handler( + profile_logdir_name), + with_stack=True) # Benchmark. latencies = [] for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): - latencies.append(run_to_completion(profile_dir=None)) + latencies.append(run_to_completion()) print(f'Avg latency: {np.mean(latencies)} seconds') + print( + f'Avg ITL: {1000*np.mean(latencies)/args.output_len:.02f} milliseconds' + ) + print(f'Peak Cuda memory: {torch.cuda.max_memory_allocated()}') + + if args.profile: + llm.llm_engine.stop_profile( + profile_ray_workers=args.profile_ray_workers, ) if __name__ == '__main__': @@ -85,12 +154,12 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'gptq', 'squeezellm', None], + choices=['awq', 'squeezellm', None], default=None) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--output-len', type=int, default=128) - parser.add_argument('--batch-size', type=int, default=8) + parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--n', type=int, default=1, @@ -103,6 +172,24 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument('--trust-remote-code', action='store_true', help='trust remote code from huggingface') + parser.add_argument('--enable-lora', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument('--enable-cuda-graph', + action='store_true', + help='enable cuda graph for decoding') + parser.add_argument('--cuda-graph-cache-size', + type=int, + default=200, + help='number of cuda graphs to cache') + parser.add_argument('--use-dummy-weights', + action='store_true', + help='use-dummy-weights') + parser.add_argument('--speculative-model', type=str, default=None) + parser.add_argument('--num-speculative-tokens', type=int, default=None) + parser.add_argument('--speculative-model-uses-tp-1', + action='store_true', + help='speculative model uses tp1') parser.add_argument( '--dtype', type=str, @@ -112,20 +199,33 @@ def run_to_completion(profile_dir: Optional[str] = None): 'The "auto" option will use FP16 precision ' 'for FP32 and FP16 models, and BF16 precision ' 'for BF16 models.') - parser.add_argument('--enforce-eager', + parser.add_argument('--run-with-nsight', action='store_true') + parser.add_argument('--profile', action='store_true') + parser.add_argument('--profile-logdir', type=str, default=None) + parser.add_argument('--profile-ray-workers', action='store_true') + parser.add_argument('--max-chunked-prefill-len', type=int, default=-1) + parser.add_argument('--max-num-prompt-seqs', type=int, default=1000) + parser.add_argument('--flash-style', action='store_true', - help='enforce eager mode and disable CUDA graph') - parser.add_argument( - '--profile', - action='store_true', - help='profile the generation process of a single batch') - parser.add_argument( - '--profile-result-dir', - type=str, - default=None, - help=( - 'path to save the pytorch profiler output. Can be visualized ' - 'with ui.perfetto.dev or Tensorboard.' - )) + help='enable flash attention') + parser.add_argument('--block-size', + type=int, + default=16, + help='block size of key/value cache') + parser.add_argument('--use-sample', + action='store_true', + help='use sample input instead of dummy input') + parser.add_argument('--disable-shared-memory', + action='store_true', + help='disable shared memory') + parser.add_argument('--verbose', + action='store_true', + help='print generated text') + parser.add_argument('--log-engine-stats', + action='store_true', + help='log engine stats') + parser.add_argument('--worker-use-ray', + action='store_true', + help='use Ray worker') args = parser.parse_args() main(args) diff --git a/tests/anyscale/__init__.py b/tests/anyscale/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/anyscale/utils.py b/tests/anyscale/utils.py new file mode 100644 index 0000000000000..9362f692ee39b --- /dev/null +++ b/tests/anyscale/utils.py @@ -0,0 +1,67 @@ + +import gc +import json +import logging +import os + +import boto3 +import ray +import torch + +from vllm.model_executor.parallel_utils.parallel_state import \ + destroy_model_parallel + +ENV_TOKEN_OVERRIDES = os.getenv("AVIARY_ENV_AWS_SECRET_NAME", + "aviary/env_overrides") + +logger = logging.getLogger(__name__) + + +def cleanup(): + # Revert to torch default after vllm modifications + torch.backends.cuda.matmul.allow_tf32 = False + torch.set_default_dtype(torch.float32) + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + ray.shutdown() + + +# Copied from aviary +class SecretManager: + + def __init__(self, secret_name: str = ENV_TOKEN_OVERRIDES): + self.secret_overrides = self.get_all_secrets(secret_name) + + def get_all_secrets(self, secret_name: str): + try: + aws_region_name = os.getenv("AWS_REGION", "us-west-2") + + # Create a Secrets Manager client + session = boto3.session.Session() + client = session.client(service_name="secretsmanager", + region_name=aws_region_name) + get_secret_value_response = client.get_secret_value( + SecretId=secret_name) + + # Decrypts secret using the associated KMS key. + secret = get_secret_value_response["SecretString"] + + secret_dict = json.loads(secret) + return secret_dict + except Exception as e: + print( + f"Unable to load env override secrets from {secret_name}. Using default secrets from env. {e}" + ) + return {} + + def override_secret(self, env_var_name: str, set_in_env=True): + # First read from env var, then from aws secrets + secret = os.getenv(env_var_name, + self.secret_overrides.get(env_var_name)) + if secret is None: + print(f"Secret {env_var_name} was not found.") + elif set_in_env: + os.environ[env_var_name] = secret + print(f"Secret {env_var_name} was set in the env.") + return secret diff --git a/tests/async_engine/api_server_async_engine.py b/tests/async_engine/api_server_async_engine.py index 1be76fdc8d868..309e1ae890551 100644 --- a/tests/async_engine/api_server_async_engine.py +++ b/tests/async_engine/api_server_async_engine.py @@ -14,13 +14,14 @@ class AsyncLLMEngineWithStats(AsyncLLMEngine): + # pylint: disable=redefined-outer-name def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._num_aborts = 0 async def abort(self, request_id: str) -> None: - await super().abort(request_id) self._num_aborts += 1 + await super().abort(request_id) def testing_stats(self) -> Dict[str, Any]: return {"num_aborted_requests": self._num_aborts} diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index d90ba37b27bb9..af6c4ec20a338 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -1,3 +1,4 @@ + import subprocess import sys import time @@ -24,14 +25,16 @@ def _query_server(prompt: str) -> dict: def api_server(): script_path = Path(__file__).parent.joinpath( "api_server_async_engine.py").absolute() + # pylint: disable=consider-using-with uvicorn_process = subprocess.Popen([ sys.executable, "-u", - str(script_path), "--model", "facebook/opt-125m" + str(script_path), "--model", "facebook/opt-125m", "--worker-use-ray" ]) yield uvicorn_process.terminate() +# pylint: disable=redefined-outer-name, unused-argument def test_api_server(api_server): """ Run the API server and test it. @@ -47,10 +50,11 @@ def test_api_server(api_server): prompts = ["Hello world"] * 1 result = None while not result: + # pylint: disable=bare-except try: - for _ in pool.map(_query_server, prompts): + for result in pool.map(_query_server, prompts): break - except Exception: + except: time.sleep(1) # Actual tests start here @@ -58,8 +62,9 @@ def test_api_server(api_server): for result in pool.map(_query_server, prompts): assert result - num_aborted_requests = requests.get( - "http://localhost:8000/stats").json()["num_aborted_requests"] + # check stats + metadata = requests.get("http://localhost:8000/stats").json() + num_aborted_requests = metadata["num_aborted_requests"] assert num_aborted_requests == 0 # Try with 100 prompts @@ -69,13 +74,15 @@ def test_api_server(api_server): # Cancel requests pool.map_async(_query_server, prompts) - time.sleep(0.01) + time.sleep(0.001) pool.terminate() pool.join() + time.sleep(0.1) - # check cancellation stats - num_aborted_requests = requests.get( - "http://localhost:8000/stats").json()["num_aborted_requests"] + # check stats + metadata = requests.get("http://localhost:8000/stats").json() + num_aborted_requests = metadata["num_aborted_requests"] + last_num_aborted_requests = num_aborted_requests assert num_aborted_requests > 0 # check that server still runs after cancellations @@ -84,3 +91,8 @@ def test_api_server(api_server): prompts = ["Hello world"] * 100 for result in pool.map(_query_server, prompts): assert result + + # check stats + metadata = requests.get("http://localhost:8000/stats").json() + num_aborted_requests = metadata["num_aborted_requests"] + assert num_aborted_requests == last_num_aborted_requests diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 174975802dc0d..977fe77f94fcf 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -1,3 +1,4 @@ + import asyncio from dataclasses import dataclass @@ -25,6 +26,9 @@ async def step_async(self): return [RequestOutput( request_id=self.request_id)] if self.request_id else [] + async def encode_request_async(self, *args, **kwargs): + pass + def generate(self, request_id): self.request_id = request_id @@ -35,6 +39,10 @@ def add_request(self, **kwargs): del kwargs # Unused self.add_request_calls += 1 + async def add_request_async(self, **kwargs): + self.add_request_calls += 1 + return + def abort_request(self, request_id): del request_id # Unused self.abort_request_calls += 1 diff --git a/tests/async_engine/test_request_tracker.py b/tests/async_engine/test_request_tracker.py index 3e4d53c5cbe23..41ee6181b22c8 100644 --- a/tests/async_engine/test_request_tracker.py +++ b/tests/async_engine/test_request_tracker.py @@ -1,3 +1,4 @@ + import pytest from vllm.engine.async_llm_engine import RequestTracker @@ -64,7 +65,13 @@ def test_request_tracker(): stream_5 = tracker.add_request("5") assert tracker.new_requests_event.flag tracker.process_request_output( - RequestOutput("2", "output", [], [], [], finished=True)) + RequestOutput("2", + "output", [], [], [], + finished=True, + arrival_time=0.0, + first_scheduled_time=0.0, + first_token_time=0.0, + time_in_queue=0.0)) new, finished = tracker.get_new_and_finished_requests() assert not tracker.new_requests_event.flag assert len(finished) == 1 diff --git a/tests/conftest.py b/tests/conftest.py index 16c04e01d703c..577b0764ee9f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,39 +1,42 @@ -import os + +import gc from typing import List, Optional, Tuple import pytest +import ray import torch from transformers import AutoModelForCausalLM from vllm import LLM, SamplingParams +from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel from vllm.transformers_utils.tokenizer import get_tokenizer -_TEST_PROMPTS = ["prompts/example.txt"] -_LONG_PROMPTS = ["prompts/summary.txt"] - - -def _read_prompts(filename: str) -> str: - prompts = [] - with open(filename, "r") as f: - prompt = f.readline() - prompts.append(prompt) - return prompts +_TEST_PROMPTS = [ + # pylint: disable=line-too-long + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", + "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", + "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", + "Describe the basic components of a neural network and how it can be trained.", + "Write a short story about a robot that dreams for the first time.", + "Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.", + "Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.", + "Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'", +] + + +def cleanup(): + # Revert to torch default after vllm modifications + torch.backends.cuda.matmul.allow_tf32 = False + torch.set_default_dtype(torch.float32) + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + ray.shutdown() @pytest.fixture def example_prompts() -> List[str]: - prompts = [] - for filename in _TEST_PROMPTS: - prompts += _read_prompts(os.path.join("tests", filename)) - return prompts - - -@pytest.fixture -def example_long_prompts() -> List[str]: - prompts = [] - for filename in _LONG_PROMPTS: - prompts += _read_prompts(os.path.join("tests", filename)) - return prompts + return _TEST_PROMPTS _STR_DTYPE_TO_TORCH_DTYPE = { @@ -57,7 +60,8 @@ def __init__( model_name, torch_dtype=torch_dtype, trust_remote_code=True, - ).cuda() + device_map={"": 0}, + ) if tokenizer_name is None: tokenizer_name = model_name self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True) @@ -97,27 +101,6 @@ def generate_greedy( outputs[i] = (output_ids[0], output_str[0]) return outputs - def generate_beam_search( - self, - prompts: List[str], - beam_width: int, - max_tokens: int, - ) -> List[Tuple[List[int], str]]: - outputs = self.generate(prompts, - do_sample=False, - max_new_tokens=max_tokens, - num_beams=beam_width, - num_return_sequences=beam_width) - for i in range(len(outputs)): - output_ids, output_str = outputs[i] - for j in range(len(output_ids)): - output_ids[j] = [ - x for x in output_ids[j] - if x != self.tokenizer.pad_token_id - ] - outputs[i] = (output_ids, output_str) - return outputs - def generate_greedy_logprobs( self, prompts: List[str], @@ -154,7 +137,8 @@ def generate_greedy_logprobs( @pytest.fixture def hf_runner(): - return HfRunner + yield HfRunner + cleanup() class VllmRunner: @@ -164,6 +148,16 @@ def __init__( model_name: str, tokenizer_name: Optional[str] = None, dtype: str = "half", + enable_cuda_graph: bool = False, + cuda_graph_max_context_len: int = 5000, + cuda_graph_cache_size: int = 10, + tensor_parallel_size: int = 1, + flash_style: bool = False, + max_chunked_prefill_len: int = -1, + max_num_prompt_seqs: int = 1000, + max_num_batched_tokens: int = 4096, + worker_use_ray: bool = False, + input_padding_size: int = 8, ) -> None: self.model = LLM( model=model_name, @@ -171,12 +165,23 @@ def __init__( trust_remote_code=True, dtype=dtype, swap_space=0, - ) + enable_cuda_graph=enable_cuda_graph, + cuda_graph_max_context_len=cuda_graph_max_context_len, + cuda_graph_cache_size=cuda_graph_cache_size, + tensor_parallel_size=tensor_parallel_size, + flash_style=flash_style, + block_size=32, + max_chunked_prefill_len=max_chunked_prefill_len, + max_num_prompt_seqs=max_num_prompt_seqs, + max_num_batched_tokens=max_num_batched_tokens, + worker_use_ray=worker_use_ray, + input_padding_size=input_padding_size) def generate( self, prompts: List[str], sampling_params: SamplingParams, + return_output_only: bool = False, ) -> List[Tuple[List[int], str]]: req_outputs = self.model.generate(prompts, sampling_params=sampling_params) @@ -189,8 +194,12 @@ def generate( for sample in req_output.outputs: output_str = sample.text output_ids = sample.token_ids - req_sample_output_ids.append(prompt_ids + output_ids) - req_sample_output_strs.append(prompt_str + output_str) + if return_output_only: + req_sample_output_ids.append(output_ids) + req_sample_output_strs.append(output_str) + else: + req_sample_output_ids.append(prompt_ids + output_ids) + req_sample_output_strs.append(prompt_str + output_str) outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs @@ -204,20 +213,29 @@ def generate_greedy( return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] - def generate_beam_search( - self, - prompts: List[str], - beam_width: int, - max_tokens: int, - ) -> List[Tuple[List[int], str]]: - beam_search_params = SamplingParams(n=beam_width, - use_beam_search=True, - temperature=0.0, - max_tokens=max_tokens) - outputs = self.generate(prompts, beam_search_params) - return outputs - @pytest.fixture def vllm_runner(): - return VllmRunner + yield VllmRunner + cleanup() + + +@pytest.fixture +def setup_cuda_graph(model_name="facebook/opt-125m", cache_size=8): + vllm_model = VllmRunner(model_name, + dtype="half", + enable_cuda_graph=True, + cuda_graph_max_context_len=64, + cuda_graph_cache_size=cache_size) + nn_model = vllm_model.model.llm_engine.workers[0].model + cuda_graph = vllm_model.model.llm_engine.workers[0].captured_model + worker = vllm_model.model.llm_engine.workers[0] + yield vllm_model, nn_model, cuda_graph, worker + del vllm_model + del nn_model + del cuda_graph + del worker + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + ray.shutdown() diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py new file mode 100644 index 0000000000000..c77e7fdc7d2b8 --- /dev/null +++ b/tests/core/test_block_manager.py @@ -0,0 +1,309 @@ + +import pytest +import time +from typing import List + +from vllm import SamplingParams +from vllm.block import PhysicalTokenBlock +from vllm.core.block_manager import BlockAllocator, BlockSpaceManager, AllocStatus +from vllm.utils import Device +from vllm.sequence import Sequence, SequenceGroup, SequenceStatus +from tests.utils import round_up_to_next_block + +from .utils import create_dummy_prompt + + +def test_block_allocator_allocate(): + block_size = 4 + num_cpu_blocks = 4 + cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) + + # Allocate all available cpu blocks. + num_free = num_cpu_blocks + assert cpu_allocator.get_num_free_blocks() == num_free + for _ in range(num_cpu_blocks): + block = cpu_allocator.allocate() + num_free -= 1 + assert block not in cpu_allocator.free_blocks + assert cpu_allocator.get_num_free_blocks() == num_free + + with pytest.raises(ValueError): + cpu_allocator.allocate() + + +def test_block_allocator_free(): + block_size = 4 + num_cpu_blocks = 4 + cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) + + # Allocate all available cpu blocks. + blocks: List[PhysicalTokenBlock] = [] + for _ in range(num_cpu_blocks): + block = cpu_allocator.allocate() + blocks.append(block) + assert block not in cpu_allocator.free_blocks + + # Free all allocated cpu blocks. + num_free = 0 + assert cpu_allocator.get_num_free_blocks() == num_free + for block in blocks: + cpu_allocator.free(block) + num_free += 1 + assert block in cpu_allocator.free_blocks + assert cpu_allocator.get_num_free_blocks() == num_free + + with pytest.raises(ValueError): + cpu_allocator.free(block) + + +def test_allocate(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate same sequence group to all available gpu blocks. + for i in range(num_gpu_blocks): + _, seq_group = create_dummy_prompt(str(i), block_size) + assert block_manager.can_allocate(seq_group) + block_manager.allocate(seq_group) + assert block_manager.can_allocate(seq_group) != AllocStatus.OK + + # Allocate same sequence group to all available gpu blocks. + # Use watermark to reserve one gpu block. + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=1 / num_gpu_blocks) + for i in range(num_gpu_blocks - 1): + _, seq_group = create_dummy_prompt(str(i), block_size) + assert block_manager.can_allocate(seq_group) + block_manager.allocate(seq_group) + assert block_manager.can_allocate(seq_group) != AllocStatus.OK + + +def test_append_slot_single_seq(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate single seq to gpu block. + prompt, seq_group = create_dummy_prompt("1", block_size) + block_manager.allocate(seq_group) + + # Nothing to append. Sequence has no new logical blocks. + assert block_manager.can_append_slots(seq_group) + before_blocks = block_manager.get_num_free_gpu_blocks() + assert not block_manager.append_slots(prompt) + after_blocks = block_manager.get_num_free_gpu_blocks() + assert before_blocks == after_blocks + + # Add block_size number of new tokens and append slot. + for i in range(block_size): + prompt.append_token_id(i + 5, {i + 5: 0}) + + assert block_manager.can_append_slots(seq_group) + before_blocks = block_manager.get_num_free_gpu_blocks() + assert not block_manager.append_slots(prompt) + after_blocks = block_manager.get_num_free_gpu_blocks() + assert before_blocks - after_blocks == 1 + + +@pytest.mark.parametrize("prompt_len", [1, 10, 100]) +@pytest.mark.parametrize("num_unprocessed_tokens", [1, 10, 100]) +@pytest.mark.parametrize("block_size", [1, 8, 16, 32]) +def test_append_multiple_slot_single_seq(prompt_len: int, + num_unprocessed_tokens: int, + block_size: int): + """Verify correct allocation when multiple tokens need to be processed. + """ + num_cpu_blocks = 0 + num_gpu_blocks = 8192 // block_size + block_manager = BlockSpaceManager(block_size=block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + watermark=0) + + # Allocate single seq to gpu block. + prompt, seq_group = create_dummy_prompt(request_id="1", + prompt_length=prompt_len, + block_size=block_size) + block_manager.allocate(seq_group) + + # Nothing to append. Sequence has no new logical blocks. + assert block_manager.can_append_slots(seq_group) + before_blocks = block_manager.get_num_free_gpu_blocks() + assert not block_manager.append_slots(prompt) + after_blocks = block_manager.get_num_free_gpu_blocks() + assert before_blocks == after_blocks + + # Append new tokens, expect correct number of new blocks + new_token_ids = list(range(num_unprocessed_tokens)) + prompt.append_token_ids(new_token_ids, [{ + token_id: 0 + } for token_id in new_token_ids]) + + old_seq_len_in_blocks = round_up_to_next_block(prompt_len, block_size) + new_seq_len_in_blocks = round_up_to_next_block( + prompt_len + num_unprocessed_tokens, block_size) + num_expected_new_blocks = new_seq_len_in_blocks - old_seq_len_in_blocks + + assert block_manager.can_append_slots(seq_group) + before_blocks = block_manager.get_num_free_gpu_blocks() + assert not block_manager.append_slots(prompt) + after_blocks = block_manager.get_num_free_gpu_blocks() + assert before_blocks - after_blocks == num_expected_new_blocks + + +def test_append_slot_cow(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate prompt to gpu block. + prompt = Sequence(1, "one two three", [1, 2, 3], block_size) + child = prompt.fork(2) + child.append_token_id(4, {4: 0}) + seq_group = SequenceGroup("1", [prompt, child], SamplingParams(), + time.time(), time.perf_counter) + block_manager.allocate(seq_group) + + # Append slot for child token. + # Last block being modified is shared. Copy on write occurs. + assert block_manager.can_append_slots(seq_group) + before_blocks = block_manager.get_num_free_gpu_blocks() + cow_src_dst = block_manager.append_slots(child) + + assert len(cow_src_dst) > 0 + for src_block, dst_block in cow_src_dst.items(): + assert src_block != dst_block + + after_blocks = block_manager.get_num_free_gpu_blocks() + assert before_blocks - after_blocks == 1 + + +def test_fork(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + prompt, seq_group = create_dummy_prompt("1", + block_size - 1, + block_size=block_size) + block_manager.allocate(seq_group) + + # Fork prompt and copy block tables. + child = prompt.fork(2) + block_manager.fork(prompt, child) + assert block_manager.get_block_table( + prompt) == block_manager.get_block_table(child) + + # Append token to child. Block is shared so copy on write occurs. + child.append_token_id(4, {4: 0}) + block_manager.append_slots(child) + assert block_manager.get_block_table( + prompt) != block_manager.get_block_table(child) + + +def test_swap(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1) + prompt.status = SequenceStatus.RUNNING + block_manager.allocate(seq_group) + + # Emulate a forward pass by appending a single token. + # The block manager then knows how many unprocessed + # tokens will be written in the next forward pass. + prompt.append_token_id(0, {0: 0.0}) + + # Swap seq group from GPU -> CPU. + gpu_blocks = block_manager.get_block_table(prompt) + assert block_manager.can_swap_out(seq_group) + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_out(seq_group) + assert list(mapping.keys()) == gpu_blocks + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) + assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks + prompt.status = SequenceStatus.SWAPPED + + # Swap seq group from CPU -> GPU. + cpu_blocks = block_manager.get_block_table(prompt) + assert block_manager.can_swap_in(seq_group) + before_cpu_blocks = block_manager.get_num_free_cpu_blocks() + before_gpu_blocks = block_manager.get_num_free_gpu_blocks() + mapping = block_manager.swap_in(seq_group) + assert list(mapping.keys()) == cpu_blocks + after_cpu_blocks = block_manager.get_num_free_cpu_blocks() + after_gpu_blocks = block_manager.get_num_free_gpu_blocks() + assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks + assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) + + +def test_free(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + prompt, seq_group = create_dummy_prompt("1", block_size) + block_manager.allocate(seq_group) + + # Free allocated seq. + prompt_blocks = len(block_manager.get_block_table(prompt)) + before_blocks = block_manager.get_num_free_gpu_blocks() + block_manager.free(prompt) + after_blocks = block_manager.get_num_free_gpu_blocks() + assert after_blocks == before_blocks + prompt_blocks + + # Block table for freed seq is deleted. + with pytest.raises(KeyError): + block_manager.get_block_table(prompt) + + +def test_reset(): + block_size = 4 + num_cpu_blocks = 4 + num_gpu_blocks = 4 + block_manager = BlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0) + + # Allocate same seq group on all available gpu blocks. + original_blocks = block_manager.get_num_free_gpu_blocks() + for i in range(num_gpu_blocks): + _, seq_group = create_dummy_prompt(str(i), block_size) + block_manager.allocate(seq_group) + assert block_manager.get_num_free_gpu_blocks() == 0 + + # Resetting block manager frees all allocated blocks. + block_manager.reset() + assert block_manager.get_num_free_gpu_blocks() == original_blocks diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py new file mode 100644 index 0000000000000..edc3da8adf02b --- /dev/null +++ b/tests/core/test_scheduler.py @@ -0,0 +1,382 @@ + +from typing import List +import pytest + +from vllm.config import CacheConfig, SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.sequence import SequenceGroup +from tests.utils import round_up_to_next_block + +from .utils import create_dummy_prompt + + +def test_scheduler_add_seq_group(): + block_size = 4 + scheduler_config = SchedulerConfig(100, 64, 1) + cache_config = CacheConfig(block_size, 1.0, 1) + cache_config.num_cpu_blocks = 4 + cache_config.num_gpu_blocks = 4 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq group to scheduler. + num_seq_group = 4 + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), block_size) + scheduler.add_seq_group(seq_group) + assert scheduler.get_num_unfinished_seq_groups() == i + 1 + + +def test_scheduler_abort_seq_group(): + block_size = 4 + scheduler_config = SchedulerConfig(100, 64, 1) + cache_config = CacheConfig(block_size, 1.0, 1) + cache_config.num_cpu_blocks = 4 + cache_config.num_gpu_blocks = 4 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add multiple seq groups to scheduler. + num_seq_group = 4 + request_ids = set() + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), block_size) + scheduler.add_seq_group(seq_group) + request_ids.add(str(i)) + + # Abort all added seq groups. + assert scheduler.get_num_unfinished_seq_groups() == num_seq_group + scheduler.abort_seq_group(request_ids) + assert scheduler.get_num_unfinished_seq_groups() == 0 + + +def test_scheduler_schedule_simple(): + block_size = 4 + num_seq_group = 4 + max_model_len = 16 + scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) + cache_config = CacheConfig(block_size, 1.0, 1) + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq groups to scheduler. + running: List[SequenceGroup] = [] + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + num_processed_token_ids=block_size - + 1) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Schedule seq groups prompts. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(running) + assert out.num_batched_tokens == num_seq_group * seq_group.get_seqs( + )[0].get_len() + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == num_seq_group + + # Schedule seq groups generation. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(running) + assert out.num_batched_tokens == num_seq_group + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == num_seq_group + + +def test_scheduler_schedule_preempt_abort(): + block_size = 4 + max_model_len = 16 + scheduler_config = SchedulerConfig(64, 2, max_model_len) + cache_config = CacheConfig(block_size, 1.0, 1) + cache_config.num_cpu_blocks = 2 + cache_config.num_gpu_blocks = 2 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq groups to scheduler. + seq_a, seq_group_a = create_dummy_prompt("1", block_size) + seq_b, seq_group_b = create_dummy_prompt("2", block_size) + scheduler.add_seq_group(seq_group_a) + scheduler.add_seq_group(seq_group_b) + + # Schedule seq groups prompts. + seq_group_meta, out = scheduler.schedule() + assert out.scheduled_seq_groups == [seq_group_a, seq_group_b] + assert out.num_batched_tokens == seq_group_a.get_seqs()[0].get_len() * 2 + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 2 + assert scheduler.get_num_unfinished_seq_groups() == 2 + + # Append "generated" tokens, allowing the sequence to mark prompt tokens as + # processed. + seq_a.append_token_id(0, {0: 0.0}) + seq_b.append_token_id(0, {0: 0.0}) + + # Schedule seq groups generation and preempt seq group b. + seq_group_meta, out = scheduler.schedule() + assert out.scheduled_seq_groups == [seq_group_a] + assert out.num_batched_tokens == 1 + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 1 + assert scheduler.get_num_unfinished_seq_groups() == 2 + + # Abort seq group a. Re-schedule seq group b prompt with recomputation. + scheduler.abort_seq_group("1") + seq_group_meta, out = scheduler.schedule() + assert out.scheduled_seq_groups == [seq_group_b] + assert out.num_batched_tokens == seq_group_b.get_seqs()[0].get_len() + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 1 + assert scheduler.get_num_unfinished_seq_groups() == 1 + + +@pytest.mark.parametrize("block_size", [1, 8, 16, 32]) +@pytest.mark.parametrize("prompt_len", [1, 128]) +@pytest.mark.parametrize("num_unprocessed_tokens", [2, 17, 128]) +@pytest.mark.parametrize("num_seq_group", [1, 4, 16]) +def test_can_schedule_seqs_with_multiple_unprocessed_tokens( + block_size: int, prompt_len: int, num_unprocessed_tokens: int, + num_seq_group: int): + """Verify scheduler can schedule sequences with more than one unprocessed + tokens. This occurs when the worker emits more than one token. + """ + max_model_len = 2048 + scheduler_config = SchedulerConfig(max_num_batched_tokens=max_model_len, + max_num_seqs=num_seq_group, + max_model_len=max_model_len) + cache_config = CacheConfig(block_size=block_size, + gpu_memory_utilization=1.0, + swap_space=0) + cache_config.num_cpu_blocks = 0 + cache_config.num_gpu_blocks = 8192 // block_size + scheduler = Scheduler(scheduler_config, cache_config, None) + + prompt_lens = [prompt_len for _ in range(num_seq_group)] + + token_ids_to_append = [ + list(range(num_unprocessed_tokens)) for _ in range(num_seq_group) + ] + + # Add seq groups to scheduler. + running: List[SequenceGroup] = [] + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(request_id=str(i), + prompt_length=prompt_lens[i], + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Schedule seq groups prompts. + _, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(running) + + # Add tokens to sequences + for seq_group in out.scheduled_seq_groups: + for i, seq in enumerate(seq_group.get_seqs()): + seq.append_token_ids(token_ids_to_append[i], + logprobs=[{ + token_id: 0.0 + } for token_id in token_ids_to_append[i]]) + + # Schedule seq groups generation. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(running) + + for seq_group_metadata in seq_group_meta: + # Only one seq per group in this test. + seq_id = next(iter(seq_group_metadata.seq_data.keys())) + + block_table = seq_group_metadata.block_tables[seq_id] + blocks_required = (seq_group_metadata.seq_data[seq_id].get_len() - 1 + + block_size) // block_size + assert len(block_table) == blocks_required + + +@pytest.mark.parametrize("block_size", [1, 8, 16, 32]) +@pytest.mark.parametrize("prompt_len", [1, 128]) +@pytest.mark.parametrize("num_unprocessed_tokens", [1, 9]) +@pytest.mark.parametrize("num_preallocated_slots_per_step", [1, 9]) +@pytest.mark.parametrize("num_seq_group", [1, 4]) +def test_can_schedule_multiple_steps(block_size: int, prompt_len: int, + num_preallocated_slots_per_step: int, + num_unprocessed_tokens: int, + num_seq_group: int): + """Verify correct scheduling when the model runs more than one step per + scheduler iteration. + """ + max_model_len = 2048 + scheduler_config = SchedulerConfig( + max_num_batched_tokens=max_model_len, + max_num_seqs=num_seq_group, + max_model_len=max_model_len, + num_preallocated_slots_per_step=num_preallocated_slots_per_step) + cache_config = CacheConfig(block_size=block_size, + gpu_memory_utilization=1.0, + swap_space=0) + cache_config.num_cpu_blocks = 0 + cache_config.num_gpu_blocks = 8192 // block_size + scheduler = Scheduler(scheduler_config, cache_config, None) + + prompt_lens = [prompt_len for _ in range(num_seq_group)] + + token_ids_to_append = [ + list(range(num_unprocessed_tokens)) for _ in range(num_seq_group) + ] + + # Add seq groups to scheduler. + running: List[SequenceGroup] = [] + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(request_id=str(i), + prompt_length=prompt_lens[i], + block_size=block_size) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + # Schedule seq groups prompts. + _, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(running) + + # Add tokens to sequences + for seq_group in out.scheduled_seq_groups: + for i, seq in enumerate(seq_group.get_seqs()): + seq.append_token_ids(token_ids_to_append[i], + logprobs=[{ + token_id: 0.0 + } for token_id in token_ids_to_append[i]]) + + # Schedule seq groups generation. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(running) + + for seq_group_metadata in seq_group_meta: + # Only one seq per group in this test. + seq_id = next(iter(seq_group_metadata.seq_data.keys())) + + # The last slot is not required because it is for the last generated + # token, and will be stored in the next iteration. + slots_required = (seq_group_metadata.seq_data[seq_id].get_len() + + num_preallocated_slots_per_step) + blocks_required = round_up_to_next_block(slots_required, block_size) + + block_table = seq_group_metadata.block_tables[seq_id] + assert len(block_table) == blocks_required + + +def test_scheduler_schedule_chunked_prefill(): + block_size = 4 + num_seq_group = 2 + max_model_len = 16 + max_chunked_prefill_len = 2 + max_num_prompt_seqs = 1 + scheduler_config = SchedulerConfig( + 64, + num_seq_group, + max_model_len, + flash_style=True, + max_chunked_prefill_len=max_chunked_prefill_len, + max_num_prompt_seqs=max_num_prompt_seqs) + cache_config = CacheConfig(block_size, 1.0, 1) + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + + # Add seq groups to scheduler. + seq_groups: List[SequenceGroup] = [] + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + num_processed_token_ids=block_size - + 1) + scheduler.add_seq_group(seq_group) + seq_groups.append(seq_group) + + # Schedule chunk prefill. Only the first seq_group should be scheduled. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(seq_groups[:1]) + seq_groups[0].get_num_unprefilled() == 2 + seq_groups[1].get_num_unprefilled() == 4 + assert out.num_batched_tokens == 2 + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 1 + assert seq_group_meta[0].request_id == "0" + assert seq_group_meta[0].is_chunked_prefill + assert seq_group_meta[0].is_prompt + + # Schedule chunk prefill. Still Only the first seq_group should be scheduled. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(seq_groups[:1]) + seq_groups[0].get_num_unprefilled() == 0 + seq_groups[1].get_num_unprefilled() == 4 + assert out.num_batched_tokens == 2 + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 1 + assert seq_group_meta[0].request_id == "0" + assert not seq_group_meta[0].is_chunked_prefill + assert seq_group_meta[0].is_prompt + + # Schedule chunk prefill. This time the second seq_group should be selected + # for chunk prefill, and the first seq_group should be select for decoding. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set(seq_groups) + seq_groups[0].get_num_unprefilled() == 0 + seq_groups[1].get_num_unprefilled() == 2 + assert out.num_batched_tokens == 3 + assert (not out.blocks_to_copy and not out.blocks_to_swap_in + and not out.blocks_to_swap_out) + assert len(seq_group_meta) == 2 + assert seq_group_meta[0].request_id == "1" + assert seq_group_meta[0].is_chunked_prefill + assert seq_group_meta[0].is_prompt + assert seq_group_meta[1].request_id == "0" + assert not seq_group_meta[1].is_chunked_prefill + assert not seq_group_meta[1].is_prompt + + +def test_scheduler_max_seqs(): + block_size = 4 + num_seq_group = 4 + max_seq_group = 2 + max_model_len = 16 + scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len) + cache_config = CacheConfig(block_size, 1.0, 1) + cache_config.num_cpu_blocks = 8 + cache_config.num_gpu_blocks = 8 + scheduler = Scheduler(scheduler_config, cache_config, None) + + all_seq_groups: List[SequenceGroup] = [] + # Add seq groups to scheduler. + for i in range(num_seq_group): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + num_processed_token_ids=block_size - + 1) + all_seq_groups.append(seq_group) + + # Append 1 seq group + scheduler.add_seq_group(all_seq_groups[0]) + + # Schedule seq groups prompts. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) + + # Schedule seq groups generation. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set([all_seq_groups[0]]) + + # Append 2 more seq group + running: List[SequenceGroup] = [] + scheduler.add_seq_group(all_seq_groups[1]) + scheduler.add_seq_group(all_seq_groups[2]) + + # Schedule seq groups prompts. + # Only 1 seq group should be scheduled since max_seq_group is 2 + # and one is prompting. + seq_group_meta, out = scheduler.schedule() + assert set(out.scheduled_seq_groups) == set([all_seq_groups[1]]) diff --git a/tests/core/utils.py b/tests/core/utils.py new file mode 100644 index 0000000000000..549bc8fc941d9 --- /dev/null +++ b/tests/core/utils.py @@ -0,0 +1,29 @@ + +import time +from typing import Tuple + +from vllm import SamplingParams +from vllm.sequence import Sequence, SequenceGroup + + +def create_dummy_prompt( + request_id: str, + prompt_length: int, + block_size: int = None, + num_processed_token_ids: int = 0) -> Tuple[Sequence, SequenceGroup]: + if not block_size: + block_size = prompt_length + + # Create dummy prompt sequence with tokens 0...block_size-1 + # and prompt "0 ... block_size". + prompt_tokens = list(range(prompt_length)) + prompt_str = " ".join([str(t) for t in prompt_tokens]) + prompt = Sequence(int(request_id), + prompt_str, + prompt_tokens, + block_size, + num_processed_token_ids=num_processed_token_ids) + seq_group = SequenceGroup(request_id, [prompt], SamplingParams(), + time.time(), time.perf_counter()) + + return prompt, seq_group diff --git a/tests/engine/test_detokenize.py b/tests/engine/test_detokenize.py index 4421739390e3b..d6c5871abfa90 100644 --- a/tests/engine/test_detokenize.py +++ b/tests/engine/test_detokenize.py @@ -1,13 +1,23 @@ + +import os import pytest +from functools import partial from transformers import AutoTokenizer +from typing import Callable, List, Dict +from unittest.mock import MagicMock +from vllm.anyscale.tokenization import TransformersTokenizer +from vllm.engine.llm_engine import LLMEngine +from vllm.sampling_params import SamplingParams +from vllm.sequence import Sequence from vllm.transformers_utils.tokenizer import detokenize_incrementally TRUTH = [ - "Hello here, this is a simple test", # noqa: E501 - "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa: E501 - "我很感谢你的热情" # noqa: E501 + # pylint: disable=line-too-long + "Hello here, this is a simple test", + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", + "我很感谢你的热情" ] TOKENIZERS = [ "facebook/opt-125m", @@ -45,6 +55,8 @@ def _run_incremental_decode(tokenizer, all_input_ids, return decoded_text +@pytest.mark.skipif("HUGGING_FACE_HUB_TOKEN" not in os.environ, + reason="requires HF token") @pytest.mark.parametrize("truth", TRUTH) @pytest.mark.parametrize("tokenizer_id", TOKENIZERS) @pytest.mark.parametrize("skip_special_tokens", (True, False)) @@ -60,3 +72,76 @@ def test_decode_streaming(tokenizer_id, truth, skip_special_tokens): tokenizer, all_input_ids, skip_special_tokens=skip_special_tokens) assert decoded_text == truth + + +@pytest.mark.skipif("HUGGING_FACE_HUB_TOKEN" not in os.environ, + reason="requires HF token") +@pytest.mark.parametrize("complete_sequence", TRUTH) +@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) +@pytest.mark.parametrize("skip_special_tokens", [True, False]) +def test_decode_sequence_works_with_multiple_tokens( + complete_sequence_token_ids: List[int], + dummy_logprobs: List[Dict[int, float]], decode_sequence: Callable, + skip_special_tokens: bool): + """Verify LLMEngine can decode sequences with >1 new tokens per step. + """ + sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens) + + # Run sequentially. + seq = create_empty_sequence() + for new_token, logprob in zip(complete_sequence_token_ids, dummy_logprobs): + seq.append_token_ids([new_token], [logprob]) + decode_sequence(seq, sampling_params) + sequential_result = seq.output_text + + # Run in batch. + seq = create_empty_sequence() + seq.append_token_ids(complete_sequence_token_ids, dummy_logprobs) + decode_sequence(seq, sampling_params) + batch_result = seq.output_text + + assert sequential_result == batch_result + + +@pytest.fixture(name="dummy_logprobs") +def create_dummy_logprobs( + complete_sequence_token_ids: List[int]) -> List[Dict[int, float]]: + return list({token_id: 0.0} for token_id in complete_sequence_token_ids) + + +@pytest.fixture(name="decode_sequence") +def create_decode_sequence(tokenizer_name: str) -> Callable: + init_kwargs = dict( + enable_lora=False, + max_num_seqs=100, + max_input_length=None, + tokenizer_mode="auto", + trust_remote_code=False, + revision=None, + ) + + self = MagicMock() + self.tokenizer = TransformersTokenizer( + tokenizer_name, + **init_kwargs, + ) + + decode_sequence = partial(LLMEngine._decode_sequence, self) # pylint: disable=protected-access + return decode_sequence + + +@pytest.fixture(name="complete_sequence_token_ids") +def create_complete_sequence_token_ids(complete_sequence: str, + tokenizer_name: str) -> List[int]: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + complete_sequence_token_ids = tokenizer(complete_sequence)["input_ids"] + return complete_sequence_token_ids + + +def create_empty_sequence(): + return Sequence( + seq_id=0, + prompt="", + prompt_token_ids=[], + block_size=16, + ) diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py index 97516bd3052cf..b96db0994f3fd 100644 --- a/tests/kernels/conftest.py +++ b/tests/kernels/conftest.py @@ -1,3 +1,4 @@ + from typing import List, Tuple import pytest @@ -12,6 +13,7 @@ def create_kv_caches( head_size: int, dtype: torch.dtype, seed: int, + flash_style: bool = False, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -19,6 +21,8 @@ def create_kv_caches( scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=dtype).element_size() key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + if flash_style: + key_cache_shape = (num_blocks, block_size, num_heads, head_size) key_caches = [] for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, @@ -28,6 +32,8 @@ def create_kv_caches( key_caches.append(key_cache) value_cache_shape = (num_blocks, num_heads, head_size, block_size) + if flash_style: + value_cache_shape = (num_blocks, block_size, num_heads, head_size) value_caches = [] for _ in range(num_layers): value_cache = torch.empty(size=value_cache_shape, diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index ba062054bf406..9667d5d5236e1 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -1,7 +1,6 @@ -import pytest -import torch +from transformers.activations import get_activation -from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul +from vllm._C import ops DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing @@ -9,6 +8,11 @@ SEEDS = [0] +def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(chunks=2, dim=1) + return F.silu(x1) * x2 + + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -23,9 +27,9 @@ def test_silu_and_mul( torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda") - layer = SiluAndMul() - out = layer(x) - ref_out = layer._forward(x) + out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") + ops.silu_and_mul(out, x) + ref_out = ref_silu_and_mul(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) @@ -43,9 +47,9 @@ def test_gelu_new( torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") - layer = NewGELU() - out = layer(x) - ref_out = layer._forward(x) + out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") + ops.gelu_new(out, x) + ref_out = get_activation("gelu_new")(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) @@ -62,7 +66,7 @@ def test_gelu_fast( torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") - layer = FastGELU() - out = layer(x) - ref_out = layer._forward(x) + out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") + ops.gelu_fast(out, x) + ref_out = get_activation("gelu_fast")(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 614b65f82ccbd..a32c0cfef8ae0 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -1,8 +1,11 @@ + +import gc import random from typing import List, Optional, Tuple import pytest import torch +import torch.nn.functional as F from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask @@ -13,7 +16,8 @@ # This will change depending on the compute capability. # - 512 as a buffer MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 -NUM_BLOCKS = 40000 # Arbitrary values for testing +# This is 40000 in upstream, but we don't have enough VRAM in CI. +NUM_BLOCKS = 128 # Arbitrary values for testing PARTITION_SIZE = 512 DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -24,6 +28,40 @@ BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] SEEDS = [0] +PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] + + +@pytest.fixture(autouse=True) +def garbage_collect(): + yield + gc.collect() + torch.cuda.empty_cache() + + +def pad_attention_inputs( + pad_config: Tuple[int, int], + block_size: int, + query: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + """Pad the attention inputs to the specified batch size and context length. + """ + pad_batch_size, pad_max_context_len = pad_config + if pad_batch_size == 0: + return query, block_tables, context_lens, max_context_len + target_batch_size = ( + (query.shape[0] - 1) % pad_batch_size + 1) * pad_batch_size + target_block_size = pad_max_context_len // block_size + 1 + padded_query = F.pad(query, + (0, 0, 0, 0, 0, target_batch_size - query.shape[0])) + padded_block_table = F.pad(block_tables, + (0, target_block_size - block_tables.shape[1], + 0, target_batch_size - block_tables.shape[0])) + padded_context_lens = F.pad(context_lens, + (0, target_batch_size - context_lens.shape[0])) + return padded_query, padded_block_table, padded_context_lens, pad_max_context_len def ref_masked_attention( @@ -105,6 +143,8 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("pad_config", PAD_CONFIGS) +@torch.inference_mode() def test_paged_attention( kv_cache_factory, version: str, @@ -115,6 +155,7 @@ def test_paged_attention( block_size: int, dtype: torch.dtype, seed: int, + pad_config: Tuple[int, int], ) -> None: random.seed(seed) torch.random.manual_seed(seed) @@ -137,8 +178,9 @@ def test_paged_attention( dtype=torch.float, device="cuda") - context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - context_lens[-1] = MAX_SEQ_LEN + max_seq_len = MAX_SEQ_LEN if not pad_config[0] else (pad_config[1] - 1000) + context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] + context_lens[-1] = max_seq_len max_context_len = max(context_lens) context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") @@ -160,23 +202,32 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. + num_valid_tokens = torch.tensor([num_seqs], + dtype=torch.long, + device="cuda") output = torch.empty_like(query) + + padded_query, padded_block_table, padded_context_lens, pad_max_context_len = \ + pad_attention_inputs(pad_config, block_size, query, + block_tables, context_lens, max_context_len) + if version == "v1": ops.paged_attention_v1( output, - query, + padded_query, key_cache, value_cache, num_kv_heads, scale, - block_tables, - context_lens, + padded_block_table, + padded_context_lens, + num_valid_tokens, block_size, - max_context_len, + pad_max_context_len, alibi_slopes, ) elif version == "v2": - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // + num_partitions = ((pad_max_context_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -196,19 +247,20 @@ def test_paged_attention( exp_sums, max_logits, tmp_output, - query, + padded_query, key_cache, value_cache, num_kv_heads, scale, - block_tables, - context_lens, + padded_block_table, + padded_context_lens, + num_valid_tokens, block_size, - max_context_len, + pad_max_context_len, alibi_slopes, ) else: - raise AssertionError(f"Unknown version: {version}") + assert False, f"Unknown version: {version}" # Run the reference implementation. ref_output = torch.empty_like(query) @@ -263,12 +315,17 @@ def ref_multi_query_kv_attention( return ref_output +def is_a100(): + return torch.cuda.get_device_name().find("NVIDIA A100") >= 0 + + # TODO(woosuk): Add tests for USE_ALIBI=True. @pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.skipif(not is_a100(), reason="OOMs without A100") @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, @@ -327,3 +384,72 @@ def test_multi_query_kv_attention( dtype, ) assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +NUM_HEADS_SMALL = [(16, 16), (16, 8)] +MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS_SMALL) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@torch.inference_mode() +def test_multi_query_kv_attention_small_scale( + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + seed: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. + # As the xformers library is already tested with its own tests, we can use + # a smaller MAX_SEQ_LEN here. + max_len = min(MAX_SEQ_LEN, 4096) + seq_lens = random.sample(range(1, max_len), num_seqs) + num_tokens = sum(seq_lens) + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + qkv = torch.empty(num_tokens, + num_query_heads + 2 * num_kv_heads, + head_size, + dtype=dtype, + device="cuda") + qkv.uniform_(-scale, scale) + query, key, value = qkv.split( + [num_query_heads, num_kv_heads, num_kv_heads], dim=1) + + num_queries_per_kv = num_query_heads // num_kv_heads + if num_queries_per_kv > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) + value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + ) + output = output.squeeze(0) + + cu_seq_lens = [0] + for seq_len in seq_lens: + cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + ref_output = ref_multi_query_kv_attention( + cu_seq_lens, + query, + key, + value, + scale, + dtype, + ) + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 9b5d7687a3fec..ab5c2720b8729 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -1,7 +1,11 @@ + import random +from typing import Tuple + import pytest import torch +import torch.nn.functional as F from vllm._C import cache_ops @@ -11,9 +15,19 @@ NUM_HEADS = [8] # Arbitrary values for testing HEAD_SIZES = [64, 80, 96, 112, 128, 256] BLOCK_SIZES = [8, 16, 32] -NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing +# Upstream also have 36000, but we don't have enough VRAM in CI +NUM_BLOCKS = [1024] # Arbitrary values for testing NUM_MAPPINGS = [256] # Arbitrary values for testing SEEDS = [0] +PADDINGS = [8, 16, 0] + + +def pad_key_value(key: torch.Tensor, value: torch.Tensor, + pad_size: int) -> Tuple[torch.Tensor, torch.Tensor]: + if pad_size == 0: + return key, value + return F.pad(key, (0, 0, 0, 0, 0, pad_size)),\ + F.pad(value, (0, 0, 0, 0, 0, pad_size)) @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @@ -88,6 +102,7 @@ def test_copy_blocks( @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("padding", PADDINGS) @torch.inference_mode() def test_reshape_and_cache( kv_cache_factory, @@ -98,6 +113,7 @@ def test_reshape_and_cache( num_blocks: int, dtype: torch.dtype, seed: int, + padding: int, ) -> None: random.seed(seed) torch.random.manual_seed(seed) @@ -125,10 +141,12 @@ def test_reshape_and_cache( # Clone the KV caches. cloned_key_cache = key_cache.clone() cloned_value_cache = value_cache.clone() + num_tokens = torch.tensor([num_tokens], dtype=torch.long, device="cuda") + padded_key, padded_value = pad_key_value(key, value, padding) # Call the reshape_and_cache kernel. - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping) + cache_ops.reshape_and_cache(padded_key, padded_value, key_cache, + value_cache, slot_mapping, num_tokens) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -144,3 +162,76 @@ def test_reshape_and_cache( assert torch.allclose(key_cache, cloned_key_cache) assert torch.allclose(value_cache, cloned_value_cache) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("padding", PADDINGS) +@torch.inference_mode() +def test_reshape_and_cache_flash( + kv_cache_factory, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + padding: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Create a random slot mapping. + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + + qkv = torch.randn(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device='cuda') + _, key, value = qkv.unbind(dim=1) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(num_blocks, + block_size, + 1, + num_heads, + head_size, + dtype, + seed, + flash_style=True) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Clone the KV caches. + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + num_tokens = torch.tensor([num_tokens], dtype=torch.long, device="cuda") + + padded_key, padded_value = pad_key_value(key, value, padding) + # Call the reshape_and_cache kernel. + cache_ops.reshape_and_cache_flash(padded_key, padded_value, key_cache, + value_cache, slot_mapping, num_tokens) + + # Run the reference implementation. + block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') + block_indicies = block_indicies.cpu().tolist() + block_offsets = slot_mapping % block_size + block_offsets = block_offsets.cpu().tolist() + for i in range(num_tokens): + block_idx = block_indicies[i] + block_offset = block_offsets[i] + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) diff --git a/tests/kernels/test_flash_attention.py b/tests/kernels/test_flash_attention.py new file mode 100644 index 0000000000000..438ef4ad9f406 --- /dev/null +++ b/tests/kernels/test_flash_attention.py @@ -0,0 +1,491 @@ + +import random +from typing import List, Optional, Tuple + +import pytest +import torch +import torch.nn.functional as F + +from vllm.anyscale.attention import ( + flash_single_query_cached_kv_attention, + flash_multi_query_cached_kv_attention_varlen, +) +from vllm.utils import get_max_shared_memory_bytes + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# This will change depending on the compute capability. +# - 512 as a buffer +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +NUM_BLOCKS = 128 # Arbitrary values for testing +PARTITION_SIZE = 512 + +DTYPES = [torch.half, torch.bfloat16] +NUM_GEN_SEQS = [3, 6, 17] # Arbitrary values for testing +NUM_PREFILL_SEQS = [3, 6, 17] # Arbitrary values for testing +NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing +NUM_HEADS_SMALL = NUM_HEADS +HEAD_SIZES = [32, 64, 128, 256] +BLOCK_SIZES = [32, 64, 256] +USE_ALIBI = [False, True] +SEEDS = [0] +PAD_CONFIGS = [(0, 0), (8, MAX_SEQ_LEN - 1000), (16, MAX_SEQ_LEN - 2000)] + + +def pad_attention_inputs( + pad_config: Tuple[int, int], + block_size: int, + query: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + """Pad the attention inputs to the specified batch size and context length. + """ + pad_batch_size, pad_max_context_len = pad_config + if pad_batch_size == 0: + return query, block_tables, context_lens, max_context_len + target_batch_size = ( + (query.shape[0] - 1) % pad_batch_size + 1) * pad_batch_size + target_block_size = pad_max_context_len // block_size + 1 + padded_query = F.pad(query, + (0, 0, 0, 0, 0, target_batch_size - query.shape[0])) + padded_block_table = F.pad(block_tables, + (0, target_block_size - block_tables.shape[1], + 0, target_batch_size - block_tables.shape[0])) + padded_context_lens = F.pad(context_lens, + (0, target_batch_size - context_lens.shape[0])) + return padded_query, padded_block_table, padded_context_lens, pad_max_context_len + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + +def ref_single_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], + flash_style: bool = False, +) -> None: + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[-2] + head_size = value_cache.shape[-1] + block_size = value_cache.shape[-3] + num_seqs = query.shape[0] + + block_tables = block_tables.cpu().tolist() + context_lens = context_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables[i] + context_len = int(context_lens[i]) + + keys = [] + values = [] + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + if flash_style: + k = key_cache[block_number, block_offset, :, :] + else: + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) + keys.append(k) + + if flash_style: + v = value_cache[block_number, block_offset, :, :] + else: + v = value_cache[block_number, :, :, block_offset] + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(context_len, device="cuda").int() + alibi_bias = (position_ids - context_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( + 1, 1, -1) + + out = ref_masked_attention(q, keys, values, scale, alibi_bias) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", [False]) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("pad_config", [(0, 0)]) +@torch.inference_mode() +def test_flash_paged_attention( + kv_cache_factory, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + seed: int, + pad_config: Tuple[int, int], +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), + num_queries_per_kv) + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, + dtype=torch.float, + device="cuda") + + max_seq_len = MAX_SEQ_LEN if not pad_config[0] else (pad_config[1] - 1000) + context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] + context_lens[-1] = max_seq_len + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + dtype, + seed, + flash_style=True) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Call the paged attention kernel. + num_valid_tokens = torch.cuda.IntTensor([num_seqs]) + output = torch.empty_like(query) + + padded_query, padded_block_table, padded_context_lens, pad_max_context_len = \ + pad_attention_inputs(pad_config, block_size, query, + block_tables, context_lens, max_context_len) + + flash_single_query_cached_kv_attention( + output, + padded_query, + key_cache, + value_cache, + scale, + padded_block_table, + padded_context_lens, + block_size, + alibi_slopes, + ) + + # Run the reference implementation. + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + context_lens, + scale, + alibi_slopes, + flash_style=True, + ) + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +def ref_multi_query_kv_attention( + cu_seq_lens: List[int], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + ref_outputs = [] + for i in range(num_seqs): + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx + + # Create attention mask. + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype, device="cuda") + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + +def ref_multi_query_kv_attention_padded( + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + cu_seq_lens: List[int], + context_lens: List[int], + scale: float, + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + block_size = value_cache.shape[-3] + ref_outputs = [] + + for i in range(num_seqs): + q_start_idx = cu_seq_lens[i] + q_end_idx = cu_seq_lens[i + 1] + seq_len = q_end_idx - q_start_idx + + context_len = context_lens[i] + + block_table = block_tables[i] + keys = [] + values = [] + + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, block_offset, :, :] + keys.append(k) + + v = value_cache[block_number, block_offset, :, :] + values.append(v) + + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + q = query[q_start_idx:q_end_idx, :, :] + k = keys[:context_len, :, :] + v = values[:context_len, :, :] + + assert seq_len <= context_len + + # pad q if seq_len is less than context_len + # this is for correct calculation of attention. + if seq_len < context_len: + indices = [i % seq_len for i in range(context_len - seq_len)] + q_left_pad = q[indices, :, :] + q = torch.cat([q_left_pad, q], dim=0) + + # Create attention mask. + attn_mask = torch.triu(torch.ones(context_len, + context_len, + dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype, device="cuda") + + ref_output = ref_masked_attention( + q, + k, + v, + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output[-seq_len:, :, :]) + ref_output = torch.cat(ref_outputs, dim=0) + return ref_output + + +def is_a100(): + return torch.cuda.get_device_name().find("NVIDIA A100") >= 0 + + +if not is_a100(): + NUM_HEADS_SMALL = [(16, 16), (16, 8)] + MAX_SEQ_LEN_SMALL = max(MAX_SEQ_LEN // 4, 8192) + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS_SMALL) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("version", ["flash"]) +@pytest.mark.parametrize("chunked_prefill", [False, True]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@torch.inference_mode() +def test_multi_query_kv_attention( + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + version: str, + seed: int, + chunked_prefill: bool, + block_size: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. + # As the xformers library is already tested with its own tests, we can use + # a smaller MAX_SEQ_LEN here. + max_len = min(MAX_SEQ_LEN, 4096) + + seq_lens = [random.randint(1, max_len // 2) for i in range(num_seqs)] + max_seq_len = max(seq_lens) + seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device="cuda") + + if chunked_prefill: + # context length will be different from seq_len if chunked_prefill is + # true. + context_lens = random.sample(range(max_seq_len, max_len), num_seqs) + else: + context_lens = seq_lens + max_context_len = max(context_lens) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device="cuda") + + num_tokens = sum(seq_lens) + cu_seq_lens = [0] + for seq_len in seq_lens: + cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + + cu_context_lens = [0] + for context_len in context_lens: + cu_context_lens.append(cu_context_lens[-1] + context_len) + + print(f"cu_seq_lens={cu_seq_lens}, cu_context_lens={cu_context_lens}") + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + num_queries_per_kv = num_query_heads // num_kv_heads + + value_cache = torch.empty(NUM_BLOCKS, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device="cuda") + key_cache = torch.empty(NUM_BLOCKS, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device="cuda") + query = torch.empty(num_tokens, + num_query_heads, + head_size, + dtype=dtype, + device="cuda") + value_cache.uniform_(-scale, scale) + key_cache.uniform_(-scale, scale) + query.uniform_(-scale, scale) + + # Create the block tables. + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + + output = torch.empty_like(query) + + if version == "flash": + flash_multi_query_cached_kv_attention_varlen( + output, + query, + key_cache, + value_cache, + scale, + block_tables, + torch.cuda.IntTensor(cu_seq_lens), + torch.cuda.IntTensor(cu_context_lens), + block_size, + max_seq_len, + max_context_len, + None, + ) + else: + assert False, f"{version=} is not supported" + + ref_output = ref_multi_query_kv_attention_padded( + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + cu_seq_lens, + context_lens, + scale, + dtype, + ) + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index b362e2c43f0da..f71dd65a90c11 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -1,47 +1,63 @@ + import pytest import torch +import torch.nn as nn -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm._C import ops DTYPES = [torch.half, torch.bfloat16, torch.float] +HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing -HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing -ADD_RESIDUAL = [False, True] SEEDS = [0] +class RefRMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + weight = torch.empty(hidden_size) + weight.normal_(mean=1.0, std=0.1) + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) -@pytest.mark.parametrize("add_residual", ADD_RESIDUAL) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @torch.inference_mode() def test_rms_norm( num_tokens: int, hidden_size: int, - add_residual: bool, dtype: torch.dtype, seed: int, ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - layer = RMSNorm(hidden_size).to(dtype).cuda() - layer.weight.data.normal_(mean=1.0, std=0.1) - scale = 1 / (2 * hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda") - x *= scale - residual = torch.randn_like(x) * scale if add_residual else None - - # NOTE(woosuk): The reference implementation should be executed first - # because the custom kernel is in-place. - ref_out = layer._forward(x, residual) - out = layer(x, residual) - # NOTE(woosuk): LayerNorm operators (including RMS) typically have larger - # numerical errors than other operators because they involve reductions. - # Therefore, we use a larger tolerance. - if add_residual: - assert torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2) - assert torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2) + scale = float(hidden_size**-0.5) + x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") + x.uniform_(-scale, scale) + ref = RefRMSNorm(hidden_size).to(dtype).cuda() + + out = torch.empty_like(x) + ops.rms_norm( + out, + x, + ref.weight.data, + ref.variance_epsilon, + ) + ref_out = ref(x) + # FIXME: A10G has slight larger difference than A100. + if torch.cuda.get_device_name().find("NVIDIA A100") >= 0: + assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5) else: - assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2) + assert torch.allclose(out, ref_out, atol=1e-1, rtol=1e-4) diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 25d6bf2378cad..68f1e07a7b3d4 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -1,23 +1,102 @@ -from typing import Optional - -import pytest import torch +import torch.nn as nn +import torch.nn.functional as F -from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm._C import ops IS_NEOX_STYLE = [True, False] DTYPES = [torch.half, torch.bfloat16, torch.float] HEAD_SIZES = [64, 80, 96, 112, 128, 256] ROTARY_DIMS = [None, 32] # None means rotary dim == head size -NUM_HEADS = [7, 17] # Arbitrary values for testing -BATCH_SIZES = [1, 5] # Arbitrary values for testing -SEQ_LENS = [11, 8192] # Arbitrary values for testing +NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing +NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing SEEDS = [0] +def rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def apply_rope( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + rotate_fn = rotate_neox if is_neox_style else rotate_gptj + q_embed = (q * cos) + (rotate_fn(q) * sin) + k_embed = (k * cos) + (rotate_fn(k) * sin) + return q_embed, k_embed + + +class RefRotaryEmbedding(nn.Module): + """Reference implementation of rotary embedding.""" + + def __init__( + self, + dim: int, + is_neox_style: bool, + max_position_embeddings: int = 8192, + base: int = 10000, + ) -> None: + super().__init__() + self.rotary_dim = dim + self.is_neox_style = is_neox_style + self.max_position_embeddings = max_position_embeddings + + # Create cos and sin embeddings. + inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim)) + t = torch.arange(max_position_embeddings).float() + freqs = torch.einsum("i,j->ij", t, inv_freq.float()) + if is_neox_style: + emb = torch.cat((freqs, freqs), dim=-1) + else: + emb = torch.repeat_interleave(freqs, 2, -1) + cos = emb.cos().to(dtype=inv_freq.dtype) + sin = emb.sin().to(dtype=inv_freq.dtype) + self.register_buffer("cos_cached", cos, persistent=False) + self.register_buffer("sin_cached", sin, persistent=False) + + def forward( + self, + positions: torch.Tensor, # [num_tokens] + query: torch.Tensor, # [num_tokens, num_heads, head_size] + key: torch.Tensor, # [num_tokens, num_heads, head_size] + ) -> Tuple[torch.Tensor, torch.Tensor]: + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + + query_rot = query_rot.transpose(0, 1) + key_rot = key_rot.transpose(0, 1) + cos = F.embedding(positions, self.cos_cached) + sin = F.embedding(positions, self.sin_cached) + + query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin, + self.is_neox_style) + query_rot = query_rot.transpose(0, 1).contiguous() + key_rot = key_rot.transpose(0, 1).contiguous() + + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + + # Output query/key shape: [num_tokens, num_tokens, head_size] + return query, key + + @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) @@ -26,8 +105,7 @@ @torch.inference_mode() def test_rotary_embedding( is_neox_style: bool, - batch_size: int, - seq_len: int, + num_tokens: int, num_heads: int, head_size: int, rotary_dim: Optional[int], @@ -41,25 +119,53 @@ def test_rotary_embedding( torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - if rotary_dim is None: - rotary_dim = head_size - rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) - rope = rope.to(dtype).cuda() - - positions = torch.randint(0, - max_position, (batch_size, seq_len), - device="cuda") - query = torch.randn(batch_size, - seq_len, + positions = torch.randint(0, max_position, (num_tokens, ), device="cuda") + query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") - key = torch.randn_like(query) + key = torch.randn(num_tokens, + num_heads * head_size, + dtype=dtype, + device="cuda") + + # Create the rotary embedding. + inv_freq = 1.0 / (base**( + torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)) + t = torch.arange(max_position).float() + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cos_sin_cache = torch.cat((cos, sin), dim=-1) + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda") + + # Run the kernel. The kernel is in-place, so we need to clone the inputs. + out_query = query.clone() + out_key = key.clone() + ops.rotary_embedding( + positions, + out_query, + out_key, + head_size, + cos_sin_cache, + is_neox_style, + ) + + # Run the reference implementation. + ref_rotary_embedding = RefRotaryEmbedding( + dim=rotary_dim, + is_neox_style=is_neox_style, + max_position_embeddings=max_position, + base=base, + ).to(dtype=dtype, device="cuda") + ref_query, ref_key = ref_rotary_embedding( + positions, + query.view(num_tokens, num_heads, head_size), + key.view(num_tokens, num_heads, head_size), + ) + ref_query = ref_query.view(num_tokens, num_heads * head_size) + ref_key = ref_key.view(num_tokens, num_heads * head_size) - # NOTE(woosuk): The reference implementation should be executed first - # because the custom kernel is in-place. - ref_query, ref_key = rope._forward(positions, query, key) - out_query, out_key = rope.forward(positions, query, key) # Compare the results. assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5) assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5) diff --git a/tests/models/test_layer_difference.py b/tests/models/test_layer_difference.py new file mode 100644 index 0000000000000..6ec1b8d6a6fa3 --- /dev/null +++ b/tests/models/test_layer_difference.py @@ -0,0 +1,135 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling. + +Run `pytest tests/models/test_models.py --forked`. +""" +import gc +import re +import time + +import pytest +import torch + +from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel + +MODELS = ["facebook/opt-125m", "/mnt/local_storage/llama-7b"] + + +class LLMExecutionTracer: + """Trace the execution of LLM model and capture the outputs of each layer. + + This class is used to compare the outputs of HF and vLLM. It's only + expected to run single request for each batch. + """ + + def __init__(self, model, is_hf: bool = False): + """Initialize the tracer. + + Args: + model: The torch.nn to trace. + is_hf: Whether the model is from huggingface transformer. + """ + self._captured = {} + self._layer_modules = {} + self._register_module(model) + self._start() + self.request_id = 0 + self.is_hf = is_hf + + def _hook(self, module, input_, output) -> None: + if module not in self._layer_modules: + return + layer = self._layer_modules[module] + + if self.is_hf: + # hf model has an extra dimension for batch size. + output = output[0].view(-1, output[0].shape[-1]) + + # store output by request_id, layer, iteration + self._captured.setdefault(self.request_id, dict()).setdefault(layer, list())\ + .append((None, output.to("cpu"))) + + def start_new_request(self, request_id) -> None: + self.request_id = request_id + + @property + def captured(self) -> None: + return self._captured + + def stop(self) -> None: + self.handle.remove() + + def _start(self) -> None: + self.handle = torch.nn.modules.module.register_module_forward_hook( + self._hook) + + def _register_module(self, module) -> None: + for module_name, module in module.named_modules(): + layer = self._find_layer_module(module_name) + if layer is not None: + self._layer_modules[module] = layer + + def _find_layer_module(self, module_name: str): + if m := re.match("^.*layers\.([0-9]+)$", module_name): + print(f"matched {module_name}") + return int(m.group(1)) + return None + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + hf_outputs = [] + vllm_outputs = [] + example_prompts = example_prompts[:7] + + vllm_model = vllm_runner(model, dtype=dtype) + vllm_tracer = LLMExecutionTracer( + vllm_model.model.llm_engine.workers[0].model, is_hf=False) + for id, prompt in enumerate(example_prompts): + vllm_tracer.start_new_request(id) + vllm_outputs.extend(vllm_model.generate_greedy([prompt], max_tokens)) + vllm_captured = vllm_tracer.captured + vllm_tracer.stop() + del vllm_tracer + del vllm_model.model + del vllm_model + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + + hf_model = hf_runner(model, dtype=dtype) + hf_tracer = LLMExecutionTracer(hf_model.model, is_hf=True) + for id, prompt in enumerate(example_prompts): + hf_tracer.start_new_request(id) + hf_outputs.extend(hf_model.generate_greedy([prompt], max_tokens)) + hf_captured = hf_tracer.captured + hf_tracer.stop() + del hf_model + gc.collect() + torch.cuda.empty_cache() + + for request_id in range(len(example_prompts)): + diffs = list() + for iteration in range(max_tokens): + for layer in range(len(vllm_captured)): + hf_input, hf_output = hf_captured[request_id][layer][iteration] + vllm_input, vllm_output_padded = vllm_captured[request_id][ + layer][iteration] + vllm_output = vllm_output_padded[:hf_output.shape[0], :] + diff = torch.sum( + torch.abs(hf_output - vllm_output)) / torch.sum( + torch.abs(hf_output)) + diffs.append(diff.item()) + + print("request_id {} avg difference {:.2f}%".format( + request_id, + torch.mean(torch.FloatTensor(diffs)).item() * 100)) + assert torch.mean(torch.FloatTensor(diffs)).item() < 0.005 diff --git a/tests/models/test_models.py b/tests/models/test_models.py index e65c424c601a2..95eabaafec811 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -15,12 +15,12 @@ "EleutherAI/pythia-70m", "bigscience/bloom-560m", "mosaicml/mpt-7b", - "microsoft/phi-2", + "microsoft/phi-1_5", ] @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) def test_models( hf_runner, diff --git a/tests/samplers/__init__.py b/tests/samplers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index a491ffa763505..020866c92e372 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -1,7 +1,3 @@ -"""Compare the outputs of HF and vLLM when using beam search. - -Run `pytest tests/samplers/test_beam_search.py --forked`. -""" import pytest # FIXME(zhuohan): The test can not pass if we: @@ -12,6 +8,9 @@ BEAM_WIDTHS = [4] MODELS = ["facebook/opt-125m"] +pytest.skip("ANYSCALE skip test as beam search is removed.", + allow_module_level=True) + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 1c67cc5bd7394..1c1b227a3f3bc 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -1,6 +1,3 @@ -import pytest -import torch - from vllm import SamplingParams MODELS = ["facebook/opt-125m"] diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py new file mode 100644 index 0000000000000..0c82f22d0359f --- /dev/null +++ b/tests/samplers/test_rejection_sampler.py @@ -0,0 +1,392 @@ +"""Tests for rejection sampling.""" +import pytest +from typing import List, Tuple + +import torch +import torch.nn.functional as F + +from vllm.model_executor.utils import set_random_seed + +from vllm.model_executor.layers.rejection_sampler import RejectionSampler + + +def mock_causal_accepted_tensor( + k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor: + """Generate an "accepted" tensor which should yield causally-accepted tokens + up to last accepted indices. + + Tokens after last_accepted_indices+1 may also be accepted, although they + will not be causally accepted. + """ + batch_size = last_accepted_indices.shape[0] + + accepted = (torch.arange(k).expand(batch_size, k) <= + last_accepted_indices.unsqueeze(-1).broadcast_to( + batch_size, k)).to(device="cuda") + + # Sprinkle accepted values after the contiguous initial accepted values. + # This replicates the behavior of rejection sampling, which may "accept" + # a token that cannot be accepted because of causality. + sprinkle_candidates = ( + torch.arange(k).expand(batch_size, k) > + last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1) + sprinkle = torch.rand(batch_size, k, device="cuda") > 0.5 + accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates] + return accepted + + +@pytest.mark.parametrize("seed", list(range(10))) +@pytest.mark.parametrize( + "which_tokens_accepted", + ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) +@torch.inference_mode() +def test_correct_output_format(which_tokens_accepted: str, seed: int): + """Verify the output has correct format given predetermined accepted matrix. + """ + set_random_seed(seed) + + batch_size = 10 + k = 5 + vocab_size = 3000 + + if which_tokens_accepted == "all_tokens_accepted": + accepted = mock_causal_accepted_tensor( + k, -1 + k * torch.ones((batch_size, ), dtype=torch.long)) + elif which_tokens_accepted == "no_tokens_accepted": + accepted = mock_causal_accepted_tensor( + k, -torch.ones((batch_size, ), dtype=torch.long)) + elif which_tokens_accepted == "some_tokens_accepted": + last_accepted_indices = torch.randint(low=-1, + high=k, + size=(batch_size, )) + accepted = mock_causal_accepted_tensor(k, last_accepted_indices) + else: + assert False + + recovered_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device="cuda") + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device="cuda") + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64, + device="cuda") + + rejection_sampler = RejectionSampler() + rejection_sampler.init_gpu_tensors(rank=0) + output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access + accepted, + recovered_token_ids, + draft_token_ids, + bonus_token_ids, + ) + + if which_tokens_accepted == "all_tokens_accepted": + # Expect all tokens to be equal to draft tokens. + assert torch.equal(output_token_ids[:, :-1], draft_token_ids) + + # Expect all bonus tokens to be included. + assert torch.equal(output_token_ids[:, -1:], bonus_token_ids) + elif which_tokens_accepted == "no_tokens_accepted": + # Expect first token to be equal to recovered tokens. + assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0]) + + # Expect everything else to be -1. + assert torch.equal(output_token_ids[:, 1:], + torch.ones_like(output_token_ids[:, 1:]) * -1) + elif which_tokens_accepted == "some_tokens_accepted": + recovered_plus_bonus = torch.cat( + (recovered_token_ids, bonus_token_ids), dim=-1) + # Assert first rejected token is a recovered token or bonus token. + assert torch.equal( + recovered_plus_bonus[torch.arange(0, batch_size), + last_accepted_indices + 1], + output_token_ids[torch.arange(0, batch_size), + last_accepted_indices + 1]) + + # Assert every subsequent token is -1. + subsequent_mask = torch.arange(0, k + 1).expand( + batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1) + assert torch.all(output_token_ids[subsequent_mask] == -1) + + +@pytest.mark.parametrize("k", list(range(1, 6))) +@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) +@pytest.mark.parametrize("batch_size", list(range(1, 32))) +@torch.inference_mode() +def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int): + rejection_sampler = RejectionSampler() + rejection_sampler.init_gpu_tensors(rank=0) + + draft_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device="cuda") + target_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device="cuda") + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64, + device="cuda") + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device="cuda") + + rejection_sampler(target_probs, bonus_token_ids, draft_probs, + draft_token_ids) + + +@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) +@pytest.mark.parametrize("which_token_ids", + ["bonus_token_ids", "draft_token_ids"]) +@torch.inference_mode() +def test_raises_when_vocab_oob(above_or_below_vocab_range: str, + which_token_ids: str): + k = 3 + batch_size = 5 + vocab_size = 30_000 + + rejection_sampler = RejectionSampler(strict_mode=True) + rejection_sampler.init_gpu_tensors(rank=0) + + draft_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device="cuda") + target_probs = torch.rand(batch_size, + k, + vocab_size, + dtype=torch.float32, + device="cuda") + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64, + device="cuda") + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64, + device="cuda") + + oob_token_ids = None + if which_token_ids == "bonus_token_ids": + oob_token_ids = bonus_token_ids + elif which_token_ids == "draft_token_ids": + oob_token_ids = draft_token_ids + else: + assert False + + if above_or_below_vocab_range == "above": + rogue_token_id = vocab_size + 1 + elif above_or_below_vocab_range == "below": + rogue_token_id = -1 + else: + assert False + + oob_token_ids[0][0] = rogue_token_id + + with pytest.raises(AssertionError): + rejection_sampler(target_probs, bonus_token_ids, draft_probs, + draft_token_ids) + + +@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False]) +@pytest.mark.parametrize("seed", list(range(5))) +@torch.inference_mode() +def test_rejection_sampling_approximates_target_distribution( + seed: int, draft_and_target_probs_equal: bool): + """Verify rejection sampling approximates target distribution, + despite sampling from a potentially distinct draft distribution. + + This is done by first creating a random target probability + distribution and a random draft probability distribution. We then + sample token ids from the rejection sampler using these draft + and target distributions. The samples are used to estimate + the output probability distribution, which we expect to approximate + the target distribution. + + A basic distance metric is used to determine similarity between + distributions. + + We expect that as we increase the number of samples, + the distance between the observed distribution and the target + distribution decreases. To measure this, we compare the distance + of the observed distribution against both the target distribution + and a uniform random distribution. We expect the distance between + the observed distribution and the target distribution to improve + much more than the distance improvement between the observed + distribution and the random distribution. + + When draft_and_target_probs_equal=True, the draft and target + probabilities are exactly equal. Rejection sampling should + still work without any NaNs or exceptions. + """ + set_random_seed(seed) + + helper = _CorrectnessTestHelper( + vocab_size=10, + rejection_sampler=RejectionSampler(), + ) + + draft_probs, target_probs, reference_probs = helper.generate_probs_for_test( + draft_and_target_probs_equal) + + sample_sizes = [10, 100, 1_000, 10_000, 100_000] + distance_wrt_reference = [] + distance_wrt_target = [] + + for num_samples in sample_sizes: + (reference_vs_rejsample_dist, + target_vs_rejsample_dist) = helper.run_and_compare_distributions( + draft_probs, + target_probs, + reference_probs, + num_samples, + ) + + distance_wrt_reference.append(reference_vs_rejsample_dist) + distance_wrt_target.append(target_vs_rejsample_dist) + + relative_change_in_distance_wrt_target = get_ratio_first_to_last( + distance_wrt_target) + relative_change_in_distance_wrt_reference = get_ratio_first_to_last( + distance_wrt_reference) + + print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} " + f"{reference_vs_rejsample_dist=:.05f}") + print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} " + f"{relative_change_in_distance_wrt_reference=:.02f}") + + relative_change_in_distance_wrt_target = get_ratio_first_to_last( + distance_wrt_target) + relative_change_in_distance_wrt_reference = get_ratio_first_to_last( + distance_wrt_reference) + + expected_improvement_multiplier = 20 + assert (relative_change_in_distance_wrt_target > + relative_change_in_distance_wrt_reference * + expected_improvement_multiplier) + + +def get_ratio_first_to_last(elements: List[float]) -> float: + return elements[0] / elements[-1] + + +class _CorrectnessTestHelper: + """Class that packages together logic required for the unit-level + rejection sampling correctness test. + """ + + def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler): + self.rejection_sampler = rejection_sampler + self.vocab_size = vocab_size + self.vocab_range = (0, vocab_size) + + self.rejection_sampler.init_gpu_tensors(rank=0) + + # Keep test simple, use k=1 + self.k = 1 + + # Bonus tokens not used, but rejection sampler requires + # correct shape. + self.num_bonus_tokens = 1 + + def generate_probs_for_test( + self, draft_and_target_probs_equal: bool + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + draft_probs, target_probs = [ + F.softmax( + torch.rand(self.vocab_size, dtype=torch.float32), + dim=-1, + ) for _ in range(2) + ] + + num_reference_probs = 100 + reference_probs = F.softmax( + torch.rand(num_reference_probs, + self.vocab_size, + dtype=torch.float32), + dim=-1, + ) + + if draft_and_target_probs_equal: + target_probs = draft_probs.clone() + + return draft_probs, target_probs, reference_probs + + def run_and_compare_distributions(self, draft_probs: torch.Tensor, + target_probs: torch.Tensor, + reference_probs: torch.Tensor, + num_samples: int) -> Tuple[float, float]: + # Sample using rejection sampling. + rej_sample_probs = self._estimate_rejection_sampling_pdf( + draft_probs, target_probs, num_samples) + + # Average distance from reference probs. + reference_vs_rejsample_dist = torch.dist( + reference_probs, + rej_sample_probs).item() / reference_probs.shape[0] + target_vs_rejsample_dist = torch.dist(target_probs, + rej_sample_probs).item() + + return reference_vs_rejsample_dist, target_vs_rejsample_dist + + def _estimate_rejection_sampling_pdf( + self, + draft_probs: torch.Tensor, + target_probs: torch.Tensor, + num_samples: int, + ) -> torch.Tensor: + # Repeat draft probs num_samples times. + draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat( + num_samples, 1, 1) + + # Repeat target probs num_samples * k times. + # Rejection sampler requires bonus token probs, but they aren't used. + target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat( + num_samples, self.k, 1) + + # Randomly sample draft token ids from draft probs. + draft_token_ids = torch.multinomial(draft_probs[:, 0, :], + num_samples=1, + replacement=True).reshape( + num_samples, self.k) + + # Bonus tokens not used but required. + bonus_token_ids = torch.zeros((1, self.num_bonus_tokens), + dtype=torch.int64, + device="cuda").repeat(num_samples, 1) + + # Get output tokens via rejection sampling. + output_token_ids = self.rejection_sampler(target_probs.to("cuda"), + bonus_token_ids.to("cuda"), + draft_probs.to("cuda"), + draft_token_ids.to("cuda")) + + # Remove bonus tokens + output_token_ids = output_token_ids[:, :-1].flatten() + + # Estimate probability density function + hist = torch.histogram(output_token_ids.to(dtype=torch.float, + device="cpu"), + bins=self.vocab_size, + range=self.vocab_range, + density=True) + + return hist.hist diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 3ad2d4608fbd5..8135055535f58 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,33 +1,47 @@ + +# pylint: disable=protected-access import random -from typing import Tuple from unittest.mock import patch +from typing import Optional, Tuple import pytest import torch -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, pythonize_sampler_output +from vllm.config import ParallelConfig, SchedulerConfig from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.worker.model_runner import ModelRunner +from vllm.model_executor.input_metadata import InputMetadata +from vllm.sequence import SamplerOutput +from vllm.worker.worker import Worker class MockLogitsSampler(Sampler): def __init__(self, vocab_size: int, fake_logits: torch.Tensor): - super().__init__(vocab_size=vocab_size) + super().__init__(vocab_size=vocab_size, org_vocab_size=vocab_size) self.fake_logits = fake_logits - def forward(self, *args, **kwargs): + def _get_logits(self, *args, **kwargs) -> torch.Tensor: + del args + del kwargs + return self.fake_logits + + def forward( + self, + embedding: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + embedding_bias: Optional[torch.Tensor] = None) -> SamplerOutput: with patch("vllm.model_executor.layers.sampler._prune_hidden_states", - lambda x, y: x), patch( - "vllm.model_executor.layers.sampler._get_logits", - lambda *args, **kwargs: self.fake_logits): - return super().forward(*args, **kwargs) + lambda x, y: x): + return super().forward(embedding, hidden_states, input_metadata, + embedding_bias) def _prepare_test( batch_size: int -) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]: +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]: vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), device="cuda", @@ -37,8 +51,10 @@ def _prepare_test( device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(32000, fake_logits) - model_runner = ModelRunner(None, None, None) - return input_tensor, fake_logits, sampler, model_runner + scheduler_config = SchedulerConfig(2048, 2048, 2048) + worker = Worker(None, ParallelConfig(1, 1, False), scheduler_config) + worker.block_size = 16 + return input_tensor, fake_logits, sampler, worker RANDOM_SEEDS = list(range(128)) @@ -48,27 +64,27 @@ def _prepare_test( def test_sampler_all_greedy(seed: int): set_random_seed(seed) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) seq_group_metadata_list = [] - prompt_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, + is_chunked_prefill=False, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=SamplingParams(temperature=0, ), block_tables={0: [1]}, + lora_request=None, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) - sampler_output = sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) + _, _, input_metadata, _, _ = worker._prepare_inputs( + seq_group_metadata_list) + sampler_output = pythonize_sampler_output( + sampler(embedding=None, + hidden_states=input_tensor, + input_metadata=input_metadata), input_metadata) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: @@ -79,85 +95,84 @@ def test_sampler_all_greedy(seed: int): def test_sampler_all_random(seed: int): set_random_seed(seed) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) for i in range(batch_size): fake_logits[i, i] = 1e2 seq_group_metadata_list = [] - prompt_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, + is_chunked_prefill=False, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=SamplingParams( temperature=1.0, n=random.randint(1, 10), ), block_tables={0: [1]}, + lora_request=None, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) - sampler_output = sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) + _, _, input_metadata, _, _ = worker._prepare_inputs( + seq_group_metadata_list) + sampler_output = pythonize_sampler_output( + sampler(embedding=None, + hidden_states=input_tensor, + input_metadata=input_metadata), input_metadata) for i, sequence_output in enumerate(sampler_output): for nth_output in sequence_output.samples: assert nth_output.output_token == i -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -def test_sampler_all_beam(seed: int): - set_random_seed(seed) - batch_size = random.randint(1, 256) - input_tensor, _, sampler, model_runner = _prepare_test(batch_size) - - seq_group_metadata_list = [] - prompt_lens = [] - for i in range(batch_size): - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=f"test_{i}", - is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, - sampling_params=SamplingParams( - temperature=0, - best_of=2, - use_beam_search=True, - ), - block_tables={0: [1]}, - )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) - sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) - # no assertion here as I am not sure how to determine whether - # the outputs are expected - in other words, this just tests - # whether there are no exceptions in the sampler - # when handling an all-beam search case. +# @pytest.mark.parametrize("seed", RANDOM_SEEDS) +# def test_sampler_all_beam(seed: int): +# set_random_seed(seed) +# batch_size = random.randint(1, 256) +# input_tensor, _, sampler, worker = _prepare_test(batch_size) + +# seq_group_metadata_list = [] +# for i in range(batch_size): +# seq_group_metadata_list.append( +# SequenceGroupMetadata( +# request_id=f"test_{i}", +# is_prompt=True, +# is_chunked_prefill=False, +# seq_data={0: SequenceData([1, 2, 3])}, +# sampling_params=SamplingParams( +# temperature=0, +# best_of=2, +# use_beam_search=True, +# ), +# block_tables={0: [1]}, +# lora_request=None, +# )) + +# _, _, input_metadata, _, _ = worker._prepare_inputs( +# seq_group_metadata_list) +# pythonize_sampler_output( +# sampler(embedding=None, +# hidden_states=input_tensor, +# input_metadata=input_metadata), input_metadata) +# # no assertion here as I am not sure how to determine whether +# # the outputs are expected - in other words, this just tests +# # whether there are no exceptions in the sampler +# # when handling an all-beam search case. @pytest.mark.parametrize("seed", RANDOM_SEEDS) def test_sampler_mixed(seed: int): set_random_seed(seed) batch_size = random.randint(1, 256) - input_tensor, fake_logits, sampler, model_runner = _prepare_test( - batch_size) + input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) seq_group_metadata_list = [] expected_tokens = [] - prompt_lens = [] for i in range(batch_size): n = 1 - sampling_type = random.randint(0, 2) + sampling_type = random.randint(0, 1) if sampling_type == 0: sampling_params = SamplingParams(temperature=0) elif sampling_type == 1: @@ -169,10 +184,10 @@ def test_sampler_mixed(seed: int): n=n, presence_penalty=random.randint(0, 1), ) - else: - sampling_params = SamplingParams(temperature=0, - use_beam_search=True, - best_of=2) + # else: + # sampling_params = SamplingParams(temperature=0, + # use_beam_search=True, + # best_of=2) for idx in range(n): fake_logits[i, i + idx] = 1e2 expected_tokens.append(i + idx) @@ -180,20 +195,22 @@ def test_sampler_mixed(seed: int): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, + is_chunked_prefill=False, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, + lora_request=None, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) - sampler_output = sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) + _, _, input_metadata, _, _ = worker._prepare_inputs( + seq_group_metadata_list) + sampler_output = pythonize_sampler_output( + sampler(embedding=None, + hidden_states=input_tensor, + input_metadata=input_metadata), input_metadata) for i, sequence_output in enumerate(sampler_output): - if seq_group_metadata_list[i].sampling_params.use_beam_search: - continue + # if seq_group_metadata_list[i].sampling_params.use_beam_search: + # continue for nth_output in sequence_output.samples: assert nth_output.output_token in expected_tokens @@ -202,7 +219,7 @@ def test_sampler_mixed(seed: int): def test_sampler_logits_processors(seed: int): set_random_seed(seed) batch_size = random.randint(1, 256) - input_tensor, _, sampler, model_runner = _prepare_test(batch_size) + input_tensor, _, sampler, worker = _prepare_test(batch_size) # This sample logits processor gives infinite score to the i-th token, # where i is the length of the input sequence. @@ -212,24 +229,25 @@ def pick_ith(token_ids, logits): return logits seq_group_metadata_list = [] - prompt_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, + is_chunked_prefill=False, seq_data={0: SequenceData([1, 2, 3])}, sampling_params=SamplingParams(temperature=0, logits_processors=[pick_ith]), block_tables={0: [1]}, + lora_request=None, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) - - sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) - sampler_output = sampler(embedding=None, - hidden_states=input_tensor, - sampling_metadata=sampling_metadata) - for _, sequence_output in enumerate(sampler_output): + + _, _, input_metadata, _, _ = worker._prepare_inputs( + seq_group_metadata_list) + sampler_output = pythonize_sampler_output( + sampler(embedding=None, + hidden_states=input_tensor, + input_metadata=input_metadata), input_metadata) + for i, sequence_output in enumerate(sampler_output): for idx, nth_output in enumerate(sequence_output.samples): assert nth_output.output_token == idx diff --git a/tests/spec_decode/conftest.py b/tests/spec_decode/conftest.py new file mode 100644 index 0000000000000..31ff360414d44 --- /dev/null +++ b/tests/spec_decode/conftest.py @@ -0,0 +1,133 @@ + +import pytest +from typing import Generator, Optional + +from vllm import LLM +from tests.anyscale.utils import SecretManager, cleanup + + +@pytest.fixture(scope="session", autouse=True) +def load_secrets(): + secrets = SecretManager() + secrets.override_secret("HUGGING_FACE_HUB_TOKEN") + + +# pylint: disable=redefined-outer-name +@pytest.fixture(name="spec_decode_llm") +def create_spec_decode_llm( + spec_decode_llm_generator: Generator[LLM, None, None]) -> LLM: + for spec_decode_llm in spec_decode_llm_generator: + yield spec_decode_llm + del spec_decode_llm + + +@pytest.fixture +def spec_decode_llm_generator( + target_model: str, draft_model: str, num_speculative_tokens: str, + tensor_parallel_size: int, with_cuda_graph: bool, + speculative_model_uses_tp_1: bool, + disable_shared_memory: bool) -> Generator[LLM, None, None]: + return create_spec_decode_llm_generator(target_model, draft_model, + num_speculative_tokens, + tensor_parallel_size, + with_cuda_graph, + speculative_model_uses_tp_1, + disable_shared_memory) + + +@pytest.fixture +def max_model_len_spec_decode_generator( + target_model: str, draft_model: str, num_speculative_tokens: str, + tensor_parallel_size: int, with_cuda_graph: bool, + disable_shared_memory: bool, + max_model_len: int) -> Generator[LLM, None, None]: + return create_spec_decode_llm_generator( + target_model, + draft_model, + num_speculative_tokens, + tensor_parallel_size, + with_cuda_graph, + speculative_model_uses_tp_1=False, + disable_shared_memory=disable_shared_memory, + max_model_len=max_model_len) + + +def create_spec_decode_llm_generator( + target_model: str, + draft_model: str, + num_speculative_tokens: str, + tensor_parallel_size: int, + with_cuda_graph: bool, + speculative_model_uses_tp_1: bool, + disable_shared_memory: bool, + max_model_len: Optional[int] = None) -> Generator[LLM, None, None]: + + def generator(): + addl_kwargs = {} + if max_model_len is not None: + addl_kwargs["max_model_len"] = max_model_len + + spec_decode_llm = LLM( + model=target_model, + speculative_model=draft_model, + num_speculative_tokens=num_speculative_tokens, + tensor_parallel_size=tensor_parallel_size, + enable_cuda_graph=with_cuda_graph, + disable_shared_memory=disable_shared_memory, + speculative_model_uses_tp_1=speculative_model_uses_tp_1, + worker_use_ray=True, + **addl_kwargs, + ) + + yield spec_decode_llm + + del spec_decode_llm + cleanup() + + return generator() + + +@pytest.fixture +def non_spec_decode_llm_generator( + target_model: str, tensor_parallel_size: int, + with_cuda_graph: bool) -> Generator[LLM, None, None]: + + return create_non_spec_decode_llm_generator(target_model, + tensor_parallel_size, + with_cuda_graph) + + +@pytest.fixture +def max_model_len_llm_generator( + target_model: str, tensor_parallel_size: int, with_cuda_graph: bool, + max_model_len: int) -> Generator[LLM, None, None]: + return create_non_spec_decode_llm_generator(target_model, + tensor_parallel_size, + with_cuda_graph, max_model_len) + + +def create_non_spec_decode_llm_generator( + target_model: str, + tensor_parallel_size: int, + with_cuda_graph: bool, + max_model_len: Optional[int] = None) -> Generator[LLM, None, None]: + + def generator(): + addl_kwargs = {} + if max_model_len is not None: + addl_kwargs["max_model_len"] = max_model_len + + llm = LLM( + model=target_model, + tensor_parallel_size=tensor_parallel_size, + enable_cuda_graph=with_cuda_graph, + worker_use_ray=True, + **addl_kwargs, + ) + + yield llm + + del llm + cleanup() + + return generator() diff --git a/tests/spec_decode/test_integration.py b/tests/spec_decode/test_integration.py new file mode 100644 index 0000000000000..b06621a410a4b --- /dev/null +++ b/tests/spec_decode/test_integration.py @@ -0,0 +1,166 @@ + +import pytest +from typing import Generator +import torch + +from vllm import LLM, SamplingParams +from tests.spec_decode.utils import (get_outputs, get_tokens_and_text, + wait_for_gpu_memory_to_clear) +from tests.anyscale.utils import cleanup + + +@pytest.mark.parametrize("draft_model", ["JackFram/llama-68m"]) +@pytest.mark.parametrize("target_model", ["JackFram/llama-160m"]) +@pytest.mark.parametrize("num_speculative_tokens", [5]) +@pytest.mark.parametrize("output_len", [128]) +@pytest.mark.parametrize("temperature", [1.0]) +@pytest.mark.parametrize("tensor_parallel_size", [1, 2]) +@pytest.mark.parametrize("with_cuda_graph", [True, False]) +@pytest.mark.parametrize("speculative_model_uses_tp_1", [False]) +@pytest.mark.parametrize("disable_shared_memory", [False]) +def test_integration_tp_and_cuda_graph( + spec_decode_llm_generator: Generator[LLM, None, None], output_len: int, + temperature: float, tensor_parallel_size: int): + """Test integration with cuda graphs and different TP degrees. + """ + run_test(spec_decode_llm_generator, output_len, temperature, + tensor_parallel_size) + + +@pytest.mark.parametrize("draft_model", ["JackFram/llama-68m"]) +@pytest.mark.parametrize("target_model", ["JackFram/llama-160m"]) +@pytest.mark.parametrize("num_speculative_tokens", [5]) +@pytest.mark.parametrize("output_len", [128]) +@pytest.mark.parametrize("temperature", [1.0]) +@pytest.mark.parametrize("tensor_parallel_size", [2]) +@pytest.mark.parametrize("with_cuda_graph", [True]) +@pytest.mark.parametrize("speculative_model_uses_tp_1", [False]) +@pytest.mark.parametrize("disable_shared_memory", [False, True]) +def test_integration_shm(spec_decode_llm_generator: Generator[LLM, None, None], + output_len: int, temperature: float, + tensor_parallel_size: int): + """Test integration with cuda graphs and shared memory. + """ + run_test(spec_decode_llm_generator, output_len, temperature, + tensor_parallel_size) + + +@pytest.mark.parametrize("draft_model", ["JackFram/llama-68m"]) +@pytest.mark.parametrize("target_model", ["JackFram/llama-160m"]) +@pytest.mark.parametrize("num_speculative_tokens", [5]) +@pytest.mark.parametrize("output_len", [128]) +@pytest.mark.parametrize("temperature", [1.0]) +@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) +@pytest.mark.parametrize("with_cuda_graph", [True, False]) +@pytest.mark.parametrize("speculative_model_uses_tp_1", [True]) +@pytest.mark.parametrize("disable_shared_memory", [False]) +def test_integration_draft_tp1(spec_decode_llm_generator: Generator[LLM, None, + None], + output_len: int, temperature: float, + tensor_parallel_size: int): + """Test integration with different draft/target TP degrees + """ + run_test(spec_decode_llm_generator, output_len, temperature, + tensor_parallel_size) + + +def run_test(spec_decode_llm_generator: Generator[LLM, None, None], + output_len: int, temperature: float, tensor_parallel_size: int): + if torch.cuda.device_count() < tensor_parallel_size: + pytest.skip(f"Expected {tensor_parallel_size=} devices") + + print("waiting for free memory before test start") + wait_for_gpu_memory_to_clear(list(range(torch.cuda.device_count())), + 1000 * 2**20) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + output_len = 128 + sampling_params = SamplingParams( + temperature=temperature, + max_tokens=output_len, + ignore_eos=True, + ) + + spec_outputs = get_outputs(spec_decode_llm_generator, prompts, + sampling_params) + cleanup() + _, spec_output_token_ids = get_tokens_and_text(spec_outputs) + + # Assert enough expected tokens were returned. + for token_ids in spec_output_token_ids: + assert len(token_ids) == output_len + + +@pytest.mark.parametrize("draft_model", ["meta-llama/Llama-2-7b-chat-hf"]) +@pytest.mark.parametrize("target_model", ["meta-llama/Llama-2-7b-chat-hf"]) +@pytest.mark.parametrize("num_speculative_tokens", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [4]) +@pytest.mark.parametrize("with_cuda_graph", [False]) +@pytest.mark.parametrize("disable_shared_memory", [False]) +@pytest.mark.parametrize("speculative_model_uses_tp_1", [False]) +def test_truncates_after_eos(spec_decode_llm_generator: Generator[LLM, None, + None], + non_spec_decode_llm_generator: Generator[LLM, + None, + None], + tensor_parallel_size: int): + """Since speculative decoding generates tokens in blocks, we verify that + the engine truncates any tokens after EOS when EOS is to be respected. + """ + # This test requires 26GB GPU memory (7B+7B), so it lives with the + # distributed integration tests despite being a unit test. + # + # I couldn't reproduce the error case with any publicly available model + # under 7B params. + + if torch.cuda.device_count() < tensor_parallel_size: + pytest.skip(f"Expected {tensor_parallel_size=} devices") + + print("waiting for free memory before test start") + wait_for_gpu_memory_to_clear(list(range(torch.cuda.device_count())), + 1000 * 2**20) + + sampling_params = SamplingParams( + max_tokens=100, + ignore_eos=False, + temperature=0.0, + ) + prompts = [ + ("[INST] <>\nYou repeat the prompt exactly. You do not add " + "additional thoughts or words\n<>\n\n Repeat this exactly: " + "'Hello world.'[/INST]"), + ] + + print("Starting generation") + spec_outputs = get_outputs(spec_decode_llm_generator, prompts, + sampling_params) + spec_output_text, spec_output_token_ids = get_tokens_and_text(spec_outputs) + + non_spec_outputs = get_outputs(non_spec_decode_llm_generator, prompts, + sampling_params) + non_spec_output_text, non_spec_output_token_ids = get_tokens_and_text( + non_spec_outputs) + + for i, prompt in enumerate(prompts): + non_spec_text = non_spec_output_text[i] + non_spec_token_ids = non_spec_output_token_ids[i] + + spec_text = spec_output_text[i] + spec_token_ids = spec_output_token_ids[i] + + print(f"{i=} {prompt=}") + print(f"{i=} {non_spec_text=}") + print(f"{i=} {spec_text=}") + print(f"{i=} {non_spec_token_ids=}") + print(f"{i=} {spec_token_ids=}") + + for i, prompt in enumerate(prompts): + non_spec_token_ids = non_spec_output_token_ids[i] + spec_token_ids = spec_output_token_ids[i] + assert non_spec_token_ids == spec_token_ids, f"{i=}" diff --git a/tests/spec_decode/test_smoke.py b/tests/spec_decode/test_smoke.py new file mode 100644 index 0000000000000..be006841cfd2c --- /dev/null +++ b/tests/spec_decode/test_smoke.py @@ -0,0 +1,259 @@ +"""High-level speculative decoding tests.""" +import time +import pytest +from typing import Generator +from itertools import cycle + +from vllm import LLM, SamplingParams +from tests.spec_decode.utils import get_outputs, get_tokens_and_text + + +@pytest.mark.parametrize("draft_model", ["JackFram/llama-68m"]) +@pytest.mark.parametrize("target_model", ["JackFram/llama-160m"]) +@pytest.mark.parametrize("num_speculative_tokens", [1, 2, 3, 4, 5]) +@pytest.mark.parametrize("output_len", [1024]) +@pytest.mark.parametrize("temperature", [0.0, 0.5, 1.0]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("with_cuda_graph", [False]) +@pytest.mark.parametrize("speculative_model_uses_tp_1", [False]) +@pytest.mark.parametrize("disable_shared_memory", [True]) +def test_smoke_no_crash(spec_decode_llm: LLM, output_len: int, + temperature: float): + """Validate that speculative decoding does not crash while generating a non- + trivial number of tokens. This is a high-level test that validates different + values of K and temperatures work over many generated tokens. + """ + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print("Starting generation") + start_time = time.time() + outputs = spec_decode_llm.generate(prompts, + sampling_params, + use_tqdm=False) + dur_ms = 1000 * (time.time() - start_time) + num_output_tokens = len(outputs[0].outputs[0].token_ids) + print(f"generated {num_output_tokens} tokens in {dur_ms=:.02f}") + print(f"ms/tok {dur_ms/num_output_tokens:.02f}") + + +@pytest.mark.parametrize("draft_model", ["JackFram/llama-68m"]) +@pytest.mark.parametrize("target_model", ["meta-llama/Llama-2-7b-chat-hf"]) +@pytest.mark.parametrize("num_speculative_tokens", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("with_cuda_graph", [False]) +@pytest.mark.parametrize("speculative_model_uses_tp_1", [False]) +@pytest.mark.parametrize("disable_shared_memory", [True]) +def test_correctness_bs_1(spec_decode_llm_generator: Generator[LLM, None, + None], + non_spec_decode_llm_generator: Generator[LLM, None, + None]): + """High-level test that validates exact equality between normal decoding and + speculative decoding. This is done via greedy sampling. + + Note that speculative decoding guarantees exact equality up to hardware + numerics. The configuration tested here does not encounter numeric + limitations, but may if ran on different hardware. + """ + prompt = "The president of the United States is" + output_len = 128 + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=output_len, + ignore_eos=True, + ) + + def evaluate(generator): + for llm in generator: + outputs = llm.generate([prompt], sampling_params, use_tqdm=False) + token_ids = outputs[0].outputs[0].token_ids + del llm + return token_ids + + spec_token_ids = evaluate(spec_decode_llm_generator) + non_spec_token_ids = evaluate(non_spec_decode_llm_generator) + + print(f"{len(spec_token_ids)=} {spec_token_ids=}") + print(f"{len(non_spec_token_ids)=} {non_spec_token_ids=}") + assert spec_token_ids == non_spec_token_ids + + +@pytest.mark.parametrize("draft_model", ["JackFram/llama-68m"]) +@pytest.mark.parametrize("target_model", ["meta-llama/Llama-2-7b-chat-hf"]) +@pytest.mark.parametrize("num_speculative_tokens", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("with_cuda_graph", [False]) +@pytest.mark.parametrize("speculative_model_uses_tp_1", [False, True]) +@pytest.mark.parametrize("disable_shared_memory", [True]) +def test_correctness_bs_gt_1(spec_decode_llm_generator: Generator[LLM, None, + None], + non_spec_decode_llm_generator: Generator[LLM, + None, + None]): + """High-level test that validates exact correctness on a large batch size. + Each sequence is compared with normal decoding and speculative decoding, and + output tokens are compared one-by-one. + + See test_correctness_bs_1 for note on speculative decoding exact equality + and hardware numerics. + """ + base_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + batch_size = 64 + prompts = [ + prompt for prompt, _ in zip(cycle(base_prompts), range(batch_size)) + ] + + output_len = 32 + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=output_len, + ignore_eos=True, + ) + + spec_outputs = get_outputs(spec_decode_llm_generator, prompts, + sampling_params) + non_spec_outputs = get_outputs(non_spec_decode_llm_generator, prompts, + sampling_params) + + spec_text_outputs, spec_output_token_ids = get_tokens_and_text( + spec_outputs) + non_spec_text_outputs, non_spec_output_token_ids = get_tokens_and_text( + non_spec_outputs) + + for i, (prompt, spec_text, spec_token_ids, non_spec_text, + non_spec_token_ids) in enumerate( + zip(prompts, spec_text_outputs, spec_output_token_ids, + non_spec_text_outputs, non_spec_output_token_ids)): + print(f"{i=} {prompt=}") + print(f" {spec_text=}") + print(f"{non_spec_text=}") + print(f"{spec_token_ids=}") + print(f"{non_spec_token_ids=}") + + for i, (prompt, spec_text, spec_token_ids, non_spec_text, + non_spec_token_ids) in enumerate( + zip(prompts, spec_text_outputs, spec_output_token_ids, + non_spec_text_outputs, non_spec_output_token_ids)): + assert spec_token_ids == non_spec_token_ids, f"{i=}" + + +@pytest.mark.parametrize("draft_model", ["JackFram/llama-68m"]) +@pytest.mark.parametrize("target_model", ["JackFram/llama-160m"]) +@pytest.mark.parametrize("num_speculative_tokens", [5]) +@pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("with_cuda_graph", [False]) +@pytest.mark.parametrize("disable_shared_memory", [False]) +@pytest.mark.parametrize("max_model_len", [200]) +def test_correctness_model_truncation( + max_model_len_spec_decode_generator: Generator[LLM, None, None], + max_model_len_llm_generator: Generator[LLM, None, + None], max_model_len: int): + """Test correct generation when output must be truncated by max model len. + """ + + sampling_params = SamplingParams( + max_tokens=max_model_len + 50, + ignore_eos=True, + temperature=0.0, + ) + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + print("Starting generation") + spec_outputs = get_outputs(max_model_len_spec_decode_generator, prompts, + sampling_params) + spec_output_text, spec_output_token_ids = get_tokens_and_text(spec_outputs) + + non_spec_outputs = get_outputs(max_model_len_llm_generator, prompts, + sampling_params) + non_spec_output_text, non_spec_output_token_ids = get_tokens_and_text( + non_spec_outputs) + + for i, prompt in enumerate(prompts): + non_spec_text = non_spec_output_text[i] + non_spec_token_ids = non_spec_output_token_ids[i] + + spec_text = spec_output_text[i] + spec_token_ids = spec_output_token_ids[i] + + print(f"{i=} {prompt=}") + print(f"{i=} {non_spec_text=}") + print(f"{i=} {spec_text=}") + print(f"{i=} {non_spec_token_ids=}") + print(f"{i=} {spec_token_ids=}") + + for i, prompt in enumerate(prompts): + non_spec_text = non_spec_output_text[i] + non_spec_token_ids = non_spec_output_token_ids[i] + + spec_text = spec_output_text[i] + spec_token_ids = spec_output_token_ids[i] + + assert non_spec_token_ids == spec_token_ids, f"{i=}" + + +def test_large_enough_cuda_graph_input(): + """Verify no crash when using CUDA graphs, particularly when the number of + decode tokens is configured to exceed the nominal batch size. This happens + when many tokens are accepted; the draft model must process up to bs*(k+1) + tokens, the target model must process up to bs*2*(k+1). + """ + draft_padding_size = 8 + target_padding_size = draft_padding_size + + batch_size = draft_padding_size * 3 + + llm = LLM( + # By setting the draft model and target model to the same model, we + # should get a 100% acceptance rate. + # This will maximize the number of decode tokens in the draft model + # and target model forward passes. + # Since the batch size is set to a multiple of the padding size, this + # guarantees that we'll exceed the cuda graph input size unless it + # accounts for extra speculative decode tokens. + model="JackFram/llama-68m", + speculative_model="JackFram/llama-68m", + tensor_parallel_size=1, + num_speculative_tokens=3, + worker_use_ray=True, + enable_cuda_graph=True, + max_num_seqs=batch_size, + target_model_input_padding_size=target_padding_size, + draft_model_input_padding_size=draft_padding_size, + ) + + base_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + prompts = [ + prompt for prompt, _ in zip(cycle(base_prompts), range(batch_size)) + ] + + output_len = 32 + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=output_len, + ignore_eos=True, + ) + llm.generate(prompts, sampling_params, use_tqdm=False) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py new file mode 100644 index 0000000000000..09d96b541d0ab --- /dev/null +++ b/tests/spec_decode/utils.py @@ -0,0 +1,53 @@ + +from typing import List +from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, + nvmlInit) +import time + + +def get_outputs(generator, prompts, sampling_params): + for llm in generator: + outputs = llm.generate(prompts, sampling_params, use_tqdm=False) + del llm + return outputs + + +def get_tokens_and_text(outputs): + all_text = [] + all_token_ids = [] + for request_output in outputs: + for completion in request_output.outputs: + all_text.append(completion.text) + all_token_ids.append(completion.token_ids) + return all_text, all_token_ids + + +def wait_for_gpu_memory_to_clear(devices: List[int], + threshold_bytes: int, + timeout_s: float = 120) -> None: + # Use nvml instead of pytorch to reduce measurement error from torch cuda + # context. + nvmlInit() + start_time = time.time() + while True: + output = {} + output_raw = {} + for device in devices: + dev_handle = nvmlDeviceGetHandleByIndex(device) + mem_info = nvmlDeviceGetMemoryInfo(dev_handle) + gb_used = mem_info.used / 2**30 + output_raw[device] = gb_used + output[device] = f'{gb_used:.02f}' + + print('gpu memory used (GB): ', end='') + for k, v in output.items(): + print(f'{k}={v}; ', end='') + print('') + if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()): + break + + if time.time() - start_time >= timeout_s: + raise ValueError(f'Memory of devices {devices=} not free after ' + f'{timeout_s=} ({threshold_bytes/2**30=})') + + time.sleep(5) diff --git a/tests/test_sequence.py b/tests/test_sequence.py new file mode 100644 index 0000000000000..08c13a8e6a352 --- /dev/null +++ b/tests/test_sequence.py @@ -0,0 +1,164 @@ + +import pytest +from vllm.sequence import SequenceData, Sequence + + +@pytest.fixture(name="sequence") +def create_sequence(seq_len: int, block_size: int) -> Sequence: + return Sequence( + seq_id=0, + prompt="", + prompt_token_ids=list(range(seq_len)), + block_size=block_size, + ) + + +@pytest.mark.parametrize("block_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_empty_slots", list(range(8))) +@pytest.mark.parametrize("seq_len", [0, 1, 100]) +def test_ensure_num_empty_slots(block_size: int, seq_len: int, + num_empty_slots: int, sequence: Sequence): + """Verify ensure_num_empty_slots correctly ensures empty slots. + """ + sequence.ensure_num_empty_slots(num_empty_slots) + + num_total_slots = block_size * len(sequence.logical_token_blocks) + measured_num_empty_slots = sum(block.get_num_empty_slots() + for block in sequence.logical_token_blocks) + num_full_slots = num_total_slots - measured_num_empty_slots + + assert measured_num_empty_slots >= num_empty_slots + assert num_full_slots == seq_len + + +@pytest.fixture(name="sequence_with_extra_blocks") +def add_blocks_to_sequence(sequence: Sequence, + num_extra_blocks: int) -> Sequence: + for _ in range(num_extra_blocks): + sequence._append_logical_block() # pylint: disable=protected-access + return sequence + + +@pytest.mark.parametrize("num_tokens_to_append", [1, 10]) +@pytest.mark.parametrize("seq_len", [0, 1, 100]) +@pytest.mark.parametrize("block_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_extra_blocks", [0, 1, 100]) +def test_append_tokens_correct_placement_in_blocks( + num_tokens_to_append: int, sequence_with_extra_blocks: Sequence, + block_size: int, seq_len: int): + """Verify new tokens are appended at the end of the sequence, instead of the + last block. This enables preallocated empty slots, which requires empty + blocks after the sequence. + """ + token_ids = list(range(num_tokens_to_append)) + logprobs = [{token_id: 0.0} for token_id in token_ids] + seq_len_before_append = seq_len + seq_len_after_append = seq_len_before_append + num_tokens_to_append + + sequence_with_extra_blocks.append_token_ids(token_ids, logprobs) + + # Assert number of full slots equal to total sequence length. + assert sum(block_size - block.get_num_empty_slots() + for block in sequence_with_extra_blocks.logical_token_blocks + ) == seq_len_after_append + + # Assert each appended token is immediately after the original sequence. + for i, token_id in enumerate(token_ids): + index = seq_len_before_append + i + block_token_ids = sequence_with_extra_blocks.logical_token_blocks[ + index // block_size].get_token_ids() + assert block_token_ids[index % block_size] == token_id + + +@pytest.mark.parametrize("generation_or_prefill", ["generation", "prefill"]) +@pytest.mark.parametrize("num_output_tokens", [0, 1, 10]) +@pytest.mark.parametrize("num_prompt_tokens", [5, 50]) +def test_get_unprocessed_tokens(generation_or_prefill: str, + num_output_tokens: int, + num_prompt_tokens: int): + """Verify sequence data correctly tracks the number of processed tokens. + """ + is_generation = generation_or_prefill == "generation" + + if is_generation: + generated_token_id = 1337 + + prompt_token_ids = list(range(num_prompt_tokens)) + output_token_ids = list(range(num_output_tokens)) + data = SequenceData( + token_ids=prompt_token_ids[:] + output_token_ids[:], + num_prompt_tokens=len(prompt_token_ids[:]), + ) + + if is_generation: + data.append_token_ids([generated_token_id], logprobs=[0.0]) + + unprocessed_token_ids = data.get_unprocessed_token_ids() + unprocessed_token_positions = data.get_unprocessed_token_positions() + + if is_generation: + assert unprocessed_token_ids == [generated_token_id] + assert unprocessed_token_positions == [ + num_prompt_tokens + num_output_tokens + ] + else: + assert unprocessed_token_ids == prompt_token_ids + output_token_ids + assert unprocessed_token_positions == list( + range(num_prompt_tokens + num_output_tokens)) + + # Reset processed tokens. Everything should behave like a prompt run now. + data.reset_processed_tokens() + + unprocessed_token_ids = data.get_unprocessed_token_ids() + unprocessed_token_positions = data.get_unprocessed_token_positions() + + if is_generation: + assert unprocessed_token_ids == (prompt_token_ids + output_token_ids + + [generated_token_id]) + assert unprocessed_token_positions == list( + range(num_prompt_tokens + num_output_tokens + 1)) + if not is_generation: + assert unprocessed_token_ids == prompt_token_ids + output_token_ids + assert unprocessed_token_positions == list( + range(num_prompt_tokens + num_output_tokens)) + + +def test_sequence_data_prefill(): + seq_data = SequenceData(prompt_token_ids=[1, 2, 3, 4], output_token_ids=[]) + assert seq_data.get_prefill_range() == (0, 0) + assert seq_data.get_num_unprefilled() == 4 + + # advance by 2 + assert seq_data.advance_prefill_range(2) == 2 + assert seq_data.get_num_unprefilled() == 2 + assert seq_data.get_prefill_range() == (0, 2) + + # advance range by 3 even though there are only 2 unprefilled tokens + assert seq_data.advance_prefill_range(3) == 2 + assert seq_data.get_num_unprefilled() == 0 + assert seq_data.get_prefill_range() == (2, 4) + + # following advances should not change anything + assert seq_data.advance_prefill_range(2) == 0 + assert seq_data.get_num_unprefilled() == 0 + assert seq_data.get_prefill_range() == (4, 4) + + # append tokens and reset, simulating recompute + seq_data.append_token_ids([1], logprobs=[0.0]) + seq_data.reset_processed_tokens() + + # after reset, the prefill range should be reset to 0 + # but the num_unprefilled should include. + # output tokens + assert seq_data.get_prefill_range() == (0, 0) + assert seq_data.get_num_unprefilled() == 5 + + # advance by 2 + assert seq_data.advance_prefill_range(2) == 2 + assert seq_data.get_num_unprefilled() == 3 + assert seq_data.get_prefill_range() == (0, 2) + + # advance range by 3 even though there are only 2 unprefilled tokens + assert seq_data.advance_prefill_range(3) == 3 + assert seq_data.get_num_unprefilled() == 0 + assert seq_data.get_prefill_range() == (2, 5) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000000000..b95865dd25bcf --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,4 @@ + + +def round_up_to_next_block(seq_len: int, block_size: int) -> int: + return (seq_len + block_size - 1) // block_size diff --git a/tests/worker/__init__.py b/tests/worker/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/worker/test_draft_target_worker.py b/tests/worker/test_draft_target_worker.py new file mode 100644 index 0000000000000..e3c296421ee5a --- /dev/null +++ b/tests/worker/test_draft_target_worker.py @@ -0,0 +1,407 @@ + +import torch +import random +import pytest + +from vllm.worker.draft_target_worker import DraftTargetWorker +from vllm.model_executor.utils import set_random_seed +from vllm.sequence import SequenceGroupMetadata + +from .utils import (mock_worker, create_seq_group_metadata_from_prompts, + create_batch, create_sampler_output_list) + +from unittest.mock import MagicMock + + +def test_get_all_seq_ids(): + """Verify get_all_seq_ids extracts all seq ids. + """ + worker = DraftTargetWorker(mock_worker(), mock_worker(), MagicMock()) + + expected_seq_ids = list(range(10)) + list(range(100, 110)) + + seq_group_metadata_list = [ + SequenceGroupMetadata( + request_id=str(seq_id), + is_prompt=True, + is_chunked_prefill=False, + seq_data={ + seq_id: MagicMock(), + }, + sampling_params=MagicMock(), + block_tables={ + seq_id: MagicMock(), + }, + lora_request=None, + ) for seq_id in expected_seq_ids + ] + + actual_seq_ids = worker._get_all_seq_ids(seq_group_metadata_list) # pylint: disable=protected-access + assert actual_seq_ids == expected_seq_ids + + +@pytest.mark.parametrize('num_target_seq_ids', [100]) +def test_create_target_seq_id_iterator(num_target_seq_ids: int): + """Assert all target seq ids are greater than input seq ids. + """ + worker = DraftTargetWorker(mock_worker(), mock_worker(), MagicMock()) + + all_seq_ids = [ + [1, 3, 5, 7], + list(range(100)) + [0], + [100], + ] + + for seq_ids in all_seq_ids: + max_seq_id = max(seq_ids) + iterator = worker._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access + for _ in range(num_target_seq_ids): + assert next(iterator) > max_seq_id + + +@pytest.mark.parametrize('k', [1, 2, 6]) +def test_get_token_ids_to_score(k: int): + """Verify DraftTargetWorker correctly determines which token ids need + to be scored. + """ + proposal_token_ids = torch.tensor( + list(range(k)), + dtype=torch.int64, + device='cuda', + ) + + expected_output = [ + [], + ] + for i in range(proposal_token_ids.shape[0]): + expected_output.append(proposal_token_ids[:i + 1].tolist()) + + worker = DraftTargetWorker(mock_worker(), mock_worker(), MagicMock()) + actual_output = worker._get_token_ids_to_score(proposal_token_ids) # pylint: disable=protected-access + + actual_output = [ + x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output + ] + + assert actual_output == expected_output + + +@pytest.mark.parametrize('k', [1, 2, 6]) +def test_create_single_target_seq_group_metadata(k: int): + """Verify correct creation of a target seq group metadata. + """ + + prompt_tokens = [1, 2, 3] + prev_output_tokens = [4, 5, 6] + + token_ids = list(range(k)) + + num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1 + + final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len( + token_ids) + + block_size = 32 + input_seq_group_metadata = create_seq_group_metadata_from_prompts( + [prompt_tokens], 2048 // block_size, block_size, [final_seq_len], + [prev_output_tokens], [num_tokens_processed])[0] + + input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0] + target_seq_id = 100 + + worker = DraftTargetWorker(mock_worker(), mock_worker(), MagicMock()) + output = worker._create_single_target_seq_group_metadata( # pylint: disable=protected-access + input_seq_group_metadata, + input_seq_id, + target_seq_id, + token_ids, + ) + + assert output.request_id == input_seq_group_metadata.request_id + assert len(output.seq_data) == 1 + assert output.seq_data[target_seq_id].get_prompt_token_ids( + ) == prompt_tokens + assert output.seq_data[target_seq_id].get_output_token_ids( + ) == prev_output_tokens + token_ids + + assert output.seq_data[target_seq_id].get_num_processed_token_ids( + ) == num_tokens_processed + k + + assert len(output.block_tables) == 1 + assert output.block_tables[ + target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id] + + +@pytest.mark.parametrize('k', [1, 2, 6]) +@pytest.mark.parametrize('batch_size', [1, 2, 32]) +@torch.inference_mode() +def test_correctly_calls_draft_model(k: int, batch_size: int): + """Verify that the DraftTargetWorker calls the draft model with correct + inputs. Everything else is mocked out. + """ + + draft_worker = mock_worker() + target_worker = mock_worker() + rejection_sampler = MagicMock() + worker = DraftTargetWorker(draft_worker, target_worker, rejection_sampler) + + exception_secret = 'artifical stop' + draft_worker.execute_model.side_effect = ValueError(exception_secret) + + execute_model_data, _, _ = create_batch(batch_size, k) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_data) + + call_args_list = draft_worker.execute_model.call_args_list + assert len(call_args_list) == 1 + + for args, _ in call_args_list: + (actual_execute_model_data, ) = args + assert actual_execute_model_data == execute_model_data + + +@pytest.mark.parametrize('k', [1, 2, 6]) +@pytest.mark.parametrize('batch_size', [1, 2, 32]) +@torch.inference_mode() +def test_correctly_calls_target_model(k: int, batch_size: int): + """Verify that the DraftTargetWorker calls the target model with correct + inputs. Everything else is mocked out. + """ + draft_worker = mock_worker() + target_worker = mock_worker() + rejection_sampler = MagicMock() + rejection_sampler.token_id_dtype = torch.int64 + draft_worker.device = 'cuda' + target_worker.device = 'cuda' + + set_random_seed(1) + + worker = DraftTargetWorker(draft_worker, target_worker, rejection_sampler) + + vocab_size = 32_000 + + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(k, batch_size), + dtype=torch.int64, + device='cuda') + draft_token_probs = torch.rand(k, + batch_size, + vocab_size, + dtype=torch.float32, + device='cuda') + + draft_output = create_sampler_output_list(draft_token_ids, + draft_token_probs) + draft_worker.execute_model.return_value = draft_output + execute_model_data, prompts, prev_output_tokens = create_batch( + batch_size, k) + + exception_secret = 'artifical stop' + target_worker.execute_model.side_effect = ValueError(exception_secret) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_data) + + seen_contexts = [] + + call_args_list = target_worker.execute_model.call_args_list + assert len(call_args_list) == 1 + for args, _ in call_args_list: + (target_execute_model_data, ) = args + + assert len(target_execute_model_data.seq_group_metadata_list) == ( + k + 1) * batch_size + for seq_group_metadata in ( + target_execute_model_data.seq_group_metadata_list): + for seq_data in seq_group_metadata.seq_data.values(): + seen_contexts.append(seq_data.get_token_ids()) + + expected_seen_contexts = [] + + for prompt, prev_generated, draft_tokens in zip( + prompts, prev_output_tokens, + draft_token_ids.transpose(0, 1).tolist()): + + for i in range(len(draft_tokens) + 1): + expected_seen_contexts.append(prompt + prev_generated + + draft_tokens[:i]) + + seen_contexts.sort() + expected_seen_contexts.sort() + assert expected_seen_contexts == seen_contexts + + +@pytest.mark.parametrize('k', [1, 2, 6]) +@pytest.mark.parametrize('batch_size', [1, 2, 32]) +@torch.inference_mode() +def test_correctly_calls_rejection_sampler(k: int, batch_size: int): + """Verify that the DraftTargetWorker calls the rejection sampler with + correct inputs. Everything else is mocked out. + """ + vocab_size = 32_000 + + draft_worker = mock_worker(vocab_size) + target_worker = mock_worker(vocab_size) + rejection_sampler = MagicMock() + rejection_sampler.token_id_dtype = torch.int64 + draft_worker.device = 'cuda' + target_worker.device = 'cuda' + + set_random_seed(1) + + worker = DraftTargetWorker(draft_worker, target_worker, rejection_sampler) + + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(k, batch_size), + dtype=torch.int64, + device='cuda') + draft_token_probs = torch.rand(k, + batch_size, + vocab_size, + dtype=torch.float32, + device='cuda') + + draft_output = create_sampler_output_list(draft_token_ids, + draft_token_probs) + draft_worker.execute_model.return_value = draft_output + execute_model_data, _, _ = create_batch(batch_size, k) + + target_token_ids = torch.randint(low=0, + high=vocab_size, + size=(1, batch_size * (k + 1)), + dtype=torch.int64, + device='cuda') + target_token_probs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') + target_output = create_sampler_output_list(target_token_ids, + target_token_probs) + + target_worker.execute_model.return_value = target_output[0] + + exception_secret = 'artifical stop' + rejection_sampler.side_effect = ValueError(exception_secret) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_data) + + assert len(rejection_sampler.call_args_list) == 1 + args, _ = rejection_sampler.call_args_list[0] + (actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs, + actual_proposal_token_ids) = args + + assert torch.equal(actual_bonus_token_ids, + target_token_ids.reshape(batch_size, k + 1)[:, -1:]) + assert torch.equal( + actual_proposal_scores, + target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) + assert torch.equal(actual_proposal_token_ids, + draft_token_ids.transpose(0, 1)) + assert torch.equal(actual_proposal_probs, + draft_token_probs.transpose(0, 1)) + + +@pytest.mark.parametrize('k', [1, 2, 6]) +@pytest.mark.parametrize('batch_size', [1, 2, 32]) +@torch.inference_mode() +def test_correctly_formats_output(k: int, batch_size: int): + """Verify that the DraftTargetWorker formats rejection sampler output + correctly. Everything else is mocked out. + """ + vocab_size = 32_000 + + draft_worker = mock_worker(vocab_size) + target_worker = mock_worker(vocab_size) + rejection_sampler = MagicMock() + rejection_sampler.token_id_dtype = torch.int64 + draft_worker.device = 'cuda' + target_worker.device = 'cuda' + + set_random_seed(1) + + worker = DraftTargetWorker(draft_worker, target_worker, rejection_sampler) + + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(k, batch_size), + dtype=torch.int64, + device='cuda') + draft_token_probs = torch.rand(k, + batch_size, + vocab_size, + dtype=torch.float32, + device='cuda') + + draft_output = create_sampler_output_list(draft_token_ids, + draft_token_probs) + draft_worker.execute_model.return_value = draft_output + execute_model_data, _, _ = create_batch(batch_size, k) + + target_token_ids = torch.randint(low=0, + high=vocab_size, + size=(1, batch_size * (k + 1)), + dtype=torch.int64, + device='cuda') + target_token_probs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') + target_output = create_sampler_output_list(target_token_ids, + target_token_probs) + + target_worker.execute_model.return_value = target_output[0] + + rejection_sampler_output = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k + 1), + dtype=torch.int64, + device='cuda') + for i in range(batch_size): + rejection_sampler_output[i][-random.randint(0, k + 1):] = -1 + + rejection_sampler.return_value = rejection_sampler_output + + output = worker.execute_model(execute_model_data) + + expected_output = create_sampler_output_list( + rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)]) + + seq_ids = [ + next(iter(seq_group_metadata.seq_data.keys())) + for seq_group_metadata in execute_model_data.seq_group_metadata_list + ] + actual_output_by_seq = {seq_id: [] for seq_id in seq_ids} + expected_output_by_seq = {seq_id: [] for seq_id in seq_ids} + + for step in output: + for seq_group in step: + for sample in seq_group.samples: + seq_id = sample.parent_seq_id + actual_output_by_seq[seq_id].append(sample) + + for step in expected_output: + for seq_group in step: + for sample in seq_group.samples: + seq_id = sample.parent_seq_id + expected_output_by_seq[seq_id].append(sample) + + all_seen_seq_ids = set( + list(actual_output_by_seq.keys()) + + list(expected_output_by_seq.keys())) + for seq_id in all_seen_seq_ids: + actual_by_step = actual_output_by_seq[seq_id] + expected_by_step = expected_output_by_seq[seq_id] + + for i in range(k + 1): + if i >= len(actual_by_step): + assert expected_by_step[i].output_token == -1 + continue + assert actual_by_step[i].output_token == expected_by_step[ + i].output_token + assert actual_by_step[i].logprobs == expected_by_step[i].logprobs diff --git a/tests/worker/test_multi_step_worker.py b/tests/worker/test_multi_step_worker.py new file mode 100644 index 0000000000000..b271b7534948c --- /dev/null +++ b/tests/worker/test_multi_step_worker.py @@ -0,0 +1,300 @@ + +import torch +import random +import pytest +from unittest.mock import MagicMock + +from vllm.worker.multi_step_worker import MultiStepWorker +from vllm.worker.worker import Worker +from vllm.model_executor.utils import set_random_seed +from vllm.sequence import ExecuteModelData + +from .utils import (create_execute_model_data, create_worker, + create_seq_group_metadata_from_prompts, zero_kv_cache, + patch_execute_model_with_seeds, + assert_logprobs_dict_allclose) + + +@pytest.mark.parametrize('num_steps', list(range(1, 17))) +def test_assert_enough_kv_space(num_steps: int): + """Test that the multi step worker checks for sufficient space in the KV + cache. It should throw if it cannot run all the steps. + """ + block_size = 16 + num_gpu_blocks = 2048 // block_size + + prompts = [ + list(range(block_size * 3)), + list(range(block_size * 2)), + ] + + prev_output_tokens = [ + list(range(block_size * 1)), + list(range(block_size * 2)), + ] + + final_seq_lens = [ + len(prompt + output) + num_steps + for prompt, output in zip(prompts, prev_output_tokens) + ] + + inputs = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_seq_lens, + continuations=prev_output_tokens) + + assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access + worker = MagicMock() + worker.block_size = block_size + + for seq_group_metadata in inputs: + original_block_tables = seq_group_metadata.block_tables + + # No exception. + assert_enough_kv_space(worker, inputs, num_steps) + + seq_group_metadata.block_tables = { + seq_id: [] + for seq_id, physical_blocks in original_block_tables.items() + } + + # Expect exception. + with pytest.raises(ValueError, + match='times but found insufficient KV space for'): + assert_enough_kv_space(worker, inputs, num_steps) + + seq_group_metadata.block_tables = original_block_tables + + +@torch.inference_mode() +def test_same_output_for_single_step(): + """Verify the multi step worker produces the same output as the normal + worker for num_steps=1. + """ + seed = 100 + model_name = 'JackFram/llama-68m' + + block_size = 32 + num_gpu_blocks = 2048 // block_size + multi_step_worker = create_worker( + MultiStepWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + worker = create_worker( + Worker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + multi_step_worker.model = worker.model + multi_step_worker.cache_engine = worker.cache_engine + + num_steps = 1 + + prompts = [ + [1, 2, 3, 4, 5], + [6, 7, 8, 9, 10], + ] + + final_seq_lens = [len(prompt) + num_steps for prompt in prompts] + + multi_step_execute_model_data = create_execute_model_data( + seq_group_metadata_list=create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, + final_seq_lens=final_seq_lens), + num_preallocated_slots=num_steps - 1) + + single_step_execute_model_data = create_execute_model_data( + seq_group_metadata_list=create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, + final_seq_lens=final_seq_lens)) + + zero_kv_cache(multi_step_worker.cache_engine) + set_random_seed(seed) + actual_output = multi_step_worker.execute_model( + multi_step_execute_model_data) + assert len(actual_output) == num_steps + actual_output = actual_output[0] + + zero_kv_cache(worker.cache_engine) + set_random_seed(seed) + expected_output, = worker.execute_model(single_step_execute_model_data) + + actual_token_ids = [ + output.samples[0].output_token for output in actual_output.outputs + ] + actual_logprobs = [ + output.samples[0].logprobs for output in actual_output.outputs + ] + + expected_token_ids = [ + output.samples[0].output_token for output in expected_output.outputs + ] + expected_logprobs = [ + output.samples[0].logprobs for output in expected_output.outputs + ] + + assert actual_token_ids == expected_token_ids + + print(f'{actual_logprobs=}') + print(f'{expected_logprobs=}') + assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs) + + +@torch.inference_mode() +def test_same_output_for_multi_step(): + """Verify the multi-step worker produces the same output as the normal + worker when num_steps > 1. This test runs the multi-step worker once, and + then runs the worker num_steps times, and compares the output. + """ + seed = 100 + model_name = 'JackFram/llama-68m' + + block_size = 16 + num_gpu_blocks = 2048 // block_size + multi_step_worker = create_worker( + MultiStepWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + worker = create_worker( + Worker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + # Make sure we go over the block boundary. + num_steps = block_size + 1 + + random.seed(seed) + prompts = [[ + random.randint(0, 1000) for _ in range(random.randint(10, 20)) + ] for _ in range(10)] + + final_seq_lens = [len(prompt) + num_steps for prompt in prompts] + + rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) + multi_step_worker.execute_model = patch_execute_model_with_seeds( + multi_step_worker, rand_seeds) + worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) + + continuations = [[1] for _ in prompts] + execute_model_data = create_execute_model_data( + create_seq_group_metadata_from_prompts(prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_seq_lens=final_seq_lens), + num_preallocated_slots=num_steps - 1) + + # Run multi-step. + zero_kv_cache(multi_step_worker.cache_engine) + set_random_seed(seed) + multi_step_output = multi_step_worker.execute_model(execute_model_data) + + # Run single-step repeatedly. + zero_kv_cache(worker.cache_engine) + single_step_output = [] + continuations = [[1] for _ in prompts] + set_random_seed(seed) + + for _ in multi_step_output: + + execute_model_data = create_execute_model_data( + create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_seq_lens=final_seq_lens)) + + single_step_output.append(worker.execute_model(execute_model_data)[0]) + + # Append output tokens to new sequence data. + for i, seq_group_output in enumerate(single_step_output[-1]): + continuations[i].append(seq_group_output.samples[0].output_token) + + # Get token ids and logprobs for comparison. + multi_step_output_logprobs = [[] for _ in prompts] + single_step_output_logprobs = [[] for _ in prompts] + + multi_step_output_token_ids = [[] for _ in prompts] + single_step_output_token_ids = [[] for _ in prompts] + for i, _ in enumerate(prompts): + for multi_step, single_step in zip(multi_step_output, + single_step_output): + multi_step_output_token_ids[i].append( + multi_step[i].samples[0].output_token) + single_step_output_token_ids[i].append( + single_step[i].samples[0].output_token) + + multi_step_output_logprobs[i].append( + multi_step[i].samples[0].logprobs) + single_step_output_logprobs[i].append( + single_step[i].samples[0].logprobs) + + # Print per-sequence token ids + for i, (multi_step_tokens, single_step_tokens) in enumerate( + zip(multi_step_output_token_ids, single_step_output_token_ids)): + print(f'{multi_step_tokens=}') + print(f'{single_step_tokens=}') + print(f'equal {multi_step_tokens == single_step_tokens}') + + # Assert token ids are equal. + for multi_step_tokens, single_step_tokens in zip( + multi_step_output_token_ids, single_step_output_token_ids): + assert multi_step_tokens == single_step_tokens + + # Assert logprobs are equal. + for multi_step_logprobs, single_step_logprobs in zip( + multi_step_output_logprobs, single_step_output_logprobs): + assert_logprobs_dict_allclose(multi_step_logprobs, + single_step_logprobs) + + +@torch.inference_mode() +def test_handles_empty_batch(): + """Verify an empty input batch (but with finished_request_ids populated) + is handled correctly. + """ + seed = 100 + model_name = 'JackFram/llama-68m' + + block_size = 16 + num_gpu_blocks = 2048 // block_size + multi_step_worker = create_worker( + MultiStepWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + execute_model_data = ExecuteModelData( + seq_group_metadata_list=[], + finished_request_ids_list=['123'], + blocks_to_swap_in={}, + blocks_to_swap_out={}, + blocks_to_copy={}, + num_preallocated_slots=0, + ) + + multi_step_output = multi_step_worker.execute_model(execute_model_data) + + # A dummy output should be returned, with no actual contents. + assert len(multi_step_output) == 1 + sampler_output = multi_step_output[0] + assert not sampler_output.probs + assert not sampler_output.sampled_tokens + assert not sampler_output.outputs diff --git a/tests/worker/test_worker.py b/tests/worker/test_worker.py new file mode 100644 index 0000000000000..4a5667c254f0c --- /dev/null +++ b/tests/worker/test_worker.py @@ -0,0 +1,151 @@ + +# pylint: disable=protected-access +import math +import random + +import pytest +import torch +from types import SimpleNamespace +from unittest.mock import MagicMock + +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.worker.worker import Worker + +from .utils import (create_execute_model_data, create_worker, + create_seq_group_metadata_from_prompts) + + +@torch.inference_mode() +def test_prepare_inputs_can_save_multiple_tokens_per_sequence(): + """Verify prepare_inputs correctly encodes input such that + the model forward pass will save >1 token from the previous + iteration in the KV cache. + + This mocks out the actual model call. + """ + seed = 100 + block_size = 32 + num_gpu_blocks = 2048 // block_size + worker = create_worker(Worker, + model_name='JackFram/llama-68m', + seed=seed, + block_size=block_size, + num_gpu_blocks=num_gpu_blocks) + + prompts = [list(range(4)), list(range(10))] + prev_output_tokens = [list(range(2)), list(range(5))] + num_tokens_processed = [len(prompt) + 1 for prompt in prompts] + final_seq_lens = [ + len(prompt + output_tokens) + 1 + for prompt, output_tokens in zip(prompts, prev_output_tokens) + ] + num_missing_from_kv_cache = [ + len(prompt) + len(output_tokens) - num_processed + for prompt, output_tokens, num_processed in zip( + prompts, prev_output_tokens, num_tokens_processed) + ] + + print(f'{prompts=}') + print(f'{prev_output_tokens=}') + print(f'{num_missing_from_kv_cache=}') + + execute_model_data = create_execute_model_data( + create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=prev_output_tokens, + final_seq_lens=final_seq_lens, + num_tokens_processed=num_tokens_processed)) + + worker.captured_model = MagicMock() + worker.execute_model(execute_model_data) + + print(f'{worker.captured_model.execute_if_capturable.call_count=}') + + call_args_list = worker.captured_model.execute_if_capturable.call_args_list + assert len(call_args_list) == 1 + kwargs = SimpleNamespace(**call_args_list[0].kwargs) + + num_new_tokens = sum(num_missing_from_kv_cache) + padded_num_new_tokens = math.ceil(num_new_tokens / 8) * 8 + # Expect the number of tokens being saved to KV cache to equal the total + # number of tokens missing from KV cache. + assert kwargs.input_metadata.num_valid_tokens == padded_num_new_tokens + assert kwargs.input_metadata.num_prompt_tokens == 0 + assert kwargs.input_metadata.slot_mapping.shape[0] == padded_num_new_tokens + assert kwargs.input_metadata.num_generation_tokens == num_new_tokens + + expected_positions = [] + expected_input_ids = [] + for prompt_token_ids, output_tokens, num_tok_missing_from_kv_cache in zip( + prompts, prev_output_tokens, num_missing_from_kv_cache): + seq = prompt_token_ids + output_tokens + total_seq_len = len(seq) + for i in range(num_tok_missing_from_kv_cache): + position = total_seq_len - num_tok_missing_from_kv_cache + i + expected_positions.append(position) + expected_input_ids.append(seq[position]) + + print(f'{expected_positions=}') + print(f'{expected_input_ids=}') + + print(f'{kwargs.input_ids=}') + print(f'{kwargs.positions=}') + + # Assert input ids and positions (sans padding) equal to expected. + assert kwargs.input_ids[:sum(num_missing_from_kv_cache)].tolist( + ) == expected_input_ids + assert kwargs.positions[:sum(num_missing_from_kv_cache)].tolist( + ) == expected_positions + + +# @pytest.mark.skip("Skip for now") +def test_worker_prepare_inputs_for_prompt(): + seed = 100 + block_size = 16 + num_gpu_blocks = 2048 // block_size + worker = create_worker(Worker, + model_name='JackFram/llama-68m', + seed=seed, + block_size=block_size, + num_gpu_blocks=num_gpu_blocks) + for batch_size in range(256): + prompt_lens = [] + seq_group_metadata_list = [] + for i in range(batch_size): + # make sure all tokens fit into one block + prompt_len = i % (worker.block_size - 1) + 1 + prompt_lens.append(prompt_len) + seq_data = list(range(prompt_len)) + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={ + 0: + SequenceData(seq_data, + prefill_start=0, + prefill_end=prompt_len) + }, + sampling_params=SamplingParams(temperature=0), + block_tables={0: [1]}, + is_chunked_prefill=False, + lora_request=None, + )) + expected_selected_token_indices = [] + selected_token_start_idx = 0 + for prompt_len in prompt_lens: + expected_selected_token_indices.append(selected_token_start_idx + + prompt_len - 1) + selected_token_start_idx += prompt_len + input_tokens, input_positions, input_metadata, _, _ = worker._prepare_inputs( + seq_group_metadata_list) + assert input_tokens.shape == input_positions.shape == ( + math.ceil(sum(prompt_lens) / 8) * 8, ) + torch.testing.assert_close(input_tokens, input_positions) + actual = input_metadata.selected_token_indices + expected = torch.tensor(expected_selected_token_indices, + device=actual.device, + dtype=actual.dtype) + torch.testing.assert_close(actual, expected) diff --git a/tests/worker/utils.py b/tests/worker/utils.py new file mode 100644 index 0000000000000..bc8ccc96d3e2a --- /dev/null +++ b/tests/worker/utils.py @@ -0,0 +1,260 @@ + +import torch +from typing import List, Optional, Dict, Iterable +from unittest.mock import MagicMock +from itertools import count + +from vllm.worker.worker import Worker +from vllm.worker.base_worker import BaseWorker +from vllm.engine.ray_utils import initialize_cluster +from vllm.engine.arg_utils import EngineArgs +from vllm.sequence import ExecuteModelData, SequenceGroupMetadata, SequenceData, SamplerOutput, SequenceGroupOutputs, SequenceOutputs +from vllm.sampling_params import SamplingParams +from vllm.worker.cache_engine import CacheEngine +from vllm.model_executor.utils import set_random_seed +from tests.utils import round_up_to_next_block + + +def create_execute_model_data( + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_request_ids_list: Optional[List[str]] = None, + blocks_to_swap_in: Optional[Dict[int, int]] = None, + blocks_to_swap_out: Optional[Dict[int, int]] = None, + blocks_to_copy: Optional[Dict[int, int]] = None, + num_preallocated_slots: Optional[int] = None, +) -> ExecuteModelData: + + if finished_request_ids_list is None: + finished_request_ids_list = [] + if blocks_to_swap_in is None: + blocks_to_swap_in = {} + if blocks_to_swap_out is None: + blocks_to_swap_out = {} + if blocks_to_copy is None: + blocks_to_copy = {} + if num_preallocated_slots is None: + num_preallocated_slots = 0 + + return ExecuteModelData( + seq_group_metadata_list=seq_group_metadata_list, + finished_request_ids_list=finished_request_ids_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + num_preallocated_slots=num_preallocated_slots, + ) + + +def mock_worker(vocab_size: int = 30_000, + max_model_len: int = 2048) -> MagicMock: + worker = MagicMock() + worker.model.config.vocab_size = vocab_size + worker.model_config.max_model_len = max_model_len + return worker + + +def patch_execute_model_with_seeds(worker: BaseWorker, rand_seeds: List[int]): + seed_iter = iter(rand_seeds) + original_execute_model = worker.execute_model + + def new_execute_model(execute_model_data): + result = original_execute_model(execute_model_data) + set_random_seed(next(seed_iter)) + return result + + return new_execute_model + + +def zero_kv_cache(cache_engine: CacheEngine): + assert cache_engine.gpu_cache + for key_blocks, value_blocks in cache_engine.gpu_cache: + key_blocks.zero_() + value_blocks.zero_() + + +def create_worker(cls: type, model_name: str, block_size: int, + num_gpu_blocks: int, seed: int): + engine_args = EngineArgs( + model=model_name, + seed=seed, + block_size=block_size, + ) + + (model_config, cache_config, parallel_config, scheduler_config, _, _, + _) = engine_args.create_engine_configs() + + distributed_init_method, _ = initialize_cluster(parallel_config) + + worker = cls( + model_config=model_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + rank=0, + distributed_init_method=distributed_init_method, + ) + + worker.init_model() + + cache_config.num_gpu_blocks = num_gpu_blocks + cache_config.num_cpu_blocks = 0 + worker.init_cache_engine(cache_config) + + return worker + + +def create_seq_group_metadata_from_prompts( + prompts: List[List[int]], + num_gpu_blocks: int, + block_size: int, + final_seq_lens: List[int], + continuations: Optional[List[List[int]]] = None, + num_tokens_processed: Optional[List[int]] = None, + seq_ids: Optional[List[int]] = None, +) -> List[SequenceGroupMetadata]: + + if continuations is None: + continuations = [[] for _ in prompts] + + if num_tokens_processed is None: + # Default to 1 token missing from kv cache for generation sequences. + num_tokens_processed = [] + for continuation, prompt in zip(continuations, prompts): + # If prefill, then default to zero tokens processed. + if not continuation: + num_tokens_processed.append(0) + else: + # If generation, then default to all but one tokens processed. + num_tokens_processed.append( + len(continuation) + len(prompt) - 1) + + if seq_ids is None: + seq_ids = list(i for i, _ in enumerate(prompts)) + + free_gpu_blocks = list(range(num_gpu_blocks)) + + block_allocations = { + i: [ + free_gpu_blocks.pop() + for _ in range(round_up_to_next_block(final_len, block_size)) + ] + for i, final_len in enumerate(final_seq_lens) + } + + return [ + SequenceGroupMetadata( + request_id=str(i), + is_prompt=len(cont_token_ids) == 0, + is_chunked_prefill=False, + seq_data={ + i: + SequenceData(token_ids=prompt_token_ids[:] + cont_token_ids[:], + num_prompt_tokens=len(prompt_token_ids[:]), + num_processed_token_ids=num_tokens_saved, + prefill_start=0, + prefill_end=len(prompt_token_ids)), + }, + sampling_params=SamplingParams(temperature=0.0, ), + block_tables={i: block_allocations[i][:]}, + lora_request=None, + ) for i, (prompt_token_ids, cont_token_ids, num_tokens_saved) in + enumerate(zip(prompts, continuations, num_tokens_processed)) + ] + + +def create_workers(test_type: type, + reference_type: type = Worker, + seed: int = 100, + block_size: int = 32, + num_gpu_blocks: int = 2048 // 32, + model_name: str = 'JackFram/llama-68m'): + test_worker = create_worker( + test_type, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + reference_worker = create_worker( + reference_type, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + return test_worker, reference_worker + + +def get_output_tokens(outputs): + return [output[0].output_token for output in outputs.outputs] + + +def get_output_logprobs(outputs): + return [output[0].logprobs for output in outputs.outputs] + + +def assert_logprobs_dict_allclose( + actual_logprobs: List[Dict[int, float]], + expected_logprobs: List[Dict[int, float]]) -> None: + for single_step_actual_logprobs, single_step_expected_logprobs in zip( + actual_logprobs, expected_logprobs): + assert set(single_step_actual_logprobs.keys()) == set( + single_step_expected_logprobs.keys()) + for token_id in single_step_actual_logprobs.keys(): + actual = torch.tensor(single_step_actual_logprobs[token_id]) + expected = torch.tensor(single_step_expected_logprobs[token_id]) + assert torch.allclose(actual, expected) + + +def create_sampler_output_list( + token_ids: torch.Tensor, + probs: Iterable[Optional[torch.Tensor]], + seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]: + num_steps, batch_size = token_ids.shape + token_ids_by_step = token_ids.tolist() + + if seq_ids is None: + seq_ids = list(range(batch_size)) + + return [ + SamplerOutput(outputs=[ + SequenceGroupOutputs( + samples=[ + SequenceOutputs( + output_token=token_id, + parent_seq_id=seq_ids[seq_index], + logprobs={token_id: 0}, + ) + ], + prompt_logprobs=None, + ) for seq_index, token_id in enumerate(token_ids_by_step[step]) + ], + probs=probs[step], + sampled_tokens=token_ids[step]) + for step in range(num_steps) + ] + + +def create_batch(batch_size, + k, + prompt_len: int = 10, + prev_output_token_len: int = 10, + seq_ids: Optional[List[int]] = None): + block_size = 8 + num_gpu_blocks = 2048 // block_size + iterator = count() + prompts = [[next(iterator) for _ in range(prompt_len)] + for _ in range(batch_size)] + prev_output_tokens = [[ + next(iterator) for _ in range(prev_output_token_len) + ] for _ in range(batch_size)] + final_seq_lens = [ + len(prompt) + len(prev_output_token) + k + 1 + for prompt, prev_output_token in zip(prompts, prev_output_tokens) + ] + execute_model_data = create_execute_model_data( + create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks, + block_size, final_seq_lens, + prev_output_tokens, seq_ids), + num_preallocated_slots=k) + return execute_model_data, prompts, prev_output_tokens diff --git a/vllm/config.py b/vllm/config.py index ff9a1308a5c88..95349488f5cd6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,4 +1,6 @@ from typing import Optional, Union +from dataclasses import dataclass +import copy import os import torch @@ -6,7 +8,7 @@ from vllm.logger import init_logger from vllm.transformers_utils.config import get_config -from vllm.utils import get_cpu_memory, is_hip +from vllm.utils import get_cpu_memory logger = init_logger(__name__) @@ -49,31 +51,32 @@ class ModelConfig: output). If None, will be derived from the model. quantization: Quantization method that was used to quantize the model weights. If None, we assume the model weights are not quantized. - enforce_eager: Whether to enforce eager execution. If True, we will - disable CUDA graph and always execute the model in eager mode. - If False, we will use CUDA graph and eager execution in hybrid. - max_context_len_to_capture: Maximum context len covered by CUDA graphs. - When a sequence has context length larger than this, we fall back - to eager mode. + flash_style: Enable flash style page attention. This is only supported + by llama models. + max_chunked_prefill_len: The maximum length of tokens for prefill + requests. Longer requests will be chunked into multiple chunks. + -1 means no chunking (disabled). This features is only supported + for flash style attention. """ - def __init__( - self, - model: str, - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - download_dir: Optional[str], - load_format: str, - dtype: Union[str, torch.dtype], - seed: int, - revision: Optional[str] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - enforce_eager: bool = False, - max_context_len_to_capture: Optional[int] = None, - ) -> None: + def __init__(self, + model: str, + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + download_dir: Optional[str], + load_format: str, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str], + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + enable_cuda_graph: bool = False, + cuda_graph_max_context_len: int = 5000, + cuda_graph_cache_size: int = 10, + flash_style: bool = False, + max_chunked_prefill_len: int = -1, self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -84,8 +87,6 @@ def __init__( self.revision = revision self.tokenizer_revision = tokenizer_revision self.quantization = quantization - self.enforce_eager = enforce_eager - self.max_context_len_to_capture = max_context_len_to_capture if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": # download model from ModelScope hub, @@ -105,34 +106,24 @@ def __init__( self._verify_load_format() self._verify_tokenizer_mode() self._verify_quantization() - self._verify_cuda_graph() + + self.enable_cuda_graph = enable_cuda_graph + self.cuda_graph_max_context_len = cuda_graph_max_context_len + self.cuda_graph_cache_size = cuda_graph_cache_size + self.flash_style = flash_style + self.max_chunked_prefill_len = max_chunked_prefill_len + + self._verify_chunk_prefill() def _verify_load_format(self) -> None: load_format = self.load_format.lower() - supported_load_format = [ - "auto", "pt", "safetensors", "npcache", "dummy" - ] - rocm_not_supported_load_format = [] - if load_format not in supported_load_format: + if load_format not in [ + "auto", "pt", "safetensors", "npcache", "dummy" + ]: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") - if is_hip() and load_format in rocm_not_supported_load_format: - rocm_supported_load_format = [ - f for f in supported_load_format - if (f not in rocm_not_supported_load_format) - ] - raise ValueError( - f"load format \'{load_format}\' is not supported in ROCm. " - f"Supported load format are " - f"{rocm_supported_load_format}") - # TODO: Remove this check once HF updates the pt weights of Mixtral. - architectures = getattr(self.hf_config, "architectures", []) - if "MixtralForCausalLM" in architectures and load_format == "pt": - raise ValueError( - "Currently, the 'pt' format is not supported for Mixtral. " - "Please use the 'safetensors' format instead. ") self.load_format = load_format def _verify_tokenizer_mode(self) -> None: @@ -141,11 +132,11 @@ def _verify_tokenizer_mode(self) -> None: raise ValueError( f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " "either 'auto' or 'slow'.") + self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq", "gptq", "squeezellm"] - rocm_not_supported_quantization = ["awq"] + supported_quantization = ["awq", "squeezellm"] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -167,27 +158,10 @@ def _verify_quantization(self) -> None: raise ValueError( f"Unknown quantization method: {self.quantization}. Must " f"be one of {supported_quantization}.") - if is_hip( - ) and self.quantization in rocm_not_supported_quantization: - raise ValueError( - f"{self.quantization} quantization is currently not supported " - f"in ROCm.") logger.warning(f"{self.quantization} quantization is not fully " "optimized yet. The speed can be slower than " "non-quantized models.") - def _verify_cuda_graph(self) -> None: - if self.max_context_len_to_capture is None: - self.max_context_len_to_capture = self.max_model_len - self.max_context_len_to_capture = min(self.max_context_len_to_capture, - self.max_model_len) - if (self.quantization in ["gptq", "squeezellm"] - and not self.enforce_eager): - # Related issue: https://github.com/vllm-project/vllm/issues/2147 - logger.warning(f"{self.quantization} does not support CUDA graph " - "yet. Disabling CUDA graph.") - self.enforce_eager = True - def verify_with_parallel_config( self, parallel_config: "ParallelConfig", @@ -208,11 +182,12 @@ def verify_with_parallel_config( "must be divisible by pipeline parallel size " f"({pipeline_parallel_size}).") - def get_sliding_window(self) -> Optional[int]: - return getattr(self.hf_config, "sliding_window", None) - - def get_vocab_size(self) -> int: - return self.hf_config.vocab_size + def _verify_chunk_prefill(self) -> None: + if self.max_chunked_prefill_len == 0: + raise ValueError("max_chunked_prefill_len can't be 0") + if self.max_chunked_prefill_len > 0 and not self.flash_style: + raise ValueError( + "chunked prefill is only supported for flash style") def get_hidden_size(self) -> int: return self.hf_config.hidden_size @@ -286,11 +261,13 @@ def __init__( gpu_memory_utilization: float, swap_space: int, sliding_window: Optional[int] = None, + flash_style: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB self.sliding_window = sliding_window + self.flash_style = flash_style self._verify_args() # Will be set after profiling. @@ -303,6 +280,18 @@ def _verify_args(self) -> None: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + if self.flash_style and self.block_size < 32: + raise ValueError( + "Flash style attention only supports block size >= 32. Got" + f"{self.block_size}.") + if not self.flash_style and self.block_size > 32: + raise ValueError( + "vLLM Page attention only supports block size <= 32. Got" + f"{self.block_size}.") + + if self.flash_style: + logger.info("Flash attention enabled.") + def verify_with_parallel_config( self, parallel_config: "ParallelConfig", @@ -331,30 +320,44 @@ class ParallelConfig: worker_use_ray: Whether to use Ray for model workers. Will be set to True if either pipeline_parallel_size or tensor_parallel_size is greater than 1. + disable_shared_memory: Whether to not use shared memory for + engine<->worker communication. Not used if Ray isn't used. + ray_workers_use_nsight: Whether to profile Ray workers with nvidia + nsight (See https://github.com/ray-project/ray/pull/39998 ). """ - def __init__( - self, - pipeline_parallel_size: int, - tensor_parallel_size: int, - worker_use_ray: bool, - max_parallel_loading_workers: Optional[int] = None, - ) -> None: + def __init__(self, + pipeline_parallel_size: int, + tensor_parallel_size: int, + worker_use_ray: bool, + disable_shared_memory: bool = False, + num_tokenizer_actors: int = 0, + tokenizer_actor_options: Optional[dict] = None, + ray_workers_use_nsight: bool = False) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size self.worker_use_ray = worker_use_ray - self.max_parallel_loading_workers = max_parallel_loading_workers + self.num_tokenizer_actors = num_tokenizer_actors + self.tokenizer_actor_options = tokenizer_actor_options + self.ray_workers_use_nsight = ray_workers_use_nsight self.world_size = pipeline_parallel_size * tensor_parallel_size if self.world_size > 1: self.worker_use_ray = True self._verify_args() + self.disable_shared_memory = (not self.worker_use_ray + or disable_shared_memory) + def _verify_args(self) -> None: if self.pipeline_parallel_size > 1: raise NotImplementedError( "Pipeline parallelism is not supported yet.") + if self.ray_workers_use_nsight and not self.worker_use_ray: + raise ValueError("Unable to use nsight profiling unless workers " + "run with Ray.") + class SchedulerConfig: """Scheduler configuration. @@ -366,7 +369,22 @@ class SchedulerConfig: iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). - max_paddings: Maximum number of paddings to be added to a batch. + num_preallocated_slots_per_step: The number of slots the scheduler will + allocate per model step, in addition to the slots allocated for + every logical token. Defaults to 0. + use_deltas: Whether scheduler output is emitted as a "delta" or update. + Deltas are smaller and incur less overhead over IPC. + max_chunked_prefill_len: The maximum length of tokens for prefill + requests. Longer requests will be chunked into multiple chunks. + -1 means no chunking (disabled). This features is only supported + for flash style attention. + max_num_prompt_seqs: The maximum number of prompt sequences to be + processed in a single iteration. + flash_style: Whether to use flash style attention. Only support + LLaMA models. + input_padding_size: The padding size for input tokens. This is used + to better support CUDAGRAPH and ultize TENSOR CORES. Has to be + a multiple of 8. """ def __init__( @@ -374,7 +392,12 @@ def __init__( max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, - max_paddings: int, + num_preallocated_slots_per_step: int = 0, + use_deltas: bool = False, + max_chunked_prefill_len: int = -1, + max_num_prompt_seqs: int = 1024, + flash_style: bool = False, + input_padding_size: int = 8, ) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens @@ -382,13 +405,23 @@ def __init__( # If max_model_len is too short, use 2048 as the default value for # higher throughput. self.max_num_batched_tokens = max(max_model_len, 2048) + self.max_num_seqs = max_num_seqs + self.max_num_decoding_tokens = max_num_seqs self.max_model_len = max_model_len - self.max_paddings = max_paddings + self.num_preallocated_slots_per_step = num_preallocated_slots_per_step + self.use_deltas = use_deltas + # We pad the prompt and generation tokens with padding size 8 + # to better support CUDAGRAPH and ultize TENSOR CORES + self.input_padding_size = input_padding_size + self.max_chunked_prefill_len = max_chunked_prefill_len + self.max_num_prompt_seqs = max_num_prompt_seqs + self.flash_style = flash_style self._verify_args() def _verify_args(self) -> None: - if self.max_num_batched_tokens < self.max_model_len: + if self.max_num_batched_tokens < self.max_model_len and \ + self.max_chunked_prefill_len == -1: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " f"smaller than max_model_len ({self.max_model_len}). " @@ -402,6 +435,223 @@ def _verify_args(self) -> None: "be greater than or equal to max_num_seqs " f"({self.max_num_seqs}).") + if self.max_num_batched_tokens < self.max_model_len: + logger.warning( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " + f"smaller than max_model_len ({self.max_model_len}). " + "This means that the user will not be able to use the full " + "model context length.") + + if self.num_preallocated_slots_per_step < 0: + raise ValueError( + f"num_preallocated_slots_per_step" + f"({self.num_preallocated_slots_per_step}) must be greater than" + " or equal to 1.") + + if self.max_chunked_prefill_len >= 0 and not self.flash_style: + raise ValueError( + "chunked prefill is only supported for flash style") + + if self.input_padding_size % 8 != 0 or self.input_padding_size == 0: + raise ValueError( + f"input_padding_size ({self.input_padding_size}) must be a " + "multiple of 8.") + + +class SpeculativeConfig: + """Configuration for speculative decoding. + + Args: + draft_model_config: ModelConfig for the draft model. + draft_parallel_config: ParallelConfig for the draft model. + num_speculative_tokens: The number of tokens to sample from the draft + model before scoring with the target model. + """ + + @staticmethod + def maybe_create_spec_config( + target_model_config: ModelConfig, + target_parallel_config: ParallelConfig, + dtype: str, + speculative_model: Optional[str], + num_speculative_tokens: Optional[int], + speculative_model_uses_tp_1: bool, + target_model_input_padding_size: Optional[int] = None, + draft_model_input_padding_size: Optional[int] = None, + ) -> Optional["SpeculativeConfig"]: + """Create a SpeculativeConfig if all fields required are not None. + """ + + if (speculative_model is None and num_speculative_tokens is None + and not speculative_model_uses_tp_1): + return None + + if (speculative_model is None and num_speculative_tokens + is not None) or (speculative_model is not None + and num_speculative_tokens is None): + raise ValueError( + "Expected both speculative_model and " + "num_speculative_tokens to be provided, but found " + f"{speculative_model=} and {num_speculative_tokens=}.") + + # TODO these should be provided as a top-level draft model config. + revision = None + quantization = None + max_model_len = None + draft_model_config = ModelConfig( + speculative_model, target_model_config.tokenizer, + target_model_config.tokenizer_mode, + target_model_config.trust_remote_code, + target_model_config.download_dir, target_model_config.load_format, + dtype, target_model_config.seed, revision, + target_model_config.tokenizer_revision, max_model_len, + quantization, target_model_config.enable_cuda_graph, + target_model_config.cuda_graph_max_context_len, + target_model_config.cuda_graph_cache_size) + + draft_parallel_config = SpeculativeConfig.create_draft_parallel_config( + target_parallel_config, speculative_model_uses_tp_1) + + return SpeculativeConfig( + draft_model_config, + draft_parallel_config, + num_speculative_tokens, + target_model_input_padding_size, + draft_model_input_padding_size, + ) + + @staticmethod + def create_draft_parallel_config( + target_parallel_config: ParallelConfig, + speculative_model_uses_tp_1: bool) -> ParallelConfig: + """Create a parallel config for use by the draft worker. + """ + tp_size = target_parallel_config.tensor_parallel_size + + if speculative_model_uses_tp_1: + tp_size = 1 + + draft_parallel_config = ParallelConfig( + pipeline_parallel_size=target_parallel_config. + pipeline_parallel_size, + tensor_parallel_size=tp_size, + worker_use_ray=target_parallel_config.worker_use_ray, + disable_shared_memory=target_parallel_config.disable_shared_memory, + num_tokenizer_actors=target_parallel_config.num_tokenizer_actors, + tokenizer_actor_options=target_parallel_config. + tokenizer_actor_options, + ray_workers_use_nsight=target_parallel_config. + ray_workers_use_nsight, + ) + + return draft_parallel_config + + def __init__( + self, + draft_model_config: ModelConfig, + draft_parallel_config: ParallelConfig, + num_speculative_tokens: int, + target_model_input_padding_size: Optional[int], + draft_model_input_padding_size: Optional[int], + ): + self.draft_model_config = draft_model_config + self.draft_parallel_config = draft_parallel_config + self.num_speculative_tokens = num_speculative_tokens + self.target_model_input_padding_size = target_model_input_padding_size + self.draft_model_input_padding_size = draft_model_input_padding_size + + self._verify_args() + + def _verify_args(self) -> None: + if self.num_speculative_tokens < 0: + raise ValueError("Expected num_speculative_tokens to be greater " + f"than zero ({self.num_speculative_tokens}).") + + self.draft_model_config.verify_with_parallel_config( + self.draft_parallel_config) + + def create_target_scheduler_config( + self, + scheduler_config: SchedulerConfig, + ) -> SchedulerConfig: + """Create a SchedulerConfig for the target model. + """ + config = copy.deepcopy(scheduler_config) + # in the worst case, the target model has + # batch_size * (num_speculative_tokens + 1) * 2 number of + # tokens. we should increase max_num_decoding_token. + config.max_num_decoding_tokens = config.max_num_decoding_tokens * 2 * ( + self.num_speculative_tokens + 1) + if self.target_model_input_padding_size is not None: + config.input_padding_size = self.target_model_input_padding_size + return config + + def create_draft_scheduler_config( + self, + scheduler_config: SchedulerConfig, + ) -> SchedulerConfig: + """Create a SchedulerConfig for the draft model. + """ + config = copy.deepcopy(scheduler_config) + # in the worst case, the draft model has + # batch_size * (num_speculative_tokens + 1) number of + # tokens. we should increase max_num_decoding_token. + config.max_num_decoding_tokens = config.max_num_decoding_tokens * ( + self.num_speculative_tokens + 1) + if self.draft_model_input_padding_size is not None: + config.input_padding_size = self.draft_model_input_padding_size + return config + + @property + def num_preallocated_slots_per_step(self) -> int: + """The number of slots the scheduler should allocate per step, in + addition to the slots allocated for each logical token. + + This is equal to the number of speculative tokens, as each speculative + token must be scored. + """ + return self.num_speculative_tokens + + +@dataclass +class LoRAConfig: + max_lora_rank: int + max_loras: int + max_cpu_loras: Optional[int] = None + lora_dtype: Optional[torch.dtype] = None + lora_extra_vocab_size: int = 256 + + def __post_init__(self): + # Keep this in sync with csrc/punica/bgmv/bgmv_config.h + possible_max_ranks = (8, 16, 32, 64) + if self.max_lora_rank not in possible_max_ranks: + raise ValueError( + f"max_lora_rank ({self.max_lora_rank}) must be one of " + f"{possible_max_ranks}.") + if self.max_loras < 1: + raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") + if self.max_cpu_loras is None: + self.max_cpu_loras = self.max_loras + elif self.max_cpu_loras < self.max_loras: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_num_seqs ({self.max_loras})") + + def verify_with_model_config(self, model_config: ModelConfig): + if self.lora_dtype in (None, "auto"): + self.lora_dtype = model_config.dtype + elif isinstance(self.lora_dtype, str): + self.lora_dtype = getattr(torch, self.lora_dtype) + if model_config.max_chunked_prefill_len > 0: + raise ValueError("chunked prefill is not supported for lora") + + def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): + if scheduler_config.max_num_batched_tokens > 65528: + raise ValueError( + "Due to limitations of the custom LoRA CUDA kernel, " + "max_num_batched_tokens must be <= 65528 when " + "LoRA is enabled.") + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, @@ -411,8 +661,6 @@ def _verify_args(self) -> None: "bfloat16": torch.bfloat16, } -_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"] - def _get_and_verify_dtype( config: PretrainedConfig, @@ -442,14 +690,6 @@ def _get_and_verify_dtype( else: raise ValueError(f"Unknown dtype: {dtype}") - if is_hip() and torch_dtype == torch.float32: - rocm_supported_dtypes = [ - k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() - if (k not in _ROCM_NOT_SUPPORTED_DTYPE) - ] - raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. " - f"Supported dtypes are {rocm_supported_dtypes}") - # Verify the dtype. if torch_dtype != config_dtype: if torch_dtype == torch.float32: @@ -521,3 +761,13 @@ def _get_and_verify_max_len( "outputs or CUDA errors. Make sure the value is correct and " "within the model context size.") return int(max_model_len) + + +@dataclass +class LoadConfig: + s3_bucket: str + s3_prefix: str + region: str + target_throughput_gbps: float = 100.0 + part_size = 4 * 1024 * 1024 + upload_if_not_exist: bool = True diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 8b26319b88cd3..cbdc0e14c974a 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -1,14 +1,12 @@ """A block manager that manages token blocks.""" +from typing import Dict, List, Optional, Set +from collections import defaultdict import enum -from typing import Dict, List, Optional, Set, Tuple from vllm.block import PhysicalTokenBlock from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device -# Mapping: logical block number -> physical block. -BlockTable = List[PhysicalTokenBlock] - class BlockAllocator: """Manages free physical token blocks for a device. @@ -29,7 +27,7 @@ def __init__( self.num_blocks = num_blocks # Initialize the free blocks. - self.free_blocks: BlockTable = [] + self.free_blocks: List[PhysicalTokenBlock] = [] for i in range(num_blocks): block = PhysicalTokenBlock(device=device, block_number=i, @@ -54,6 +52,10 @@ def get_num_free_blocks(self) -> int: return len(self.free_blocks) +# Mapping: logical block number -> physical block. +BlockTable = List[PhysicalTokenBlock] + + class AllocStatus(enum.Enum): """Result for BlockSpaceManager.can_allocate @@ -140,44 +142,115 @@ def allocate(self, seq_group: SequenceGroup) -> None: for seq in seq_group.get_seqs(): self.block_tables[seq.seq_id] = block_table.copy() - def can_append_slot(self, seq_group: SequenceGroup) -> bool: - # Simple heuristic: If there is at least one free block - # for each sequence, we can append. - num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() - num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) - return num_seqs <= num_free_gpu_blocks + def can_append_slots(self, + seq_group: SequenceGroup, + num_preallocated_slots: int = 0) -> bool: + """Determine whether there is enough space to append new slots to the + running sequences in the sequence group. + + Args: + seq_group: The sequence group whose running sequences will be used + in the determination. + num_preallocated_slots: The number of slots beyond the sequence + length that will be allocated. Used when a worker emits more + than one token per scheduler invocation. + """ + running_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + max_num_new_blocks = self._get_num_new_blocks_required_to_append( + running_seqs, num_preallocated_slots) + return max_num_new_blocks <= self.gpu_allocator.get_num_free_blocks() + + def _get_num_new_blocks_required_to_append( + self, seqs: List[Sequence], num_preallocated_slots: int) -> int: + """Calculate the number of new blocks required to append new tokens. + + Args: + seqs: The list of sequences to be used in the calculation. + num_preallocated_slots: The number of slots beyond the sequence + length that will be allocated. Used when a worker emits more + than one token per scheduler invocation. + + """ + max_num_new_slots_per_seq = [ + seq.get_num_unprocessed_token_ids() + num_preallocated_slots + for seq in seqs + ] + + # For simplicity, we assume each new slot consumes a new block (either + # by COW or new allocation). This is the worst case -- a better + # heuristic could be used. + max_num_new_blocks = sum(max_num_new_slots_per_seq) + return max_num_new_blocks + + def append_slots(self, + seq: Sequence, + num_preallocated_slots: int = 0) -> Dict[int, List[int]]: + """Allocate physical slots for new tokens. + + Args: + seq: The sequence that needs allocation to store new tokens. + num_preallocated_slots: The number of slots beyond the sequence + length that will be allocated. Used when a worker emits more + than one token per scheduler invocation. + """ + seq.ensure_num_empty_slots(num_preallocated_slots) + + num_new_blocks = 0 + while len(self.block_tables[seq.seq_id]) < len( + seq.logical_token_blocks): + self._append_block(self.block_tables[seq.seq_id]) + num_new_blocks += 1 + + # Even if no new blocks were added, make sure the last block is + # appendable. + num_blocks_to_check_appendable = max(num_new_blocks, 1) + return self._ensure_last_blocks_are_appendable( + self.block_tables[seq.seq_id], num_blocks_to_check_appendable) + + def _append_block(self, block_table: BlockTable) -> None: + """Append a block to the block table. May allocate a new block or re-use + a block when configured with a sliding window. + """ + if (self.block_sliding_window + and len(block_table) >= self.block_sliding_window): + # re-use a block + block_table.append(block_table[len(block_table) % + self.block_sliding_window]) + return - def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: - """Allocate a physical slot for a new token.""" - logical_blocks = seq.logical_token_blocks - block_table = self.block_tables[seq.seq_id] + # The sequence has a new logical block. + # Allocate a new physical block. + block = self.gpu_allocator.allocate() + block_table.append(block) + + def _ensure_last_blocks_are_appendable( + self, block_table: BlockTable, + num_blocks_to_check: int) -> Dict[int, List[int]]: + """Ensure the last blocks in the block table are appendable, e.g. if the + blocks are owned by a single sequence. + + The blocks which are not appendable are replaced with new blocks. The + copy-on-write source and destination block numbers are then returned. + """ + cow_src_dst = defaultdict(list) + for i in range( + len(block_table) - num_blocks_to_check, len(block_table)): + # We want to check if a token can be appended to this block. + block = block_table[i] + assert block.device == Device.GPU + if block.ref_count == 1: + # Not shared with other sequences. Appendable. + continue - if len(block_table) < len(logical_blocks): - if (self.block_sliding_window - and len(block_table) >= self.block_sliding_window): - # re-use a block - block_table.append(block_table[len(block_table) % - self.block_sliding_window]) - else: - # The sequence has a new logical block. - # Allocate a new physical block. - block = self.gpu_allocator.allocate() - block_table.append(block) - return None - - # We want to append the token to the last physical block. - last_block = block_table[-1] - assert last_block.device == Device.GPU - if last_block.ref_count == 1: - # Not shared with other sequences. Appendable. - return None - else: - # The last block is shared with other sequences. + # The block is shared with other sequences. # Copy on Write: Allocate a new block and copy the tokens. new_block = self.gpu_allocator.allocate() - block_table[-1] = new_block - self.gpu_allocator.free(last_block) - return last_block.block_number, new_block.block_number + block_table[i] = new_block + self.gpu_allocator.free(block) + + cow_src_dst[block.block_number].append(new_block.block_number) + + return dict(cow_src_dst) def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: # NOTE: fork does not allocate a new physical block. @@ -198,14 +271,20 @@ def _get_physical_blocks( blocks.update(self.block_tables[seq.seq_id]) return list(blocks) - def can_swap_in(self, seq_group: SequenceGroup) -> bool: - blocks = self._get_physical_blocks(seq_group) - num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) - num_free_blocks = self.gpu_allocator.get_num_free_blocks() + def can_swap_in(self, + seq_group: SequenceGroup, + num_preallocated_slots: int = 0) -> bool: # NOTE: Conservatively, we assume that every sequence will allocate - # at least one free block right after the swap-in. + # at least one free block per new slot right after the swap-in. # NOTE: This should match the logic in can_append_slot(). - num_required_blocks = len(blocks) + num_swapped_seqs + swapped_seqs = seq_group.get_seqs(status=SequenceStatus.SWAPPED) + max_num_new_blocks = self._get_num_new_blocks_required_to_append( + swapped_seqs, num_preallocated_slots) + + blocks = self._get_physical_blocks(seq_group) + + num_required_blocks = len(blocks) + max_num_new_blocks + num_free_blocks = self.gpu_allocator.get_num_free_blocks() return num_free_blocks - num_required_blocks >= self.watermark_blocks def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: @@ -277,6 +356,10 @@ def free(self, seq: Sequence) -> None: self._free_block_table(block_table) del self.block_tables[seq.seq_id] + # Sequence tracks which tokens have been saved to KV. + # Clear it as the physical block data may be overwritten. + seq.reset_processed_tokens() + def reset(self) -> None: for block_table in self.block_tables.values(): self._free_block_table(block_table) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ca28bbdc2fb95..7ba6d5fc5ff5f 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,13 +1,15 @@ import enum import time -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, Iterable, Set -from vllm.config import CacheConfig, SchedulerConfig +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.block_manager import AllocStatus, BlockSpaceManager +from vllm.anyscale.lora.utils import LoRARequest from vllm.core.policy import PolicyFactory from vllm.logger import init_logger from vllm.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadata, SequenceStatus) + SequenceGroupMetadata, SequenceGroupMetadataDelta, + SequenceStatus) logger = init_logger(__name__) @@ -30,27 +32,129 @@ class SchedulerOutputs: def __init__( self, scheduled_seq_groups: List[SequenceGroup], - prompt_run: bool, + num_chunked_prefill_groups: int, + num_prompt_groups: int, num_batched_tokens: int, blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], ignored_seq_groups: List[SequenceGroup], + done_seq_group_ids: Set[str], + num_preallocated_slots: int = 0, + num_preempted_seqs: int = 0, + lora_enabled: bool = False, ) -> None: self.scheduled_seq_groups = scheduled_seq_groups - self.prompt_run = prompt_run + self.num_chunked_prefill_groups = num_chunked_prefill_groups + self.num_prompt_groups = num_prompt_groups self.num_batched_tokens = num_batched_tokens self.blocks_to_swap_in = blocks_to_swap_in self.blocks_to_swap_out = blocks_to_swap_out self.blocks_to_copy = blocks_to_copy + self.done_seq_group_ids = done_seq_group_ids + self.num_preempted_seqs = num_preempted_seqs + + # The number of preallocated slots per sequence in the KV cache. + # This is normally zero, but is greater than zero when multiple + # tokens are generated per scheduling iteration + self.num_preallocated_slots = num_preallocated_slots + assert self.num_preallocated_slots >= 0 + # Swap in and swap out should never happen at the same time. assert not (blocks_to_swap_in and blocks_to_swap_out) self.ignored_seq_groups = ignored_seq_groups + if lora_enabled: + self.num_loras = len(set(self.lora_requests)) + self._sort_by_lora_ids() def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in - and not self.blocks_to_swap_out and not self.blocks_to_copy) + and not self.blocks_to_swap_out and not self.blocks_to_copy + and not self.done_seq_group_ids) + + def _sort_by_lora_ids(self) -> bool: + self.scheduled_seq_groups.sort(key=lambda g: ( + g.lora_request.lora_int_id if g.lora_request else 0, g.request_id)) + + @property + def lora_requests(self) -> Set[LoRARequest]: + return {g.lora_request for g in self.scheduled_seq_groups} + + +class SchedulerDecodeOutputs: + """Outputs of the decoding phase of the scheduler. + + Attributes: + token_budget: The number of available token slots after scheduling. + decoding_seq_groups: Selected sequence groups for decoding. + num_preempted_seqs: The number of preempted sequences. + blocks_to_swap_in: The blocks to swap in. + blocks_to_swap_out: The blocks to swap out. + blocks_to_copy: The blocks to copy. + """ + + def __init__(self, token_budget: int, + decoding_seq_groups: List[SequenceGroup], + num_preempted_seqs: int, blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]]) -> None: + self.token_budget = token_budget + self.decoding_seq_groups = decoding_seq_groups + self.num_preempted_seqs = num_preempted_seqs + self.blocks_to_swap_in = blocks_to_swap_in + self.blocks_to_swap_out = blocks_to_swap_out + self.blocks_to_copy = blocks_to_copy + + @staticmethod + def create_empty() -> "SchedulerDecodeOutputs": + return SchedulerDecodeOutputs(0, [], 0, {}, {}, {}) + + def num_decoding_seqs(self): + return sum( + seq_group.num_seqs(status=SequenceStatus.RUNNING) + for seq_group in self.decoding_seq_groups) + + def curr_loras(self): + return set(seq_group.lora_int_id + for seq_group in self.decoding_seq_groups) + + +class SchedulePrefillOutputs: + """Outputs of the prefilling phase of the scheduler. + + Attributes: + token_budget: The number of available token slots after scheduling. + num_batched_tokens: The number of batched tokens. + chunk_prefilling_seq_groups: Selected sequence groups for chunked + prefilling. + prompting_seq_groups: Selected sequence groups for prompting. + ignored_seq_groups: Ignored sequence groups. + """ + + def __init__( + self, + token_budget: int, + num_batched_tokens: int, + chunk_prefilling_seq_groups: List[SequenceGroup], + prompting_seq_groups: List[SequenceGroup], + ignored_seq_groups: List[SequenceGroup], + ) -> None: + self.token_budget = token_budget + self.num_batched_tokens = num_batched_tokens + self.chunk_prefilling_seq_groups = chunk_prefilling_seq_groups + self.prompting_seq_groups = prompting_seq_groups + self.ignored_seq_groups = ignored_seq_groups + + def num_prompting_groups(self): + return len(self.prompting_seq_groups) + + def num_chunk_prefilling_groups(self): + return len(self.chunk_prefilling_seq_groups) + + def num_selected_groups(self): + return len(self.chunk_prefilling_seq_groups) + len( + self.prompting_seq_groups) class Scheduler: @@ -59,12 +163,28 @@ def __init__( self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, + lora_config: Optional[LoRAConfig], ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config - - self.prompt_limit = min(self.scheduler_config.max_model_len, - self.scheduler_config.max_num_batched_tokens) + # Note for LoRA scheduling: the current policy is extremely + # simple and NOT fair. It can lead to starvation of some + # LoRAs. This should be improved in the future. + self.lora_config = lora_config + self.prompt_limit = self.scheduler_config.max_model_len + self.chunked_prefill_enabled = \ + self.scheduler_config.max_chunked_prefill_len >= 0 + if self.chunked_prefill_enabled: + self.max_chunked_prefill_len = \ + scheduler_config.max_chunked_prefill_len + logger.info( + f"chunked prefill enabled, {self.max_chunked_prefill_len=}" + f", {self.scheduler_config.max_num_prompt_seqs=}" + f", { self.scheduler_config.max_num_batched_tokens=}") + assert not self.lora_enabled, \ + "chunked prefilling is not supported with LoRA" + else: + self.max_chunked_prefill_len = 1000_000_000 # Instantiate the scheduling policy. self.policy = PolicyFactory.get_policy(policy_name="fcfs") @@ -78,20 +198,55 @@ def __init__( # TODO(zhuohan): Use deque instead of list for better performance. # Sequence groups in the WAITING state. self.waiting: List[SequenceGroup] = [] + # Sequence groups in the CHUNKED PREFILLING state. + self.chunked_prefilling: List[SequenceGroup] = [] # Sequence groups in the RUNNING state. self.running: List[SequenceGroup] = [] # Sequence groups in the SWAPPED state. self.swapped: List[SequenceGroup] = [] + # IDs of aborted & finished seq groups before + # the current scheduling iteration. + self.done_ids: Set[str] = set() + + @property + def lora_enabled(self): + return bool(self.lora_config) + + @property + def _use_deltas(self): + return self.scheduler_config.use_deltas + + @property + def _num_preallocated_slots(self) -> int: + """The number of slots to preallocate per decode step. + + This is greater than zero when the worker runs more than one step per + scheduler invocation. + """ + return self.scheduler_config.num_preallocated_slots_per_step + + @property + def num_decoding_tokens_per_seq(self) -> int: + """The number of new tokens will be generated.""" + return self._num_preallocated_slots + 1 + def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. + # logger.debug(f"add_seq_group {seq_group.request_id}") self.waiting.append(seq_group) - def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: + def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> int: + """Returns the number of actually aborted seq groups.""" if isinstance(request_id, str): request_id = (request_id, ) request_ids = set(request_id) - for state_queue in [self.waiting, self.running, self.swapped]: + self.done_ids.update(request_ids) + aborted = 0 + for state_queue in [ + self.waiting, self.running, self.swapped, + self.chunked_prefilling + ]: # We need to reverse the list as we are removing elements # from it as we iterate over it. If we don't do it, # indices will get messed up and we will skip over elements. @@ -99,6 +254,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: if seq_group.request_id in request_ids: # Remove the sequence group from the state queue. state_queue.remove(seq_group) + aborted += 1 for seq in seq_group.get_seqs(): if seq.is_finished(): continue @@ -106,117 +262,49 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: self.free_seq(seq) request_ids.remove(seq_group.request_id) if not request_ids: - return + return aborted + return aborted def has_unfinished_seqs(self) -> bool: - return self.waiting or self.running or self.swapped + return self.waiting or self.running or self.swapped \ + or self.chunked_prefilling def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) - def _schedule(self) -> SchedulerOutputs: - # Blocks that need to be swaped or copied before model execution. + def _schedule_decoding(self, token_budget: int) -> SchedulerDecodeOutputs: + """Schedule sequence groups for decoding. + First schedule the sequence groups in the RUNNING state. + Then schedule the sequence groups in the SWAPPED state. + + Args: + token_budget: The number of available token slots. + """ blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {} blocks_to_copy: Dict[int, List[int]] = {} + decoding_seq_groups: List[SequenceGroup] = [] + preempted: List[SequenceGroup] = [] # Fix the current time. now = time.monotonic() - # Join waiting sequences if possible. - if not self.swapped: - ignored_seq_groups: List[SequenceGroup] = [] - scheduled: List[SequenceGroup] = [] - # The total number of sequences on the fly, including the - # requests in the generation phase. - num_curr_seqs = sum(seq_group.get_max_num_running_seqs() - for seq_group in self.running) - seq_lens: List[int] = [] - - # Optimization: We do not sort the waiting queue since the preempted - # sequence groups are added to the front and the new sequence groups - # are added to the back. - while self.waiting: - seq_group = self.waiting[0] - - assert seq_group.num_seqs() == 1, ( - "Waiting sequence group should have only one prompt " - "sequence.") - num_prompt_tokens = seq_group.get_seqs()[0].get_len() - if num_prompt_tokens > self.prompt_limit: - logger.warning( - f"Input prompt ({num_prompt_tokens} tokens) is too long" - f" and exceeds limit of {self.prompt_limit}") - for seq in seq_group.get_seqs(): - seq.status = SequenceStatus.FINISHED_IGNORED - ignored_seq_groups.append(seq_group) - self.waiting.pop(0) - continue - - # If the sequence group cannot be allocated, stop. - can_allocate = self.block_manager.can_allocate(seq_group) - if can_allocate == AllocStatus.LATER: - break - elif can_allocate == AllocStatus.NEVER: - logger.warning( - f"Input prompt ({num_prompt_tokens} tokens) is too long" - f" and exceeds the capacity of block_manager") - for seq in seq_group.get_seqs(): - seq.status = SequenceStatus.FINISHED_IGNORED - ignored_seq_groups.append(seq_group) - self.waiting.pop(0) - continue - - # If the number of batched tokens exceeds the limit, stop. - new_seq_lens = seq_lens + [num_prompt_tokens] - num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) - if (num_batched_tokens > - self.scheduler_config.max_num_batched_tokens): - break - - # The total number of sequences in the RUNNING state should not - # exceed the maximum number of sequences. - num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_curr_seqs + num_new_seqs > - self.scheduler_config.max_num_seqs): - break - - num_paddings = num_batched_tokens - sum(new_seq_lens) - if num_paddings > self.scheduler_config.max_paddings: - break - seq_lens = new_seq_lens - - seq_group = self.waiting.pop(0) - self._allocate(seq_group) - self.running.append(seq_group) - num_curr_seqs += num_new_seqs - scheduled.append(seq_group) - - if scheduled or ignored_seq_groups: - scheduler_outputs = SchedulerOutputs( - scheduled_seq_groups=scheduled, - prompt_run=True, - num_batched_tokens=len(seq_lens) * - max(seq_lens) if seq_lens else 0, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ignored_seq_groups=ignored_seq_groups, - ) - return scheduler_outputs - # NOTE(woosuk): Preemption happens only when there is no available slot # to keep all the sequence groups in the RUNNING state. # In this case, the policy is responsible for deciding which sequence # groups to preempt. self.running = self.policy.sort_by_priority(now, self.running) - # Reserve new token slots for the running sequence groups. - running: List[SequenceGroup] = [] - preempted: List[SequenceGroup] = [] + # Step 1: Schedule as many decoding requests as possible. + # If we run out of token budget, stop. + # If we run out of available slots, try to preempt + # the lowest-priority sequence groups. while self.running: + if token_budget < self.running[0].num_unfinished_seqs( + ) * self.num_decoding_tokens_per_seq: + break seq_group = self.running.pop(0) - while not self.block_manager.can_append_slot(seq_group): + while not self._can_append_slots(seq_group): if self.running: # Preempt the lowest-priority sequence groups. victim_seq_group = self.running.pop(-1) @@ -229,77 +317,360 @@ def _schedule(self) -> SchedulerOutputs: preempted.append(seq_group) break else: + # logger.debug(f"append slot for {seq_group}") # Append new slots to the sequence group. - self._append_slot(seq_group, blocks_to_copy) - running.append(seq_group) - self.running = running - - # Swap in the sequence groups in the SWAPPED state if possible. + self._append_slots(seq_group, blocks_to_copy) + # logger.debug(f"scheduled r -> r {seq_group.request_id}") + decoding_seq_groups.append(seq_group) + token_budget -= seq_group.num_seqs( + status=SequenceStatus.RUNNING + ) * self.num_decoding_tokens_per_seq + + # If any sequence group is preempted, do not swap in any sequence group. + if preempted: + return SchedulerDecodeOutputs(token_budget, decoding_seq_groups, + len(preempted), blocks_to_swap_in, + blocks_to_swap_out, blocks_to_copy) + + # Step 2: Swap in the sequence groups in the SWAPPED state if possible. self.swapped = self.policy.sort_by_priority(now, self.swapped) - if not preempted: - num_curr_seqs = sum(seq_group.get_max_num_running_seqs() - for seq_group in self.running) - - while self.swapped: - seq_group = self.swapped[0] - # If the sequence group cannot be swapped in, stop. - if not self.block_manager.can_swap_in(seq_group): - break + num_curr_seqs = sum( + seq_group.num_seqs(status=SequenceStatus.RUNNING) + for seq_group in decoding_seq_groups) + curr_loras = set( + seq_group.lora_int_id + for seq_group in self.running) if self.lora_enabled else None + + swapped_indices_to_remove = [] + for i, seq_group in enumerate(self.swapped): + if token_budget < self.swapped[0].num_unfinished_seqs( + ) * self.num_decoding_tokens_per_seq: + break + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if lora_int_id > 0 and lora_int_id not in curr_loras and ( + len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + continue - # The total number of sequences in the RUNNING state should not - # exceed the maximum number of sequences. - num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_curr_seqs + num_new_seqs > - self.scheduler_config.max_num_seqs): - break + # If the sequence group cannot be swapped in, stop. + if not self._can_swap_in(seq_group): + break + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) + if (num_curr_seqs + num_new_seqs > + self.scheduler_config.max_num_seqs): + break + + swapped_indices_to_remove.append(i) + if lora_int_id > 0: + curr_loras.add(lora_int_id) + self._swap_in(seq_group, blocks_to_swap_in) + self._append_slots(seq_group, blocks_to_copy) + num_curr_seqs += num_new_seqs + # logger.debug(f"scheduled s -> r {seq_group.request_id}") + decoding_seq_groups.append(seq_group) + token_budget -= seq_group.num_seqs( + status=SequenceStatus.RUNNING + ) * self.num_decoding_tokens_per_seq + + for i in reversed(swapped_indices_to_remove): + self.swapped.pop(i) + + return SchedulerDecodeOutputs(token_budget, decoding_seq_groups, + len(preempted), blocks_to_swap_in, + blocks_to_swap_out, blocks_to_copy) + + def _chunk_prefill_sequence_group( + self, seq_group: SequenceGroup, token_budget: int, + chunk_prefilling_seq_groups: List[SequenceGroup], + prompting_seq_groups: List[SequenceGroup]) -> int: + """Chunked prefilling one sequence_group + + Args: + token_budget: The number of available token slots. + seq_group: The sequence to be chunk prefilled. + chunk_prefilling_seq_groups: (output) if the sequence group has more + to prefill after this step, it will be added to this list. + prompting_seq_groups: (output) The prompting sequence groups. If + the sequence group finishes prefilling after this step, it will + be added to this list. + + Returns: + num_tokens: The number of tokens to be prefilled from + the sequence group. + """ + num_unprefilled_tokens = seq_group.get_num_unprefilled() + to_advance = min(num_unprefilled_tokens, token_budget, + self.max_chunked_prefill_len) + + seq_group.advance_prefill_range(to_advance) + + # If the sequence group is not fully prefilled, put it into the + # chunked prefilling queue. + if seq_group.get_num_unprefilled() > 0: + # logger.debug(f"scheduled p -> p {seq_group.request_id}") + chunk_prefilling_seq_groups.append(seq_group) + else: + # logger.debug(f"scheduled p -> r {seq_group.request_id}") + prompting_seq_groups.append(seq_group) + + return to_advance + + def _schedule_prefilling( + self, + token_budget: int, + num_curr_seqs: int, + curr_loras: Optional[Set[int]] = None) -> SchedulePrefillOutputs: + """Schedule sequence groups for (chunked) prefilling. + + Args: + token_budget: The number of available token slots. + num_curr_seqs: The number of sequences already scheduled. + curr_loras: The set of LoRA IDs already scheduled. + + Returns: + SchedulePrefillOutputs: The outputs of the prefilling phase. + """ + ignored_seq_groups: List[SequenceGroup] = [] + num_batched_tokens: int = 0 + prompting_seq_groups: List[SequenceGroup] = [] + chunk_prefilling_seq_groups: List[SequenceGroup] = [] + num_prompting_seqs: int = 0 + + # If any request in swapped state, try not schedule any prefilling. + if self.swapped: + return SchedulePrefillOutputs(token_budget, num_batched_tokens, + chunk_prefilling_seq_groups, + prompting_seq_groups, + ignored_seq_groups) + + # Step 1: Continue schedule those requests are in chunked prefilling. + # This is called only if chunked prefilling is enabled. + while self.chunked_prefilling and token_budget > 0 \ + and num_prompting_seqs < self.scheduler_config.max_num_prompt_seqs: + + if not self.chunked_prefill_enabled: + assert False, "can't reach here since chunk prefill is disabled" + + seq_group = self.chunked_prefilling.pop(0) + + num_prefilled_tokens = self._chunk_prefill_sequence_group( + seq_group, token_budget, chunk_prefilling_seq_groups, + prompting_seq_groups) + + token_budget -= num_prefilled_tokens + num_batched_tokens += num_prefilled_tokens + num_curr_seqs += seq_group.get_max_num_running_seqs() + num_prompting_seqs += 1 + + # Step 2: Schedule the waiting requests for (chunked) prefilling. + + # Optimization: We do not sort the waiting queue since the preempted + # sequence groups are added to the front and the new sequence groups + # are added to the back. + waiting_indices_to_remove = [] + for i, seq_group in enumerate(self.waiting): + if not (token_budget > 0 and num_prompting_seqs < + self.scheduler_config.max_num_prompt_seqs): + break + + assert seq_group.num_seqs() == 1, ( + "Waiting sequence group should have only one prompt " + "sequence.") + + # If the sequence group cannot be allocated, put into the ignored. + num_prompt_tokens = seq_group.get_seqs()[0].get_len() + if num_prompt_tokens > self.prompt_limit: + logger.warning( + f"Input prompt ({num_prompt_tokens} tokens) is too long" + f" and exceeds limit of {self.prompt_limit}") + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + waiting_indices_to_remove.append(i) + continue + + # If the sequence group cannot be allocated, stop. + can_allocate = self.block_manager.can_allocate(seq_group) + if can_allocate == AllocStatus.LATER: + break + elif can_allocate == AllocStatus.NEVER: + logger.warning( + f"Input prompt ({num_prompt_tokens} tokens) is too long" + f" and exceeds the capacity of block_manager") + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + waiting_indices_to_remove.append(i) + continue + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if lora_int_id > 0 and lora_int_id not in curr_loras and len( + curr_loras) >= self.lora_config.max_loras: + # We don't have a space for another LoRA, so + # we ignore this request for now. + continue - seq_group = self.swapped.pop(0) - self._swap_in(seq_group, blocks_to_swap_in) - self._append_slot(seq_group, blocks_to_copy) - num_curr_seqs += num_new_seqs - self.running.append(seq_group) + # If the number of batched tokens exceeds the limit and + # chunked prefill is disabled, stop. + if num_prompt_tokens > token_budget and \ + not self.chunked_prefill_enabled: + break - # Each sequence in the generation phase only takes one token slot. - # Therefore, the number of batched tokens is equal to the number of - # sequences in the RUNNING state. - num_batched_tokens = sum( - seq_group.num_seqs(status=SequenceStatus.RUNNING) - for seq_group in self.running) + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.get_max_num_running_seqs() + if (num_curr_seqs + num_new_seqs > + self.scheduler_config.max_num_seqs): + break + waiting_indices_to_remove.append(i) + if lora_int_id > 0: + curr_loras.add(lora_int_id) + self._allocate(seq_group) + + num_prefilled_tokens = self._chunk_prefill_sequence_group( + seq_group, token_budget, chunk_prefilling_seq_groups, + prompting_seq_groups) + + token_budget -= num_prefilled_tokens + num_batched_tokens += num_prefilled_tokens + num_curr_seqs += seq_group.get_max_num_running_seqs() + num_prompting_seqs += 1 + + for i in reversed(waiting_indices_to_remove): + self.waiting.pop(i) + + return SchedulePrefillOutputs(token_budget, num_batched_tokens, + chunk_prefilling_seq_groups, + prompting_seq_groups, ignored_seq_groups) + + def _schedule(self) -> SchedulerOutputs: + token_budget = self._round_down_by_padding( + self.scheduler_config.max_num_batched_tokens) + + if self.chunked_prefill_enabled: + # Chunked prefilling is enabled. + # We first schedule as many decoding requests as possible, + # and then schedule chunked prefilling requests. + decoding_outputs = self._schedule_decoding(token_budget) + + token_budget = self._round_down_by_padding( + decoding_outputs.token_budget) + + prefilling_outputs = self._schedule_prefilling( + token_budget, decoding_outputs.num_decoding_seqs(), + decoding_outputs.curr_loras() if self.lora_enabled else None) + else: + # Default behavior + # First schedule as many prefilling requests as possible, + # then schedule decoding requests. + + num_curr_seqs = sum( + seq_group.num_seqs(status=SequenceStatus.RUNNING) + for seq_group in self.running) + curr_loras = set( + seq_group.lora_int_id + for seq_group in self.running) if self.lora_enabled else None + + prefilling_outputs = self._schedule_prefilling( + token_budget, num_curr_seqs, curr_loras) + + assert len(prefilling_outputs.chunk_prefilling_seq_groups + ) == 0, "Chunked prefill is disabled" + + if len(prefilling_outputs.prompting_seq_groups) > 0: + decoding_outputs = SchedulerDecodeOutputs.create_empty() + else: + decoding_outputs = self._schedule_decoding(token_budget) + + num_batched_tokens = prefilling_outputs.num_batched_tokens + \ + decoding_outputs.num_decoding_seqs() * \ + self.num_decoding_tokens_per_seq + + is_decoding_only = prefilling_outputs.num_selected_groups() == 0 scheduler_outputs = SchedulerOutputs( - scheduled_seq_groups=self.running, - prompt_run=False, + scheduled_seq_groups=prefilling_outputs.chunk_prefilling_seq_groups + + prefilling_outputs.prompting_seq_groups + + decoding_outputs.decoding_seq_groups, + num_chunked_prefill_groups=prefilling_outputs. + num_chunk_prefilling_groups(), + num_prompt_groups=prefilling_outputs.num_selected_groups(), num_batched_tokens=num_batched_tokens, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ignored_seq_groups=[], + blocks_to_swap_in=decoding_outputs.blocks_to_swap_in, + blocks_to_swap_out=decoding_outputs.blocks_to_swap_out, + blocks_to_copy=decoding_outputs.blocks_to_copy, + ignored_seq_groups=prefilling_outputs.ignored_seq_groups, + num_preempted_seqs=decoding_outputs.num_preempted_seqs, + done_seq_group_ids=self.done_ids.copy(), + num_preallocated_slots=self._num_preallocated_slots + if is_decoding_only else 0, + lora_enabled=self.lora_enabled, ) + + self.done_ids.clear() + + self.chunked_prefilling = \ + prefilling_outputs.chunk_prefilling_seq_groups + \ + self.chunked_prefilling + self.running = self.running + \ + prefilling_outputs.prompting_seq_groups + \ + decoding_outputs.decoding_seq_groups return scheduler_outputs - def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: + def schedule( + self + ) -> Tuple[List[Union[SequenceGroupMetadata, SequenceGroupMetadataDelta]], + SchedulerOutputs]: + now_perf_counter = time.perf_counter() + now = time.time() # Schedule sequence groups. # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. scheduler_outputs = self._schedule() # Create input data structures. - seq_group_metadata_list: List[SequenceGroupMetadata] = [] - for seq_group in scheduler_outputs.scheduled_seq_groups: - seq_data: Dict[int, SequenceData] = {} + seq_group_metadata_list: List[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]] = [] + for i, seq_group in enumerate(scheduler_outputs.scheduled_seq_groups): + if seq_group.first_scheduled_time is None: + seq_group.first_scheduled_time = now + seq_group.time_in_queue = (now_perf_counter - + seq_group.arrival_time_perf_counter) + seq_data: Dict[int, List[SequenceData]] = {} block_tables: Dict[int, List[int]] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id seq_data[seq_id] = seq.data block_tables[seq_id] = self.block_manager.get_block_table(seq) - seq_group_metadata = SequenceGroupMetadata( - request_id=seq_group.request_id, - is_prompt=scheduler_outputs.prompt_run, - seq_data=seq_data, - sampling_params=seq_group.sampling_params, - block_tables=block_tables, - ) + is_prompt = i < scheduler_outputs.num_prompt_groups + is_chunked_prefill = \ + i < scheduler_outputs.num_chunked_prefill_groups + + if not self._use_deltas or is_prompt: + seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group.request_id, + is_chunked_prefill=is_chunked_prefill, + is_prompt=is_prompt, + seq_data=seq_data, + sampling_params=seq_group.sampling_params, + block_tables=block_tables, + lora_request=seq_group.lora_request, + ) + else: + seq_group_metadata = SequenceGroupMetadataDelta( + request_id=seq_group.request_id, + block_tables=block_tables, + ) seq_group_metadata_list.append(seq_group_metadata) return seq_group_metadata_list, scheduler_outputs @@ -310,29 +681,32 @@ def free_seq(self, seq: Sequence) -> None: self.block_manager.free(seq) def free_finished_seq_groups(self) -> None: - self.running = [ - seq_group for seq_group in self.running - if not seq_group.is_finished() - ] + new_running = [] + for seq_group in self.running: + if seq_group.is_finished(): + self.done_ids.add(seq_group.request_id) + else: + new_running.append(seq_group) + self.running = new_running def _allocate(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) for seq in seq_group.get_seqs(): seq.status = SequenceStatus.RUNNING - def _append_slot( + def _append_slots( self, seq_group: SequenceGroup, blocks_to_copy: Dict[int, List[int]], ) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - ret = self.block_manager.append_slot(seq) - if ret is not None: - src_block, dst_block = ret - if src_block in blocks_to_copy: - blocks_to_copy[src_block].append(dst_block) - else: - blocks_to_copy[src_block] = [dst_block] + additional_blocks_to_copy = self.block_manager.append_slots( + seq, self._num_preallocated_slots) + + for src_block, dst_blocks in additional_blocks_to_copy.items(): + if src_block not in blocks_to_copy: + blocks_to_copy[src_block] = [] + blocks_to_copy[src_block].extend(dst_blocks) def _preempt( self, @@ -356,12 +730,13 @@ def _preempt( preemption_mode = PreemptionMode.RECOMPUTE else: preemption_mode = PreemptionMode.SWAP + if preemption_mode == PreemptionMode.RECOMPUTE: self._preempt_by_recompute(seq_group) elif preemption_mode == PreemptionMode.SWAP: self._preempt_by_swap(seq_group, blocks_to_swap_out) else: - raise AssertionError("Invalid preemption mode.") + assert False, "Invalid preemption mode." def _preempt_by_recompute( self, @@ -409,3 +784,15 @@ def _swap_out( blocks_to_swap_out.update(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq.status = SequenceStatus.SWAPPED + + def _can_swap_in(self, seq_group: SequenceGroup) -> bool: + return self.block_manager.can_swap_in(seq_group, + self._num_preallocated_slots) + + def _can_append_slots(self, seq_group: SequenceGroup) -> bool: + return self.block_manager.can_append_slots( + seq_group, self._num_preallocated_slots) + + def _round_down_by_padding(self, x: int) -> int: + return x // self.scheduler_config.input_padding_size \ + * self.scheduler_config.input_padding_size diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7e58069e2c22d..83bda0d8b5e12 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,10 +1,12 @@ import argparse import dataclasses +import json from dataclasses import dataclass from typing import Optional, Tuple from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoadConfig, SpeculativeConfig, + LoRAConfig) @dataclass @@ -22,19 +24,38 @@ class EngineArgs: worker_use_ray: bool = False pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 - max_parallel_loading_workers: Optional[int] = None block_size: int = 16 swap_space: int = 4 # GiB gpu_memory_utilization: float = 0.90 max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 - max_paddings: int = 256 disable_log_stats: bool = False revision: Optional[str] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None - enforce_eager: bool = False - max_context_len_to_capture: int = 8192 + load_s3_path: str = None + load_s3_region: str = 'us-west-2' + enable_cuda_graph: bool = False + cuda_graph_max_context_len: int = 5000 + cuda_graph_cache_size: int = 10 + disable_shared_memory: bool = False + num_tokenizer_actors: int = 0 + speculative_model: Optional[str] = None + speculative_model_uses_tp_1: bool = False + num_speculative_tokens: Optional[int] = None + target_model_input_padding_size: Optional[int] = None + draft_model_input_padding_size: Optional[int] = None + enable_lora: bool = False + max_loras: int = 1 + max_lora_rank: int = 16 + lora_extra_vocab_size: int = 256 + lora_dtype = 'auto' + max_cpu_loras: int = -1 + flash_style: bool = False + max_chunked_prefill_len: int = -1 + max_num_prompt_seqs: int = 256 + input_padding_size: int = 8 + ray_workers_use_nsight: bool = False def __post_init__(self): if self.tokenizer is None: @@ -89,6 +110,58 @@ def add_cli_args( help='directory to download and load the weights, ' 'default to the default cache dir of ' 'huggingface') + + # LoRA related configs + parser.add_argument('--enable-lora', + action='store_true', + help='enable lora adapters') + parser.add_argument('--max-loras', + type=int, + default=EngineArgs.max_loras, + help='max number of LoRAs in a single batch') + parser.add_argument('--max-lora-rank', + type=int, + default=EngineArgs.max_lora_rank, + help='max LoRA rank') + parser.add_argument('--lora-extra-vocab-size', + type=int, + default=EngineArgs.lora_extra_vocab_size, + help='LoRA extra vocab size') + parser.add_argument('--lora-dtype', + type=str, + default=EngineArgs.lora_dtype, + choices=['auto', 'float16', 'bfloat16', 'float32'], + help='data type for lora') + parser.add_argument( + '--max-cpu-loras', + type=int, + default=EngineArgs.max_cpu_loras, + help=('Maximum number of loras to store in CPU memory. ' + 'Must be >= than max_num_seqs. ' + 'Defaults to max_num_seqs.')) + # Cuda Graph related configs + parser.add_argument('--enable-cuda-graph', + action='store_true', + help='enable cuda graph for decoding') + parser.add_argument('--cuda-graph-max-context-len', + type=int, + default=5000, + help='max context length for cuda graph decoding.' + 'request with longer context will fallback to' + 'non-compiled decoding') + parser.add_argument('--cuda-graph-cache-size', + type=int, + default=10, + help='num of cached cuda graphs for decoding') + parser.add_argument( + '--disable-shared-memory', + action='store_true', + help='don\'t use shared memory for engine<->worker comms') + parser.add_argument( + '--num-tokenizer-actors', + type=int, + default=0, + help='num of Ray actors for tokenization (0 is no Ray)') parser.add_argument( '--load-format', type=str, @@ -135,17 +208,11 @@ def add_cli_args( type=int, default=EngineArgs.tensor_parallel_size, help='number of tensor parallel replicas') - parser.add_argument( - '--max-parallel-loading-workers', - type=int, - help='load model sequentially in multiple batches, ' - 'to avoid RAM OOM when using tensor ' - 'parallel and large models') # KV cache arguments parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, - choices=[8, 16, 32], + choices=[8, 16, 32, 64, 128, 256, 512, 1024], help='token block size') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). parser.add_argument('--seed', @@ -156,13 +223,11 @@ def add_cli_args( type=int, default=EngineArgs.swap_space, help='CPU swap space size (GiB) per GPU') - parser.add_argument( - '--gpu-memory-utilization', - type=float, - default=EngineArgs.gpu_memory_utilization, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=EngineArgs.gpu_memory_utilization, + help='the percentage of GPU memory to be used for' + 'the model executor') parser.add_argument('--max-num-batched-tokens', type=int, default=EngineArgs.max_num_batched_tokens, @@ -172,10 +237,6 @@ def add_cli_args( type=int, default=EngineArgs.max_num_seqs, help='maximum number of sequences per iteration') - parser.add_argument('--max-paddings', - type=int, - default=EngineArgs.max_paddings, - help='maximum number of paddings in a batch') parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') @@ -183,25 +244,76 @@ def add_cli_args( parser.add_argument('--quantization', '-q', type=str, - choices=['awq', 'gptq', 'squeezellm', None], + choices=['awq', 'squeezellm', None], + default=None, + help='Method used to quantize the weights') + parser.add_argument('--load-s3-path', + type=str, + default=None, + help='Fast loading s3 path') + parser.add_argument('--load-s3-region', + type=str, + default='us-west-2', + help='Fast loading s3 region') + parser.add_argument('--rope-scaling', + default=None, + type=json.loads, + help='RoPE scaling configuration') + parser.add_argument( + '--speculative-model', + type=str, + default=None, + help='name of the draft model to be used in speculative decoding.') + + parser.add_argument( + '--speculative-model-uses-tp-1', + action='store_true', + help='whether the speculative model should use the same tensor ' + 'parallel degree as the verifier model, or use tp=1') + + parser.add_argument('--num-speculative-tokens', + type=int, + default=None, + help='number of speculative tokens to sample from ' + 'the draft model in speculative decoding') + + parser.add_argument('--target-model-input-padding-size', + type=int, default=None, - help='Method used to quantize the weights. If ' - 'None, we first check the `quantization_config` ' - 'attribute in the model config file. If that is ' - 'None, we assume the model weights are not ' - 'quantized and use `dtype` to determine the data ' - 'type of the weights.') - parser.add_argument('--enforce-eager', + help='padding size for speculative decoding target' + ' model prompt/generation tokens.' + ' must be a multiple of 8') + + parser.add_argument('--draft-model-input-padding-size', + type=int, + default=None, + help='padding size for speculative decoding draft' + ' model prompt/generation tokens.' + ' must be a multiple of 8') + + parser.add_argument('--flash-style', action='store_true', - help='Always use eager-mode PyTorch. If False, ' - 'will use eager mode and CUDA graph in hybrid ' - 'for maximal performance and flexibility.') - parser.add_argument('--max-context-len-to-capture', + help='use flash attention') + parser.add_argument( + '--max-chunked-prefill-len', + type=int, + default=-1, + help='max number of prefill tokens allowed in chunked prefill' + ', -1 means no limit') + parser.add_argument( + '--max-num-prompt-seqs', + type=int, + default=1024, + help='max number of prompt sequences allowed in prefill') + parser.add_argument('--input-padding-size', type=int, - default=EngineArgs.max_context_len_to_capture, - help='maximum context length covered by CUDA ' - 'graphs. When a sequence has context length ' - 'larger than this, we fall back to eager mode.') + default=8, + help='padding size for prompt/generation tokens.' + ' must be a multiple of 8') + parser.add_argument('--ray-workers-use-nsight', + type=bool, + default=False, + help='use nsight to profile ray workers') return parser @classmethod @@ -214,27 +326,72 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': def create_engine_configs( self, - ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: - model_config = ModelConfig(self.model, self.tokenizer, - self.tokenizer_mode, self.trust_remote_code, - self.download_dir, self.load_format, - self.dtype, self.seed, self.revision, - self.tokenizer_revision, self.max_model_len, - self.quantization, self.enforce_eager, - self.max_context_len_to_capture) - cache_config = CacheConfig(self.block_size, - self.gpu_memory_utilization, - self.swap_space, - model_config.get_sliding_window()) - parallel_config = ParallelConfig(self.pipeline_parallel_size, - self.tensor_parallel_size, - self.worker_use_ray, - self.max_parallel_loading_workers) - scheduler_config = SchedulerConfig(self.max_num_batched_tokens, - self.max_num_seqs, - model_config.max_model_len, - self.max_paddings) - return model_config, cache_config, parallel_config, scheduler_config + ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, + LoadConfig, SpeculativeConfig, Optional[LoRAConfig]]: + # Initialize the configs. + model_config = ModelConfig( + self.model, self.tokenizer, self.tokenizer_mode, + self.trust_remote_code, self.download_dir, self.load_format, + self.dtype, self.seed, self.revision, self.tokenizer_revision, + self.max_model_len, self.quantization, self.enable_cuda_graph, + self.cuda_graph_max_context_len, self.cuda_graph_cache_size, + self.flash_style, self.max_chunked_prefill_len) + + cache_config = CacheConfig( + self.block_size, self.gpu_memory_utilization, self.swap_space, + getattr(model_config.hf_config, 'sliding_window', None), + self.flash_style) + parallel_config = ParallelConfig( + self.pipeline_parallel_size, + self.tensor_parallel_size, + self.worker_use_ray, + self.disable_shared_memory, + self.num_tokenizer_actors, + ray_workers_use_nsight=self.ray_workers_use_nsight, + ) + + speculative_config = SpeculativeConfig.maybe_create_spec_config( + model_config, + parallel_config, + self.dtype, + speculative_model=self.speculative_model, + num_speculative_tokens=self.num_speculative_tokens, + speculative_model_uses_tp_1=self.speculative_model_uses_tp_1, + target_model_input_padding_size=self. + target_model_input_padding_size, + draft_model_input_padding_size=self.draft_model_input_padding_size, + ) + + scheduler_config = SchedulerConfig( + self.max_num_batched_tokens, + self.max_num_seqs, + model_config.max_model_len, + use_deltas=parallel_config.worker_use_ray + and not speculative_config, + num_preallocated_slots_per_step=0 if not speculative_config else + speculative_config.num_preallocated_slots_per_step, + max_chunked_prefill_len=self.max_chunked_prefill_len, + max_num_prompt_seqs=self.max_num_prompt_seqs, + flash_style=self.flash_style, + input_padding_size=self.input_padding_size, + ) + if self.load_s3_path is not None: + if self.load_s3_path.startswith('s3://'): + _, _, bucket, key = self.load_s3_path.split('/', 3) + load_config = LoadConfig(bucket, key, self.load_s3_region) + else: + load_config = None + + lora_config = LoRAConfig( + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras + if self.max_cpu_loras > 0 else None) if self.enable_lora else None + + return (model_config, cache_config, parallel_config, scheduler_config, + load_config, speculative_config, lora_config) @dataclass diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index d854a20b8b95a..90659270c21c4 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -2,9 +2,11 @@ import time from functools import partial from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, - Union, AsyncIterator) + Union) +from vllm.anyscale.lora.utils import LoRARequest from vllm.config import ModelConfig +from vllm.sequence import ExecuteModelData from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.engine.ray_utils import initialize_cluster, ray @@ -183,49 +185,152 @@ async def step_async(self) -> List[RequestOutput]: and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ + logger.debug("Running async step...") seq_group_metadata_list, scheduler_outputs, ignored = self._schedule() if scheduler_outputs.is_empty(): return ignored - # Execute the model. - output = await self._run_workers_async( - "execute_model", + data = ExecuteModelData( seq_group_metadata_list=seq_group_metadata_list, + finished_request_ids_list=list( + scheduler_outputs.done_seq_group_ids), blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_preallocated_slots=scheduler_outputs.num_preallocated_slots, ) - return self._process_model_outputs(output, scheduler_outputs) + ignored + # Execute the model. + output = await self._run_workers_async( + "execute_model", + data, + use_shared_memory=self.shared_mem_engine_to_worker is not None) + + outputs = self._process_model_outputs(output, scheduler_outputs) + + if self.shared_mem_engine_to_worker is not None: + if not outputs or all(out.finished for out in outputs): + logger.debug("Putting shm event to sleep") + self.shared_mem_engine_to_worker.clear() + self.shared_mem_worker_to_engine.clear() + self.shared_mem_engine_to_worker.put_to_sleep(block=False) + self.shared_mem_worker_to_engine.put_to_sleep(block=False) + + logger.debug("Async step finished") + + return outputs + + async def encode_request_async( + self, + request_id: str, # pylint: disable=unused-argument + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + ): + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = await self.tokenizer.encode_async( + request_id=request_id, + prompt=prompt, + lora_request=lora_request) + return prompt_token_ids + + async def add_request_async( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + ) -> None: + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + if arrival_time is None: + arrival_time = time.time() + prompt_token_ids = await self.encode_request_async( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + + return self.add_request( + request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + arrival_time=arrival_time, + lora_request=lora_request, + ) async def _run_workers_async( self, method: str, *args, get_all_outputs: bool = False, + wait_for_workers: bool = True, + use_shared_memory: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers.""" - coros = [] - for worker in self.workers: - if self.parallel_config.worker_use_ray: - coros.append( - worker.execute_method.remote(method, *args, **kwargs)) - else: - executor = getattr(worker, method) - coros.append(asyncio.get_event_loop().run_in_executor( - None, partial(executor, *args, **kwargs))) - - all_outputs = await asyncio.gather(*coros) - - if get_all_outputs: - return all_outputs + if use_shared_memory: + try: + logger.debug(f"Set data to shared memory: {args[0]}") + self.shared_mem_engine_to_worker.set_data(args[0]) + except RuntimeError: + # Raise underlying exception + await asyncio.wait_for(self._exceute_model_futures, timeout=5) + raise + logger.debug("Waiting for incoming data...") + await self.shared_mem_worker_to_engine.wait_for_incoming_data_async( + ) + try: + output = self.shared_mem_worker_to_engine.get_data() + except RuntimeError: + # Raise underlying exception + await asyncio.wait_for(self._exceute_model_futures, timeout=5) + raise + logger.debug(f"Got data {output}") + self.shared_mem_worker_to_engine.clear() + return output + else: + coros = [] + for worker in self.workers: + if self.parallel_config.worker_use_ray: + coros.append( + worker.execute_method.remote(method, *args, **kwargs)) + else: + executor = getattr(worker, method) + coros.append(asyncio.get_event_loop().run_in_executor( + None, partial(executor, *args, **kwargs))) + + if wait_for_workers: + all_outputs = await asyncio.gather(*coros) + + if get_all_outputs: + return all_outputs + + # Make sure all workers have the same results. + output = all_outputs[0] + if wait_for_workers: + for other_output in all_outputs[1:]: + assert output == other_output + return output + + async def check_health_async(self): + if not self.parallel_config.worker_use_ray: + return - # Make sure all workers have the same results. - output = all_outputs[0] - for other_output in all_outputs[1:]: - assert output == other_output - return output + self._check_if_any_actor_is_dead() + if self._exceute_model_futures: + ready, _ = await asyncio.wait(self._exceute_model_futures, + timeout=0, + return_when=asyncio.FIRST_COMPLETED) + if ready: + # Raise any exception + await asyncio.wait_for(ready, timeout=1) + raise RuntimeError("At least one Worker is dead.") class AsyncLLMEngine: @@ -281,6 +386,11 @@ def is_running(self) -> bool: return (self.background_loop is not None and not self.background_loop.done()) + @property + def is_stopped(self) -> bool: + return (self.background_loop is not None + and self.background_loop.done()) + def start_background_loop(self) -> None: """Start the background loop.""" if self.is_running: @@ -301,16 +411,7 @@ def _init_engine(self, *args, elif self.worker_use_ray: engine_class = ray.remote(num_cpus=0)(self._engine_class).remote else: - # FIXME(woosuk): This is a bit hacky. Be careful when changing the - # order of the arguments. - cache_config = args[1] - parallel_config = args[2] - if parallel_config.tensor_parallel_size == 1: - num_gpus = cache_config.gpu_memory_utilization - else: - num_gpus = 1 - engine_class = ray.remote(num_gpus=num_gpus)( - self._engine_class).remote + engine_class = ray.remote(num_gpus=1)(self._engine_class).remote return engine_class(*args, **kwargs) async def engine_step(self) -> bool: @@ -325,9 +426,14 @@ async def engine_step(self) -> bool: # Add the request into the vLLM engine's waiting queue. # TODO: Maybe add add_request_batch to reduce Ray overhead if self.engine_use_ray: - await self.engine.add_request.remote(**new_request) + resp = await self.engine.add_request.remote(**new_request) else: - self.engine.add_request(**new_request) + resp = await self.engine.add_request_async(**new_request) + if isinstance(resp, Exception): + request_id = new_request["request_id"] + self._request_tracker.propagate_exception( + resp, request_id=request_id) + self._abort(request_id) if finished_requests: await self._engine_abort(finished_requests) @@ -366,6 +472,7 @@ async def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, ) -> AsyncStream: if self.log_requests: shortened_prompt = prompt @@ -379,7 +486,8 @@ async def add_request( logger.info(f"Received request {request_id}: " f"prompt: {shortened_prompt!r}, " f"sampling params: {sampling_params}, " - f"prompt token ids: {shortened_token_ids}.") + f"prompt token ids: {shortened_token_ids}, " + f"lora_request: {lora_request}.") if not self.is_running: if self.start_engine_loop: @@ -391,22 +499,32 @@ async def add_request( "error that caused the background loop to stop " "(AsyncEngineDeadError).") + if arrival_time is None: + arrival_time = time.time() + prompt_token_ids = await self.engine.encode_request_async( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + stream = self._request_tracker.add_request( request_id, prompt=prompt, sampling_params=sampling_params, prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + arrival_time=arrival_time, + lora_request=lora_request, + ) return stream async def generate( - self, - prompt: Optional[str], - sampling_params: SamplingParams, - request_id: str, - prompt_token_ids: Optional[List[int]] = None - ) -> AsyncIterator[RequestOutput]: + self, + prompt: Optional[str], + sampling_params: SamplingParams, + request_id: str, + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None) -> RequestOutput: """Generate outputs for a request. Generate outputs for a request. This method is a coroutine. It adds the @@ -430,11 +548,14 @@ async def generate( arrival_time = time.monotonic() try: - stream = await self.add_request(request_id, - prompt, - sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + stream = await self.add_request( + request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + lora_request=lora_request, + ) async for request_output in stream: yield request_output @@ -462,7 +583,10 @@ async def abort(self, request_id: str) -> None: return self._abort(request_id) - def _abort(self, request_id: str) -> None: + def _abort(self, + request_id: str, + *, + verbose: Optional[bool] = None) -> None: """Abort a request. Abort a submitted request. If the request is finished or not found, @@ -471,8 +595,9 @@ def _abort(self, request_id: str) -> None: Args: request_id: The unique id of the request. """ - self._request_tracker.abort_request(request_id, - verbose=self.log_requests) + self._request_tracker.abort_request( + request_id, + verbose=self.log_requests if verbose is None else verbose) async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" @@ -503,3 +628,18 @@ def from_engine_args(cls, max_log_len=engine_args.max_log_len, start_engine_loop=start_engine_loop) return engine + + async def check_health(self): + t = time.perf_counter() + logger.debug("Starting health check...") + if self.is_stopped: + raise RuntimeError("Background loop is stopped.") + + if self.engine_use_ray: + try: + await self.engine.check_health.remote() + except ray.exceptions.RayActorError as e: + raise RuntimeError("Engine is dead.") from e + else: + await self.engine.check_health_async() + logger.debug(f"Health check took {time.perf_counter()-t}s") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d6e388bf135b2..3506dc1004b32 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,34 +1,43 @@ -import copy import time from functools import partial from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union +import msgspec + +from vllm.anyscale.shm.msgspec_shm import RayEvent, SharedMsgspecBufferWithEvent, SharedMemoryManager +from vllm.anyscale.lora.utils import LoRARequest +from vllm.anyscale.tokenization import TransformersTokenizer, RayTokenizerPool from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoadConfig, SpeculativeConfig, + LoRAConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs -from vllm.engine.metrics import record_metrics -from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray +from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, - SequenceGroupMetadata, SequenceGroupOutput, - SequenceOutput, SequenceStatus) -from vllm.transformers_utils.tokenizer import (detokenize_incrementally, - get_tokenizer) + SequenceGroupMetadata, SequenceOutputs, + SequenceGroupOutputs, SequenceStatus, + ExecuteModelData, SequenceGroupMetadataDelta, + DraftTargetWorkerMetrics) +from vllm.transformers_utils.tokenizer import detokenize_incrementally from vllm.utils import Counter +from vllm.worker.base_worker import BaseLoraWorker +from vllm.anyscale.profiler_utils import TorchProfiler if ray: from ray.air.util.torch_dist import init_torch_dist_process_group - from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup + from vllm.worker.worker import Worker # pylint: disable=ungrouped-imports logger = init_logger(__name__) _LOGGING_INTERVAL_SEC = 5 +SHARED_MEMORY_BUFFER_SIZE = int(5e+7) # 50 MB class LLMEngine: @@ -66,6 +75,9 @@ def __init__( cache_config: CacheConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, + load_config: LoadConfig, + speculative_config: SpeculativeConfig, + lora_config: Optional[LoRAConfig], distributed_init_method: str, placement_group: Optional["PlacementGroup"], log_stats: bool, @@ -84,124 +96,254 @@ def __init__( f"load_format={model_config.load_format}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"quantization={model_config.quantization}, " - f"enforce_eager={model_config.enforce_eager}, " - f"seed={model_config.seed})") + f"seed={model_config.seed}") + + if load_config is not None: + logger.info( + "Try to initializing the model with" + f" s3://{load_config.s3_bucket}/{load_config.s3_prefix}") + # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config self.cache_config = cache_config + assert self.cache_config.sliding_window == getattr( + self.model_config.hf_config, "sliding_window", None) self.parallel_config = parallel_config self.scheduler_config = scheduler_config + self.load_config = load_config + self.lora_config = lora_config + self.speculative_config = speculative_config self.log_stats = log_stats self._verify_args() - self.tokenizer = get_tokenizer( - model_config.tokenizer, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code, - tokenizer_revision=model_config.tokenizer_revision, - revision=model_config.revision) self.seq_counter = Counter() + self._init_tokenizer() + + self.shared_mem_manager = None + self.shared_mem_event = None + self.shared_mem_engine_to_worker = None + self.shared_mem_worker_to_engine = None # Create the parallel GPU workers. if self.parallel_config.worker_use_ray: - self._init_workers_ray(placement_group) + additional_ray_args = {} + if self.parallel_config.ray_workers_use_nsight: + logger.info("Configuring Ray workers to use nsight.") + additional_ray_args = {"runtime_env": {"nsight": "default"}} + + self._init_workers_ray(placement_group, **additional_ray_args) + runtime_contexts = self._run_workers("get_runtime_context", + get_all_outputs=True) + # If engine and all workers are on the same node, + # we can use shared memory. + if (not self.parallel_config.disable_shared_memory + and all(runtime_context["node_id"] == + ray.get_runtime_context().get_node_id() + for runtime_context in runtime_contexts)): + logger.info("Using shared memory for communication between " + "engine and workers.") + self.shared_mem_manager = SharedMemoryManager() + self.shared_mem_manager.start() # pylint: disable=consider-using-with + # Reusing the same event for both buffers is fine, as there's + # no situation in which we'd only want to wake up one buffer. + self.shared_mem_event = RayEvent.options( + num_cpus=0, + scheduling_strategy=NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=False)).remote() + self.shared_mem_engine_to_worker = SharedMsgspecBufferWithEvent( + size=SHARED_MEMORY_BUFFER_SIZE, + manager=self.shared_mem_manager, + encoder_init_fn=msgspec.msgpack.Encoder, + decoder_init_fn=lambda: msgspec.msgpack.Decoder(type=List[ + SamplerOutput]), + ray_event=self.shared_mem_event, + ) + self.shared_mem_worker_to_engine = SharedMsgspecBufferWithEvent( + size=SHARED_MEMORY_BUFFER_SIZE, + manager=self.shared_mem_manager, + encoder_init_fn=msgspec.msgpack.Encoder, + decoder_init_fn=lambda: msgspec.msgpack.Decoder(type=List[ + SamplerOutput]), + ray_event=self.shared_mem_event, + ) + logger.info( + "Engine shared memory input buffer id: " + f"{self.shared_mem_engine_to_worker.participant_id}") + logger.info( + "Engine shared memory output buffer id: " + f"{self.shared_mem_worker_to_engine.participant_id}") else: self._init_workers(distributed_init_method) + # Make sure the tokenizer actors are alive + self.tokenizer.ping() + # Profile the memory usage and initialize the cache. self._init_cache() # Create the scheduler. - self.scheduler = Scheduler(scheduler_config, cache_config) + self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) + + self._exceute_model_futures = None + if self._uses_shared_memory: + self._exceute_model_futures = self._run_workers( + "execute_model_shared_memory", + get_all_outputs=True, + wait_for_workers=False, + shared_memory_input=self.shared_mem_engine_to_worker, + shared_memory_output=self.shared_mem_worker_to_engine, + participant_id=self.shared_mem_engine_to_worker.participant_id) # Logging. self.last_logging_time = 0.0 + self.last_stats: Tuple[float, dict] = None # List of (timestamp, num_tokens) self.num_prompt_tokens: List[Tuple[float, int]] = [] # List of (timestamp, num_tokens) self.num_generation_tokens: List[Tuple[float, int]] = [] + self.num_started_tasks = 0 + self.num_finished_tasks = 0 + self.num_aborted_tasks = 0 + self.num_iterations = 0 + + self._last_draft_target_worker_metrics: Optional[ + DraftTargetWorkerMetrics] = None + + self._profiler = TorchProfiler() + + @property + def _uses_shared_memory(self) -> bool: + return self.shared_mem_engine_to_worker is not None + + def _init_tokenizer(self, **kwargs): + init_kwargs = dict( + enable_lora=bool(self.lora_config), + max_num_seqs=self.scheduler_config.max_num_seqs, + max_input_length=None, + tokenizer_mode=self.model_config.tokenizer_mode, + trust_remote_code=self.model_config.trust_remote_code, + revision=self.model_config.tokenizer_revision) + init_kwargs.update(kwargs) + if self.parallel_config.num_tokenizer_actors > 0: + ray_actor_options = (self.parallel_config.tokenizer_actor_options + or { + "num_cpus": 0 + }) + ray_actor_options[ + "scheduling_strategy"] = NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=False) + + self.tokenizer: RayTokenizerPool = RayTokenizerPool( + self.model_config.tokenizer, + num_actors=self.parallel_config.num_tokenizer_actors, + ray_actor_options=ray_actor_options, + **init_kwargs) + else: + self.tokenizer: TransformersTokenizer = TransformersTokenizer( + self.model_config.tokenizer, **init_kwargs) - def _init_workers(self, distributed_init_method: str): - # Lazy import the Worker to avoid importing torch.cuda/xformers + def _create_worker( + self, rank: Optional[int], + distributed_init_method: Optional[str]) -> BaseLoraWorker: + # Lazy import the Worker classes to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker - - assert self.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") - - self.workers: List[Worker] = [] - worker = Worker( + from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel + from vllm.worker.multi_step_worker import MultiStepWorker # pylint: disable=import-outside-toplevel + from vllm.worker.single_tp_worker import SingleTpWorker # pylint: disable=import-outside-toplevel + from vllm.worker.draft_target_worker import DraftTargetWorker # pylint: disable=import-outside-toplevel + + if not self.speculative_config: + return Worker( + self.model_config, + self.parallel_config, + self.scheduler_config, + rank, + distributed_init_method, + load_config=self.load_config, + lora_config=self.lora_config, + ) + + target_worker = Worker( self.model_config, self.parallel_config, - self.scheduler_config, - 0, + self.speculative_config.create_target_scheduler_config( + self.scheduler_config), + rank, distributed_init_method, + load_config=self.load_config, + lora_config=self.lora_config, ) - self.workers.append(worker) - self._run_workers( - "init_model", - get_all_outputs=True, + + draft_worker = MultiStepWorker( + self.speculative_config.draft_model_config, + self.speculative_config.draft_parallel_config, + self.speculative_config.create_draft_scheduler_config( + self.scheduler_config), + rank, + distributed_init_method, + load_config=self.load_config, + lora_config=self.lora_config, ) + draft_worker = SingleTpWorker.maybe_wrap_worker( + draft_worker, self.speculative_config.draft_parallel_config, + self.parallel_config) + return DraftTargetWorker.from_workers(draft_worker, target_worker) + + def _init_workers(self, distributed_init_method: str): + assert self.parallel_config.world_size == 1, ( + "Ray is required if parallel_config.world_size > 1.") + rank = 0 + + self.workers: List[BaseLoraWorker] = [ + self._create_worker(rank, distributed_init_method) + ] self._run_workers( - "load_model", + "init_model", get_all_outputs=True, - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, ) def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker - self.workers: List[Worker] = [] + self.workers: List[BaseLoraWorker] = [] for bundle in placement_group.bundle_specs: if not bundle.get("GPU", 0): continue - if self.parallel_config.tensor_parallel_size == 1: - num_gpus = self.cache_config.gpu_memory_utilization - else: - num_gpus = 1 worker = ray.remote( num_cpus=0, - num_gpus=num_gpus, + num_gpus=1, scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=placement_group, placement_group_capture_child_tasks=True), **ray_remote_kwargs, - )(RayWorkerVllm).remote(self.model_config.trust_remote_code) + )(RayWorker).remote(self.model_config.trust_remote_code) self.workers.append(worker) # Initialize torch distributed process group for the workers. init_torch_dist_process_group(self.workers, backend="nccl") - model_config = copy.deepcopy(self.model_config) - parallel_config = copy.deepcopy(self.parallel_config) - scheduler_config = copy.deepcopy(self.scheduler_config) - self._run_workers("init_worker", - get_all_outputs=True, - worker_init_fn=lambda: Worker( - model_config, - parallel_config, - scheduler_config, - None, - None, - )) + self._run_workers( - "init_model", + "init_worker", get_all_outputs=True, + worker_init_fn=partial(self._create_worker, + rank=None, + distributed_init_method=None), ) self._run_workers( - "load_model", + "init_model", get_all_outputs=True, - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, ) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) def _init_cache(self) -> None: """Profiles the memory usage and initializes the KV cache.""" @@ -241,9 +383,6 @@ def _init_cache(self) -> None: # Initialize the cache. self._run_workers("init_cache_engine", cache_config=self.cache_config) - # Warm up the model. This includes capturing the model into CUDA graph - # if enforce_eager is False. - self._run_workers("warm_up_model") @classmethod def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": @@ -261,6 +400,20 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": log_stats=not engine_args.disable_log_stats) return engine + def encode_request( + self, + request_id: str, # pylint: disable=unused-argument + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + ): + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = self.tokenizer.encode(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + return prompt_token_ids + def add_request( self, request_id: str, @@ -268,6 +421,7 @@ def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, ) -> None: """Add a request to the engine's request pool. @@ -285,23 +439,31 @@ def add_request( arrival_time: The arrival time of the request. If None, we use the current monotonic time. """ + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") if arrival_time is None: arrival_time = time.monotonic() - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = self.tokenizer.encode(prompt) + prompt_token_ids = self.encode_request( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) + seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, + lora_request) # Create the sequence group. - seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time) + seq_group = SequenceGroup(request_id, [seq], + sampling_params, arrival_time, + time.perf_counter(), lora_request) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) + self.num_started_tasks += 1 def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: """Aborts a request(s) with the given ID. @@ -309,7 +471,7 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: Args: request_id: The ID(s) of the request to abort. """ - self.scheduler.abort_seq_group(request_id) + self.num_aborted_tasks += self.scheduler.abort_seq_group(request_id) def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" @@ -325,60 +487,116 @@ def has_unfinished_requests(self) -> bool: def _schedule( self - ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, - List[RequestOutput]]: + ) -> Tuple[List[Union[SequenceGroupMetadata, SequenceGroupMetadataDelta]], + SchedulerOutputs, List[RequestOutput]]: seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() return seq_group_metadata_list, scheduler_outputs, [ RequestOutput.from_seq_group(seq_group) for seq_group in scheduler_outputs.ignored_seq_groups ] - def _check_beam_search_early_stopping( - self, - early_stopping: Union[bool, str], - sampling_params: SamplingParams, - best_running_seq: Sequence, - current_worst_seq: Sequence, - ) -> bool: - assert sampling_params.use_beam_search - length_penalty = sampling_params.length_penalty - if early_stopping is True: - return True - - current_worst_score = (current_worst_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) - if early_stopping is False: - highest_attainable_score = (best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) - else: - assert early_stopping == "never" - if length_penalty > 0.0: - # If length_penalty > 0.0, beam search will prefer longer - # sequences. The highest attainable score calculation is - # based on the longest possible sequence length in this case. - max_possible_length = max( - best_running_seq.get_prompt_len() + - sampling_params.max_tokens, - self.scheduler_config.max_model_len) - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id, - seq_len=max_possible_length)) - else: - # Otherwise, beam search will prefer shorter sequences. The - # highest attainable score calculation is based on the current - # sequence length. - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) - return current_worst_score >= highest_attainable_score + # def _check_beam_search_early_stopping( + # self, + # early_stopping: Union[bool, str], + # sampling_params: SamplingParams, + # best_running_seq: Sequence, + # current_worst_seq: Sequence, + # ) -> bool: + # assert sampling_params.use_beam_search + # length_penalty = sampling_params.length_penalty + # if early_stopping is True: + # return True + + # current_worst_score = (current_worst_seq.get_beam_search_score( + # length_penalty=length_penalty, + # eos_token_id=self.tokenizer.get_lora_tokenizer( + # current_worst_seq.lora_request).eos_token_id)) + # if early_stopping is False: + # highest_attainable_score = ( + # best_running_seq.get_beam_search_score( + # length_penalty=length_penalty, + # eos_token_id=self.tokenizer.get_lora_tokenizer( + # best_running_seq.lora_request).eos_token_id)) + # else: + # assert early_stopping == "never" + # if length_penalty > 0.0: + # # If length_penalty > 0.0, beam search will prefer longer + # # sequences. The highest attainable score calculation is + # # based on the longest possible sequence length in this case. + # max_possible_length = max( + # best_running_seq.get_prompt_len() + + # sampling_params.max_tokens, + # self.scheduler_config.max_model_len) + # highest_attainable_score = ( + # best_running_seq.get_beam_search_score( + # length_penalty=length_penalty, + # eos_token_id=self.tokenizer.get_lora_tokenizer( + # best_running_seq.lora_request).eos_token_id, + # seq_len=max_possible_length)) + # else: + # # Otherwise, beam search will prefer shorter sequences. The + # # highest attainable score calculation is based on the current + # # sequence length. + # highest_attainable_score = ( + # best_running_seq.get_beam_search_score( + # length_penalty=length_penalty, + # eos_token_id=self.tokenizer.get_lora_tokenizer( + # best_running_seq.lora_request).eos_token_id)) + # return current_worst_score >= highest_attainable_score + + def _process_spec_decode_sequence_group_outputs( + self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutputs]) -> None: + """Process sequence group outputs when speculative decoding is enabled. + + This serves the same purpose as _process_sequence_group_outputs except + without any of the beam search logic. + """ + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + assert len(seqs) == 1, ("Beam search not supported in speculative " + "decoding.") + seq = seqs[0] + + # Since there's only one sequence per sequence group, we can take the + # first sample. + samples = [outputs[step].samples[0] for step in range(len(outputs))] + + # Draft target worker pads all outputs with -1 to have same length. + output_token_ids = [ + sample.output_token for sample in samples + if sample.output_token != -1 + ] + output_logprobs = [sample.logprobs for sample in samples] + + # Truncate to max_tokens if necessary. + remaining_tokens = seq_group.sampling_params.max_tokens - ( + seq.get_output_len() + len(output_token_ids)) + if remaining_tokens < 0: + output_token_ids = output_token_ids[:remaining_tokens] + output_logprobs = output_logprobs[:remaining_tokens] + + # Truncate any tokens after EOS. This is required as spec decode + # generates tokens in fixed blocks, which may go beyond the EOS token. + if not seq_group.sampling_params.ignore_eos: + eos_token_id = self.tokenizer.get_lora_tokenizer( + seq.lora_request).eos_token_id + # Avoiding .index calls as exception throwing in the happy path + # is expensive. + for i in range(len(output_token_ids)): + if output_token_ids[i] == eos_token_id: + output_token_ids = output_token_ids[:i + 1] + output_logprobs = output_logprobs[:i + 1] + break + + seq.append_token_ids(output_token_ids, output_logprobs) + + self._decode_sequence(seq, seq_group.sampling_params) + self._check_stop(seq, seq_group.sampling_params) + if seq.is_finished(): + self.scheduler.free_seq(seq) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - outputs: SequenceGroupOutput) -> None: + outputs: SequenceGroupOutputs) -> None: # Process prompt logprobs prompt_logprobs = outputs.prompt_logprobs if prompt_logprobs is not None: @@ -399,7 +617,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Process the child samples for each parent sequence for parent in parent_seqs: - child_samples: List[SequenceOutput] = parent_child_dict[ + child_samples: List[SequenceOutputs] = parent_child_dict[ parent.seq_id] if len(child_samples) == 0: # This parent sequence has no children samples. Remove @@ -451,7 +669,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Select the child sequences to keep in the sequence group. selected_child_seqs = [] unselected_child_seqs = [] - beam_width = seq_group.sampling_params.best_of + beam_width = seq_group.sampling_params.actual_best_of length_penalty = seq_group.sampling_params.length_penalty # Select the newly finished sequences with the highest scores @@ -465,7 +683,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Sort the finished sequences by their scores. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), + eos_token_id=self.tokenizer.get_lora_tokenizer(x[0].lora_request + ).eos_token_id), reverse=True) for seq, parent, is_new in all_finished_seqs[:beam_width]: if is_new: @@ -493,7 +712,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Sort the running sequences by their scores. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), + eos_token_id=self.tokenizer.get_lora_tokenizer(x[0].lora_request + ).eos_token_id), reverse=True) # Check if we can stop the beam search. @@ -549,12 +769,35 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, self.scheduler.free_seq(seq) def _process_model_outputs( - self, output: SamplerOutput, + self, output: List[SamplerOutput], scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: # Update the scheduled sequence groups with the model outputs. + now = time.time() scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups - for seq_group, outputs in zip(scheduled_seq_groups, output): - self._process_sequence_group_outputs(seq_group, outputs) + + # Organize list of sampler output by sequence group. + output_by_sequence_group = [[] for _ in scheduled_seq_groups] + for step in output: + for i, sequence_group_output in enumerate(step): + output_by_sequence_group[i].append(sequence_group_output) + + # combine all samples for zipping + for i, (seq_group, outputs) in enumerate( + zip(scheduled_seq_groups, output_by_sequence_group)): + # Chunked prefill groups are not generation tokens. Their + # outputs are ignored. For seq_group finished chunked + # prefilling, it will be considered as prompting. + if i < scheduler_outputs.num_chunked_prefill_groups: + continue + if seq_group.first_token_time is None: + seq_group.first_token_time = now + + if self.speculative_config: + self._process_spec_decode_sequence_group_outputs( + seq_group, outputs) + else: + assert len(outputs) == 1 + self._process_sequence_group_outputs(seq_group, outputs[0]) # Free the finished sequence groups. self.scheduler.free_finished_seq_groups() @@ -565,14 +808,27 @@ def _process_model_outputs( scheduler_outputs.ignored_seq_groups): request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) + self.num_finished_tasks += int(request_output.finished) + + # Write logits to request outputs if present in sampler outputs. + for i, step in enumerate(output): + if step and step.logits is not None: + request_outputs[i].logits = step.logits + + # If worker metrics are provided, store locally. + if (self.speculative_config and output + and output[0].draft_target_worker_metrics is not None): + self._last_draft_target_worker_metrics = output[ + 0].draft_target_worker_metrics if self.log_stats: # Log the system stats. - self._log_system_stats(scheduler_outputs.prompt_run, + self._log_system_stats(scheduler_outputs.num_prompt_groups, scheduler_outputs.num_batched_tokens) + self.num_iterations += 1 return request_outputs - def step(self) -> List[RequestOutput]: + def step(self, return_logits: bool = False) -> List[RequestOutput]: """Performs one decoding iteration and returns newly generated results. This function performs one decoding iteration of the engine. It first @@ -580,21 +836,42 @@ def step(self) -> List[RequestOutput]: token blocks to be swapped in/out/copy. Then, it executes the model and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. + + Args: + return_logits: Whether to return the logits from the model for + quality evaluation purposes. """ seq_group_metadata_list, scheduler_outputs, ignored = self._schedule() if scheduler_outputs.is_empty(): return ignored - # Execute the model. - output = self._run_workers( - "execute_model", + data = ExecuteModelData( seq_group_metadata_list=seq_group_metadata_list, + finished_request_ids_list=list( + scheduler_outputs.done_seq_group_ids), blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_preallocated_slots=scheduler_outputs.num_preallocated_slots, + return_logits=return_logits, ) + # Execute the model. + now = time.perf_counter() + output = self._run_workers("execute_model", + data, + use_shared_memory=self._uses_shared_memory) + logger.debug(f"model execution takes{time.perf_counter() - now}") + + outputs = self._process_model_outputs(output, scheduler_outputs) - return self._process_model_outputs(output, scheduler_outputs) + if self._uses_shared_memory: + if not outputs or all(out.finished for out in outputs): + self.shared_mem_engine_to_worker.clear() + self.shared_mem_worker_to_engine.clear() + self.shared_mem_engine_to_worker.put_to_sleep(block=False) + self.shared_mem_worker_to_engine.put_to_sleep(block=False) + + return outputs def _log_system_stats( self, @@ -608,8 +885,8 @@ def _log_system_stats( else: self.num_generation_tokens.append((now, num_batched_tokens)) - should_log = now - self.last_logging_time >= _LOGGING_INTERVAL_SEC - if not should_log: + elapsed_time = now - self.last_logging_time + if elapsed_time < _LOGGING_INTERVAL_SEC: return # Discard the old stats. @@ -648,16 +925,6 @@ def _log_system_stats( else: cpu_cache_usage = 0.0 - record_metrics( - avg_prompt_throughput=avg_prompt_throughput, - avg_generation_throughput=avg_generation_throughput, - scheduler_running=len(self.scheduler.running), - scheduler_swapped=len(self.scheduler.swapped), - scheduler_waiting=len(self.scheduler.waiting), - gpu_cache_usage=gpu_cache_usage, - cpu_cache_usage=cpu_cache_usage, - ) - logger.info("Avg prompt throughput: " f"{avg_prompt_throughput:.1f} tokens/s, " "Avg generation throughput: " @@ -669,38 +936,73 @@ def _log_system_stats( f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") self.last_logging_time = now - def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: - """Decodes the new token for a sequence.""" - (new_tokens, new_output_text, prefix_offset, - read_offset) = detokenize_incrementally( - self.tokenizer, - all_input_ids=seq.get_token_ids(), - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms.spaces_between_special_tokens, - ) - if seq.tokens is None: - seq.tokens = new_tokens - else: - seq.tokens.extend(new_tokens) - seq.prefix_offset = prefix_offset - seq.read_offset = read_offset - seq.output_text += new_output_text + self._record_system_stats(avg_prompt_throughput, + avg_generation_throughput, gpu_cache_usage, + cpu_cache_usage) + + if self._last_draft_target_worker_metrics is not None: + metrics = self._last_draft_target_worker_metrics + logger.info( + "Speculative metrics: " + f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " + f"System efficiency: {metrics.system_efficiency:.3f}, " + f"Number of speculative tokens: {metrics.num_spec_tokens}, " + f"Number of accepted tokens: {metrics.accepted_tokens}, " + f"Number of draft tokens tokens: {metrics.draft_tokens}, " + f"Number of emitted tokens tokens: {metrics.emitted_tokens}.") + + def _record_system_stats(self, avg_prompt_throughput: float, + avg_generation_throughput: float, + gpu_cache_usage: float, + cpu_cache_usage: float) -> Tuple[float, dict]: + self.last_stats = (self.last_logging_time, { + "avg_prompt_throughput": avg_prompt_throughput, + "avg_generation_throughput": avg_generation_throughput, + "gpu_cache_usage": gpu_cache_usage, + "cpu_cache_usage": cpu_cache_usage, + }) + return self.last_stats + + def _decode_sequence(self, seq: Sequence, + sampling_params: SamplingParams) -> None: + """Decodes new token(s) for a sequence.""" + unseen_token_ids = seq.get_new_token_ids() + token_ids = seq.get_token_ids()[:-len(unseen_token_ids)] + + for new_token_id in unseen_token_ids: + token_ids.append(new_token_id) + + (new_tokens, new_output_text, prefix_offset, + read_offset) = detokenize_incrementally( + self.tokenizer.get_lora_tokenizer(seq.lora_request), + all_input_ids=token_ids, + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=sampling_params.skip_special_tokens, + spaces_between_special_tokens=sampling_params. + spaces_between_special_tokens, + ) + if seq.tokens is None: + seq.tokens = new_tokens + else: + seq.tokens.extend(new_tokens) + seq.prefix_offset = prefix_offset + seq.read_offset = read_offset + seq.output_text += new_output_text def _check_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None: """Stop the finished sequences.""" for stop_str in sampling_params.stop: if seq.output_text.endswith(stop_str): - if not sampling_params.include_stop_str_in_output: - # Truncate the output text so that the stop string is - # not included in the output. - seq.output_text = seq.output_text[:-len(stop_str)] + # Truncate the output text so that the stop string is + # not included in the output. + seq.output_text = seq.output_text[:-len(stop_str)] seq.status = SequenceStatus.FINISHED_STOPPED return - if seq.get_last_token_id() in sampling_params.stop_token_ids: + if set(seq.get_new_token_ids()).intersection( + sampling_params.stop_token_ids): seq.status = SequenceStatus.FINISHED_STOPPED return @@ -710,63 +1012,146 @@ def _check_stop(self, seq: Sequence, return # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: + if seq.get_output_len() >= sampling_params.max_tokens: seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return # Check if the sequence has generated the EOS token. if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == self.tokenizer.eos_token_id): + and self.tokenizer.get_lora_tokenizer( + seq.lora_request).eos_token_id in seq.get_new_token_ids()): seq.status = SequenceStatus.FINISHED_STOPPED return - def _run_workers_in_batch( - self, - workers, - method: str, - *args, - **kwargs, - ): - all_outputs = [] - for worker in workers: - if self.parallel_config.worker_use_ray: - executor = partial(worker.execute_method.remote, method) - else: - executor = getattr(worker, method) - - output = executor(*args, **kwargs) - all_outputs.append(output) - if self.parallel_config.worker_use_ray: - all_outputs = ray.get(all_outputs) - return all_outputs - def _run_workers( self, method: str, *args, get_all_outputs: bool = False, - max_concurrent_workers: Optional[int] = None, + wait_for_workers: bool = True, + use_shared_memory: bool = False, **kwargs, ) -> Any: """Runs the given method on all workers.""" - all_outputs = [] - if max_concurrent_workers: - work_groups = [ - self.workers[i:i + max_concurrent_workers] - for i in range(0, len(self.workers), max_concurrent_workers) - ] + if use_shared_memory: + try: + logger.debug(f"Set data to shared memory: {args[0]}") + self.shared_mem_engine_to_worker.set_data(args[0]) + except RuntimeError: + # Raise underlying exception + ray.get(self._exceute_model_futures, timeout=5) + raise + logger.debug("Waiting for incoming data...") + self.shared_mem_worker_to_engine.wait_for_incoming_data() + try: + output = self.shared_mem_worker_to_engine.get_data() + except RuntimeError: + # Raise underlying exception + ray.get(self._exceute_model_futures, timeout=5) + raise + logger.debug(f"Got data {output}") + self.shared_mem_worker_to_engine.clear() + return output else: - work_groups = [self.workers] + all_outputs = [] + start = time.time() - for workers in work_groups: - all_outputs.extend( - self._run_workers_in_batch(workers, method, *args, **kwargs)) + for worker in self.workers: + if self.parallel_config.worker_use_ray: + executor = partial(worker.execute_method.remote, method) + else: + executor = getattr(worker, method) - if get_all_outputs: - return all_outputs + output = executor(*args, **kwargs) + all_outputs.append(output) - # Make sure all workers have the same results. - output = all_outputs[0] - for other_output in all_outputs[1:]: - assert output == other_output - return output + if self.parallel_config.worker_use_ray: + if wait_for_workers: + all_outputs = ray.get(all_outputs) + + end = time.time() + + if method == "init_model": + logger.info("{} used {:.3f} seconds".format( + method, end - start)) + + if get_all_outputs: + return all_outputs + + # Make sure all workers have the same results. + output = all_outputs[0] + if wait_for_workers: + for other_output in all_outputs[1:]: + assert output == other_output + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> List[int]: + return self._run_workers("list_loras") + + def get_metadata_cache_len(self) -> int: + return self._run_workers("get_metadata_cache_len", ) + + def _check_if_any_actor_is_dead(self): + workers = (self.workers + or []) + (self.tokenizer.tokenizer_actors if isinstance( + self.tokenizer, RayTokenizerPool) else []) + if workers: + dead_actors = [] + for actor in workers: + actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access + if actor_state["State"] == "DEAD": + dead_actors.append(actor) + if dead_actors: + raise RuntimeError("At least one Worker is dead. " + f"Dead Workers: {dead_actors}. ") + + def check_health(self) -> None: + if not self.parallel_config.worker_use_ray: + return + + self._check_if_any_actor_is_dead() + if self._exceute_model_futures: + ready, _ = ray.wait(self._exceute_model_futures, timeout=0) + if ready: + # Raise any exception + ray.get(ready, timeout=1) + raise RuntimeError("At least one Worker is dead.") + + def __del__(self): + if getattr(self, "shared_mem_manager", None) is not None: + self.shared_mem_manager.shutdown() + + def start_profile(self, profile_ray_workers: bool, **kwargs): + """Start profiling. Can optionally run profiling in Ray workers. + """ + self._profiler.start_profile(**kwargs) + if profile_ray_workers: + if not self.parallel_config.worker_use_ray: + raise ValueError( + "Cannot profile ray workers: " + f" worker_use_ray={self.parallel_config.worker_use_ray:}") + + if not self.parallel_config.disable_shared_memory: + raise ValueError("Cannot profile ray workers: shared memory " + "must be disabled") + self._run_workers("start_profile", **kwargs) + + def stop_profile(self, profile_ray_workers: bool): + if self.parallel_config.worker_use_ray and profile_ray_workers: + self._run_workers("stop_profile") + + self._profiler.stop_profile() diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py new file mode 100644 index 0000000000000..d34cbbed55c4d --- /dev/null +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -0,0 +1,388 @@ +from typing import Tuple, Optional +from functools import cached_property + +import torch +import torch.nn as nn +import torch.jit + + +# torch.multinomial forces a GPU<->CPU sync. +# Therefore, we use an optimized implementation instead that skips the sync. +# Note that we always sample with replacement. +# probs will be modified in place, but this is fine, as we pass +# in a copy already. +@torch.jit.script +def _multinomial( + probs: torch.Tensor, + num_samples: int, +) -> torch.Tensor: + if num_samples > 1: + # This is equivalent to torch.repeat_interleaved (which also + # forces a GPU<->CPU sync). + probs = probs[:, None, :].expand(probs.shape[0], num_samples, + probs.shape[1]).contiguous().view( + -1, probs.shape[1]) + q = torch.empty_like(probs).exponential_(1.0) + return probs.div_(q).argmax(dim=1).view(-1, num_samples) + + +class RejectionSampler(nn.Module): + """Apply modified rejection sampling as described in "Accelerating Large + Language Model Decoding with Speculative Sampling" + https://arxiv.org/pdf/2302.01318.pdf. + """ + + def __init__(self, strict_mode: bool = False): + """Create a rejection sampler. + + Args: + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + """ + super().__init__() + self.probs_dtype = torch.float32 + self.token_id_dtype = torch.int64 + self._num_bonus_tokens = 1 + self._strict_mode = strict_mode + + self.num_accepted_tokens: Optional[torch.Tensor] = None + self.num_emitted_tokens: Optional[torch.Tensor] = None + self.num_draft_tokens: int = 0 + + def init_gpu_tensors(self, rank: int) -> None: + assert self.num_accepted_tokens is None + device = f"cuda:{rank}" + self.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + self.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + + def forward( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + """Sample token ids using rejection sampling. This accepts or rejects + tokens proposed by the draft model using the probability of each token + according to the draft and target models. + + In the worst case where all draft tokens are rejected, it is guaranteed + one correct token will be emitted. + + In the case where all draft tokens are accepted, a bonus token will be + accepted as its cheap to have the target model score this speculative + sequence. + + Args: + target_probs: The probability distribution over token ids given + context according to the target model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + bonus_token_ids: The "bonus" token ids that are accepted iff all + speculative tokens in a sequence are accepted. + shape = [batch_size, num_bonus_tokens] + + draft_probs: The probability distribution over token ids given + context according to the draft model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + draft_token_ids: The token ids that were sampled from the draft + probabilities. + shape = [batch_size, num_speculative_tokens] + + Returns: + output_token_ids: The token ids sampled via rejection sampling, + or -1 if unable to sample a token because the previous token + was rejected. + shape = [batch_size, num_speculative_tokens + num_bonus_tokens] + """ + # Only perform shape/dtype/device checking in strict mode, as it adds + # overhead. + if self._strict_mode: + self._raise_if_incorrect_shape(target_probs, bonus_token_ids, + draft_probs, draft_token_ids) + self._raise_if_incorrect_dtype(target_probs, bonus_token_ids, + draft_probs, draft_token_ids) + self._raise_if_inconsistent_device(target_probs, bonus_token_ids, + draft_probs, draft_token_ids) + self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], + bonus_token_ids, + draft_token_ids) + + accepted, recovered_token_ids = self._batch_modified_rejection_sampling( + target_probs, + draft_probs, + draft_token_ids, + ) + + output_token_ids = self._create_output( + accepted, + recovered_token_ids, + draft_token_ids, + bonus_token_ids, + ) + return output_token_ids + + def _batch_modified_rejection_sampling( + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Perform modified rejection sampling on each sequence. + + Returns: + A tuple of two tensors: + 0: A bool tensor of which tokens in each sequence is accepted. + shape = [batch_size, k] + 1: Token ids sampled from a recovered distribution, to be used + when a token is rejected. + shape = [batch_size, k] + """ + + batch_size, k, vocab_size = draft_probs.shape + + # shape [batch_size, k] + accepted = self._get_accepted(target_probs, draft_probs, + draft_token_ids) + + recovered_probs = self._get_recovered_probs( + target_probs, draft_probs).reshape(batch_size * k, vocab_size) + + recovered_token_ids = _multinomial(recovered_probs, + num_samples=1).reshape( + batch_size, k) + return accepted, recovered_token_ids + + def _get_accepted( + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + ) -> torch.Tensor: + r"""Create bool matrix over the proposed draft tokens. If + True, then a token can be accepted, else it should be + rejected. + + Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of + :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according + to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the + same conditional probability according to the draft model, the token + is accepted with probability: + + .. math:: + \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)} + {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right) + + This implementation does not apply causality. When using the output, + if a token is rejected, subsequent tokens should not be used. + + Returns a bool tensor of shape [batch_size, k] specifying which tokens + are accepted. + """ + batch_size, k, _ = draft_probs.shape + batch_indices = torch.arange(batch_size, + device=target_probs.device)[:, None] + probs_indicies = torch.arange(k, device=target_probs.device) + + # shape [batch_size, k] + selected_draft_probs = draft_probs[batch_indices, probs_indicies, + draft_token_ids] + + # shape [batch_size, k] + selected_target_probs = target_probs[batch_indices, probs_indicies, + draft_token_ids] + + uniform_rand = torch.rand(batch_size, + k, + dtype=self.probs_dtype, + device=target_probs.device) + capped_ratio = torch.minimum( + selected_target_probs / selected_draft_probs, + torch.full((1, ), 1, device=target_probs.device)) + accepted = uniform_rand < capped_ratio + + return accepted + + def _get_recovered_probs( + self, + target_probs: torch.Tensor, # [k, vocab_size] + draft_probs: torch.Tensor, # [k, vocab_size] + ) -> torch.Tensor: + r"""Create a probability distribution for each proposed token which can + be sampled if the proposed token is rejected. + + When this routine is applied sequentially, the true distribution of the + target model is recovered (within hardware numerics). + + The probability distribution used in this rejection case is constructed + as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of + :math:`x` given context :math:`x_1, \dots, x_n` according to the target + model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability + according to the draft model: + + .. math:: + x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+ + + where :math:`(f(x))_+` is defined as: + + .. math:: + (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))} + + See https://github.com/anyscale/vllm/pull/157 for a visualization of the + draft, target, and recovered probability distributions. + + Returns a tensor of shape [batch_size, k, vocab_size]. + + Note: This batches operations on GPU and thus constructs the recovered + distribution for all tokens, even if they are accepted. This causes + division-by-zero errors, so we use self._smallest_positive_value to + avoid that. This introduces some drift to the distribution. + """ + _, k, _ = draft_probs.shape + + # shape [batch_size, k, vocab_size] + difference = target_probs - draft_probs + + # TODO(cade): Can we use logprobs instead of probs, and avoid the + # division-by-zero errors without introducing distribution drift? + + # shape [batch_size, k, vocab_size] + f = torch.clamp(difference, min=self._smallest_positive_value) + + # shape [batch_size, k, vocab_size] + recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1) + + return recovered_probs + + @cached_property + def _smallest_positive_value(self) -> float: + """Return the smallest positive value representable by the probs dtype. + This value is used when constructing a distribution from which to sample + recovered tokens in the first rejection case. + + See _get_recovered_probs for more details + + Note that this isn't actually the smallest positive value representable + by float32, but the smallest positive normal value. + See https://en.wikipedia.org/wiki/Subnormal_number for more information. + """ + return torch.finfo(self.probs_dtype).tiny + + def _create_output( + self, + accepted: torch.Tensor, # [batch_size, k] + recovered_token_ids: torch.Tensor, # [batch_size, k] + draft_token_ids: torch.Tensor, # [batch_size, k] + bonus_token_ids: torch.Tensor, # [batch_size] + ) -> torch.Tensor: + """Format output. Returns a matrix of token ids. When + a token is rejected via rejection sampling, all subsequent + token ids are set to -1 for the sequence. + + shape = [batch_size, k + num_bonus_tokens] + """ + bonus_token_ids = bonus_token_ids.squeeze() + batch_size, k = recovered_token_ids.shape + + # Determine the index of the first False value for each row. + limits = (accepted == 0).max(1).indices + limits[~(accepted == 0).any(1)] = k + + # Create masks using the indices. + indices = torch.arange(k, device=accepted.device).unsqueeze(0) + accepted_mask = indices < limits.unsqueeze(1) + after_false_mask = indices == limits.unsqueeze(1) + + # Create an extended output tensor + output_with_bonus_tokens = -torch.ones( + (batch_size, k + self._num_bonus_tokens), + dtype=self.token_id_dtype, + device=accepted.device) + output = output_with_bonus_tokens[:, :k] + + # Fill in the first k columns of the output tensor using masks and data + # tensors. + output[:, :k] = torch.where(accepted_mask, draft_token_ids, + -torch.ones_like(draft_token_ids)) + + # Fill the last column. + # We check output directly as accepted may have True values inconsistent + # with causal acceptance. + output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, + bonus_token_ids, -1) + + # Fill the recovered token ids. + output.mul_(~after_false_mask).add_( + recovered_token_ids.mul(after_false_mask)) + + self.num_accepted_tokens += accepted.sum() + self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() + self.num_draft_tokens += batch_size * k + + return output_with_bonus_tokens + + def _raise_if_incorrect_shape( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + (target_batch_size, num_target_probs, + target_vocab_size) = target_probs.shape + bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape + draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape + draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape + + assert draft_batch_size == target_batch_size + assert num_draft_probs == num_target_probs + assert (draft_vocab_size == target_vocab_size + ), f"{draft_vocab_size=} {target_vocab_size=}" + + assert draft_token_ids_batch_size == draft_batch_size + assert num_draft_token_ids == num_draft_probs + + assert bonus_batch_size == target_batch_size + assert num_bonus_tokens == self._num_bonus_tokens + + def _raise_if_incorrect_dtype( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + assert all(probs.dtype == self.probs_dtype + for probs in [target_probs, draft_probs]) + assert all(token_ids.dtype == self.token_id_dtype + for token_ids in [bonus_token_ids, draft_token_ids]) + + def _raise_if_inconsistent_device( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + devices = [ + t.device for t in + [target_probs, bonus_token_ids, draft_probs, draft_token_ids] + ] + assert [devices[0] == device for device in devices] + + def _raise_if_out_of_bounds_vocab( + self, + vocab_size: int, + bonus_token_ids: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + assert torch.all(bonus_token_ids < vocab_size) + assert torch.all(bonus_token_ids >= 0) + assert torch.all(draft_token_ids < vocab_size) + assert torch.all(draft_token_ids >= 0) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index fe88b0ea42936..05be95ea5319c 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,15 +1,341 @@ """A layer that samples the next tokens from the model's outputs.""" +import math +import logging +from dataclasses import dataclass +from collections import defaultdict from typing import Dict, List, Optional, Tuple +import msgspec +import ray +import numpy as np +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy import torch +import torch.jit import torch.nn as nn +from vllm.anyscale.shm.msgspec_shm import RayEvent, SharedMsgspecBufferWithEvent, SharedMemoryManager +from vllm.anyscale.shm.numpy import numpy_encode_hook, numpy_ext_hook +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.sampler_ops.penalty_triton import apply_penalty from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_gather) -from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors -from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sampling_params import SamplingType from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, - SequenceData, SequenceGroupOutput, SequenceOutput) + SequenceGroupOutputs, SequenceOutputs) +from vllm.utils import in_wsl + +_SAMPLING_EPS = 1e-5 +SHARED_MEMORY_BUFFER_SIZE = int(2e+7) # 20 MB +logger = logging.getLogger(__name__) + + +@dataclass +class SamplingTokenTensors: + """Datastructure used to encode sampling inputs in GPU tensors. + + This enables a non-blocking sampler. + """ + unique_output_token_ids: torch.Tensor + output_token_counts: torch.Tensor + sample_indices: torch.Tensor + prompt_indices: torch.Tensor + categorized_sample_indices: Tuple[torch.Tensor, torch.Tensor] + cumsum_penalties_seq_lens: torch.Tensor + max_penalties_seq_len: int + + @classmethod + def from_lists( + cls, + unique_output_token_ids: List[int], + output_token_counts: List[int], + penalties_seq_lens: List[int], + sample_indices: List[int], + prompt_indices: List[int], + categorized_sample_indices: Dict[SamplingType, List[int]], + device: torch.device, + vocab_size: int # pylint: disable=unused-argument + ) -> "SamplingTokenTensors": + # WSL doesn't support pinned memory. + # Note that the performance will be very bad without + # pinned memory. + pin_memory = not in_wsl() + + max_penalties_seq_len = max(penalties_seq_lens) + # Must have length of batch_size+1 for cumsum used by triton + # penalty kernel + # Represents the number of unique token ids in output for + # each sequence + penalties_seq_lens = [0] + penalties_seq_lens + + sampling_categories = (SamplingType.GREEDY, SamplingType.RANDOM) + indicies_list = sample_indices + prompt_indices + offset = len(indicies_list) + + for indicies in categorized_sample_indices.values(): + indicies_list.extend(indicies) + + output_count_int_tensor = torch.tensor( + [unique_output_token_ids, output_token_counts], + device="cpu", + dtype=torch.int, + pin_memory=pin_memory, + ) + penalties_seq_len_tensor = torch.tensor( + penalties_seq_lens, + device="cpu", + dtype=torch.int, + pin_memory=pin_memory, + ) + indices_tensor = torch.tensor( + indicies_list, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + output_count_int_tensor_gpu = output_count_int_tensor.to( + device=device, non_blocking=True) + penalties_seq_len_tensor_gpu = penalties_seq_len_tensor.to( + device=device, non_blocking=True) + indices_tensor_gpu = indices_tensor.to(device=device, + non_blocking=True) + categorized_sample_indices_tensors = [None] * len(sampling_categories) + for category in sampling_categories: + sample_indices_len = len( + categorized_sample_indices.get(category, [])) + categorized_sample_indices_tensors[category] = indices_tensor_gpu[ + offset:offset + sample_indices_len] + offset += sample_indices_len + + return cls( + sample_indices=indices_tensor_gpu[:len(sample_indices)], + prompt_indices=indices_tensor_gpu[len(sample_indices + ):len(sample_indices) + + len(prompt_indices)], + categorized_sample_indices=tuple( + categorized_sample_indices_tensors), + unique_output_token_ids=output_count_int_tensor_gpu[0], + output_token_counts=output_count_int_tensor_gpu[1], + cumsum_penalties_seq_lens=penalties_seq_len_tensor_gpu.cumsum(0), + max_penalties_seq_len=max_penalties_seq_len) + + @classmethod + def from_input_metadata(cls, input_metadata: InputMetadata, + vocab_size: int, device: torch.device): + unique_output_token_ids: List[int] = [] + output_token_counts: List[int] = [] + penalties_seq_lens: List[int] = [] + sample_indices: List[int] = [] + prompt_indices: List[int] = [] + categorized_sample_indices: Dict[SamplingType, + List[int]] = defaultdict(list) + + sample_indices_start_idx = 0 + categorized_indices_start_idx = 0 + + for i, seq_group in enumerate(input_metadata.seq_groups): + seq_ids, sampling_params = seq_group + + is_prompt = i < input_metadata.num_prompts + if is_prompt: + prompt_len = input_metadata.prompt_lens[i] + + if sampling_params.prompt_logprobs is not None: + prompt_indices.extend( + range(sample_indices_start_idx, + sample_indices_start_idx + prompt_len - 1)) + # NOTE: prompt token positions do not need sample, skip + sample_indices_start_idx += prompt_len - 1 + + categorized_sample_indices[ + sampling_params.sampling_type].append( + categorized_indices_start_idx) + sample_indices.append(sample_indices_start_idx) + sample_indices_start_idx += 1 + categorized_indices_start_idx += 1 + + for seq_id in seq_ids: + if sampling_params.has_penalties: + seq_data = input_metadata.seq_data[seq_id] + id_to_counts = seq_data.output_token_id_count + for token_id, count in id_to_counts.items(): + unique_output_token_ids.append(token_id) + output_token_counts.append(count) + penalties_seq_lens.append(len(id_to_counts)) + else: + penalties_seq_lens.append(0) + + if not is_prompt: + categorized_sample_indices[ + sampling_params.sampling_type].append( + categorized_indices_start_idx) + sample_indices.append(sample_indices_start_idx) + sample_indices_start_idx += 1 + categorized_indices_start_idx += 1 + + return cls.from_lists(unique_output_token_ids, output_token_counts, + penalties_seq_lens, sample_indices, + prompt_indices, categorized_sample_indices, + device, vocab_size) + + +@dataclass +class SamplingParametersTensors: + """Datastructure used to encode sampling inputs in GPU tensors. + + This enables a non-blocking sampler. + """ + temperatures: torch.Tensor + top_ps: torch.Tensor + top_ks: torch.Tensor + min_ps: torch.Tensor + presence_penalties: torch.Tensor + frequency_penalties: torch.Tensor + repetition_penalties: torch.Tensor + max_top_k: int + max_prompt_best_of: int + do_penalties: bool + do_top_p_top_k: bool + do_min_p: bool + largest_num_logprobs: int + + @classmethod + def from_lists(cls, temperatures: List[float], top_ps: List[float], + top_ks: List[int], min_ps: List[float], + presence_penalties: List[float], + frequency_penalties: List[float], + repetition_penalties: List[float], + prompt_best_of: List[int], do_penalties: bool, + do_top_p_top_k: bool, do_min_p: bool, + largest_num_logprobs: int, device: torch.device, + dtype: torch.dtype) -> "SamplingParametersTensors": + + # WSL doesn't support pinned memory. + # Note that the performance will be very bad without + # pinned memory. + pin_memory = not in_wsl() + + max_top_k = max(top_ks) + max_prompt_best_of = max(prompt_best_of) if prompt_best_of else 1 + + float_tensor = torch.tensor( + [ + temperatures, top_ps, min_ps, presence_penalties, + frequency_penalties, repetition_penalties + ], + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + int_tensor = torch.tensor( + [top_ks], + device="cpu", + dtype=torch.int, + pin_memory=pin_memory, + ) + float_tensor_gpu = float_tensor.to(device=device, non_blocking=True) + int_tensor_gpu = int_tensor.to(device=device, non_blocking=True) + + return cls(temperatures=float_tensor_gpu[0], + top_ps=float_tensor_gpu[1], + top_ks=int_tensor_gpu[0], + min_ps=float_tensor_gpu[2], + presence_penalties=float_tensor_gpu[3], + frequency_penalties=float_tensor_gpu[4], + repetition_penalties=float_tensor_gpu[5], + max_top_k=max_top_k, + max_prompt_best_of=max_prompt_best_of, + do_penalties=do_penalties, + do_top_p_top_k=do_top_p_top_k, + do_min_p=do_min_p, + largest_num_logprobs=largest_num_logprobs) + + @classmethod + def from_input_metadata(cls, input_metadata: InputMetadata, + vocab_size: int, device: torch.device, + dtype: torch.dtype): + top_ks: List[int] = [] + temperatures: List[float] = [] + top_ps: List[float] = [] + min_ps: List[float] = [] + presence_penalties: List[float] = [] + frequency_penalties: List[float] = [] + repetition_penalties: List[float] = [] + prompt_best_of: List[int] = [] + + do_penalties = False + do_top_p_top_k = False + do_min_p = False + largest_num_logprobs = 0 + + for i, seq_group in enumerate(input_metadata.seq_groups): + seq_ids, sampling_params = seq_group + is_prompt = i < input_metadata.num_prompts + + temperature = sampling_params.temperature + # k should not be greater than the vocab size. + top_k = min(sampling_params.top_k, vocab_size) + top_k = vocab_size if top_k == -1 else top_k + + if temperature < _SAMPLING_EPS: + # NOTE: Zero temperature means deterministic sampling + # (i.e., greedy sampling or beam search). + # Set the temperature to 1 to avoid division by zero. + temperature = 1.0 + if not do_top_p_top_k and ( + sampling_params.top_p < 1.0 - _SAMPLING_EPS + or top_k != vocab_size): + do_top_p_top_k = True + if not do_min_p and sampling_params.min_p > _SAMPLING_EPS: + do_min_p = True + if not do_penalties and sampling_params.has_penalties: + do_penalties = True + + if is_prompt: + prompt_best_of.append(sampling_params.actual_best_of) + if sampling_params.prompt_logprobs is not None: + prompt_len = input_metadata.prompt_lens[i] + temperatures += [temperature] * (prompt_len - 1) + top_ps += [sampling_params.top_p] * (prompt_len - 1) + top_ks += [top_k] * (prompt_len - 1) + min_ps += [sampling_params.min_p] * (prompt_len - 1) + presence_penalties += [0] * (prompt_len - 1) + frequency_penalties += [0] * (prompt_len - 1) + repetition_penalties += [1] * (prompt_len - 1) + + top_ks += [top_k] * len(seq_ids) + temperatures += [temperature] * len(seq_ids) + top_ps += [sampling_params.top_p] * len(seq_ids) + min_ps += [sampling_params.min_p] * len(seq_ids) + presence_penalties += [sampling_params.presence_penalty + ] * len(seq_ids) + frequency_penalties += [sampling_params.frequency_penalty + ] * len(seq_ids) + repetition_penalties += [sampling_params.repetition_penalty + ] * len(seq_ids) + if sampling_params.logprobs: + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.logprobs) + + return cls.from_lists(temperatures, top_ps, top_ks, min_ps, + presence_penalties, frequency_penalties, + repetition_penalties, prompt_best_of, + do_penalties, do_top_p_top_k, do_min_p, + largest_num_logprobs, device, dtype) + + +@dataclass +class RawSamplerOutput: + """Class containing sampler output stored in torch tensors. + """ + sampled_tokens: torch.Tensor + sampled_logprobs: torch.Tensor + prompt_logprobs: torch.Tensor + probs: torch.Tensor + sampling_parameters_tensors: "SamplingParametersTensors" + sampling_token_tensors: "SamplingTokenTensors" + top_logprobs: Optional[torch.Tensor] + top_token_ids: Optional[torch.Tensor] + logits: Optional[torch.Tensor] class Sampler(nn.Module): @@ -27,57 +353,103 @@ class Sampler(nn.Module): parameters (e.g., sampling method, temperature, top-p, top-k, etc.). """ - def __init__(self, vocab_size: int) -> None: + _copy_stream: Optional[torch.cuda.Stream] = None + + def __init__( + self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + include_gpu_probs_tensor: bool = False, + ) -> None: super().__init__() self.vocab_size = vocab_size - self._copy_stream: torch.cuda.Stream = torch.cuda.Stream() + self.org_vocab_size = org_vocab_size or vocab_size + self.include_gpu_probs_tensor = include_gpu_probs_tensor + + def __del__(self): + if getattr(self, "_shared_mem_manager", None) is not None: + self._shared_mem_manager.shutdown() + + def _get_logits( + self, + embedding: torch.Tensor, + hidden_states: torch.Tensor, + embedding_bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.nn.functional.linear(hidden_states, embedding, + embedding_bias) + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + logits = logits[:, :self.org_vocab_size] + return logits def forward( self, embedding: torch.Tensor, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, + input_metadata: InputMetadata, embedding_bias: Optional[torch.Tensor] = None, - ) -> SamplerOutput: + sampling_parameters_tensors: Optional[ + SamplingParametersTensors] = None, + sampling_token_tensors: Optional[SamplingTokenTensors] = None, + ) -> RawSamplerOutput: + # Get logits for entire sequence before pruning hidden states + # for model quality evaluation. + batched_seq_logits = None + if input_metadata.return_logits: + batched_seq_logits = self._get_logits( + embedding, hidden_states, embedding_bias).to(torch.float) + # Get the hidden states that we use for sampling. - hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) + hidden_states = _prune_hidden_states(hidden_states, input_metadata) # Get the logits for the next tokens. - logits = _get_logits(hidden_states, embedding, embedding_bias, - self.vocab_size) + logits = self._get_logits(embedding, hidden_states, embedding_bias) _, vocab_size = logits.shape - # Apply logits processors (if any). - logits = _apply_logits_processors(logits, sampling_metadata) - # Prepare sampling tensors in another stream to overlap # CPU<->GPU data transfer with GPU computation in forward pass. - with torch.cuda.stream(self._copy_stream): - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype) - - torch.cuda.current_stream().wait_stream(self._copy_stream) + if not sampling_parameters_tensors or not sampling_token_tensors: + if Sampler._copy_stream is None: + # Initialize stream here once to make sure it uses the + # correct device. + Sampler._copy_stream = torch.cuda.Stream() + with torch.cuda.stream(Sampler._copy_stream): + if not sampling_parameters_tensors: + sampling_parameters_tensors = ( + SamplingParametersTensors.from_input_metadata( + input_metadata, vocab_size, logits.device, + logits.dtype)) + if not sampling_token_tensors: + sampling_token_tensors = ( + SamplingTokenTensors.from_input_metadata( + input_metadata, vocab_size, logits.device)) + + torch.cuda.current_stream().wait_stream(Sampler._copy_stream) # Apply presence and frequency penalties. - if do_penalties: - logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, - sampling_tensors.output_tokens, - sampling_tensors.presence_penalties, - sampling_tensors.frequency_penalties, - sampling_tensors.repetition_penalties) - - # Apply temperature scaling. + if sampling_parameters_tensors.do_penalties: + logits = _apply_penalties_triton( + logits, sampling_token_tensors.unique_output_token_ids, + sampling_token_tensors.output_token_counts, + sampling_token_tensors.cumsum_penalties_seq_lens, + sampling_token_tensors.max_penalties_seq_len, + sampling_parameters_tensors.presence_penalties, + sampling_parameters_tensors.frequency_penalties, + sampling_parameters_tensors.repetition_penalties) + # Use in-place division to avoid creating a new tensor. - logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) + logits.div_(sampling_parameters_tensors.temperatures.unsqueeze(dim=1)) - if do_top_p_top_k: - logits = _apply_top_p_top_k(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) + # Apply top-p and top-k truncation. + if sampling_parameters_tensors.do_top_p_top_k: + logits = _apply_top_p_top_k(logits, + sampling_parameters_tensors.top_ps, + sampling_parameters_tensors.top_ks) - if do_min_p: - logits = _apply_min_p(logits, sampling_tensors.min_ps) + if sampling_parameters_tensors.do_min_p: + logits = _apply_min_p(logits, sampling_parameters_tensors.min_ps) # We use float32 for probabilities and log probabilities. # Compute the probabilities. @@ -86,60 +458,142 @@ def forward( # Use log_softmax to ensure numerical stability. logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + if sampling_parameters_tensors.largest_num_logprobs > 0: + top_logprobs, top_token_ids = torch.topk( + logprobs, + sampling_parameters_tensors.largest_num_logprobs, + dim=-1) + else: + top_logprobs, top_token_ids = None, None + # Sample the next tokens. - sample_results = _sample(probs, logprobs, sampling_metadata) - # Get the logprobs query results. - prompt_logprobs, sample_logprobs = _get_logprobs( - logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, sampling_metadata, - prompt_logprobs, sample_logprobs) - - -def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, - embedding_bias: Optional[torch.Tensor], - vocab_size: int) -> torch.Tensor: - # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) - if embedding_bias is not None: - logits += embedding_bias - logits = tensor_model_parallel_all_gather(logits) - # Remove paddings in vocab (if any). - logits = logits[:, :vocab_size] - return logits + sampled_tokens, sampled_logprobs, prompt_logprobs = _sample( + probs=probs, + logprobs=logprobs, + prompt_token_indices=sampling_token_tensors.prompt_indices, + sample_indices=sampling_token_tensors.sample_indices, + categorized_sample_indices=sampling_token_tensors. + categorized_sample_indices, + max_best_of=sampling_parameters_tensors.max_prompt_best_of, + modify_greedy_probs=self.include_gpu_probs_tensor, + ) + + return RawSamplerOutput(sampled_tokens, sampled_logprobs, + prompt_logprobs, probs, + sampling_parameters_tensors, + sampling_token_tensors, top_logprobs, + top_token_ids, batched_seq_logits) + + +def _flatten_list(lst): + return [item for sublist in lst for item in sublist] + + +def pythonize_sampler_output(raw_sampler_output: RawSamplerOutput, + input_metadata: InputMetadata) -> SamplerOutput: + """Convert sampling output stored in PyTorch tensors to sampling output + stored in Python datastructures. + + This blocks the CPU until the GPU catches up, so should only be used when + necessary. + """ + # GPU<->CPU sync happens below. + + samples = raw_sampler_output.sampled_tokens.tolist() + logprobs = raw_sampler_output.sampled_logprobs.tolist() + prompt_logprobs = raw_sampler_output.prompt_logprobs.tolist() + if raw_sampler_output.top_logprobs is not None: + top_logprobs = raw_sampler_output.top_logprobs.tolist() + top_token_ids = raw_sampler_output.top_token_ids.tolist() + sample_idx = 0 + prompt_logprobs_idx = 0 + top_logprob_idx = 0 + sampler_output = [] + + for i, seq_group in enumerate(input_metadata.seq_groups): + seq_ids, sampling_params = seq_group + is_prompt = i < input_metadata.num_prompts + num_parent_seqs = len(seq_ids) + if sampling_params.sampling_type == SamplingType.GREEDY: + assert num_parent_seqs == 1, ( + "Greedy sampling should have only one seq.") + parent_ids = list(range(num_parent_seqs)) + token_ids = samples[sample_idx][0:1] + seq_logprobs = logprobs[sample_idx][0:1] + offset = 1 + elif is_prompt: + actual_best_of = sampling_params.actual_best_of + parent_ids = [0] * actual_best_of + token_ids = samples[sample_idx][:actual_best_of] + seq_logprobs = logprobs[sample_idx][:actual_best_of] + offset = 1 + else: + parent_ids = list(range(num_parent_seqs)) + token_ids = _flatten_list(samples[sample_idx:sample_idx + + num_parent_seqs]) + seq_logprobs = _flatten_list(logprobs[sample_idx:sample_idx + + num_parent_seqs]) + offset = num_parent_seqs + + if is_prompt and sampling_params.prompt_logprobs is not None: + group_prompt_logprobs: PromptLogprobs = [None] + prompt_tokens = input_metadata.seq_data[ + seq_ids[0]].get_prompt_token_ids() + for token_id in prompt_tokens[1:]: + prompt_logprobs_dict = { + token_id: prompt_logprobs[prompt_logprobs_idx][token_id] + } + if sampling_params.prompt_logprobs > 0: + prompt_logprobs_dict.update( + zip( + top_token_ids[top_logprob_idx] + [:sampling_params.prompt_logprobs], + top_logprobs[top_logprob_idx] + [:sampling_params.prompt_logprobs])) + group_prompt_logprobs.append(prompt_logprobs_dict) + top_logprob_idx += 1 + prompt_logprobs_idx += 1 + else: + group_prompt_logprobs = None + + num_logprobs = sampling_params.logprobs + if num_logprobs is None: + num_logprobs = 0 + group_sample_logprobs: SampleLogprobs = [] + for next_token_id, logprob, parent_id in zip(token_ids, seq_logprobs, + parent_ids): + sample_logprobs_dict = {next_token_id: logprob} + if num_logprobs > 0: + sample_logprobs_dict.update( + zip( + top_token_ids[top_logprob_idx + + parent_id][:num_logprobs], + top_logprobs[top_logprob_idx + + parent_id][:num_logprobs])) + group_sample_logprobs.append(sample_logprobs_dict) + + sample_idx += offset + top_logprob_idx += offset + sampler_output.append( + SequenceGroupOutputs([ + SequenceOutputs(seq_ids[parent_id], token_id, seq_logprobs) + for parent_id, token_id, seq_logprobs in zip( + parent_ids, token_ids, group_sample_logprobs) + ], group_prompt_logprobs)) + + return SamplerOutput(sampler_output) def _prune_hidden_states( hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, + input_metadata: InputMetadata, ) -> torch.Tensor: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - return hidden_states.index_select(0, - sampling_metadata.selected_token_indices) - - -def _get_prompt_and_output_tokens( - sampling_metadata: SamplingMetadata, -) -> Tuple[List[List[int]], List[List[int]]]: - prompt_tokens: List[List[int]] = [] - output_tokens: List[List[int]] = [] - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - # NOTE: prompt token positions do not need output tokens to - # compute penalties. - prompt_len = sampling_metadata.prompt_lens[i] - prompt_tokens.extend([] for _ in range(prompt_len - 1)) - output_tokens.extend([] for _ in range(prompt_len - 1)) - for seq_id in seq_ids: - seq_data = sampling_metadata.seq_data[seq_id] - prompt_tokens.append(seq_data.prompt_token_ids) - output_tokens.append(seq_data.output_token_ids) - return prompt_tokens, output_tokens + return hidden_states.index_select(0, input_metadata.selected_token_indices) +@torch.jit.script def _get_bin_counts_and_mask( - tokens: torch.Tensor, + tokens_tensor: torch.Tensor, vocab_size: int, num_seqs: int, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -147,38 +601,14 @@ def _get_bin_counts_and_mask( # vocab_size + 1 for padding. bin_counts = torch.zeros((num_seqs, vocab_size + 1), dtype=torch.long, - device=tokens.device) - bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) + device=tokens_tensor.device) + bin_counts.scatter_add_(1, tokens_tensor, torch.ones_like(tokens_tensor)) bin_counts = bin_counts[:, :vocab_size] mask = bin_counts > 0 return bin_counts, mask -def _apply_logits_processors( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> torch.Tensor: - logits_row_idx = 0 - found_logits_processors = False - for seq_ids, sampling_params in sampling_metadata.seq_groups: - logits_processors = sampling_params.logits_processors - if logits_processors: - found_logits_processors = True - for seq_id in seq_ids: - logits_row = logits[logits_row_idx] - token_ids = sampling_metadata.seq_data[seq_id].output_token_ids - for logits_processor in logits_processors: - logits_row = logits_processor(token_ids, logits_row) - logits[logits_row_idx] = logits_row - logits_row_idx += 1 - else: - logits_row_idx += len(seq_ids) - if found_logits_processors: - assert logits_row_idx == logits.shape[0] - return logits - - def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_tokens_tensor: torch.Tensor, presence_penalties: torch.Tensor, @@ -191,17 +621,32 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_tokens_tensor, vocab_size, num_seqs) repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) - repetition_penalties[~(prompt_mask | output_mask)] = 1.0 + repetition_penalties[prompt_mask.logical_or_( + output_mask).logical_not_()] = 1.0 logits = torch.where(logits > 0, logits / repetition_penalties, logits * repetition_penalties) # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts - logits -= presence_penalties.unsqueeze_(dim=1) * output_mask + logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze(dim=1) * output_mask return logits +def _apply_penalties_triton( + logits: torch.Tensor, unique_output_token_ids: torch.Tensor, + output_token_counts: torch.Tensor, + cumsum_penalties_seq_lens: torch.Tensor, max_penalties_seq_len: int, + presence_penalties: torch.Tensor, frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: + apply_penalty(logits, presence_penalties, frequency_penalties, + repetition_penalties, unique_output_token_ids, + output_token_counts, cumsum_penalties_seq_lens, + max_penalties_seq_len) + return logits + + +@torch.jit.script def _apply_top_p_top_k( logits: torch.Tensor, p: torch.Tensor, @@ -211,17 +656,17 @@ def _apply_top_p_top_k( # Apply top-p. probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort) - top_p_mask = probs_sum > p.unsqueeze_(dim=1) + probs_sum = probs_sort.cumsum(dim=-1) # Apply top-k. # Create a mask for the top-k elements. top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device) top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1) - top_k_mask = top_k_mask >= k.unsqueeze_(dim=1) + p = p.unsqueeze(dim=1) + k = k.unsqueeze(dim=1) # Final mask. - mask = (top_p_mask | top_k_mask) + mask = ((probs_sum - probs_sort) > p) | (top_k_mask >= k) logits_sort.masked_fill_(mask, -float("inf")) # Re-sort the probabilities. @@ -234,6 +679,7 @@ def _apply_top_p_top_k( return logits +@torch.jit.script def _apply_min_p( logits: torch.Tensor, min_p: torch.Tensor, @@ -244,334 +690,136 @@ def _apply_min_p( """ probs = torch.softmax(logits, dim=-1) top_probs, _ = probs.max(dim=-1, keepdim=True) - scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs - tokens_to_remove = probs < scaled_min_p + min_p = min_p.unsqueeze(dim=1) + tokens_to_remove = probs < (min_p * top_probs) logits = logits.masked_fill_(tokens_to_remove, -float("inf")) return logits -def _greedy_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - samples: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: - samples = samples.tolist() - sample_idx = 0 - results = [] - for seq_group in selected_seq_groups: - seq_ids, _ = seq_group - num_parent_seqs = len(seq_ids) - assert num_parent_seqs == 1, ( - "Greedy sampling should have only one seq.") - parent_ids = list(range(num_parent_seqs)) - next_token_ids = [samples[sample_idx]] - results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs - return results - - -def _random_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - is_prompts: List[bool], - random_samples: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: - # Find the maximum best_of value of the prompt phase requests. - random_samples = random_samples.cpu() - sample_idx = 0 - results = [] - for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): - seq_ids, sampling_params = seq_group - num_parent_seqs = len(seq_ids) - if is_prompt: - # Prompt phase. - parent_ids = [0] * sampling_params.best_of - next_token_ids = random_samples[ - sample_idx, :sampling_params.best_of].tolist() - else: - # Generation phase. - parent_ids = list(range(num_parent_seqs)) - next_token_ids = random_samples[sample_idx:sample_idx + - num_parent_seqs, 0].tolist() - results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs - return results +# def _beam_search_sample( +# selected_seq_groups: List[Tuple[List[int], SamplingParams]], +# is_prompts: List[bool], +# seq_data: Dict[int, SequenceData], +# logprobs: torch.Tensor, +# ) -> List[Tuple[List[int], List[int]]]: +# # We sample 2 * beam_width candidates to make sure that with high +# # probability we can get `beam_width` candidates in addition to +# # the finished sequences for the next iteration. See +# # https://github.com/tensorflow/tensor2tensor/blob/bafdc1 +# b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py +# #L557-L563 +# # for details. See also HF reference: +# # https://github.com/huggingface/transformers/blob/a4dd53d8 +# 8e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py +# #L3063-L3065 +# # +# # NOTE: Beam search is not vectorized, so its speed can be slower than +# # other sampling methods. +# sample_idx = 0 +# results = [] +# for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): +# seq_ids, sampling_params = seq_group +# num_parent_seqs = len(seq_ids) +# beam_width = sampling_params.actual_best_of +# seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs] +# if is_prompt: +# # Prompt phase. +# assert num_parent_seqs == 1, ( +# "Prompt input should have only one seq.") +# parent_ids = [0] * (2 * beam_width) +# _, next_token_ids = torch.topk(seq_group_logprobs[0], +# 2 * beam_width) +# next_token_ids = next_token_ids.tolist() +# else: +# # Generation phase. +# cumulative_logprobs = [ +# seq_data[seq_id].cumulative_logprob for seq_id in seq_ids +# ] +# cumulative_logprobs = torch.tensor( +# cumulative_logprobs, +# dtype=torch.float, +# device=seq_group_logprobs.device) +# seq_group_logprobs = (seq_group_logprobs + +# cumulative_logprobs.unsqueeze(dim=1)) +# _, topk_ids = torch.topk(seq_group_logprobs.flatten(), +# 2 * beam_width) +# topk_ids = topk_ids.tolist() +# vocab_size = seq_group_logprobs.size(-1) +# parent_ids = [i // vocab_size for i in topk_ids] +# next_token_ids = [i % vocab_size for i in topk_ids] +# results.append((next_token_ids, parent_ids)) +# sample_idx += num_parent_seqs +# assert sample_idx == logprobs.size(0) +# return results + + +@torch.jit.script +def _modify_greedy_probs(probs: torch.Tensor, sample_indices: torch.Tensor, + sampled_tokens: torch.Tensor) -> None: + """Set the probability of the sampled token to 1, all other tokens to zero. + This is used in speculative decoding where the sampling method must be + encoded within the sampled probability distributions. + """ + sample_indices = sample_indices.to(torch.long) + probs.index_fill_(0, sample_indices, 0) + probs.flatten().index_fill_(0, (sample_indices * probs.stride()[0]) + + sampled_tokens, 1) -def _beam_search_sample( - selected_seq_groups: List[Tuple[List[int], SamplingParams]], - is_prompts: List[bool], - seq_data: Dict[int, SequenceData], - logprobs: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: - # We sample 2 * beam_width candidates to make sure that with high - # probability we can get `beam_width` candidates in addition to - # the finished sequences for the next iteration. See - # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563 - # for details. See also HF reference: - # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 - # - # NOTE: Beam search is not vectorized, so its speed can be slower than - # other sampling methods. - sample_idx = 0 - results = [] - for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): - seq_ids, sampling_params = seq_group - num_parent_seqs = len(seq_ids) - beam_width = sampling_params.best_of - seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs] - if is_prompt: - # Prompt phase. - assert num_parent_seqs == 1, ( - "Prompt input should have only one seq.") - parent_ids = [0] * (2 * beam_width) - _, next_token_ids = torch.topk(seq_group_logprobs[0], - 2 * beam_width) - next_token_ids = next_token_ids.tolist() - else: - # Generation phase. - cumulative_logprobs = [ - seq_data[seq_id].cumulative_logprob for seq_id in seq_ids - ] - cumulative_logprobs = torch.tensor( - cumulative_logprobs, - dtype=torch.float, - device=seq_group_logprobs.device) - seq_group_logprobs = (seq_group_logprobs + - cumulative_logprobs.unsqueeze(dim=1)) - _, topk_ids = torch.topk(seq_group_logprobs.flatten(), - 2 * beam_width) - topk_ids = topk_ids.tolist() - vocab_size = seq_group_logprobs.size(-1) - parent_ids = [i // vocab_size for i in topk_ids] - next_token_ids = [i % vocab_size for i in topk_ids] - results.append((next_token_ids, parent_ids)) - sample_idx += num_parent_seqs - assert sample_idx == logprobs.size(0) - return results - - -# torch.multinomial forces a GPU<->CPU sync. -# Therefore, we use an optimized implementation instead. -# Note that we always sample with replacement. -# probs will be modified in place, but this is fine, as we pass -# in a copy already. -def _multinomial( +@torch.jit.script +def _sample_tensor( probs: torch.Tensor, num_samples: int, -): + random_sample_indices: torch.Tensor, +) -> torch.Tensor: if num_samples > 1: # This is equivalent to torch.repeat_interleaved (which also # forces a GPU<->CPU sync). - # This allows us to do sampling with replacement by creating - # num_samples copies of each row in the tensor, and then - # batch sampling the resulting tensor. probs = probs[:, None, :].expand(probs.shape[0], num_samples, - probs.shape[1]).contiguous().view( - -1, probs.shape[1]) - q = torch.empty_like(probs).exponential_(1) - return probs.div_(q).argmax(dim=1).view(-1, num_samples) - - + probs.shape[1]).contiguous() + else: + probs = probs.view(probs.shape[0], num_samples, probs.shape[1]) + + has_random_sample_indices = bool(random_sample_indices.numel()) + if has_random_sample_indices: + random_sample_probs = probs[random_sample_indices] + random_sample_probs_v = random_sample_probs.view( + random_sample_probs.shape[0] * num_samples, + random_sample_probs.shape[-1]) + q = torch.empty_like(random_sample_probs_v).exponential_(1.0).pow_(-1) + probs.index_reduce_(0, random_sample_indices, + q.view_as(random_sample_probs), "prod") + return probs.view(probs.shape[0] * num_samples, + -1).argmax(dim=1).view(-1, num_samples) + + +@torch.jit.script def _sample( probs: torch.Tensor, logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> List[Tuple[List[int], List[int]]]: - categorized_seq_group_ids = {t: [] for t in SamplingType} - categorized_sample_indices = sampling_metadata.categorized_sample_indices - for i, seq_group in enumerate(sampling_metadata.seq_groups): - _, sampling_params = seq_group - sampling_type = sampling_params.sampling_type - categorized_seq_group_ids[sampling_type].append(i) - - sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} - sample_metadata = {} - - # Counterintiutively, having two loops here is actually faster. - # The first loop can run without waiting on GPU<->CPU sync. - for sampling_type in SamplingType: - sample_indices = categorized_sample_indices[sampling_type] - num_tokens = len(sample_indices) - if num_tokens == 0: - continue - seq_group_ids = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] - is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] - sample_metadata[sampling_type] = (seq_group_ids, seq_groups, - is_prompts, sample_indices) - if sampling_type == SamplingType.GREEDY: - greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1) - elif sampling_type == SamplingType.RANDOM: - max_best_of = 1 - for seq_group, is_prompt in zip(seq_groups, is_prompts): - if is_prompt: - _, sampling_params = seq_group - max_best_of = max(max_best_of, sampling_params.best_of) - multinomial_samples = _multinomial(probs[sample_indices], - max_best_of) - elif sampling_type == SamplingType.BEAM: - beam_search_logprobs = logprobs[sample_indices] - else: - raise ValueError(f"Unsupported sampling type: {sampling_type}") - - # GPU<->CPU sync happens in the loop below. - - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[ - sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) - elif sampling_type == SamplingType.RANDOM: - sample_results = _random_sample(seq_groups, is_prompts, - multinomial_samples) - elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, is_prompts, - sampling_metadata.seq_data, - beam_search_logprobs) - sample_results_dict.update(zip(seq_group_ids, sample_results)) - - sample_results = [ - sample_results_dict[i] - for i in range(len(sampling_metadata.seq_groups)) - ] - return sample_results - - -def _get_logprobs( - logprobs: torch.Tensor, - sampling_metadata: SamplingMetadata, - sample_results: List[Tuple[List[int], List[int]]], -) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[ - int, float]]]]: - # Prepare query indices - batched_logprobs_query_seq_indices: List[int] = [] - batched_logprobs_query_token_indices: List[int] = [] - largest_num_logprobs = 0 - sample_idx = 0 - for i, (seq_group, sample_result) in enumerate( - zip(sampling_metadata.seq_groups, sample_results)): - seq_ids, sampling_params = seq_group - next_token_ids, parent_ids = sample_result - num_parent_seqs = len(seq_ids) - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.prompt_logprobs) - prompt_len = sampling_metadata.prompt_lens[i] - prompt_tokens = sampling_metadata.seq_data[ - seq_ids[0]].prompt_token_ids - batched_logprobs_query_seq_indices.extend( - sample_idx + j for j in range(prompt_len - 1)) - batched_logprobs_query_token_indices.extend( - token_id for token_id in prompt_tokens[1:]) - sample_idx += prompt_len - 1 - batched_logprobs_query_seq_indices.extend( - [sample_idx + parent_id for parent_id in parent_ids]) - batched_logprobs_query_token_indices.extend(next_token_ids) - if sampling_params.logprobs is not None: - largest_num_logprobs = max(largest_num_logprobs, - sampling_params.logprobs) - sample_idx += num_parent_seqs - assert sample_idx == logprobs.size(0) - - # Batched query for logprobs of selected token - batched_logprobs_query_result = logprobs[[ - batched_logprobs_query_seq_indices, - batched_logprobs_query_token_indices - ]] - - # Batched query for logprobs of topk tokens - if largest_num_logprobs > 0: - top_logprobs, top_token_ids = torch.topk(logprobs, - largest_num_logprobs, - dim=-1) - top_logprobs = top_logprobs.cpu() - top_token_ids = top_token_ids.cpu() - else: - top_logprobs, top_token_ids = None, None - - batched_logprobs_query_result = batched_logprobs_query_result.cpu() - - # Gather results - result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] - result_sample_logprobs: List[SampleLogprobs] = [] - sample_idx = 0 - query_result_idx = 0 - for i, (seq_group, sample_result) in enumerate( - zip(sampling_metadata.seq_groups, sample_results)): - seq_ids, sampling_params = seq_group - next_token_ids, parent_ids = sample_result - - # Prompt logprobs - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - num_logprobs = sampling_params.prompt_logprobs - prompt_len = sampling_metadata.prompt_lens[i] - prompt_tokens = sampling_metadata.seq_data[ - seq_ids[0]].prompt_token_ids - group_prompt_logprobs: PromptLogprobs = [None] - for token_id in prompt_tokens[1:]: - prompt_logprobs_dict = { - token_id: - batched_logprobs_query_result[query_result_idx].item() - } - if num_logprobs > 0: - prompt_logprobs_dict.update( - zip(top_token_ids[sample_idx, :num_logprobs].tolist(), - top_logprobs[sample_idx, :num_logprobs].tolist())) - group_prompt_logprobs.append(prompt_logprobs_dict) - sample_idx += 1 - query_result_idx += 1 - result_prompt_logprobs.append(group_prompt_logprobs) - else: - result_prompt_logprobs.append(None) - - # Sample logprobs - num_logprobs = sampling_params.logprobs - if num_logprobs is None: - num_logprobs = 0 - group_sample_logprobs: SampleLogprobs = [] - for next_token_id, parent_id in zip(next_token_ids, parent_ids): - sample_logprobs_dict = { - next_token_id: - batched_logprobs_query_result[query_result_idx].item() - } - query_result_idx += 1 - if num_logprobs > 0: - sample_logprobs_dict.update( - zip( - top_token_ids[sample_idx + - parent_id, :num_logprobs].tolist(), - top_logprobs[sample_idx + - parent_id, :num_logprobs].tolist())) - group_sample_logprobs.append(sample_logprobs_dict) - result_sample_logprobs.append(group_sample_logprobs) - sample_idx += len(seq_ids) - - return result_prompt_logprobs, result_sample_logprobs - - -def _build_sampler_output( - sample_results: List[Tuple[List[int], List[int]]], - sampling_metadata: SamplingMetadata, - prompt_logprobs: List[Optional[PromptLogprobs]], - sample_logprobs: List[SampleLogprobs], -) -> SamplerOutput: - sampler_output = [] - for (seq_group, sample_result, group_prompt_logprobs, - group_sample_logprobs) in zip(sampling_metadata.seq_groups, - sample_results, prompt_logprobs, - sample_logprobs): - seq_ids, _ = seq_group - next_token_ids, parent_ids = sample_result - seq_outputs = [] - for parent_id, next_token_id, logprobs in zip(parent_ids, - next_token_ids, - group_sample_logprobs): - seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) - sampler_output.append( - SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) - return sampler_output + prompt_token_indices: torch.Tensor, + sample_indices: torch.Tensor, + categorized_sample_indices: Tuple[torch.Tensor, torch.Tensor], + max_best_of: int, + modify_greedy_probs: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + prompt_logprobs = logprobs[prompt_token_indices] + sample_probs = probs[sample_indices] + + sampled_tokens = _sample_tensor(sample_probs, max_best_of, + categorized_sample_indices[1]) + + has_greedy_indices = bool(categorized_sample_indices[0].numel()) + if modify_greedy_probs and has_greedy_indices: + # Note: in greedy sampling, there only one sample per sequence + # group. + greedy_sample_indices = sample_indices[categorized_sample_indices[0]] + _modify_greedy_probs( + probs, greedy_sample_indices, + sampled_tokens[categorized_sample_indices[0]][:, 0]) + + sampled_logprobs = torch.gather(logprobs[sample_indices], 1, + sampled_tokens) + return sampled_tokens, sampled_logprobs, prompt_logprobs diff --git a/vllm/model_executor/layers/sampler_ops/__init__.py b/vllm/model_executor/layers/sampler_ops/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/sampler_ops/penalty_triton.py b/vllm/model_executor/layers/sampler_ops/penalty_triton.py new file mode 100644 index 0000000000000..9187b4e46434c --- /dev/null +++ b/vllm/model_executor/layers/sampler_ops/penalty_triton.py @@ -0,0 +1,64 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_apply_penalty(logits, presence_penalty, freqency_penalty, + repetition_penalty, p_token_ids, p_token_counts, + p_cumsum_seq_len, stride_logit_b, + block_p: tl.constexpr): + cur_batch = tl.program_id(0) + cur_freqency = tl.load(freqency_penalty + cur_batch) + cur_presence = tl.load(presence_penalty + cur_batch) + cur_repetition = tl.load(repetition_penalty + cur_batch) + + cur_batch_start_index = tl.load(p_cumsum_seq_len + cur_batch) + cur_batch_end_index = tl.load(p_cumsum_seq_len + cur_batch + 1) + + cur_batch_id_offset = cur_batch_start_index + tl.arange(0, block_p) + batch_ids = tl.load(p_token_ids + cur_batch_id_offset, + mask=cur_batch_id_offset < cur_batch_end_index, + other=0) + batch_ids_count = tl.load(p_token_counts + cur_batch_id_offset, + mask=cur_batch_id_offset < cur_batch_end_index, + other=0) + + row_start_ptr = logits + cur_batch * stride_logit_b + cur_offset = row_start_ptr + batch_ids + cur_logits = tl.load(cur_offset, + mask=cur_batch_id_offset < cur_batch_end_index, + other=0.0) + rep_logits = tl.where(cur_logits > 0, cur_logits / cur_repetition, + cur_logits * cur_repetition) + freq_logits = rep_logits - batch_ids_count * cur_freqency + pre_logits = freq_logits - cur_presence + output_ptr = logits + cur_batch * stride_logit_b + batch_ids + tl.store(output_ptr, + pre_logits, + mask=cur_batch_id_offset < cur_batch_end_index) + + +@torch.no_grad() +def apply_penalty(logits, presence_penalty, freqency_penalty, + repetition_penalty, p_token_ids, p_token_counts, + p_cumsum_seq_len, p_max_len_in_batch): + if not logits.is_contiguous(): + logits = logits.contiguous() + block = triton.next_power_of_2(p_max_len_in_batch) + if block <= 512: + block = 512 + elif block <= 1024: + block = 1024 + num_warps = 8 + _fwd_kernel_apply_penalty[(logits.shape[0], )](logits, + presence_penalty, + freqency_penalty, + repetition_penalty, + p_token_ids, + p_token_counts, + p_cumsum_seq_len, + logits.stride(0), + num_warps=num_warps, + block_p=block) diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index 9a5e2889381d9..f44b970f6cec0 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -2,7 +2,9 @@ # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -"""Tensor and pipeline parallel groups.""" +"""Model and data parallel groups.""" + +import contextlib import torch @@ -83,8 +85,19 @@ def initialize_model_parallel( _PIPELINE_GLOBAL_RANKS = ranks +@contextlib.contextmanager +def patch_tensor_parallel_group(group): + old_group = get_tensor_model_parallel_group() + global _TENSOR_MODEL_PARALLEL_GROUP + _TENSOR_MODEL_PARALLEL_GROUP = group + try: + yield + finally: + _TENSOR_MODEL_PARALLEL_GROUP = old_group + + def model_parallel_is_initialized(): - """Check if tensor and pipeline parallel groups are initialized.""" + """Check if model and data parallel groups are initialized.""" return (_TENSOR_MODEL_PARALLEL_GROUP is not None and _PIPELINE_MODEL_PARALLEL_GROUP is not None) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 30a8036a63fc9..66dfcdc886d22 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,10 +1,10 @@ """Sampling parameters for text generation.""" from enum import IntEnum -from functools import cached_property -from typing import Callable, List, Optional, Union - +from typing import Any, Callable, Dict, List, Optional, Union import torch +import msgspec + _SAMPLING_EPS = 1e-5 @@ -20,7 +20,7 @@ class SamplingType(IntEnum): tensor of logits to sample from.""" -class SamplingParams: +class SamplingParams(msgspec.Struct, array_like=True, omit_defaults=True): """Sampling parameters for text generation. Overall, we follow the sampling parameters from the OpenAI text completion @@ -70,9 +70,7 @@ class SamplingParams: The returned output will not contain the stop strings. stop_token_ids: List of tokens that stop the generation when they are generated. The returned output will contain the stop tokens unless - the stop tokens are special tokens. - include_stop_str_in_output: Whether to include the stop strings in output - text. Defaults to False. + the stop tokens are sepcial tokens. ignore_eos: Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. max_tokens: Maximum number of tokens to generate per output sequence. @@ -88,81 +86,71 @@ class SamplingParams: tokens in the output. Defaults to True. logits_processors: List of functions that modify logits based on previously generated tokens. + response_format: Format to return the final response in. Can be for ex: + response_format={"type": "json", "schema": "{...}"} + """ - def __init__( - self, - n: int = 1, - best_of: Optional[int] = None, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - repetition_penalty: float = 1.0, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = -1, - min_p: int = 0.0, - use_beam_search: bool = False, - length_penalty: float = 1.0, - early_stopping: Union[bool, str] = False, - stop: Optional[Union[str, List[str]]] = None, - stop_token_ids: Optional[List[int]] = None, - include_stop_str_in_output: bool = False, - ignore_eos: bool = False, - max_tokens: int = 16, - logprobs: Optional[int] = None, - prompt_logprobs: Optional[int] = None, - skip_special_tokens: bool = True, - spaces_between_special_tokens: bool = True, - logits_processors: Optional[List[LogitsProcessor]] = None, - ) -> None: - self.n = n - self.best_of = best_of if best_of is not None else n - self.presence_penalty = presence_penalty - self.frequency_penalty = frequency_penalty - self.repetition_penalty = repetition_penalty - self.temperature = temperature - self.top_p = top_p - self.top_k = top_k - self.min_p = min_p - self.use_beam_search = use_beam_search - self.length_penalty = length_penalty - self.early_stopping = early_stopping - if stop is None: - self.stop = [] - elif isinstance(stop, str): - self.stop = [stop] - else: - self.stop = list(stop) - if stop_token_ids is None: - self.stop_token_ids = [] - else: - self.stop_token_ids = list(stop_token_ids) - self.ignore_eos = ignore_eos - self.max_tokens = max_tokens - self.logprobs = logprobs - self.prompt_logprobs = prompt_logprobs - self.skip_special_tokens = skip_special_tokens - self.spaces_between_special_tokens = spaces_between_special_tokens - self.logits_processors = logits_processors - self.include_stop_str_in_output = include_stop_str_in_output - self._verify_args() + n: int = 1 + best_of: Optional[int] = 0 + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + repetition_penalty: float = 1.0 + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 + min_p: float = 0.0 + use_beam_search: bool = False + length_penalty: float = 1.0 + early_stopping: Union[bool, str] = False + stop: List[str] = [] + stop_token_ids: List[int] = [] + ignore_eos: bool = False + max_tokens: int = 16 + logprobs: Optional[int] = None + prompt_logprobs: Optional[int] = None + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + logits_processors: Optional[List[LogitsProcessor]] = None + response_format: Optional[Dict[str, Any]] = None + + @property + def actual_best_of(self) -> int: + return (self.best_of if + (self.best_of is not None and self.best_of > 0) else self.n) + + @property + def has_penalties(self) -> bool: + return (abs(self.presence_penalty) >= _SAMPLING_EPS + or abs(self.frequency_penalty) >= _SAMPLING_EPS + or abs(self.repetition_penalty - 1.0) >= _SAMPLING_EPS) + + @property + def sampling_type(self) -> SamplingType: if self.use_beam_search: - self._verify_beam_search() - else: - self._verify_non_beam_search() - if self.temperature < _SAMPLING_EPS: - # Zero temperature means greedy sampling. - self.top_p = 1.0 - self.top_k = -1 - self.min_p = 0.0 - self._verify_greedy_sampling() + return SamplingType.BEAM + if self.temperature < _SAMPLING_EPS: + return SamplingType.GREEDY + return SamplingType.RANDOM + + def __post_init__(self): + assert not self.use_beam_search, "beam search is disabled" + self._verify_args() + # if self.use_beam_search: + # self._verify_beam_search() + # else: + self._verify_non_beam_search() + if self.temperature < _SAMPLING_EPS: + # Zero temperature means greedy sampling. + self._verify_greedy_sampling() def _verify_args(self) -> None: if self.n < 1: raise ValueError(f"n must be at least 1, got {self.n}.") - if self.best_of < self.n: - raise ValueError(f"best_of must be greater than or equal to n, " - f"got n={self.n} and best_of={self.best_of}.") + if self.actual_best_of < self.n: + raise ValueError( + f"best_of must be greater than or equal to n, " + f"got n={self.n} and best_of={self.actual_best_of}.") if not -2.0 <= self.presence_penalty <= 2.0: raise ValueError("presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}.") @@ -193,20 +181,20 @@ def _verify_args(self) -> None: raise ValueError(f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}.") - def _verify_beam_search(self) -> None: - if self.best_of == 1: - raise ValueError("best_of must be greater than 1 when using beam " - f"search. Got {self.best_of}.") - if self.temperature > _SAMPLING_EPS: - raise ValueError("temperature must be 0 when using beam search.") - if self.top_p < 1.0 - _SAMPLING_EPS: - raise ValueError("top_p must be 1 when using beam search.") - if self.top_k != -1: - raise ValueError("top_k must be -1 when using beam search.") - if self.early_stopping not in [True, False, "never"]: - raise ValueError( - f"early_stopping must be True, False, or 'never', " - f"got {self.early_stopping}.") + # def _verify_beam_search(self) -> None: + # if self.actual_best_of == 1: + # raise ValueError("best_of must be greater than 1 when using beam " + # f"search. Got {self.actual_best_of}.") + # if self.temperature > _SAMPLING_EPS: + # raise ValueError("temperature must be 0 when using beam search.") + # if self.top_p < 1.0 - _SAMPLING_EPS: + # raise ValueError("top_p must be 1 when using beam search.") + # if self.top_k != -1: + # raise ValueError("top_k must be -1 when using beam search.") + # if self.early_stopping not in [True, False, "never"]: + # raise ValueError( + # f"early_stopping must be True, False, or 'never', " + # f"got {self.early_stopping}.") def _verify_non_beam_search(self) -> None: if self.early_stopping is not False: @@ -219,39 +207,33 @@ def _verify_non_beam_search(self) -> None: "default value of 1.0 when not using beam search.") def _verify_greedy_sampling(self) -> None: - if self.best_of > 1: + if self.actual_best_of > 1: raise ValueError("best_of must be 1 when using greedy sampling." - f"Got {self.best_of}.") - - @cached_property - def sampling_type(self) -> SamplingType: - if self.use_beam_search: - return SamplingType.BEAM - if self.temperature < _SAMPLING_EPS: - return SamplingType.GREEDY - return SamplingType.RANDOM + f"Got {self.actual_best_of}.") + if self.top_p < 1.0 - _SAMPLING_EPS: + raise ValueError("top_p must be 1 when using greedy sampling.") + if self.top_k != -1: + raise ValueError("top_k must be -1 when using greedy sampling.") def __repr__(self) -> str: - return ( - f"SamplingParams(n={self.n}, " - f"best_of={self.best_of}, " - f"presence_penalty={self.presence_penalty}, " - f"frequency_penalty={self.frequency_penalty}, " - f"repetition_penalty={self.repetition_penalty}, " - f"temperature={self.temperature}, " - f"top_p={self.top_p}, " - f"top_k={self.top_k}, " - f"min_p={self.min_p}, " - f"use_beam_search={self.use_beam_search}, " - f"length_penalty={self.length_penalty}, " - f"early_stopping={self.early_stopping}, " - f"stop={self.stop}, " - f"stop_token_ids={self.stop_token_ids}, " - f"include_stop_str_in_output={self.include_stop_str_in_output}, " - f"ignore_eos={self.ignore_eos}, " - f"max_tokens={self.max_tokens}, " - f"logprobs={self.logprobs}, " - f"prompt_logprobs={self.prompt_logprobs}, " - f"skip_special_tokens={self.skip_special_tokens}, " - "spaces_between_special_tokens=" - f"{self.spaces_between_special_tokens})") + return (f"SamplingParams(n={self.n}, " + f"best_of={self.actual_best_of}, " + f"presence_penalty={self.presence_penalty}, " + f"frequency_penalty={self.frequency_penalty}, " + f"repetition_penalty={self.repetition_penalty}, " + f"temperature={self.temperature}, " + f"top_p={self.top_p}, " + f"top_k={self.top_k}, " + f"min_p={self.min_p}, " + f"use_beam_search={self.use_beam_search}, " + f"length_penalty={self.length_penalty}, " + f"early_stopping={self.early_stopping}, " + f"stop={self.stop}, " + f"stop_token_ids={self.stop_token_ids}, " + f"ignore_eos={self.ignore_eos}, " + f"max_tokens={self.max_tokens}, " + f"logprobs={self.logprobs}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"skip_special_tokens={self.skip_special_tokens}, " + "spaces_between_special_tokens=" + f"{self.spaces_between_special_tokens})") diff --git a/vllm/sequence.py b/vllm/sequence.py index 7d36eeac0aa02..91560e49650e0 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,9 +1,14 @@ """Sequence and its related classes.""" import copy import enum -from typing import Dict, List, Optional, Union +from collections import Counter +from typing import Dict, List, Optional, Tuple, Union +import torch + +import msgspec from vllm.block import LogicalTokenBlock +from vllm.anyscale.lora.utils import LoRARequest from vllm.sampling_params import SamplingParams PromptLogprobs = List[Optional[Dict[int, float]]] @@ -47,7 +52,7 @@ def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: return finish_reason -class SequenceData: +class SequenceData(msgspec.Struct, array_like=True, omit_defaults=True): """Data associated with a sequence. @@ -55,45 +60,148 @@ class SequenceData: prompt_token_ids: The token IDs of the prompt. Attributes: - prompt_token_ids: The token IDs of the prompt. - output_token_ids: The token IDs of the output. + token_ids: The token IDs so far (prompt+output). + num_prompt_tokens: The number of prompt tokens. If not specified, + will be set to the length of token_ids. cumulative_logprob: The cumulative log probability of the output. + num_processed_token_ids: The number of token ids that have been + processed by the workers. + prefill_start: The start index of the chunked prefill. + prefill_end: The end index of the chunked prefill. """ - def __init__( - self, - prompt_token_ids: List[int], - ) -> None: - self.prompt_token_ids = prompt_token_ids - self.output_token_ids: List[int] = [] - self.cumulative_logprob = 0.0 + token_ids: List[int] + num_prompt_tokens: int = -1 + cumulative_logprob: float = 0.0 + prefill_start: int = 0 + prefill_end: int = 0 + _prompt_token_id_count: Optional[Dict[int, int]] = None + _output_token_id_count: Optional[Dict[int, int]] = None + + # The number of tokens that have been processed. + # Processed means that the KV for the tokens has been computed and stored + # to the KV cache. + num_processed_token_ids: int = 0 + + @property + def prompt_token_id_count(self) -> Counter[int, int]: + if self._prompt_token_id_count is None: + self._prompt_token_id_count = Counter(self.get_prompt_token_ids()) + return self._prompt_token_id_count + + @property + def output_token_id_count(self) -> Counter[int, int]: + if self._output_token_id_count is None: + self._output_token_id_count = Counter(self.get_output_token_ids()) + return self._output_token_id_count + + def __post_init__(self): + if self.num_prompt_tokens < 0: + self.num_prompt_tokens = len(self.token_ids) + + if (self.num_processed_token_ids > + self.get_len()) or (self.num_processed_token_ids < 0): + raise ValueError(f"{self.num_processed_token_ids=} must be in the " + "interval [0, {self.get_len()=}]") + + def append_token_ids(self, token_ids: List[int], + logprobs: List[float]) -> None: + """Append token ids to the output token ids and update the cumulative + logprob. Also updates the number of processed token ids to the sequence + length before the new tokens. + """ + self.num_processed_token_ids = self.get_len() - def append_token_id(self, token_id: int, logprob: float) -> None: - self.output_token_ids.append(token_id) - self.cumulative_logprob += logprob + self.token_ids.extend(token_ids) + self.cumulative_logprob += sum(logprobs) + + for token_id in token_ids: + self.output_token_id_count[ + token_id] = self.output_token_id_count.get(token_id, 0) + 1 + + def reset_processed_tokens(self) -> None: + """Set the number of processed tokens to zero. Used when a sequence is + preempted by recomputation. This reset the prefill range as well. + """ + self.num_processed_token_ids = 0 + self.prefill_start = 0 + self.prefill_end = 0 + + def get_num_processed_token_ids(self) -> int: + return self.num_processed_token_ids + + def get_unprocessed_token_ids(self) -> List[int]: + return self.token_ids[self.get_unprocessed_token_start_idx():] + + def get_unprocessed_token_start_idx(self) -> int: + seq_len = self.get_len() + num_unprocessed_token_ids = seq_len - self.num_processed_token_ids + return seq_len - num_unprocessed_token_ids + + def get_unprocessed_token_positions(self) -> List[int]: + return list( + range(self.get_unprocessed_token_start_idx(), self.get_len())) def get_len(self) -> int: - return len(self.output_token_ids) + len(self.prompt_token_ids) + return len(self.token_ids) def get_prompt_len(self) -> int: - return len(self.prompt_token_ids) + return self.num_prompt_tokens + + def get_prompt_token_ids(self) -> int: + return self.token_ids[:self.num_prompt_tokens] def get_output_len(self) -> int: - return len(self.output_token_ids) + return len(self.token_ids) - self.num_prompt_tokens def get_token_ids(self) -> List[int]: - return self.prompt_token_ids + self.output_token_ids + return self.token_ids + + def advance_prefill_range(self, prefill_range: int) -> int: + """Advance the prefill range by the specified amount + + Args: + prefill_range: The amount to advance the prefill range. + Returns: + The actual number of advanced tokens. + """ + self.prefill_start = self.prefill_end + # The increased range could be larger than the seq length. + # Clamp it to the seq length. + # Note that we use prompt_len + output_len instead of + # prompt_len here. This is because during recompute + # we need to prefill for both prompt and output. + self.prefill_end = min(self.prefill_end + prefill_range, + self.get_len()) + return self.prefill_end - self.prefill_start + + def get_prefill_range(self) -> Tuple[int, int]: + """Returns the prefill range.""" + return self.prefill_start, self.prefill_end + + def get_num_unprefilled(self) -> int: + return self.get_len() - self.prefill_end + + def get_output_token_ids(self) -> List[int]: + return self.token_ids[self.num_prompt_tokens:] + + def get_token_id(self, index: int) -> int: + return self.token_ids[index] + + def get_new_token_ids(self) -> List[int]: + return self.get_unprocessed_token_ids() def get_last_token_id(self) -> int: - if not self.output_token_ids: - return self.prompt_token_ids[-1] - return self.output_token_ids[-1] + return self.token_ids[-1] def __repr__(self) -> str: return (f"SequenceData(" - f"prompt_token_ids={self.prompt_token_ids}, " - f"output_token_ids={self.output_token_ids}, " - f"cumulative_logprob={self.cumulative_logprob})") + f"prompt_token_ids={self.get_prompt_token_ids()}, " + f"output_token_ids={self.get_output_token_ids()}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"num_processed_token_ids={self.num_processed_token_ids}), " + f"prompt_token_id_count={self.prompt_token_id_count}, " + f"output_token_id_count={self.output_token_id_count}") class Sequence: @@ -105,6 +213,8 @@ class Sequence: prompt_token_ids: The token IDs of the prompt. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. + num_processed_tokens: The number of prompt tokens to be considered + processed by SequenceData. """ def __init__( @@ -113,16 +223,25 @@ def __init__( prompt: str, prompt_token_ids: List[int], block_size: int, + lora_request: Optional[LoRARequest] = None, + num_processed_token_ids: int = 0, ) -> None: self.seq_id = seq_id self.prompt = prompt self.block_size = block_size + self.lora_request = lora_request - self.data = SequenceData(prompt_token_ids) + self.data = SequenceData( + prompt_token_ids, num_processed_token_ids=num_processed_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] + + # Keep track of the first logical block index that has empty slots. + # Used to determine which block new tokens should be appended to. + self.block_index_for_new_tokens = 0 + # Initialize the logical token blocks with the prompt token ids. self._append_tokens_to_blocks(prompt_token_ids) self.status = SequenceStatus.WAITING @@ -133,6 +252,10 @@ def __init__( # Input + output tokens self.tokens: Optional[List[str]] = None + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + def _append_logical_block(self) -> None: block = LogicalTokenBlock( block_number=len(self.logical_token_blocks), @@ -146,25 +269,63 @@ def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: if not self.logical_token_blocks: self._append_logical_block() - last_block = self.logical_token_blocks[-1] + last_block = self.logical_token_blocks[ + self.block_index_for_new_tokens] if last_block.is_full(): self._append_logical_block() - last_block = self.logical_token_blocks[-1] + self.block_index_for_new_tokens += 1 + last_block = self.logical_token_blocks[ + self.block_index_for_new_tokens] num_empty_slots = last_block.get_num_empty_slots() last_block.append_tokens(token_ids[cursor:cursor + num_empty_slots]) cursor += num_empty_slots + def ensure_num_empty_slots(self, num_desired_empty_slots: int) -> None: + """Ensure the specified number of empty slots are present in the logical + token blocks, allocating additional blocks if necessary. + """ + if not self.logical_token_blocks: + self._append_logical_block() + + num_empty_slots = sum( + block.get_num_empty_slots() for block in + self.logical_token_blocks[self.block_index_for_new_tokens:]) + num_empty_remaining = num_desired_empty_slots - num_empty_slots + + while num_empty_remaining > 0: + self._append_logical_block() + last_block = self.logical_token_blocks[-1] + num_empty_remaining -= last_block.get_num_empty_slots() + def append_token_id( self, token_id: int, logprobs: Dict[int, float], ) -> None: - assert token_id in logprobs - self._append_tokens_to_blocks([token_id]) - self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id]) + return self.append_token_ids([token_id], [logprobs]) + + def append_token_ids( + self, + token_ids: List[int], + logprobs: List[Dict[int, float]], + ) -> None: + self._append_tokens_to_blocks(token_ids) + self.output_logprobs.extend(logprobs) + self.data.append_token_ids(token_ids, [ + logprob[token_id] + for logprob, token_id in zip(logprobs, token_ids) + ]) + + def reset_processed_tokens(self): + self.data.reset_processed_tokens() + + def get_num_processed_token_ids(self) -> int: + return self.data.get_num_processed_token_ids() + + def get_num_unprocessed_token_ids(self) -> int: + return self.get_len() - self.get_num_processed_token_ids() def get_len(self) -> int: return self.data.get_len() @@ -181,30 +342,34 @@ def get_token_ids(self) -> List[int]: def get_last_token_id(self) -> int: return self.data.get_last_token_id() + def get_new_token_ids(self) -> List[int]: + return self.data.get_new_token_ids() + def get_output_token_ids(self) -> List[int]: - return self.data.output_token_ids + return self.data.get_output_token_ids() def get_cumulative_logprob(self) -> float: return self.data.cumulative_logprob - def get_beam_search_score(self, - length_penalty: float = 0.0, - seq_len: Optional[int] = None, - eos_token_id: Optional[int] = None) -> float: - """Calculate the beam search score with length penalty. - - Adapted from - - https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 - """ - if seq_len is None: - seq_len = self.get_len() - # NOTE: HF implementation does not count the EOS token - # towards the length, we align with that here for testing. - if (eos_token_id is not None - and self.get_last_token_id() == eos_token_id): - seq_len -= 1 - return self.get_cumulative_logprob() / (seq_len**length_penalty) + # def get_beam_search_score(self, + # length_penalty: float = 0.0, + # seq_len: Optional[int] = None, + # eos_token_id: Optional[int] = None) -> float: + # """Calculate the beam search score with length penalty. + + # Adapted from + + # https://github.com/huggingface/transformers/blob/ccb92be23def445f2af + # dea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 + # """ + # if seq_len is None: + # seq_len = self.get_len() + # # NOTE: HF implementation does not count the EOS token + # # towards the length, we align with that here for testing. + # if (eos_token_id is not None + # and self.get_last_token_id() == eos_token_id): + # seq_len -= 1 + # return self.get_cumulative_logprob() / (seq_len**length_penalty) def is_finished(self) -> bool: return SequenceStatus.is_finished(self.status) @@ -236,11 +401,18 @@ def __init__( seqs: List[Sequence], sampling_params: SamplingParams, arrival_time: float, + arrival_time_perf_counter: float, + lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.arrival_time = arrival_time + self.arrival_time_perf_counter = arrival_time_perf_counter + self.lora_request = lora_request + self.first_scheduled_time = None + self.first_token_time = None + self.time_in_queue = None self.prompt_logprobs: Optional[PromptLogprobs] = None @property @@ -253,7 +425,11 @@ def prompt(self) -> str: def prompt_token_ids(self) -> List[int]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. - return next(iter(self.seqs_dict.values())).data.prompt_token_ids + return next(iter(self.seqs_dict.values())).data.get_prompt_token_ids() + + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 def get_max_num_running_seqs(self) -> int: """The maximum number of sequences running in parallel in the remaining @@ -261,13 +437,13 @@ def get_max_num_running_seqs(self) -> int: if self.sampling_params.use_beam_search: # For beam search, maximally there will always be `best_of` beam # candidates running in the future. - return self.sampling_params.best_of + return self.sampling_params.actual_best_of else: - if self.sampling_params.best_of > self.num_seqs(): + if self.sampling_params.actual_best_of > self.num_seqs(): # At prompt stage, the sequence group is not yet filled up # and only have one sequence running. However, in the # generation stage, we will have `best_of` sequences running. - return self.sampling_params.best_of + return self.sampling_params.actual_best_of # At sampling stages, return the number of actual sequences # that are not finished yet. return self.num_unfinished_seqs() @@ -291,6 +467,23 @@ def get_unfinished_seqs(self) -> List[Sequence]: def get_finished_seqs(self) -> List[Sequence]: return [seq for seq in self.seqs_dict.values() if seq.is_finished()] + def advance_prefill_range(self, size: int) -> int: + """Advance the prefill range by the specified amount. + + Args: + size: The amount to advance the prefill range. + Returns: + The actual number of advanced tokens. + """ + return [ + seq.data.advance_prefill_range(size) + for seq in self.seqs_dict.values() + ][0] + + def get_num_unprefilled(self) -> int: + # All sequences in the group should have the same prompt. + return list(self.seqs_dict.values())[0].data.get_num_unprefilled() + def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: return len(self.get_seqs(status)) @@ -324,35 +517,60 @@ def __repr__(self) -> str: f"num_seqs={len(self.seqs_dict)})") -class SequenceGroupMetadata: +class SequenceGroupMetadataDelta(msgspec.Struct, + tag=True, + array_like=True, + omit_defaults=True): + request_id: str + block_tables: Optional[Dict[int, List[int]]] + + @property + def is_prompt(self): + return False + + @property + def is_chunked_prefill(self): + # A Delta should always be decoding (not chunk-prefiling). + return False + + +class SequenceGroupMetadata(msgspec.Struct, + tag=True, + array_like=True, + omit_defaults=True): """Metadata for a sequence group. Used to create `InputMetadata`. Args: request_id: The ID of the request. is_prompt: Whether the request is at prompt stage. + is_chunked_prefill: Whether the request is at chunked prefill stage. + Note that chunked_prefill is also a prompt stage. seq_data: The sequence data. (Seq id -> sequence data) sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) """ - def __init__( - self, - request_id: str, - is_prompt: bool, - seq_data: Dict[int, SequenceData], - sampling_params: SamplingParams, - block_tables: Dict[int, List[int]], - ) -> None: - self.request_id = request_id - self.is_prompt = is_prompt - self.seq_data = seq_data - self.sampling_params = sampling_params - self.block_tables = block_tables + request_id: str + is_chunked_prefill: bool + is_prompt: bool + seq_data: Dict[int, SequenceData] + sampling_params: SamplingParams + block_tables: Optional[Dict[int, List[int]]] + lora_request: Optional[LoRARequest] + + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + def update_from_delta(self, delta: "SequenceGroupMetadataDelta"): + self.block_tables = delta.block_tables + self.is_prompt = delta.is_prompt + return self -class SequenceOutput: + +class SequenceOutputs(msgspec.Struct, array_like=True, omit_defaults=True): """The model output associated with a sequence. Args: @@ -363,51 +581,91 @@ class SequenceOutput: (Token id -> logP(x_i+1 | x_0, ..., x_i)) """ - def __init__( - self, - parent_seq_id: int, - output_token: int, - logprobs: Dict[int, float], - ) -> None: - self.parent_seq_id = parent_seq_id - self.output_token = output_token - self.logprobs = logprobs + parent_seq_id: int + output_token: int + logprobs: Dict[int, float] def __repr__(self) -> str: - return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " + return (f"SequenceOutputs(parent_seq_id={self.parent_seq_id}, " f"output_token={self.output_token}, " f"logprobs={self.logprobs})") def __eq__(self, other: object) -> bool: - if not isinstance(other, SequenceOutput): + if not isinstance(other, SequenceOutputs): raise NotImplementedError() return (self.parent_seq_id == other.parent_seq_id and self.output_token == other.output_token and self.logprobs == other.logprobs) -class SequenceGroupOutput: - """The model output associated with a sequence group.""" - - def __init__( - self, - samples: List[SequenceOutput], - prompt_logprobs: Optional[PromptLogprobs], - ) -> None: - self.samples = samples - self.prompt_logprobs = prompt_logprobs +class SequenceGroupOutputs(msgspec.Struct, array_like=True, + omit_defaults=True): + """The model outputs associated with a sequence group.""" - def __repr__(self) -> str: - return (f"SequenceGroupOutput(samples={self.samples}, " - f"prompt_logprobs={self.prompt_logprobs})") + samples: List[SequenceOutputs] + prompt_logprobs: Optional[PromptLogprobs] def __eq__(self, other: object) -> bool: - if not isinstance(other, SequenceGroupOutput): - raise NotImplementedError() - return (self.samples == other.samples + return (isinstance(other, self.__class__) + and self.samples == other.samples and self.prompt_logprobs == other.prompt_logprobs) -# For each sequence group, we generate a list of SequenceOutput object, -# each of which contains one possible candidate for the next token. -SamplerOutput = List[SequenceGroupOutput] +class SamplerOutput(msgspec.Struct, array_like=True, omit_defaults=True): + outputs: List[SequenceGroupOutputs] + + # Used to store an on-GPU tensor containing the batch probabilities. + probs: Optional[torch.Tensor] = None + + # Used to store an on-GPU tensor containing the sampled token ids. + sampled_tokens: Optional[torch.Tensor] = None + + # Used to store an on-CPU tensor containing the batch logits + # for the full sequence. + logits: Optional[torch.Tensor] = None + + draft_target_worker_metrics: Optional["DraftTargetWorkerMetrics"] = None + + def __getitem__(self, idx: int): + return self.outputs[idx] + + def __setitem__(self, idx: int, value): + self.outputs[idx] = value + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, + self.__class__) and self.outputs == other.outputs + + +class DraftTargetWorkerMetrics(msgspec.Struct, + array_like=True, + omit_defaults=True): + num_spec_tokens: int + draft_acceptance_rate: float + system_efficiency: float + accepted_tokens: int + draft_tokens: int + emitted_tokens: int + + def __repr__(self) -> str: + return ( + f"DraftTargetWorkerMetrics(num_spec_tokens={self.num_spec_tokens}," + f"draft_acceptance_rate={self.draft_acceptance_rate:.3f}, " + f"system_efficiency={self.system_efficiency:.3f}, " + f"accepted_tokens={self.accepted_tokens}, " + f"draft_tokens={self.draft_tokens}, " + f"emitted_tokens={self.emitted_tokens})") + + +class ExecuteModelData(msgspec.Struct, array_like=True, omit_defaults=True): + seq_group_metadata_list: List[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]] + finished_request_ids_list: List[str] + blocks_to_swap_in: Dict[int, int] + blocks_to_swap_out: Dict[int, int] + blocks_to_copy: Dict[int, List[int]] + num_preallocated_slots: int + return_logits: bool = False diff --git a/vllm/worker/base_worker.py b/vllm/worker/base_worker.py new file mode 100644 index 0000000000000..7a89eac629533 --- /dev/null +++ b/vllm/worker/base_worker.py @@ -0,0 +1,64 @@ +from abc import ABC, abstractmethod + + +class BaseWorker(ABC): + """Base class for Workers. + See Worker implementation for details. + """ + + @abstractmethod + def init_model(self): + raise NotImplementedError + + @abstractmethod + def profile_num_available_blocks(self): + raise NotImplementedError + + @abstractmethod + def init_cache_engine(self): + raise NotImplementedError + + @abstractmethod + def execute_model(self): + raise NotImplementedError + + @abstractmethod + def get_metadata_cache_len(self): + raise NotImplementedError + + @abstractmethod + def get_runtime_context(self): + raise NotImplementedError + + +class BaseLoraWorker(BaseWorker): + """Base class for LoRA-enabled Workers. + See the Worker implememntation for details. + """ + + @abstractmethod + def add_lora(self): + raise NotImplementedError + + @abstractmethod + def remove_lora(self): + raise NotImplementedError + + @abstractmethod + def list_loras(self): + raise NotImplementedError + + +class LoraNotSupportedWorker(BaseLoraWorker): + """Implementation of BaseLoraWorker which raises + an error on LoRA calls. + """ + + def add_lora(self): + raise ValueError("LoRA not supported") + + def remove_lora(self): + raise ValueError("LoRA not supported") + + def list_loras(self): + raise ValueError("LoRA not supported") diff --git a/vllm/worker/draft_target_worker.py b/vllm/worker/draft_target_worker.py new file mode 100644 index 0000000000000..220d6bcc23fb9 --- /dev/null +++ b/vllm/worker/draft_target_worker.py @@ -0,0 +1,891 @@ +from typing import Iterator, List, Tuple, Optional, Union +from itertools import chain, count +from functools import cached_property +import logging +import time + +import msgspec +import torch +import traceback + +from vllm.anyscale.shm.msgspec_shm import SharedMsgspecBufferWithEvent +from vllm.sequence import (SamplerOutput, SequenceGroupMetadata, + ExecuteModelData, SequenceOutputs, SequenceData, + SequenceGroupOutputs, DraftTargetWorkerMetrics) +from vllm.worker.worker import Worker +from vllm.worker.multi_step_worker import MultiStepWorker +from vllm.worker.single_tp_worker import SingleTpWorker +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_group +from vllm.config import CacheConfig +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.base_worker import BaseWorker +from vllm.anyscale.profiler_utils import TorchProfiler, nvtx_range, Profilable +from vllm.model_executor.layers.sampler import RawSamplerOutput +from vllm.utils import in_wsl + +SeqId = int +TargetSeqId = int +TokenId = int + +logger = logging.getLogger(__name__) + + +class DraftTargetWorker(Profilable, BaseWorker): + """Worker which implements speculative decoding via a draft model for + proposing tokens and a target model for verifying tokens. A modified form of + rejection sampling is applied to generate the output tokens. + + Scoring is done by querying the target model for a prefix context for each + speculative token, and an additional prefix for a bonus token. + + For example, given the previously-generated sequence: + [0, 1, 2, 3, 4, 5] + + and the continuation proposed by the draft model (k=3): + [6, 7, 8] + + the DraftTargetWorker will query the target model for the probability + distribution over token ids given the following contexts: + [0, 1, 2, 3, 4, 5] # used to score [6] + [0, 1, 2, 3, 4, 5, 6] # used to score [6, 7] + [0, 1, 2, 3, 4, 5, 6, 7] # used to score [6, 7, 8] + [0, 1, 2, 3, 4, 5, 6, 7, 8] # used to generate a bonus token + + The output of the first model for each context is then used to accept or + reject the proposed draft tokens. + """ + + @classmethod + def from_workers(cls, draft_worker: Union[MultiStepWorker, SingleTpWorker], + target_worker: Worker) -> "DraftTargetWorker": + return cls(draft_worker, target_worker, RejectionSampler()) + + def __init__( + self, + draft_worker: Union[MultiStepWorker, SingleTpWorker], + target_worker: Worker, + rejection_sampler: RejectionSampler, + ): + """ + Create a DraftTargetWorker. + + Args: + draft_worker: A draft worker that can run multiple steps + in a row. + target_worker: The normal worker that is used for scoring. + It should contain the target model. + rejection_sampler: A Torch module used to perform modified rejection + sampling for speculative decoding. + """ + self.draft_worker = draft_worker + self.target_worker = target_worker + self.rejection_sampler = rejection_sampler + + self.device = None + + # We don't have a device set yet. + self._copy_stream: Optional[torch.cuda.Stream] = None + + self.probs_dtype = self.rejection_sampler.probs_dtype + self.token_id_dtype = self.rejection_sampler.token_id_dtype + + self._profiler = TorchProfiler() + + pin_memory = not in_wsl() + self._aggregate_num_accepted_tokens = torch.tensor( + 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) + self._aggregate_num_emitted_tokens = torch.tensor( + 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) + self._aggregate_num_draft_tokens = 0 + + self._rejsample_metrics_collect_interval_s = 5.0 + self._last_metrics_collect_time = 0 + + def _configure_samplers(self): + """Configure model samplers to return a probability tensor in the + SamplerOutput. This simplifies the data wrangling logic in speculative + decoding. + """ + self.draft_worker.model.sampler.include_gpu_probs_tensor = (True) + self.target_worker.model.sampler.include_gpu_probs_tensor = True + + def init_model(self): + # Intitialize the target model before the draft model. + # This allows the draft model to have a smaller TP degree than the + # larger model without refactors to parallel_state. + self.target_worker.init_model() + self.draft_worker.init_model() + self._configure_samplers() + + self.device = self.target_worker.device + self._copy_stream = torch.cuda.Stream() + + self.rejection_sampler.init_gpu_tensors(self.rank) + + def profile_num_available_blocks(self, block_size: int, + gpu_memory_utilization: float, + cpu_swap_space: int): + num_gpu_blocks, num_cpu_blocks = ( + self.target_worker.profile_num_available_blocks( + block_size, gpu_memory_utilization, cpu_swap_space)) + + new_num_gpu_blocks = self._calculate_gpu_blocks( + block_size, num_gpu_blocks) + return new_num_gpu_blocks, num_cpu_blocks + + def _calculate_gpu_blocks(self, block_size: int, + total_num_gpu_blocks: int) -> int: + """Given total_num_gpu_blocks, the number of GPU blocks that could be + allocate to the target model, this function calculates how many blocks + should be given to the draft and target model. + + Note that usually the block size, in bytes, of each model is different, + as it's a function of number of KV/layer, number of heads, and hidden + dimension size. + + Since the target and draft models allocate the same number of blocks, we + simply calculate the number of blocks where if allocated by both models, + the total memory usage from KV cache is no larger than the number of + blocks allocatable by the target model alone. + """ + target_kv_size_bytes = CacheEngine.get_cache_block_size( + block_size, + self.target_worker.model_config, + self.target_worker.parallel_config, + ) + + draft_kv_size_bytes = CacheEngine.get_cache_block_size( + block_size, + self.draft_worker.model_config, + self.draft_worker.parallel_config, + ) + + new_num_gpu_blocks = int(total_num_gpu_blocks * target_kv_size_bytes / + (draft_kv_size_bytes + target_kv_size_bytes)) + + return new_num_gpu_blocks + + def init_cache_engine(self, cache_config: CacheConfig): + self.target_worker.init_cache_engine(cache_config) + self.draft_worker.init_cache_engine(cache_config) + + @property + def rank(self): + return self.target_worker.rank + + def get_metadata_cache_len(self) -> int: + """Metadata cache not currently supported. + """ + return 0 + + def get_runtime_context(self) -> Optional[dict]: + return self.target_worker.get_runtime_context() + + def _get_max_model_len(self) -> Tuple[int, int]: + draft_max_model_len = (self.draft_worker.model_config.max_model_len) + target_max_model_len = (self.target_worker.model_config.max_model_len) + + assert draft_max_model_len is not None + assert target_max_model_len is not None + + return draft_max_model_len, target_max_model_len + + @staticmethod + def _get_k_from_execute_model_data( + execute_model_data: ExecuteModelData) -> int: + """Given an ExecuteModelData, determine the number of speculative + tokens (k). This is equal to the number of preallocated slots as each + speculative token requires a KV slot. + """ + k = execute_model_data.num_preallocated_slots + assert k >= 0, f"Expected {k=} >= 0" + return k + + @staticmethod + def _get_draft_num_preallocated_slots_from_k(k: int) -> int: + """Given the number of speculative tokens, return the number of + preallocated slots to give to the draft model. This is equal to k - 1, + because the draft model will not store KV for the last of the k + generated tokens. + """ + num_preallocated_slots = k - 1 + assert num_preallocated_slots >= 0, ("Expected " + f"{num_preallocated_slots=} >= 0") + return num_preallocated_slots + + def execute_model_shared_memory( + self, + shared_memory_input: SharedMsgspecBufferWithEvent, + shared_memory_output: SharedMsgspecBufferWithEvent, + participant_id: int # pylint: disable=unused-argument + ): + shared_memory_input.decoder = msgspec.msgpack.Decoder(ExecuteModelData) + logger.info("Worker shared memory input buffer id: " + f"{shared_memory_input.participant_id}") + logger.info("Worker shared memory output buffer id: " + f"{shared_memory_input.participant_id}") + parallel_group = get_tensor_model_parallel_group() + try: + while True: + shared_memory_input.wait_for_incoming_data() + data = shared_memory_input.get_data() + torch.distributed.barrier(group=parallel_group) + shared_memory_input.clear() + outputs = self.execute_model(data) + if self.rank < 1: + shared_memory_output.set_data(outputs) + except Exception: + traceback.print_exc() + shared_memory_output.set_error() + raise + + @torch.inference_mode() + @nvtx_range("draft_target_worker.execute_model") + def execute_model( + self, + execute_model_data: ExecuteModelData, + *, + return_python_output: bool = True # pylint: disable=unused-argument + ) -> List[SamplerOutput]: + + k = self._get_k_from_execute_model_data(execute_model_data) + if k == 0: + return self._run_prefill(execute_model_data) + + if len(execute_model_data.seq_group_metadata_list) == 0: + return self._run_for_empty_input(execute_model_data) + + return self._run_speculative_decoding_step(execute_model_data, k) + + @nvtx_range("draft_target_worker._run_prefill") + def _run_prefill( + self, execute_model_data: ExecuteModelData) -> List[SamplerOutput]: + """Run a prefill step, without any speculation. The input is sent to the + draft and target model so that prompt KV are stored in both caches. + """ + assert self._is_prefill(execute_model_data) + assert execute_model_data.num_preallocated_slots == 0, ( + "Expected " + f"{execute_model_data.num_preallocated_slots=} to be zero during " + "prefill.") + + logger.debug("draft prefill") + self.draft_worker.execute_model(execute_model_data, + return_python_output=False) + + logger.debug("target worker prefill") + sampler_output, = self.target_worker.execute_model(execute_model_data) + + # Do not want PyTorch tensors transferred back. + sampler_output.probs = None + sampler_output.sampled_tokens = None + return [sampler_output] + + def _is_prefill(self, execute_model_data: ExecuteModelData) -> bool: + """Returns whether or not the input ExecuteModelData is prefill or not. + """ + return any(seq_group_metadata.is_prompt for seq_group_metadata in + execute_model_data.seq_group_metadata_list) + + def _run_for_empty_input( + self, execute_model_data: ExecuteModelData) -> List[SamplerOutput]: + """If there are no sequences in the input, simply call the models with + the inpiut. This allows them to process metadata, such as cleaning up + after a request finishes. + """ + self.draft_worker.execute_model(execute_model_data, + return_python_output=False) + target_output, = self.target_worker.execute_model(execute_model_data) + return [target_output] + + @nvtx_range("draft_target_worker._run_speculative_decoding_step") + def _run_speculative_decoding_step( + self, + execute_model_data: ExecuteModelData, + k: int, + ) -> List[SamplerOutput]: + """Execute a single step of speculative decoding. + + This runs the draft model k times, then scores each token using the + target model. Rejection sampling is performed on the draft and target + outputs to determine which tokens can be accepted without modifying the + true distribution. + + Args: + execute_model_data: The input sequences that will be speculated + upon. + k: A hyperparameter integer dictating how many tokens to speculate. + Given some k, this will return a number of tokens per sequence + in the interval [1, k+1], depending on how many tokens are + accepted. + + Returns: + A List of SamplerOutput, as if the target worker were simply called + multiple times. + """ + logger.debug(f"running draft model for {k=} steps") + + (spec_seqs, non_spec_seqs, all_seqs, original_indices, + original_indices_ready) = self._get_seqs_for_spec_decode( + execute_model_data.seq_group_metadata_list, k) + + proposal_token_ids, proposal_probs = self._get_proposals( + execute_model_data, spec_seqs, k) + + should_collect_rejsample_metrics = ( + self._should_collect_rejsample_metrics(time.time())) + if should_collect_rejsample_metrics: + aggregate_metrics_ready = self._copy_rejsample_metrics_async() + + logger.debug("scoring draft tokens") + (proposal_scores, bonus_token_ids, + non_spec_token_ids) = self._score_proposals(execute_model_data, + proposal_token_ids, + spec_seqs, non_spec_seqs) + + with nvtx_range("draft_target_worker.rejection_sampler"): + accepted_token_ids = self.rejection_sampler( + proposal_scores, + bonus_token_ids, + proposal_probs, + proposal_token_ids, + ) + + # Append output tokens from non-speculative sequences to + # the accepted token ids tensor. + non_spec_token_ids = non_spec_token_ids.expand(-1, k + 1).clone() + non_spec_token_ids[:, 1:] = -1 + accepted_token_ids = torch.cat( + [accepted_token_ids, non_spec_token_ids]) + + # Rearrange so that results are in the order of the original seq group + # metadata. + torch.cuda.current_stream().wait_event(original_indices_ready) + accepted_token_ids = accepted_token_ids[original_indices] + + # Construct output. + seq_ids = self._get_all_seq_ids(all_seqs) + sampler_output = self._create_output_sampler_list( + seq_ids, accepted_token_ids) + + if should_collect_rejsample_metrics: + self._last_metrics_collect_time = time.time() + metrics = self._collect_rejsample_metrics(k, + aggregate_metrics_ready) + sampler_output[0].draft_target_worker_metrics = metrics + + return sampler_output + + def _get_seqs_for_spec_decode( + self, seq_group_metadata_list: List[SequenceGroupMetadata], k: int + ) -> Tuple[List[SequenceGroupMetadata], List[SequenceGroupMetadata], + List[SequenceGroupMetadata], torch.Tensor, torch.cuda.Event]: + """Determine which sequences are eligible for speculative decoding. + + Any sequence which would go over the model max len is ineligible. + """ + + all_seqs: List[SequenceGroupMetadata] = seq_group_metadata_list + spec_seqs: List[SequenceGroupMetadata] = [] + non_spec_seqs: List[SequenceGroupMetadata] = [] + + # Indices of each seq group in the original seq_group_metadata_list. + non_spec_indices: List[int] = [] + spec_indices: List[int] = [] + + # Maximum number of new tokens in speculative decoding. + max_num_new_tokens = k + 1 + + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + # Only one sequence per group in speculative decoding. + seq_data = next(iter(seq_group_metadata.seq_data.values())) + draft_max_model_len, target_max_model_len = self._get_max_model_len( + ) + max_model_len = min(draft_max_model_len, target_max_model_len) + seq_len = seq_data.get_len() + + # If the sequence would go over the model limit in speculative + # decoding, then we should perform normal decoding. + if seq_len + max_num_new_tokens > max_model_len: + non_spec_seqs.append(seq_group_metadata) + non_spec_indices.append(i) + continue + + # Otherwise, we can do speculative decoding. + spec_seqs.append(seq_group_metadata) + spec_indices.append(i) + + # Return the original indices so that the output order can be preserved. + # During scoring, the speculative sequences are placed first, then the + # non-speculative sequences. + # + # This async copy is waited upon before constructing final output. + with torch.cuda.stream(self._copy_stream): + original_indices = torch.tensor(spec_indices + non_spec_indices, + dtype=torch.long, + pin_memory=True) + original_indices = original_indices.to(device=self.device, + non_blocking=True) + original_indices_ready = torch.cuda.Event() + original_indices_ready.record() + return (spec_seqs, non_spec_seqs, all_seqs, original_indices, + original_indices_ready) + + @nvtx_range("draft_target_worker._get_proposals") + def _get_proposals(self, execute_model_data: ExecuteModelData, + spec_seqs: List[SequenceGroupMetadata], + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Get proposal tokens from the draft model. + """ + + # If there are no sequences to generate draft tokens for, return empty + # tensors. + if not spec_seqs: + return torch.zeros(0, + k, + dtype=self.token_id_dtype, + device=self.device), torch.zeros( + 0, + k, + self._vocab_size, + dtype=self.probs_dtype, + device=self.device) + + execute_model_data.num_preallocated_slots = ( + self._get_draft_num_preallocated_slots_from_k(k)) + execute_model_data.seq_group_metadata_list = spec_seqs + + sampler_output_list = self.draft_worker.execute_model( + execute_model_data, return_python_output=False) + + # vLLM currently stores results in Python datastructures. We convert to + # torch to use in rejection sampling and to simplify data + # transformations. + proposal_token_ids, proposal_probs = self._sampler_output_to_torch( + sampler_output_list) + + return proposal_token_ids, proposal_probs + + @nvtx_range("draft_target_worker._score_proposals") + def _score_proposals( + self, + execute_model_data: ExecuteModelData, + proposal_token_ids: torch.Tensor, # shape: [batch_size, k] + spec_seqs: List[SequenceGroupMetadata], + non_spec_seqs: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Score the proposed tokens via the target model. + + This converts each input sequence to a set of k+1 target sequences. The + target sequences have the unique continuations to be scored and a + unique sequence ID that is different from all input sequence ids. + + This adds overhead and should be removed. It is done because the sampler + currently operates on sequences instead of queries. + """ + # Convert to target sequence ids. + target_seq_group_metadata_list = self._create_scoring_model_input( + spec_seqs, proposal_token_ids) + + num_scoring_tokens = len(target_seq_group_metadata_list) + target_seq_group_metadata_list.extend(non_spec_seqs) + + # Score proposal token ids. + target_sampler_output = self.target_worker.execute_model( + ExecuteModelData( + target_seq_group_metadata_list, + execute_model_data.finished_request_ids_list, + execute_model_data.blocks_to_swap_in, + execute_model_data.blocks_to_swap_out, + execute_model_data.blocks_to_copy, + num_preallocated_slots=0, + ), + return_python_output=False) + + (target_token_ids, target_probs, + non_spec_target_token_ids) = self._split_scoring_output( + target_sampler_output, num_scoring_tokens) + + # Map distinct sequences used to score each token + # of shape [batch_size * k + 1] back to [batch_size, k + 1]. + batch_size, k = proposal_token_ids.shape + + target_token_ids = target_token_ids.squeeze().reshape( + batch_size, k + 1) + target_probs = target_probs.squeeze().reshape(batch_size, k + 1, + self._vocab_size) + + # shape: [batch_size, 1] + bonus_token_ids = target_token_ids[:, -1:] + + # shape: [batch_size, k, vocab_size] + proposal_scores = target_probs[:, :-1] + + return proposal_scores, bonus_token_ids, non_spec_target_token_ids + + def _create_output_sampler_list( + self, + seq_ids: List[SeqId], + accepted_token_ids: torch.Tensor # shape: [batch_size, k+1] + ) -> List[SamplerOutput]: + """Given the accepted token ids, create a list of SamplerOutput. + + The output is padded with -1 tokens such that each sequence has + the same number of outputs. + """ + # shape: [k+1, batch_size] + accepted_token_ids_by_step = accepted_token_ids.transpose(0, + 1).tolist() + sampler_output_list = [] + for token_ids_by_step in accepted_token_ids_by_step: + if all(token_id == -1 for token_id in token_ids_by_step): + break + + step_output_token_ids = [] + for token_id, seq_id in zip(token_ids_by_step, seq_ids): + step_output_token_ids.append( + SequenceGroupOutputs( + samples=[ + SequenceOutputs( + parent_seq_id=seq_id, + output_token=token_id, + # TODO currently rejection sampling does not + # emit probs, so this value is meaningless. + logprobs={token_id: 0}, + ) + ], + prompt_logprobs=None, + )) + sampler_output_list.append( + SamplerOutput(outputs=step_output_token_ids)) + + return sampler_output_list + + def _sampler_output_to_torch( + self, + sampler_output_list: List[RawSamplerOutput], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Utility function which converts a list of SamplerOutput to tensors. + + Returns: + token_ids: torch.Tensor + shape: [batch_size, len(sampler_output_list)] + + probs: torch.Tensor + shape: [batch_size, len(sampler_output_list), vocab_size] + """ + + # shape: [batch_size, num_sampler_output, vocab_size] + probs = torch.stack( + [sampler_output.probs for sampler_output in sampler_output_list], + dim=0, + ).transpose(0, 1) + + # shape: [batch_size, num_sampler_output] + token_ids = torch.stack( + [ + sampler_output.sampled_tokens.flatten() + for sampler_output in sampler_output_list + ], + dim=0, + ).transpose(0, 1) + + return token_ids, probs + + def _should_collect_rejsample_metrics(self, now: float) -> bool: + """Return whether or not this iteration should print rejection sampling + metrics. + """ + if self.rank != 0: + return False + + if (now - self._last_metrics_collect_time < + self._rejsample_metrics_collect_interval_s): + return False + return True + + def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: + """Copy rejection sampling metrics (number of accepted tokens, etc) to + CPU asynchronously. + + Returns a CUDA event recording when the copy is complete. + """ + self._copy_stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self._copy_stream): + self._aggregate_num_accepted_tokens.copy_( + self.rejection_sampler.num_accepted_tokens, non_blocking=True) + self._aggregate_num_emitted_tokens.copy_( + self.rejection_sampler.num_emitted_tokens, non_blocking=True) + # Number of draft tokens is calculated on CPU, so no copy is + # required. + self._aggregate_num_draft_tokens = ( + self.rejection_sampler.num_draft_tokens) + + aggregate_metrics_ready = torch.cuda.Event() + aggregate_metrics_ready.record(self._copy_stream) + + return aggregate_metrics_ready + + def _collect_rejsample_metrics( + self, k: int, + ready_event: torch.cuda.Event) -> DraftTargetWorkerMetrics: + """Create metrics object from statistics copied asynchronously. + + Args: + k: int. The number of speculative tokens; used to determine system + efficiency. + ready_event: torch.cuda.Event. The CUDA event recording when the + async GPU->CPU copy is complete. + """ + + ready_event.synchronize() + accepted_tokens = self._aggregate_num_accepted_tokens.item() + emitted_tokens = self._aggregate_num_emitted_tokens.item() + draft_tokens = self._aggregate_num_draft_tokens + + # Divide by k since batch size can be variable. + num_possible_tokens = (draft_tokens / k) * (k + 1) + + if draft_tokens > 0: + draft_acceptance_rate = accepted_tokens / draft_tokens + else: + draft_acceptance_rate = float("nan") + + if num_possible_tokens > 0: + system_efficiency = emitted_tokens / num_possible_tokens + else: + system_efficiency = float("nan") + + return DraftTargetWorkerMetrics( + num_spec_tokens=k, + draft_acceptance_rate=draft_acceptance_rate, + system_efficiency=system_efficiency, + accepted_tokens=accepted_tokens, + draft_tokens=draft_tokens, + emitted_tokens=emitted_tokens, + ) + + def _create_scoring_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_token_ids: torch.Tensor, # shape: [batch_size, k] + ) -> List[SequenceGroupMetadata]: + """Given the original input sequences and proposed tokens from the draft + model, create a list of target sequences that can be used for scoring. + """ + + # TODO(cade) perform this on GPU to remove blocking call. + proposal_token_ids = proposal_token_ids.tolist() + + if not seq_group_metadata_list: + return [] + + target_seq_ids_iter = self._create_target_seq_id_iterator( + self._get_all_seq_ids(seq_group_metadata_list)) + + target_seq_group_metadata = list( + chain.from_iterable( + self._create_target_seq_group_metadata( + seq_group_metadata, + proposal_token_ids, + i, + target_seq_ids_iter, + ) for i, seq_group_metadata in enumerate( + seq_group_metadata_list))) + + return target_seq_group_metadata + + def _split_scoring_output( + self, sampler_output: RawSamplerOutput, num_scoring_tokens: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Split the target model output into speculative and non-speculative + output. + """ + + # First samples are from speculative scoring, latter samples are non- + # speculative samples. + split_sizes = [ + num_scoring_tokens, + sampler_output.sampled_tokens.numel() - num_scoring_tokens + ] + (spec_probs, non_spec_probs) = sampler_output.probs.split(split_sizes) + (spec_sampled_tokens, non_spec_sampled_tokens + ) = sampler_output.sampled_tokens.flatten().split(split_sizes) + + # Convert scores to tensors. + sampler_output.probs = spec_probs + sampler_output.sampled_tokens = spec_sampled_tokens + target_token_ids, target_probs = self._sampler_output_to_torch( + [sampler_output]) + + # Convert non-speculative output tokens to tensors. + sampler_output.probs = non_spec_probs + sampler_output.sampled_tokens = non_spec_sampled_tokens + non_spec_target_token_ids, _ = self._sampler_output_to_torch( + [sampler_output]) + + return target_token_ids, target_probs, non_spec_target_token_ids + + def _create_target_seq_group_metadata( + self, + input_seq_group_metadata: SequenceGroupMetadata, + proposal_token_ids: List[int], # shape: [batch_size, k] + batch_index: int, + target_seq_ids_iter: Iterator[TargetSeqId], + ) -> List[SequenceGroupMetadata]: + """Given an input sequence group metadata and a list of draft tokens, + create a list of target SequenceGroupMetadata, one for each + token id that needs to be scored. + + Naive speculative decoding requires K target model scores, one for each + draft model token. However one can add a bonus token such that if each + token is accepted, then a final token may be sampled from the model. + This function creates K+1 target SequenceGroupMetadata to take + advantage of the bonus token. + """ + assert not input_seq_group_metadata.is_prompt, ( + "Speculating on " + "prompts not yet supported") + assert len(input_seq_group_metadata.seq_data) == 1, ( + "Beam search " + "not supported in speculative decoding") + input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys())) + + token_ids_to_score = self._get_token_ids_to_score( + proposal_token_ids[batch_index]) + + target_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + for token_ids in token_ids_to_score: + target_seq_group_metadata_list.append( + self._create_single_target_seq_group_metadata( + input_seq_group_metadata, + input_seq_id, + next(target_seq_ids_iter), + token_ids, + )) + + return target_seq_group_metadata_list + + def _create_single_target_seq_group_metadata( + self, + seq_group_metadata: SequenceGroupMetadata, + seq_id: SeqId, + target_seq_id: TargetSeqId, + token_ids: List[TokenId], + ) -> SequenceGroupMetadata: + """Create a single target SequenceGroupMetadata. + + Args: + seq_group_metadata: The metadata for the input sequence. + seq_id: The input sequence ID. + target_seq_id: The corresponding target sequence ID. + token_ids: The list of token ids that are to be appended to the + input sequence. + """ + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_token_ids = seq_data.get_prompt_token_ids() + new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] + + # The first scoring seqeuence will include the normal number of + # processed tokens. This allows it to write KV from the previous + # iteration. + # + # Subsequent scoring sequences only include a single unprocessed token; + # the token they score. + if len(token_ids) == 0: + num_processed_token_ids = seq_data.get_num_processed_token_ids() + else: + num_processed_token_ids = seq_data.get_len() + len(token_ids) - 1 + + return SequenceGroupMetadata( + request_id=seq_group_metadata.request_id, + is_prompt=seq_group_metadata.is_prompt, + is_chunked_prefill=seq_group_metadata.is_chunked_prefill, + seq_data={ + target_seq_id: + SequenceData( + token_ids=prompt_token_ids + new_output_token_ids, + num_prompt_tokens=len(prompt_token_ids), + # Support for tracking cumulative logprob not yet + # implemented. + cumulative_logprob=0.0, + num_processed_token_ids=num_processed_token_ids, + ), + }, + sampling_params=seq_group_metadata.sampling_params, + block_tables={ + target_seq_id: seq_group_metadata.block_tables[seq_id], + }, + lora_request=None, + ) + + def _get_token_ids_to_score( + self, + full_spec_token_ids: List[int] # shape: [k] + ) -> List[List[TokenId]]: + """Given an int tensor of proposal token ids, return a list of + token ids that should be scored. + + Returns k+1 output lists. The additional one is used for generating the + bonus token. + + Example: + Input: [0, 1, 2, 3] (k=4) + Output: (k+1 lists) + [] + [0] + [0, 1] + [0, 1, 2] + [0, 1, 2, 3] + """ + empty_token_ids = [] + + token_ids_to_score = [empty_token_ids] + token_ids_to_score.extend([ + full_spec_token_ids[:i + 1] + for i in range(len(full_spec_token_ids)) + ]) + return token_ids_to_score + + def _get_all_seq_ids( + self, seq_group_metadata_list: List[SequenceGroupMetadata] + ) -> List[SeqId]: + """Given a list of SequenceGroupMetadata, create a list of all + sequence ids. + """ + return list( + chain.from_iterable([ + seq_group_metadata.seq_data.keys() + for seq_group_metadata in seq_group_metadata_list + ])) + + def _create_target_seq_id_iterator( + self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: + """Create an iterator for creating target sequence ids. + Target sequence ids are distinct from sequence ids because we create a + distinct target sequence id for each proposal token to be scored. + + This implementation increments a counter starting at 1 + max of all + provided input sequence ids. + """ + return count(start=max(seq_ids) + 1) + + @cached_property + def _vocab_size(self) -> int: + """Get the vocab size of the model and make sure it's consistent between + draft and target workers. + """ + vocab_sizes = [ + worker.model.config.vocab_size + for worker in [self.draft_worker, self.target_worker] + ] + assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes) + return vocab_sizes[0] + + def start_profile(self, **kwargs) -> None: + self._profiler.start_profile(**kwargs) + + def stop_profile(self) -> None: + self._profiler.stop_profile() diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py new file mode 100644 index 0000000000000..4a5e4c0d1babf --- /dev/null +++ b/vllm/worker/multi_step_worker.py @@ -0,0 +1,203 @@ +from typing import List + +import torch + +from vllm.sequence import SamplerOutput, SequenceGroupMetadata, ExecuteModelData +from vllm.worker.worker import Worker +from vllm.worker.base_worker import LoraNotSupportedWorker +from vllm.anyscale.profiler_utils import nvtx_range +from vllm.model_executor.layers.sampler import RawSamplerOutput, pythonize_sampler_output +from vllm.model_executor.input_metadata import MultiStepInputMetadata, InputMetadata + + +class MultiStepWorker(Worker, LoraNotSupportedWorker): + """The MultiStepWorker is equivalent to a Worker except that it allows + multiple forward passes in a single call, assuming the scheduler has + allocated enough space to store the additional KV. This reduces overhead + by invoking the scheduler less and also by requiring less interprocess + communication. + + Currently, the MultiStepWorker does not support cache swap operations, delta + sequence group metadata updates, or beeam search. The first two can be + added, however adding beam search is more complicated as it requires memory + allocations during forks. + """ + + @staticmethod + def _get_num_steps_from_num_preallocated_slots( + num_preallocated_slots: int) -> int: + """Determine the number of steps the MultiStepWorker should run given + the number of slots preallocated by the scheduler. + + This is num_preallocated_slots plus one because the last generated token + will not have its KV generated yet. + """ + return num_preallocated_slots + 1 + + @torch.inference_mode() + @nvtx_range("multi_step_worker.execute_model") + def execute_model( + self, + execute_model_data: ExecuteModelData, + *, + return_python_output: bool = True) -> List[SamplerOutput]: + """Run the model forward pass num_steps times. Returns the list of + sampler output, one per model forward pass. + """ + (seq_group_metadata_list, _, _, _, _, + return_logits) = (execute_model_data.seq_group_metadata_list, + execute_model_data.finished_request_ids_list, + execute_model_data.blocks_to_swap_in, + execute_model_data.blocks_to_swap_out, + execute_model_data.blocks_to_copy, + execute_model_data.return_logits) + + # Return if there are no input sequences. + # We can do nothing here since input metadata deltas and + # cache events are not supported. + if not seq_group_metadata_list: + return [SamplerOutput([])] + + num_steps = self._get_num_steps_from_num_preallocated_slots( + execute_model_data.num_preallocated_slots) + + # Set num_preallocated_slots to zero; the single step worker does not + # need to know about any other slots. + old_num_preallocated_slots = execute_model_data.num_preallocated_slots + execute_model_data.num_preallocated_slots = 0 + + # Assert enough KV space for num_steps tokens per sequence. + self._assert_enough_kv_space( + execute_model_data.seq_group_metadata_list, num_steps) + + # Prepare input tensors. + ( + input_tokens, + input_positions, + multi_step_input_metadata, + _, + _, + ) = self._prepare_inputs(seq_group_metadata_list, + return_logits=return_logits, + num_steps=num_steps) + + # Run model num_steps times. + model_outputs = [] + prev_parameters_tensors = (None, None) + for _ in range(num_steps): + # TODO(cade,antoni) This code breaks abstractions to improve + # latency. We should refactor this so that `advance_step` can be + # performed without blocking the GPU. Then this can become cuda + # graphable, and simpler! + with nvtx_range("multi_step_worker.run_single_step"): + if model_outputs: + (input_tokens, input_positions, + input_metadata) = (self._advance_step( + model_outputs[-1], + input_metadata.selected_token_indices, input_tokens, + input_positions, multi_step_input_metadata)) + else: + input_metadata = multi_step_input_metadata.get_next_step() + + output = self.captured_model.execute_if_capturable( + input_ids=input_tokens, + positions=input_positions, + input_metadata=input_metadata, + cache_events=None, + sampling_parameters_tensors=prev_parameters_tensors[0], + sampling_token_tensors=prev_parameters_tensors[1], + ) + prev_parameters_tensors = (output.sampling_parameters_tensors, + output.sampling_token_tensors) + + model_outputs.append(output) + + execute_model_data.num_preallocated_slots = old_num_preallocated_slots + + if return_python_output: + model_outputs = [ + pythonize_sampler_output(o, input_metadata) + for o in model_outputs + ] + + return model_outputs + + def _advance_step( + self, last_output: RawSamplerOutput, + last_selected_token_indices: torch.Tensor, + input_tokens: torch.Tensor, input_positions: torch.Tensor, + multi_step_input_metadata: MultiStepInputMetadata + ) -> InputMetadata: + sampled_tokens = last_output.sampled_tokens.flatten() + # Sampled tokens from last step become new input tokens. + input_tokens[:last_selected_token_indices.shape[0]] = sampled_tokens + input_tokens[last_selected_token_indices.shape[0]:] = 0 + new_input_positions = input_positions[ + last_selected_token_indices].add_(1) + input_positions[:last_selected_token_indices. + shape[0]] = new_input_positions + input_positions[last_selected_token_indices.shape[0]:] = 0 + input_metadata = multi_step_input_metadata.get_next_step() + return input_tokens, input_positions, input_metadata + + def _assert_enough_kv_space( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + num_steps: int) -> None: + """Assert there are enough physical blocks per sequence to store the + current KV plus additional KV from num_steps tokens. + """ + assert self.block_size is not None + for seq_group_metadata in seq_group_metadata_list: + # Only one seq_id is guaranteed because there is no beam search. + seq_id = list(seq_group_metadata.seq_data.keys())[0] + seq = seq_group_metadata.seq_data[seq_id] + + # After num_steps, the seq len will be the current seq len + # plus one token per step. + final_seq_len = seq.get_len() + num_steps + + # We will have final_seq_len - 1 KV because vLLM saves KV for a + # token in the iteration after the token was generated. + required_num_kv_slots = final_seq_len - 1 + + # The allocated number of kv slots is the number of allocated blocks + # times the number of slots of block. + number_physical_blocks = len( + seq_group_metadata.block_tables[seq_id]) + allocated_kv_slots = number_physical_blocks * self.block_size + + if required_num_kv_slots > allocated_kv_slots: + request_id = seq_group_metadata.request_id + raise ValueError( + "The worker attempted to run " + f"{num_steps} times but found insufficient KV space for " + f"{request_id=} {seq_id=}. ({allocated_kv_slots=} " + f"{required_num_kv_slots=}).") + + def _raise_if_unsupported(self, + execute_model_data: ExecuteModelData) -> None: + """MultiStepWorker does not yet implement support for cache swap + operations, incremental seq group metadata, or beam search. + """ + (seq_group_metadata_list, _, blocks_to_swap_in, blocks_to_swap_out, + blocks_to_copy) = (execute_model_data.seq_group_metadata_list, + execute_model_data.finished_request_ids_list, + execute_model_data.blocks_to_swap_in, + execute_model_data.blocks_to_swap_out, + execute_model_data.blocks_to_copy) + + if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): + raise NotImplementedError( + "MultiStepWorker does not support cache operations") + + if any(not isinstance(seq_group_metadata, SequenceGroupMetadata) + for seq_group_metadata in seq_group_metadata_list): + raise NotImplementedError( + "MultiStepWorker only supports SequenceGroupMetadata input " + "(not deltas).") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in seq_group_metadata_list): + raise NotImplementedError( + "MultiStepWorker does not support beam search.") diff --git a/vllm/worker/single_tp_worker.py b/vllm/worker/single_tp_worker.py new file mode 100644 index 0000000000000..70b69911b8eda --- /dev/null +++ b/vllm/worker/single_tp_worker.py @@ -0,0 +1,111 @@ +from typing import List, Optional +import logging + +import torch + +from vllm.sequence import (SamplerOutput, ExecuteModelData) +from vllm.model_executor.parallel_utils.parallel_state import patch_tensor_parallel_group +from vllm.config import CacheConfig, ParallelConfig +from vllm.worker.base_worker import BaseWorker + +logger = logging.getLogger(__name__) + + +class SingleTpWorker(BaseWorker): + """Class which allows a speculative draft model to run with tensor parallel + degree of 1, while target model runs with larger tensor parallel degree. + This reduces the overhead of small draft models. + + This is implemented by changing vLLM's tensor parallel group to a group of + size 1 during forward passes. + """ + + @classmethod + def maybe_wrap_worker(cls, worker, draft_parallel_config: ParallelConfig, + target_parallel_config: ParallelConfig): + """Wrap the worker in a SingleTpWorker if necessary. + """ + draft_tp = draft_parallel_config.tensor_parallel_size + if draft_tp == target_parallel_config.tensor_parallel_size: + return worker + + if draft_tp != 1: + raise ValueError("{cls} only supports tp=1, found " + f"{draft_tp=}") + + logger.info(f"Wrapping {type(worker)} in {cls}") + return cls(worker) + + def __init__( + self, + worker: BaseWorker, + ): + self._worker = worker + self._single_tp_group = None + + def init_model(self): + """Initialize the model on all ranks. + + This also creates a single-rank process group containing only the + self process. + """ + world_rank = torch.distributed.get_rank() + self._single_tp_group = torch.distributed.new_group([world_rank]) + + with patch_tensor_parallel_group(self._single_tp_group): + self._worker.init_model(should_init_distributed_env=False) + + def profile_num_available_blocks(self, block_size: int, + gpu_memory_utilization: float, + cpu_swap_space: int): + """Profile the model on all ranks. + """ + with patch_tensor_parallel_group(self._single_tp_group): + return self._worker.profile_num_available_blocks( + block_size, gpu_memory_utilization, cpu_swap_space) + + def init_cache_engine(self, cache_config: CacheConfig): + """Initialize the cache engine on all ranks. + """ + with patch_tensor_parallel_group(self._single_tp_group): + self._worker.init_cache_engine(cache_config) + + @property + def model_config(self): + return self._worker.model_config + + @property + def parallel_config(self): + return self._worker.parallel_config + + @property + def model(self): + return self._worker.model + + @property + def rank(self): + return self._worker.rank + + def get_metadata_cache_len(self) -> int: + """Metadata cache not currently supported. + """ + return 0 + + def get_runtime_context(self) -> Optional[dict]: + return self._worker.get_runtime_context() + + @property + def _vocab_size(self) -> int: + return self.model.config.vocab_size + + @torch.inference_mode() + def execute_model( + self, + execute_model_data: ExecuteModelData, + *, + return_python_output: bool = True) -> List[SamplerOutput]: + """Execute the model separately on each rank. + """ + with patch_tensor_parallel_group(self._single_tp_group): + return self._worker.execute_model( + execute_model_data, return_python_output=return_python_output) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 8698b15721507..4170dfb591dd9 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,21 +1,52 @@ """A GPU worker class.""" +import gc import os -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple, Optional, Set, Union + +import msgspec import torch +import torch.backends import torch.distributed +import traceback -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) -from vllm.model_executor import set_random_seed +from vllm.anyscale.shm.msgspec_shm import SharedMsgspecBufferWithEvent +from vllm.worker.base_worker import BaseLoraWorker +from vllm.anyscale.profiler_utils import TorchProfiler, Profilable +from vllm.anyscale.cuda_graph import CudaGraphCapturedModel +from vllm.anyscale.lora.utils import LoRARequest +from vllm.anyscale.lora.worker_manager import ( + DisabledWorkerLoRAManager, + LRUCacheWorkerLoRAManager, +) +from vllm.config import ( + CacheConfig, + LoadConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + LoRAConfig, +) +from vllm.logger import init_logger +from vllm.model_executor import InputMetadata, MultiStepInputMetadata, get_model, set_random_seed +from vllm.model_executor.layers.sampler import pythonize_sampler_output from vllm.model_executor.parallel_utils.parallel_state import ( - initialize_model_parallel) -from vllm.sequence import SamplerOutput, SequenceGroupMetadata + initialize_model_parallel, model_parallel_is_initialized, + get_tensor_model_parallel_world_size, + get_pipeline_model_parallel_world_size, get_tensor_model_parallel_group) +from vllm.sampling_params import SamplingParams +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata, ExecuteModelData, SequenceGroupMetadataDelta from vllm.worker.cache_engine import CacheEngine -from vllm.worker.model_runner import ModelRunner +from vllm.anyscale.lora.layers import LoRAMapping +from vllm.engine.ray_utils import ray + +logger = init_logger(__name__) +LORA_WARMUP_RANK = 8 +MAX_INT_32 = 2**31 - 1 -class Worker: + +class Worker(Profilable, BaseLoraWorker): """A worker class that executes (a partition of) the model on a GPU. Each worker is associated with a single GPU. The worker is responsible for @@ -30,23 +61,52 @@ def __init__( scheduler_config: SchedulerConfig, rank: Optional[int] = None, distributed_init_method: Optional[str] = None, + load_config: Optional[LoadConfig] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.rank = rank self.distributed_init_method = distributed_init_method + self.load_config = load_config + self.lora_config = lora_config - self.model_runner = ModelRunner(model_config, parallel_config, - scheduler_config) # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). self.cache_config = None + self.block_size = None + self.sliding_window = None self.cache_engine = None self.cache_events = None self.gpu_cache = None - def init_model(self) -> None: + # Stats, updated every iteration + self.num_input_tokens = 0 + self.num_seq_groups = 0 + + self.lora_manager = None + + self.seq_metadata_cache = None + self.input_padding_size = self.scheduler_config.input_padding_size + + # Enable small batch padding optimization for chunked prefill. + self.optimize_small_batch_padding = \ + self.scheduler_config.max_chunked_prefill_len > 0 + + self._profiler = TorchProfiler() + + def init_model(self, should_init_distributed_env: bool = True): + """Initialize the model. + + If should_init_distributed_env is False, do not initialize torch + distributed or other collective utilities. + """ + # Torch default: False + torch.backends.cuda.matmul.allow_tf32 = True + # Torch default: True + torch.backends.cudnn.allow_tf32 = True + # torch.distributed.all_reduce does not free the input tensor until # the synchronization point. This causes the memory usage to grow # as the number of all_reduce calls increases. This env var disables @@ -58,8 +118,8 @@ def init_model(self) -> None: # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) # Env vars will be set by Ray. - self.rank = self.rank if self.rank is not None else int( - os.getenv("RANK", "-1")) + self.rank = (self.rank if self.rank is not None else int( + os.getenv("RANK", "-1"))) local_rank = int(os.getenv("LOCAL_RANK", "0")) self.device = torch.device(f"cuda:{local_rank}") if self.rank < 0: @@ -68,15 +128,33 @@ def init_model(self) -> None: _check_if_gpu_supports_dtype(self.model_config.dtype) - # Initialize the distributed environment. - _init_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method) + if should_init_distributed_env: + # Initialize the distributed environment. + _init_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method) # Initialize the model. set_random_seed(self.model_config.seed) + self.model = get_model(self.model_config, self.load_config, + self.lora_config) + + vocab_size = self.model.config.vocab_size + + if self.lora_config: + logger.info("Creating LoRA adapter...") + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, vocab_size, + self.lora_config, self.device) + self.model = self.lora_manager.create_lora_adapter(self.model) + else: + self.lora_manager = DisabledWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, vocab_size, + self.lora_config, self.device) - def load_model(self): - self.model_runner.load_model() + if self.scheduler_config.use_deltas: + self.seq_metadata_cache: Dict[str, SequenceGroupMetadata] = {} @torch.inference_mode() def profile_num_available_blocks( @@ -89,10 +167,166 @@ def profile_num_available_blocks( # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - self.model_runner.profile_run() + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + + # Enable top-k sampling to reflect the accurate memory usage. + vocab_size = self.model.config.vocab_size + sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests = [] + dummy_lora_requests_per_seq = [] + if self.lora_config: + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_id=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] + + def run_prefill_max_tokens(): + """Run the prefill step with the maximum number of sequences. + + Attempt to fill the batch (total num tokens == + max_num_batched_tokens). This may not be possible if `max_num_seqs + * max_model_len < max_num_batched_tokens`. + This is to mimic running + the largest possible prefill step + + Apply the maximum number of loras if necessary (1 for every + sequence) + + """ + seqs = [] + input_tokens = [] + for group_id in range(max_num_seqs): + seq_len = min( + max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs), + self.model_config.max_model_len) + prompt_tokens = [0] * seq_len + seq_data = SequenceData(prompt_tokens) + seq_data.advance_prefill_range(seq_len) + input_tokens.extend(prompt_tokens) + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + is_chunked_prefill=False, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, + ) + seqs.append(seq) + + ( + input_tokens, + input_positions, + input_metadata, + lora_mapping, + prepared_lora_requests, + ) = self._prepare_inputs(seqs) + + if self.lora_config: + self.apply_loras(prepared_lora_requests, lora_mapping) + + # Execute the model. + num_layers = self.model_config.get_num_layers(self.parallel_config) + self.model( + input_ids=input_tokens, + positions=input_positions, + kv_caches=[(None, None)] * num_layers, + input_metadata=input_metadata, + cache_events=None, + ) + + def run_generation_max_seqs(): + """Run the generation step with maximum number of sequences. + + This is to mimic running the largest possible generation step. + each sequences has a length of 1 to mimic the generation step. + + Apply the maximum number of loras if necessary (1 for every + sequence) + + """ + seqs = [] + input_tokens = [] + for group_id in range(max_num_seqs): + # setting sequence length to 1 to mimic the generation/decode + # step, where we only are operating on sequences of 1 token at + # a time. + seq_len = 1 + prompt_tokens = [0] * seq_len + seq_data = SequenceData(prompt_tokens) + seq_data.advance_prefill_range(seq_len) + input_tokens.extend(prompt_tokens) + seq = SequenceGroupMetadata( + request_id=str(group_id), + # though this is not meant to be a prompt, we set this to + # true because we don't have block tables / kv caches + # initialized, and we still want to mimic the generation + # with lora. + is_prompt=True, + is_chunked_prefill=False, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, + ) + seqs.append(seq) + + (input_tokens, input_positions, input_metadata, lora_mapping, + prepared_lora_requests) = (self._prepare_inputs(seqs)) + + if self.lora_config: + self.apply_loras(prepared_lora_requests, lora_mapping) + + # Execute the model. + num_layers = self.model_config.get_num_layers(self.parallel_config) + self.model( + input_ids=input_tokens, + positions=input_positions, + kv_caches=[(None, None)] * num_layers, + input_metadata=input_metadata, + cache_events=None, + ) + + # Run both prefill with the maximum number of tokens and generation + # with the maximum number of sequences. Apply any loras if necessary. + # If there are no loras applied then prefill will use the more memory + # than during generation. However, if loras are applied then it is + # possible for generation to use more memory than prefill. + # This is because when applying loras during prefill, the loras are + # applied iteratively on the batch for each sequence/lora, however + # during generation the loras are stacked and then 1 forward pass is + # done. While is is more efficient in terms of computation, it is less + # memory efficient since all the loras need to be loaded in GPU memory + # at the same time. + run_prefill_max_tokens() + # Since memory consumption for generation is only potentially larger + # than prefill when loras are applied, we only run the generation step + # when loras are applied to save time. + + if dummy_lora_requests: + run_generation_max_seqs() # Calculate the number of blocks that can be allocated with the # profiled peak memory. torch.cuda.synchronize() @@ -107,32 +341,398 @@ def profile_num_available_blocks( num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) + self.lora_manager.remove_all_loras() + if self.seq_metadata_cache: + self.seq_metadata_cache.clear() + gc.collect() torch.cuda.empty_cache() + + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + return num_gpu_blocks, num_cpu_blocks def init_cache_engine(self, cache_config: CacheConfig) -> None: self.cache_config = cache_config + self.block_size = cache_config.block_size + self.sliding_window = cache_config.sliding_window + self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) self.cache_events = self.cache_engine.events self.gpu_cache = self.cache_engine.gpu_cache - self.model_runner.set_block_size(self.cache_engine.block_size) + self.captured_model = CudaGraphCapturedModel(self.model, + self.gpu_cache, + self.model_config, + self.scheduler_config, + self.block_size) - def warm_up_model(self) -> None: - if not self.model_config.enforce_eager: - self.model_runner.capture_model(self.gpu_cache) - # Reset the seed to ensure that the random state is not affected by - # the model initialization and profiling. - set_random_seed(self.model_config.seed) + def _prepare_inputs( + self, + seq_group_metadata_list: List[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]], + return_logits: bool = False, + num_steps: int = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, LoRAMapping, + Set[LoRARequest]]: + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + input_tokens: List[int] = [] + input_positions: List[int] = [] + selected_token_indices: List[int] = [] + selected_token_start_idx = 0 + lora_requests: Set[LoRARequest] = set() + + lora_index_mapping: List[int] = [] + lora_prompt_mapping: List[int] = [] + is_multi_step = num_steps > 0 + num_steps = 1 if not is_multi_step else num_steps + + # Add prompt tokens. + prompt_lens: List[int] = [] + block_tables: List[List[List[int]]] = [[] for _ in range(num_steps)] + + max_num_blocks_per_seq = 0 + slot_mapping: List[List[int]] = [[] for _ in range(num_steps)] + context_lens: List[int] = [] + num_chunked_prefill = 0 + + for seq_idx, seq_group_metadata in enumerate(seq_group_metadata_list): + if not seq_group_metadata.is_prompt: + continue + + assert num_steps == 1 + + if seq_group_metadata.is_chunked_prefill: + num_chunked_prefill += 1 + + if self.seq_metadata_cache is not None: + self.seq_metadata_cache[ + seq_group_metadata.request_id] = seq_group_metadata + + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + seq_groups.append((seq_ids, sampling_params)) + lora_id = seq_group_metadata.lora_int_id + + # Use any sequence in the group. + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prefill_start, prefill_end = seq_data.get_prefill_range() + prompt_tokens = seq_data.get_token_ids()[prefill_start:prefill_end] + prompt_len = len(prompt_tokens) + prompt_lens.append(prompt_len) + context_lens.append(prefill_end) + + input_tokens.extend(prompt_tokens) + + # Set the right input_position for positional encoding. + input_positions.extend(range(prefill_start, prefill_end)) + + assert len(seq_ids) == 1, "Prompt input should have only one seq." + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + prompt_len - 1)) + + selected_token_indices.append(selected_token_start_idx + + prompt_len - 1) + selected_token_start_idx += prompt_len + + if lora_id > 0: + # if we are preparing inputs for the warmup step, we want the + # lora computation to take up the maximum possible amount of + # memory that way we can get a tighter upper bound on the + # amount of memory we can use and therefore not oom. If + # for_warmup is true, we add the lora lora mapping that is used + # during generation. + lora_requests.add(seq_group_metadata.lora_request) + lora_index_mapping += [lora_id] * prompt_len + lora_prompt_mapping.extend( + [lora_id] * + (prompt_len if sampling_params.prompt_logprobs else 1)) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping[0].extend([0] * prompt_len) + continue + + # Compute the slot mapping. + block_table = seq_group_metadata.block_tables[seq_id] + for i in range(prefill_start, prefill_end): + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping[0].append(slot) + + block_tables[0].append(block_table) + + # pad prompt tokens. This is required for cuda-graph. + input_tokens = self._pad_to_alignment(input_tokens) + input_positions = self._pad_to_alignment(input_positions) + slot_mapping[0] = self._pad_to_alignment(slot_mapping[0], + padded_value=-1) + num_prompt_tokens = len(input_tokens) + selected_token_start_idx = len(input_tokens) + + # Add generation tokens. + num_generation_tokens = 0 + + for seq_idx, seq_group_metadata in enumerate(seq_group_metadata_list): + if seq_group_metadata.is_prompt or \ + seq_group_metadata.is_chunked_prefill: + continue + + if (self.seq_metadata_cache is not None and + seq_group_metadata.request_id in self.seq_metadata_cache): + seq_group_metadata = self.seq_metadata_cache[ + seq_group_metadata.request_id].update_from_delta( + seq_group_metadata) + seq_group_metadata_list[seq_idx] = seq_group_metadata + + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + seq_groups.append((seq_ids, sampling_params)) + lora_id = seq_group_metadata.lora_int_id + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + seq_block_table = seq_group_metadata.block_tables[seq_id] + + generation_token_positions = ( + seq_data.get_unprocessed_token_positions()) + generation_token_ids = seq_data.get_unprocessed_token_ids() + + # Only the output from the last token needs to be sampled from. + token_position_to_sample = generation_token_positions[-1] + + for input_token, input_position in zip( + generation_token_ids, generation_token_positions): + # Calculate metadata of generation token. + context_len = input_position + 1 + + block_table = seq_block_table + block_number = block_table[input_position // + self.block_size] + block_offset = input_position % self.block_size + slot = block_number * self.block_size + block_offset + + # If sliding window is enabled, truncate the context len and + # block table. + if self.sliding_window is not None: + context_len = min(context_len, self.sliding_window) + + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + + # Append metadata of generation token to input lists. + input_positions.append(input_position) + input_tokens.append(input_token) + lora_index_mapping.append(lora_id) + slot_mapping[0].append(slot) + context_lens.append(context_len) + block_tables[0].append(block_table) + + # If we should sample a token from the output, append + # sampling metadata to input lists. + if input_position == token_position_to_sample: + selected_token_indices.append(selected_token_start_idx) + selected_token_start_idx += 1 + else: + # Do not select this token for sampling. + selected_token_start_idx += 1 + num_generation_tokens += 1 + + # Update LoRA mapping. + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + lora_prompt_mapping.append(lora_id) + + # This couples concerns between the multi step worker and the worker. + # TODO(cade,antoni) Clean up when making multi step worker cuda + # graphable. + for step in range(1, num_steps): + for seq_idx, seq_group_metadata in enumerate( + seq_group_metadata_list): + if seq_group_metadata.is_prompt or \ + seq_group_metadata.is_chunked_prefill: + continue + + seq_ids = list(seq_group_metadata.seq_data.keys()) + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + seq_block_table = seq_group_metadata.block_tables[seq_id] + + generation_token_positions = ( + seq_data.get_unprocessed_token_positions()) + token_position_to_sample = generation_token_positions[-1] + for input_position in generation_token_positions: + # Calculate metadata of generation token. + if input_position != token_position_to_sample: + continue + input_position = input_position + step + context_len = input_position + 1 + + block_table = seq_block_table + block_number = block_table[input_position // + self.block_size] + block_offset = input_position % self.block_size + slot = block_number * self.block_size + block_offset + + # If sliding window is enabled, truncate the context len + # and block table. + if self.sliding_window is not None: + context_len = min(context_len, self.sliding_window) + + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + + # Append metadata of generation token to input lists. + slot_mapping[step].append(slot) + block_tables[step].append(block_table) + + max_num_blocks_per_seq = [ + max((len(b) for b in block_table), default=0) + for block_table in block_tables + ] + max_context_len = max(context_lens, default=0) + + self.num_input_tokens = len(input_tokens) + self.num_seq_groups = len(seq_groups) + + # Pad the input length to be a multiple of 8. + # This is required for utilizing the Tensor Cores in NVIDIA GPUs. + input_tokens = self._pad_to_alignment( + input_tokens, num_generation_tokens=num_generation_tokens) + input_positions = self._pad_to_alignment( + input_positions, num_generation_tokens=num_generation_tokens) + slot_mapping = [ + self._pad_to_alignment(s, + padded_value=-1, + num_generation_tokens=num_generation_tokens) + for s in slot_mapping + ] + + # Convert to tensors. + tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device="cuda") + positions_tensor = torch.tensor(input_positions, + dtype=torch.long, + device="cuda") + slot_mapping_tensors = [ + torch.tensor(sm, dtype=torch.long, device="cuda") + for sm in slot_mapping + ] + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device="cuda") + selected_token_indices = torch.tensor(selected_token_indices, + dtype=torch.long, + device="cuda") + padded_block_tables = [[ + _pad_to_max(b_inner, max_num_blocks_per_seq[i]) + for b_inner in b_outer + ] for i, b_outer in enumerate(block_tables)] + block_tables_tensors = [ + torch.tensor(bt, dtype=torch.int, device="cuda") + for bt in padded_block_tables + ] + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + lora_mapping = LoRAMapping( + self._pad_to_alignment(lora_index_mapping), + lora_prompt_mapping, + ) + + input_metadata = InputMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + num_chunked_prefill=num_chunked_prefill, + num_prompt_tokens=num_prompt_tokens, + num_generation_tokens=num_generation_tokens, + slot_mapping=slot_mapping_tensors[0], + context_lens=context_lens_tensor, + max_context_len=max_context_len, + block_tables=block_tables_tensors[0], + selected_token_indices=selected_token_indices, + sliding_window=self.sliding_window, + return_logits=return_logits, + flash_style=self.scheduler_config.flash_style, + ) + if is_multi_step: + input_metadata = MultiStepInputMetadata( + num_steps, + input_metadata, + extra_slot_mapping=slot_mapping_tensors[1:], + extra_block_tables=block_tables_tensors[1:]) + return ( + tokens_tensor, + positions_tensor, + input_metadata, + lora_mapping, + lora_requests, + ) + + def execute_model_shared_memory( + self, + shared_memory_input: SharedMsgspecBufferWithEvent, + shared_memory_output: SharedMsgspecBufferWithEvent, + participant_id: int # pylint: disable=unused-argument + ): + shared_memory_input.decoder = msgspec.msgpack.Decoder(ExecuteModelData) + logger.info("Worker shared memory input buffer id: " + f"{shared_memory_input.participant_id}") + logger.info("Worker shared memory output buffer id: " + f"{shared_memory_input.participant_id}") + parallel_group = get_tensor_model_parallel_group() + try: + while True: + logger.debug("Waiting for incoming data...") + shared_memory_input.wait_for_incoming_data() + data = shared_memory_input.get_data() + logger.debug(f"Received data {data}.") + torch.distributed.barrier(group=parallel_group) + shared_memory_input.clear() + logger.debug("Executing model...") + outputs = self.execute_model(data) + logger.debug(f"Execute output {outputs}.") + if self.rank < 1: + logger.debug("Setting output") + shared_memory_output.set_data(outputs) + except Exception: + traceback.print_exc() + shared_memory_output.set_error() + raise @torch.inference_mode() def execute_model( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: + self, + execute_model_data: ExecuteModelData, + *, + return_python_output: bool = True) -> List[SamplerOutput]: + (seq_group_metadata_list, finished_request_ids_list, blocks_to_swap_in, + blocks_to_swap_out, blocks_to_copy, + return_logits) = (execute_model_data.seq_group_metadata_list, + execute_model_data.finished_request_ids_list, + execute_model_data.blocks_to_swap_in, + execute_model_data.blocks_to_swap_out, + execute_model_data.blocks_to_copy, + execute_model_data.return_logits) + + # Clean up the cache + if self.seq_metadata_cache: + for finished_request_id in finished_request_ids_list: + self.seq_metadata_cache.pop(finished_request_id, None) + # Issue cache operations. issued_cache_op = False if blocks_to_swap_in: @@ -145,21 +745,114 @@ def execute_model( self.cache_engine.copy(blocks_to_copy) issued_cache_op = True - cache_events = self.cache_events if issued_cache_op else None + if issued_cache_op: + cache_events = self.cache_events + else: + cache_events = None - # Wait for cache operations to finish. - # TODO(woosuk): Profile swapping overhead and optimize if needed. - if cache_events is not None: - for event in cache_events: - event.wait() # If there is no input, we don't need to execute the model. if not seq_group_metadata_list: - return {} + if cache_events is not None: + for event in cache_events: + event.wait() + return [SamplerOutput([])] + + seq_group_request_ids = [ + seq_group_metadata.request_id + for seq_group_metadata in seq_group_metadata_list + ] + + # Prepare input tensors. + ( + input_tokens, + input_positions, + input_metadata, + lora_mapping, + lora_requests, + ) = self._prepare_inputs(seq_group_metadata_list, + return_logits=return_logits) + + if self.lora_config: + lora_requests = [ + seq_group_metadata.lora_request + for seq_group_metadata in seq_group_metadata_list + ] + self.apply_loras(lora_requests, lora_mapping) + + output = self.captured_model.execute_if_capturable( + input_ids=input_tokens, + positions=input_positions, + input_metadata=input_metadata, + cache_events=cache_events, + ) + if return_python_output: + output = pythonize_sampler_output(output, input_metadata) + + if self.seq_metadata_cache is not None: + for request_id, sampler_output in zip(seq_group_request_ids, + output): + cached_seq_metadata = self.seq_metadata_cache[request_id] + for sample in sampler_output.samples: + cached_seq_metadata.seq_data[ + sample.parent_seq_id].append_token_ids( + [sample.output_token], [0]) + output = [output] - output = self.model_runner.execute_model(seq_group_metadata_list, - self.gpu_cache) return output + def apply_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + self.lora_manager.apply_loras(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.lora_manager.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.lora_manager.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.lora_manager.list_loras() + + def get_metadata_cache_len(self) -> int: + return len(self.seq_metadata_cache + ) if self.seq_metadata_cache is not None else -1 + + def get_runtime_context(self) -> Optional[dict]: + if ray: + runtime_ctx = ray.get_runtime_context() + return { + "job_id": runtime_ctx.get_job_id(), + "node_id": runtime_ctx.get_node_id(), + "worker_id": runtime_ctx.get_worker_id(), + } + return None + + def start_profile(self, **kwargs) -> None: + self._profiler.start_profile(**kwargs) + + def stop_profile(self) -> None: + self._profiler.stop_profile() + + def _pad_to_alignment(self, + x: List[int], + padded_value: int = 0, + num_generation_tokens: int = 0) -> List[int]: + """Pad the input to be a multiple of the alignment size. + + Args: + x: The input list. + padded_value: The value to pad with. + num_generation_tokens: The number of generation tokens in the input. + If this is between 1 and 8, and we enable small batch + optimization, the input will padded to 8. + Returns: + The padded list. + """ + pad_size = self.input_padding_size + if self.optimize_small_batch_padding and 0 < num_generation_tokens < 32: + pad_size = 8 + return x + [padded_value] * ((-len(x)) % pad_size) + def _init_distributed_environment( parallel_config: ParallelConfig, @@ -188,8 +881,23 @@ def _init_distributed_environment( # A small all_reduce for warmup. torch.distributed.all_reduce(torch.zeros(1).cuda()) - initialize_model_parallel(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + + if model_parallel_is_initialized(): + assert get_tensor_model_parallel_world_size( + ) == parallel_config.tensor_parallel_size + assert get_pipeline_model_parallel_world_size( + ) == parallel_config.pipeline_parallel_size + else: + initialize_model_parallel( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + ) + + +def _pad_to_max(x: List[int], + max_len: int, + padded_value: int = 0) -> List[int]: + return x + [padded_value] * (max_len - len(x)) def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):