diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml index e9b6e28fa6bcb..9d40813a98d7a 100644 --- a/.github/workflows/clang-format.yml +++ b/.github/workflows/clang-format.yml @@ -2,13 +2,13 @@ name: clang-format on: # Trigger the workflow on push or pull request, - # but only for the main branch + # but only for the habana_main branch push: branches: - - main + - habana_main pull_request: branches: - - main + - habana_main jobs: clang-format: diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 5780f09a646cb..c2674b914f485 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -2,13 +2,13 @@ name: mypy on: # Trigger the workflow on push or pull request, - # but only for the main branch + # but only for the habana_main branch push: branches: - - main + - habana_main pull_request: branches: - - main + - habana_main jobs: ruff: @@ -50,4 +50,6 @@ jobs: mypy vllm/transformers_utils --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml + mypy vllm/hpu --config-file pyproject.toml + diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 773def58fd966..a2b7aa2549af9 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -2,13 +2,13 @@ name: ruff on: # Trigger the workflow on push or pull request, - # but only for the main branch + # but only for the habana_main branch push: branches: - - main + - habana_main pull_request: branches: - - main + - habana_main jobs: ruff: diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml index 04f307bcf8b0e..4e0d67c5b59d6 100644 --- a/.github/workflows/yapf.yml +++ b/.github/workflows/yapf.yml @@ -2,13 +2,13 @@ name: yapf on: # Trigger the workflow on push or pull request, - # but only for the main branch + # but only for the habana_main branch push: branches: - - main + - habana_main pull_request: branches: - - main + - habana_main jobs: yapf: runs-on: ubuntu-latest diff --git a/format.sh b/format.sh index 5ad6d6f2938bb..fbfc27a68bb3d 100755 --- a/format.sh +++ b/format.sh @@ -113,6 +113,7 @@ mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/transformers_utils --config-file pyproject.toml mypy vllm/usage --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml +mypy vllm/hpu --config-file pyproject.toml # If git diff returns a file that is in the skip list, the file may be checked anyway: diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index e03357a0b32cc..ae936925b5001 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -15,15 +15,18 @@ import vllm.hpu.utils as hpu_utils from vllm.worker.profiler import Profiler +from vllm.logger import init_logger -PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1') - +logger = init_logger() +HPUFusedRMSNorm = None +try: + from habana_frameworks.torch.hpex.normalization import FusedRMSNorm + HPUFusedRMSNorm = FusedRMSNorm +except ImportError: + logger.warning("Could not import HPU FusedRMSNorm kernel. " + "vLLM will use forward_native implementation of RMSNorm.") -def silu_and_mul(output, input): - d = input.shape[-1] // 2 - silu = torch.nn.SiLU().to(input.device) - x, y = torch.split(input, d, dim=-1) - output.copy_(silu(x) * y) +PA_SPLIT_VALUE = (os.environ.get('PA_SPLIT_VALUE', '1') == '1') def fetch_from_cache(cache, blocks, permutations): @@ -66,8 +69,7 @@ def paged_attention_v1(query, keys = [k.unflatten(1, (kv_heads, 1)) for k in keys] mask = mask.unsqueeze(2) - attn_weights = [torch.matmul(query, k) for k in keys] - attn_weights = torch.cat(attn_weights, dim=-1) + attn_weights = torch.cat([torch.matmul(query, k) for k in keys], dim=-1) if alibi_slopes is not None: attn_weights.add_(alibi_slopes[:, :, -attn_weights.size(2):, -attn_weights.size(3):]) @@ -103,12 +105,9 @@ def paged_attention_v1(query, return attn_weights.squeeze(-2) -def silu_and_mul_wrapper(x: torch.Tensor) -> torch.Tensor: +def silu_and_mul(x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 - output_shape = (x.shape[:-1] + (d, )) - out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - silu_and_mul(out, x) - return out + return F.silu(x[..., :d]) * x[..., d:] def static_fused_moe(hidden_states, w1, w2, score, topk): @@ -133,13 +132,10 @@ def static_fused_moe(hidden_states, w1, w2, score, topk): htorch.core.mark_step() for expert_idx in range(num_experts): - padded_weight = padded_weights[expert_idx] - current_state_static = hidden_states.reshape(-1, D) - w_output = silu_and_mul_wrapper( - torch.matmul(current_state_static, w1[expert_idx].transpose(0, 1))) + w_output = torch.matmul(hidden_states, w1[expert_idx].transpose(0, 1)) + w_output = silu_and_mul(w_output) w_output = torch.matmul(w_output, w2[expert_idx].transpose(0, 1)) - current_hidden_states_static = w_output * padded_weight - final_hidden_states += current_hidden_states_static + final_hidden_states += w_output * padded_weights[expert_idx] htorch.core.mark_step() return final_hidden_states.view(-1, D) @@ -166,7 +162,8 @@ def prompt_attention( query = query.unflatten(1, (kv_heads, -1)) key = key.unflatten(1, (kv_heads, 1)) value = value.unflatten(1, (kv_heads, 1)) - attn_bias = attn_bias.unsqueeze(2) + if attn_bias is not None: + attn_bias = attn_bias.unsqueeze(2) attn_weights = torch.matmul(query * scale, key.transpose(-1, -2)) if attn_bias is not None: attn_weights.add_(attn_bias) diff --git a/vllm/hpu/rotary_embed.py b/vllm/hpu/rotary_embed.py index e44bfa2f6210c..30a88d68a24af 100644 --- a/vllm/hpu/rotary_embed.py +++ b/vllm/hpu/rotary_embed.py @@ -20,7 +20,7 @@ except ImportError: logger.warning("Could not import HPU FusedRoPE kernel. " "vLLM will use forward_native implementation of RoPE.") - FusedRoPE = None + FusedRoPE = None else: FusedRoPE = None diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 01429d2fcbd17..55cbbabd7da44 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -6,19 +6,8 @@ from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp -from vllm.utils import is_hpu logger = init_logger(__name__) -if is_hpu(): - try: - from habana_frameworks.torch.hpex.normalization import (FusedRMSNorm as - HPUFusedRMSNorm - ) - except ImportError: - logger.warning( - "Could not import HPU FusedRMSNorm kernel. " - "vLLM will use forward_native implementation of RMSNorm.") - HPUFusedRMSNorm = None class RMSNorm(CustomOp): @@ -86,6 +75,7 @@ def forward_hpu( x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + from vllm.hpu.ops import HPUFusedRMSNorm if HPUFusedRMSNorm is None: return self.forward_native(x, residual) if residual is not None: diff --git a/vllm/utils.py b/vllm/utils.py index c1d0f37eb154f..8a1bc5de03eb7 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -664,7 +664,7 @@ def get_summary_string(self): return ( f"{format_bytes(self.consumed_device_memory)} of device memory " f"({format_bytes(self.final_device_memory)}/" - f"({format_bytes(HabanaMemoryProfiler.total_device_memory())} used)" + f"{format_bytes(HabanaMemoryProfiler.total_device_memory())} used)" f" and {format_bytes(self.consumed_host_memory)} of host memory " f"({format_bytes(self.final_host_memory)}/" f"{format_bytes(HabanaMemoryProfiler.total_host_memory())} used)") diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/habana_model_runner.py index 8a220e2ef0171..cf91c69069ed6 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/habana_model_runner.py @@ -409,7 +409,7 @@ def __init__( # Profiler stats self.profiler_counter_helper = HabanaProfilerCounterHelper() - + self._mem_margin: Optional[int] = None self._setup_buckets() def load_model(self) -> None: @@ -1071,10 +1071,15 @@ def warmup_all_buckets(self, buckets, is_prompt, kv_caches): len(buckets), batch_size, seq_len) self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) - def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches, - available_mem): - total_batch_seq = 0.001 - total_mem = 0 + def warmup_graphs(self, + strategy, + buckets, + is_prompt, + kv_caches, + available_mem, + starting_mem=0, + total_batch_seq=0.001): + total_mem = starting_mem idx = 0 phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' num_candidates = len(buckets) @@ -1088,14 +1093,18 @@ def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches, raise NotImplementedError( f'Unsupported graph allocation strategy: {strategy}') buckets = list(sorted(buckets, key=ordering)) - + captured_all = True for idx, (batch_size, seq_len) in enumerate(buckets): # Graph memory usage is proportional to seq dimension in a batch batch_seq = batch_size * seq_len if is_prompt else batch_size mem_estimate = batch_seq / total_batch_seq * total_mem if mem_estimate >= available_mem: + captured_all = False + continue + graphed_bucket = (batch_size, seq_len, is_prompt) + if graphed_bucket in self.graphed_buckets: continue - self.graphed_buckets.add((batch_size, seq_len, is_prompt)) + self.graphed_buckets.add(graphed_bucket) self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) with HabanaMemoryProfiler() as mem_prof: self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) @@ -1104,6 +1113,12 @@ def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches, available_mem -= used_mem total_mem += used_mem total_batch_seq += batch_seq + + return total_mem, total_batch_seq, captured_all + + def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): + num_candidates = len(buckets) + phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' graphed = list(c[:2] for c in self.graphed_buckets if c[2] == is_prompt) msg = (f'{phase} captured:{len(graphed)} ' @@ -1124,22 +1139,63 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: self.warmup_all_buckets(self.decode_buckets, False, kv_caches) if not self.enforce_eager and htorch.utils.internal.is_lazy(): - mem_margin = 1.0 - float( - os.environ.get('VLLM_GRAPH_MEM_MARGIN', '0.02')) - free_mem = \ - mem_margin * HabanaMemoryProfiler.current_free_device_memory() - free_mem = align_workers(free_mem, torch.distributed.ReduceOp.MIN) + assert self.mem_margin is not None, \ + ("HabanaWorker.determine_num_available_blocks needs " + "to be called before warming up the model.") + free_mem = HabanaMemoryProfiler.current_free_device_memory() + graph_free_mem = free_mem - self.mem_margin + graph_free_mem = align_workers(graph_free_mem, + torch.distributed.ReduceOp.MIN) prompt_graph_mem_ratio = float( os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.5')) - prompt_available_memory = prompt_graph_mem_ratio * free_mem - decode_available_memory = free_mem - prompt_available_memory - prompt_strategy = 'min_tokens' + prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem + decode_available_memory = graph_free_mem - prompt_available_memory + msg = (f"Using {format_bytes(graph_free_mem)}" + f"/{format_bytes(free_mem)} " + "of free device memory for HPUGraphs, " + f"{format_bytes(prompt_available_memory)} for prompt and " + f"{format_bytes(decode_available_memory)} for decode " + f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})") + logger.info(msg) + prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY', + 'min_tokens') decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY', 'max_bs') - self.warmup_graphs(prompt_strategy, self.prompt_buckets, True, - kv_caches, prompt_available_memory) - self.warmup_graphs(decode_strategy, self.decode_buckets, False, - kv_caches, decode_available_memory) + mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ + self.warmup_graphs( + prompt_strategy, self.prompt_buckets, True, kv_caches, + prompt_available_memory) + mem_post_decode, decode_batch_seq, decode_captured_all = \ + self.warmup_graphs( + decode_strategy, self.decode_buckets, False, kv_caches, + decode_available_memory) + + # Not all prompt buckets were captured, but all decode buckets were + # captured and we have some free graph-allocated space left. + # Let's try to use it for capturing more prompt buckets. + if mem_post_decode + mem_post_prompt < graph_free_mem \ + and not prompt_captured_all \ + and decode_captured_all: + mem_post_prompt, _, prompt_captured_all = self.warmup_graphs( + prompt_strategy, self.prompt_buckets, True, kv_caches, + graph_free_mem - mem_post_prompt - mem_post_decode, + mem_post_prompt, prompt_batch_seq) + + # Not all decode buckets were captured, but all prompt buckets were + # captured and we have some free graph-allocated space left. + # Let's try to use it for capturing more decode buckets. + if mem_post_decode + mem_post_prompt < graph_free_mem \ + and not decode_captured_all \ + and prompt_captured_all: + mem_post_decode, _, _ = self.warmup_graphs( + decode_strategy, self.decode_buckets, False, kv_caches, + graph_free_mem - mem_post_prompt - mem_post_decode, + mem_post_decode, decode_batch_seq) + + self.log_graph_warmup_summary(self.prompt_buckets, True, + mem_post_prompt) + self.log_graph_warmup_summary(self.decode_buckets, False, + mem_post_decode) end_time = time.perf_counter() end_mem = HabanaMemoryProfiler.current_device_memory_usage() @@ -1154,6 +1210,14 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: def vocab_size(self) -> int: return self.model_config.get_vocab_size() + @property + def mem_margin(self) -> Optional[int]: + return self._mem_margin + + @mem_margin.setter + def mem_margin(self, value): + self._mem_margin = value + def _maybe_wrap_in_hpu_graph(*args, **kwargs): return htorch.hpu.wrap_in_hpu_graph(HpuModelAdapter( diff --git a/vllm/worker/habana_worker.py b/vllm/worker/habana_worker.py index 6be229e037d06..f3fdc4dcc63c6 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/habana_worker.py @@ -16,14 +16,18 @@ SpeculativeConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest +from vllm.utils import HabanaMemoryProfiler, format_bytes from vllm.worker.cache_engine import CacheEngine from vllm.worker.habana_model_runner import HabanaModelRunner from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput +logger = init_logger(__name__) + class HabanaWorker(LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a HPU. @@ -122,20 +126,37 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - self.model_runner.profile_run() - torch.hpu.synchronize() - + with HabanaMemoryProfiler() as m: + self.model_runner.profile_run() + torch.hpu.synchronize() + msg = ("Model profiling run " + f"took {m.get_summary_string()}") + logger.info(msg) # At this point we should've allocated the maximum workspace for all # recipes we will use the extra memory for graphs/blocks free_hpu_memory = torch.hpu.mem_get_info()[0] cache_block_size = self.get_cache_block_size_bytes() - graph_headroom = 1 - (float( + graph_reserved_mem = (float( os.environ.get('VLLM_GRAPH_RESERVED_MEM', '0.4')) if not self.model_config.enforce_eager else 0) - num_hpu_blocks = int(free_hpu_memory * graph_headroom * - self.cache_config.gpu_memory_utilization // - cache_block_size) + graph_headroom = 1 - graph_reserved_mem + available_hpu_memory = free_hpu_memory * \ + self.cache_config.gpu_memory_utilization + hpu_memory_margin = free_hpu_memory * ( + 1 - self.cache_config.gpu_memory_utilization) + self.model_runner.mem_margin = hpu_memory_margin + cache_size_bytes = available_hpu_memory * graph_headroom + graph_headroom_bytes = available_hpu_memory * (1 - graph_headroom) + msg = ( + f"Free device memory: {format_bytes(free_hpu_memory)}, " + f"{format_bytes(available_hpu_memory)} usable " + f"(gpu_memory_utilization={self.cache_config.gpu_memory_utilization})," + f" {format_bytes(graph_headroom_bytes)} reserved for HPUGraphs " + f"(VLLM_GRAPH_RESERVED_MEM={graph_reserved_mem}), " + f"{format_bytes(cache_size_bytes)} reserved for KV cache") + logger.info(msg) + num_hpu_blocks = int(cache_size_bytes // cache_block_size) num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size) num_hpu_blocks = max(num_hpu_blocks, 0) @@ -161,7 +182,12 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - self._init_cache_engine() + with HabanaMemoryProfiler() as m: + self._init_cache_engine() + torch.hpu.synchronize() + msg = ("Initializing cache engine " + f"took {m.get_summary_string()}") + logger.info(msg) self._warm_up_model() def _init_cache_engine(self):