Skip to content

Commit

Permalink
Add DeepSeek-V2-Lite/DeepSeek-V2-Lite-Chat model support
Browse files Browse the repository at this point in the history
export VLLM_TORCH_COMPILE_LEVEL=3 is a must to w/a RMSNorm bug in
habana_frameworks/torch/hpex/normalization/FusedRMSNorm.py
  • Loading branch information
hlin99 committed Oct 21, 2024
1 parent 07c98a5 commit db392a6
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 6 deletions.
19 changes: 19 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,25 @@ def forward_tpu(
topk=top_k,
gating_output=router_logits,
renormalize=renormalize)
def forward_native(self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None):
return self.forward_hpu(layer=layer,
x=x,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
router_logits=router_logits,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)


class FusedMoE(torch.nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def __init__(

def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
0, self.rotary_dim, 2, dtype=torch.float, device="hpu") /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
Expand All @@ -617,7 +617,7 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
device="hpu",
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
Expand Down
23 changes: 19 additions & 4 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def __init__(
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
use_grouped_topk=False,
#num_expert_group=config.n_group,
#topk_group=config.topk_group,
prefix=f"{prefix}.experts")

self.gate = ReplicatedLinear(config.hidden_size,
Expand Down Expand Up @@ -284,9 +284,15 @@ def forward(
else:
q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads,
self.qk_head_dim)
# need reshape from tensor(x0, y0) to tensor(x1) for hpu
_batch_size = positions.shape[0]
positions = positions.reshape(positions.shape[0] * positions.shape[1])
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
# need reshape from tensor(x0, y0, z0) to tensor(x1, y1) for hpu
if len(latent_cache.shape) == 3:
latent_cache = latent_cache.reshape(latent_cache.shape[0] * latent_cache.shape[1], latent_cache.shape[2])
kv_a, _ = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
Expand All @@ -310,7 +316,13 @@ def forward(
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim],
value=0).view(-1,
self.num_local_heads * 256)
# need restore from tensor(x0, y0) to tensor(x1, y1, z1) for hpu
q = q.reshape(_batch_size, q.shape[0] // _batch_size, q.shape[1])
k = k.reshape(_batch_size, k.shape[0] // _batch_size, k.shape[1])
v = v.reshape(_batch_size, v.shape[0] // _batch_size, v.shape[1])
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
# need restore from tensor(x0, y0, z0) to tensor(x1, y1) for hpu
attn_output = attn_output.reshape(attn_output.shape[0] * attn_output.shape[1], attn_output.shape[2])
attn_output = attn_output.view(
-1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape(
-1, self.num_local_heads * self.v_head_dim)
Expand All @@ -328,6 +340,7 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()

self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
Expand Down Expand Up @@ -395,7 +408,9 @@ def forward(
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)

# need reshape from tensor(x0, y0, z0) to tensor(x1, y1) for hpu
if len(residual.shape) == 3:
residual = residual.reshape(residual.shape[0] * residual.shape[1], residual.shape[2])
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
Expand Down

0 comments on commit db392a6

Please sign in to comment.