Skip to content

Commit

Permalink
fix one_hot bug in torch compile mode
Browse files Browse the repository at this point in the history
Signed-off-by: yuwenzho <[email protected]>
  • Loading branch information
yuwenzho committed Oct 25, 2024
1 parent 3af4b6c commit 9747a8a
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def __init__(self, model, block_size, dtype, enforce_eager):
'0').lower() in ['1', 'true']
self.block_size = block_size
self.dtype = dtype
self.enforce_eager = enforce_eager
if not is_fake_hpu() and not htorch.utils.internal.is_lazy(
) and not enforce_eager:
self.model = torch.compile(self.model,
Expand Down Expand Up @@ -309,18 +310,20 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
attn_metadata = prefill_metadata._replace(attn_bias=attn_bias)
return attn_metadata

def _set_block_mapping(self, metadata, batch_size, device, dtype):
def _set_block_mapping(self, metadata, batch_size, device, dtype,
warmup_mode):
mask = torch.arange(0,
self.block_size,
device=device,
dtype=torch.int32).unsqueeze(0)
mask = mask >= metadata.block_usage.unsqueeze(-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))
if is_fake_hpu():
# Unfortunately one_hot on CPU doesn't handle
# out of bounds classes. We need to mask those
# values manually
if is_fake_hpu() or (not htorch.utils.internal.is_lazy()
and not self.enforce_eager and not warmup_mode):
# Unfortunately one_hot on CPU or in torch compile
# mode doesn't handle out of bounds classes.
# We need to mask those values manually
oob_values = metadata.block_mapping.lt(0)
block_mapping = metadata.block_mapping.masked_fill(oob_values, 0)
block_mapping = torch.nn.functional.one_hot(block_mapping,
Expand All @@ -335,26 +338,25 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype):
return metadata

def _update_metadata(self, attn_metadata, batch_size, seq_len, device,
dtype):
dtype, warmup_mode):
if attn_metadata.is_prompt:
meta = attn_metadata
attn_metadata = self._set_attn_bias(meta, batch_size, seq_len,
device, dtype)
else:
meta = attn_metadata
attn_metadata = self._set_block_mapping(meta, batch_size, device,
dtype)
dtype, warmup_mode)
return attn_metadata

def forward(self, *args, **kwargs):
kwargs = kwargs.copy()
selected_token_indices = kwargs.pop('selected_token_indices')
if 'warmup_mode' in kwargs:
kwargs.pop('warmup_mode')
warmup_mode = kwargs.pop('warmup_mode') if 'warmup_mode' in kwargs else None
input_ids = kwargs['input_ids']
kwargs['attn_metadata'] = self._update_metadata(
kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1),
input_ids.device, self.dtype)
input_ids.device, self.dtype, warmup_mode)
LoraMask.setLoraMask(kwargs.pop('lora_mask'))
hidden_states = self.model(*args, **kwargs)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
Expand Down Expand Up @@ -1969,6 +1971,7 @@ def execute_model(
with self.profiler.record_event('internal', model_event_name):
hidden_states = self.model.forward(
**execute_model_kwargs,
warmup_mode=warmup_mode,
selected_token_indices=sampling_metadata.selected_token_indices
)

Expand Down

0 comments on commit 9747a8a

Please sign in to comment.