From 53747d4e3297b34e5b8b4a68d44a858d51ddfdb6 Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Fri, 24 May 2024 23:30:02 +0800 Subject: [PATCH 1/9] adding dropout to regular flash attention. --- src/flag_attn/dropout.py | 15 ++++++ src/flag_attn/flash.py | 109 +++++++++++++++++++++++++++++++++++---- 2 files changed, 113 insertions(+), 11 deletions(-) create mode 100644 src/flag_attn/dropout.py diff --git a/src/flag_attn/dropout.py b/src/flag_attn/dropout.py new file mode 100644 index 0000000..9a33fd1 --- /dev/null +++ b/src/flag_attn/dropout.py @@ -0,0 +1,15 @@ +import torch + +def philox_cuda_seed_offset(increment, device=None): + device = device or torch.cuda.current_device() + gen = torch.cuda.default_generators[device] + state_copy = gen.get_state() + c0, c1 = state_copy.view(torch.int64) + seed, offset = int(c0), int(c1) + increment = (increment + 3) // 4 * 4 + c1 += increment + # get_state returns a new tensor, so it needs set_state to update the actual generator state. + gen.set_state(state_copy) + return seed, offset + + diff --git a/src/flag_attn/flash.py b/src/flag_attn/flash.py index a32564b..bbb0b03 100644 --- a/src/flag_attn/flash.py +++ b/src/flag_attn/flash.py @@ -6,12 +6,14 @@ from flag_attn.split_kv import _fwd_split_kv_kernel, _fwd_combine_kv_splits, num_splits_herustic from flag_attn.split_kv import get_fwd_config as get_fwd_config_kv_split +from .dropout import philox_cuda_seed_offset + __all__ = ["attention"] # --------------------------- public API --------------------------- class FlashAttention(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_attention): + def forward(ctx, q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, return_total_attention): Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1] assert Dq == Dk == Dv assert Dk in {16, 32, 64, 128} @@ -24,6 +26,14 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ if sm_scale is None: sm_scale = 1. / math.sqrt(D) + # Dropout preparation. + is_dropout = dropout_p > 0 + if is_dropout: + n_dropouts = B * H * M * N + seed, offset = philox_cuda_seed_offset(n_dropouts) + else: + seed, offset = None, None + # to work around https://github.com/openai/triton/issues/2441 device = torch.cuda.device_of(q) num_sms = torch.cuda.get_device_properties(device).multi_processor_count @@ -44,7 +54,9 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ o = torch.empty_like(q) L = torch.empty((B, H, M), device=q.device, dtype=torch.float32) _fwd_kernel[grid]( - q, k, v, sm_scale, + q, k, v, + sm_scale, + dropout_p, seed, offset, L, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), @@ -52,7 +64,7 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ o.stride(0), o.stride(1), o.stride(2), o.stride(3), B, H, M, N, P_SEQ, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D, - IS_CAUSAL=causal, LARGER_M=larger_m, + IS_CAUSAL=causal, IS_DROPOUT=is_dropout, LARGER_M=larger_m, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, num_warps=num_warps, num_stages=num_stages, ) @@ -68,7 +80,8 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ grid = (triton.cdiv(M, BLOCK_M), S, H * B) N_SPLIT_SIZE = triton.cdiv(triton.cdiv(N, BLOCK_N), S) * BLOCK_N _fwd_split_kv_kernel[grid]( - q, k, v, sm_scale, + q, k, v, + sm_scale, multiple_l, multiple_o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), @@ -114,6 +127,10 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ ctx.save_for_backward(q, k, v, o, L) ctx.sm_scale = sm_scale ctx.causal = causal + ctx.is_dropout = is_dropout + ctx.dropout_p = dropout_p + ctx.seed = seed + ctx.offset = offset has_extra_return = return_log_normalizer or return_total_attention if has_extra_return: @@ -130,6 +147,10 @@ def backward(ctx, do, *ignored): q, k, v, o, L = ctx.saved_tensors sm_scale = ctx.sm_scale causal = ctx.causal + is_dropout = ctx.is_dropout + dropout_p = ctx.dropout_p + seed = ctx.seed + offset = ctx.offset B, H, M, D = q.shape N = k.shape[2] @@ -148,15 +169,19 @@ def backward(ctx, do, *ignored): divisible_m = M % BLOCK_M == 0 divisible_n = N % BLOCK_N == 0 + p = 1 / (1 - dropout_p) + delta = torch.empty_like(L) grid = (triton.cdiv(M, BLOCK_M), H, B) _bwd_preprocess[grid]( o, do, delta, + p, o.stride(0), o.stride(1), o.stride(2), o.stride(3), do.stride(0), do.stride(1), do.stride(2), do.stride(3), delta.stride(0), delta.stride(1), delta.stride(2), M, + IS_DROPOUT=is_dropout BLOCK_M=BLOCK_M, D_HEAD=D, DIVISIBLE_M=divisible_m, ) @@ -165,8 +190,10 @@ def backward(ctx, do, *ignored): dv = torch.empty_like(v) grid = (triton.cdiv(N, BLOCK_N), H, B) _bwd_kv_kernel[grid]( - q, k, v, sm_scale, do, - dk, dv, + q, k, v, + dropout_p, + sm_scale, + do, dk, dv, L, delta, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), @@ -175,6 +202,7 @@ def backward(ctx, do, *ignored): dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), B, H, M, N, P_SEQ, + IS_DROPOUT=is_dropout, BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, num_stages=num_stages, num_warps=num_warps, @@ -201,7 +229,7 @@ def backward(ctx, do, *ignored): return dq, dk, dv, None, None, None, None -def attention(q, k, v, causal=False, sm_scale=None, +def attention(q, k, v, causal=False, sm_scale=None, dropout_p=0.0, return_log_normalizer=False, return_total_attention=False, ): """ @@ -213,6 +241,7 @@ def attention(q, k, v, causal=False, sm_scale=None, v(torch.Tensor): The values. The shape is (batch_size, nheads, seqlen_k, headdim). causal(bool): Whether causal masking is applied to attention scores before applying softmax. sm_scale(float): The scaling of attention scores before applying softmax. + dropout_p(float): Dropout probability. return_log_normalizer(bool): Whether to return the log normalizer of softmax inside attention. return_total_attention(bool): Whether to return the sum of attention along q's sequence dimendion. @@ -224,7 +253,7 @@ def attention(q, k, v, causal=False, sm_scale=None, log_normalizer(torch.Tensor): The log normalizer. The shape is (batch_size, nheads, seqlen_q). total_attention(torch.Tensor): The total attention. The shape is (batch_size, nheads, seqlen_k). """ - return FlashAttention.apply(q, k, v, causal, sm_scale, return_log_normalizer, return_total_attention) + return FlashAttention.apply(q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, return_total_attention) # --------------------------- Forward --------------------------- @@ -265,7 +294,11 @@ def get_fwd_config(B, H, M, N, D, causal): @triton.jit def _fwd_kernel( - Q, K, V, sm_scale, + Q, K, V, + sm_scale, + dropout_p, + seed, + offset, L, O, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, @@ -273,7 +306,7 @@ def _fwd_kernel( stride_oz, stride_oh, stride_om, stride_ok, Z, H, M, N, P_SEQ, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, + IS_CAUSAL: tl.constexpr, IS_DROPOUT: tl.constexpr, LARGER_M: tl.constexpr, DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, ): input_dtype = Q.dtype.element_ty @@ -299,6 +332,12 @@ def _fwd_kernel( offs_m = start_m * BLOCK_M + offs_m_base offs_n_base = tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_DMODEL) + + if IS_DROPOUT: + rowblock_base = off_z * H * M * N + off_h * M * N + start_m * BLOCK_M * N + offs_rng_base = offset + rowblock_base + offs_rng_base += tl.arange(0, BLOCK_M)[:, None] * N + offs_rng_base += tl.arange(0, BLOCK_N)[None, :] # initialize pointers to value-like data q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) @@ -362,6 +401,8 @@ def _fwd_kernel( k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=".cg") v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=".cg") + # -- compute dropout mask -- + # -- compute qk --- s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) s += tl.dot(q, k) @@ -372,11 +413,18 @@ def _fwd_kernel( causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :] s = tl.where(causal_mask, s, float("-inf")) + # -- compute scaling constant --- m_i_new = tl.maximum(m_i, tl.max(s, 1)) alpha = tl.math.exp2((m_i - m_i_new) * qk_scale) p = tl.math.exp2(s * qk_scale - m_i_new[:, None] * qk_scale) + # -- apply dropout -- + if IS_DROPOUT: + offs_rng = start_n + offs_rng_base + pmask = tl.rand(seed, offs_rng, n_rounds=6) > dropout_p + p *= pmask + # -- scale and update acc: acc *= alpha[:, None]-- acc *= alpha[:, None] acc += tl.dot(p.to(input_dtype), v) @@ -397,6 +445,11 @@ def _fwd_kernel( acc = acc * (1.0 / l_i[:, None]) l = m_i * sm_scale + tl.log(l_i) # log(normalizer) + # -- scale o due to dropout + if IS_DROPOUT: + scale = 1.0 / (1.0 - dropout_p) + acc *= scale + if DIVISIBLE_M: tl.store(l_ptrs, l, cache_modifier=".cg") tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=".cg") @@ -439,12 +492,13 @@ def get_bwd_config(B, H, M, N, D, causal): def _bwd_preprocess( Out, DO, Delta, + p, stride_oz, stride_oh, stride_om, stride_ok, stride_doz, stride_doh, stride_dom, stride_dok, stride_dz, stride_dh, stride_dm, M, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, - DIVISIBLE_M: tl.constexpr, + DIVISIBLE_M: tl.constexpr, IS_DROPOUT ): off_h = tl.program_id(1) off_z = tl.program_id(2) @@ -468,6 +522,9 @@ def _bwd_preprocess( o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32) do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32) + if IS_DROPOUT: + do *= p + # compute delta = tl.sum(o * do, axis=1) # write-back @@ -484,6 +541,7 @@ def _bwd_kv_kernel( DK, DV, L, D, + dropout_p, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, @@ -491,6 +549,7 @@ def _bwd_kv_kernel( stride_dkz, stride_dkh, stride_dkn, stride_dkk, stride_dvz, stride_dvh, stride_dvn, stride_dvk, Z, H, M, N, P_SEQ, + IS_DROPOUT: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, CAUSAL: tl.constexpr, DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, @@ -546,6 +605,13 @@ def _bwd_kv_kernel( v = tl.load(v_ptrs, mask=mask_n[:, None]) k = tl.load(k_ptrs, mask=mask_n[:, None]) + # dropout + if IS_DROPOUT: + colblock_base = off_z * H * M * N + off_h * M * N + start_n * BLOCK_N + offs_rng_base = offset + colblock_base + offs_rng_base += tl.arange(0, BLOCK_M)[:, None] * N + offs_rng_base += tl.arange(0, BLOCK_N)[None, :] + # initialize dk amd dv dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) @@ -585,6 +651,12 @@ def _bwd_kv_kernel( if CAUSAL: p = tl.where(causal_mask, p, 0.0) + # -- apply dropout -- + if IS_DROPOUT: + offs_rng = offs_rng_base + start_m * N + pmask = tl.rand(seed, offs_rng, n_rounds=6) > dropout_p + p *= pmask + # compute dv = dot(p, do) if DIVISIBLE_M: do = tl.load(do_ptrs) @@ -631,12 +703,14 @@ def _bwd_q_kernel( DQ, L, D, + dropout_p, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, stride_doz, stride_doh, stride_dom, stride_dok, stride_dqz, stride_dqh, stride_dqm, stride_dqk, Z, H, M, N, P_SEQ, + IS_DROPOUT, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, @@ -706,6 +780,13 @@ def _bwd_q_kernel( else: hi = N + # dropout + if IS_DROPOUT: + rowblock_base = off_z * H * M * N + off_h * M * N + start_m * BLOCK_M * N + offs_rng_base = offset + rowblock_base + offs_rng_base += tl.arange(0, BLOCK_M)[:, None] * N + offs_rng_base += tl.arange(0, BLOCK_N)[None, :] + # loop over a row for start_n in range(0, hi, BLOCK_N): offs_n = start_n + offs_n_base @@ -736,6 +817,12 @@ def _bwd_q_kernel( # s = tl.where(valid_mask, s, float("-inf")) p = tl.math.exp2(s * qk_scale - l[:, None] * log2e) # (BLOCK_M, BLOCK_N) + # -- apply dropout -- + if IS_DROPOUT: + offs_rng = start_n + offs_rng_base + pmask = tl.rand(seed, offs_rng, n_rounds=6) > dropout_p + p *= pmask + # compute dp = dot(v, do) dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) dp += tl.dot(do.to(input_dtype), tl.trans(v)) From c23bf88fcb5cd2855f4e7647eec4afed5cc5f8ae Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Sat, 25 May 2024 00:11:59 +0800 Subject: [PATCH 2/9] fixed syntax errors. --- src/flag_attn/flash.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/src/flag_attn/flash.py b/src/flag_attn/flash.py index bbb0b03..b0d01a5 100644 --- a/src/flag_attn/flash.py +++ b/src/flag_attn/flash.py @@ -32,7 +32,7 @@ def forward(ctx, q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, re n_dropouts = B * H * M * N seed, offset = philox_cuda_seed_offset(n_dropouts) else: - seed, offset = None, None + seed, offset = 0, 0 # to work around https://github.com/openai/triton/issues/2441 device = torch.cuda.device_of(q) @@ -181,7 +181,7 @@ def backward(ctx, do, *ignored): do.stride(0), do.stride(1), do.stride(2), do.stride(3), delta.stride(0), delta.stride(1), delta.stride(2), M, - IS_DROPOUT=is_dropout + IS_DROPOUT=is_dropout, BLOCK_M=BLOCK_M, D_HEAD=D, DIVISIBLE_M=divisible_m, ) @@ -190,11 +190,13 @@ def backward(ctx, do, *ignored): dv = torch.empty_like(v) grid = (triton.cdiv(N, BLOCK_N), H, B) _bwd_kv_kernel[grid]( - q, k, v, - dropout_p, + q, k, v, sm_scale, do, dk, dv, L, delta, + dropout_p, + seed, + offset, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), @@ -214,19 +216,23 @@ def backward(ctx, do, *ignored): q, k, v, sm_scale, do, dq, L, delta, + dropout_p, + seed, + offset, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), do.stride(0), do.stride(1), do.stride(2), do.stride(3), dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), B, H, M, N, P_SEQ, + IS_DROPOUT=is_dropout, BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal, LARGER_M=larger_m, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, num_stages=num_stages, num_warps = num_warps, ) - return dq, dk, dv, None, None, None, None + return dq, dk, dv, None, None, None, None, None def attention(q, k, v, causal=False, sm_scale=None, dropout_p=0.0, @@ -332,7 +338,7 @@ def _fwd_kernel( offs_m = start_m * BLOCK_M + offs_m_base offs_n_base = tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_DMODEL) - + if IS_DROPOUT: rowblock_base = off_z * H * M * N + off_h * M * N + start_m * BLOCK_M * N offs_rng_base = offset + rowblock_base @@ -542,6 +548,8 @@ def _bwd_kv_kernel( L, D, dropout_p, + seed, + offset, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, @@ -549,9 +557,8 @@ def _bwd_kv_kernel( stride_dkz, stride_dkh, stride_dkn, stride_dkk, stride_dvz, stride_dvh, stride_dvn, stride_dvk, Z, H, M, N, P_SEQ, - IS_DROPOUT: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - CAUSAL: tl.constexpr, + CAUSAL: tl.constexpr, IS_DROPOUT: tl.constexpr, DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, ): input_dtype = Q.dtype.element_ty @@ -704,15 +711,16 @@ def _bwd_q_kernel( L, D, dropout_p, + seed, + offset, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, stride_doz, stride_doh, stride_dom, stride_dok, stride_dqz, stride_dqh, stride_dqm, stride_dqk, Z, H, M, N, P_SEQ, - IS_DROPOUT, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, + CAUSAL: tl.constexpr, IS_DROPOUT: tl.constexpr, LARGER_M: tl.constexpr, DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, ): input_dtype = Q.dtype.element_ty From 3fa61836c11d80e8b0bb9588fe3a2b354fe7c82a Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Mon, 27 May 2024 22:17:27 +0800 Subject: [PATCH 3/9] working in progress, forward dropout passes tests. --- src/flag_attn/dropout.py | 20 +++++++++++++ src/flag_attn/flash.py | 39 +++++++++++++++++------- src/flag_attn/testing/flash.py | 20 ++++++++++--- tests/flag_attn/test_dropout.py | 29 ++++++++++++++++++ tests/flag_attn/test_flash_attention.py | 40 +++++++++++++++++++++++++ 5 files changed, 134 insertions(+), 14 deletions(-) create mode 100644 tests/flag_attn/test_dropout.py diff --git a/src/flag_attn/dropout.py b/src/flag_attn/dropout.py index 9a33fd1..2b92626 100644 --- a/src/flag_attn/dropout.py +++ b/src/flag_attn/dropout.py @@ -1,4 +1,6 @@ import torch +import triton +import triton.language as tl def philox_cuda_seed_offset(increment, device=None): device = device or torch.cuda.current_device() @@ -12,4 +14,22 @@ def philox_cuda_seed_offset(increment, device=None): gen.set_state(state_copy) return seed, offset +@triton.jit +def dropout_mask_kernel(dropout_mask, B, H, M, N, dropout_p, seed, offset): + row, b, h = tl.program_id(0), tl.program_id(1), tl.program_id(2) + offs_base = b * H * M * N + h * M * N + row * N + BLOCK: tl.constexpr = 1024 + offs_base += tl.arange(0, BLOCK) + for start_n in range(0, N, BLOCK): + offs = start_n + offs_base + rng_offs = offset + offs + pmask = tl.rand(seed, rng_offs, n_rounds=6) > dropout_p + row_mask = start_n + tl.arange(0, BLOCK) < N + tl.store(dropout_mask + offs, pmask, mask=row_mask) + +def dropout_mask(x, B, H, M, N, dropout_p, seed, offset): + dropout_mask = torch.empty((B, H, M, N), dtype=torch.bool, device=x.device) + grid = (M, B, H) + dropout_mask_kernel[grid](dropout_mask, B, H, M, N, dropout_p, seed, offset) + return dropout_mask diff --git a/src/flag_attn/flash.py b/src/flag_attn/flash.py index b0d01a5..a8b7c4b 100644 --- a/src/flag_attn/flash.py +++ b/src/flag_attn/flash.py @@ -13,7 +13,10 @@ # --------------------------- public API --------------------------- class FlashAttention(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, return_total_attention): + def forward(ctx, q, k, v, causal, sm_scale, dropout_p, + return_log_normalizer=False, + return_total_attention=False, + return_seed_offset=False): Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1] assert Dq == Dk == Dv assert Dk in {16, 32, 64, 128} @@ -53,8 +56,10 @@ def forward(ctx, q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, re grid = (triton.cdiv(M, BLOCK_M), H, B) o = torch.empty_like(q) L = torch.empty((B, H, M), device=q.device, dtype=torch.float32) + p = torch.empty((B, H, M, N), device=q.device, dtype=torch.float32) _fwd_kernel[grid]( q, k, v, + p, sm_scale, dropout_p, seed, offset, L, o, @@ -69,6 +74,7 @@ def forward(ctx, q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, re num_warps=num_warps, num_stages=num_stages, ) else: # split kv + assert not is_dropout, "Cannot apply dropout with splitkv yet." BLOCK_M, BLOCK_N, num_stages, num_warps = config_for_split_kv divisible_m = M % BLOCK_M == 0 @@ -132,15 +138,18 @@ def forward(ctx, q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, re ctx.seed = seed ctx.offset = offset - has_extra_return = return_log_normalizer or return_total_attention + has_extra_return = True in (return_log_normalizer, return_total_attention, return_seed_offset) if has_extra_return: outs = ( o, + p, L if return_log_normalizer else None, - tot_attn if return_total_attention else None + tot_attn if return_total_attention else None, + seed if return_seed_offset else None, + offset if return_seed_offset else None ) return outs - return o + return o, p @staticmethod def backward(ctx, do, *ignored): @@ -232,11 +241,11 @@ def backward(ctx, do, *ignored): num_stages=num_stages, num_warps = num_warps, ) - return dq, dk, dv, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None def attention(q, k, v, causal=False, sm_scale=None, dropout_p=0.0, - return_log_normalizer=False, return_total_attention=False, + return_log_normalizer=False, return_total_attention=False, return_seed_offset=False ): """ An implementation of FlashAttention v2(https://arxiv.org/abs/2307.08691). @@ -259,7 +268,7 @@ def attention(q, k, v, causal=False, sm_scale=None, dropout_p=0.0, log_normalizer(torch.Tensor): The log normalizer. The shape is (batch_size, nheads, seqlen_q). total_attention(torch.Tensor): The total attention. The shape is (batch_size, nheads, seqlen_k). """ - return FlashAttention.apply(q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, return_total_attention) + return FlashAttention.apply(q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, return_total_attention, return_seed_offset) # --------------------------- Forward --------------------------- @@ -300,7 +309,7 @@ def get_fwd_config(B, H, M, N, D, causal): @triton.jit def _fwd_kernel( - Q, K, V, + Q, K, V, P, sm_scale, dropout_p, seed, @@ -344,6 +353,7 @@ def _fwd_kernel( offs_rng_base = offset + rowblock_base offs_rng_base += tl.arange(0, BLOCK_M)[:, None] * N offs_rng_base += tl.arange(0, BLOCK_N)[None, :] + p_base = rowblock_base + tl.arange(0, BLOCK_M)[:, None] * N + tl.arange(0, BLOCK_N)[None, :] # initialize pointers to value-like data q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) @@ -394,6 +404,8 @@ def _fwd_kernel( offs_n_init = offs_n_base k_ptrs = K + (offs_k[:, None] * stride_vk + offs_n_init[None, :] * stride_vn) # (BLOCK_DMODEL, BLOCK_N) v_ptrs = V + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL) + p_ptrs = P + offs_n_init + p_out = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) for start_n in range(0, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) offs_n = start_n + offs_n_base @@ -408,6 +420,7 @@ def _fwd_kernel( v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=".cg") # -- compute dropout mask -- + # -- compute qk --- s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -425,18 +438,22 @@ def _fwd_kernel( alpha = tl.math.exp2((m_i - m_i_new) * qk_scale) p = tl.math.exp2(s * qk_scale - m_i_new[:, None] * qk_scale) + p_sum = tl.sum(p, 1) # -- apply dropout -- if IS_DROPOUT: offs_rng = start_n + offs_rng_base pmask = tl.rand(seed, offs_rng, n_rounds=6) > dropout_p - p *= pmask + p *= pmask.to(tl.float32) + p_out = p + # mask_n = offs_n < N + # tl.store(P + p_base + start_n, p, mask=mask_n[None, :]) # -- scale and update acc: acc *= alpha[:, None]-- acc *= alpha[:, None] acc += tl.dot(p.to(input_dtype), v) # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) + l_i = l_i * alpha + p_sum m_i = m_i_new # update pointers k_ptrs += BLOCK_N * stride_kn @@ -453,6 +470,8 @@ def _fwd_kernel( # -- scale o due to dropout if IS_DROPOUT: + mask_n = offs_n_base < N + tl.store(P + p_base, p_out / l_i[:, None], mask=mask_n[None, :]) scale = 1.0 / (1.0 - dropout_p) acc *= scale diff --git a/src/flag_attn/testing/flash.py b/src/flag_attn/testing/flash.py index 7f272ed..64bcd75 100644 --- a/src/flag_attn/testing/flash.py +++ b/src/flag_attn/testing/flash.py @@ -1,10 +1,13 @@ import math import torch + def attention(q, k, v, causal, + dropout_p=0.0, + dropout_mask=None, sm_scale=None, return_log_normalizer=False, return_total_attention=False, @@ -40,13 +43,22 @@ def attention(q, if return_total_attention: tot_attn = torch.sum(P, dim=-2) - attn_output = torch.matmul(P.to(v.dtype), v).to(input_dtype) + # Applies dropout + dropout_scaling = 1.0 / (1 - dropout_p) + if dropout_mask is not None: + P_dropout = P.masked_fill(~dropout_mask, 0.0) + # P = P.masked_fill(~dropout_mask, 0.0) + + # attn_output = torch.matmul(P.to(v.dtype), v * dropout_scaling).to(input_dtype) + attn_output = torch.matmul(P_dropout.to(v.dtype), v * dropout_scaling).to(input_dtype) has_extra_return = return_log_normalizer or return_total_attention if has_extra_return: outs = (attn_output, - log_normalizer if return_log_normalizer else None, - tot_attn if return_total_attention else None) + P, + P_dropout, + log_normalizer if return_log_normalizer else None, + tot_attn if return_total_attention else None) return outs else: - return attn_output + return attn_output, P, P_dropout diff --git a/tests/flag_attn/test_dropout.py b/tests/flag_attn/test_dropout.py new file mode 100644 index 0000000..98b16df --- /dev/null +++ b/tests/flag_attn/test_dropout.py @@ -0,0 +1,29 @@ +import torch +import pytest +from flag_attn.dropout import dropout_mask + +@pytest.mark.parametrize('x', [torch.empty(1, device='cuda')]) +@pytest.mark.parametrize('B, H, M, N', [ + (2, 4, 512, 612), + (2, 4, 1024, 1034), + (2, 4, 2048, 2048), + (2, 4, 4096, 4096), + (2, 4, 4001, 4001), + (2, 4, 4001, 4096), + (2, 4, 4096, 4000), + (1, 2, 8192, 8202), + (1, 2, 8192, 8192), +]) +@pytest.mark.parametrize('p', [0.5, 0.8]) +def test_dropout_mask(x, B, H, M, N, p): + import math + seed = 123456789 + offset = 123456789123456789 + mask = dropout_mask(x, B, H, M, N, p, seed, offset) + # zeros indicate to drop + # k follows Binomial distributio B(k; n, p) + n = mask.numel() + k = torch.sum(mask == 0) + p_cap = k / n + tol = 0.01 + assert math.fabs(p_cap - p) < tol * p diff --git a/tests/flag_attn/test_flash_attention.py b/tests/flag_attn/test_flash_attention.py index 8c6d06b..91a5783 100644 --- a/tests/flag_attn/test_flash_attention.py +++ b/tests/flag_attn/test_flash_attention.py @@ -190,3 +190,43 @@ def test_attention_with_aux_outs(B, H, M, N, D, causal, stride_order, dtype, sca triton_max_diff = max_diff(tot_attn_hyp, tot_attn_ref) assert triton_max_diff <= 2 * torch_max_diff + 1e-5 + +@pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) +@pytest.mark.parametrize('scale', [1.0, 2.0]) +@pytest.mark.parametrize('B, H, M, N, D', [ + (2, 4, 512, 612, 128), + (2, 4, 1024, 1034, 64), + (2, 4, 2048, 2048, 32), + # (2, 4, 4096, 4096, 16), + # (2, 4, 4001, 4001, 32), + # (2, 4, 4001, 4096, 64), + # (2, 4, 4096, 4000, 128), + # (1, 2, 8192, 8202, 16), + # (1, 2, 8192, 8192, 32), +]) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('dropout_p', [0.5, 0.8]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('stride_order', ['BHTD']) +def test_attention_dropout(B, H, M, N, D, causal, dropout_p, stride_order, dtype, scale, device_id): + device = f"cuda:{device_id}" + if stride_order == "BHTD": + q = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) + k = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) + v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) + else: + q = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) + k = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) + v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) + + from flag_attn.dropout import dropout_mask + o_hyp, p_hyp, _, _, seed, offset = flag_attn.flash_attention(q, k, v, causal, dropout_p=dropout_p, return_seed_offset=True) + mask = dropout_mask(q, B, H, M, N, dropout_p, seed, offset) + o_ref, P_ref, P_dropout_ref = flag_attn.testing.flash_attention(q, k, v, causal, dropout_p=dropout_p, dropout_mask=mask, upcast=True) + o_torch, P_torch, P_dropout_torch = flag_attn.testing.flash_attention(q, k, v, causal, dropout_p=dropout_p, dropout_mask=mask, upcast=False) + + torch_max_diff = max_diff(o_torch, o_ref) + triton_max_diff = max_diff(o_hyp, o_ref) + report("o hyp", o_hyp, o_ref) + report("o torch", o_torch, o_ref) + assert triton_max_diff <= 2 * torch_max_diff + 1e-5 \ No newline at end of file From b60017dee24260ec7a06699e7acd4335b7bb0232 Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Tue, 28 May 2024 12:00:40 +0800 Subject: [PATCH 4/9] Flash attn fwd with dropout done. --- src/flag_attn/flash.py | 19 ++++--------------- src/flag_attn/testing/flash.py | 12 ++++-------- tests/flag_attn/test_flash_attention.py | 22 +++++++++++----------- 3 files changed, 19 insertions(+), 34 deletions(-) diff --git a/src/flag_attn/flash.py b/src/flag_attn/flash.py index a8b7c4b..f770f54 100644 --- a/src/flag_attn/flash.py +++ b/src/flag_attn/flash.py @@ -59,7 +59,6 @@ def forward(ctx, q, k, v, causal, sm_scale, dropout_p, p = torch.empty((B, H, M, N), device=q.device, dtype=torch.float32) _fwd_kernel[grid]( q, k, v, - p, sm_scale, dropout_p, seed, offset, L, o, @@ -142,14 +141,13 @@ def forward(ctx, q, k, v, causal, sm_scale, dropout_p, if has_extra_return: outs = ( o, - p, L if return_log_normalizer else None, tot_attn if return_total_attention else None, seed if return_seed_offset else None, offset if return_seed_offset else None ) return outs - return o, p + return o @staticmethod def backward(ctx, do, *ignored): @@ -309,7 +307,7 @@ def get_fwd_config(B, H, M, N, D, causal): @triton.jit def _fwd_kernel( - Q, K, V, P, + Q, K, V, sm_scale, dropout_p, seed, @@ -404,8 +402,6 @@ def _fwd_kernel( offs_n_init = offs_n_base k_ptrs = K + (offs_k[:, None] * stride_vk + offs_n_init[None, :] * stride_vn) # (BLOCK_DMODEL, BLOCK_N) v_ptrs = V + (offs_n_init[:, None] * stride_kn + offs_k[None, :] * stride_kk) # (BLOCK_N, BLOCK_DMODEL) - p_ptrs = P + offs_n_init - p_out = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) for start_n in range(0, hi, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) offs_n = start_n + offs_n_base @@ -419,9 +415,6 @@ def _fwd_kernel( k = tl.load(k_ptrs, mask=mask_n[None, :], cache_modifier=".cg") v = tl.load(v_ptrs, mask=mask_n[:, None], cache_modifier=".cg") - # -- compute dropout mask -- - - # -- compute qk --- s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) s += tl.dot(q, k) @@ -432,21 +425,19 @@ def _fwd_kernel( causal_mask = (P_SEQ + offs_m[:, None]) >= offs_n[None, :] s = tl.where(causal_mask, s, float("-inf")) - # -- compute scaling constant --- m_i_new = tl.maximum(m_i, tl.max(s, 1)) alpha = tl.math.exp2((m_i - m_i_new) * qk_scale) p = tl.math.exp2(s * qk_scale - m_i_new[:, None] * qk_scale) + # -- compute partial sumexpn before applying dropout p_sum = tl.sum(p, 1) + # -- apply dropout -- if IS_DROPOUT: offs_rng = start_n + offs_rng_base pmask = tl.rand(seed, offs_rng, n_rounds=6) > dropout_p p *= pmask.to(tl.float32) - p_out = p - # mask_n = offs_n < N - # tl.store(P + p_base + start_n, p, mask=mask_n[None, :]) # -- scale and update acc: acc *= alpha[:, None]-- acc *= alpha[:, None] @@ -470,8 +461,6 @@ def _fwd_kernel( # -- scale o due to dropout if IS_DROPOUT: - mask_n = offs_n_base < N - tl.store(P + p_base, p_out / l_i[:, None], mask=mask_n[None, :]) scale = 1.0 / (1.0 - dropout_p) acc *= scale diff --git a/src/flag_attn/testing/flash.py b/src/flag_attn/testing/flash.py index 64bcd75..fb83ce8 100644 --- a/src/flag_attn/testing/flash.py +++ b/src/flag_attn/testing/flash.py @@ -43,22 +43,18 @@ def attention(q, if return_total_attention: tot_attn = torch.sum(P, dim=-2) - # Applies dropout + # Applies dropout dropout_scaling = 1.0 / (1 - dropout_p) if dropout_mask is not None: - P_dropout = P.masked_fill(~dropout_mask, 0.0) - # P = P.masked_fill(~dropout_mask, 0.0) + P = P.masked_fill(~dropout_mask, 0.0) - # attn_output = torch.matmul(P.to(v.dtype), v * dropout_scaling).to(input_dtype) - attn_output = torch.matmul(P_dropout.to(v.dtype), v * dropout_scaling).to(input_dtype) + attn_output = torch.matmul(P.to(v.dtype), v * dropout_scaling).to(input_dtype) has_extra_return = return_log_normalizer or return_total_attention if has_extra_return: outs = (attn_output, - P, - P_dropout, log_normalizer if return_log_normalizer else None, tot_attn if return_total_attention else None) return outs else: - return attn_output, P, P_dropout + return attn_output diff --git a/tests/flag_attn/test_flash_attention.py b/tests/flag_attn/test_flash_attention.py index 91a5783..26e1076 100644 --- a/tests/flag_attn/test_flash_attention.py +++ b/tests/flag_attn/test_flash_attention.py @@ -197,18 +197,18 @@ def test_attention_with_aux_outs(B, H, M, N, D, causal, stride_order, dtype, sca (2, 4, 512, 612, 128), (2, 4, 1024, 1034, 64), (2, 4, 2048, 2048, 32), - # (2, 4, 4096, 4096, 16), - # (2, 4, 4001, 4001, 32), - # (2, 4, 4001, 4096, 64), - # (2, 4, 4096, 4000, 128), - # (1, 2, 8192, 8202, 16), - # (1, 2, 8192, 8192, 32), + (2, 4, 4096, 4096, 16), + (2, 4, 4001, 4001, 32), + (2, 4, 4001, 4096, 64), + (2, 4, 4096, 4000, 128), + (1, 2, 8192, 8202, 16), + (1, 2, 8192, 8192, 32), ]) @pytest.mark.parametrize('causal', [False, True]) @pytest.mark.parametrize('dropout_p', [0.5, 0.8]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize('stride_order', ['BHTD']) -def test_attention_dropout(B, H, M, N, D, causal, dropout_p, stride_order, dtype, scale, device_id): +@pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) +def test_attention_fwd_dropout(B, H, M, N, D, causal, dropout_p, stride_order, dtype, scale, device_id): device = f"cuda:{device_id}" if stride_order == "BHTD": q = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) @@ -220,10 +220,10 @@ def test_attention_dropout(B, H, M, N, D, causal, dropout_p, stride_order, dtype v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) from flag_attn.dropout import dropout_mask - o_hyp, p_hyp, _, _, seed, offset = flag_attn.flash_attention(q, k, v, causal, dropout_p=dropout_p, return_seed_offset=True) + o_hyp, _, _, seed, offset = flag_attn.flash_attention(q, k, v, causal, dropout_p=dropout_p, return_seed_offset=True) mask = dropout_mask(q, B, H, M, N, dropout_p, seed, offset) - o_ref, P_ref, P_dropout_ref = flag_attn.testing.flash_attention(q, k, v, causal, dropout_p=dropout_p, dropout_mask=mask, upcast=True) - o_torch, P_torch, P_dropout_torch = flag_attn.testing.flash_attention(q, k, v, causal, dropout_p=dropout_p, dropout_mask=mask, upcast=False) + o_ref = flag_attn.testing.flash_attention(q, k, v, causal, dropout_p=dropout_p, dropout_mask=mask, upcast=True) + o_torch = flag_attn.testing.flash_attention(q, k, v, causal, dropout_p=dropout_p, dropout_mask=mask, upcast=False) torch_max_diff = max_diff(o_torch, o_ref) triton_max_diff = max_diff(o_hyp, o_ref) From 91a22f33f4407b4461e841cf0d8ce0bc304c215c Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Wed, 29 May 2024 17:38:08 +0800 Subject: [PATCH 5/9] fixed extra return errors in tests. --- src/flag_attn/flash.py | 2 +- tests/flag_attn/test_flash_attention.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/flag_attn/flash.py b/src/flag_attn/flash.py index e6206da..cee7c05 100644 --- a/src/flag_attn/flash.py +++ b/src/flag_attn/flash.py @@ -717,7 +717,7 @@ def _bwd_kv_kernel( if IS_DROPOUT: dv += tl.dot(tl.trans(p_masked), do) * rp # (BLOCK_N, BLOCK_DMODEL) # still correct else: - dv += tl.dot(tl.trans(p), do) # (BLOCK_N, BLOCK_DMODEL) # still correct + dv += tl.dot(tl.trans(p).to(input_dtype), do) # (BLOCK_N, BLOCK_DMODEL) # still correct # compute dp = dot(v, do) if DIVISIBLE_M: diff --git a/tests/flag_attn/test_flash_attention.py b/tests/flag_attn/test_flash_attention.py index 218ebfb..3f31154 100644 --- a/tests/flag_attn/test_flash_attention.py +++ b/tests/flag_attn/test_flash_attention.py @@ -206,7 +206,7 @@ def test_attention_with_aux_outs(B, H, M, N, D, causal, stride_order, dtype, sca o_ref, log_norm_ref, tot_attn_ref = flag_attn.testing.flash_attention(q, k, v, causal, return_log_normalizer=True, return_total_attention=True, upcast=True) o_torch, log_norm_torch, tot_attn_torch = flag_attn.testing.flash_attention(q, k, v, causal, return_log_normalizer=True, return_total_attention=True, upcast=False) - o_hyp, log_norm_hyp, tot_attn_hyp = flag_attn.flash_attention(q, k, v, causal, return_log_normalizer=True, return_total_attention=True) + o_hyp, log_norm_hyp, tot_attn_hyp, *_ = flag_attn.flash_attention(q, k, v, causal, return_log_normalizer=True, return_total_attention=True) torch_max_diff = max_diff(o_torch, o_ref) From cda4b19e8d84bf5475a6a60ef4c043a8999d2b34 Mon Sep 17 00:00:00 2001 From: Clement Chan Date: Tue, 4 Jun 2024 15:09:32 +0800 Subject: [PATCH 6/9] move call to philox_cuda_seed_offset & kernel launch into device guard to ensure execution on device other than cuda:0 --- src/flag_attn/dropout.py | 4 ++-- src/flag_attn/flash.py | 19 +++++++++---------- tests/flag_attn/test_flash_attention.py | 2 +- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/flag_attn/dropout.py b/src/flag_attn/dropout.py index 9dec72c..d0c0ef5 100644 --- a/src/flag_attn/dropout.py +++ b/src/flag_attn/dropout.py @@ -32,6 +32,6 @@ def dropout_mask(x, B, H, M, N, dropout_p, seed, offset): if dropout_p == 0: return dropout_mask grid = (M, B, H) - dropout_mask_kernel[grid](dropout_mask, B, H, M, N, dropout_p, seed, offset) + with torch.cuda.device(x.device): + dropout_mask_kernel[grid](dropout_mask, B, H, M, N, dropout_p, seed, offset) return dropout_mask - diff --git a/src/flag_attn/flash.py b/src/flag_attn/flash.py index cee7c05..43a630e 100644 --- a/src/flag_attn/flash.py +++ b/src/flag_attn/flash.py @@ -43,19 +43,20 @@ def forward(ctx, q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, re # contiguity q, k, v = maybe_contiguous(q), maybe_contiguous(k), maybe_contiguous(v) - # Dropout preparation. - is_dropout = dropout_p > 0 - if is_dropout: - offset_increment = B * H * M * N - seed, offset = philox_cuda_seed_offset(offset_increment) - else: - seed, offset = 0, 0 # to work around https://github.com/openai/triton/issues/2441 device = torch.cuda.device_of(q) num_sms = torch.cuda.get_device_properties(device).multi_processor_count with torch.cuda.device(device): + # Dropout preparation. + is_dropout = dropout_p > 0 + if is_dropout: + offset_increment = B * H * M * N + seed, offset = philox_cuda_seed_offset(offset_increment) + else: + seed, offset = 0, 0 + config_for_split_kv = get_fwd_config_kv_split(B, H, M, N, D, causal) S = num_splits_herustic(B, H, M, N, config_for_split_kv[0], config_for_split_kv[1], num_sms, 128) split_kv: bool = S > 1 @@ -71,7 +72,6 @@ def forward(ctx, q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, re grid = (triton.cdiv(M, BLOCK_M), H, B) o = torch.empty_like(q) L = torch.empty((B, H, M), device=q.device, dtype=torch.float32) - p = torch.empty((B, H, M, N), device=q.device, dtype=torch.float32) _fwd_kernel[grid]( q, k, v, sm_scale, @@ -378,7 +378,6 @@ def _fwd_kernel( offs_rng_base = offset + rowblock_base offs_rng_base += tl.arange(0, BLOCK_M)[:, None] * N offs_rng_base += tl.arange(0, BLOCK_N)[None, :] - p_base = rowblock_base + tl.arange(0, BLOCK_M)[:, None] * N + tl.arange(0, BLOCK_N)[None, :] # initialize pointers to value-like data q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) @@ -566,7 +565,7 @@ def _bwd_preprocess( # compute delta = tl.sum(o * do, axis=1) - + # (NOTE) dropout scaling doesn't affect delta's value # when dropout is applied, o and do are actually scaled. # original_o equals o times reverse scale while original_do is do times scale, diff --git a/tests/flag_attn/test_flash_attention.py b/tests/flag_attn/test_flash_attention.py index 3f31154..47f6545 100644 --- a/tests/flag_attn/test_flash_attention.py +++ b/tests/flag_attn/test_flash_attention.py @@ -295,7 +295,7 @@ def test_attention_bwd_dropout(B, H, M, N, D, causal, dropout_p, stride_order, d v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() do = torch.randn((B, M, H, D), dtype=dtype, device=device).transpose(1, 2) - + from flag_attn.dropout import dropout_mask # from flag_attn.dropout import philox_cuda_seed_offset # philox_cuda_seed_offset(increment) From 8b4fdbb93ffbc0dfce126e74e0f0c945736fd16f Mon Sep 17 00:00:00 2001 From: Clement Chan Date: Wed, 5 Jun 2024 10:24:23 +0800 Subject: [PATCH 7/9] remove unused parameters --- src/flag_attn/flash.py | 32 +++++++++++--------------------- 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/src/flag_attn/flash.py b/src/flag_attn/flash.py index 43a630e..e4968ee 100644 --- a/src/flag_attn/flash.py +++ b/src/flag_attn/flash.py @@ -43,7 +43,6 @@ def forward(ctx, q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, re # contiguity q, k, v = maybe_contiguous(q), maybe_contiguous(k), maybe_contiguous(v) - # to work around https://github.com/openai/triton/issues/2441 device = torch.cuda.device_of(q) num_sms = torch.cuda.get_device_properties(device).multi_processor_count @@ -73,8 +72,7 @@ def forward(ctx, q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, re o = torch.empty_like(q) L = torch.empty((B, H, M), device=q.device, dtype=torch.float32) _fwd_kernel[grid]( - q, k, v, - sm_scale, + q, k, v, sm_scale, dropout_p, seed, offset, L, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), @@ -99,8 +97,7 @@ def forward(ctx, q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, re grid = (triton.cdiv(M, BLOCK_M), S, H * B) N_SPLIT_SIZE = triton.cdiv(triton.cdiv(N, BLOCK_N), S) * BLOCK_N _fwd_split_kv_kernel[grid]( - q, k, v, - sm_scale, + q, k, v, sm_scale, multiple_l, multiple_o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), @@ -191,20 +188,15 @@ def backward(ctx, do, *ignored): divisible_m = M % BLOCK_M == 0 divisible_n = N % BLOCK_N == 0 - # reciprocal dropout scale - rp = 1 / (1 - dropout_p) - delta = torch.empty_like(L) grid = (triton.cdiv(M, BLOCK_M), H, B) _bwd_preprocess[grid]( o, do, delta, - rp, o.stride(0), o.stride(1), o.stride(2), o.stride(3), do.stride(0), do.stride(1), do.stride(2), do.stride(3), delta.stride(0), delta.stride(1), delta.stride(2), M, - IS_DROPOUT=is_dropout, BLOCK_M=BLOCK_M, D_HEAD=D, DIVISIBLE_M=divisible_m, ) @@ -214,9 +206,8 @@ def backward(ctx, do, *ignored): dv = torch.empty((B, H, N, D), dtype=v.dtype, device=q.device) grid = (triton.cdiv(N, BLOCK_N), H, B) _bwd_kv_kernel[grid]( - q, k, v, - sm_scale, - do, dk, dv, + q, k, v, sm_scale, do, + dk, dv, L, delta, dropout_p, seed, @@ -229,8 +220,8 @@ def backward(ctx, do, *ignored): dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), B, H, M, N, P_SEQ, num_groups, - BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, - CAUSAL=causal, IS_DROPOUT=is_dropout, + BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal, + IS_DROPOUT=is_dropout, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, num_stages=num_stages, num_warps=num_warps, ) @@ -281,7 +272,8 @@ def attention(q, k, v, causal=False, sm_scale=None, dropout_p=0.0, Returns: out(torch.Tensor): The output. The shape is (batch_size, num_heads_q, seqlen_q, headdim). - If `return_log_normalizer` or `return_total_attention`, return the following results in addition. + If `return_log_normalizer` or `return_total_attention` or `return_seed_offset` is True, + return the following results in addition. log_normalizer(torch.Tensor): The log normalizer. The shape is (batch_size, num_heads_q, seqlen_q). total_attention(torch.Tensor): The total attention. The shape is (batch_size, num_heads_q, seqlen_k). @@ -332,8 +324,7 @@ def get_fwd_config(B, H, M, N, D, causal): @triton.jit def _fwd_kernel( - Q, K, V, - sm_scale, + Q, K, V, sm_scale, dropout_p, seed, offset, @@ -532,13 +523,12 @@ def get_bwd_config(B, H, M, N, D, causal): def _bwd_preprocess( Out, DO, Delta, - rp, stride_oz, stride_oh, stride_om, stride_ok, stride_doz, stride_doh, stride_dom, stride_dok, stride_dz, stride_dh, stride_dm, M, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, - DIVISIBLE_M: tl.constexpr, IS_DROPOUT + DIVISIBLE_M: tl.constexpr, ): off_h = tl.program_id(1) off_z = tl.program_id(2) @@ -562,7 +552,6 @@ def _bwd_preprocess( o = tl.load(o_ptrs, mask=mask_m[:, None]).to(tl.float32) do = tl.load(do_ptrs, mask=mask_m[:, None]).to(tl.float32) - # compute delta = tl.sum(o * do, axis=1) @@ -864,6 +853,7 @@ def _bwd_q_kernel( v = tl.load(v_ptrs, mask=mask_n[:, None]) k = tl.load(k_ptrs, mask=mask_n[:, None]) + # recompute p = softmax(qk * sm_scale, dim=-1) if not DIVISIBLE_N: valid_mask = mask_n # & mask_m[:, None] From e49958ef3efda5013d904d131d79687c7ebc2676 Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Wed, 5 Jun 2024 16:53:40 +0800 Subject: [PATCH 8/9] Minor refactors. --- src/flag_attn/dropout.py | 22 ---------------------- src/flag_attn/testing/__init__.py | 2 +- src/flag_attn/testing/dropout.py | 25 +++++++++++++++++++++++++ tests/flag_attn/test_dropout.py | 9 +++++---- tests/flag_attn/test_flash_attention.py | 13 ++----------- 5 files changed, 33 insertions(+), 38 deletions(-) create mode 100644 src/flag_attn/testing/dropout.py diff --git a/src/flag_attn/dropout.py b/src/flag_attn/dropout.py index d0c0ef5..62ac39c 100644 --- a/src/flag_attn/dropout.py +++ b/src/flag_attn/dropout.py @@ -13,25 +13,3 @@ def philox_cuda_seed_offset(increment, device=None): # get_state returns a new tensor, so it needs set_state to update the actual generator state. gen.set_state(state_copy) return seed, offset - -@triton.jit -def dropout_mask_kernel(dropout_mask, B, H, M, N, dropout_p, seed, offset): - row, b, h = tl.program_id(0), tl.program_id(1), tl.program_id(2) - offs_base = b * H * M * N + h * M * N + row * N - BLOCK: tl.constexpr = 1024 - offs_base += tl.arange(0, BLOCK) - for start_n in range(0, N, BLOCK): - offs = start_n + offs_base - rng_offs = offset + offs - pmask = tl.rand(seed, rng_offs, n_rounds=6) > dropout_p - row_mask = start_n + tl.arange(0, BLOCK) < N - tl.store(dropout_mask + offs, pmask, mask=row_mask) - -def dropout_mask(x, B, H, M, N, dropout_p, seed, offset): - dropout_mask = torch.full((B, H, M, N), True, dtype=torch.bool, device=x.device) - if dropout_p == 0: - return dropout_mask - grid = (M, B, H) - with torch.cuda.device(x.device): - dropout_mask_kernel[grid](dropout_mask, B, H, M, N, dropout_p, seed, offset) - return dropout_mask diff --git a/src/flag_attn/testing/__init__.py b/src/flag_attn/testing/__init__.py index da03cff..02da1b1 100644 --- a/src/flag_attn/testing/__init__.py +++ b/src/flag_attn/testing/__init__.py @@ -1,4 +1,4 @@ from flag_attn.testing.flash import attention as flash_attention # noqa: F401 from flag_attn.testing.piecewise import attention as piecewise_attention # noqa: F401 from flag_attn.testing.paged import attention as paged_attention # noqa: F401 - +from flag_attn.testing.dropout import recompute_mask \ No newline at end of file diff --git a/src/flag_attn/testing/dropout.py b/src/flag_attn/testing/dropout.py new file mode 100644 index 0000000..fbf5106 --- /dev/null +++ b/src/flag_attn/testing/dropout.py @@ -0,0 +1,25 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def recompute_mask_kernel(mask, B, H, M, N, dropout_p, seed, offset): + row, b, h = tl.program_id(0), tl.program_id(1), tl.program_id(2) + offs_base = b * H * M * N + h * M * N + row * N + BLOCK: tl.constexpr = 1024 + offs_base += tl.arange(0, BLOCK) + for start_n in range(0, N, BLOCK): + offs = start_n + offs_base + rng_offs = offset + offs + pmask = tl.rand(seed, rng_offs, n_rounds=6) > dropout_p + row_mask = start_n + tl.arange(0, BLOCK) < N + tl.store(mask + offs, pmask, mask=row_mask) + +def recompute_mask(B, H, M, N, dropout_p, seed, offset, device): + mask = torch.full((B, H, M, N), True, dtype=torch.bool, device=device) + if dropout_p == 0: + return mask + grid = (M, B, H) + with torch.cuda.device(device): + recompute_mask_kernel[grid](mask, B, H, M, N, dropout_p, seed, offset) + return mask diff --git a/tests/flag_attn/test_dropout.py b/tests/flag_attn/test_dropout.py index 98b16df..3128c29 100644 --- a/tests/flag_attn/test_dropout.py +++ b/tests/flag_attn/test_dropout.py @@ -1,8 +1,8 @@ import torch import pytest -from flag_attn.dropout import dropout_mask +from flag_attn.testing import recompute_mask + -@pytest.mark.parametrize('x', [torch.empty(1, device='cuda')]) @pytest.mark.parametrize('B, H, M, N', [ (2, 4, 512, 612), (2, 4, 1024, 1034), @@ -15,11 +15,12 @@ (1, 2, 8192, 8192), ]) @pytest.mark.parametrize('p', [0.5, 0.8]) -def test_dropout_mask(x, B, H, M, N, p): +def test_recompute_mask(B, H, M, N, p): import math seed = 123456789 offset = 123456789123456789 - mask = dropout_mask(x, B, H, M, N, p, seed, offset) + device = torch.cuda.current_device() + mask = recompute_mask(B, H, M, N, p, seed, offset, device) # zeros indicate to drop # k follows Binomial distributio B(k; n, p) n = mask.numel() diff --git a/tests/flag_attn/test_flash_attention.py b/tests/flag_attn/test_flash_attention.py index 47f6545..3f512c0 100644 --- a/tests/flag_attn/test_flash_attention.py +++ b/tests/flag_attn/test_flash_attention.py @@ -250,9 +250,8 @@ def test_attention_fwd_dropout(B, H, M, N, D, causal, dropout_p, stride_order, d k = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) - from flag_attn.dropout import dropout_mask o_hyp, _, _, seed, offset = flag_attn.flash_attention(q, k, v, causal, dropout_p=dropout_p, return_seed_offset=True) - mask = dropout_mask(q, B, H, M, N, dropout_p, seed, offset) + mask = flag_attn.testing.recompute_mask(B, H, M, N, dropout_p, seed, offset, device) o_ref = flag_attn.testing.flash_attention(q, k, v, causal, dropout_p=dropout_p, dropout_mask=mask, upcast=True) o_torch = flag_attn.testing.flash_attention(q, k, v, causal, dropout_p=dropout_p, dropout_mask=mask, upcast=False) @@ -295,13 +294,9 @@ def test_attention_bwd_dropout(B, H, M, N, D, causal, dropout_p, stride_order, d v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() do = torch.randn((B, M, H, D), dtype=dtype, device=device).transpose(1, 2) - - from flag_attn.dropout import dropout_mask # from flag_attn.dropout import philox_cuda_seed_offset - # philox_cuda_seed_offset(increment) o_hyp, _, _, seed, offset = flag_attn.flash_attention(q, k, v, causal=causal, dropout_p=dropout_p, return_seed_offset=True) - mask = dropout_mask(q, B, H, M, N, dropout_p, seed, offset) - # print('mask', mask.to(torch.int)) + mask = flag_attn.testing.recompute_mask(B, H, M, N, dropout_p, seed, offset, device) o_ref = flag_attn.testing.flash_attention(q, k, v, causal=causal, dropout_p=dropout_p, dropout_mask=mask, upcast=True) o_torch = flag_attn.testing.flash_attention(q, k, v, causal=causal, dropout_p=dropout_p, dropout_mask=mask, upcast=False) @@ -309,10 +304,6 @@ def test_attention_bwd_dropout(B, H, M, N, D, causal, dropout_p, stride_order, d gq_torch, gk_torch, gv_torch = torch.autograd.grad(o_torch, (q, k, v), do) gq_hyp, gk_hyp, gv_hyp = torch.autograd.grad(o_hyp, (q, k, v), do) - # print('dv_ref', gv_ref) - # print('dv', gv_hyp) - # print('dv_torch', gv_torch) - o_torch_max_diff = max_diff(o_torch, o_ref) gq_torch_max_diff = max_diff(gq_torch, gq_ref) gk_torch_max_diff = max_diff(gk_torch, gk_ref) From e5c408a5563dedda65f85213d4e97317e5edfe3f Mon Sep 17 00:00:00 2001 From: Clement Chan Date: Wed, 5 Jun 2024 17:27:22 +0800 Subject: [PATCH 9/9] add dropout into feature list in README --- README.md | 2 +- README_cn.md | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c8414cb..157dc57 100644 --- a/README.md +++ b/README.md @@ -234,11 +234,11 @@ The performance of piecewise_attention has improved compared to that in v0.1. In - support computation of total attention of each `k` gets from all `q`'s; - supports returning accumulative attention of each keys. - supports [MQA](https://arxiv.org/abs/1911.02150) and [GQA](https://arxiv.org/pdf/2305.13245). +- supports dropout of attention weights. #### Limitations - `headdim` should be in `[16, 32, 64, 128]`. -- dropout of attention weights is not supported yet. ## TODOs diff --git a/README_cn.md b/README_cn.md index 3075b53..8bdab2e 100644 --- a/README_cn.md +++ b/README_cn.md @@ -224,11 +224,12 @@ print(gq) - 支持前向和反向计算; - K/V 的序列长度可以不等于 Q 的序列长度; - 支持计算每个 k 从所有 q 得到的 attention 总和。 +- 支持 [MQA](https://arxiv.org/abs/1911.02150) and [GQA](https://arxiv.org/pdf/2305.13245). +- 支持对 attention weights 进行 dropout. #### 限制 - `headdim` 必须为 `[16, 32, 64, 128]` 之一; -- 尚未支持对 attention weight 使用 dropout。 ## TODOs