Skip to content

Commit

Permalink
fix spacing and fast rms not for training (#2558)
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s authored Feb 1, 2024
1 parent 43c3300 commit f4c4406
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion onmt/inputters/text_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def parse_features(line, n_feats=0, defaults=None):
text, feats = [], [[] for _ in range(n_feats)]
check, count = 0, 0
for token in line.split(" "):
tok, *fts = token.strip().split("│")
tok, *fts = token.strip("\n").split("│")
check += len(fts)
count += 1
if not fts and defaults is not None:
Expand Down
9 changes: 5 additions & 4 deletions onmt/modules/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@ def __init__(self, hidden_size: int, eps: float = 1e-6):
self.weight = nn.Parameter(torch.ones(hidden_size))

def forward(self, hidden_states):
if AWQ_INFERENCE_ENGINE:
output = torch.empty_like(hidden_states)
if AWQ_INFERENCE_ENGINE and not self.training:
inp_type = hidden_states.dtype
output = torch.empty_like(hidden_states).to(inp_type)
if hidden_states.dim() == 2: # patch for multi experts
hidden_states = hidden_states.unsqueeze(0)
awq_inference_engine.layernorm_forward_cuda(
hidden_states, self.weight, output, self.eps
hidden_states.half(), self.weight.half(), output.half(), self.eps
)
if hidden_states.dim() == 2: # patch for multi experts
output = output.unsqueeze(0)
return output
return output.to(inp_type)
else:
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
Expand Down

0 comments on commit f4c4406

Please sign in to comment.