Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use more realistic RoPE tests #1785

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions tests/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_rope_llama_3():
# See https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json for settings
@torch.inference_mode()
def test_rope_llama_3_1():
head_dim = 128
head_dim = 32
rope_theta = 50_000

their_rope_config = {
Expand All @@ -130,15 +130,16 @@ def test_rope_llama_3_1():

config = LlamaConfig(
rope_theta=rope_theta,
rope_scaling=their_rope_config
rope_scaling=their_rope_config,
head_dim=head_dim
)

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3")
batch_size, seq_len = 1, 10
batch_size, seq_len = 1, 131_072
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
Expand Down Expand Up @@ -169,7 +170,7 @@ def test_rope_llama_3_1():
# See https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json for settings
@torch.inference_mode()
def test_rope_llama_3_2():
head_dim = 128
head_dim = 32
rope_theta = 50_000

their_rope_config = {
Expand All @@ -189,15 +190,16 @@ def test_rope_llama_3_2():

config = LlamaConfig(
rope_theta=rope_theta,
rope_scaling=their_rope_config
rope_scaling=their_rope_config,
head_dim=head_dim
)

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3")
batch_size, seq_len = 1, 10
batch_size, seq_len = 1, 131_072
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)
Expand All @@ -222,4 +224,5 @@ def test_rope_llama_3_2():
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)

Loading