diff --git a/README_GAUDI.md b/README_GAUDI.md index a569d6314acf8..9ea30a2e43f69 100644 --- a/README_GAUDI.md +++ b/README_GAUDI.md @@ -26,7 +26,8 @@ To verify that the Intel Gaudi software was correctly installed, run: ``` {.console} $ hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible $ apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core and habanalabs-thunk are installed -$ pip list | habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml, habana-media-loader and habana_quantization_toolkit are installed +$ pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed +$ pip list | grep neural # verify that neural-compressor is installed ``` Refer to [Intel Gaudi Software Stack diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 7af291d62efc6..ddbac022a8d9d 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -26,7 +26,8 @@ To verify that the Intel Gaudi software was correctly installed, run: $ hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible $ apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core and habanalabs-thunk are installed - $ pip list | habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml, habana-media-loader and habana_quantization_toolkit are installed + $ pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed + $ pip list | grep neural # verify that neural_compressor is installed Refer to `Intel Gaudi Software Stack Verification `__ diff --git a/vllm/attention/backends/habana_attn.py b/vllm/attention/backends/habana_attn.py index 33b6e2e538b13..7a867e79b203d 100644 --- a/vllm/attention/backends/habana_attn.py +++ b/vllm/attention/backends/habana_attn.py @@ -12,6 +12,8 @@ AttentionMetadata, AttentionType) from vllm.attention.ops.habana_paged_attn import (HabanaPagedAttention, HabanaPagedAttentionMetadata) +from vllm.hpu import cache_ops +from vllm.hpu.utils import Matmul, Softmax, VLLMKVCache from vllm.logger import init_logger logger = init_logger(__name__) @@ -108,7 +110,7 @@ def __post_init__(self): self.attn_bias: Optional[torch.Tensor] = None -class HabanaAttentionImpl(AttentionImpl): +class HabanaAttentionImpl(AttentionImpl, torch.nn.Module): """ If the input tensors contain prompt tokens, the layout is as follows: |<--------------- num_prefill_tokens ----------------->| @@ -137,10 +139,16 @@ def __init__( blocksparse_params: Optional[Dict[str, Any]] = None, max_seq_len: int = 4096, ) -> None: + super(AttentionImpl, self).__init__() self.kv_cache_dtype = kv_cache_dtype self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) + self.matmul_qk = Matmul() + self.softmax = Softmax() + self.matmul_av = Matmul() + self.k_cache = VLLMKVCache() + self.v_cache = VLLMKVCache() self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window self.position_bias = None @@ -204,9 +212,13 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - HabanaPagedAttention.write_to_paged_cache( - key, value, key_cache, value_cache, attn_metadata.slot_mapping, - self.kv_cache_dtype, attn_metadata.is_prompt) + num_kv_cache_passes, num_slots_available, indices, offsets = \ + cache_ops.prepare_to_cache(key_cache, + attn_metadata.slot_mapping) + key_cache = self.k_cache(key, key_cache, num_kv_cache_passes, + num_slots_available, indices, offsets) + value_cache = self.v_cache(value, value_cache, num_kv_cache_passes, + num_slots_available, indices, offsets) if attn_metadata.is_prompt: # Prompt run. @@ -232,6 +244,9 @@ def forward( attn_bias=attn_bias, p=0.0, scale=self.scale, + matmul_qk_op=self.matmul_qk, + softmax_op=self.softmax, + matmul_av_op=self.matmul_av, ) output = out.reshape(batch_size, seq_len, hidden_size) else: @@ -255,7 +270,8 @@ def forward( query, key_cache, value_cache, attn_metadata.block_tables, attn_metadata.seq_lens_tensor, self.kv_cache_dtype, self.num_kv_heads, self.scale, self.position_bias, k_scale, - v_scale) + v_scale, self.matmul_qk, self.softmax, self.matmul_av, + self.k_cache, self.v_cache) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/attention/ops/habana_paged_attn.py b/vllm/attention/ops/habana_paged_attn.py index 7dd701c7a0cdf..9602886299c47 100644 --- a/vllm/attention/ops/habana_paged_attn.py +++ b/vllm/attention/ops/habana_paged_attn.py @@ -75,6 +75,11 @@ def forward_decode( alibi_slopes: Optional[torch.Tensor], k_scale: float, v_scale: float, + matmul_qk_op, + softmax_op, + matmul_av_op, + k_cache_cls, + v_cache_cls, ) -> torch.Tensor: block_size = value_cache.shape[1] return ops.paged_attention_v1( @@ -88,6 +93,11 @@ def forward_decode( block_size, alibi_slopes, kv_cache_dtype, + matmul_qk_op, + softmax_op, + matmul_av_op, + k_cache_cls, + v_cache_cls, ) @staticmethod diff --git a/vllm/config.py b/vllm/config.py index f16bea16fe646..6acb70ad047b2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -474,12 +474,13 @@ def _verify_args(self) -> None: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"): logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " "Meanwhile, it may cause accuracy drop without a proper " - "scaling factor") + "scaling factor. " + "Intel Gaudi (HPU) supports fp8 (using fp8_inc).") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") @@ -600,11 +601,12 @@ class LoadConfig: ignore_patterns: The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints. - + device: Device on which weights are loaded. """ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO download_dir: Optional[str] = None + device: Optional[str] = None model_loader_extra_config: Optional[Union[str, dict]] = field( default_factory=dict) ignore_patterns: Optional[Union[List[str], str]] = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e4b223a1b505f..d6c544750afea 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -38,6 +38,7 @@ class EngineArgs: trust_remote_code: bool = False download_dir: Optional[str] = None load_format: str = 'auto' + weights_load_device: Optional[str] = None dtype: str = 'auto' kv_cache_dtype: str = 'auto' quantization_param_path: Optional[str] = None @@ -205,6 +206,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'section for more information.\n' '* "bitsandbytes" will load the weights using bitsandbytes ' 'quantization.\n') + parser.add_argument("--weights-load-device", + type=str, + default=EngineArgs.weights_load_device, + choices=["cuda", "neuron", "hpu", "cpu"], + help='Device on which weights are loaded.') parser.add_argument( '--dtype', type=str, @@ -223,11 +229,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--kv-cache-dtype', type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3', 'fp8_inc'], default=EngineArgs.kv_cache_dtype, help='Data type for kv cache storage. If "auto", will use model ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3). ' + 'Intel Gaudi (HPU) supports fp8 (using fp8_inc).') parser.add_argument( '--quantization-param-path', type=nullable_str, @@ -835,9 +842,12 @@ def create_engine_config(self, ) -> EngineConfig: self.model_loader_extra_config[ "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path + device = device_config.device if self.weights_load_device is None else \ + self.weights_load_device load_config = LoadConfig( load_format=self.load_format, download_dir=self.download_dir, + device=device, model_loader_extra_config=self.model_loader_extra_config, ignore_patterns=self.ignore_patterns, ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3f7e0a7a4dc53..f8b9c48bc9589 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -182,7 +182,7 @@ def __init__( "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " "pipeline_parallel_size=%d, " "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " + "weights_load_device=%s, enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " @@ -206,6 +206,7 @@ def __init__( parallel_config.pipeline_parallel_size, parallel_config.disable_custom_all_reduce, model_config.quantization, + load_config.device, model_config.enforce_eager, cache_config.cache_dtype, model_config.quantization_param_path, @@ -853,6 +854,9 @@ def _process_model_outputs( request_outputs.append(request_output) return request_outputs + def finish_measurements(self): + self.model_executor.finish_measurements() + def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 62309ed345b1d..fc9f118ff14b2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -173,6 +173,9 @@ def set_tokenizer( self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer( tokenizer) + def finish_measurements(self): + self.llm_engine.finish_measurements() + @overload # LEGACY: single (prompt + optional token ids) def generate( self, diff --git a/vllm/executor/habana_executor.py b/vllm/executor/habana_executor.py index f5cf26b687053..80f8037a2d043 100644 --- a/vllm/executor/habana_executor.py +++ b/vllm/executor/habana_executor.py @@ -90,6 +90,9 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: msg = f"init_cache_engine took {cache_init_m.get_summary_string()}" logger.info(msg) + def finish_measurements(self): + self.driver_worker.finish_measurements() + def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: @@ -180,6 +183,12 @@ def check_health(self) -> None: # it's running. return + def shutdown(self) -> None: + self.driver_worker.shutdown_inc() + + def __del__(self): + self.shutdown() + class HabanaExecutorAsync(HabanaExecutor, ExecutorAsyncBase): diff --git a/vllm/executor/ray_habana_executor.py b/vllm/executor/ray_habana_executor.py index 9e0a89cbeb8aa..17e3414a96b57 100644 --- a/vllm/executor/ray_habana_executor.py +++ b/vllm/executor/ray_habana_executor.py @@ -237,6 +237,9 @@ def _driver_execute_model( return self.driver_worker.execute_method("execute_model", execute_model_req) + def finish_measurements(self): + self._run_workers("finish_measurements") + def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: diff --git a/vllm/hpu/cache_ops.py b/vllm/hpu/cache_ops.py index 14824945aa53a..98f109accea06 100644 --- a/vllm/hpu/cache_ops.py +++ b/vllm/hpu/cache_ops.py @@ -43,6 +43,37 @@ def reshape_and_cache(key, value[start_idx:end_idx]) +def prepare_to_cache(cache, slot_mapping): + num_blocks = cache.size(0) + block_size = cache.size(1) + slot_mapping = slot_mapping.flatten() + indices = torch.div(slot_mapping, block_size, rounding_mode="floor") + offsets = torch.fmod(slot_mapping, block_size) + num_slots_requested = slot_mapping.size(0) + num_slots_available = num_blocks * block_size + # NOTE(kzawora): HPU PT bridge crashes with + # RuntimeError: Invalid inputs for scatter_nd_onnx + # on index_put when num_slots_requested > num_slots_available. + # This case might occur when we have little kv cache blocks and + # lots of padding, or are doing warmup. + # This loop is a workaround for this issue. Please remove it + # once key_cache.index_put_(indices, offsets), key) works. + num_kv_cache_passes = torch.div(num_slots_requested, + num_slots_available).ceil().int().item() + + return num_kv_cache_passes, num_slots_available, indices, offsets + + +def insert_or_update_cache(input, cache, num_kv_cache_passes, + num_slots_available, block_indices, block_offsets): + for i in range(num_kv_cache_passes): + start_idx = i * num_slots_available + end_idx = (i + 1) * num_slots_available + cache.index_put_((block_indices[start_idx:end_idx], + block_offsets[start_idx:end_idx]), + input[start_idx:end_idx]) + + def swap_blocks(src, dst, block_mapping): index_src = torch.zeros((1, ), dtype=torch.int32, device=src.device) index_dst = torch.zeros((1, ), dtype=torch.int32, device=dst.device) diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index c8f00c1cbd59d..23f6964723d3f 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -11,7 +11,6 @@ import torch import torch.nn.functional as F -import vllm.hpu.utils as hpu_utils from vllm.logger import init_logger logger = init_logger(__name__) @@ -33,7 +32,6 @@ def fetch_from_cache(cache, blocks, permutations): ] -@hpu_utils.with_mark_steps def paged_attention_v1(query, key_cache, value_cache, @@ -43,7 +41,12 @@ def paged_attention_v1(query, context_lens, block_size, alibi_slopes=None, - kv_cache_dtype=None) -> None: + kv_cache_dtype=None, + matmul_qk_op=torch.matmul, + softmax_op=torch.softmax, + matmul_av_op=torch.matmul, + k_cache_cls=None, + v_cache_cls=None) -> None: seq_len = block_tables.size(1) batch_size, query_heads, _ = query.shape _, _, kv_heads, _ = key_cache.shape @@ -56,19 +59,23 @@ def paged_attention_v1(query, batch_size, 1, 1, -1)) query.mul_(scale) query = query.unsqueeze(-2) - keys = fetch_from_cache(key_cache, block_tables, (0, 2, 3, 1)) + fetch_keys = fetch_from_cache if k_cache_cls is None else \ + k_cache_cls.fetch_from_cache + keys = fetch_keys(key_cache, block_tables, (0, 2, 3, 1)) if query_heads != kv_heads: query = query.unflatten(1, (kv_heads, -1)) keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] mask = mask.unsqueeze(2) - attn_weights = torch.cat([torch.matmul(query, k) for k in keys], dim=-1) + attn_weights = torch.cat([matmul_qk_op(query, k) for k in keys], dim=-1) if alibi_slopes is not None: attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):, -attn_weights.size(3):]) - attn_weights = (attn_weights.masked_fill(mask, min_inf).softmax(dim=-1)) + attn_weights = softmax_op(attn_weights.masked_fill(mask, min_inf), dim=-1) - values = fetch_from_cache(value_cache, block_tables, (0, 2, 1, 3)) + fetch_values = fetch_from_cache if v_cache_cls is None else \ + v_cache_cls.fetch_from_cache + values = fetch_values(value_cache, block_tables, (0, 2, 1, 3)) if PA_SPLIT_VALUE: attn_weights = attn_weights.split(block_size, dim=-1) else: @@ -76,7 +83,7 @@ def paged_attention_v1(query, attn_weights = [attn_weights] if query_heads != kv_heads: values = [v.unflatten(1, (kv_heads, 1)) for v in values] - attn_weights = [torch.matmul(a, v) for a, v in zip(attn_weights, values)] + attn_weights = [matmul_av_op(a, v) for a, v in zip(attn_weights, values)] if query_heads != kv_heads: attn_weights = [a.flatten(1, 2) for a in attn_weights] attn_weights = sum(attn_weights) @@ -119,7 +126,6 @@ def static_fused_moe(hidden_states, w1, w2, score, topk): return final_hidden_states.view(-1, D) -@hpu_utils.with_mark_steps def prompt_attention( query: torch.Tensor, key: torch.Tensor, @@ -127,6 +133,9 @@ def prompt_attention( attn_bias: Optional[torch.Tensor] = None, p: float = 0.0, scale: Optional[float] = None, + matmul_qk_op=torch.matmul, + softmax_op=torch.softmax, + matmul_av_op=torch.matmul, ) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -139,11 +148,11 @@ def prompt_attention( value = value.unflatten(1, (kv_heads, 1)) if attn_bias is not None: attn_bias = attn_bias.unsqueeze(2) - attn_weights = torch.matmul(query * scale, key.transpose(-1, -2)) + attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2)) if attn_bias is not None: attn_weights.add_(attn_bias) - attn_weights = torch.softmax(attn_weights, dim=-1) - attn_weights = torch.matmul(attn_weights, value) + attn_weights = softmax_op(attn_weights, dim=-1) + attn_weights = matmul_av_op(attn_weights, value) if query_heads != kv_heads: attn_weights = attn_weights.flatten(1, 2) attn_weights = attn_weights.transpose(1, 2) diff --git a/vllm/hpu/utils.py b/vllm/hpu/utils.py index b7b435c50c295..3d9c7cb1c4c22 100644 --- a/vllm/hpu/utils.py +++ b/vllm/hpu/utils.py @@ -8,6 +8,9 @@ from functools import wraps import habana_frameworks.torch as htorch +import torch + +from vllm.hpu.cache_ops import insert_or_update_cache def with_mark_steps(fn): @@ -22,3 +25,40 @@ def wrapped(*args, **kwargs): return result return wrapped + + +class Matmul(torch.nn.Module): + + def __init__(self): + super(Matmul, self).__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + +class Softmax(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, dim=None, inv_head=None): + return torch.softmax(x, dim) + + +class VLLMKVCache(torch.nn.Module): + + def __init__(self): + super(VLLMKVCache, self).__init__() + + def forward(self, input, cache, num_kv_cache_passes, num_slots_available, + block_indices, block_offset): + insert_or_update_cache(input, cache, num_kv_cache_passes, + num_slots_available, block_indices, + block_offset) + return cache + + def fetch_from_cache(self, cache, blocks, permutations): + return [ + cache.index_select(0, blocks[:, i]).permute(permutations) + for i in range(blocks.size(1)) + ] diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 55cbbabd7da44..c12668c14887d 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -79,18 +79,15 @@ def forward_hpu( if HPUFusedRMSNorm is None: return self.forward_native(x, residual) if residual is not None: - orig_dtype = x.dtype orig_shape = x.shape residual += x.view(residual.shape) # Note: HPUFusedRMSNorm requires 3D tensors as inputs - x = HPUFusedRMSNorm.apply(residual.float(), self.weight.float(), + x = HPUFusedRMSNorm.apply(residual, self.weight, self.variance_epsilon) - return x.to(orig_dtype).view(orig_shape), residual + return x.view(orig_shape), residual - orig_dtype = x.dtype - x = HPUFusedRMSNorm.apply(x.float(), self.weight.float(), - self.variance_epsilon) - return x.to(orig_dtype) + x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon) + return x def forward_xpu( self, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b6e280ae65049..10c8a95f838da 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -273,6 +273,7 @@ def __init__(self, quant_config, prefix) self.gather_output = gather_output + self.collective_func = tensor_model_parallel_all_gather # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() @@ -334,7 +335,7 @@ def forward(self, input_): output_parallel = self.quant_method.apply(self, input_, bias) if self.gather_output: # All-gather across the partitions. - output = tensor_model_parallel_all_gather(output_parallel) + output = self.collective_func(output_parallel) else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None @@ -723,6 +724,7 @@ def __init__(self, self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results + self.collective_func = tensor_model_parallel_all_reduce # Divide the weight matrix along the last dimension. self.tp_rank = get_tensor_model_parallel_rank() @@ -770,7 +772,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def forward(self, input_): + def resolve_input(self, input_): if self.input_is_parallel: input_parallel = input_ else: @@ -778,6 +780,10 @@ def forward(self, input_): splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size) input_parallel = splitted_input[tp_rank].contiguous() + return input_parallel + + def forward(self, input_): + input_parallel = self.resolve_input(input_) # Matrix multiply. assert self.quant_method is not None diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index bd574512e3431..7590d3e980275 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -18,6 +18,7 @@ GPTQMarlinConfig) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config) +from vllm.model_executor.layers.quantization.inc import INCConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig @@ -37,6 +38,7 @@ "squeezellm": SqueezeLLMConfig, "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, + "inc": INCConfig, } diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py new file mode 100644 index 0000000000000..f6718ec2ac9e7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/inc.py @@ -0,0 +1,115 @@ +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +class INCConfig(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change.") + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + + @classmethod + def get_name(cls) -> str: + return "inc" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "INCConfig": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = ("fp8" in quant_method) + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["INCLinearMethod"]: + if isinstance(layer, LinearBase): + return INCLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def get_min_capability(cls) -> int: + # The AWQ kernel only supports Turing or newer GPUs. + return 75 + + @staticmethod + def get_config_filenames() -> List[str]: + return [] + + +class INCLinearMethod(LinearMethodBase): + """Linear method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, + quant_config: INCConfig, + separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + output_size_per_partition = sum(output_partition_sizes) + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + if self.separate_bias_add: + if bias is not None: + return F.linear(x, weight) + bias + return F.linear(x, weight) + return F.linear(x, weight, bias) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bbe49655020da..06048d97088e1 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -37,7 +37,7 @@ supports_vision) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import is_tpu +from vllm.utils import is_hpu, is_tpu logger = init_logger(__name__) @@ -48,14 +48,15 @@ def _get_quantization_config( """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config(model_config, load_config) - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} is not " - "supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") + if not is_hpu(): + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: raise ValueError( @@ -276,10 +277,11 @@ def load_model(self, *, model_config: ModelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with torch.device(self.load_config.device): model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config, scheduler_config) + logger.info("Loading weights on %s ...", self.load_config.device) model.load_weights( self._get_weights_iterator(model_config.model, model_config.revision, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 306d22e42ed1d..676a51ce67f96 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -48,6 +48,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.utils import is_hip @@ -317,6 +318,9 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + if current_platform.is_hpu(): + import habana_frameworks.torch as htorch + htorch.core.mark_step() for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( @@ -326,6 +330,8 @@ def forward( attn_metadata, residual, ) + if current_platform.is_hpu(): + htorch.core.mark_step() if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/utils.py b/vllm/utils.py index 8a1bc5de03eb7..fe84253feb172 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -39,6 +39,7 @@ "fp8": torch.uint8, "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, + "fp8_inc": torch.float8_e4m3fn, } TORCH_DTYPE_TO_NUMPY_DTYPE = { diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 93be2f4c321fe..ec0b8c2369210 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -91,9 +91,11 @@ def _allocate_kv_cache( # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. + dtype = torch.uint8 if self.dtype == torch.float8_e4m3fn else \ + self.dtype kv_cache.append( torch.zeros(kv_cache_shape, - dtype=self.dtype, + dtype=dtype, pin_memory=pin_memory, device=device)) return kv_cache diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index cf91c69069ed6..72aba42ae8553 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -182,8 +182,8 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, def forward(self, *args, **kwargs): kwargs = kwargs.copy() selected_token_indices = kwargs.pop('selected_token_indices') - if 'bypass_hpu_graphs' in kwargs: - kwargs.pop('bypass_hpu_graphs') # required for PT eager + if 'warmup_mode' in kwargs: + kwargs.pop('warmup_mode') input_ids = kwargs['input_ids'] kwargs['attn_metadata'] = self._set_attn_bias(kwargs['attn_metadata'], input_ids.size(0), @@ -413,6 +413,9 @@ def __init__( self._setup_buckets() def load_model(self) -> None: + import habana_frameworks.torch.core as htcore + if self.model_config.quantization == 'inc': + htcore.hpu_set_env() with HabanaMemoryProfiler() as m: with HabanaMemoryProfiler() as m_getmodel: self.model = get_model( @@ -429,6 +432,26 @@ def load_model(self) -> None: f"took {m_getmodel.get_summary_string()}") logger.info(msg) + if self.model_config.quantization == 'inc': + logger.info("Preparing model with INC..") + with HabanaMemoryProfiler() as m_inc: + from neural_compressor.torch.quantization import ( + FP8Config, convert, prepare) + config = FP8Config.from_json_file( + os.getenv("QUANT_CONFIG", "")) + if config.measure: + self.model = prepare(self.model, config) + elif config.quantize: + self.model = convert(self.model, config) + htcore.hpu_initialize(self.model, + mark_only_scales_as_const=True) + logger.info("Preparing model with INC took %s", + m_inc.get_summary_string()) + else: + self.model = self.model.to("hpu") + htcore.mark_step() + torch.hpu.synchronize() + # FIXME: Running with disable_tensor_cache=True causes # RuntimeErrors. This needs to be debugged with HabanaMemoryProfiler() as m_wrap: @@ -1051,7 +1074,7 @@ def warmup_scenario(self, batch_size, seq_len, is_prompt, torch.hpu.synchronize() for _ in range(times): inputs = self.prepare_model_input(seqs) - self.execute_model(inputs, kv_caches) + self.execute_model(inputs, kv_caches, warmup_mode=True) torch.hpu.synchronize() self.profiler.end() gc.collect() @@ -1362,6 +1385,10 @@ def prepare_model_input( is_prompt=is_prompt, virtual_engine=virtual_engine) + def finish_measurements(self): + from neural_compressor.torch.quantization import finalize_calibration + finalize_calibration(self.model.model) + @torch.inference_mode() def execute_model( self, @@ -1369,6 +1396,7 @@ def execute_model( kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, + warmup_mode=False, ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( @@ -1402,6 +1430,11 @@ def execute_model( } if multi_modal_input is not None: execute_model_kwargs.update(multi_modal_input) + if htorch.utils.internal.is_lazy(): + execute_model_kwargs.update({ + "bypass_hpu_graphs": not use_graphs, + "warmup_mode": warmup_mode + }) htorch.core.mark_step() if self.is_driver_worker: @@ -1415,9 +1448,8 @@ def execute_model( with self.profiler.record_event('internal', model_event_name): hidden_states = self.model.forward( **execute_model_kwargs, - selected_token_indices=sampling_metadata. - selected_token_indices, - bypass_hpu_graphs=not use_graphs) + selected_token_indices=sampling_metadata.selected_token_indices + ) # Compute the logits. with self.profiler.record_event( @@ -1459,3 +1491,16 @@ def execute_model( is_prompt=is_prompt) self.profiler.record_counter(self.event_start, counters) return [output] + + def shutdown_inc(self): + print('inc shutdown') + if (model_config := getattr(self, "model_config", None)) and \ + getattr(model_config, "quantization", None) == 'inc': + print('inc shutdown start') + from neural_compressor.torch.quantization import ( + finalize_calibration) + finalize_calibration(self.model.model) + print('inc shutdown') + + def __del__(self): + self.shutdown_inc() diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index f3fdc4dcc63c6..87122c03d3c8f 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -91,6 +91,16 @@ def __init__( # Initialize gpu_cache as embedding models don't initialize kv_caches self.hpu_cache: Optional[List[List[torch.tensor]]] = None + def _set_env_vars(self): + local_rank = self.local_rank + if self.parallel_config.world_size == 1: + local_rank = -1 + import os + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["ID"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(self.parallel_config.world_size) + os.environ["RANK"] = str(self.rank) + def init_device(self) -> None: if self.device_config.device.type == "hpu": self.device = torch.device("hpu") @@ -99,6 +109,8 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. + if self.model_config.quantization == 'inc': + self._set_env_vars() init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, self.local_rank) @@ -211,6 +223,9 @@ def _warm_up_model(self) -> None: # the model initialization and profiling. set_random_seed(self.model_config.seed) + def finish_measurements(self): + self.model_runner.finish_measurements() + @property def do_metadata_broadcast(self) -> bool: return self.parallel_config.tensor_parallel_size > 1 @@ -288,6 +303,12 @@ def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: def list_prompt_adapters(self) -> Set[int]: raise NotImplementedError("LoRA is not implemented for HPU backend.") + def shutdown_inc(self): + self.model_runner.shutdown_inc() + + def __del__(self): + self.shutdown_inc() + @property def max_model_len(self) -> int: return self.model_config.max_model_len