Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix spacing and fast rms not for training #2558

Merged
merged 1 commit into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onmt/inputters/text_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def parse_features(line, n_feats=0, defaults=None):
text, feats = [], [[] for _ in range(n_feats)]
check, count = 0, 0
for token in line.split(" "):
tok, *fts = token.strip().split("│")
tok, *fts = token.strip("\n").split("│")
check += len(fts)
count += 1
if not fts and defaults is not None:
Expand Down
9 changes: 5 additions & 4 deletions onmt/modules/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@ def __init__(self, hidden_size: int, eps: float = 1e-6):
self.weight = nn.Parameter(torch.ones(hidden_size))

def forward(self, hidden_states):
if AWQ_INFERENCE_ENGINE:
output = torch.empty_like(hidden_states)
if AWQ_INFERENCE_ENGINE and not self.training:
inp_type = hidden_states.dtype
output = torch.empty_like(hidden_states).to(inp_type)
if hidden_states.dim() == 2: # patch for multi experts
hidden_states = hidden_states.unsqueeze(0)
awq_inference_engine.layernorm_forward_cuda(
hidden_states, self.weight, output, self.eps
hidden_states.half(), self.weight.half(), output.half(), self.eps
)
if hidden_states.dim() == 2: # patch for multi experts
output = output.unsqueeze(0)
return output
return output.to(inp_type)
else:
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
Expand Down
Loading