Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use flash_attn_with_kvcache for faster inference #2539

Merged
merged 11 commits into from
Dec 26, 2023
20 changes: 15 additions & 5 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
from onmt.modules.position_ffn import ActivationFunction
from onmt.modules.moe import MoE
from onmt.utils.misc import sequence_mask

try:
from apex.normalization import FusedRMSNorm as RMSNorm
except ImportError:
from onmt.modules.rmsnorm import RMSNorm
from onmt.modules.rmsnorm import RMSNorm


class TransformerDecoderLayerBase(nn.Module):
Expand Down Expand Up @@ -44,6 +40,7 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -89,6 +86,7 @@ def __init__(
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
rotary_interleave (bool): Interleave the head dimensions when rotary
embeddings are applied
rotary_theta (int): rotary base theta
"""
super(TransformerDecoderLayerBase, self).__init__()

Expand All @@ -100,6 +98,7 @@ def __init__(
max_relative_positions=max_relative_positions,
relative_positions_buckets=relative_positions_buckets,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
attn_type="self",
self_attn_type=self_attn_type,
add_qkvbias=add_qkvbias,
Expand Down Expand Up @@ -280,6 +279,7 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -311,6 +311,7 @@ def __init__(
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
Expand Down Expand Up @@ -473,6 +474,7 @@ def from_opt(cls, opt, embeddings):
else 1,
sliding_window=opt.sliding_window,
rotary_interleave=opt.rotary_interleave,
rotary_theta=opt.rotary_theta,
num_experts=opt.num_experts,
num_experts_per_tok=opt.num_experts_per_tok,
)
Expand Down Expand Up @@ -563,6 +565,7 @@ class TransformerDecoder(TransformerDecoderBase):
parallel_gpu (int): Number of gpu for tensor parallelism
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
rotary_theta (int): rotary base theta
"""

def __init__(
Expand Down Expand Up @@ -594,6 +597,7 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -627,6 +631,7 @@ def __init__(
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
Expand Down Expand Up @@ -834,6 +839,7 @@ class TransformerLMDecoder(TransformerDecoderBase):
parallel_gpu (int): Number of gpu for tensor parallelism
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
rotary_theta (int): rotary base theta
"""

def __init__(
Expand Down Expand Up @@ -865,6 +871,7 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -897,6 +904,7 @@ def __init__(
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
Expand Down Expand Up @@ -976,3 +984,5 @@ def _init_cache(self, tgt=None):
)
if hasattr(layer.self_attn, "rope"):
layer.self_attn.rope = layer.self_attn.rope.to(tgt.device)
layer.self_attn.cos = layer.self_attn.cos.to(tgt.device)
layer.self_attn.sin = layer.self_attn.sin.to(tgt.device)
6 changes: 6 additions & 0 deletions onmt/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class TransformerEncoderLayer(nn.Module):
parallel_gpu (int): Number of gpu for tensor parallelism
rotary_interleave (bool): Interleave the head dimensions when rotary
embeddings are applied
rotary_theta (int): rotary base theta
"""

def __init__(
Expand All @@ -61,6 +62,7 @@ def __init__(
use_ckpting=[],
parallel_gpu=1,
rotary_interleave=True,
rotary_theta=1e4,
):
super(TransformerEncoderLayer, self).__init__()

Expand All @@ -72,6 +74,7 @@ def __init__(
max_relative_positions=max_relative_positions,
relative_positions_buckets=relative_positions_buckets,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
attn_type="self",
add_qkvbias=add_qkvbias,
num_kv=num_kv,
Expand Down Expand Up @@ -177,6 +180,7 @@ def __init__(
use_ckpting=[],
parallel_gpu=1,
rotary_interleave=True,
rotary_theta=1e4,
):
super(TransformerEncoder, self).__init__()

Expand All @@ -201,6 +205,7 @@ def __init__(
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
)
for i in range(num_layers)
]
Expand Down Expand Up @@ -239,6 +244,7 @@ def from_opt(cls, opt, embeddings):
if opt.parallel_mode == "tensor_parallel"
else 1,
rotary_interleave=opt.rotary_interleave,
rotary_theta=opt.rotary_theta,
)

def forward(self, src, src_len=None):
Expand Down
4 changes: 4 additions & 0 deletions onmt/modules/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import torch.nn as nn
from onmt.modules.position_ffn import PositionwiseFeedForward
from torch.distributed import all_reduce


class MoE(nn.Module):
Expand Down Expand Up @@ -40,12 +41,15 @@ def __init__(
)
self.gate = nn.Linear(d_model, num_experts, bias=False)
self.num_experts_per_tok = num_experts_per_tok
self.parallel_gpu = parallel_gpu

def forward(self, x):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])

scores = self.gate(x)
if self.parallel_gpu > 1:
all_reduce(scores)
expert_weights, expert_indices = torch.topk(
scores, self.num_experts_per_tok, dim=-1
)
Expand Down
136 changes: 114 additions & 22 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torch.distributed import all_reduce
from importlib import import_module


# Help functions for Rotary Embeddings
# https://arxiv.org/pdf/2104.09864.pdf
# too convoluted to make maxseqlen a parameter.
Expand Down Expand Up @@ -258,6 +257,7 @@ def __init__(
max_relative_positions: int = 0,
relative_positions_buckets: int = 0,
rotary_interleave: bool = True,
rotary_theta: int = 1e4,
attn_type: str = None,
self_attn_type: str = None,
add_qkvbias=False,
Expand Down Expand Up @@ -352,9 +352,19 @@ def __init__(
self.relative_attention_bias = None

if max_relative_positions == -1: # rotary embeddings
self.rope = rotaryembeddings(self.dim_per_head)
self.rope = rotaryembeddings(self.dim_per_head, base=rotary_theta)
self.cos = (
self.rope[:, : self.rope.size(1) // 2].real.contiguous().half()
)
self.sin = (
self.rope[:, : self.rope.size(1) // 2].imag.contiguous().half()
)
self.rotary_interleave = rotary_interleave

self.rotary_theta = rotary_theta
else:
self.cos = None
self.sin = None
self.rotary_interleave = None
if max_relative_positions == -2: # alibi positional bias
self.alibi = AlibiPositionalBias(head_count)

Expand All @@ -367,6 +377,9 @@ def __init__(
and torch.cuda.get_device_capability()[0] >= 8
):
self.flash_attn_func = getattr(flash_pack, "flash_attn_func")
self.flash_attn_with_kvcache = getattr(
flash_pack, "flash_attn_with_kvcache"
)
self.flash2 = True
except ImportError:
self.flash2 = False
Expand Down Expand Up @@ -420,27 +433,104 @@ def forward(
key = shape(key, self.dim_per_head)
value = shape(value, self.dim_per_head)

if self.max_relative_positions == -1: # Rotary Embeddings
start_pos = step
seqlen = query.size(2)
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head, maxseqlen=(seqlen + 2048)
).to(self.rope.device)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
)
start_pos = step
seqlen = query.size(2)

if (
step == 0
or not self.flash2
or self.max_relative_positions not in [0, -1]
or query.size(0) > 128
or query.dtype != torch.float16
):
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
)

if self.layer_cache[1]["keys"].numel() != 0:
key = torch.cat((self.layer_cache[1]["keys"], key), dim=2)
value = torch.cat((self.layer_cache[1]["values"], value), dim=2)
if sliding_window > 0 and key.size(2) > sliding_window:
key = key[:, :, 1:, :]
value = value[:, :, 1:, :]

self.layer_cache[1]["keys"] = key
self.layer_cache[1]["values"] = value

if self.layer_cache[1]["keys"].numel() != 0:
key = torch.cat((self.layer_cache[1]["keys"], key), dim=2)
value = torch.cat((self.layer_cache[1]["values"], value), dim=2)
else:
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
self.cos = (
self.rope[:, : self.rope.size(1) // 2]
.real.contiguous()
.half()
)
self.sin = (
self.rope[:, : self.rope.size(1) // 2]
.imag.contiguous()
.half()
)
if start_pos >= self.layer_cache[1]["keys"].size(2):
self.layer_cache[1]["keys"] = torch.cat(
[
self.layer_cache[1]["keys"],
torch.zeros(
self.layer_cache[1]["keys"].shape[:-2]
+ (32,)
+ self.layer_cache[1]["keys"].shape[-1:],
device=query.device,
).half(),
],
dim=-2,
)
self.layer_cache[1]["values"] = torch.cat(
[
self.layer_cache[1]["values"],
torch.zeros(
self.layer_cache[1]["values"].shape[:-2]
+ (32,)
+ self.layer_cache[1]["values"].shape[-1:],
device=query.device,
).half(),
],
dim=-2,
)
if sliding_window > 0 and key.size(2) > sliding_window:
key = key[:, :, 1:, :]
value = value[:, :, 1:, :]
self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][
:, :, 1:, :
]
self.layer_cache[1]["values"] = self.layer_cache[1]["values"][
:, :, 1:, :
]
context = self.flash_attn_with_kvcache(
query.transpose(1, 2),
self.layer_cache[1]["keys"].transpose(1, 2),
self.layer_cache[1]["values"].transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
rotary_cos=self.cos,
rotary_sin=self.sin,
cache_seqlens=step,
rotary_interleaved=self.rotary_interleave,
).transpose(1, 2)
attn_output = self.final_linear(unshape(context))
if self.parallel_gpu > 1:
all_reduce(attn_output)
return attn_output, None

self.layer_cache[1]["keys"] = key
self.layer_cache[1]["values"] = value
elif self.attn_type == "context":
query = self.linear_query(query)
query = shape(query, self.dim_per_head)
Expand Down Expand Up @@ -484,7 +574,9 @@ def forward(
seqlen = query.size(2)
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head, maxseqlen=(seqlen + 2048)
self.dim_per_head,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
query, key = apply_rotary_emb(
Expand Down
6 changes: 1 addition & 5 deletions onmt/modules/position_ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

try:
from apex.normalization import FusedRMSNorm as RMSNorm
except ImportError:
from onmt.modules.rmsnorm import RMSNorm
from onmt.modules.rmsnorm import RMSNorm
from torch.nn.utils import skip_init
from torch.distributed import all_reduce

Expand Down
Loading
Loading