Skip to content

Commit

Permalink
Supporting memory efficient dropout in flash attention (#23)
Browse files Browse the repository at this point in the history
1. add dropout to regular flash attention.
2. add philox_cuda_seed_offset to increment offset of pytorch's philox random generator's state.
---------

Co-authored-by: Clement Chan <[email protected]>
  • Loading branch information
tongxin and iclementine authored Jun 5, 2024
1 parent 13664fc commit ee91638
Show file tree
Hide file tree
Showing 9 changed files with 308 additions and 23 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,11 @@ The performance of piecewise_attention has improved compared to that in v0.1. In
- support computation of total attention of each `k` gets from all `q`'s;
- supports returning accumulative attention of each keys.
- supports [MQA](https://arxiv.org/abs/1911.02150) and [GQA](https://arxiv.org/pdf/2305.13245).
- supports dropout of attention weights.

#### Limitations

- `headdim` should be in `[16, 32, 64, 128]`.
- dropout of attention weights is not supported yet.

## TODOs

Expand Down
3 changes: 2 additions & 1 deletion README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,12 @@ print(gq)
- 支持前向和反向计算;
- K/V 的序列长度可以不等于 Q 的序列长度;
- 支持计算每个 k 从所有 q 得到的 attention 总和。
- 支持 [MQA](https://arxiv.org/abs/1911.02150) and [GQA](https://arxiv.org/pdf/2305.13245).
- 支持对 attention weights 进行 dropout.

#### 限制

- `headdim` 必须为 `[16, 32, 64, 128]` 之一;
- 尚未支持对 attention weight 使用 dropout。

## TODOs

Expand Down
15 changes: 15 additions & 0 deletions src/flag_attn/dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
import triton
import triton.language as tl

def philox_cuda_seed_offset(increment, device=None):
device = device or torch.cuda.current_device()
gen = torch.cuda.default_generators[device]
state_copy = gen.get_state()
c0, c1 = state_copy.view(torch.int64)
seed, offset = int(c0), int(c1)
increment = (increment + 3) // 4 * 4
c1 += increment
# get_state returns a new tensor, so it needs set_state to update the actual generator state.
gen.set_state(state_copy)
return seed, offset
Loading

0 comments on commit ee91638

Please sign in to comment.