Skip to content

Commit

Permalink
fixed the dynamic resizing of rotary embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
l-k-11235 committed Feb 13, 2024
1 parent 54514b8 commit b7e8b8f
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# are both < 2048 tokens.


def rotaryembeddings(dim: int, maxseqlen=4096, base=10000):
def rotaryembeddings(dim: int, maxseqlen=2048, base=10000):
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
tmax = torch.arange(maxseqlen, device=inv_freq.device)
rope = torch.outer(tmax, inv_freq).float()
Expand Down Expand Up @@ -467,7 +467,7 @@ def forward(
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.rotary_dim,
maxseqlen=(seqlen + 4096),
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
rope = self.rope[start_pos : start_pos + seqlen]
Expand All @@ -486,23 +486,6 @@ def forward(
self.layer_cache[1]["values"] = value

else:
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.rotary_dim,
maxseqlen=(seqlen + 4096),
base=self.rotary_theta,
).to(self.rope.device)
self.cos = (
self.rope[:, : self.rope.size(1) // 2]
.real.contiguous()
.half()
)
self.sin = (
self.rope[:, : self.rope.size(1) // 2]
.imag.contiguous()
.half()
)
if start_pos >= self.layer_cache[1]["keys"].size(2):
self.layer_cache[1]["keys"] = torch.cat(
[
Expand All @@ -528,6 +511,24 @@ def forward(
],
dim=-2,
)
seqlen = self.layer_cache[1]["keys"].size(2)
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.rotary_dim,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
self.cos = (
self.rope[:, : self.rope.size(1) // 2]
.real.contiguous()
.half()
)
self.sin = (
self.rope[:, : self.rope.size(1) // 2]
.imag.contiguous()
.half()
)
if sliding_window > 0 and key.size(2) > sliding_window:
self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][
:, :, 1:, :
Expand Down

0 comments on commit b7e8b8f

Please sign in to comment.