Skip to content

Commit

Permalink
use flash_attn_with_kvcache for faster inference (#2539)
Browse files Browse the repository at this point in the history
* use flash_attn_with_kvcache
* patch rmsnorm for multiexperts
* rope theta as an option
  • Loading branch information
vince62s authored Dec 26, 2023
1 parent 05cde4d commit 0436cdd
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 39 deletions.
20 changes: 15 additions & 5 deletions onmt/decoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
from onmt.modules.position_ffn import ActivationFunction
from onmt.modules.moe import MoE
from onmt.utils.misc import sequence_mask

try:
from apex.normalization import FusedRMSNorm as RMSNorm
except ImportError:
from onmt.modules.rmsnorm import RMSNorm
from onmt.modules.rmsnorm import RMSNorm


class TransformerDecoderLayerBase(nn.Module):
Expand Down Expand Up @@ -44,6 +40,7 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -89,6 +86,7 @@ def __init__(
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
rotary_interleave (bool): Interleave the head dimensions when rotary
embeddings are applied
rotary_theta (int): rotary base theta
"""
super(TransformerDecoderLayerBase, self).__init__()

Expand All @@ -100,6 +98,7 @@ def __init__(
max_relative_positions=max_relative_positions,
relative_positions_buckets=relative_positions_buckets,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
attn_type="self",
self_attn_type=self_attn_type,
add_qkvbias=add_qkvbias,
Expand Down Expand Up @@ -280,6 +279,7 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -311,6 +311,7 @@ def __init__(
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
Expand Down Expand Up @@ -473,6 +474,7 @@ def from_opt(cls, opt, embeddings):
else 1,
sliding_window=opt.sliding_window,
rotary_interleave=opt.rotary_interleave,
rotary_theta=opt.rotary_theta,
num_experts=opt.num_experts,
num_experts_per_tok=opt.num_experts_per_tok,
)
Expand Down Expand Up @@ -563,6 +565,7 @@ class TransformerDecoder(TransformerDecoderBase):
parallel_gpu (int): Number of gpu for tensor parallelism
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
rotary_theta (int): rotary base theta
"""

def __init__(
Expand Down Expand Up @@ -594,6 +597,7 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -627,6 +631,7 @@ def __init__(
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
Expand Down Expand Up @@ -834,6 +839,7 @@ class TransformerLMDecoder(TransformerDecoderBase):
parallel_gpu (int): Number of gpu for tensor parallelism
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
rotary_theta (int): rotary base theta
"""

def __init__(
Expand Down Expand Up @@ -865,6 +871,7 @@ def __init__(
parallel_gpu=1,
sliding_window=0,
rotary_interleave=True,
rotary_theta=1e4,
num_experts=0,
num_experts_per_tok=2,
):
Expand Down Expand Up @@ -897,6 +904,7 @@ def __init__(
parallel_gpu=parallel_gpu,
sliding_window=sliding_window,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
num_experts=num_experts,
num_experts_per_tok=num_experts_per_tok,
)
Expand Down Expand Up @@ -976,3 +984,5 @@ def _init_cache(self, tgt=None):
)
if hasattr(layer.self_attn, "rope"):
layer.self_attn.rope = layer.self_attn.rope.to(tgt.device)
layer.self_attn.cos = layer.self_attn.cos.to(tgt.device)
layer.self_attn.sin = layer.self_attn.sin.to(tgt.device)
6 changes: 6 additions & 0 deletions onmt/encoders/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class TransformerEncoderLayer(nn.Module):
parallel_gpu (int): Number of gpu for tensor parallelism
rotary_interleave (bool): Interleave the head dimensions when rotary
embeddings are applied
rotary_theta (int): rotary base theta
"""

def __init__(
Expand All @@ -61,6 +62,7 @@ def __init__(
use_ckpting=[],
parallel_gpu=1,
rotary_interleave=True,
rotary_theta=1e4,
):
super(TransformerEncoderLayer, self).__init__()

Expand All @@ -72,6 +74,7 @@ def __init__(
max_relative_positions=max_relative_positions,
relative_positions_buckets=relative_positions_buckets,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
attn_type="self",
add_qkvbias=add_qkvbias,
num_kv=num_kv,
Expand Down Expand Up @@ -177,6 +180,7 @@ def __init__(
use_ckpting=[],
parallel_gpu=1,
rotary_interleave=True,
rotary_theta=1e4,
):
super(TransformerEncoder, self).__init__()

Expand All @@ -201,6 +205,7 @@ def __init__(
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
rotary_interleave=rotary_interleave,
rotary_theta=rotary_theta,
)
for i in range(num_layers)
]
Expand Down Expand Up @@ -239,6 +244,7 @@ def from_opt(cls, opt, embeddings):
if opt.parallel_mode == "tensor_parallel"
else 1,
rotary_interleave=opt.rotary_interleave,
rotary_theta=opt.rotary_theta,
)

def forward(self, src, src_len=None):
Expand Down
4 changes: 4 additions & 0 deletions onmt/modules/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import torch.nn as nn
from onmt.modules.position_ffn import PositionwiseFeedForward
from torch.distributed import all_reduce


class MoE(nn.Module):
Expand Down Expand Up @@ -40,12 +41,15 @@ def __init__(
)
self.gate = nn.Linear(d_model, num_experts, bias=False)
self.num_experts_per_tok = num_experts_per_tok
self.parallel_gpu = parallel_gpu

def forward(self, x):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])

scores = self.gate(x)
if self.parallel_gpu > 1:
all_reduce(scores)
expert_weights, expert_indices = torch.topk(
scores, self.num_experts_per_tok, dim=-1
)
Expand Down
136 changes: 114 additions & 22 deletions onmt/modules/multi_headed_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torch.distributed import all_reduce
from importlib import import_module


# Help functions for Rotary Embeddings
# https://arxiv.org/pdf/2104.09864.pdf
# too convoluted to make maxseqlen a parameter.
Expand Down Expand Up @@ -258,6 +257,7 @@ def __init__(
max_relative_positions: int = 0,
relative_positions_buckets: int = 0,
rotary_interleave: bool = True,
rotary_theta: int = 1e4,
attn_type: str = None,
self_attn_type: str = None,
add_qkvbias=False,
Expand Down Expand Up @@ -352,9 +352,19 @@ def __init__(
self.relative_attention_bias = None

if max_relative_positions == -1: # rotary embeddings
self.rope = rotaryembeddings(self.dim_per_head)
self.rope = rotaryembeddings(self.dim_per_head, base=rotary_theta)
self.cos = (
self.rope[:, : self.rope.size(1) // 2].real.contiguous().half()
)
self.sin = (
self.rope[:, : self.rope.size(1) // 2].imag.contiguous().half()
)
self.rotary_interleave = rotary_interleave

self.rotary_theta = rotary_theta
else:
self.cos = None
self.sin = None
self.rotary_interleave = None
if max_relative_positions == -2: # alibi positional bias
self.alibi = AlibiPositionalBias(head_count)

Expand All @@ -367,6 +377,9 @@ def __init__(
and torch.cuda.get_device_capability()[0] >= 8
):
self.flash_attn_func = getattr(flash_pack, "flash_attn_func")
self.flash_attn_with_kvcache = getattr(
flash_pack, "flash_attn_with_kvcache"
)
self.flash2 = True
except ImportError:
self.flash2 = False
Expand Down Expand Up @@ -420,27 +433,104 @@ def forward(
key = shape(key, self.dim_per_head)
value = shape(value, self.dim_per_head)

if self.max_relative_positions == -1: # Rotary Embeddings
start_pos = step
seqlen = query.size(2)
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head, maxseqlen=(seqlen + 2048)
).to(self.rope.device)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
)
start_pos = step
seqlen = query.size(2)

if (
step == 0
or not self.flash2
or self.max_relative_positions not in [0, -1]
or query.size(0) > 128
or query.dtype != torch.float16
):
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
rope = self.rope[start_pos : start_pos + seqlen]
query, key = apply_rotary_emb(
query, key, rope, interleave=self.rotary_interleave
)

if self.layer_cache[1]["keys"].numel() != 0:
key = torch.cat((self.layer_cache[1]["keys"], key), dim=2)
value = torch.cat((self.layer_cache[1]["values"], value), dim=2)
if sliding_window > 0 and key.size(2) > sliding_window:
key = key[:, :, 1:, :]
value = value[:, :, 1:, :]

self.layer_cache[1]["keys"] = key
self.layer_cache[1]["values"] = value

if self.layer_cache[1]["keys"].numel() != 0:
key = torch.cat((self.layer_cache[1]["keys"], key), dim=2)
value = torch.cat((self.layer_cache[1]["values"], value), dim=2)
else:
if self.max_relative_positions == -1: # Rotary Embeddings
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
self.cos = (
self.rope[:, : self.rope.size(1) // 2]
.real.contiguous()
.half()
)
self.sin = (
self.rope[:, : self.rope.size(1) // 2]
.imag.contiguous()
.half()
)
if start_pos >= self.layer_cache[1]["keys"].size(2):
self.layer_cache[1]["keys"] = torch.cat(
[
self.layer_cache[1]["keys"],
torch.zeros(
self.layer_cache[1]["keys"].shape[:-2]
+ (32,)
+ self.layer_cache[1]["keys"].shape[-1:],
device=query.device,
).half(),
],
dim=-2,
)
self.layer_cache[1]["values"] = torch.cat(
[
self.layer_cache[1]["values"],
torch.zeros(
self.layer_cache[1]["values"].shape[:-2]
+ (32,)
+ self.layer_cache[1]["values"].shape[-1:],
device=query.device,
).half(),
],
dim=-2,
)
if sliding_window > 0 and key.size(2) > sliding_window:
key = key[:, :, 1:, :]
value = value[:, :, 1:, :]
self.layer_cache[1]["keys"] = self.layer_cache[1]["keys"][
:, :, 1:, :
]
self.layer_cache[1]["values"] = self.layer_cache[1]["values"][
:, :, 1:, :
]
context = self.flash_attn_with_kvcache(
query.transpose(1, 2),
self.layer_cache[1]["keys"].transpose(1, 2),
self.layer_cache[1]["values"].transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
rotary_cos=self.cos,
rotary_sin=self.sin,
cache_seqlens=step,
rotary_interleaved=self.rotary_interleave,
).transpose(1, 2)
attn_output = self.final_linear(unshape(context))
if self.parallel_gpu > 1:
all_reduce(attn_output)
return attn_output, None

self.layer_cache[1]["keys"] = key
self.layer_cache[1]["values"] = value
elif self.attn_type == "context":
query = self.linear_query(query)
query = shape(query, self.dim_per_head)
Expand Down Expand Up @@ -484,7 +574,9 @@ def forward(
seqlen = query.size(2)
if seqlen > self.rope.size(0):
self.rope = rotaryembeddings(
self.dim_per_head, maxseqlen=(seqlen + 2048)
self.dim_per_head,
maxseqlen=(seqlen + 2048),
base=self.rotary_theta,
).to(self.rope.device)
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
query, key = apply_rotary_emb(
Expand Down
6 changes: 1 addition & 5 deletions onmt/modules/position_ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

try:
from apex.normalization import FusedRMSNorm as RMSNorm
except ImportError:
from onmt.modules.rmsnorm import RMSNorm
from onmt.modules.rmsnorm import RMSNorm
from torch.nn.utils import skip_init
from torch.distributed import all_reduce

Expand Down
Loading

0 comments on commit 0436cdd

Please sign in to comment.