From 0d333ff04f3301c541970ba46e2e92e13eaa83bb Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 29 Jun 2024 23:58:57 -0700 Subject: [PATCH] refactor: reduce the binary size of batch decode kernels (#343) This PR refactors the batch decode related kernels, and make the following breaking changes: 1. remove `batch_decode_with_padded_kv_cache` operator, we encourage user to use `BatchDecodeWithPagedKVCacheWrapper`. 2. Delete redundant DTypeQ * DTypeKV combinations, now we only support the following cases: 1. DTypeQ == DTypeKV 2. DTypeQ is a float16 and DTypeKV is a float8 The output data type follows the query data type. --- CMakeLists.txt | 42 +-- include/flashinfer/attention/decode.cuh | 200 ------------ include/flashinfer/decode_attention_decl.cuh | 9 - include/flashinfer/utils.cuh | 3 + python/MANIFEST.in | 1 - python/csrc/batch_decode.cu | 318 ++++++------------- python/csrc/batch_prefill.cu | 12 +- python/csrc/cascade.cu | 6 +- python/csrc/flashinfer_ops.cu | 2 - python/csrc/flashinfer_ops.h | 5 - python/csrc/group_gemm.cu | 2 +- python/csrc/norm.cu | 2 +- python/csrc/pytorch_extension_utils.h | 12 +- python/csrc/single_decode.cu | 48 +-- python/csrc/single_prefill.cu | 4 +- python/flashinfer/__init__.py | 3 - python/flashinfer/cascade.py | 117 ------- python/flashinfer/decode.py | 243 +------------- python/generate_batch_padded_decode_inst.py | 72 ----- python/setup.py | 33 +- python/tests/test_batch_decode_kernels.py | 29 +- python/tests/test_shared_prefix_kernels.py | 223 +++++-------- src/bench_single_decode.cu | 28 +- src/cpu_reference.h | 6 +- src/flashinfer_ops.cuh | 31 -- src/test_batch_decode.cu | 70 ++-- src/test_batch_prefill.cu | 12 +- src/test_single_decode.cu | 49 ++- src/test_single_prefill.cu | 2 +- 29 files changed, 359 insertions(+), 1225 deletions(-) delete mode 100644 python/generate_batch_padded_decode_inst.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 6db5c334..266a21cb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -136,9 +136,9 @@ foreach(head_dim IN LISTS HEAD_DIMS) list(APPEND single_decode_kernels_src ${generated_kernel_src}) endforeach(dtype) - # fp8 in, fp16 out - foreach(dtype IN LISTS DECODE_FP8_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16.cu) + # fp8 kv-cache + foreach(dtype_kv IN LISTS DECODE_FP8_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16.cu) add_custom_command( OUTPUT ${generated_kernel_src} COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py ${generated_kernel_src} @@ -147,7 +147,7 @@ foreach(head_dim IN LISTS HEAD_DIMS) VERBATIM ) list(APPEND single_decode_kernels_src ${generated_kernel_src}) - endforeach(dtype) + endforeach(dtype_kv) endforeach(pos_encoding_mode) endforeach(kv_layout) endforeach(logits_post_hook) @@ -172,9 +172,9 @@ foreach(head_dim IN LISTS HEAD_DIMS) list(APPEND batch_decode_kernels_src ${generated_kernel_src}) endforeach(dtype) - # fp8 in, fp16 out - foreach(dtype IN LISTS DECODE_FP8_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16_idtype_${idtype}.cu) + # fp8 kv-cache + foreach(dtype_kv IN LISTS DECODE_FP8_DTYPES) + set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu) add_custom_command( OUTPUT ${generated_kernel_src} COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py ${generated_kernel_src} @@ -183,34 +183,8 @@ foreach(head_dim IN LISTS HEAD_DIMS) VERBATIM ) list(APPEND batch_decode_kernels_src ${generated_kernel_src}) - endforeach() + endforeach(dtype_kv) endforeach(idtype) - - # padded kv-cache - foreach(dtype IN LISTS DECODE_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_padded_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND batch_decode_kernels_src ${generated_kernel_src}) - endforeach(dtype) - - # padded kv-cache, fp8 in, fp16 out - foreach(dtype IN LISTS DECODE_FP8_DTYPES) - set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_padded_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16.cu) - add_custom_command( - OUTPUT ${generated_kernel_src} - COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py ${generated_kernel_src} - DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py - COMMENT "Generating additional source file ${generated_kernel_src}" - VERBATIM - ) - list(APPEND batch_decode_kernels_src ${generated_kernel_src}) - endforeach() endforeach(pos_encoding_mode) endforeach(kv_layout) endforeach(logits_post_hook) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 3997a098..ece23db9 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -358,141 +358,6 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _ } } -template -__global__ void BatchDecodeWithPaddedKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k, - DTypeKV* __restrict__ v, - DTypeOut* __restrict__ o, - float* __restrict__ lse, - tensor_info_t info, - float logits_soft_cap, float sm_scale, - float rope_rcp_scale, float rope_rcp_theta) { - auto block = cg::this_thread_block(); - sm_scale *= - (logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap)); - - constexpr uint32_t head_dim = bdx * vec_size; - uint32_t kv_head_idx = blockIdx.y; - uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y; - uint32_t batch_idx = blockIdx.x; - uint32_t num_qo_heads = info.num_qo_heads; - uint32_t num_kv_heads = info.num_kv_heads; - const float alibi_slope = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e; - uint32_t seq_len = info.kv_len; - - extern __shared__ uint8_t smem[]; - DTypeKV* k_smem = (DTypeKV*)smem; - DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeKV)); - float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeKV)); - - uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; - vec_t q_vec; - vec_t freq; - if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - freq[i] = rope_rcp_scale * - __powf(rope_rcp_theta, - float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); - } - // apply rotary embedding to q matrix - q_vec = vec_apply_llama_rope( - q + batch_idx * num_qo_heads * head_dim + info.get_qo_elem_offset(0, qo_head_idx, 0), freq, - seq_len - 1); - } else { - // do not apply rotary embedding to q matrix - q_vec.cast_load(q + batch_idx * num_qo_heads * head_dim + - info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size)); - } -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - q_vec[i] *= sm_scale; - } - block.sync(); - - // preload k tiles and v tiles - uint32_t producer_kv_idx_base = 0; - constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8; -#pragma unroll - for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { - cp_async::pred_load( - k_smem + ((iter * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, - k + batch_idx * seq_len * num_kv_heads * head_dim + - info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx, - tx * vec_size), - producer_kv_idx_base + tz * bdy + ty < seq_len); - cp_async::commit_group(); - cp_async::pred_load( - v_smem + ((iter * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, - v + batch_idx * seq_len * num_kv_heads * head_dim + - info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx, - tx * vec_size), - producer_kv_idx_base + tz * bdy + ty < seq_len); - cp_async::commit_group(); - producer_kv_idx_base += bdy * bdz; - } - - // pipelining k/v tiles loading and state updating - uint32_t consumer_kv_idx_base = 0, stage_idx = 0; - state_t st_local; - float s[bdy]; - -#pragma unroll 4 - for (uint32_t iter = 0; iter < ceil_div(seq_len, bdy * bdz); ++iter) { - // compute qk - cp_async::wait_group<2 * num_stages_smem - 1>(); - block.sync(); - compute_qk( - k_smem + (stage_idx * bdz + tz) * bdy * head_dim, stage_idx, q_vec, freq, - consumer_kv_idx_base, iter * bdy * bdz, seq_len, seq_len - 1, alibi_slope, s, st_local, - logits_soft_cap); - block.sync(); - // load k - cp_async::pred_load( - k_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, - k + batch_idx * seq_len * num_kv_heads * head_dim + - info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx, - tx * vec_size), - producer_kv_idx_base + tz * bdy + ty < seq_len); - cp_async::commit_group(); - - // update m/d/o state - cp_async::wait_group<2 * num_stages_smem - 1>(); - block.sync(); - update_local_state(v_smem + (stage_idx * bdz + tz) * bdy * head_dim, s, - stage_idx, st_local); - block.sync(); - - // load v - cp_async::pred_load( - v_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size, - v + batch_idx * seq_len * num_kv_heads * head_dim + - info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx, - tx * vec_size), - producer_kv_idx_base + tz * bdy + ty < seq_len); - cp_async::commit_group(); - - stage_idx = (stage_idx + 1) % num_stages_smem; - producer_kv_idx_base += bdy * bdz; - consumer_kv_idx_base += bdy * bdz; - } - cp_async::wait_group<0>(); - block.sync(); - - // sync local state of all warps inside a threadblock - sync_state(st_local, reinterpret_cast(smem), smem_md); - - st_local.normalize(); - st_local.o.cast_store(o + batch_idx * num_qo_heads * head_dim + - info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size)); - - // write lse - if (lse != nullptr) { - lse[batch_idx * num_qo_heads + qo_head_idx] = st_local.get_lse(); - } -} - /*! * \brief FlashAttention decoding cuda kernel with paged kv-cache for multiple requests * \tparam logits_post_hook The logits post hook used in the kernel @@ -937,71 +802,6 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( return cudaSuccess; } -/*! - * \brief FlashAttention decoding cuda kernel with paged kv-cache for batched requests - * \tparam page_storage Whether to store indices or pointers of each active page - * \tparam DTypeQ A template type indicates the query data type - * \tparam DTypeKV A template type indicates the key-value data type - * \tparam DTypeOut A template type indicates the output data type - * \tparam IdType A template type indicates the index data type used in paged kv-cache - * \param q [batch_size, num_qo_heads, head_dim] The query matrix - * \param paged_kv The paged kv cache data structure - * \param o [batch_size, num_qo_heads, head_dim] The output matrix - * \param tmp Used-allocated temporary buffer - * \param lse The logsumexp values. - * \param num_qo_heads A integer indicates the number of heads of query and output - * \param pos_encoding_mode The positional encoding mode - * \param rope_scale The scaling ratio used in RoPE Interpolation. - * \param rope_theta A floating point number indicate the "theta" used in RoPE - * \param stream The cuda stream to launch the kernel - * \return status Indicates whether CUDA calls are successful - */ -template -cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, - DTypeOut* tmp, float* lse, uint32_t batch_size, - uint32_t padded_kv_len, uint32_t num_qo_heads, - uint32_t num_kv_heads, float logits_soft_cap, - float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream) { - const float rope_rcp_scale = 1.f / rope_scale; - const float rope_rcp_theta = 1.f / rope_theta; - - constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); - constexpr uint32_t num_stages_smem = 2U; - constexpr uint32_t bdx = HEAD_DIM / vec_size; - static_assert(bdx <= 32); - DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { - constexpr uint32_t bdy = GROUP_SIZE; - constexpr uint32_t num_threads = std::max(128U, bdx * bdy); - constexpr uint32_t bdz = num_threads / (bdx * bdy); - - const uint32_t smem_size = 2 * num_stages_smem * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + - 2 * bdy * bdz * sizeof(float); - - dim3 nblks(batch_size, num_kv_heads); - dim3 nthrs(bdx, bdy, bdz); - auto kernel = BatchDecodeWithPaddedKVCacheKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - tensor_info_t info(1, padded_kv_len, num_qo_heads, num_kv_heads); - void* args[] = {(void*)&q, - (void*)&k, - (void*)&v, - (void*)&o, - (void*)&lse, - (void*)&info, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); - return cudaSuccess; -} - } // namespace flashinfer #endif // FLASHINFER_DECODE_CUH_ diff --git a/include/flashinfer/decode_attention_decl.cuh b/include/flashinfer/decode_attention_decl.cuh index b082bcda..3b77edfa 100644 --- a/include/flashinfer/decode_attention_decl.cuh +++ b/include/flashinfer/decode_attention_decl.cuh @@ -45,15 +45,6 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( float* lse, bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); -template -cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, - DTypeOut* tmp, float* lse, uint32_t batch_size, - uint32_t padded_kv_len, uint32_t num_qo_heads, - uint32_t num_kv_heads, float logits_soft_cap, - float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream); - template diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index cc9592ea..04890373 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -97,6 +97,9 @@ if (group_size == 1) { \ constexpr size_t GROUP_SIZE = 1; \ __VA_ARGS__ \ + } else if (group_size == 2) { \ + constexpr size_t GROUP_SIZE = 2; \ + __VA_ARGS__ \ } else if (group_size == 4) { \ constexpr size_t GROUP_SIZE = 4; \ __VA_ARGS__ \ diff --git a/python/MANIFEST.in b/python/MANIFEST.in index 854badc8..d7ad6177 100644 --- a/python/MANIFEST.in +++ b/python/MANIFEST.in @@ -1,6 +1,5 @@ # sdist & wheel include version.txt -include generate_batch_padded_decode_inst.py include generate_batch_paged_decode_inst.py include generate_batch_paged_prefill_inst.py include generate_batch_ragged_prefill_inst.py diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index 1062b07d..710a8c94 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -20,115 +20,6 @@ using namespace flashinfer; -std::vector batch_decode_with_padded_kv_cache( - torch::Tensor q, torch::Tensor k_padded, torch::Tensor v_padded, unsigned int layout, - unsigned int pos_encoding_mode, float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, bool return_lse) { - CHECK_INPUT(q); - CHECK_INPUT(k_padded); - CHECK_INPUT(v_padded); - CHECK_DIM(3, q); - CHECK_DIM(4, k_padded); - CHECK_DIM(4, v_padded); - CHECK_SHAPE(k_padded, v_padded); - CHECK_EQ(q.size(0), k_padded.size(0)); - CHECK_EQ(q.size(2), k_padded.size(3)); - CHECK_EQ(v_padded.scalar_type(), k_padded.scalar_type()); - unsigned int batch_size = q.size(0); - unsigned int num_qo_heads = q.size(1); - unsigned int head_dim = q.size(2); - unsigned int padded_kv_len, num_kv_heads; - QKVLayout kv_layout = static_cast(layout); - if (kv_layout == QKVLayout::kNHD) { - padded_kv_len = k_padded.size(1); - num_kv_heads = k_padded.size(2); - } else { - padded_kv_len = k_padded.size(2); - num_kv_heads = k_padded.size(1); - } - CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); - - cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); - auto o = torch::empty_like( - q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type())); - torch::Tensor lse = torch::empty({0}); - if (return_lse) { - lse = torch::empty({batch_size, num_qo_heads}, q.options()).to(torch::kFloat32); - } - - TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); - const LogitsPostHook logits_post_hook = - logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - - if (is_float8_tensor(q)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k_padded.scalar_type(), kv_type, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - nv_half* tmp = nullptr; - cudaError_t status = - BatchDecodeWithPaddedKVCacheDispatched( - static_cast(q.data_ptr()), - static_cast(k_padded.data_ptr()), - static_cast(v_padded.data_ptr()), - static_cast(o.data_ptr()), - /*tmp=*/tmp, - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, padded_kv_len, num_qo_heads, num_kv_heads, logits_soft_cap, - sm_scale, rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPaddedKVCache failed with error code ", status); - return true; - }); - }); - }); - }); - }); - }); - } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k_padded.scalar_type(), kv_type, [&] { - q_type* tmp = nullptr; - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - cudaError_t status = - BatchDecodeWithPaddedKVCacheDispatched( - static_cast(q.data_ptr()), - static_cast(k_padded.data_ptr()), - static_cast(v_padded.data_ptr()), - static_cast(o.data_ptr()), - /*tmp=*/tmp, - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, padded_kv_len, num_qo_heads, num_kv_heads, logits_soft_cap, - sm_scale, rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPaddedKVCache failed with error code ", status); - return true; - }); - }); - }); - }); - }); - }); - } - - if (return_lse) { - return {o, lse}; - } else { - return {o}; - } -} - void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, @@ -154,57 +45,56 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - if (is_float8_tensor(empty_q_data)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_q_data.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8( - empty_kv_data.scalar_type(), kv_type, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = handler_->BeginForwardDispatched< - HEAD_DIM, PageStorage::kIndices, LOGITS_POST_HOOK, KV_LAYOUT, - POS_ENCODING_MODE, q_type, kv_type, nv_half, int32_t>( - static_cast(workspace_buffer.data_ptr()), - workspace_size_in_bytes, static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), batch_size, - num_qo_heads, num_kv_heads, page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); + auto q_scalar_type = empty_q_data.scalar_type(); + auto kv_scalar_type = empty_kv_data.scalar_type(); + + if (q_scalar_type == kv_scalar_type) { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q_scalar_type, qkv_type, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = handler_->BeginForwardDispatched< + HEAD_DIM, PageStorage::kIndices, LOGITS_POST_HOOK, KV_LAYOUT, + POS_ENCODING_MODE, qkv_type, qkv_type, qkv_type, int32_t>( + static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, + num_kv_heads, page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; }); - }); - }); }); + }); + }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_q_data.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8( - empty_kv_data.scalar_type(), kv_type, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = handler_->BeginForwardDispatched< - HEAD_DIM, PageStorage::kIndices, LOGITS_POST_HOOK, KV_LAYOUT, - POS_ENCODING_MODE, q_type, kv_type, q_type, int32_t>( - static_cast(workspace_buffer.data_ptr()), - workspace_size_in_bytes, static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), batch_size, - num_qo_heads, num_kv_heads, page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); - }); - }); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(kv_scalar_type, kv_type, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = handler_->BeginForwardDispatched< + HEAD_DIM, PageStorage::kIndices, LOGITS_POST_HOOK, KV_LAYOUT, + POS_ENCODING_MODE, q_type, kv_type, q_type, int32_t>( + static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, + num_kv_heads, page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); }); }); + }); + }); }); } } @@ -255,8 +145,7 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); - torch::Tensor o = torch::empty_like( - q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type())); + torch::Tensor o = torch::empty_like(q); torch::Tensor lse; if (return_lse) { lse = torch::empty({batch_size, num_qo_heads}, q.options()).to(torch::kFloat32); @@ -266,71 +155,70 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - if (is_float8_tensor(q)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8( - paged_kv_data.scalar_type(), kv_type, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, - static_cast(paged_kv_data.data_ptr()), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, - POS_ENCODING_MODE, q_type, kv_type, nv_half, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), - num_qo_heads, logits_soft_cap, sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = paged_kv_data.scalar_type(); + + if (q_scalar_type == kv_scalar_type) { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q_scalar_type, qkv_type, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, + POS_ENCODING_MODE, qkv_type, qkv_type, qkv_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), + /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), + num_qo_heads, logits_soft_cap, sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; }); - }); - }); }); + }); + }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8( - paged_kv_data.scalar_type(), kv_type, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, - static_cast(paged_kv_data.data_ptr()), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, - POS_ENCODING_MODE, q_type, kv_type, q_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), - /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), - num_qo_heads, logits_soft_cap, sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); - }); - }); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(kv_scalar_type, kv_type, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, + POS_ENCODING_MODE, q_type, kv_type, q_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), + /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), + /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), + num_qo_heads, logits_soft_cap, sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); }); }); + }); + }); }); } diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 89cd38c6..09d2f3d6 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -37,7 +37,7 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward( cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); handler_->SetCUDAStream(torch_current_stream); - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_q_data.scalar_type(), q_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] { cudaError_t status = handler_->BeginForward( static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), @@ -111,7 +111,7 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { paged_kv_t paged_kv( @@ -218,7 +218,7 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { paged_kv_t paged_kv( @@ -280,7 +280,7 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward( cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); handler_->SetCUDAStream(torch_current_stream); - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_q_data.scalar_type(), q_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(empty_q_data.scalar_type(), q_type, [&] { cudaError_t status = handler_->BeginForward( static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, static_cast(qo_indptr.data_ptr()), static_cast(kv_indptr.data_ptr()), @@ -342,7 +342,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward( const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { return DISPATCH_allow_fp16_qk_reduction( @@ -433,7 +433,7 @@ std::vector BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_allow_fp16_qk_reduction( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, [&] { diff --git a/python/csrc/cascade.cu b/python/csrc/cascade.cu index 5faeb0d8..9733edd8 100644 --- a/python/csrc/cascade.cu +++ b/python/csrc/cascade.cu @@ -43,7 +43,7 @@ std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, tor auto v_merged = torch::empty_like(v_a, v_a.options()); auto s_merged = torch::empty({seq_len, num_heads}, s_a.options()); - bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v_a.scalar_type(), c_type, [&] { + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v_a.scalar_type(), c_type, [&] { cudaError_t status = MergeState( static_cast(v_a.data_ptr()), static_cast(s_a.data_ptr()), static_cast(v_b.data_ptr()), static_cast(s_b.data_ptr()), @@ -79,7 +79,7 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe unsigned int head_dim = v.size(2); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); - bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v.scalar_type(), c_type, [&] { + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v.scalar_type(), c_type, [&] { cudaError_t status = MergeStateInPlace( static_cast(v.data_ptr()), static_cast(s.data_ptr()), static_cast(v_other.data_ptr()), static_cast(s_other.data_ptr()), seq_len, @@ -109,7 +109,7 @@ std::vector merge_states(torch::Tensor v, torch::Tensor s) { auto v_merged = torch::empty({seq_len, num_heads, head_dim}, v.options()); auto s_merged = torch::empty({seq_len, num_heads}, s.options()); - bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(v.scalar_type(), c_type, [&] { + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v.scalar_type(), c_type, [&] { cudaError_t status = MergeStates( static_cast(v.data_ptr()), static_cast(s.data_ptr()), static_cast(v_merged.data_ptr()), static_cast(s_merged.data_ptr()), diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index f3a5f62d..e3155529 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -30,8 +30,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("merge_state_in_place", &merge_state_in_place, "Merge another self-attention state in-place."); m.def("merge_states", &merge_states, "Merge multiple self-attention states"); - m.def("batch_decode_with_padded_kv_cache", &batch_decode_with_padded_kv_cache, - "Multi-request batch decode with padded KV-Cache operator"); m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, "Top-k sampling from probabilities"); diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 70893b91..ca4971ee 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -50,11 +50,6 @@ void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_othe std::vector merge_states(torch::Tensor v, torch::Tensor s); -std::vector batch_decode_with_padded_kv_cache( - torch::Tensor q, torch::Tensor k_padded, torch::Tensor v_padded, unsigned int layout, - unsigned int pos_encoding_mode, float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, bool return_lse); - torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples); std::vector top_p_sampling_from_probs(torch::Tensor probs, diff --git a/python/csrc/group_gemm.cu b/python/csrc/group_gemm.cu index f8ee4388..79dba42f 100644 --- a/python/csrc/group_gemm.cu +++ b/python/csrc/group_gemm.cu @@ -51,7 +51,7 @@ torch::Tensor CutlassSegmentGEMMPyTorchWrapper::Forward(torch::Tensor seg_indptr weight_indices = weight_indices.to(torch::kInt64); } - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(x.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(x.scalar_type(), c_type, [&] { using cutlass_t = typename cutlass_dtype::type; auto status = CutlassSegmentGEMMWrapper( handler_.get(), static_cast(x.data_ptr()), diff --git a/python/csrc/norm.cu b/python/csrc/norm.cu index 52a15eaa..859600ab 100644 --- a/python/csrc/norm.cu +++ b/python/csrc/norm.cu @@ -31,7 +31,7 @@ torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps) { cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); auto y = torch::empty_like(x); - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(x.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(x.scalar_type(), c_type, [&] { cudaError_t status = norm::RMSNorm( static_cast(x.data_ptr()), static_cast(w.data_ptr()), static_cast(y.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index 5aac9196..c3e04185 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -32,7 +32,7 @@ using namespace flashinfer; #ifdef FLASHINFER_ENABLE_BF16 -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Half: { \ @@ -51,7 +51,7 @@ using namespace flashinfer; } \ }() #else -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Half: { \ @@ -97,7 +97,7 @@ using namespace flashinfer; #endif #if defined(FLASHINFER_ENABLE_BF16) && defined(FLASHINFER_ENABLE_FP8) -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Half: { \ @@ -124,7 +124,7 @@ using namespace flashinfer; } \ }() #elif defined(FLASHINFER_ENABLE_BF16) -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Half: { \ @@ -143,7 +143,7 @@ using namespace flashinfer; } \ }() #elif defined(FLASHINFER_ENABLE_FP8) -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Float8_e4m3fn: { \ @@ -162,7 +162,7 @@ using namespace flashinfer; } \ }() #else -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Half: { \ diff --git a/python/csrc/single_decode.cu b/python/csrc/single_decode.cu index 52f246f4..07d0fb7b 100644 --- a/python/csrc/single_decode.cu +++ b/python/csrc/single_decode.cu @@ -46,40 +46,40 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc } CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); - auto o = torch::empty_like( - q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type())); + auto o = torch::empty_like(q); TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - if (is_float8_tensor(q)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k.scalar_type(), kv_type, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = SingleDecodeWithKVCacheDispatched< - HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE>( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, - logits_soft_cap, sm_scale, rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SingleDecodeWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); - }); + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = k.scalar_type(); + + if (q_scalar_type == kv_scalar_type) { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q_scalar_type, qkv_type, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = SingleDecodeWithKVCacheDispatched< + HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE>( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), num_qo_heads, num_kv_heads, kv_len, + logits_soft_cap, sm_scale, rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SingleDecodeWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); }); }); }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k.scalar_type(), kv_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q_scalar_type, q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(kv_scalar_type, kv_type, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { diff --git a/python/csrc/single_prefill.cu b/python/csrc/single_prefill.cu index faf90fb2..a0fcfec8 100644 --- a/python/csrc/single_prefill.cu +++ b/python/csrc/single_prefill.cu @@ -59,7 +59,7 @@ std::vector single_prefill_with_kv_cache( const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { @@ -142,7 +142,7 @@ std::vector single_prefill_with_kv_cache_custom_mask( const LogitsPostHook logits_post_hook = logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; - bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index 2a3b3c9c..d93097a9 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -16,8 +16,6 @@ from .decode import ( single_decode_with_kv_cache, - batch_decode_with_padded_kv_cache, - batch_decode_with_padded_kv_cache_return_lse, BatchDecodeWithPagedKVCacheWrapper, CUDAGraphBatchDecodeWithPagedKVCacheWrapper, ) @@ -31,7 +29,6 @@ merge_state, merge_state_in_place, merge_states, - batch_decode_with_shared_prefix_padded_kv_cache, BatchDecodeWithSharedPrefixPagedKVCacheWrapper, BatchPrefillWithSharedPrefixPagedKVCacheWrapper, ) diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index 28809ebe..f26913a5 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -31,7 +31,6 @@ raise e from .decode import ( - batch_decode_with_padded_kv_cache_return_lse, BatchDecodeWithPagedKVCacheWrapper, ) from .prefill import ( @@ -175,122 +174,6 @@ def merge_states(v: torch.Tensor, s: torch.Tensor): return _kernels.merge_states(v, s) -def batch_decode_with_shared_prefix_padded_kv_cache( - q: torch.Tensor, - k_shared: torch.Tensor, - v_shared: torch.Tensor, - k_unique: torch.Tensor, - v_unique: torch.Tensor, - kv_layout: str = "NHD", - allow_fp16_qk_reduction=False, - sm_scale: Optional[float] = None, - rope_scale: Optional[float] = None, - rope_theta: Optional[float] = None, -): - r"""Decode attention between queries and shared prefix kv-cache for batch of - requests. - - Parameters - ---------- - q : torch.Tensor - The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``. - k_shared : torch.Tensor - The shared prefix key tensor, shape: - ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, - or ``[num_kv_heads, shared_prefix_len, head_dim]`` if :attr:`kv_layout` is - ``HND``. - v_shared : torch.Tensor - The shared prefix value tensor, shape: - ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, - or ``[num_kv_heads, shared_prefix_len, head_dim]`` if :attr:`kv_layout` is - ``HND``. - k_unique : torch.Tensor - The request-independent suffix key tensor, shape: - ``[batch_size, unique_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is - ``NHD``, or ``[batch_size, num_kv_heads, unique_len, head_dim]`` if - :attr:`kv_layout` is ``HND``. - v_unique : torch.Tensor - The request-independent suffix value tensor, shape: - ``[batch_size, unique_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is - ``NHD``, or ``[batch_size, num_kv_heads, unique_len, head_dim]`` if - :attr:`kv_layout` is ``HND``. - kv_layout : str - The layout of the kv-cache, could be either "NHD" or "HND". - allow_fp16_qk_reduction : bool - Whether to use f16 for qk reduction (faster at the cost of slight precision - loss). - sm_scale : Optional[float] - The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)`` - rope_scale : Optional[float] - The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. - rope_theta : Optional[float] - The theta used in RoPE, if not provided, will be set to ``1e4``. - - Returns - ------- - V : torch.Tensor - The attention output, shape: ``[batch_size, num_heads, head_dim]`` - - Example - ------- - >>> import torch - >>> import flashinfer - >>> shared_prefix_len = 16384 - >>> padded_unique_suffix_len = 2048 - >>> batch_size = 53 - >>> num_qo_heads = 32 - >>> num_kv_heads = 32 - >>> head_dim = 128 - >>> q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0") - >>> k_shared = torch.randn(shared_prefix_len, num_kv_heads, head_dim).half().to("cuda:0") - >>> v_shared = torch.randn(shared_prefix_len, num_kv_heads, head_dim).half().to("cuda:0") - >>> k_unique = torch.randn( - ... batch_size, - ... padded_unique_suffix_len, - ... num_kv_heads, - ... head_dim - ... ).half().to("cuda:0") - >>> v_unique = torch.randn( - ... batch_size, - ... padded_unique_suffix_len, - ... num_kv_heads, - ... head_dim - ... ).half().to("cuda:0") - >>> o = flashinfer.batch_decode_with_shared_prefix_padded_kv_cache( - ... q, k_shared, v_shared, k_unique, v_unique, kv_layout="NHD", - ... allow_fp16_qk_reduction=True - ... ) - >>> o.shape - torch.Size([53, 32, 128]) - """ - check_kv_layout(kv_layout) - V_shared, S_shared = single_prefill_with_kv_cache_return_lse( - q, - k_shared, - v_shared, - causal=False, - pos_encoding_mode="NONE", - kv_layout=kv_layout, - allow_fp16_qk_reduction=allow_fp16_qk_reduction, - sm_scale=sm_scale, - rope_scale=rope_scale, - rope_theta=rope_theta, - ) - V_unique, S_unique = batch_decode_with_padded_kv_cache_return_lse( - q, - k_unique, - v_unique, - kv_layout=kv_layout, - pos_encoding_mode="NONE", - sm_scale=sm_scale, - rope_scale=rope_scale, - rope_theta=rope_theta, - ) - - merge_state_in_place(V_shared, S_shared, V_unique, S_unique) - return V_shared - - class BatchDecodeWithSharedPrefixPagedKVCacheWrapper: r"""Wrapper class for decode attention with shared-prefix paged kv-cache for batch of requests. diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index bc4328c6..2d6ee5cc 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -52,7 +52,7 @@ def _get_cache_buf(name: str, bytes: int, device: torch.device): def _grouped_size_compiled_for_decode_kernels(num_qo_heads: int, num_kv_heads: int): - return (num_qo_heads // num_kv_heads) in [1, 4, 8] + return (num_qo_heads // num_kv_heads) in [1, 2, 4, 8] def single_decode_with_kv_cache( @@ -165,7 +165,6 @@ def single_decode_with_kv_cache( ) if use_tensor_cores: - print(q.shape, k.shape) out = _kernels.single_prefill_with_kv_cache( q.unsqueeze(0), k, @@ -199,244 +198,6 @@ def single_decode_with_kv_cache( return out -def batch_decode_with_padded_kv_cache( - q: torch.Tensor, - k_padded: torch.Tensor, - v_padded: torch.Tensor, - kv_layout: str = "NHD", - pos_encoding_mode: str = "NONE", - q_scale: Optional[float] = None, - k_scale: Optional[float] = None, - v_scale: Optional[float] = None, - logits_soft_cap: Optional[float] = None, - sm_scale: Optional[float] = None, - rope_scale: Optional[float] = None, - rope_theta: Optional[float] = None, -): - r"""Decode attention with padded KV cache for batch of requests, return attention - output. - - Parameters - ---------- - q : torch.Tensor - The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``. - k_padded : torch.Tensor - The padded key tensor, shape: - ``[batch_size, padded_seq_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` - is ``NHD`` or ``[batch_size, num_kv_heads, padded_seq_len, head_dim]`` if - :attr:`kv_layout` is ``HND``. - v_padded : torch.Tensor - The padded value tensor, shape: - ``[batch_size, padded_seq_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` - is ``NHD`` or ``[batch_size, num_kv_heads, padded_seq_len, head_dim]`` if - :attr:`kv_layout` is ``HND``. - kv_layout : str - The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. - pos_encoding_mode : str - The position encoding applied inside attention kernels, could be - ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. - Defaults to ``NONE``. - q_scale : Optional[float] - The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. - k_scale : Optional[float] - The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. - v_scale : Optional[float] - The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. - logits_soft_cap : Optional[float] - The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not - provided, will be set to ``0``. If greater than 0, the logits will be capped according to - formula: - :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, - where :math:`x` is the input logits. - sm_scale : Optional[float] - The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``. - rope_scale : Optional[float] - The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. - rope_theta : Optional[float] - The theta used in RoPE, if not provided, will be set to ``1e4``. - - Returns - ------- - torch.Tensor - The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. - - Examples - -------- - >>> import torch - >>> import flashinfer - >>> padded_kv_len = 4096 - >>> num_qo_heads = 32 - >>> num_kv_heads = 32 - >>> batch_size = 7 - >>> head_dim = 128 - >>> q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0") - >>> k_padded = torch.randn(batch_size, padded_kv_len, num_kv_heads, head_dim).half().to("cuda:0") - >>> v_padded = torch.randn(batch_size, padded_kv_len, num_kv_heads, head_dim).half().to("cuda:0") - >>> o = flashinfer.batch_decode_with_padded_kv_cache( - ... q, k_padded, v_padded, "NHD", "LLAMA" - ... ) - >>> o.shape - torch.Size([7, 32, 128]) - - Notes - ----- - The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is - not equal to ``num_kv_heads``, the function will use - `grouped query attention `_. - """ - if logits_soft_cap is None: - logits_soft_cap = 0.0 - if sm_scale is None: - head_dim = q.shape[-1] - sm_scale = 1.0 / math.sqrt(head_dim) - if q_scale is not None: - sm_scale *= q_scale - if k_scale is not None: - sm_scale *= k_scale - if rope_scale is None: - rope_scale = 1.0 - if rope_theta is None: - rope_theta = 1e4 - out = _kernels.batch_decode_with_padded_kv_cache( - q, - k_padded, - v_padded, - TensorLayout[kv_layout].value, - PosEncodingMode[pos_encoding_mode].value, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - False, # return_lse - )[0] - if v_scale is not None: - out *= v_scale - return out - - -def batch_decode_with_padded_kv_cache_return_lse( - q: torch.Tensor, - k_padded: torch.Tensor, - v_padded: torch.Tensor, - kv_layout: str = "NHD", - pos_encoding_mode: str = "NONE", - q_scale: Optional[float] = None, - k_scale: Optional[float] = None, - v_scale: Optional[float] = None, - logits_soft_cap: Optional[float] = None, - sm_scale: Optional[float] = None, - rope_scale: Optional[float] = None, - rope_theta: Optional[float] = None, -): - r"""Decode attention with padded KV cache for batch of requests, return attention - output and logsumexp of attention scores, return attention output and logsumexp of - attention scores. - - Parameters - ---------- - q : torch.Tensor - The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]``. - k_padded : torch.Tensor - The padded key tensor, shape: - ``[batch_size, padded_seq_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` - is ``NHD`` or ``[batch_size, num_kv_heads, padded_seq_len, head_dim]`` if - :attr:`kv_layout` is ``HND``. - v_padded : torch.Tensor - The padded value tensor, shape: - ``[batch_size, padded_seq_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` - is ``NHD`` or ``[batch_size, num_kv_heads, padded_seq_len, head_dim]`` if - :attr:`kv_layout` is ``HND``. - kv_layout : str - The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. - pos_encoding_mode : str - The position encoding applied inside attention kernels, could be - ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. - Defaults to ``NONE``. - logits_soft_cap : Optional[float] - The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not - provided, will be set to ``0``. If greater than 0, the logits will be capped according to - formula: - :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, - where :math:`x` is the input logits. - q_scale : Optional[float] - The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. - k_scale : Optional[float] - The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. - v_scale : Optional[float] - The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. - sm_scale : Optional[float] - The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``. - rope_scale : Optional[float] - The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. - rope_theta : Optional[float] - The theta used in RoPE, if not provided, will be set to ``1e4``. - - Returns - ------- - V : torch.Tensor - The attention output, shape: [batch_size, num_qo_heads, head_dim] - S : torch.Tensor - The logsumexp of attention scores, Shape: [batch_size, num_qo_heads] - - Examples - -------- - >>> import torch - >>> import flashinfer - >>> padded_kv_len = 4096 - >>> num_qo_heads = 32 - >>> num_kv_heads = 32 - >>> batch_size = 7 - >>> head_dim = 128 - >>> q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0") - >>> k_padded = torch.randn(batch_size, padded_kv_len, num_kv_heads, head_dim).half().to("cuda:0") - >>> v_padded = torch.randn(batch_size, padded_kv_len, num_kv_heads, head_dim).half().to("cuda:0") - >>> v, s = flashinfer.batch_decode_with_padded_kv_cache_return_lse( - ... q, k_padded, v_padded, "NHD" - ... ) - >>> v.shape - torch.Size([7, 32, 128]) - >>> s.shape - torch.Size([7, 32]) - - Notes - ----- - Please refer to the :ref:`tutorial ` for a detailed - explanation of the log-sum-exp function and attention states. - - The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is - not equal to ``num_kv_heads``, the function will use - `grouped query attention `_. - """ - if logits_soft_cap is None: - logits_soft_cap = 0.0 - if sm_scale is None: - head_dim = q.shape[-1] - sm_scale = 1.0 / math.sqrt(head_dim) - if q_scale is not None: - sm_scale *= q_scale - if k_scale is not None: - sm_scale *= k_scale - if rope_scale is None: - rope_scale = 1.0 - if rope_theta is None: - rope_theta = 1e4 - V, s = _kernels.batch_decode_with_padded_kv_cache( - q, - k_padded, - v_padded, - TensorLayout[kv_layout].value, - PosEncodingMode[pos_encoding_mode].value, - logits_soft_cap, - sm_scale, - rope_scale, - rope_theta, - True, # return_lse - ) - if v_scale is not None: - V *= v_scale - return V, s - - class BatchDecodeWithPagedKVCacheWrapper: r"""Wrapper class for decode attention with paged kv-cache (first proposed in `vLLM `_) for batch of requests. @@ -982,8 +743,8 @@ def forward_return_lse( self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, PosEncodingMode[pos_encoding_mode].value, - sm_scale, logits_soft_cap, + sm_scale, rope_scale, rope_theta, True, # return_lse diff --git a/python/generate_batch_padded_decode_inst.py b/python/generate_batch_padded_decode_inst.py deleted file mode 100644 index 51635a64..00000000 --- a/python/generate_batch_padded_decode_inst.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -Copyright (c) 2024 by FlashInfer team. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import sys -import re -from literal_map import ( - kv_layout_literal, - pos_encoding_mode_literal, - dtype_literal, - logits_hook_literal, -) -from pathlib import Path - - -def get_cu_file_str( - head_dim, - logits_hook, - kv_layout, - pos_encoding_mode, - dtype_q, - dtype_kv, - dtype_out, -): - content = """#include - -namespace flashinfer {{ - -template cudaError_t BatchDecodeWithPaddedKVCacheDispatched<{head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( - {dtype_q}* q, {dtype_kv}* k, {dtype_kv}* v, - {dtype_out}* o, {dtype_out}* tmp, float* lse, - uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, uint32_t num_kv_heads, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, cudaStream_t stream); - -}} - """.format( - logits_hook=logits_hook_literal[int(logits_hook)], - kv_layout=kv_layout_literal[int(kv_layout)], - head_dim=head_dim, - pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], - dtype_q=dtype_literal[dtype_q], - dtype_kv=dtype_literal[dtype_kv], - dtype_out=dtype_literal[dtype_out], - ) - return content - - -if __name__ == "__main__": - pattern = ( - r"batch_padded_decode_head_([0-9]+)_logitshook_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" - r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" - ) - - compiled_pattern = re.compile(pattern) - path = Path(sys.argv[1]) - fname = path.name - match = compiled_pattern.match(fname) - with open(path, "w") as f: - f.write(get_cu_file_str(*match.groups())) diff --git a/python/setup.py b/python/setup.py index 038b2737..45881d7b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -28,7 +28,7 @@ import torch import torch.utils.cpp_extension as torch_cpp_ext -import generate_single_decode_inst, generate_single_prefill_inst, generate_batch_paged_decode_inst, generate_batch_padded_decode_inst, generate_batch_paged_prefill_inst, generate_batch_ragged_prefill_inst, generate_dispatch_inc +import generate_single_decode_inst, generate_single_prefill_inst, generate_batch_paged_decode_inst, generate_batch_paged_prefill_inst, generate_batch_ragged_prefill_inst, generate_dispatch_inc root = pathlib.Path(__name__).parent @@ -92,10 +92,12 @@ def get_instantiation_cu() -> List[str]: idtypes = ["i32"] prefill_dtypes = ["f16"] decode_dtypes = ["f16"] + fp16_dtypes = ["f16"] fp8_dtypes = ["e4m3", "e5m2"] if enable_bf16: prefill_dtypes.append("bf16") decode_dtypes.append("bf16") + fp16_dtypes.append("bf16") if enable_fp8: decode_dtypes.extend(fp8_dtypes) @@ -112,8 +114,10 @@ def get_instantiation_cu() -> List[str]: kv_layouts, pos_encoding_modes, ): - for dtype_q, dtype_kv in itertools.product(decode_dtypes, decode_dtypes): - dtype_out = dtype_q if dtype_q not in fp8_dtypes else "f16" + for dtype_q, dtype_kv in list(zip(decode_dtypes, decode_dtypes)) + list( + itertools.product(fp16_dtypes, fp8_dtypes) + ): + dtype_out = dtype_q fname = f"single_decode_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" files.append(prefix + "/" + fname) content = generate_single_decode_inst.get_cu_file_str( @@ -140,8 +144,10 @@ def get_instantiation_cu() -> List[str]: pos_encoding_modes, ): for idtype in idtypes: - for dtype_q, dtype_kv in itertools.product(decode_dtypes, decode_dtypes): - dtype_out = dtype_q if dtype_q not in fp8_dtypes else "f16" + for dtype_q, dtype_kv in list(zip(decode_dtypes, decode_dtypes)) + list( + itertools.product(fp16_dtypes, fp8_dtypes) + ): + dtype_out = dtype_q fname = f"batch_paged_decode_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu" files.append(prefix + "/" + fname) content = generate_batch_paged_decode_inst.get_cu_file_str( @@ -156,21 +162,6 @@ def get_instantiation_cu() -> List[str]: ) write_if_different(root / prefix / fname, content) - for dtype_q, dtype_kv in itertools.product(decode_dtypes, decode_dtypes): - dtype_out = dtype_q if dtype_q not in fp8_dtypes else "f16" - fname = f"batch_padded_decode_head_{head_dim}_logitshook_{logits_hook}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" - files.append(prefix + "/" + fname) - content = generate_batch_padded_decode_inst.get_cu_file_str( - head_dim, - logits_hook, - kv_layout, - pos_encoding_mode, - dtype_q, - dtype_kv, - dtype_out, - ) - write_if_different(root / prefix / fname, content) - # single prefill files for ( head_dim, @@ -369,7 +360,7 @@ def __init__(self, *args, **kwargs) -> None: "-Xcompiler", "-mcmodel=medium", "-Xcompiler", - "\"-Wl,--no-relax\"" + '"-Wl,--no-relax"', ], }, ) diff --git a/python/tests/test_batch_decode_kernels.py b/python/tests/test_batch_decode_kernels.py index c020e657..b588d5b7 100644 --- a/python/tests/test_batch_decode_kernels.py +++ b/python/tests/test_batch_decode_kernels.py @@ -29,9 +29,8 @@ @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) -@pytest.mark.parametrize( - "q_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] -) +@pytest.mark.parametrize("return_lse", [True, False]) +@pytest.mark.parametrize("q_dtype", [torch.float16]) @pytest.mark.parametrize( "kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] ) @@ -44,6 +43,7 @@ def test_batch_decode_with_paged_kv_cache( head_dim, kv_layout, pos_encoding_mode, + return_lse, q_dtype, kv_dtype, ): @@ -75,7 +75,14 @@ def test_batch_decode_with_paged_kv_cache( data_type=kv_dtype, q_data_type=q_dtype, ) - o = wrapper.forward(q, kv_data.to(kv_dtype), pos_encoding_mode=pos_encoding_mode) + if return_lse: + o, _ = wrapper.forward_return_lse( + q, kv_data.to(kv_dtype), pos_encoding_mode=pos_encoding_mode + ) + else: + o = wrapper.forward( + q, kv_data.to(kv_dtype), pos_encoding_mode=pos_encoding_mode + ) for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] @@ -127,9 +134,7 @@ def test_batch_decode_with_paged_kv_cache( @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) -@pytest.mark.parametrize( - "q_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] -) +@pytest.mark.parametrize("q_dtype", [torch.float16]) @pytest.mark.parametrize( "kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] ) @@ -288,13 +293,13 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( if __name__ == "__main__": test_batch_decode_with_paged_kv_cache( - 256, 54, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16 + 256, 54, 8, 8, 8, 128, "NHD", "NONE", False, torch.float16, torch.float16 ) test_batch_decode_with_paged_kv_cache( - 12, 2048, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16 + 12, 2048, 8, 8, 8, 128, "NHD", "NONE", False, torch.float16, torch.float16 ) test_batch_decode_with_paged_kv_cache( - 12, 54, 1, 8, 8, 128, "HND", "NONE", torch.float8_e5m2, torch.float8_e5m2 + 12, 54, 1, 8, 8, 128, "HND", "NONE", True, torch.float16, torch.float8_e5m2 ) test_cuda_graph_batch_decode_with_paged_kv_cache( 12, 2048, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16 @@ -303,8 +308,8 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( 128, 54, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16 ) test_batch_decode_with_paged_kv_cache( - 12, 54, 1, 8, 8, 128, "HND", "NONE", torch.float8_e5m2, torch.float16 + 12, 54, 1, 8, 8, 128, "HND", "NONE", True, torch.float16, torch.float8_e5m2 ) test_cuda_graph_batch_decode_with_paged_kv_cache( - 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float8_e5m2, torch.float16 + 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float16, torch.float8_e5m2 ) diff --git a/python/tests/test_shared_prefix_kernels.py b/python/tests/test_shared_prefix_kernels.py index ba4e752f..3911f696 100644 --- a/python/tests/test_shared_prefix_kernels.py +++ b/python/tests/test_shared_prefix_kernels.py @@ -25,106 +25,7 @@ def ceil_div(a, b): return (a + b - 1) // b -@pytest.mark.parametrize("batch_size", [12, 17]) -@pytest.mark.parametrize("unique_kv_len", [37, 17]) -@pytest.mark.parametrize("shared_kv_len", [54, 97, 1979]) -@pytest.mark.parametrize("num_heads", [8, 16]) -@pytest.mark.parametrize("head_dim", [128, 256]) -def test_batch_decode_with_shared_prefix_padded_kv_cache( - batch_size, unique_kv_len, shared_kv_len, num_heads, head_dim -): - q = torch.randn(batch_size, num_heads, head_dim).to(0).half() - k_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half() - v_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half() - k_unique = torch.randn(batch_size, unique_kv_len, num_heads, head_dim).to(0).half() - v_unique = torch.randn(batch_size, unique_kv_len, num_heads, head_dim).to(0).half() - - o = flashinfer.batch_decode_with_shared_prefix_padded_kv_cache( - q, k_shared, v_shared, k_unique, v_unique - ) - - for i in range(batch_size): - qi = q[i] - ki = torch.cat([k_shared, k_unique[i]], dim=0) - vi = torch.cat([v_shared, v_unique[i]], dim=0) - o_ref_i = flashinfer.single_decode_with_kv_cache(qi, ki, vi) - o_i_np = o[i].cpu().numpy() - o_ref_i_np = o_ref_i.cpu().numpy() - numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) - - -@pytest.mark.parametrize("batch_size", [12, 17]) -@pytest.mark.parametrize("unique_kv_len", [37, 17]) -@pytest.mark.parametrize("shared_kv_len", [54, 97, 1979]) -@pytest.mark.parametrize("num_heads", [8, 16]) -@pytest.mark.parametrize("head_dim", [128, 256]) -@pytest.mark.parametrize("page_size", [1, 16]) -def test_batch_decode_with_shared_prefix_paged_kv_cache( - batch_size, unique_kv_len, shared_kv_len, num_heads, head_dim, page_size -): - kv_layout = "NHD" - q = torch.randn(batch_size, num_heads, head_dim).to(0).half() - k_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half() - v_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half() - k_unique = torch.randn(batch_size, unique_kv_len, num_heads, head_dim).to(0).half() - v_unique = torch.randn(batch_size, unique_kv_len, num_heads, head_dim).to(0).half() - - kv_data = ( - torch.zeros( - batch_size * ceil_div(unique_kv_len, page_size), - 2, - page_size, - num_heads, - head_dim, - ) - .to(0) - .half() - ) - kv_indices = ( - torch.arange(0, batch_size * ceil_div(unique_kv_len, page_size)).to(0).int() - ) - kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * ceil_div( - unique_kv_len, page_size - ) - kv_last_page_len = torch.full( - (batch_size,), (unique_kv_len - 1) % page_size + 1, dtype=torch.int32 - ).to(0) - - workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) - wrapper = flashinfer.BatchDecodeWithSharedPrefixPagedKVCacheWrapper( - workspace_buffer, kv_layout - ) - wrapper.begin_forward( - kv_indptr, - kv_indices, - kv_last_page_len, - num_heads, - num_heads, - head_dim, - page_size, - kv_data.dtype, - ) - append_indptr = torch.arange(0, batch_size + 1).to(0).int() * unique_kv_len - flashinfer.append_paged_kv_cache( - k_unique.view(-1, num_heads, head_dim), - v_unique.view(-1, num_heads, head_dim), - append_indptr, - kv_data, - kv_indices, - kv_indptr, - kv_last_page_len, - kv_layout, - ) - - o_padded = flashinfer.batch_decode_with_shared_prefix_padded_kv_cache( - q, k_shared, v_shared, k_unique, v_unique - ) - o_paged = wrapper.forward(q, k_shared, v_shared, kv_data) - numpy.testing.assert_allclose( - o_padded.cpu().numpy(), o_paged.cpu().numpy(), rtol=1e-3, atol=1e-3 - ) - - +@pytest.mark.parametrize("stage", ["decode", "append"]) @pytest.mark.parametrize("batch_size", [12, 17]) @pytest.mark.parametrize("unique_kv_len", [37, 17]) @pytest.mark.parametrize("shared_kv_len", [128, 512, 2048]) @@ -132,13 +33,26 @@ def test_batch_decode_with_shared_prefix_paged_kv_cache( @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("page_size", [1, 16]) -def test_batch_prefill_with_shared_prefix_paged_kv_cache( - batch_size, unique_kv_len, shared_kv_len, num_heads, causal, head_dim, page_size +def test_batch_attention_with_shared_prefix_paged_kv_cache( + stage, + batch_size, + unique_kv_len, + shared_kv_len, + num_heads, + causal, + head_dim, + page_size, ): + if stage == "decode" and causal == True: + pytest.skip("Causal attention is not required in decode stage") assert shared_kv_len % page_size == 0 kv_layout = "NHD" - q = torch.randn(batch_size * unique_kv_len, num_heads, head_dim).to(0).half() - q_indptr = torch.arange(0, batch_size + 1).to(0).int() * unique_kv_len + if stage == "append": + q = torch.randn(batch_size * unique_kv_len, num_heads, head_dim).to(0).half() + q_indptr = torch.arange(0, batch_size + 1).to(0).int() * unique_kv_len + else: + q = torch.randn(batch_size, num_heads, head_dim).to(0).half() + q_indptr = torch.arange(0, batch_size + 1).to(0).int() k_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half() v_shared = torch.randn(shared_kv_len, num_heads, head_dim).to(0).half() k_unique = torch.randn(batch_size * unique_kv_len, num_heads, head_dim).to(0).half() @@ -195,12 +109,20 @@ def test_batch_prefill_with_shared_prefix_paged_kv_cache( kv_layout, ) - baseline_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout - ) - cascade_wrapper = flashinfer.BatchPrefillWithSharedPrefixPagedKVCacheWrapper( - torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout - ) + if stage == "decode": + baseline_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout + ) + cascade_wrapper = flashinfer.BatchDecodeWithSharedPrefixPagedKVCacheWrapper( + torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout + ) + else: + baseline_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout + ) + cascade_wrapper = flashinfer.BatchPrefillWithSharedPrefixPagedKVCacheWrapper( + torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0), kv_layout + ) baseline_kv_indices_arr = [] for i in range(batch_size): @@ -219,40 +141,69 @@ def test_batch_prefill_with_shared_prefix_paged_kv_cache( ceil_div(shared_kv_len, page_size) + ceil_div(unique_kv_len, page_size) ) baseline_kv_last_page_len = unique_last_page_len - baseline_wrapper.begin_forward( - q_indptr, - baseline_kv_indptr, - baseline_kv_indices, - baseline_kv_last_page_len, - num_heads, - num_heads, - head_dim, - page_size, - ) - - o_baseline = baseline_wrapper.forward(q, kv_data, causal=causal) + if stage == "decode": + baseline_wrapper.begin_forward( + baseline_kv_indptr, + baseline_kv_indices, + baseline_kv_last_page_len, + num_heads, + num_heads, + head_dim, + page_size, + ) + o_baseline = baseline_wrapper.forward(q, kv_data) + else: + baseline_wrapper.begin_forward( + q_indptr, + baseline_kv_indptr, + baseline_kv_indices, + baseline_kv_last_page_len, + num_heads, + num_heads, + head_dim, + page_size, + ) + o_baseline = baseline_wrapper.forward(q, kv_data, causal=causal) cascade_kv_indices = unique_kv_indices cascade_kv_indptr = unique_kv_indptr cascade_kv_last_page_len = unique_last_page_len - cascade_wrapper.begin_forward( - q_indptr, - cascade_kv_indptr, - cascade_kv_indices, - cascade_kv_last_page_len, - num_heads, - num_heads, - head_dim, - page_size, - ) - o_cascade = cascade_wrapper.forward(q, k_shared, v_shared, kv_data, causal=causal) + if stage == "decode": + cascade_wrapper.begin_forward( + cascade_kv_indptr, + cascade_kv_indices, + cascade_kv_last_page_len, + num_heads, + num_heads, + head_dim, + page_size, + ) + o_cascade = cascade_wrapper.forward(q, k_shared, v_shared, kv_data) + else: + cascade_wrapper.begin_forward( + q_indptr, + cascade_kv_indptr, + cascade_kv_indices, + cascade_kv_last_page_len, + num_heads, + num_heads, + head_dim, + page_size, + ) + o_cascade = cascade_wrapper.forward( + q, k_shared, v_shared, kv_data, causal=causal + ) + numpy.testing.assert_allclose( o_baseline.cpu().numpy(), o_cascade.cpu().numpy(), rtol=1e-3, atol=1e-3 ) if __name__ == "__main__": - test_batch_decode_with_shared_prefix_padded_kv_cache(12, 37, 54, 8, 128) - test_batch_decode_with_shared_prefix_paged_kv_cache(12, 37, 54, 8, 128, 16) - test_batch_prefill_with_shared_prefix_paged_kv_cache(12, 37, 256, 8, True, 128, 16) + test_batch_attention_with_shared_prefix_paged_kv_cache( + "decode", 12, 37, 128, 8, False, 128, 16 + ) + test_batch_attention_with_shared_prefix_paged_kv_cache( + "apppend", 12, 37, 128, 8, True, 128, 16 + ) diff --git a/src/bench_single_decode.cu b/src/bench_single_decode.cu index 642228ac..473205af 100644 --- a/src/bench_single_decode.cu +++ b/src/bench_single_decode.cu @@ -22,7 +22,7 @@ using flashinfer::PosEncodingMode; using flashinfer::QKVLayout; -template +template void bench_flashinfer_single_decode(nvbench::state& state) { size_t seq_len = state.get_int64("seq_len"); size_t num_qo_heads = state.get_int64("num_qo_heads"); @@ -32,16 +32,16 @@ void bench_flashinfer_single_decode(nvbench::state& state) { size_t kv_layout = state.get_int64("kv_layout"); bool cooperative = state.get_int64("cooperative"); // Allocate input data: - thrust::device_vector Q(num_qo_heads * head_dim); - thrust::device_vector K(seq_len * num_kv_heads * head_dim); - thrust::device_vector V(seq_len * num_kv_heads * head_dim); - thrust::device_vector O(num_qo_heads * head_dim); - thrust::device_vector tmp(16 * 1024 * 1024); + thrust::device_vector Q(num_qo_heads * head_dim); + thrust::device_vector K(seq_len * num_kv_heads * head_dim); + thrust::device_vector V(seq_len * num_kv_heads * head_dim); + thrust::device_vector O(num_qo_heads * head_dim); + thrust::device_vector tmp(16 * 1024 * 1024); // Provide throughput information: - state.add_global_memory_reads( + state.add_global_memory_reads( num_qo_heads * head_dim + 2 * seq_len * num_kv_heads * head_dim, "Read"); - state.add_global_memory_writes(num_qo_heads * head_dim, "Write"); + state.add_global_memory_writes(num_qo_heads * head_dim, "Write"); state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { timer.start(); @@ -106,11 +106,11 @@ void bench_flashinfer_single_decode_with_prefill(nvbench::state& state) { #define STR_HELPER(x) #x #define STR(x) STR_HELPER(x) -#define BENCH_FLASHINFER_SINGLE_DECODE(dtype_in, dtype_out) \ - auto bench_flashinfer_single_decode_##dtype_in##_##dtype_out##_ = \ - bench_flashinfer_single_decode; \ - NVBENCH_BENCH(bench_flashinfer_single_decode_##dtype_in##_##dtype_out##_) \ - .set_name(("bench_flashinfer_single_decode_" STR(dtype_in) "_" STR(dtype_out))) \ +#define BENCH_FLASHINFER_SINGLE_DECODE(dtype_qo, dtype_kv) \ + auto bench_flashinfer_single_decode_##dtype_qo##_##dtype_kv##_ = \ + bench_flashinfer_single_decode; \ + NVBENCH_BENCH(bench_flashinfer_single_decode_##dtype_qo##_##dtype_kv##_) \ + .set_name(("bench_flashinfer_single_decode_" STR(dtype_qo) "_" STR(dtype_kv))) \ .add_int64_axis("seq_len", \ {32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536}) \ .add_int64_axis("num_qo_heads", {32}) \ @@ -135,7 +135,7 @@ void bench_flashinfer_single_decode_with_prefill(nvbench::state& state) { .add_int64_axis("cooperative", {1}) BENCH_FLASHINFER_SINGLE_DECODE(half, half); -BENCH_FLASHINFER_SINGLE_DECODE(__nv_fp8_e5m2, half); +BENCH_FLASHINFER_SINGLE_DECODE(half, __nv_fp8_e5m2); // Use prefill kernel for decoding, useful in GQA on GPUs with low non-tensor performance such as // A100 BENCH_FLASHINFER_SINGLE_DECODE_WITH_PREFILL(half, half); diff --git a/src/cpu_reference.h b/src/cpu_reference.h index 960653c2..1679e551 100644 --- a/src/cpu_reference.h +++ b/src/cpu_reference.h @@ -72,9 +72,9 @@ inline std::vector apply_llama_rope(const T* input, size_t D, size_t offs return std::move(rst); } -template -std::vector single_mha(const std::vector& q, const std::vector& k, - const std::vector& v, size_t qo_len, size_t kv_len, +template +std::vector single_mha(const std::vector& q, const std::vector& k, + const std::vector& v, size_t qo_len, size_t kv_len, size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kHND, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index d4361bdf..962792ee 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -192,37 +192,6 @@ cudaError_t SingleDecodeWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* return cudaSuccess; } -template -cudaError_t BatchDecodeWithPaddedKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, - DTypeOut* tmp, float* lse, uint32_t batch_size, - uint32_t padded_kv_len, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t head_dim, - QKVLayout kv_layout = QKVLayout::kNHD, - PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, - std::optional maybe_sm_scale = std::nullopt, - float rope_scale = 1.f, float rope_theta = 1e4, - cudaStream_t stream = nullptr) { - const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); - if (num_qo_heads % num_kv_heads != 0) { - std::ostringstream err_msg; - err_msg << "num_qo_heads " << num_qo_heads << " is not a multiple of num_kv_heads " - << num_kv_heads; - throw std::invalid_argument(err_msg.str()); - } - - DISPATCH_head_dim( - head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { - return BatchDecodeWithPaddedKVCacheDispatched( - q, k, v, o, tmp, lse, batch_size, padded_kv_len, num_qo_heads, num_kv_heads, - /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta, stream); - })})}); - return cudaSuccess; -} - template cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV( diff --git a/src/test_batch_decode.cu b/src/test_batch_decode.cu index 00bfe90a..5db6434b 100644 --- a/src/test_batch_decode.cu +++ b/src/test_batch_decode.cu @@ -25,7 +25,7 @@ using namespace flashinfer; constexpr QKVLayout kv_layout = QKVLayout::kNHD; -template +template void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, flashinfer::PosEncodingMode pos_encoding_mode, @@ -36,29 +36,30 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si for (size_t i = 0; i < batch_size; ++i) { append_indptr.push_back(append_indptr.back() + seq_lens[i]); } - std::vector q; - std::vector o_ref; - std::vector kv_data; + std::vector q; + std::vector o_ref; + std::vector kv_data; std::vector kv_indptr{0}; std::vector kv_indices; std::vector kv_last_page_len; size_t page_counter = 0; - std::vector> keys, values; + std::vector> keys, values; for (size_t i = 0; i < batch_size; ++i) { size_t seq_len = seq_lens[i]; size_t num_pages = (seq_len + page_size - 1) / page_size; size_t last_page_len = (seq_len - 1) % page_size + 1; - std::vector qi(num_qo_heads * head_dim), ki(seq_len * num_kv_heads * head_dim), + std::vector qi(num_qo_heads * head_dim); + std::vector ki(seq_len * num_kv_heads * head_dim), vi(seq_len * num_kv_heads * head_dim); utils::vec_normal_(qi); utils::vec_normal_(ki); utils::vec_normal_(vi); // compute reference output - std::vector o_ref_i = - cpu_reference::single_mha(qi, ki, vi, 1, seq_len, num_qo_heads, num_kv_heads, - head_dim, false, QKVLayout::kNHD, pos_encoding_mode); + std::vector o_ref_i = cpu_reference::single_mha( + qi, ki, vi, 1, seq_len, num_qo_heads, num_kv_heads, head_dim, false, QKVLayout::kNHD, + pos_encoding_mode); keys.push_back(ki); values.push_back(vi); // append new q and o_ref @@ -76,22 +77,22 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si assert(q.size() == batch_size * num_qo_heads * head_dim); assert(o_ref.size() == batch_size * num_qo_heads * head_dim); - flashinfer::paged_kv_t paged_kv_cpu( + flashinfer::paged_kv_t paged_kv_cpu( num_kv_heads, page_size, head_dim, batch_size, kv_data.data(), kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); - cpu_reference::append_paged_kv_cache(paged_kv_cpu, keys, values, - append_indptr); + cpu_reference::append_paged_kv_cache(paged_kv_cpu, keys, values, + append_indptr); // copy data to device - thrust::device_vector kv_data_device(kv_data); + thrust::device_vector kv_data_device(kv_data); thrust::device_vector kv_indptr_device(kv_indptr); thrust::device_vector kv_indices_device(kv_indices); thrust::device_vector kv_last_page_len_device(kv_last_page_len); - thrust::device_vector q_device(q); - thrust::device_vector o_device(o_ref.size()); + thrust::device_vector q_device(q); + thrust::device_vector o_device(o_ref.size()); // create paged_kv object - flashinfer::paged_kv_t paged_kv( + flashinfer::paged_kv_t paged_kv( num_kv_heads, page_size, head_dim, batch_size, thrust::raw_pointer_cast(kv_data_device.data()), thrust::raw_pointer_cast(kv_indices_device.data()), @@ -100,31 +101,32 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si flashinfer::BatchDecodeHandler handler; size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - BatchDecodeHandlerBeginForward( - &handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, - kv_indptr.data(), kv_last_page_len.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, - page_size, pos_encoding_mode); + BatchDecodeHandlerBeginForward(&handler, (void*)thrust::raw_pointer_cast(buffer.data()), + workspace_size_in_bytes, kv_indptr.data(), + kv_last_page_len.data(), batch_size, num_qo_heads, + num_kv_heads, head_dim, page_size, pos_encoding_mode); if (!cooperative) { // use non-cooperative kernel cudaError_t status = - flashinfer::BatchDecodeWithPagedKVCacheNoSplitKV( + flashinfer::BatchDecodeWithPagedKVCacheNoSplitKV( thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv, kv_partition_info_t(), thrust::raw_pointer_cast(o_device.data()), /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); } else { cudaError_t status = - flashinfer::BatchDecodeWithPagedKVCacheWrapper( + flashinfer::BatchDecodeWithPagedKVCacheWrapper( &handler, thrust::raw_pointer_cast(q_device.data()), /*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()), /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status)); } // compare result - thrust::host_vector o_host = o_device; + thrust::host_vector o_host = o_device; size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; bool nan_detected = false; for (size_t i = 0; i < batch_size * num_qo_heads * head_dim; ++i) { @@ -145,7 +147,7 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si EXPECT_EQ(nan_detected, false) << "NaN detected."; } -template +template void TestBatchDecodeKernelCorrectness() { for (size_t page_size : {1, 3, 7, 16}) { for (size_t batch_size : {1, 7, 37, 61}) { @@ -153,7 +155,7 @@ void TestBatchDecodeKernelCorrectness() { for (size_t num_kv_heads : {32, 8, 4}) { for (size_t head_dim : {64, 128, 256}) { for (size_t pos_encoding_mode : {0U, 1U}) { - _TestBatchDecodingKernelCorrectness( + _TestBatchDecodingKernelCorrectness( page_size, batch_size, num_qo_heads, num_kv_heads, head_dim, flashinfer::PosEncodingMode(pos_encoding_mode), false); } @@ -164,7 +166,7 @@ void TestBatchDecodeKernelCorrectness() { } } -template +template void TestCooperativeBatchDecodeKernelCorrectness() { for (size_t page_size : {1, 3, 7, 16}) { for (size_t batch_size : {1, 2, 4, 8}) { @@ -172,7 +174,7 @@ void TestCooperativeBatchDecodeKernelCorrectness() { for (size_t num_kv_heads : {32, 8, 4}) { for (size_t head_dim : {64, 128, 256}) { for (size_t pos_encoding_mode : {0U, 1U}) { - _TestBatchDecodingKernelCorrectness( + _TestBatchDecodingKernelCorrectness( page_size, batch_size, num_qo_heads, num_kv_heads, head_dim, flashinfer::PosEncodingMode(pos_encoding_mode), true); } @@ -184,25 +186,25 @@ void TestCooperativeBatchDecodeKernelCorrectness() { } TEST(FlashInferCorrectnessTest, BatchDecodeKernelCorrectnessTestFP16) { - TestBatchDecodeKernelCorrectness(); + TestBatchDecodeKernelCorrectness(); } #ifdef FLASHINFER_ENABLE_BF16 TEST(FlashInferCorrectnessTest, TestBatchDecodeKernelCorrectnessBF16) { - TestBatchDecodeKernelCorrectness<__nv_bfloat16>(); + TestBatchDecodeKernelCorrectness<__nv_bfloat16, __nv_bfloat16>(); } #endif #ifdef FLASHINFER_ENABLE_FP8 TEST(FlashInferCorrectnessTest, TestBatchDecodeKernelCorrectnessE4M3) { - TestBatchDecodeKernelCorrectness<__nv_fp8_e4m3>(); + TestBatchDecodeKernelCorrectness(); } TEST(FlashInferCorrectnessTest, TestBatchDecodeKernelCorrectnessE5M2) { - TestBatchDecodeKernelCorrectness<__nv_fp8_e5m2>(); + TestBatchDecodeKernelCorrectness(); } #endif TEST(FlashInferCorrectnessTest, TestCooperativeBatchDecodeKernelCorrectnessTestFP16) { - TestCooperativeBatchDecodeKernelCorrectness(); + TestCooperativeBatchDecodeKernelCorrectness(); } diff --git a/src/test_batch_prefill.cu b/src/test_batch_prefill.cu index 76bea74d..a9e78468 100644 --- a/src/test_batch_prefill.cu +++ b/src/test_batch_prefill.cu @@ -94,7 +94,7 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n std::vector q(q_len * num_qo_heads * head_dim); utils::vec_normal_(q); - std::vector o_ref = cpu_reference::single_mha( + std::vector o_ref = cpu_reference::single_mha( q, key[request_idx], value[request_idx], q_len, kv_len, num_qo_heads, num_kv_heads, head_dim, causal, QKVLayout::kNHD, pos_encoding_mode); @@ -173,8 +173,8 @@ void _TestBatchRaggedPrefillKernelCorrectness(size_t num_kv_heads, size_t num_qo utils::vec_normal_(k); utils::vec_normal_(v); std::vector o_ref = - cpu_reference::single_mha(q, k, v, q_len, kv_len, num_qo_heads, num_kv_heads, - head_dim, causal, QKVLayout::kNHD, pos_encoding_mode); + cpu_reference::single_mha(q, k, v, q_len, kv_len, num_qo_heads, num_kv_heads, + head_dim, causal, QKVLayout::kNHD, pos_encoding_mode); // NOTE(Zihao): The following code is only compatible with kv_layout = QKVLayout::kNHD std::copy(q.begin(), q.end(), std::back_inserter(queries)); std::copy(k.begin(), k.end(), std::back_inserter(keys)); @@ -297,7 +297,7 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { // create one-hot queries int32_t q_len = q_lens[request_idx], kv_len = kv_lens[request_idx]; - std::vector o_ref_i = cpu_reference::single_mha( + std::vector o_ref_i = cpu_reference::single_mha( q[request_idx], key[request_idx], value[request_idx], q_len, kv_len, num_qo_heads, num_kv_heads, head_dim, causal, QKVLayout::kNHD, pos_encoding_mode); o_ref.push_back(o_ref_i); @@ -402,8 +402,8 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz utils::vec_normal_(q); std::vector o_ref = - cpu_reference::single_mha(q, k, v, q_lens[0], kv_lens[0], num_qo_heads, num_kv_heads, - head_dim, causal, QKVLayout::kNHD, pos_encoding_mode); + cpu_reference::single_mha(q, k, v, q_lens[0], kv_lens[0], num_qo_heads, num_kv_heads, + head_dim, causal, QKVLayout::kNHD, pos_encoding_mode); thrust::device_vector q_indptr_device(q_indptr); thrust::device_vector q_device(q); diff --git a/src/test_single_decode.cu b/src/test_single_decode.cu index 1bdb123f..b316486e 100644 --- a/src/test_single_decode.cu +++ b/src/test_single_decode.cu @@ -23,32 +23,32 @@ using namespace flashinfer; -template +template void _TestDecodingKernelCorrectness(size_t num_qo_heads, size_t num_kv_heads, size_t seq_len, size_t head_dim, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode) { - std::vector Q_host(num_qo_heads * head_dim); - std::vector K_host(seq_len * num_kv_heads * head_dim); - std::vector V_host(seq_len * num_kv_heads * head_dim); - std::vector O_host(num_qo_heads * head_dim); + std::vector Q_host(num_qo_heads * head_dim); + std::vector K_host(seq_len * num_kv_heads * head_dim); + std::vector V_host(seq_len * num_kv_heads * head_dim); + std::vector O_host(num_qo_heads * head_dim); utils::vec_normal_(Q_host); utils::vec_normal_(K_host); utils::vec_normal_(V_host); utils::vec_zero_(O_host); - thrust::device_vector Q(Q_host); - thrust::device_vector K(K_host); - thrust::device_vector V(V_host); - thrust::device_vector O(O_host); - thrust::device_vector tmp(16 * 1024 * 1024); - std::vector o_ref_host; + thrust::device_vector Q(Q_host); + thrust::device_vector K(K_host); + thrust::device_vector V(V_host); + thrust::device_vector O(O_host); + thrust::device_vector tmp(32 * 1024 * 1024); + std::vector o_ref_host; - o_ref_host = - cpu_reference::single_mha(Q_host, K_host, V_host, 1, seq_len, num_qo_heads, - num_kv_heads, head_dim, false, kv_layout, pos_encoding_mode); + o_ref_host = cpu_reference::single_mha( + Q_host, K_host, V_host, 1, seq_len, num_qo_heads, num_kv_heads, head_dim, false, kv_layout, + pos_encoding_mode); - cudaError_t status = SingleDecodeWithKVCache( + cudaError_t status = SingleDecodeWithKVCache( thrust::raw_pointer_cast(Q.data()), thrust::raw_pointer_cast(K.data()), thrust::raw_pointer_cast(V.data()), thrust::raw_pointer_cast(O.data()), thrust::raw_pointer_cast(tmp.data()), num_qo_heads, num_kv_heads, seq_len, head_dim, @@ -56,8 +56,7 @@ void _TestDecodingKernelCorrectness(size_t num_qo_heads, size_t num_kv_heads, si EXPECT_EQ(status, cudaSuccess) << "SingleDecodeWithKVCache kernel launch failed, error message: " << cudaGetErrorString(status); - thrust::host_vector o_host = O; - thrust::host_vector tmp_host = tmp; + thrust::host_vector o_host = O; size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; bool nan_detected = false; @@ -79,7 +78,7 @@ void _TestDecodingKernelCorrectness(size_t num_qo_heads, size_t num_kv_heads, si EXPECT_FALSE(nan_detected) << "NaN detected."; } -template +template void TestSingleDecodeKernelCorrectness() { for (size_t num_qo_heads : {32}) { for (size_t num_kv_heads : {4, 8, 32}) { @@ -88,9 +87,9 @@ void TestSingleDecodeKernelCorrectness() { for (size_t head_dim : {64, 128, 256}) { for (unsigned int kv_layout : {0U, 1U}) { for (unsigned int pos_encoding_mode : {0U, 1U}) { - _TestDecodingKernelCorrectness(num_qo_heads, num_kv_heads, seq_len, head_dim, - QKVLayout(kv_layout), - PosEncodingMode(pos_encoding_mode)); + _TestDecodingKernelCorrectness(num_qo_heads, num_kv_heads, seq_len, + head_dim, QKVLayout(kv_layout), + PosEncodingMode(pos_encoding_mode)); } } } @@ -100,20 +99,20 @@ void TestSingleDecodeKernelCorrectness() { } TEST(FlashInferCorrectnessTest, SingleDecodeKernelCorrectnessTestFP16) { - TestSingleDecodeKernelCorrectness(); + TestSingleDecodeKernelCorrectness(); } #ifdef FLASHINFER_ENABLE_BF16 TEST(FlashInferCorrectnessTest, SingleDecodeKernelCorrectnessTestBF16) { - TestSingleDecodeKernelCorrectness(); + TestSingleDecodeKernelCorrectness(); } #endif #ifdef FLASHINFER_ENABLE_FP8 TEST(FlashInferCorrectnessTest, SingleDecodeKernelCorrectnessTestE4M3) { - TestSingleDecodeKernelCorrectness<__nv_fp8_e4m3>(); + TestSingleDecodeKernelCorrectness(); } TEST(FlashInferCorrectnessTest, SingleDecodeKernelCorrectnessTestE5M2) { - TestSingleDecodeKernelCorrectness<__nv_fp8_e5m2>(); + TestSingleDecodeKernelCorrectness(); } #endif diff --git a/src/test_single_prefill.cu b/src/test_single_prefill.cu index b37d1cff..963bcc4d 100644 --- a/src/test_single_prefill.cu +++ b/src/test_single_prefill.cu @@ -54,7 +54,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu << cudaGetErrorString(status); thrust::host_vector o_h(o_d); - std::vector o_ref = cpu_reference::single_mha( + std::vector o_ref = cpu_reference::single_mha( q, k, v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, causal, kv_layout, pos_encoding_mode); size_t num_results_error_atol = 0;