Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config hidden layer number to run in 1 lazy graph #451

Open
wants to merge 2 commits into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata):
attn_bias: Optional[torch.Tensor]
seq_lens_tensor: Optional[torch.Tensor]
context_lens_tensor: Optional[torch.Tensor]
config_hidden_layers: Optional[int]


class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
max_seq_len: int = 4096,
config_hidden_layers: Optional[int] = None,
) -> None:
super(AttentionImpl, self).__init__()
self.kv_cache_dtype = kv_cache_dtype
Expand Down Expand Up @@ -133,6 +135,7 @@ def __init__(
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")
self.config_hidden_layers = config_hidden_layers

def forward(
self,
Expand Down Expand Up @@ -215,6 +218,7 @@ def forward(
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
)
else:
# TODO: enable FusedSDPA
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def forward(
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if is_hpu:
if is_hpu and i % attn_metadata.config_hidden_layers == 0:
htorch.core.mark_step()
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def forward(
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if is_hpu:
if is_hpu and i % attn_metadata.config_hidden_layers == 0:
htorch.core.mark_step()
if not get_pp_group().is_last_rank:
return IntermediateTensors({
Expand Down
5 changes: 4 additions & 1 deletion vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ def __init__(
self.lora_manager: LRUCacheWorkerLoRAManager = None
self.model: torch.nn.Module = None
self.inc_initialized_successfully = False
self.config_hidden_layers = int(os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))

# Profiler stats
self.profiler = HabanaHighLevelProfiler()
Expand Down Expand Up @@ -991,6 +992,7 @@ def _prepare_prompt(
num_prefill_tokens=sum_query_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
config_hidden_layers=self.config_hidden_layers,
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)

Expand Down Expand Up @@ -1199,6 +1201,7 @@ def _prepare_decode(
num_prefill_tokens=0,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
config_hidden_layers=self.config_hidden_layers,
)
return PrepareDecodeMetadata(input_tokens=input_tokens,
input_positions=input_positions,
Expand Down Expand Up @@ -1405,7 +1408,7 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
'attn_bias', 'seq_lens_tensor', 'context_lens_tensor',
'block_list', 'block_mapping', 'block_usage', 'slot_mapping',
'is_prompt', 'block_indices', 'block_offsets', 'block_scales',
'block_groups'
'block_groups', 'config_hidden_layers'
])
return attention_metadata

Expand Down
Loading