From 4c322d15b91fd6afcc3842deaafe0ee9e831dca2 Mon Sep 17 00:00:00 2001 From: George Smyrnis Date: Fri, 2 Feb 2024 17:25:57 -0600 Subject: [PATCH] Trying to debug. --- open_lm/attention.py | 2 +- open_lm/model.py | 3 ++- open_lm/train.py | 12 ++++++------ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/open_lm/attention.py b/open_lm/attention.py index 783fbb63..1f0d8693 100644 --- a/open_lm/attention.py +++ b/open_lm/attention.py @@ -22,7 +22,7 @@ def xformers_attn(queries, keys, values, is_causal, document_seqlens=None): # see (https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask) # we would like to replace the mask generation with: mask = xops.fmha.attn_bias.LowerTriangularFromBottomRightMask() # sadly we cannot us this because it needs xformers>=0.0.23 and this is not compatible with torch<2.1.1 while llm-foundry requires torch<2.1.1 - + print("attention called") if document_seqlens is None or all(len(ds) == 1 for ds in document_seqlens): # In this case, all the tokens inside the sequence (are considered to) come from the same document. # The attention mask is constructed as a simple causal mask diff --git a/open_lm/model.py b/open_lm/model.py index 767bd2ad..511fd867 100644 --- a/open_lm/model.py +++ b/open_lm/model.py @@ -155,6 +155,7 @@ def reset_parameters(self): torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std) def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False, document_seqlens=None): + print("attention called") batchsize, q_len, _ = x.shape queries, keys, vals = self.in_proj(x).chunk(3, dim=-1) @@ -247,6 +248,7 @@ def reset_parameters(self): torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std) def forward(self, x, past_key_value=None, use_cache=False, document_seqlens=None): + print("block called") h, past_key_value = self.attention( self.attention_norm(x), is_causal=True, @@ -320,7 +322,6 @@ def set_grad_checkpointing(self, enable=True): def forward(self, input, past_key_values=None, use_cache=False, document_seqlens=None): x = self.tok_embeddings(input) x = self.post_embed_norm(x) - if past_key_values is None: past_key_values = [None] * self.n_layers elif isinstance(past_key_values, tuple): diff --git a/open_lm/train.py b/open_lm/train.py index 277830aa..63f3c4e6 100644 --- a/open_lm/train.py +++ b/open_lm/train.py @@ -53,16 +53,16 @@ def get_document_seqlens(inputs, args): document_seqlens = [] for idx in range(inputs.shape[0]): eot_idx = torch.nonzero(inputs[idx] == SpecialTokens.END_OF_TEXT.value) - if len(eot_idx.shape) == 0: - # Fallback case - an eot token should appear at the end. - document_seqlens.append([args.seq_len + 1]) + if eot_idx.shape[0] == 0: + # All tokens come from the same document. + document_seqlens.append([args.seq_len]) else: start_idx = 0 seqlens = [] for k in range(eot_idx.shape[0]): - seqlens.append(eot_idx[k] - start_idx + 1) - start_idx = eot_idx[k] + 1 - if start_idx < args.seq_len + 1: + seqlens.append(eot_idx[k].item() - start_idx + 1) + start_idx = eot_idx[k].item() + 1 + if start_idx < args.seq_len: seqlens.append(args.seq_len - start_idx) document_seqlens.append(seqlens) else: