From 3b34893436ee67c2cd1e6f34f5b4d56bd154241c Mon Sep 17 00:00:00 2001 From: Nir David Date: Thu, 1 Aug 2024 19:44:30 +0300 Subject: [PATCH] Inc on vLLM -Change RMS norm to BF16 --- vllm/model_executor/layers/layernorm.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 55cbbabd7da44..7434d02b60ada 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -79,18 +79,16 @@ def forward_hpu( if HPUFusedRMSNorm is None: return self.forward_native(x, residual) if residual is not None: - orig_dtype = x.dtype orig_shape = x.shape residual += x.view(residual.shape) # Note: HPUFusedRMSNorm requires 3D tensors as inputs - x = HPUFusedRMSNorm.apply(residual.float(), self.weight.float(), + x = HPUFusedRMSNorm.apply(residual, self.weight, self.variance_epsilon) - return x.to(orig_dtype).view(orig_shape), residual + return x.view(orig_shape), residual - orig_dtype = x.dtype - x = HPUFusedRMSNorm.apply(x.float(), self.weight.float(), + x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon) - return x.to(orig_dtype) + return x def forward_xpu( self,