Skip to content

Commit

Permalink
Fix generation with large sequences (#2561)
Browse files Browse the repository at this point in the history
* fixed a bug in generation when the sequence is larger than 2048 tokens
* fixed the dynamic resizing of rotary embeddings
  • Loading branch information
l-k-11235 authored Feb 22, 2024
1 parent 5deb20e commit 0e72326
Showing 1 changed file with 31 additions and 30 deletions.
61 changes: 31 additions & 30 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@
# are both < 2048 tokens.


def rotaryembeddings(dim: int, maxseqlen=2048, base=10000):
def rotaryembeddings(dim: int, maxseqlen=2048, base=10000, device=None):
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
tmax = torch.arange(maxseqlen, device=inv_freq.device)
rope = torch.outer(tmax, inv_freq).float()
# rope is now matrix [maxseqlen, dim/2]
rope = torch.polar(torch.ones_like(rope), rope)
rope = torch.cat((rope, rope), dim=1)
return rope
if device is not None:
rope = rope.to(device)
cos = rope[:, : rope.size(1) // 2].real.contiguous().half()
sin = rope[:, : rope.size(1) // 2].imag.contiguous().half()
return rope, cos, sin


def rotate_half(x):
Expand Down Expand Up @@ -369,12 +373,8 @@ def __init__(
self.rotary_dim = self.dim_per_head
else:
self.rotary_dim = rotary_dim
self.rope = rotaryembeddings(self.rotary_dim, base=rotary_theta)
self.cos = (
self.rope[:, : self.rope.size(1) // 2].real.contiguous().half()
)
self.sin = (
self.rope[:, : self.rope.size(1) // 2].imag.contiguous().half()
self.rope, self.cos, self.sin = rotaryembeddings(
self.rotary_dim, base=rotary_theta
)
self.rotary_interleave = rotary_interleave
self.rotary_theta = rotary_theta
Expand Down Expand Up @@ -465,11 +465,13 @@ def forward(
):
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(

self.rope, _, _ = rotaryembeddings(
self.rotary_dim,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
device=self.rope.device,
)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
Expand All @@ -486,23 +488,6 @@ def forward(
self.layer_cache[1]["values"] = value

else:
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.rotary_dim,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
self.cos = (
self.rope[:, : self.rope.size(1) // 2]
.real.contiguous()
.half()
)
self.sin = (
self.rope[:, : self.rope.size(1) // 2]
.imag.contiguous()
.half()
)
if start_pos >= self.layer_cache[1]["keys"].size(2):
self.layer_cache[1]["keys"] = torch.cat(
[
Expand All @@ -528,6 +513,20 @@ def forward(
],
dim=-2,
)
if (
self.max_relative_positions == -1
and start_pos + 32 >= self.rope.size(0)
):
# Resize rotary embeddings.
# We take a margin of 32 tokens as the kv_cache
# is incremented by 32 tokens every 32 tokens.
self.rope, self.cos, self.sin = rotaryembeddings(
self.rotary_dim,
maxseqlen=(start_pos + 2048),
base=self.rotary_theta,
device=self.rope.device,
)

if sliding_window > 0 and key.size(2) > sliding_window:
self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][
:, :, 1:, :
Expand Down Expand Up @@ -593,12 +592,14 @@ def forward(
start_pos = 0
seqlen = query.size(2)
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
# Resize rotary embeddings.
self.rope, self.cos, self.sin = rotaryembeddings(
self.rotary_dim,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
device=query.device,
)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
)
Expand Down

0 comments on commit 0e72326

Please sign in to comment.