Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve rope #1745

Merged
merged 15 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class Config:
intermediate_size: Optional[int] = None
rope_condense_ratio: int = 1
rope_base: int = 10000
rope_adjustments: Optional[dict] = None
n_expert: int = 0
n_expert_per_token: int = 0
attention_logit_softcapping: Optional[float] = None
Expand Down Expand Up @@ -893,6 +894,7 @@ def norm_class(self) -> Type:
mlp_class_name="LLaMAMLP",
intermediate_size=14336,
rope_base=500000,
rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192)
),
# https://huggingface.co/meta-llama/Meta-Llama-3-70B/blob/main/config.json
dict(
Expand Down Expand Up @@ -931,6 +933,7 @@ def norm_class(self) -> Type:
mlp_class_name="LLaMAMLP",
intermediate_size=28672,
rope_base=500000,
rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192)
),
# https://huggingface.co/meta-llama/Meta-Llama-3.1-405B/blob/main/config.json
dict(
Expand All @@ -950,8 +953,9 @@ def norm_class(self) -> Type:
mlp_class_name="LLaMAMLP",
intermediate_size=53248,
rope_base=500000,
rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192)
),
# https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json
# https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/config.json
dict(
name="Llama-3.2-1B{}",
hf_config=dict(org="meta-llama", name="Llama-3.2-1B{}"),
Expand All @@ -969,6 +973,7 @@ def norm_class(self) -> Type:
mlp_class_name="LLaMAMLP",
intermediate_size=8192,
rope_base=500000,
rope_adjustments=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192)
),
# https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json
dict(
Expand All @@ -988,6 +993,7 @@ def norm_class(self) -> Type:
mlp_class_name="LLaMAMLP",
intermediate_size=8192,
rope_base=500000,
rope_adjustments=dict(factor=32.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192)
),
]
for c in llama_3:
Expand Down
81 changes: 79 additions & 2 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,39 @@ def from_name(cls, name: str, **kwargs: Any) -> Self:
return cls(Config.from_name(name, **kwargs))

def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]:

if self.config.rope_adjustments is None:
extra_config = None

else:
adjusted_params_required = ["factor", "low_freq_factor", "high_freq_factor", "original_max_seq_len"]
params_present = [param in self.config.rope_adjustments for param in adjusted_params_required]
num_params_present = sum(params_present)

if num_params_present == 0:
extra_config = None # uses standard RoPE
elif num_params_present == 4:
# These parameters should always be used together so that we don't interfere with standard rope
extra_config = {
"original_max_seq_len": self.config.rope_adjustments["original_max_seq_len"],
"factor": self.config.rope_adjustments["factor"],
"low_freq_factor": self.config.rope_adjustments["low_freq_factor"],
"high_freq_factor": self.config.rope_adjustments["high_freq_factor"],
}
else:
# Some but not all parameters are specified; raise an error
raise ValueError(
"The following adjusted RoPE parameters are missing in rope_adjustments."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you forgot to add the list of missing parameters in the error message.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I am addressing it here: #1781

"All adjusted RoPE parameters must be specified together."
)

return build_rope_cache(
seq_len=self.max_seq_length,
n_elem=self.config.rope_n_elem,
device=device,
condense_ratio=self.config.rope_condense_ratio,
base=self.config.rope_base,
extra_config=extra_config,
)

def set_kv_cache(
Expand Down Expand Up @@ -410,17 +437,67 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


def build_rope_cache(
seq_len: int, n_elem: int, device: Optional[torch.device] = None, base: int = 10000, condense_ratio: int = 1
seq_len: int,
n_elem: int,
device: Optional[torch.device] = None,
base: int = 10000,
condense_ratio: int = 1,
extra_config: Optional[dict] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Enhanced Transformer with Rotary Position Embedding.

Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.

Args:
seq_len (int): Sequence length.
n_elem (int): Number of elements (head dimension).
device (torch.device, optional): Device for tensor allocations.
base (int, optional): Base for computing inverse frequencies.
condense_ratio (int, optional): Ratio to condense the position indices.
extra_config (dict, optional): Configuration parameters for frequency adjustments (used by Llama 3.1 and 3.2)

Returns:
Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE.
"""
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$

# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ assert n_elem % 2 == 0, "n_elem (head dimension) must be even"
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))

if extra_config is not None:
# Extract configuration parameters
orig_context_len = extra_config["original_max_seq_len"]
factor = extra_config["factor"]
low_freq_factor = extra_config["low_freq_factor"]
high_freq_factor = extra_config["high_freq_factor"]

# Compute wavelength thresholds
low_freq_wavelen = orig_context_len / low_freq_factor
high_freq_wavelen = orig_context_len / high_freq_factor

# Compute wavelengths corresponding to the inverse frequencies
wavelen = 2 * torch.pi / theta

# Initialize adjusted inverse frequencies
adjusted_theta = theta.clone()

# Low Frequency Region: wavelen > low_freq_wavelen
mask_low_freq = wavelen > low_freq_wavelen
adjusted_theta[mask_low_freq] = theta[mask_low_freq] / factor

# Medium Frequency Region: high_freq_wavelen ≤ wavelen ≤ low_freq_wavelen
Andrei-Aksionov marked this conversation as resolved.
Show resolved Hide resolved
mask_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen)
# Compute smooth factor for medium frequencies
ratio = orig_context_len / wavelen[mask_medium_freq]
smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor)
# Interpolate inverse frequencies
adjusted_theta[mask_medium_freq] = (
(1 - smooth_factor) * (theta[mask_medium_freq] / factor)
+ smooth_factor * theta[mask_medium_freq]
)
theta = adjusted_theta

# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=device) / condense_ratio

Expand Down
205 changes: 202 additions & 3 deletions tests/test_rope.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import torch
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding, apply_rotary_pos_emb
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXRotaryEmbedding
from transformers.models.gpt_neox.modeling_gpt_neox import apply_rotary_pos_emb as apply_rotary_pos_emb_gptneo
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb as apply_rotary_pos_emb_llama
from transformers.models.llama.configuration_llama import LlamaConfig

from litgpt.model import apply_rope, build_rope_cache


@torch.inference_mode()
def test_rope():
def test_rope_gptneox():
bs, seq_len, n_head, n_embed = 1, 6, 2, 8
head_size = n_embed // n_head
x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float()
Expand All @@ -22,5 +26,200 @@ def test_rope():
torch.testing.assert_close(ours_sin_cached, theirs_sin.squeeze())

ours_x_rope = apply_rope(x, ours_cos_cached, ours_sin_cached)
theirs_x_rope, _ = apply_rotary_pos_emb(x, x, theirs_cos, theirs_sin, position_ids)
theirs_x_rope, _ = apply_rotary_pos_emb_gptneo(x, x, theirs_cos, theirs_sin, position_ids)
torch.testing.assert_close(ours_x_rope, theirs_x_rope)


@torch.inference_mode()
def test_rope_llama_2():
head_dim = 64
rope_theta = 10_000

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, scaling_factor=None, base=rope_theta)
batch_size, seq_len = 1, 10
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)

# our rope
ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta)
torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos)
torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin)

##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, seq_len, head_dim)
keys = torch.randn(batch_size, num_heads, seq_len, head_dim)

ours_q_rot = apply_rope(queries, ours_cos, ours_sin)
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)


# See https://huggingface.co/meta-llama/Meta-Llama-3-8B/blob/main/config.json for settings
@torch.inference_mode()
def test_rope_llama_3():
head_dim = 64
rope_theta = 50_000

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, scaling_factor=None, base=rope_theta)
batch_size, seq_len = 1, 10
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)

# our rope
ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta)
torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos)
torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin)

##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, seq_len, head_dim)
keys = torch.randn(batch_size, num_heads, seq_len, head_dim)

ours_q_rot = apply_rope(queries, ours_cos, ours_sin)
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)


# See https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json for settings
@torch.inference_mode()
def test_rope_llama_3_1():
head_dim = 128
rope_theta = 50_000

their_rope_config = {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
}

our_rope_config = {
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_seq_len": 8192
}

config = LlamaConfig(
rope_theta=rope_theta,
rope_scaling=their_rope_config
)

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3")
batch_size, seq_len = 1, 10
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)

# our rope
ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config)
torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos)
torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin)

##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, seq_len, head_dim)
keys = torch.randn(batch_size, num_heads, seq_len, head_dim)

ours_q_rot = apply_rope(queries, ours_cos, ours_sin)
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)


# See https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json for settings
@torch.inference_mode()
def test_rope_llama_3_2():
head_dim = 128
rope_theta = 50_000

their_rope_config = {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
}

our_rope_config = {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_seq_len": 8192
}

config = LlamaConfig(
rope_theta=rope_theta,
rope_scaling=their_rope_config
)

##################################
# Compare cos and sin
##################################
# transformer rope
rot_emb = LlamaRotaryEmbedding(head_dim, base=rope_theta, config=config, rope_type="llama3")
batch_size, seq_len = 1, 10
qk_tensor = torch.randn(batch_size, seq_len, head_dim)
position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
theirs_cos, theirs_sin = rot_emb(qk_tensor, position_ids)

# our rope
ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config)
torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos)
torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin)

##################################
# Compare rotated tensors
##################################
# Settings
num_heads = 4

# Dummy query and key tensors
torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, seq_len, head_dim)
keys = torch.randn(batch_size, num_heads, seq_len, head_dim)

ours_q_rot = apply_rope(queries, ours_cos, ours_sin)
ours_k_rot = apply_rope(keys, ours_cos, ours_sin)
theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin)
torch.testing.assert_close(theirs_q_rot, ours_q_rot)
torch.testing.assert_close(theirs_k_rot, ours_k_rot)
Loading