From c948323e42daa2b197af84324e6394561fcec0e2 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Fri, 1 Nov 2024 23:54:36 +0200 Subject: [PATCH 1/2] Fix multistep deepcopy overhead Signed-off-by: Chendi Xue --- vllm/worker/hpu_model_runner.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index c50e4e244dffe..fc7414c3b032b 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -2109,6 +2109,7 @@ def execute_model( # we only want to pythonize in the last step sampling_metadata.skip_sampler_cpu_output = True self.model.model.sampler.include_gpu_probs_tensor = True + cache_orig_output_token_ids = [] for i in range(num_steps): with self.profiler.record_event('internal', model_event_name): hidden_states = self.model.forward( @@ -2159,8 +2160,11 @@ def execute_model( ctx = model_input.async_callback.keywords[ # type: ignore "ctx"] seq_group_metadata_list = ctx.seq_group_metadata_list - seq_group_metadata_list = copy.deepcopy( - seq_group_metadata_list) + # Cache the original output token ids + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + cache_orig_output_token_ids.append({}) + for seq_id, data in seq_group_metadata.seq_data.items(): + cache_orig_output_token_ids[i][seq_id] = copy.deepcopy(data.output_token_ids) for seq_group_metadata in seq_group_metadata_list: for data in seq_group_metadata.seq_data.values(): max_output_len = sampling_metadata.seq_groups[ @@ -2185,6 +2189,12 @@ def execute_model( "attn_metadata": self.trim_attn_metadata(result.attn_metadata) }) + else: + if len(cache_orig_output_token_ids) > 0: + # Reuse the original output token ids + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + for seq_id, data in seq_group_metadata.seq_data.items(): + data.output_token_ids = cache_orig_output_token_ids[i][seq_id] if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event From 50515b059d7707e89c7ac4cfb2b6b5dee6ab16b9 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Sat, 2 Nov 2024 00:36:49 +0200 Subject: [PATCH 2/2] Fix formatting issue detected by yapf, ruff, isort and mypy Signed-off-by: Chendi Xue --- vllm/worker/hpu_model_runner.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index fc7414c3b032b..f78be3f971fbc 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -2109,7 +2109,7 @@ def execute_model( # we only want to pythonize in the last step sampling_metadata.skip_sampler_cpu_output = True self.model.model.sampler.include_gpu_probs_tensor = True - cache_orig_output_token_ids = [] + cache_orig_output_token_ids: List[Dict] = [] for i in range(num_steps): with self.profiler.record_event('internal', model_event_name): hidden_states = self.model.forward( @@ -2161,10 +2161,12 @@ def execute_model( "ctx"] seq_group_metadata_list = ctx.seq_group_metadata_list # Cache the original output token ids - for i, seq_group_metadata in enumerate(seq_group_metadata_list): + for i, seq_group_metadata in enumerate( + seq_group_metadata_list): cache_orig_output_token_ids.append({}) - for seq_id, data in seq_group_metadata.seq_data.items(): - cache_orig_output_token_ids[i][seq_id] = copy.deepcopy(data.output_token_ids) + for j, data in seq_group_metadata.seq_data.items(): + cache_orig_output_token_ids[i][j] = \ + copy.deepcopy(data.output_token_ids) for seq_group_metadata in seq_group_metadata_list: for data in seq_group_metadata.seq_data.values(): max_output_len = sampling_metadata.seq_groups[ @@ -2192,9 +2194,11 @@ def execute_model( else: if len(cache_orig_output_token_ids) > 0: # Reuse the original output token ids - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - for seq_id, data in seq_group_metadata.seq_data.items(): - data.output_token_ids = cache_orig_output_token_ids[i][seq_id] + for i, seq_group_metadata in enumerate( + seq_group_metadata_list): + for j, data in seq_group_metadata.seq_data.items(): + data.output_token_ids = \ + cache_orig_output_token_ids[i][j] if self.is_driver_worker and self.profiler.enabled: # Stop recording 'execute_model' event