diff --git a/vllm/hpu/ops.py b/vllm/hpu/ops.py index 9e901d2ad0b7b..634ea270b58ec 100644 --- a/vllm/hpu/ops.py +++ b/vllm/hpu/ops.py @@ -231,6 +231,16 @@ def forward(self, state): return torch.matmul(state, self.weight) +def calculate_routing_tensors(score, topk, hidden_states_dtype): + routing_weights = F.softmax(score, dim=1, dtype=torch.float32) + routing_weights, selected_experts = torch.topk(routing_weights, + topk, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states_dtype) + return routing_weights, selected_experts + + class StaticFusedMOE(torch.nn.Module): def __init__(self, num_total_experts): @@ -243,12 +253,8 @@ def __init__(self, num_total_experts): def forward(self, hidden_states, w1, w2, score, topk): B, D = hidden_states.shape - routing_weights = F.softmax(score, dim=1, dtype=torch.float32) - routing_weights, selected_experts = torch.topk(routing_weights, - topk, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - routing_weights = routing_weights.to(hidden_states.dtype) + routing_weights, selected_experts = calculate_routing_tensors( + score, topk, hidden_states.dtype) final_hidden_states = torch.zeros((1, B, D), dtype=hidden_states.dtype, device=hidden_states.device) @@ -271,3 +277,33 @@ def forward(self, hidden_states, w1, w2, score, topk): final_hidden_states += current_hidden_states_static return final_hidden_states.view(-1, D) + + +class DynamicFusedMOE(torch.nn.Module): + + def __init__(self, num_total_experts): + super().__init__() + self.num_total_experts = num_total_experts + + def forward(self, hidden_states, w1, w2, score, topk): + htorch.core.mark_step() + routing_weights, selected_experts = calculate_routing_tensors( + score, topk, hidden_states.dtype) + # pre-processing for custom op inputs + experts_range = range(self.num_total_experts) + w1_list = [w1[i,:,:].squeeze() for i in experts_range] + w2_list = [w2[i,:,:].squeeze() for i in experts_range] + + final_hidden_states = torch.ops.hpu.mixture_of_experts( + hidden_states=hidden_states, + expert_routing_table=selected_experts, + router_weights=routing_weights, + w12=w1_list, + w3=w2_list, + permuted_weights=True, + activation="silu", + experts_min=0, + experts_max=7 + ) + + return final_hidden_states.view(-1, hidden_states.shape[1]) \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index cf0d5f98f1b01..589cd73871af8 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -202,8 +202,14 @@ def __init__( self.num_expert_group = num_expert_group self.topk_group = topk_group if is_hpu(): - from vllm.hpu.ops import StaticFusedMOE - self.hpu_static_fused_moe = StaticFusedMOE(self.num_experts) + from vllm.hpu.ops import StaticFusedMOE, DynamicFusedMOE + from vllm.model_executor.layers.quantization.inc import INCConfig + selected_fused_moe = ( + StaticFusedMOE + if isinstance(quant_config, INCConfig) + else DynamicFusedMOE + ) + self.hpu_static_fused_moe = selected_fused_moe(self.num_experts) if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -254,24 +260,25 @@ def weight_loader(self, param: torch.nn.Parameter, shard_size = self.intermediate_size_per_partition shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + from vllm.hpu.ops import StaticFusedMOE # w1, gate_proj case: Load into first shard of w13. if shard_id == 0: param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - if is_hpu(): + if is_hpu() and isinstance(self.hpu_static_fused_moe, StaticFusedMOE): self.hpu_static_fused_moe.w13_list[expert_id].set_weight( param_data[expert_id]) # w3, up_proj case: Load into second shard of w13. elif shard_id == 2: param_data[expert_id, shard_size:2 * shard_size, :] = loaded_weight[shard, :] - if is_hpu(): + if is_hpu() and isinstance(self.hpu_static_fused_moe, StaticFusedMOE): self.hpu_static_fused_moe.w13_list[expert_id].set_weight( param_data[expert_id]) # w2, down_proj case: Load into only shard of w2. elif shard_id == 1: param_data[expert_id, :, :] = loaded_weight[:, shard] - if is_hpu(): + if is_hpu() and isinstance(self.hpu_static_fused_moe, StaticFusedMOE): self.hpu_static_fused_moe.w2_list[expert_id].set_weight( param_data[expert_id]) else: