Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding stream to 1 kernel. #590

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions exllamav2/exllamav2_ext/cuda/cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ void array_fp16_to_fp8_cuda(const half* pIn, unsigned char *pOut, int stride, in
gridDim.x = DIVIDE((max - min) / 8, THREADS);
gridDim.y = height;

fp16_to_fp8_kernel<<<gridDim, blockDim>>>(pIn, pOut, stride, height, min, max);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
fp16_to_fp8_kernel<<<gridDim, blockDim, 0, stream>>>(pIn, pOut, stride, height, min, max);
// cuda_check( cudaPeekAtLastError() );
}

Expand All @@ -113,7 +114,8 @@ void array_fp8_to_fp16_cuda(const unsigned char* pIn, half* pOut, int stride, in
gridDim.x = DIVIDE((max - min) / 8, THREADS);
gridDim.y = height;

fp8_to_fp16_kernel<<<gridDim, blockDim>>>(pIn, pOut, stride, height, min, max);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
fp8_to_fp16_kernel<<<gridDim, blockDim, 0, stream>>>(pIn, pOut, stride, height, min, max);
// cuda_check( cudaPeekAtLastError() );
}

Expand Down
10 changes: 6 additions & 4 deletions exllamav2/exllamav2_ext/cuda/h_add.cu
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,15 @@ void cuda_vector_add_
int width
)
{
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (width % 8 == 0)
{
dim3 blockDim, gridDim;
blockDim.x = min(NUM_THREADS_INT4, width / 8);
gridDim.x = DIVIDE(width, NUM_EL_INT4);
gridDim.y = DIVIDE(height, NUM_THREADS_Y_INT4);

cuda_vector_add_int4_kernel<<<gridDim, blockDim>>>(dest, source, height, width);
cuda_vector_add_int4_kernel<<<gridDim, blockDim, 0, stream>>>(dest, source, height, width);
}
else
{
Expand All @@ -153,7 +154,7 @@ void cuda_vector_add_
gridDim.x = DIVIDE(width, NUM_THREADS_X * 2);
gridDim.y = DIVIDE(height, NUM_THREADS_Y);

cuda_vector_add_kernel<<<gridDim, blockDim>>>(dest, source, height, width);
cuda_vector_add_kernel<<<gridDim, blockDim, 0, stream>>>(dest, source, height, width);
}
}

Expand All @@ -165,14 +166,15 @@ void cuda_vector_set_
int width
)
{
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (width % 8 == 0)
{
dim3 blockDim, gridDim;
blockDim.x = min(NUM_THREADS_INT4, width / 8);
gridDim.x = DIVIDE(width, NUM_EL_INT4);
gridDim.y = DIVIDE(height, NUM_THREADS_Y_INT4);

cuda_vector_set_int4_kernel<<<gridDim, blockDim>>>(dest, source, height, width);
cuda_vector_set_int4_kernel<<<gridDim, blockDim, 0, stream>>>(dest, source, height, width);
}
else
{
Expand All @@ -181,7 +183,7 @@ void cuda_vector_set_
gridDim.x = DIVIDE(width, NUM_THREADS_X * 2);
gridDim.y = DIVIDE(height, NUM_THREADS_Y);

cuda_vector_set_kernel<<<gridDim, blockDim>>>(dest, source, height, width);
cuda_vector_set_kernel<<<gridDim, blockDim, 0, stream>>>(dest, source, height, width);
}
}

7 changes: 4 additions & 3 deletions exllamav2/exllamav2_ext/cuda/h_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ void h_gemm_cuda
const float beta
)
{
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if ((beta == 1.0f || beta == 0.0f) && (alpha == 1.0f))
{
bool clear = (beta == 0.0f);
Expand All @@ -241,7 +242,7 @@ void h_gemm_cuda
// DBGI3(blockDim.x, blockDim.y, blockDim.z);
// DBGI3(gridDim.x, gridDim.y, gridDim.z);

h_gemm_tall_kernel<<<gridDim, blockDim>>>(size_m, size_n, size_k, a, b, c, clear);
h_gemm_tall_kernel<<<gridDim, blockDim, 0, stream>>>(size_m, size_n, size_k, a, b, c, clear);
cuda_check( cudaPeekAtLastError() );
return;
}
Expand All @@ -261,7 +262,7 @@ void h_gemm_cuda
// DBGI3(blockDim.x, blockDim.y, blockDim.z);
// DBGI3(gridDim.x, gridDim.y, gridDim.z);

h_gemm_wide_kernel<<<gridDim, blockDim>>>(size_m, size_n, size_k, a, b, c, clear);
h_gemm_wide_kernel<<<gridDim, blockDim, 0, stream>>>(size_m, size_n, size_k, a, b, c, clear);
cuda_check( cudaPeekAtLastError() );
return;
}
Expand All @@ -271,4 +272,4 @@ void h_gemm_cuda
// DBGI3(size_m, size_n, size_k);
cuda_check( cudaPeekAtLastError() );

}
}
3 changes: 2 additions & 1 deletion exllamav2/exllamav2_ext/cuda/head_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,6 @@ void head_norm_cuda

float r_dim = 1.0f / (float) head_dim;

head_norm_kernel<<<gridDim, blockDim>>>(x, w, b, y, epsilon, r_dim, rows, num_heads, head_dim);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
head_norm_kernel<<<gridDim, blockDim, 0, stream>>>(x, w, b, y, epsilon, r_dim, rows, num_heads, head_dim);
}
3 changes: 2 additions & 1 deletion exllamav2/exllamav2_ext/cuda/layer_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -204,5 +204,6 @@ void layer_norm_cuda

int blocks_per_warp = DIVIDE(dim, NUM_THREADS * 2);
fp_layer_norm_kernel kernel = pick_layer_norm_kernel(blocks_per_warp);
kernel<<<gridDim, blockDim>>>(x, w, b, y, epsilon, r_dim, rows, dim, add_residual);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>(x, w, b, y, epsilon, r_dim, rows, dim, add_residual);
}
6 changes: 4 additions & 2 deletions exllamav2/exllamav2_ext/cuda/pack_tensor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ void pack_rows_4_cuda
dim3 threads(BLOCKSIZE_X, BLOCKSIZE_Y);
dim3 blocks(DIVIDE(out_columns, BLOCKSIZE_X), DIVIDE(rows, BLOCKSIZE_Y));

pack_rows_4_kernel<<<blocks, threads>>>(input, output, rows, out_columns);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
pack_rows_4_kernel<<<blocks, threads, 0, stream>>>(input, output, rows, out_columns);
}

// Pack rows:
Expand Down Expand Up @@ -93,7 +94,8 @@ void pack_rows_6_cuda
dim3 threads(BLOCKSIZE_X, BLOCKSIZE_Y);
dim3 blocks(DIVIDE(out_columns, BLOCKSIZE_X), DIVIDE(rows, BLOCKSIZE_Y));

pack_rows_6_kernel<<<blocks, threads>>>(input, output, rows, out_columns);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
pack_rows_6_kernel<<<blocks, threads, 0, stream>>>(input, output, rows, out_columns);
}

// Pack columns
Expand Down
6 changes: 4 additions & 2 deletions exllamav2/exllamav2_ext/cuda/q_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ void gemm_half_q_half_cuda_part

// Launch kernel

kernel<<<gridDim, blockDim>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>
(
a,
b->cuda_q_weight,
Expand Down Expand Up @@ -165,7 +166,8 @@ void gemm_half_q_half_cuda_part
// print_global_mem(r_weights, 1, 1, 1);
// DBGI(r_weights_stride);

kernel<<<gridDim, blockDim>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>
(
a,
b->cuda_q_weight,
Expand Down
17 changes: 11 additions & 6 deletions exllamav2/exllamav2_ext/cuda/q_matrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ QMatrix::QMatrix
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1;

shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
}

QMatrix::~QMatrix()
Expand Down Expand Up @@ -491,10 +492,11 @@ void QMatrix::reconstruct(half* out, int row_a, int row_b)

gridDim.y = DIVIDE(row_b - row_a, BLOCK_KN_SIZE);

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (!is_gptq)
{
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
reconstruct_kernel<<<gridDim, blockDim>>>
reconstruct_kernel<<<gridDim, blockDim, 0, stream>>>
(
cuda_q_weight,
cuda_q_perm,
Expand All @@ -519,7 +521,7 @@ void QMatrix::reconstruct(half* out, int row_a, int row_b)
else
{
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE * 4);
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
(
cuda_q_weight,
cuda_q_perm,
Expand Down Expand Up @@ -640,7 +642,8 @@ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx)
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = height / 8;

make_sequential_kernel<<<gridDim, blockDim>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
(
cuda_q_weight,
cuda_new_qweight,
Expand Down Expand Up @@ -722,7 +725,8 @@ void matrix_fp8_to_fp16_cuda
dim3 blockDim, gridDim;
blockDim.x = THREADS_F;
gridDim.x = numel / (BLOCKSIZE_F * THREADS_F);
matrix_fp8_to_fp16_kernel<<<gridDim, blockDim>>>(in_ptr, out_ptr);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
matrix_fp8_to_fp16_kernel<<<gridDim, blockDim, 0, stream>>>(in_ptr, out_ptr);
}

void matrix_fp16_to_fp8_cuda
Expand All @@ -738,7 +742,8 @@ void matrix_fp16_to_fp8_cuda
dim3 blockDim, gridDim;
blockDim.x = THREADS_F;
gridDim.x = numel / (BLOCKSIZE_F * THREADS_F);
matrix_fp16_to_fp8_kernel<<<gridDim, blockDim>>>(in_ptr, out_ptr);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
matrix_fp16_to_fp8_kernel<<<gridDim, blockDim, 0, stream>>>(in_ptr, out_ptr);
}

// Q4/FP16 convert funcs
Expand Down
14 changes: 8 additions & 6 deletions exllamav2/exllamav2_ext/cuda/q_mlp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ void QMLP::forward_

// Up proj with gate

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (gate)
{
gemm_half_q_half_cuda(cublas_handle, norm_state, gate, temp_a, rows, intermediate_size, columns, true, temp_dq);
Expand All @@ -105,7 +106,7 @@ void QMLP::forward_
apply_loras_cuda(cublas_handle, up_proj_lora, loras, up, norm_state, temp_b, lora_temp, rows);

fp_act_mul_kernel kernel = pick_act_mul_kernel(use_half2, false, act_gelu);
kernel<<<gridDim, blockDim>>>(temp_a, temp_b, rows, intermediate_size, NULL, 0);
kernel<<<gridDim, blockDim, 0, stream>>>(temp_a, temp_b, rows, intermediate_size, NULL, 0);
}

// Up proj without gate
Expand All @@ -117,7 +118,7 @@ void QMLP::forward_
apply_loras_cuda(cublas_handle, up_proj_lora, loras, up, norm_state, temp_a, lora_temp, rows);

fp_act_kernel kernel = pick_act_kernel(use_half2, false, act_gelu);
kernel<<<gridDim, blockDim>>>(temp_a, rows, intermediate_size, NULL, 0);
kernel<<<gridDim, blockDim, 0, stream>>>(temp_a, rows, intermediate_size, NULL, 0);
}

// Down proj without post_layernorm
Expand Down Expand Up @@ -244,12 +245,13 @@ void QMoEMLP::forward_
blockDim.y = 1;
gridDim.x = 1;
gridDim.y = DIVIDE(rows, WARPS);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (num_experts == 4)
softmax4_topk_norm_kernel<<<gridDim, blockDim>>>(temp_logits, rows, num_experts_per_token);
softmax4_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
else if (num_experts == 8)
softmax8_topk_norm_kernel<<<gridDim, blockDim>>>(temp_logits, rows, num_experts_per_token);
softmax8_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
else if (num_experts == 16)
softmax16_topk_norm_kernel<<<gridDim, blockDim>>>(temp_logits, rows, num_experts_per_token);
softmax16_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);

// For small no. rows, execute all kernels but pass the routing weights. Rows with a weight of zero will skip dot
// product accum and kernels launched with only zero-weights will exit prematurely.
Expand All @@ -271,7 +273,7 @@ void QMoEMLP::forward_
blockDim.y = THREADS_Y;
gridDim.x = DIVIDE(intermediate_size, THREADS_X) / (use_half2 ? 2 : 1);
gridDim.y = DIVIDE(rows, THREADS_Y);
kernel<<<gridDim, blockDim>>>(temp_a, temp_b, rows, intermediate_size, temp_logits + i, num_experts);
kernel<<<gridDim, blockDim, 0, stream>>>(temp_a, temp_b, rows, intermediate_size, temp_logits + i, num_experts);

gemm_half_q_half_cuda(cublas_handle, temp_a, w2[i], x, rows, columns, intermediate_size, false, temp_dq, true, temp_logits + i, num_experts, true);

Expand Down
18 changes: 12 additions & 6 deletions exllamav2/exllamav2_ext/cuda/quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ void quantize_rtn_cuda
dim3 threads(BLOCKSIZE_X, 1);
dim3 blocks(DIVIDE(columns, BLOCKSIZE_X), 1);

quantize_rtn_kernel<<<blocks, threads>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
quantize_rtn_kernel<<<blocks, threads, 0, stream>>>
(
weights,
scale,
Expand Down Expand Up @@ -151,7 +152,8 @@ void fused_quantize_adjust_cuda
dim3 threads(BLOCKSIZE_X, 1);
dim3 blocks(DIVIDE(columns, BLOCKSIZE_X), 1);

fused_quantize_adjust_kernel<<<blocks, threads>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
fused_quantize_adjust_kernel<<<blocks, threads, 0, stream>>>
(
weights,
quant,
Expand Down Expand Up @@ -232,7 +234,8 @@ void quantize_cuda
// DBGI2(rows, columns);
// DBGF2(qzero, maxq);

quantize_kernel<<<blocks, threads>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
quantize_kernel<<<blocks, threads, 0, stream>>>
(
input,
output,
Expand Down Expand Up @@ -281,7 +284,8 @@ void adjust_error_row_cuda
dim3 threads(BLOCKSIZE_X, 1);
dim3 blocks(DIVIDE(columns, BLOCKSIZE_X), 1);

adjust_error_row_kernel<<<blocks, threads>>>(hessian_inv, error, weights, quant, c, columns, hcolumns);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
adjust_error_row_kernel<<<blocks, threads, 0, stream>>>(hessian_inv, error, weights, quant, c, columns, hcolumns);
}

__global__ void quantize_err_kernel
Expand Down Expand Up @@ -353,7 +357,8 @@ void quantize_err_cuda
// DBGI2(rows, columns);
// DBGF2(qzero, maxq);

quantize_err_kernel<<<blocks, threads>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
quantize_err_kernel<<<blocks, threads, 0, stream>>>
(
input,
output,
Expand Down Expand Up @@ -414,5 +419,6 @@ void vv_mul_sub_cuda
gridDim.y = DIVIDE(x_size, BLOCKSIZE_Y);
gridDim.z = 1;

vv_mul_sub_kernel<<<gridDim, blockDim>>>(x, y, z, x_size, y_size);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vv_mul_sub_kernel<<<gridDim, blockDim, 0, stream>>>(x, y, z, x_size, y_size);
}
3 changes: 2 additions & 1 deletion exllamav2/exllamav2_ext/cuda/rms_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -220,5 +220,6 @@ void rms_norm_cuda

int blocks_per_warp = DIVIDE(dim, NUM_THREADS * 2);
fp_rms_norm_kernel kernel = pick_rms_norm_kernel(blocks_per_warp);
kernel<<<gridDim, blockDim>>>(x, w, y, epsilon, r_dim, rows, dim, add_residual, input_fp32, output_fp32);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>(x, w, y, epsilon, r_dim, rows, dim, add_residual, input_fp32, output_fp32);
}
6 changes: 4 additions & 2 deletions exllamav2/exllamav2_ext/cuda/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ void rope_cuda
gridDim.y = DIVIDE(rows_per_batch, threads_y);
gridDim.z = batch_size;

rope_cuda_kernel<<<gridDim, blockDim>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
rope_cuda_kernel<<<gridDim, blockDim, 0, stream>>>
(
x,
sin,
Expand Down Expand Up @@ -240,7 +241,8 @@ void rope_cuda_qk
gridDim.y = DIVIDE(rows_per_batch, threads_y);
gridDim.z = batch_size;

rope_cuda_qk_kernel<<<gridDim, blockDim>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
rope_cuda_qk_kernel<<<gridDim, blockDim, 0, stream>>>
(
x_q,
x_k,
Expand Down
6 changes: 4 additions & 2 deletions exllamav2/exllamav2_ext/cuda/softcap.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ void softcap_cuda_
blockDim.x = NUM_THREADS;
gridDim.x = DIVIDE(numel, NUM_THREADS);

cuda_softcap_kernel<<<gridDim, blockDim>>>(x, numel, scale);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
cuda_softcap_kernel<<<gridDim, blockDim, 0, stream>>>(x, numel, scale);
}

// TODO: Profile
Expand Down Expand Up @@ -73,6 +74,7 @@ void h_softcap_cuda_
blockDim.x = NUM_THREADS;
gridDim.x = DIVIDE(numel / 2, NUM_THREADS);

h_cuda_softcap_kernel<<<gridDim, blockDim>>>(x, numel, scale);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
h_cuda_softcap_kernel<<<gridDim, blockDim, 0, stream>>>(x, numel, scale);
}