Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Add Dropout #7

Open
hypnopump opened this issue Dec 18, 2023 · 4 comments
Open

Feature request: Add Dropout #7

hypnopump opened this issue Dec 18, 2023 · 4 comments

Comments

@hypnopump
Copy link

hypnopump commented Dec 18, 2023

The pytorch base implementation of scaled_dot_product_attention provides dropout as an arg. Fusing it into the triton kernel would replicate that functionality, as dropout is applied to the attention scores, not the output.

In the CUDA version, it is supported here
There have been attempts at integrating into triton before

@iclementine
Copy link
Collaborator

iclementine commented Dec 27, 2023

Yes. It is on our plan. Actually WIP.

Since triton provides pseudo-random generator now we can implement an memory efficient flash attention with dropout without having to save the dropout mask(since it requires O(n^2) memory). The essence is to re-generate the same dropout masking as is used in the forward pass.

@hypnopump
Copy link
Author

I have implemented a prototype which seems to work here but it's hard to test correctness without separately implementing the dropout layer and checking, as uses a different random seed than torch.

@iclementine
Copy link
Collaborator

iclementine commented Dec 28, 2023

Yes, testing for randomness is tricky. There is no proper all_close test for operators with randomness. But I think dropout can be testing in the following ways:

  1. re-generation of mask;
  2. distribution test of mask.

This method can be applied to a separate dropout operator. It can also be applied to the dropout part of a more-complex-operator, but the overall testing for correctness is more complicated then operators without randomness involved.

@iclementine
Copy link
Collaborator

finally done in #23

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants