diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 457450cda2ce6..e96be5665257f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -168,6 +168,25 @@ def forward_tpu( topk=top_k, gating_output=router_logits, renormalize=renormalize) + def forward_native(self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None): + return self.forward_hpu(layer=layer, + x=x, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + router_logits=router_logits, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) class FusedMoE(torch.nn.Module): diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 85cd700c978ea..badc0ad51d84e 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -598,7 +598,7 @@ def __init__( def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: pos_freqs = self.base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / + 0, self.rotary_dim, 2, dtype=torch.float, device="hpu") / self.rotary_dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) @@ -617,7 +617,7 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) t = torch.arange(self.max_position_embeddings * self.scaling_factor, - device="cuda", + device="hpu", dtype=torch.float32) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = (freqs.cos() * self.mscale) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 702be7b7f5ed9..383f3a6b2641d 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -118,9 +118,9 @@ def __init__( reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, + use_grouped_topk=False, + #num_expert_group=config.n_group, + #topk_group=config.topk_group, prefix=f"{prefix}.experts") self.gate = ReplicatedLinear(config.hidden_size, @@ -284,9 +284,15 @@ def forward( else: q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, self.qk_head_dim) + # need reshape from tensor(x0, y0) to tensor(x1) for hpu + _batch_size = positions.shape[0] + positions = positions.reshape(positions.shape[0] * positions.shape[1]) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + # need reshape from tensor(x0, y0, z0) to tensor(x1, y1) for hpu + if len(latent_cache.shape) == 3: + latent_cache = latent_cache.reshape(latent_cache.shape[0] * latent_cache.shape[1], latent_cache.shape[2]) kv_a, _ = latent_cache.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) @@ -310,7 +316,13 @@ def forward( v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(-1, self.num_local_heads * 256) + # need restore from tensor(x0, y0) to tensor(x1, y1, z1) for hpu + q = q.reshape(_batch_size, q.shape[0] // _batch_size, q.shape[1]) + k = k.reshape(_batch_size, k.shape[0] // _batch_size, k.shape[1]) + v = v.reshape(_batch_size, v.shape[0] // _batch_size, v.shape[1]) attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + # need restore from tensor(x0, y0, z0) to tensor(x1, y1) for hpu + attn_output = attn_output.reshape(attn_output.shape[0] * attn_output.shape[1], attn_output.shape[2]) attn_output = attn_output.view( -1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape( -1, self.num_local_heads * self.v_head_dim) @@ -328,6 +340,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -395,7 +408,9 @@ def forward( kv_cache=kv_cache, attn_metadata=attn_metadata, ) - + # need reshape from tensor(x0, y0, z0) to tensor(x1, y1) for hpu + if len(residual.shape) == 3: + residual = residual.reshape(residual.shape[0] * residual.shape[1], residual.shape[2]) # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual)