Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting memory efficient dropout in flash attention #23

Merged
merged 10 commits into from
Jun 5, 2024
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,11 @@ The performance of piecewise_attention has improved compared to that in v0.1. In
- support computation of total attention of each `k` gets from all `q`'s;
- supports returning accumulative attention of each keys.
- supports [MQA](https://arxiv.org/abs/1911.02150) and [GQA](https://arxiv.org/pdf/2305.13245).
- supports dropout of attention weights.

#### Limitations

- `headdim` should be in `[16, 32, 64, 128]`.
- dropout of attention weights is not supported yet.

## TODOs

Expand Down
3 changes: 2 additions & 1 deletion README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,12 @@ print(gq)
- 支持前向和反向计算;
- K/V 的序列长度可以不等于 Q 的序列长度;
- 支持计算每个 k 从所有 q 得到的 attention 总和。
- 支持 [MQA](https://arxiv.org/abs/1911.02150) and [GQA](https://arxiv.org/pdf/2305.13245).
- 支持对 attention weights 进行 dropout.

#### 限制

- `headdim` 必须为 `[16, 32, 64, 128]` 之一;
- 尚未支持对 attention weight 使用 dropout。

## TODOs

Expand Down
15 changes: 15 additions & 0 deletions src/flag_attn/dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch
import triton
import triton.language as tl

def philox_cuda_seed_offset(increment, device=None):
device = device or torch.cuda.current_device()
gen = torch.cuda.default_generators[device]
state_copy = gen.get_state()
c0, c1 = state_copy.view(torch.int64)
seed, offset = int(c0), int(c1)
increment = (increment + 3) // 4 * 4
c1 += increment
# get_state returns a new tensor, so it needs set_state to update the actual generator state.
gen.set_state(state_copy)
return seed, offset
Loading