diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index f2875194e93a0..4d717858eddb6 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -317,18 +317,19 @@ def _set_block_mapping(self, metadata, batch_size, device, dtype): mask = mask >= metadata.block_usage.unsqueeze(-1) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) - if is_fake_hpu(): - # Unfortunately one_hot on CPU doesn't handle - # out of bounds classes. We need to mask those - # values manually - oob_values = metadata.block_mapping.lt(0) - block_mapping = metadata.block_mapping.masked_fill(oob_values, 0) - block_mapping = torch.nn.functional.one_hot(block_mapping, + + if not is_fake_hpu() and htorch.utils.internal.is_lazy(): + block_mapping = torch.nn.functional.one_hot(metadata.block_mapping, num_classes=batch_size) - block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) else: - block_mapping = torch.nn.functional.one_hot(metadata.block_mapping, + # Unfortunately one_hot on CPU/torch.compile mode/eager mode + # doesn't handle out of bounds classes, + # so we convert all negative values to 0. + block_mapping = torch.nn.functional.relu(metadata.block_mapping) + block_mapping = torch.nn.functional.one_hot(block_mapping, num_classes=batch_size) + oob_values = metadata.block_mapping.lt(0) + block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) block_mapping = block_mapping.to(dtype) metadata = metadata._replace(block_mapping=block_mapping, attn_bias=attn_bias)