From ee91638dec6da8c00c4113d179f469e0ffcd5852 Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Wed, 5 Jun 2024 17:41:11 +0800 Subject: [PATCH] Supporting memory efficient dropout in flash attention (#23) 1. add dropout to regular flash attention. 2. add philox_cuda_seed_offset to increment offset of pytorch's philox random generator's state. --------- Co-authored-by: Clement Chan --- README.md | 2 +- README_cn.md | 3 +- src/flag_attn/dropout.py | 15 +++ src/flag_attn/flash.py | 140 +++++++++++++++++++++--- src/flag_attn/testing/__init__.py | 2 +- src/flag_attn/testing/dropout.py | 25 +++++ src/flag_attn/testing/flash.py | 15 ++- tests/flag_attn/test_dropout.py | 30 +++++ tests/flag_attn/test_flash_attention.py | 99 ++++++++++++++++- 9 files changed, 308 insertions(+), 23 deletions(-) create mode 100644 src/flag_attn/dropout.py create mode 100644 src/flag_attn/testing/dropout.py create mode 100644 tests/flag_attn/test_dropout.py diff --git a/README.md b/README.md index c8414cb..157dc57 100644 --- a/README.md +++ b/README.md @@ -234,11 +234,11 @@ The performance of piecewise_attention has improved compared to that in v0.1. In - support computation of total attention of each `k` gets from all `q`'s; - supports returning accumulative attention of each keys. - supports [MQA](https://arxiv.org/abs/1911.02150) and [GQA](https://arxiv.org/pdf/2305.13245). +- supports dropout of attention weights. #### Limitations - `headdim` should be in `[16, 32, 64, 128]`. -- dropout of attention weights is not supported yet. ## TODOs diff --git a/README_cn.md b/README_cn.md index 3075b53..8bdab2e 100644 --- a/README_cn.md +++ b/README_cn.md @@ -224,11 +224,12 @@ print(gq) - 支持前向和反向计算; - K/V 的序列长度可以不等于 Q 的序列长度; - 支持计算每个 k 从所有 q 得到的 attention 总和。 +- 支持 [MQA](https://arxiv.org/abs/1911.02150) and [GQA](https://arxiv.org/pdf/2305.13245). +- 支持对 attention weights 进行 dropout. #### 限制 - `headdim` 必须为 `[16, 32, 64, 128]` 之一; -- 尚未支持对 attention weight 使用 dropout。 ## TODOs diff --git a/src/flag_attn/dropout.py b/src/flag_attn/dropout.py new file mode 100644 index 0000000..62ac39c --- /dev/null +++ b/src/flag_attn/dropout.py @@ -0,0 +1,15 @@ +import torch +import triton +import triton.language as tl + +def philox_cuda_seed_offset(increment, device=None): + device = device or torch.cuda.current_device() + gen = torch.cuda.default_generators[device] + state_copy = gen.get_state() + c0, c1 = state_copy.view(torch.int64) + seed, offset = int(c0), int(c1) + increment = (increment + 3) // 4 * 4 + c1 += increment + # get_state returns a new tensor, so it needs set_state to update the actual generator state. + gen.set_state(state_copy) + return seed, offset diff --git a/src/flag_attn/flash.py b/src/flag_attn/flash.py index 73c4bac..e4968ee 100644 --- a/src/flag_attn/flash.py +++ b/src/flag_attn/flash.py @@ -6,6 +6,8 @@ from flag_attn.split_kv import _fwd_split_kv_kernel, _fwd_combine_kv_splits, num_splits_herustic from flag_attn.split_kv import get_fwd_config as get_fwd_config_kv_split +from .dropout import philox_cuda_seed_offset + __all__ = ["attention"] @@ -20,8 +22,7 @@ def rounded_multiple(a, b): # --------------------------- public API --------------------------- class FlashAttention(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_attention): - # size, stride, dtype checking + def forward(ctx, q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, return_total_attention, return_seed_offset): Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1] assert Dq == Dk == Dv, "feature size of q, k, v should be equal" assert Dk in {16, 32, 64, 128} @@ -47,6 +48,14 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ num_sms = torch.cuda.get_device_properties(device).multi_processor_count with torch.cuda.device(device): + # Dropout preparation. + is_dropout = dropout_p > 0 + if is_dropout: + offset_increment = B * H * M * N + seed, offset = philox_cuda_seed_offset(offset_increment) + else: + seed, offset = 0, 0 + config_for_split_kv = get_fwd_config_kv_split(B, H, M, N, D, causal) S = num_splits_herustic(B, H, M, N, config_for_split_kv[0], config_for_split_kv[1], num_sms, 128) split_kv: bool = S > 1 @@ -64,6 +73,7 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ L = torch.empty((B, H, M), device=q.device, dtype=torch.float32) _fwd_kernel[grid]( q, k, v, sm_scale, + dropout_p, seed, offset, L, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), @@ -71,11 +81,12 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ o.stride(0), o.stride(1), o.stride(2), o.stride(3), B, H, M, N, P_SEQ, num_groups, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D, - IS_CAUSAL=causal, LARGER_M=larger_m, + IS_CAUSAL=causal, IS_DROPOUT=is_dropout, LARGER_M=larger_m, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, num_warps=num_warps, num_stages=num_stages, ) else: # split kv + assert not is_dropout, "Cannot apply dropout with splitkv." BLOCK_M, BLOCK_N, num_stages, num_warps = config_for_split_kv divisible_m = M % BLOCK_M == 0 @@ -132,13 +143,18 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ ctx.save_for_backward(q, k, v, o, L) ctx.sm_scale = sm_scale ctx.causal = causal + ctx.dropout_p = dropout_p + ctx.seed = seed + ctx.offset = offset - has_extra_return = return_log_normalizer or return_total_attention + has_extra_return = True in (return_log_normalizer, return_total_attention, return_seed_offset) if has_extra_return: outs = ( o, L if return_log_normalizer else None, - tot_attn if return_total_attention else None + tot_attn if return_total_attention else None, + seed if is_dropout and return_seed_offset else None, + offset if is_dropout and return_seed_offset else None ) return outs return o @@ -148,6 +164,10 @@ def backward(ctx, do, *ignored): q, k, v, o, L = ctx.saved_tensors sm_scale = ctx.sm_scale causal = ctx.causal + dropout_p = ctx.dropout_p + is_dropout = ctx.dropout_p > 0 + seed = ctx.seed + offset = ctx.offset B, H, M, D = q.shape N = k.shape[2] @@ -189,6 +209,9 @@ def backward(ctx, do, *ignored): q, k, v, sm_scale, do, dk, dv, L, delta, + dropout_p, + seed, + offset, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), @@ -198,6 +221,7 @@ def backward(ctx, do, *ignored): B, H, M, N, P_SEQ, num_groups, BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal, + IS_DROPOUT=is_dropout, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, num_stages=num_stages, num_warps=num_warps, ) @@ -208,6 +232,9 @@ def backward(ctx, do, *ignored): q, k, v, sm_scale, do, dq, L, delta, + dropout_p, + seed, + offset, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), @@ -216,17 +243,17 @@ def backward(ctx, do, *ignored): B, H, M, N, P_SEQ, num_groups, BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, - CAUSAL=causal, LARGER_M=larger_m, + CAUSAL=causal, IS_DROPOUT=is_dropout, LARGER_M=larger_m, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, num_stages=num_stages, num_warps = num_warps, ) dk = dk.reshape((B, Hk, num_groups, N, D)).sum(2) dv = dv.reshape((B, Hk, num_groups, N, D)).sum(2) - return dq, dk, dv, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None -def attention(q, k, v, causal=False, sm_scale=None, - return_log_normalizer=False, return_total_attention=False, +def attention(q, k, v, causal=False, sm_scale=None, dropout_p=0.0, + return_log_normalizer=False, return_total_attention=False, return_seed_offset=False ): """ An implementation of FlashAttention v2(https://arxiv.org/abs/2307.08691). @@ -237,21 +264,26 @@ def attention(q, k, v, causal=False, sm_scale=None, v(torch.Tensor): The values. The shape is (batch_size, num_heads_k, seqlen_k, headdim). causal(bool): Whether causal masking is applied to attention scores before applying softmax. sm_scale(float): The scaling of attention scores before applying softmax. + dropout_p(float): Dropout probability. return_log_normalizer(bool): Whether to return the log normalizer of softmax inside attention. return_total_attention(bool): Whether to return the sum of attention along q's sequence dimendion. + return_seed_offset(bool): Whether to return dropout seed and offset Returns: out(torch.Tensor): The output. The shape is (batch_size, num_heads_q, seqlen_q, headdim). - If `return_log_normalizer` or `return_total_attention`, return the following results in addition. + If `return_log_normalizer` or `return_total_attention` or `return_seed_offset` is True, + return the following results in addition. log_normalizer(torch.Tensor): The log normalizer. The shape is (batch_size, num_heads_q, seqlen_q). total_attention(torch.Tensor): The total attention. The shape is (batch_size, num_heads_q, seqlen_k). + seed(int): The Philox seed used in dropout. + offset(int): The starting Philox offset used in dropout. Notes: `num_heads_q` must be a multiple of `num_heads_k`. """ - return FlashAttention.apply(q, k, v, causal, sm_scale, return_log_normalizer, return_total_attention) + return FlashAttention.apply(q, k, v, causal, sm_scale, dropout_p, return_log_normalizer, return_total_attention, return_seed_offset) # --------------------------- Forward --------------------------- @@ -293,6 +325,9 @@ def get_fwd_config(B, H, M, N, D, causal): @triton.jit def _fwd_kernel( Q, K, V, sm_scale, + dropout_p, + seed, + offset, L, O, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, @@ -301,7 +336,7 @@ def _fwd_kernel( Z, H, M, N, P_SEQ, num_groups, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, + IS_CAUSAL: tl.constexpr, IS_DROPOUT: tl.constexpr, LARGER_M: tl.constexpr, DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, ): input_dtype = Q.dtype.element_ty @@ -329,6 +364,12 @@ def _fwd_kernel( offs_n_base = tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_DMODEL) + if IS_DROPOUT: + rowblock_base = off_z * H * M * N + off_h * M * N + start_m * BLOCK_M * N + offs_rng_base = offset + rowblock_base + offs_rng_base += tl.arange(0, BLOCK_M)[:, None] * N + offs_rng_base += tl.arange(0, BLOCK_N)[None, :] + # initialize pointers to value-like data q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) # (BLOCK_M, BLOCK_DMODEL) o_ptrs = O + (offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok) # (BLOCK_M, BLOCK_DMODEL) @@ -406,12 +447,21 @@ def _fwd_kernel( alpha = tl.math.exp2((m_i - m_i_new) * qk_scale) p = tl.math.exp2(s * qk_scale - m_i_new[:, None] * qk_scale) + # -- compute partial sumexpn before applying dropout + p_sum = tl.sum(p, 1) + + # -- apply dropout -- + if IS_DROPOUT: + offs_rng = start_n + offs_rng_base + pmask = tl.rand(seed, offs_rng, n_rounds=6) > dropout_p + p *= pmask.to(tl.float32) + # -- scale and update acc: acc *= alpha[:, None]-- acc *= alpha[:, None] acc += tl.dot(p.to(input_dtype), v) # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) + l_i = l_i * alpha + p_sum m_i = m_i_new # update pointers k_ptrs += BLOCK_N * stride_kn @@ -426,6 +476,11 @@ def _fwd_kernel( acc = acc * (1.0 / l_i[:, None]) l = m_i * sm_scale + tl.log(l_i) # log(normalizer) + # -- scale o due to dropout + if IS_DROPOUT: + scale = 1.0 / (1.0 - dropout_p) + acc *= scale + if DIVISIBLE_M: tl.store(l_ptrs, l, cache_modifier=".cg") tl.store(o_ptrs, acc.to(input_dtype), cache_modifier=".cg") @@ -499,6 +554,12 @@ def _bwd_preprocess( # compute delta = tl.sum(o * do, axis=1) + + # (NOTE) dropout scaling doesn't affect delta's value + # when dropout is applied, o and do are actually scaled. + # original_o equals o times reverse scale while original_do is do times scale, + # and thus delta remains unchanged. + # write-back d_ptrs = Delta + off_m * stride_dm if DIVISIBLE_M: @@ -513,6 +574,9 @@ def _bwd_kv_kernel( DK, DV, L, D, + dropout_p, + seed, + offset, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, @@ -522,7 +586,7 @@ def _bwd_kv_kernel( Z, H, M, N, P_SEQ, num_groups, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - CAUSAL: tl.constexpr, + CAUSAL: tl.constexpr, IS_DROPOUT: tl.constexpr, DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, ): input_dtype = Q.dtype.element_ty @@ -577,6 +641,14 @@ def _bwd_kv_kernel( v = tl.load(v_ptrs, mask=mask_n[:, None]) k = tl.load(k_ptrs, mask=mask_n[:, None]) + # dropout + if IS_DROPOUT: + colblock_base = off_z * H * M * N + off_h * M * N + start_n * BLOCK_N + offs_rng_base = offset + colblock_base + offs_rng_base += tl.arange(0, BLOCK_M)[:, None] * N + offs_rng_base += tl.arange(0, BLOCK_N)[None, :] + rp = 1. / (1. - dropout_p) + # initialize dk amd dv dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) @@ -621,7 +693,19 @@ def _bwd_kv_kernel( do = tl.load(do_ptrs) else: do = tl.load(do_ptrs, mask=mask_m[:, None]) # (BLOCK_M, BLOCK_DMODEL) - dv += tl.dot(tl.trans(p.to(do.dtype)), do) # (BLOCK_N, BLOCK_DMODEL) # still correct + + if IS_DROPOUT: + # do *= rp + offs_rng = offs_rng_base + start_m * N + pmask = tl.rand(seed, offs_rng, n_rounds=6) > dropout_p + p_masked = p * pmask + p_masked = p_masked.to(input_dtype) + + # -- apply dropout -- + if IS_DROPOUT: + dv += tl.dot(tl.trans(p_masked), do) * rp # (BLOCK_N, BLOCK_DMODEL) # still correct + else: + dv += tl.dot(tl.trans(p).to(input_dtype), do) # (BLOCK_N, BLOCK_DMODEL) # still correct # compute dp = dot(v, do) if DIVISIBLE_M: @@ -631,6 +715,11 @@ def _bwd_kv_kernel( dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) dp += tl.dot(do, tl.trans(v)) + # -- apply dropout -- + if IS_DROPOUT: + dp *= rp + dp *= pmask + # compute ds = p * (dp - delta[:, None]) ds = p * (dp - delta[:, None]) # (BLOCK_M, BLOCK_N) @@ -662,6 +751,9 @@ def _bwd_q_kernel( DQ, L, D, + dropout_p, + seed, + offset, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, @@ -670,7 +762,7 @@ def _bwd_q_kernel( Z, H, M, N, P_SEQ, num_groups, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, + CAUSAL: tl.constexpr, IS_DROPOUT: tl.constexpr, LARGER_M: tl.constexpr, DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, ): input_dtype = Q.dtype.element_ty @@ -739,6 +831,15 @@ def _bwd_q_kernel( else: hi = N + # dropout + if IS_DROPOUT: + rowblock_base = off_z * H * M * N + off_h * M * N + start_m * BLOCK_M * N + offs_rng_base = offset + rowblock_base + offs_rng_base += tl.arange(0, BLOCK_M)[:, None] * N + offs_rng_base += tl.arange(0, BLOCK_N)[None, :] + rp = 1. / (1. - dropout_p) + do *= rp.to(do.dtype) + # loop over a row for start_n in range(0, hi, BLOCK_N): offs_n = start_n + offs_n_base @@ -772,6 +873,13 @@ def _bwd_q_kernel( # compute dp = dot(v, do) dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) dp += tl.dot(do.to(input_dtype), tl.trans(v)) + + if IS_DROPOUT: + offs_rng = start_n + offs_rng_base + pmask = tl.rand(seed, offs_rng, n_rounds=6) > dropout_p + dp *= pmask + # p_dropout = p * pmask.to(tl.float32) + # no need to mask dp # if CAUSAL: # dp = tl.where(causal_mask & valid_mask, dp, 0.0) diff --git a/src/flag_attn/testing/__init__.py b/src/flag_attn/testing/__init__.py index da03cff..02da1b1 100644 --- a/src/flag_attn/testing/__init__.py +++ b/src/flag_attn/testing/__init__.py @@ -1,4 +1,4 @@ from flag_attn.testing.flash import attention as flash_attention # noqa: F401 from flag_attn.testing.piecewise import attention as piecewise_attention # noqa: F401 from flag_attn.testing.paged import attention as paged_attention # noqa: F401 - +from flag_attn.testing.dropout import recompute_mask \ No newline at end of file diff --git a/src/flag_attn/testing/dropout.py b/src/flag_attn/testing/dropout.py new file mode 100644 index 0000000..fbf5106 --- /dev/null +++ b/src/flag_attn/testing/dropout.py @@ -0,0 +1,25 @@ +import torch +import triton +import triton.language as tl + +@triton.jit +def recompute_mask_kernel(mask, B, H, M, N, dropout_p, seed, offset): + row, b, h = tl.program_id(0), tl.program_id(1), tl.program_id(2) + offs_base = b * H * M * N + h * M * N + row * N + BLOCK: tl.constexpr = 1024 + offs_base += tl.arange(0, BLOCK) + for start_n in range(0, N, BLOCK): + offs = start_n + offs_base + rng_offs = offset + offs + pmask = tl.rand(seed, rng_offs, n_rounds=6) > dropout_p + row_mask = start_n + tl.arange(0, BLOCK) < N + tl.store(mask + offs, pmask, mask=row_mask) + +def recompute_mask(B, H, M, N, dropout_p, seed, offset, device): + mask = torch.full((B, H, M, N), True, dtype=torch.bool, device=device) + if dropout_p == 0: + return mask + grid = (M, B, H) + with torch.cuda.device(device): + recompute_mask_kernel[grid](mask, B, H, M, N, dropout_p, seed, offset) + return mask diff --git a/src/flag_attn/testing/flash.py b/src/flag_attn/testing/flash.py index 4314019..5c092dd 100644 --- a/src/flag_attn/testing/flash.py +++ b/src/flag_attn/testing/flash.py @@ -1,10 +1,13 @@ import math import torch + def attention(q, k, v, causal, + dropout_p=0.0, + dropout_mask=None, sm_scale=None, return_log_normalizer=False, return_total_attention=False, @@ -49,13 +52,19 @@ def attention(q, if return_total_attention: tot_attn = torch.sum(P, dim=-2) - attn_output = torch.matmul(P.to(v.dtype), v).to(input_dtype) + # Applies dropout + dropout_scaling = 1.0 / (1 - dropout_p) + if dropout_mask is not None: + P = P.masked_fill(~dropout_mask, 0.0) + + attn_output = torch.matmul(P.to(v.dtype), v) * dropout_scaling + attn_output = attn_output.to(input_dtype) has_extra_return = return_log_normalizer or return_total_attention if has_extra_return: outs = (attn_output, - log_normalizer if return_log_normalizer else None, - tot_attn if return_total_attention else None) + log_normalizer if return_log_normalizer else None, + tot_attn if return_total_attention else None) return outs else: return attn_output diff --git a/tests/flag_attn/test_dropout.py b/tests/flag_attn/test_dropout.py new file mode 100644 index 0000000..3128c29 --- /dev/null +++ b/tests/flag_attn/test_dropout.py @@ -0,0 +1,30 @@ +import torch +import pytest +from flag_attn.testing import recompute_mask + + +@pytest.mark.parametrize('B, H, M, N', [ + (2, 4, 512, 612), + (2, 4, 1024, 1034), + (2, 4, 2048, 2048), + (2, 4, 4096, 4096), + (2, 4, 4001, 4001), + (2, 4, 4001, 4096), + (2, 4, 4096, 4000), + (1, 2, 8192, 8202), + (1, 2, 8192, 8192), +]) +@pytest.mark.parametrize('p', [0.5, 0.8]) +def test_recompute_mask(B, H, M, N, p): + import math + seed = 123456789 + offset = 123456789123456789 + device = torch.cuda.current_device() + mask = recompute_mask(B, H, M, N, p, seed, offset, device) + # zeros indicate to drop + # k follows Binomial distributio B(k; n, p) + n = mask.numel() + k = torch.sum(mask == 0) + p_cap = k / n + tol = 0.01 + assert math.fabs(p_cap - p) < tol * p diff --git a/tests/flag_attn/test_flash_attention.py b/tests/flag_attn/test_flash_attention.py index df6ca20..3f512c0 100644 --- a/tests/flag_attn/test_flash_attention.py +++ b/tests/flag_attn/test_flash_attention.py @@ -206,7 +206,7 @@ def test_attention_with_aux_outs(B, H, M, N, D, causal, stride_order, dtype, sca o_ref, log_norm_ref, tot_attn_ref = flag_attn.testing.flash_attention(q, k, v, causal, return_log_normalizer=True, return_total_attention=True, upcast=True) o_torch, log_norm_torch, tot_attn_torch = flag_attn.testing.flash_attention(q, k, v, causal, return_log_normalizer=True, return_total_attention=True, upcast=False) - o_hyp, log_norm_hyp, tot_attn_hyp = flag_attn.flash_attention(q, k, v, causal, return_log_normalizer=True, return_total_attention=True) + o_hyp, log_norm_hyp, tot_attn_hyp, *_ = flag_attn.flash_attention(q, k, v, causal, return_log_normalizer=True, return_total_attention=True) torch_max_diff = max_diff(o_torch, o_ref) @@ -221,3 +221,100 @@ def test_attention_with_aux_outs(B, H, M, N, D, causal, stride_order, dtype, sca triton_max_diff = max_diff(tot_attn_hyp, tot_attn_ref) assert triton_max_diff <= 2 * torch_max_diff + 1e-5 + +@pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) +@pytest.mark.parametrize('scale', [1.0, 2.0]) +@pytest.mark.parametrize('B, H, M, N, D', [ + (2, 4, 512, 612, 128), + (2, 4, 1024, 1034, 64), + (2, 4, 2048, 2048, 32), + (2, 4, 4096, 4096, 16), + (2, 4, 4001, 4001, 32), + (2, 4, 4001, 4096, 64), + (2, 4, 4096, 4000, 128), + (1, 2, 8192, 8202, 16), + (1, 2, 8192, 8192, 32), +]) +@pytest.mark.parametrize('causal', [False, True]) +@pytest.mark.parametrize('dropout_p', [0.5, 0.8]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) +def test_attention_fwd_dropout(B, H, M, N, D, causal, dropout_p, stride_order, dtype, scale, device_id): + device = f"cuda:{device_id}" + if stride_order == "BHTD": + q = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) + k = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) + v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) + else: + q = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) + k = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) + v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) + + o_hyp, _, _, seed, offset = flag_attn.flash_attention(q, k, v, causal, dropout_p=dropout_p, return_seed_offset=True) + mask = flag_attn.testing.recompute_mask(B, H, M, N, dropout_p, seed, offset, device) + o_ref = flag_attn.testing.flash_attention(q, k, v, causal, dropout_p=dropout_p, dropout_mask=mask, upcast=True) + o_torch = flag_attn.testing.flash_attention(q, k, v, causal, dropout_p=dropout_p, dropout_mask=mask, upcast=False) + + torch_max_diff = max_diff(o_torch, o_ref) + triton_max_diff = max_diff(o_hyp, o_ref) + report("o hyp", o_hyp, o_ref) + report("o torch", o_torch, o_ref) + assert triton_max_diff <= 2 * torch_max_diff + 1e-5 + + +import random +# @pytest.mark.parametrize('increment', [random.randint(0, 1000000000) for i in range(100)]) +@pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) +@pytest.mark.parametrize('scale', [1.0, 2.0]) +@pytest.mark.parametrize('B, H, M, N, D', [ + (2, 4, 512, 612, 128), + (2, 4, 1024, 1034, 64), + (2, 4, 2048, 2048, 32), + (2, 4, 4096, 4096, 16), + (2, 4, 4001, 4001, 32), + (2, 4, 4001, 4096, 64), + (2, 4, 4096, 4000, 128), + (1, 2, 8192, 8202, 16), + (1, 2, 8192, 8192, 32), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('dropout_p', [0.5, 0.8]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) +def test_attention_bwd_dropout(B, H, M, N, D, causal, dropout_p, stride_order, dtype, scale, device_id): + device = f"cuda:{device_id}" + if stride_order == "BHTD": + q = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() + k = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() + v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() + do = torch.randn((B, H, M, D), dtype=dtype, device=device) + else: + q = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() + k = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() + v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() + do = torch.randn((B, M, H, D), dtype=dtype, device=device).transpose(1, 2) + + # from flag_attn.dropout import philox_cuda_seed_offset + o_hyp, _, _, seed, offset = flag_attn.flash_attention(q, k, v, causal=causal, dropout_p=dropout_p, return_seed_offset=True) + mask = flag_attn.testing.recompute_mask(B, H, M, N, dropout_p, seed, offset, device) + o_ref = flag_attn.testing.flash_attention(q, k, v, causal=causal, dropout_p=dropout_p, dropout_mask=mask, upcast=True) + o_torch = flag_attn.testing.flash_attention(q, k, v, causal=causal, dropout_p=dropout_p, dropout_mask=mask, upcast=False) + + gq_ref, gk_ref, gv_ref = torch.autograd.grad(o_ref, (q, k, v), do) + gq_torch, gk_torch, gv_torch = torch.autograd.grad(o_torch, (q, k, v), do) + gq_hyp, gk_hyp, gv_hyp = torch.autograd.grad(o_hyp, (q, k, v), do) + + o_torch_max_diff = max_diff(o_torch, o_ref) + gq_torch_max_diff = max_diff(gq_torch, gq_ref) + gk_torch_max_diff = max_diff(gk_torch, gk_ref) + gv_torch_max_diff = max_diff(gv_torch, gv_ref) + + o_triton_max_diff = max_diff(o_hyp, o_ref) + gq_triton_max_diff = max_diff(gq_hyp, gq_ref) + gk_triton_max_diff = max_diff(gk_hyp, gk_ref) + gv_triton_max_diff = max_diff(gv_hyp, gv_ref) + + assert o_triton_max_diff < 2 * o_torch_max_diff + 1e-5 + assert gq_triton_max_diff < 2 * gq_torch_max_diff + 1e-5 + assert gk_triton_max_diff < 2 * gk_torch_max_diff + 1e-5 + assert gv_triton_max_diff < 2 * gv_torch_max_diff + 1e-5 \ No newline at end of file