Skip to content

Commit

Permalink
Trying to debug.
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgiosSmyrnis committed Feb 2, 2024
1 parent 22cd4eb commit 4c322d1
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion open_lm/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def xformers_attn(queries, keys, values, is_causal, document_seqlens=None):
# see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask)
# we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask()
# sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1

print("attention called")
if document_seqlens is None or all(len(ds) == 1 for ds in document_seqlens):
# In this case, all the tokens inside the sequence (are considered to) come from the same document.
# The attention mask is constructed as a simple causal mask
Expand Down
3 changes: 2 additions & 1 deletion open_lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def reset_parameters(self):
torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std)

def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False, document_seqlens=None):
print("attention called")
batchsize, q_len, _ = x.shape
queries, keys, vals = self.in_proj(x).chunk(3, dim=-1)

Expand Down Expand Up @@ -247,6 +248,7 @@ def reset_parameters(self):
torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std)

def forward(self, x, past_key_value=None, use_cache=False, document_seqlens=None):
print("block called")
h, past_key_value = self.attention(
self.attention_norm(x),
is_causal=True,
Expand Down Expand Up @@ -320,7 +322,6 @@ def set_grad_checkpointing(self, enable=True):
def forward(self, input, past_key_values=None, use_cache=False, document_seqlens=None):
x = self.tok_embeddings(input)
x = self.post_embed_norm(x)

if past_key_values is None:
past_key_values = [None] * self.n_layers
elif isinstance(past_key_values, tuple):
Expand Down
12 changes: 6 additions & 6 deletions open_lm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,16 @@ def get_document_seqlens(inputs, args):
document_seqlens = []
for idx in range(inputs.shape[0]):
eot_idx = torch.nonzero(inputs[idx] == SpecialTokens.END_OF_TEXT.value)
if len(eot_idx.shape) == 0:
# Fallback case - an eot token should appear at the end.
document_seqlens.append([args.seq_len + 1])
if eot_idx.shape[0] == 0:
# All tokens come from the same document.
document_seqlens.append([args.seq_len])
else:
start_idx = 0
seqlens = []
for k in range(eot_idx.shape[0]):
seqlens.append(eot_idx[k] - start_idx + 1)
start_idx = eot_idx[k] + 1
if start_idx < args.seq_len + 1:
seqlens.append(eot_idx[k].item() - start_idx + 1)
start_idx = eot_idx[k].item() + 1
if start_idx < args.seq_len:
seqlens.append(args.seq_len - start_idx)
document_seqlens.append(seqlens)
else:
Expand Down

0 comments on commit 4c322d1

Please sign in to comment.