Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix one_hot bug in torch compile mode #427

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def __init__(self, model, block_size, dtype, enforce_eager):
'0').lower() in ['1', 'true']
self.block_size = block_size
self.dtype = dtype
self.enforce_eager = enforce_eager
if not is_fake_hpu() and not htorch.utils.internal.is_lazy(
) and not enforce_eager:
self.model = torch.compile(self.model,
Expand Down Expand Up @@ -317,14 +318,15 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype):
mask = mask >= metadata.block_usage.unsqueeze(-1)
attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_(
mask, -math.inf))
if is_fake_hpu():
# Unfortunately one_hot on CPU doesn't handle
# out of bounds classes. We need to mask those
# values manually
oob_values = metadata.block_mapping.lt(0)
block_mapping = metadata.block_mapping.masked_fill(oob_values, 0)
if is_fake_hpu() or (not htorch.utils.internal.is_lazy()
and not self.enforce_eager):
madamczykhabana marked this conversation as resolved.
Show resolved Hide resolved
# Unfortunately one_hot on CPU or in torch compile mode
# doesn't handle out of bounds classes.
# We need to mask those values manually
yuwenzho marked this conversation as resolved.
Show resolved Hide resolved
block_mapping = torch.nn.functional.relu(metadata.block_mapping)
block_mapping = torch.nn.functional.one_hot(block_mapping,
num_classes=batch_size)
oob_values = metadata.block_mapping.lt(0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this and next line were left accidentally ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and next line cannot be deleted. Negative values in metadata.block_mapping must be represented as zeros after one_hot, otherwise it will cause error in further processing.
Here is an example for better understanding:
metadata.block_mapping=[0, -1, -1] --relu()--> [0, 0, 0] --one_hot with num_classes=1--> [[1], [1], [1]] --oob value mask--> [[1], [0], [0]]

block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0)
else:
block_mapping = torch.nn.functional.one_hot(metadata.block_mapping,
Expand Down
Loading