Skip to content

Commit

Permalink
Fix one_hot bug in torch compile mode (#427)
Browse files Browse the repository at this point in the history
Fix one_hot bug in torch compile mode
```
>           block_mapping = torch.nn.functional.one_hot(metadata.block_mapping,
                                                        num_classes=batch_size)
E           RuntimeError: Class values must be non-negative.

../../vllm/worker/hpu_model_runner.py:311: RuntimeError
```
  • Loading branch information
yuwenzho authored Oct 29, 2024
1 parent 2a38e6f commit 3e135ae
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,18 +318,19 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype):
mask = mask >= metadata.block_usage.unsqueeze(-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))
if is_fake_hpu():
# Unfortunately one_hot on CPU doesn't handle
# out of bounds classes. We need to mask those
# values manually
oob_values = metadata.block_mapping.lt(0)
block_mapping = metadata.block_mapping.masked_fill(oob_values, 0)
block_mapping = torch.nn.functional.one_hot(block_mapping,

if not is_fake_hpu() and htorch.utils.internal.is_lazy():
block_mapping = torch.nn.functional.one_hot(metadata.block_mapping,
num_classes=batch_size)
block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0)
else:
block_mapping = torch.nn.functional.one_hot(metadata.block_mapping,
# Unfortunately one_hot on CPU/torch.compile mode/eager mode
# doesn't handle out of bounds classes,
# so we convert all negative values to 0.
block_mapping = torch.nn.functional.relu(metadata.block_mapping)
block_mapping = torch.nn.functional.one_hot(block_mapping,
num_classes=batch_size)
oob_values = metadata.block_mapping.lt(0)
block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0)
block_mapping = block_mapping.to(dtype)
metadata = metadata._replace(block_mapping=block_mapping,
attn_bias=attn_bias)
Expand Down

0 comments on commit 3e135ae

Please sign in to comment.