Skip to content

Commit

Permalink
refactor: reduce the binary size of batch decode kernels (#343)
Browse files Browse the repository at this point in the history
This PR refactors the batch decode related kernels, and make the
following breaking changes:
1. remove `batch_decode_with_padded_kv_cache` operator, we encourage
user to use `BatchDecodeWithPagedKVCacheWrapper`.
2. Delete redundant DTypeQ * DTypeKV combinations, now we only support
the following cases:
  1. DTypeQ == DTypeKV
  2. DTypeQ is a float16 and DTypeKV is a float8

The output data type follows the query data type.
  • Loading branch information
yzh119 authored Jun 30, 2024
1 parent e0a233a commit 0d333ff
Show file tree
Hide file tree
Showing 29 changed files with 359 additions and 1,225 deletions.
42 changes: 8 additions & 34 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ foreach(head_dim IN LISTS HEAD_DIMS)
list(APPEND single_decode_kernels_src ${generated_kernel_src})
endforeach(dtype)

# fp8 in, fp16 out
foreach(dtype IN LISTS DECODE_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16.cu)
# fp8 kv-cache
foreach(dtype_kv IN LISTS DECODE_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_decode_inst.py ${generated_kernel_src}
Expand All @@ -147,7 +147,7 @@ foreach(head_dim IN LISTS HEAD_DIMS)
VERBATIM
)
list(APPEND single_decode_kernels_src ${generated_kernel_src})
endforeach(dtype)
endforeach(dtype_kv)
endforeach(pos_encoding_mode)
endforeach(kv_layout)
endforeach(logits_post_hook)
Expand All @@ -172,9 +172,9 @@ foreach(head_dim IN LISTS HEAD_DIMS)
list(APPEND batch_decode_kernels_src ${generated_kernel_src})
endforeach(dtype)

# fp8 in, fp16 out
foreach(dtype IN LISTS DECODE_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16_idtype_${idtype}.cu)
# fp8 kv-cache
foreach(dtype_kv IN LISTS DECODE_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_f16_dtypekv_${dtype_kv}_dtypeout_f16_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_decode_inst.py ${generated_kernel_src}
Expand All @@ -183,34 +183,8 @@ foreach(head_dim IN LISTS HEAD_DIMS)
VERBATIM
)
list(APPEND batch_decode_kernels_src ${generated_kernel_src})
endforeach()
endforeach(dtype_kv)
endforeach(idtype)

# padded kv-cache
foreach(dtype IN LISTS DECODE_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_padded_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_${dtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
list(APPEND batch_decode_kernels_src ${generated_kernel_src})
endforeach(dtype)

# padded kv-cache, fp8 in, fp16 out
foreach(dtype IN LISTS DECODE_FP8_DTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_padded_decode_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_dtypeq_${dtype}_dtypekv_${dtype}_dtypeout_f16.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_padded_decode_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
list(APPEND batch_decode_kernels_src ${generated_kernel_src})
endforeach()
endforeach(pos_encoding_mode)
endforeach(kv_layout)
endforeach(logits_post_hook)
Expand Down
200 changes: 0 additions & 200 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -358,141 +358,6 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
}
}

template <LogitsPostHook logits_post_hook, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode,
uint32_t num_stages_smem, uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz,
typename DTypeQ, typename DTypeKV, typename DTypeOut>
__global__ void BatchDecodeWithPaddedKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k,
DTypeKV* __restrict__ v,
DTypeOut* __restrict__ o,
float* __restrict__ lse,
tensor_info_t<kv_layout, bdx * vec_size> info,
float logits_soft_cap, float sm_scale,
float rope_rcp_scale, float rope_rcp_theta) {
auto block = cg::this_thread_block();
sm_scale *=
(logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap));

constexpr uint32_t head_dim = bdx * vec_size;
uint32_t kv_head_idx = blockIdx.y;
uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y;
uint32_t batch_idx = blockIdx.x;
uint32_t num_qo_heads = info.num_qo_heads;
uint32_t num_kv_heads = info.num_kv_heads;
const float alibi_slope = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e;
uint32_t seq_len = info.kv_len;

extern __shared__ uint8_t smem[];
DTypeKV* k_smem = (DTypeKV*)smem;
DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeKV));
float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeKV));

uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z;
vec_t<float, vec_size> q_vec;
vec_t<float, vec_size> freq;
if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) {
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
freq[i] = rope_rcp_scale *
__powf(rope_rcp_theta,
float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim));
}
// apply rotary embedding to q matrix
q_vec = vec_apply_llama_rope<vec_size, bdx>(
q + batch_idx * num_qo_heads * head_dim + info.get_qo_elem_offset(0, qo_head_idx, 0), freq,
seq_len - 1);
} else {
// do not apply rotary embedding to q matrix
q_vec.cast_load(q + batch_idx * num_qo_heads * head_dim +
info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size));
}
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
q_vec[i] *= sm_scale;
}
block.sync();

// preload k tiles and v tiles
uint32_t producer_kv_idx_base = 0;
constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8;
#pragma unroll
for (uint32_t iter = 0; iter < num_stages_smem; ++iter) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
k_smem + ((iter * bdz + tz) * bdy + ty) * head_dim + tx * vec_size,
k + batch_idx * seq_len * num_kv_heads * head_dim +
info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + tz * bdy + ty < seq_len);
cp_async::commit_group();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + ((iter * bdz + tz) * bdy + ty) * head_dim + tx * vec_size,
v + batch_idx * seq_len * num_kv_heads * head_dim +
info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + tz * bdy + ty < seq_len);
cp_async::commit_group();
producer_kv_idx_base += bdy * bdz;
}

// pipelining k/v tiles loading and state updating
uint32_t consumer_kv_idx_base = 0, stage_idx = 0;
state_t<vec_size> st_local;
float s[bdy];

#pragma unroll 4
for (uint32_t iter = 0; iter < ceil_div(seq_len, bdy * bdz); ++iter) {
// compute qk
cp_async::wait_group<2 * num_stages_smem - 1>();
block.sync();
compute_qk<logits_post_hook, pos_encoding_mode, vec_size, bdx, bdy>(
k_smem + (stage_idx * bdz + tz) * bdy * head_dim, stage_idx, q_vec, freq,
consumer_kv_idx_base, iter * bdy * bdz, seq_len, seq_len - 1, alibi_slope, s, st_local,
logits_soft_cap);
block.sync();
// load k
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
k_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size,
k + batch_idx * seq_len * num_kv_heads * head_dim +
info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + tz * bdy + ty < seq_len);
cp_async::commit_group();

// update m/d/o state
cp_async::wait_group<2 * num_stages_smem - 1>();
block.sync();
update_local_state<vec_size, bdx, bdy>(v_smem + (stage_idx * bdz + tz) * bdy * head_dim, s,
stage_idx, st_local);
block.sync();

// load v
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + ((stage_idx * bdz + tz) * bdy + ty) * head_dim + tx * vec_size,
v + batch_idx * seq_len * num_kv_heads * head_dim +
info.get_kv_elem_offset(producer_kv_idx_base + tz * bdy + ty, kv_head_idx,
tx * vec_size),
producer_kv_idx_base + tz * bdy + ty < seq_len);
cp_async::commit_group();

stage_idx = (stage_idx + 1) % num_stages_smem;
producer_kv_idx_base += bdy * bdz;
consumer_kv_idx_base += bdy * bdz;
}
cp_async::wait_group<0>();
block.sync();

// sync local state of all warps inside a threadblock
sync_state<vec_size, bdx, bdy, bdz>(st_local, reinterpret_cast<float*>(smem), smem_md);

st_local.normalize();
st_local.o.cast_store(o + batch_idx * num_qo_heads * head_dim +
info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size));

// write lse
if (lse != nullptr) {
lse[batch_idx * num_qo_heads + qo_head_idx] = st_local.get_lse();
}
}

/*!
* \brief FlashAttention decoding cuda kernel with paged kv-cache for multiple requests
* \tparam logits_post_hook The logits post hook used in the kernel
Expand Down Expand Up @@ -937,71 +802,6 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
return cudaSuccess;
}

/*!
* \brief FlashAttention decoding cuda kernel with paged kv-cache for batched requests
* \tparam page_storage Whether to store indices or pointers of each active page
* \tparam DTypeQ A template type indicates the query data type
* \tparam DTypeKV A template type indicates the key-value data type
* \tparam DTypeOut A template type indicates the output data type
* \tparam IdType A template type indicates the index data type used in paged kv-cache
* \param q [batch_size, num_qo_heads, head_dim] The query matrix
* \param paged_kv The paged kv cache data structure
* \param o [batch_size, num_qo_heads, head_dim] The output matrix
* \param tmp Used-allocated temporary buffer
* \param lse The logsumexp values.
* \param num_qo_heads A integer indicates the number of heads of query and output
* \param pos_encoding_mode The positional encoding mode
* \param rope_scale The scaling ratio used in RoPE Interpolation.
* \param rope_theta A floating point number indicate the "theta" used in RoPE
* \param stream The cuda stream to launch the kernel
* \return status Indicates whether CUDA calls are successful
*/
template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOUT,
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut>
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, float* lse, uint32_t batch_size,
uint32_t padded_kv_len, uint32_t num_qo_heads,
uint32_t num_kv_heads, float logits_soft_cap,
float sm_scale, float rope_scale,
float rope_theta, cudaStream_t stream) {
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;

constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL);
constexpr uint32_t num_stages_smem = 2U;
constexpr uint32_t bdx = HEAD_DIM / vec_size;
static_assert(bdx <= 32);
DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, {
constexpr uint32_t bdy = GROUP_SIZE;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);

const uint32_t smem_size = 2 * num_stages_smem * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) +
2 * bdy * bdz * sizeof(float);

dim3 nblks(batch_size, num_kv_heads);
dim3 nthrs(bdx, bdy, bdz);
auto kernel = BatchDecodeWithPaddedKVCacheKernel<LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE,
num_stages_smem, vec_size, bdx, bdy, bdz,
DTypeQ, DTypeKV, DTypeOut>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
tensor_info_t<KV_LAYOUT, HEAD_DIM> info(1, padded_kv_len, num_qo_heads, num_kv_heads);
void* args[] = {(void*)&q,
(void*)&k,
(void*)&v,
(void*)&o,
(void*)&lse,
(void*)&info,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});
return cudaSuccess;
}

} // namespace flashinfer

#endif // FLASHINFER_DECODE_CUH_
9 changes: 0 additions & 9 deletions include/flashinfer/decode_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,6 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
float* lse, bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream);

template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOUT,
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut>
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, float* lse, uint32_t batch_size,
uint32_t padded_kv_len, uint32_t num_qo_heads,
uint32_t num_kv_heads, float logits_soft_cap,
float sm_scale, float rope_scale,
float rope_theta, cudaStream_t stream);

template <PageStorage PAGE_STORAGE, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV,
typename DTypeOut, typename IdType>
Expand Down
3 changes: 3 additions & 0 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@
if (group_size == 1) { \
constexpr size_t GROUP_SIZE = 1; \
__VA_ARGS__ \
} else if (group_size == 2) { \
constexpr size_t GROUP_SIZE = 2; \
__VA_ARGS__ \
} else if (group_size == 4) { \
constexpr size_t GROUP_SIZE = 4; \
__VA_ARGS__ \
Expand Down
1 change: 0 additions & 1 deletion python/MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# sdist & wheel
include version.txt
include generate_batch_padded_decode_inst.py
include generate_batch_paged_decode_inst.py
include generate_batch_paged_prefill_inst.py
include generate_batch_ragged_prefill_inst.py
Expand Down
Loading

0 comments on commit 0d333ff

Please sign in to comment.