Skip to content

Commit

Permalink
doc attention eot enum value
Browse files Browse the repository at this point in the history
  • Loading branch information
afang-story authored Feb 2, 2024
1 parent 9c7fbe1 commit 22cd4eb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions open_lm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
if args.mask_across_documents:
# Some input samples contain EOT as the final token. The prediction after that is meaningless, so it
# should not contribute to the loss.
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT, as_tuple=True)
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True)
targets[ignore_indices] = loss.ignore_index

out, _, _ = model(inputs, document_seqlens=document_seqlens)
Expand All @@ -175,7 +175,7 @@ def train_one_epoch(model, data, loss, epoch, step, optimizer, scaler, scheduler
if args.mask_across_documents:
# Some input samples contain EOT as the final token. The prediction after that is meaningless, so it
# should not contribute to the loss.
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT, as_tuple=True)
ignore_indices = torch.nonzero(inputs == SpecialTokens.END_OF_TEXT.value, as_tuple=True)
targets[ignore_indices] = loss.ignore_index

for ii in range(args.accum_freq):
Expand Down

0 comments on commit 22cd4eb

Please sign in to comment.