Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flash decoding(flash attention with split_kv) #17

Merged
merged 4 commits into from
Feb 5, 2024

Conversation

iclementine
Copy link
Collaborator

@iclementine iclementine commented Feb 1, 2024

Implement flash decoding(flash attention with split_kv) https://princeton-nlp.github.io/flash-decoding/. This algorithm is used when batch_size * num_heads * blocks_along_seqlen_q cannot saturate the gpu's SM's.

Benchmark results o RTX-3090.

batch_size=2, num_heads=32, seqlen_q =1, seqlen_k=N_CTX

attention_d-64_dtype-torch.float16 (ms):
       N_CTX  flag_attn      torch    flash-2
0      512.0   0.030585   0.037478   0.026489
1     1024.0   0.045577   0.047122   0.045771
2     2048.0   0.055814   0.068738   0.084344
3     4096.0   0.091942   0.109804   0.103477
4     8192.0   0.167185   0.192318   0.185000
5    16384.0   0.318114   0.358216   0.336406
6    32768.0   0.797097   0.725244   0.655855
7    65536.0   1.223194   1.454980   1.299972
8   131072.0   2.429410   2.982287   2.559437
9   262144.0   4.837376   6.334187   5.085689
10  524288.0   9.653147  13.354424  10.183790

batch_size=2, num_heads=16, seqlen_q =1, seqlen_k=N_CTX

attention_d-128_dtype-torch.float16 (ms):
       N_CTX  flag_attn      torch   flash-2
0      512.0   0.046705   0.033903  0.022163
1     1024.0   0.046658   0.046374  0.033584
2     2048.0   0.055688   0.067447  0.053302
3     4096.0   0.094723   0.107909  0.098266
4     8192.0   0.169925   0.189810  0.171835
5    16384.0   0.321480   0.348469  0.324517
6    32768.0   0.625961   0.676866  0.630376
7    65536.0   1.230001   1.345603  1.236284
8   131072.0   2.435531   2.695659  2.446050
9   262144.0   4.854816   5.393702  4.863354
10  524288.0   9.683682  10.757249  9.691238

Copy link
Collaborator

@StrongSpoon StrongSpoon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👏

return 1

num_n_blocks = triton.cdiv(N, BLOCK_N)
def num_split_avaiable(s):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

available?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When spliting num_n_blocks into s splits, there may be some splits(the last split, for example) without valid workload.

For example, when spliting 64 blocks into 12 splits, with each split processing cdiv(64, 12) = 6 blocks, the last split is empty. So spliting 64 blocks into 12 splits is not available in this sense.

@iclementine iclementine merged commit 1641d0c into FlagOpen:main Feb 5, 2024
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants