From f4c44063863a1f9bfc2c0bfcb811278628340342 Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Thu, 1 Feb 2024 10:16:36 +0100 Subject: [PATCH] fix spacing and fast rms not for training (#2558) --- onmt/inputters/text_utils.py | 2 +- onmt/modules/rmsnorm.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/onmt/inputters/text_utils.py b/onmt/inputters/text_utils.py index 83da07cc62..a0ed39c859 100644 --- a/onmt/inputters/text_utils.py +++ b/onmt/inputters/text_utils.py @@ -14,7 +14,7 @@ def parse_features(line, n_feats=0, defaults=None): text, feats = [], [[] for _ in range(n_feats)] check, count = 0, 0 for token in line.split(" "): - tok, *fts = token.strip().split("│") + tok, *fts = token.strip("\n").split("│") check += len(fts) count += 1 if not fts and defaults is not None: diff --git a/onmt/modules/rmsnorm.py b/onmt/modules/rmsnorm.py index 3d8515aa5f..a25d08b27e 100644 --- a/onmt/modules/rmsnorm.py +++ b/onmt/modules/rmsnorm.py @@ -24,16 +24,17 @@ def __init__(self, hidden_size: int, eps: float = 1e-6): self.weight = nn.Parameter(torch.ones(hidden_size)) def forward(self, hidden_states): - if AWQ_INFERENCE_ENGINE: - output = torch.empty_like(hidden_states) + if AWQ_INFERENCE_ENGINE and not self.training: + inp_type = hidden_states.dtype + output = torch.empty_like(hidden_states).to(inp_type) if hidden_states.dim() == 2: # patch for multi experts hidden_states = hidden_states.unsqueeze(0) awq_inference_engine.layernorm_forward_cuda( - hidden_states, self.weight, output, self.eps + hidden_states.half(), self.weight.half(), output.half(), self.eps ) if hidden_states.dim() == 2: # patch for multi experts output = output.unsqueeze(0) - return output + return output.to(inp_type) else: hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True)