Skip to content

Commit

Permalink
feat: support huggingface transformer style rope interface (#568)
Browse files Browse the repository at this point in the history
Previously our rope apis assume the position indices of each request is
contiguous, which is not appropriate for applications such as
speculative decoding, this PR fixes the issue by supporting the
huggingface transformer-style API which use `pos_ids` argument to
specify positions.

This PR implements parts of the feature of #530 , other requests are
coming in later PRs.

cc @dreaming-panda @abcdabcd987 @ByronHsu
  • Loading branch information
yzh119 authored Oct 29, 2024
1 parent cdc12c3 commit 4f40420
Show file tree
Hide file tree
Showing 7 changed files with 521 additions and 211 deletions.
39 changes: 23 additions & 16 deletions flashinfer-aot/csrc_aot/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,11 @@ torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional<torch::Tenso
torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
unsigned int top_k_val);

torch::Tensor chain_speculative_sampling(
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
torch::Tensor target_probs, torch::Tensor output_accepted_token_num,
torch::Tensor output_emitted_token_num, bool deterministic);
torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
torch::Tensor uniform_samples, torch::Tensor target_probs,
torch::Tensor output_accepted_token_num,
torch::Tensor output_emitted_token_num,
bool deterministic);

void rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps);

Expand All @@ -82,24 +83,30 @@ void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta);

void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta, float low_freq_factor, float high_freq_factor,
float old_context_length);

std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope,
torch::Tensor k_rope, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta);

std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length);

std::vector<torch::Tensor> apply_rope_pos_ids(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta);

std::vector<torch::Tensor> apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k,
torch::Tensor q_rope, torch::Tensor k_rope,
torch::Tensor pos_ids, bool interleave,
float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length);

torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);

torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
Expand Down Expand Up @@ -141,11 +148,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul");
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul");
m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul");
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
"Apply Llama 3.1 style RoPE in-place");
m.def("apply_rope", &apply_rope, "Apply RoPE");
m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE");
m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids");
m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids,
"Apply Llama 3.1 style RoPE with positional ids");
m.def("packbits", &packbits, "GPU packbits operator");
m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator");
m.def("cutlass_segment_gemm", &CutlassSegmentGEMM, "Cutlass Segment GEMM operator");
Expand Down
Loading

0 comments on commit 4f40420

Please sign in to comment.