Skip to content

Commit

Permalink
Inc on vLLM - Fix CR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
nirda7 committed Aug 13, 2024
1 parent 3b34893 commit 2ba2b5b
Show file tree
Hide file tree
Showing 12 changed files with 40 additions and 39 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _is_hpu() -> bool:
is_hpu_available = True
try:
subprocess.run(["hl-smi"], capture_output=True, check=True)
except (FileNotFoundError, NotADirectoryError, PermissionError, subprocess.CalledProcessError):
except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
if not os.path.exists('/dev/accel/accel0') and not os.path.exists(
'/dev/accel/accel_controlD0'):
# last resort...
Expand Down Expand Up @@ -267,7 +267,7 @@ def _is_neuron() -> bool:
torch_neuronx_installed = True
try:
subprocess.run(["neuron-ls"], capture_output=True, check=True)
except (FileNotFoundError, NotADirectoryError, PermissionError, subprocess.CalledProcessError):
except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
torch_neuronx_installed = False
return torch_neuronx_installed or VLLM_TARGET_DEVICE == "neuron"

Expand Down
20 changes: 10 additions & 10 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ def __init__(
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.qk_matmul = Matmul()
self.matmul_qk = Matmul()
self.softmax = Softmax()
self.av_matmul = Matmul()
self.key_cache = VLLMKVCache()
self.value_cache = VLLMKVCache()
self.matmul_av = Matmul()
self.k_cache = VLLMKVCache()
self.v_cache = VLLMKVCache()
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
self.position_bias = None
Expand Down Expand Up @@ -213,8 +213,8 @@ def forward(
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
num_kv_cache_passes, num_slots_available, indices, offsets = cache_ops.prepare_to_cache(key_cache, attn_metadata.slot_mapping)
key_cache = self.key_cache(key, key_cache, num_kv_cache_passes, num_slots_available, indices, offsets)
value_cache = self.value_cache(value, value_cache, num_kv_cache_passes, num_slots_available, indices, offsets)
key_cache = self.k_cache(key, key_cache, num_kv_cache_passes, num_slots_available, indices, offsets)
value_cache = self.v_cache(value, value_cache, num_kv_cache_passes, num_slots_available, indices, offsets)

if attn_metadata.is_prompt:
# Prompt run.
Expand All @@ -240,9 +240,9 @@ def forward(
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
qk_matmul_op=self.qk_matmul,
matmul_qk_op=self.matmul_qk,
softmax_op=self.softmax,
av_matmul_op=self.av_matmul,
matmul_av_op=self.matmul_av,
)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
Expand All @@ -266,8 +266,8 @@ def forward(
query, key_cache, value_cache, attn_metadata.block_tables,
attn_metadata.seq_lens_tensor, self.kv_cache_dtype,
self.num_kv_heads, self.scale, self.position_bias, k_scale,
v_scale, self.qk_matmul, self.softmax, self.av_matmul,
self.key_cache, self.value_cache)
v_scale, self.matmul_qk, self.softmax, self.matmul_av,
self.k_cache, self.v_cache)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

Expand Down
8 changes: 4 additions & 4 deletions vllm/attention/ops/habana_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def forward_decode(
alibi_slopes: Optional[torch.Tensor],
k_scale: float,
v_scale: float,
qk_matmul_op,
matmul_qk_op,
softmax_op,
av_matmul_op,
matmul_av_op,
k_cache_cls,
v_cache_cls,
) -> torch.Tensor:
Expand All @@ -93,9 +93,9 @@ def forward_decode(
block_size,
alibi_slopes,
kv_cache_dtype,
qk_matmul_op,
matmul_qk_op,
softmax_op,
av_matmul_op,
matmul_av_op,
k_cache_cls,
v_cache_cls,
)
Expand Down
4 changes: 2 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,13 +474,13 @@ def _verify_args(self) -> None:
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "hf8"):
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor. "
"FP8_E4M3 is also supported on hpu (hf8).")
"Intel Gaudi (HPU) supports fp8 (using fp8_inc).")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

Expand Down
4 changes: 2 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'hf8'],
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'fp8_inc'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3). '
'FP8_E4M3 is also supported on hpu (hf8).')
'Intel Gaudi (HPU) supports fp8 (using fp8_inc).')
parser.add_argument(
'--quantization-param-path',
type=nullable_str,
Expand Down
18 changes: 9 additions & 9 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def paged_attention_v1(query,
block_size,
alibi_slopes=None,
kv_cache_dtype=None,
qk_matmul_op=torch.matmul,
matmul_qk_op=torch.matmul,
softmax_op=torch.softmax,
av_matmul_op=torch.matmul,
matmul_av_op=torch.matmul,
k_cache_cls=None,
v_cache_cls=None) -> None:
seq_len = block_tables.size(1)
Expand All @@ -67,13 +67,13 @@ def paged_attention_v1(query,
keys = [k.unflatten(1, (kv_heads, 1)) for k in keys]
mask = mask.unsqueeze(2)

attn_weights = torch.cat([qk_matmul_op(query, k) for k in keys], dim=-1)
attn_weights = torch.cat([matmul_qk_op(query, k) for k in keys], dim=-1)
if alibi_slopes is not None:
attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):,
-attn_weights.size(3):])
attn_weights = softmax_op(attn_weights.masked_fill(mask, min_inf), dim=-1)

fetch_values = fetch_from_cache if v_cache_cls is None else k_cache_cls.fetch_from_cache
fetch_values = fetch_from_cache if v_cache_cls is None else v_cache_cls.fetch_from_cache
values = fetch_values(value_cache, block_tables, (0, 2, 1, 3))
if PA_SPLIT_VALUE:
attn_weights = attn_weights.split(block_size, dim=-1)
Expand All @@ -82,7 +82,7 @@ def paged_attention_v1(query,
attn_weights = [attn_weights]
if query_heads != kv_heads:
values = [v.unflatten(1, (kv_heads, 1)) for v in values]
attn_weights = [av_matmul_op(a, v) for a, v in zip(attn_weights, values)]
attn_weights = [matmul_av_op(a, v) for a, v in zip(attn_weights, values)]
if query_heads != kv_heads:
attn_weights = [a.flatten(1, 2) for a in attn_weights]
attn_weights = sum(attn_weights)
Expand Down Expand Up @@ -132,9 +132,9 @@ def prompt_attention(
attn_bias: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
qk_matmul_op = torch.matmul,
matmul_qk_op = torch.matmul,
softmax_op = torch.softmax,
av_matmul_op = torch.matmul,
matmul_av_op = torch.matmul,
) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
Expand All @@ -147,11 +147,11 @@ def prompt_attention(
value = value.unflatten(1, (kv_heads, 1))
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(2)
attn_weights = qk_matmul_op(query * scale, key.transpose(-1, -2))
attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2))
if attn_bias is not None:
attn_weights.add_(attn_bias)
attn_weights = softmax_op(attn_weights, dim=-1)
attn_weights = av_matmul_op(attn_weights, value)
attn_weights = matmul_av_op(attn_weights, value)
if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
attn_weights = attn_weights.transpose(1, 2)
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def get_quant_method(self, layer: torch.nn.Module,
def get_scaled_act_names(self) -> List[str]:
return []

def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75

Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import is_hip, is_hpu
from vllm.utils import is_hip
from vllm.platforms import current_platform

from .interfaces import SupportsLoRA
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
Expand Down Expand Up @@ -317,7 +318,7 @@ def forward(
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

if is_hpu():
if current_platform.is_hpu():
import habana_frameworks.torch as htorch
htorch.core.mark_step()
for i in range(self.start_layer, self.end_layer):
Expand All @@ -329,7 +330,7 @@ def forward(
attn_metadata,
residual,
)
if is_hpu():
if current_platform.is_hpu():
htorch.core.mark_step()

if not get_pp_group().is_last_rank:
Expand Down
2 changes: 1 addition & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"fp8": torch.uint8,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
"hf8": torch.float8_e4m3fn,
"fp8_inc": torch.float8_e4m3fn,
}

TORCH_DTYPE_TO_NUMPY_DTYPE = {
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _allocate_kv_cache(
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
dtype = torch.int8 if self.dtype == torch.float8_e4m3fn else self.dtype
dtype = torch.uint8 if self.dtype == torch.float8_e4m3fn else self.dtype
kv_cache.append(
torch.zeros(kv_cache_shape,
dtype=dtype,
Expand Down
4 changes: 3 additions & 1 deletion vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@ def __init__(
self._setup_buckets()

def load_model(self) -> None:
import habana_frameworks.torch.core as htcore
if self.model_config.quantization == 'inc':
htcore.hpu_set_env()
with HabanaMemoryProfiler() as m:
with HabanaMemoryProfiler() as m_getmodel:
self.model = get_model(
Expand All @@ -429,7 +432,6 @@ def load_model(self) -> None:
f"took {m_getmodel.get_summary_string()}")
logger.info(msg)

import habana_frameworks.torch.core as htcore
if self.model_config.quantization == 'inc':
logger.info("Preparing model with INC..")
with HabanaMemoryProfiler() as m_inc:
Expand Down
3 changes: 0 additions & 3 deletions vllm/worker/habana_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,6 @@ def init_device(self) -> None:
set_random_seed(self.model_config.seed)

def load_model(self):
if self.model_config.quantization == 'inc':
import habana_frameworks.torch.core as htcore
htcore.hpu_set_env()
self.model_runner.load_model()

@torch.inference_mode()
Expand Down

0 comments on commit 2ba2b5b

Please sign in to comment.