Skip to content

Commit

Permalink
support grouped query attention(MQA & GQA) for flash_attn (#22)
Browse files Browse the repository at this point in the history
* support grouped query attention(GQA) for flash_attn(fwd, bwd, split_kv, total_attention)
* add mqa/gqa into feature list; update documentations and testings for flash attention.
  • Loading branch information
iclementine authored May 27, 2024
1 parent b0045fb commit 13664fc
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 81 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ The performance of piecewise_attention has improved compared to that in v0.1. In
- the sequence length of k/v can be different from that of q;
- support computation of total attention of each `k` gets from all `q`'s;
- supports returning accumulative attention of each keys.
- supports [MQA](https://arxiv.org/abs/1911.02150) and [GQA](https://arxiv.org/pdf/2305.13245).

#### Limitations

Expand Down
73 changes: 53 additions & 20 deletions src/flag_attn/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,40 @@

__all__ = ["attention"]


def maybe_contiguous(x):
# only when the inner most dimension is contiguous can LDGSTS be used
# so inner-dimension contiguity is enforced.
return x.contiguous() if x.stride(-1) != 1 else x

def rounded_multiple(a, b):
return (a + b - 1) // b * b

# --------------------------- public API ---------------------------
class FlashAttention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_attention):
# size, stride, dtype checking
Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Dq == Dk == Dv
assert Dq == Dk == Dv, "feature size of q, k, v should be equal"
assert Dk in {16, 32, 64, 128}

B, H, M, D = q.shape
N = k.shape[2]
Hk, Hv = k.shape[1], v.shape[1]
assert Hk == Hv, "num of heads in k and v should be equal"
assert H % Hk == 0, "number of heads in q must be a multiple of that in k & v"
num_groups = H // Hk

P_SEQ = N - M
larger_m = M > N

if sm_scale is None:
sm_scale = 1. / math.sqrt(D)

# contiguity
q, k, v = maybe_contiguous(q), maybe_contiguous(k), maybe_contiguous(v)

# to work around https://github.com/openai/triton/issues/2441
device = torch.cuda.device_of(q)
num_sms = torch.cuda.get_device_properties(device).multi_processor_count
Expand All @@ -32,6 +50,7 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_
config_for_split_kv = get_fwd_config_kv_split(B, H, M, N, D, causal)
S = num_splits_herustic(B, H, M, N, config_for_split_kv[0], config_for_split_kv[1], num_sms, 128)
split_kv: bool = S > 1
# print(f"flag_attn choose {S} splits")

if not split_kv:
config = get_fwd_config(B, H, M, N, D, causal)
Expand All @@ -50,7 +69,7 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
B, H, M, N, P_SEQ,
B, H, M, N, P_SEQ, num_groups,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D,
IS_CAUSAL=causal, LARGER_M=larger_m,
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
Expand All @@ -61,7 +80,6 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_

divisible_m = M % BLOCK_M == 0
divisible_n = N % BLOCK_N == 0

# consider using 3d grid to avoid div & rem
multiple_l = torch.empty((B, H, S, M), dtype=torch.float32, device="cuda")
multiple_o = torch.empty((B, H, S, M, D), dtype=torch.float16, device="cuda")
Expand All @@ -74,7 +92,7 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
multiple_o.stride(0), multiple_o.stride(1), multiple_o.stride(2), multiple_o.stride(3), multiple_o.stride(4),
B, H, M, N, P_SEQ, N_SPLIT_SIZE, S,
B, H, M, N, P_SEQ, N_SPLIT_SIZE, S, num_groups,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
IS_CAUSAL=causal, LARGER_M=larger_m,
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
Expand Down Expand Up @@ -103,7 +121,7 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_
q, k, L, tot_attn, sm_scale,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
B, H, M, N, P_SEQ,
B, H, M, N, P_SEQ, num_groups,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
CAUSAL=causal,
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
Expand Down Expand Up @@ -133,6 +151,8 @@ def backward(ctx, do, *ignored):

B, H, M, D = q.shape
N = k.shape[2]
Hk = k.shape[1]
num_groups = H // Hk
P_SEQ = N - M
larger_m = M > N

Expand Down Expand Up @@ -161,8 +181,9 @@ def backward(ctx, do, *ignored):
DIVISIBLE_M=divisible_m,
)

dk = torch.empty_like(k)
dv = torch.empty_like(v)
# NOTE that dk & dv always have the same number of heads as q, instead of q.
dk = torch.empty((B, H, N, D), dtype=k.dtype, device=q.device)
dv = torch.empty((B, H, N, D), dtype=v.dtype, device=q.device)
grid = (triton.cdiv(N, BLOCK_N), H, B)
_bwd_kv_kernel[grid](
q, k, v, sm_scale, do,
Expand All @@ -175,6 +196,7 @@ def backward(ctx, do, *ignored):
dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3),
dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3),
B, H, M, N, P_SEQ,
num_groups,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal,
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
num_stages=num_stages, num_warps=num_warps,
Expand All @@ -192,12 +214,14 @@ def backward(ctx, do, *ignored):
do.stride(0), do.stride(1), do.stride(2), do.stride(3),
dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3),
B, H, M, N, P_SEQ,
num_groups,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
CAUSAL=causal, LARGER_M=larger_m,
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
num_stages=num_stages, num_warps = num_warps,
)

dk = dk.reshape((B, Hk, num_groups, N, D)).sum(2)
dv = dv.reshape((B, Hk, num_groups, N, D)).sum(2)
return dq, dk, dv, None, None, None, None


Expand All @@ -208,21 +232,24 @@ def attention(q, k, v, causal=False, sm_scale=None,
An implementation of FlashAttention v2(https://arxiv.org/abs/2307.08691).
Arguments:
q(torch.Tensor): The first queries. The shape is (batch_size, nheads, seqlen_q, headdim).
k(torch.Tensor): The first keys. The shape is (batch_size, nheads, seqlen_k, headdim).
v(torch.Tensor): The values. The shape is (batch_size, nheads, seqlen_k, headdim).
q(torch.Tensor): The first queries. The shape is (batch_size, num_heads_q, seqlen_q, headdim).
k(torch.Tensor): The first keys. The shape is (batch_size, num_heads_k, seqlen_k, headdim).
v(torch.Tensor): The values. The shape is (batch_size, num_heads_k, seqlen_k, headdim).
causal(bool): Whether causal masking is applied to attention scores before applying softmax.
sm_scale(float): The scaling of attention scores before applying softmax.
return_log_normalizer(bool): Whether to return the log normalizer of softmax inside attention.
return_total_attention(bool): Whether to return the sum of attention along q's sequence dimendion.
Returns:
out(torch.Tensor): The output. The shape is (batch_size, nheads, seqlen_q, headdim).
out(torch.Tensor): The output. The shape is (batch_size, num_heads_q, seqlen_q, headdim).
If `return_log_normalizer` or `return_total_attention`, return the following results in addition.
log_normalizer(torch.Tensor): The log normalizer. The shape is (batch_size, nheads, seqlen_q).
total_attention(torch.Tensor): The total attention. The shape is (batch_size, nheads, seqlen_k).
log_normalizer(torch.Tensor): The log normalizer. The shape is (batch_size, num_heads_q, seqlen_q).
total_attention(torch.Tensor): The total attention. The shape is (batch_size, num_heads_q, seqlen_k).
Notes:
`num_heads_q` must be a multiple of `num_heads_k`.
"""
return FlashAttention.apply(q, k, v, causal, sm_scale, return_log_normalizer, return_total_attention)

Expand Down Expand Up @@ -272,6 +299,7 @@ def _fwd_kernel(
stride_vz, stride_vh, stride_vn, stride_vk,
stride_oz, stride_oh, stride_om, stride_ok,
Z, H, M, N, P_SEQ,
num_groups,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr,
DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
Expand All @@ -289,9 +317,10 @@ def _fwd_kernel(
qk_scale = sm_scale * log2e

# offset pointers for (batch, head)
off_hk = off_h // num_groups
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
K += off_z * stride_kz + off_hk * stride_kh
V += off_z * stride_vz + off_hk * stride_vh
O += off_z * stride_oz + off_h * stride_oh
L += (off_z * H + off_h) * M # l's shape is (B, H, M)

Expand Down Expand Up @@ -491,6 +520,7 @@ def _bwd_kv_kernel(
stride_dkz, stride_dkh, stride_dkn, stride_dkk,
stride_dvz, stride_dvh, stride_dvn, stride_dvk,
Z, H, M, N, P_SEQ,
num_groups,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
CAUSAL: tl.constexpr,
DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
Expand All @@ -504,9 +534,10 @@ def _bwd_kv_kernel(
qk_scale = sm_scale * log2e

# offset pointers for (batch, head)
off_hk = off_h // num_groups
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
K += off_z * stride_kz + off_hk * stride_kh
V += off_z * stride_vz + off_hk * stride_vh
DO += off_z * stride_doz + off_h * stride_doh

# offset pointers for batch/head
Expand Down Expand Up @@ -637,6 +668,7 @@ def _bwd_q_kernel(
stride_doz, stride_doh, stride_dom, stride_dok,
stride_dqz, stride_dqh, stride_dqm, stride_dqk,
Z, H, M, N, P_SEQ,
num_groups,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
CAUSAL: tl.constexpr, LARGER_M: tl.constexpr,
DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
Expand All @@ -654,9 +686,10 @@ def _bwd_q_kernel(
qk_scale = sm_scale * log2e

# offset pointers for (batch, head)
off_hk = off_h // num_groups
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
K += off_z * stride_kz + off_hk * stride_kh
V += off_z * stride_vz + off_hk * stride_vh
DO += off_z * stride_doz + off_h * stride_doh
D += (off_z * H + off_h) * M
L += (off_z * H + off_h) * M
Expand Down
13 changes: 9 additions & 4 deletions src/flag_attn/split_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _fwd_split_kv_kernel(
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vn, stride_vk,
stride_oz, stride_oh, stride_os, stride_om, stride_ok,
Z, H, M, N, P_SEQ, N_SPLIT_SIZE, S,
Z, H, M, N, P_SEQ, N_SPLIT_SIZE, S, num_groups,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr,
DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
Expand All @@ -31,6 +31,7 @@ def _fwd_split_kv_kernel(
off_zh = tl.program_id(2)
off_h = off_zh % H
off_z = off_zh // H
off_hk = off_h // num_groups

# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
Expand All @@ -40,8 +41,8 @@ def _fwd_split_kv_kernel(

# offset pointers for (batch & head)
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
K += off_z * stride_kz + off_hk * stride_kh
V += off_z * stride_vz + off_hk * stride_vh

# offset pointers for (batch & head, split)
O += off_z * stride_oz + off_h * stride_oh + n_split_id * stride_os # o's shape is (B, H, S, M, D)
Expand Down Expand Up @@ -269,6 +270,10 @@ def attention(q, k, v, causal=False, sm_scale=None):

B, H, M, D = q.shape
N = k.shape[2]
Hk, Hv = k.shape[1], v.shape[1]
assert Hk == Hv, "num of heads in k and v should be equal"
assert H % Hk == 0, "number of heads in q must be a multiple of that in k & v"
num_groups = H // Hk
P_SEQ = N - M
larger_m = M > N

Expand Down Expand Up @@ -299,7 +304,7 @@ def attention(q, k, v, causal=False, sm_scale=None):
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
multiple_o.stride(0), multiple_o.stride(1), multiple_o.stride(2), multiple_o.stride(3), multiple_o.stride(4),
B, H, M, N, P_SEQ, N_SPLIT_SIZE, S,
B, H, M, N, P_SEQ, N_SPLIT_SIZE, S, num_groups,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
IS_CAUSAL=causal, LARGER_M=larger_m,
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
Expand Down
9 changes: 9 additions & 0 deletions src/flag_attn/testing/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ def attention(q,
D = q.shape[-1]
if sm_scale is None:
sm_scale = 1. / math.sqrt(D)

num_heads_q = q.shape[1]
num_heads_k = k.shape[1]
assert num_heads_q % num_heads_k == 0
num_groups = num_heads_q // num_heads_k

if num_groups > 1:
k = torch.repeat_interleave(k, repeats=num_groups, dim=1)
v = torch.repeat_interleave(v, repeats=num_groups, dim=1)
kv_seq_len = k.shape[-2]
q_seq_len = q.shape[-2]
p_seq = kv_seq_len - q_seq_len
Expand Down
11 changes: 8 additions & 3 deletions src/flag_attn/total.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ def total_attention(q, k, l, causal=False, sm_scale=None):

B, H, M, D = q.shape
N = k.shape[2]
Hk = k.shape[1]
assert H % Hk == 0, "number of heads in q must be a multiple of that in k"
num_groups = H // Hk

P_SEQ = N - M

if sm_scale is None:
Expand All @@ -34,7 +38,7 @@ def total_attention(q, k, l, causal=False, sm_scale=None):
q, k, l, tot_attn, sm_scale,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
B, H, M, N, P_SEQ,
B, H, M, N, P_SEQ, num_groups,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
CAUSAL=causal,
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
Expand All @@ -48,7 +52,7 @@ def _total_attention_kernel(
Q, K, L, TA, sm_scale,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
Z, H, M, N, P_SEQ,
Z, H, M, N, P_SEQ, num_groups,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
CAUSAL: tl.constexpr,
DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr,
Expand All @@ -61,8 +65,9 @@ def _total_attention_kernel(
qk_scale = sm_scale * log2e

# offset pointers for (batch, head)
off_hk = off_h // num_groups
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
K += off_z * stride_kz + off_hk * stride_kh
L += (off_z * H + off_h) * M
TA += (off_z * H + off_h) * N # (b, h, n)

Expand Down
Loading

0 comments on commit 13664fc

Please sign in to comment.