Skip to content

Commit

Permalink
Fix the Gemma generation (#1094)
Browse files Browse the repository at this point in the history
  • Loading branch information
solitude-alive authored Jun 18, 2024
1 parent 66d1a9c commit ef6e196
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions torchtune/models/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

class GemmaTransformerDecoder(nn.Module):
"""
Transformer Decoder derived from the Gemma architecture. A key difference between
GemmaTransformer Decoder derived from Gemma architecture. A key difference between
the Gemma transformer decoder and :class:`~torchtune.modules.TransformerDecoder`
is that the output projection is replaced instead with a reverse projection
using the transposed token embedding weights from output dim to input dim
Expand Down Expand Up @@ -68,6 +68,12 @@ def __init__(
self.norm_embeddings = norm_embeddings

def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None:
"""Setup key value caches for attention calculation.
Args:
batch_size (int): batch size for the caches.
dtype (torch.dtype): dtype for the caches.
"""
for layer in self.layers:
layer.attn.kv_cache = KVCache(
batch_size=batch_size,
Expand All @@ -93,10 +99,21 @@ def forward(
"""
Args:
tokens (Tensor): input tensor with shape [b x s]
mask (Optional[Tensor]): Optional tensor which contains the attention mask.
Default is None
input_pos (Optional[Tensor]): Optional tensor which contains the position
of the current token. This is only used during inference. Default is None
mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask
with shape [b x s x s]. This is applied after the query-key multiplication and
before the softmax. A value of True in row i and column j means token i attends
to token j. A value of False means token i does not attend to token j. If no
mask is specified, a causal mask is used by default. Default is None.
input_pos (Optional[Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b x s].
During inference, this indicates the position of the current token.
If none, assume the index of the token is its position id. Default is None.
Note: At the very first step of inference, when the model is provided with a prompt,
``input_pos`` would contain the positions of all of the tokens in the prompt
(eg: ``torch.arange(prompt_length)``). This is because we will need to compute the
KV values for each position.
Returns:
Tensor: output tensor with shape [b x s x v]
Expand All @@ -109,6 +126,7 @@ def forward(
- s: sequence length
- v: vocab size
- d: embed dim
- m_s: max seq len
"""
# input tensor of shape [b, s]
bsz, seq_len = tokens.shape
Expand All @@ -127,7 +145,7 @@ def forward(
)
# shape: [1, input_pos_len, m_s]
# in most cases input_pos_len should be 1
mask = self.causal_mask[None, None, input_pos]
mask = self.causal_mask[None, input_pos]

if self.norm_embeddings:
hidden_dim = h.size(-1)
Expand Down

0 comments on commit ef6e196

Please sign in to comment.