Skip to content

Commit

Permalink
refactored dynamic resizing of rotary embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
l-k-11235 committed Feb 21, 2024
1 parent b7e8b8f commit 23c94bb
Showing 1 changed file with 20 additions and 26 deletions.
46 changes: 20 additions & 26 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def rotaryembeddings(dim: int, maxseqlen=2048, base=10000):
# rope is now matrix [maxseqlen, dim/2]
rope = torch.polar(torch.ones_like(rope), rope)
rope = torch.cat((rope, rope), dim=1)
return rope
cos = rope[:, : rope.size(1) // 2].real.contiguous().half()
sin = rope[:, : rope.size(1) // 2].imag.contiguous().half()
return rope, cos, sin


def rotate_half(x):
Expand Down Expand Up @@ -369,12 +371,8 @@ def __init__(
self.rotary_dim = self.dim_per_head
else:
self.rotary_dim = rotary_dim
self.rope = rotaryembeddings(self.rotary_dim, base=rotary_theta)
self.cos = (
self.rope[:, : self.rope.size(1) // 2].real.contiguous().half()
)
self.sin = (
self.rope[:, : self.rope.size(1) // 2].imag.contiguous().half()
self.rope, self.cos, self.sin = rotaryembeddings(
self.rotary_dim, base=rotary_theta
)
self.rotary_interleave = rotary_interleave
self.rotary_theta = rotary_theta
Expand Down Expand Up @@ -465,11 +463,13 @@ def forward(
):
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(

self.rope, self.cos, self.sin = rotaryembeddings(
self.rotary_dim,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
)
self.rope = self.rope.to(self.rope.device)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
Expand Down Expand Up @@ -511,31 +511,24 @@ def forward(
],
dim=-2,
)
seqlen = self.layer_cache[1]["keys"].size(2)
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
if self.max_relative_positions == -1: # Rotary
# Resize rotary embeddings.
self.rope, self.cos, self.sin = rotaryembeddings(
self.rotary_dim,
maxseqlen=(seqlen + 2048),
maxseqlen=(start_pos + 2048),
base=self.rotary_theta,
).to(self.rope.device)
self.cos = (
self.rope[:, : self.rope.size(1) // 2]
.real.contiguous()
.half()
)
self.sin = (
self.rope[:, : self.rope.size(1) // 2]
.imag.contiguous()
.half()
)
self.rope = self.rope.to(self.rope.device)

if sliding_window > 0 and key.size(2) > sliding_window:
self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][
:, :, 1:, :
]
self.layer_cache[1]["values"] = self.layer_cache[1]["values"][
:, :, 1:, :
]
self.cos = self.cos.to(query.device)
self.sin = self.sin.to(query.device)
context = self.flash_attn_with_kvcache(
query.transpose(1, 2),
self.layer_cache[1]["keys"].transpose(1, 2),
Expand Down Expand Up @@ -594,11 +587,12 @@ def forward(
start_pos = 0
seqlen = query.size(2)
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
# Resize rotary embeddings.
self.rope, self.cos, self.sin = rotaryembeddings(
self.rotary_dim,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
)
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
Expand Down

0 comments on commit 23c94bb

Please sign in to comment.