Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply_rope_inplace will cause graphbreak due to mutated inputs #403

Open
jianc99 opened this issue Jul 28, 2024 · 4 comments
Open

apply_rope_inplace will cause graphbreak due to mutated inputs #403

jianc99 opened this issue Jul 28, 2024 · 4 comments

Comments

@jianc99
Copy link

jianc99 commented Jul 28, 2024

import torch
import flashinfer

rope = flashinfer.apply_rope_inplace

torch.library.define(
     "mylib::target_rope",
     "(Tensor(a!) q, Tensor(a!) k, Tensor indptr, Tensor offsets) -> None",
)
@torch.library.impl("mylib::target_rope", "cuda")
def target_rope(q, k, indptr, offsets):
     rope(q, k, indptr, offsets, interleave=True)

@torch.library.register_fake("mylib::target_rope")
def target_rope_abstract(q, k, indptr, offsets):
     return None

q = torch.randn(4, 4, 128, dtype=torch.bfloat16).to(0)
k = torch.randn(4, 1, 128, dtype=torch.bfloat16).to(0)
indptr = torch.arange(5, dtype=torch.int32).to(0)
offsets = torch.full((4,), 1, dtype=torch.int32).to(0)

torch.compile(torch.ops.mylib.target_rope, mode="reduce-overhead", fullgraph=True)(q, k, indptr, offsets)
skipping cudagraphs due to mutated inputs (2 instances)
@yzh119
Copy link
Collaborator

yzh119 commented Jul 28, 2024

@yzh119
Copy link
Collaborator

yzh119 commented Jul 29, 2024

I noticed that you already annotated the mutated inputs.
I think it's okay to expose another set of apply_rope and apply_llama31_rope which are not inplace operations for pytorch compile.

@jianc99
Copy link
Author

jianc99 commented Jul 29, 2024

Yeah I have annotated that but it still not works. Exposing non in place rope will be much helpful, thanks!

yzh119 added a commit that referenced this issue Jul 29, 2024
As requested in #403, this PR implements non-inplace rope operators.
@yzh119
Copy link
Collaborator

yzh119 commented Jul 29, 2024

Done in #405 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants