diff --git a/CMakeLists.txt b/CMakeLists.txt index 2b3679a0548c9..38b67cd707be6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,7 +19,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") # Supported AMD GPU architectures. -set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100") +set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101") # # Supported/expected torch versions for CUDA/ROCm. diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index 6b7969f035d8d..f03d3da5a8f9c 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -1,8 +1,14 @@ + #include #include #include #include +#if defined(__HIPCC__) && \ + (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300__ +#endif + constexpr int WARP_SIZE = 64; template @@ -328,6 +334,8 @@ __device__ __forceinline__ T loadnt(T* addr) { #define M 1 #define DTYPE half +#if defined(__HIP__MI300__) // TODO: Add NAVI support + __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { @@ -360,23 +368,23 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, bigType bigA[M][UNRL]; bigType bigB0[UNRL]; -#if (YTILE >= 2) + #if (YTILE >= 2) bigType bigB1[UNRL]; -#endif + #endif for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; const half* B_ = &B[(n + 0) * K + k_]; bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); -#if (YTILE >= 2) + #if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); -#endif + #endif } // Fetch activation matrix from either just LDS or from both LDS / memory -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; @@ -389,16 +397,16 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, } // Do the matrix multiplication in interleaved manner -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; -#pragma unroll + #pragma unroll for (uint32_t m = 0; m < M; m++) { // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! -#pragma unroll + #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][0]) @@ -407,11 +415,11 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- -#if (YTILE >= 2) + #if (YTILE >= 2) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][1]) : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); -#endif + #endif } } } @@ -455,6 +463,18 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, } } +#else // !defined(__HIP__MI300__) TODO: Add NAVI support + +__global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + assert(false); +} + +#endif // defined(__HIP__MI300__) TODO: Add NAVI support + +#if defined(__HIP__MI300__) // TODO: Add NAVI support + __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { @@ -555,36 +575,36 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, bigType bigA[M][UNRL]; bigType bigB0[UNRL]; -#if (YTILE >= 2) + #if (YTILE >= 2) bigType bigB1[UNRL]; -#endif -#if (YTILE >= 3) + #endif + #if (YTILE >= 3) bigType bigB2[UNRL]; -#endif -#if (YTILE >= 4) + #endif + #if (YTILE >= 4) bigType bigB3[UNRL]; -#endif -#if (YTILE >= 5) + #endif + #if (YTILE >= 5) bigType bigB4[UNRL]; -#endif -#if (YTILE >= 6) + #endif + #if (YTILE >= 6) bigType bigB5[UNRL]; -#endif -#if (YTILE >= 7) + #endif + #if (YTILE >= 7) bigType bigB6[UNRL]; -#endif -#if (YTILE >= 8) + #endif + #if (YTILE >= 8) bigType bigB7[UNRL]; -#endif -#if (YTILE >= 9) + #endif + #if (YTILE >= 9) bigType bigB8[UNRL]; -#endif -#if (YTILE >= 10) + #endif + #if (YTILE >= 10) bigType bigB9[UNRL]; -#endif -#if (YTILE >= 11) + #endif + #if (YTILE >= 11) bigType bigB10[UNRL]; -#endif + #endif //---------------------------------------------------- // Fetch weight matrix B in interleaved K-split! // - Each thread (lane) is fetching 8 elements (A_Chunk) @@ -604,7 +624,7 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { // Fetch the weight matrix from memory! -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; @@ -619,34 +639,34 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- -#if (YTILE >= 2) + #if (YTILE >= 2) // if (n+1>=N) continue; bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); -#endif -#if (YTILE >= 3) + #endif + #if (YTILE >= 3) // if (n+2>=N) continue; bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); -#endif -#if (YTILE >= 4) + #endif + #if (YTILE >= 4) // if (n+3>=N) continue; bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); -#endif -#if (YTILE >= 5) + #endif + #if (YTILE >= 5) // if (n+4>=N) continue; bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); -#endif -#if (YTILE >= 6) + #endif + #if (YTILE >= 6) // if (n+5>=N) continue; bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); -#endif -#if (YTILE >= 7) + #endif + #if (YTILE >= 7) // if (n+6>=N) continue; bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); -#endif -#if (YTILE >= 8) + #endif + #if (YTILE >= 8) // if (n+7>=N) continue; bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); -#endif + #endif /* #if (YTILE >= 9) if (n+8>=N) continue; bigB8[k2].h8 = @@ -658,7 +678,7 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, } // Fetch activation matrix from either just LDS or from both LDS / memory -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; @@ -675,16 +695,16 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, } // Do the matrix multiplication in interleaved manner -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; -#pragma unroll + #pragma unroll for (uint32_t m = 0; m < M; m++) { // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! -#pragma unroll + #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][0]) @@ -693,56 +713,56 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- -#if (YTILE >= 2) + #if (YTILE >= 2) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][1]) : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); -#endif -#if (YTILE >= 3) + #endif + #if (YTILE >= 3) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][2]) : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); -#endif -#if (YTILE >= 4) + #endif + #if (YTILE >= 4) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][3]) : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); -#endif -#if (YTILE >= 5) + #endif + #if (YTILE >= 5) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][4]) : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); -#endif -#if (YTILE >= 6) + #endif + #if (YTILE >= 6) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][5]) : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); -#endif -#if (YTILE >= 7) + #endif + #if (YTILE >= 7) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][6]) : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); -#endif -#if (YTILE >= 8) + #endif + #if (YTILE >= 8) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][7]) : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); -#endif -#if (YTILE >= 9) + #endif + #if (YTILE >= 9) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][8]) : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); -#endif -#if (YTILE >= 10) + #endif + #if (YTILE >= 10) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][9]) : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); -#endif -#if (YTILE >= 11) + #endif + #if (YTILE >= 11) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][10]) : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); -#endif + #endif } } } @@ -800,6 +820,16 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, } } +#else // !defined(__HIP__MI300__) TODO: Add NAVI support + +__global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + assert(false); +} + +#endif // defined(__HIP__MI300__) TODO: Add NAVI support + #undef YTILE #undef UNRL #undef M @@ -808,6 +838,8 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, #define UNRL 2 #define M 2 +#if defined(__HIP__MI300__) // TODO: Add NAVI support + __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { @@ -908,36 +940,36 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, bigType bigA[M][UNRL]; bigType bigB0[UNRL]; -#if (YTILE >= 2) + #if (YTILE >= 2) bigType bigB1[UNRL]; -#endif -#if (YTILE >= 3) + #endif + #if (YTILE >= 3) bigType bigB2[UNRL]; -#endif -#if (YTILE >= 4) + #endif + #if (YTILE >= 4) bigType bigB3[UNRL]; -#endif -#if (YTILE >= 5) + #endif + #if (YTILE >= 5) bigType bigB4[UNRL]; -#endif -#if (YTILE >= 6) + #endif + #if (YTILE >= 6) bigType bigB5[UNRL]; -#endif -#if (YTILE >= 7) + #endif + #if (YTILE >= 7) bigType bigB6[UNRL]; -#endif -#if (YTILE >= 8) + #endif + #if (YTILE >= 8) bigType bigB7[UNRL]; -#endif -#if (YTILE >= 9) + #endif + #if (YTILE >= 9) bigType bigB8[UNRL]; -#endif -#if (YTILE >= 10) + #endif + #if (YTILE >= 10) bigType bigB9[UNRL]; -#endif -#if (YTILE >= 11) + #endif + #if (YTILE >= 11) bigType bigB10[UNRL]; -#endif + #endif //---------------------------------------------------- // Fetch weight matrix B in interleaved K-split! // - Each thread (lane) is fetching 8 elements (A_Chunk) @@ -957,7 +989,7 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { // Fetch the weight matrix from memory! -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; @@ -972,34 +1004,34 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- -#if (YTILE >= 2) + #if (YTILE >= 2) // if (n+1>=N) continue; bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); -#endif -#if (YTILE >= 3) + #endif + #if (YTILE >= 3) // if (n+2>=N) continue; bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); -#endif -#if (YTILE >= 4) + #endif + #if (YTILE >= 4) // if (n+3>=N) continue; bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); -#endif -#if (YTILE >= 5) + #endif + #if (YTILE >= 5) // if (n+4>=N) continue; bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); -#endif -#if (YTILE >= 6) + #endif + #if (YTILE >= 6) // if (n+5>=N) continue; bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); -#endif -#if (YTILE >= 7) + #endif + #if (YTILE >= 7) // if (n+6>=N) continue; bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); -#endif -#if (YTILE >= 8) + #endif + #if (YTILE >= 8) // if (n+7>=N) continue; bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); -#endif + #endif /* #if (YTILE >= 9) if (n+8>=N) continue; bigB8[k2].h8 = @@ -1011,7 +1043,7 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, } // Fetch activation matrix from either just LDS or from both LDS / memory -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; @@ -1028,16 +1060,16 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, } // Do the matrix multiplication in interleaved manner -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; -#pragma unroll + #pragma unroll for (uint32_t m = 0; m < M; m++) { // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! -#pragma unroll + #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][0]) @@ -1046,56 +1078,56 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- -#if (YTILE >= 2) + #if (YTILE >= 2) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][1]) : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); -#endif -#if (YTILE >= 3) + #endif + #if (YTILE >= 3) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][2]) : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); -#endif -#if (YTILE >= 4) + #endif + #if (YTILE >= 4) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][3]) : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); -#endif -#if (YTILE >= 5) + #endif + #if (YTILE >= 5) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][4]) : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); -#endif -#if (YTILE >= 6) + #endif + #if (YTILE >= 6) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][5]) : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); -#endif -#if (YTILE >= 7) + #endif + #if (YTILE >= 7) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][6]) : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); -#endif -#if (YTILE >= 8) + #endif + #if (YTILE >= 8) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][7]) : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); -#endif -#if (YTILE >= 9) + #endif + #if (YTILE >= 9) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][8]) : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); -#endif -#if (YTILE >= 10) + #endif + #if (YTILE >= 10) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][9]) : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); -#endif -#if (YTILE >= 11) + #endif + #if (YTILE >= 11) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][10]) : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); -#endif + #endif } } } @@ -1153,6 +1185,16 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, } } +#else // !defined(__HIP__MI300__) TODO: Add NAVI support + +__global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + assert(false); +} + +#endif // defined(__HIP__MI300__) TODO: Add NAVI support + #undef YTILE #undef UNRL #undef M @@ -1161,6 +1203,8 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, #define UNRL 2 #define M 3 +#if defined(__HIP__MI300__) // TODO: Add NAVI support + __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { @@ -1261,36 +1305,36 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, bigType bigA[M][UNRL]; bigType bigB0[UNRL]; -#if (YTILE >= 2) + #if (YTILE >= 2) bigType bigB1[UNRL]; -#endif -#if (YTILE >= 3) + #endif + #if (YTILE >= 3) bigType bigB2[UNRL]; -#endif -#if (YTILE >= 4) + #endif + #if (YTILE >= 4) bigType bigB3[UNRL]; -#endif -#if (YTILE >= 5) + #endif + #if (YTILE >= 5) bigType bigB4[UNRL]; -#endif -#if (YTILE >= 6) + #endif + #if (YTILE >= 6) bigType bigB5[UNRL]; -#endif -#if (YTILE >= 7) + #endif + #if (YTILE >= 7) bigType bigB6[UNRL]; -#endif -#if (YTILE >= 8) + #endif + #if (YTILE >= 8) bigType bigB7[UNRL]; -#endif -#if (YTILE >= 9) + #endif + #if (YTILE >= 9) bigType bigB8[UNRL]; -#endif -#if (YTILE >= 10) + #endif + #if (YTILE >= 10) bigType bigB9[UNRL]; -#endif -#if (YTILE >= 11) + #endif + #if (YTILE >= 11) bigType bigB10[UNRL]; -#endif + #endif //---------------------------------------------------- // Fetch weight matrix B in interleaved K-split! // - Each thread (lane) is fetching 8 elements (A_Chunk) @@ -1310,7 +1354,7 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { // Fetch the weight matrix from memory! -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; @@ -1325,34 +1369,34 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- -#if (YTILE >= 2) + #if (YTILE >= 2) // if (n+1>=N) continue; bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); -#endif -#if (YTILE >= 3) + #endif + #if (YTILE >= 3) // if (n+2>=N) continue; bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); -#endif -#if (YTILE >= 4) + #endif + #if (YTILE >= 4) // if (n+3>=N) continue; bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); -#endif -#if (YTILE >= 5) + #endif + #if (YTILE >= 5) // if (n+4>=N) continue; bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); -#endif -#if (YTILE >= 6) + #endif + #if (YTILE >= 6) // if (n+5>=N) continue; bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); -#endif -#if (YTILE >= 7) + #endif + #if (YTILE >= 7) // if (n+6>=N) continue; bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); -#endif -#if (YTILE >= 8) + #endif + #if (YTILE >= 8) // if (n+7>=N) continue; bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); -#endif + #endif /* #if (YTILE >= 9) if (n+8>=N) continue; bigB8[k2].h8 = @@ -1364,7 +1408,7 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, } // Fetch activation matrix from either just LDS or from both LDS / memory -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; @@ -1381,16 +1425,16 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, } // Do the matrix multiplication in interleaved manner -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; -#pragma unroll + #pragma unroll for (uint32_t m = 0; m < M; m++) { // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! -#pragma unroll + #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][0]) @@ -1399,56 +1443,56 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- -#if (YTILE >= 2) + #if (YTILE >= 2) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][1]) : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); -#endif -#if (YTILE >= 3) + #endif + #if (YTILE >= 3) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][2]) : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); -#endif -#if (YTILE >= 4) + #endif + #if (YTILE >= 4) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][3]) : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); -#endif -#if (YTILE >= 5) + #endif + #if (YTILE >= 5) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][4]) : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); -#endif -#if (YTILE >= 6) + #endif + #if (YTILE >= 6) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][5]) : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); -#endif -#if (YTILE >= 7) + #endif + #if (YTILE >= 7) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][6]) : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); -#endif -#if (YTILE >= 8) + #endif + #if (YTILE >= 8) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][7]) : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); -#endif -#if (YTILE >= 9) + #endif + #if (YTILE >= 9) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][8]) : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); -#endif -#if (YTILE >= 10) + #endif + #if (YTILE >= 10) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][9]) : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); -#endif -#if (YTILE >= 11) + #endif + #if (YTILE >= 11) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][10]) : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); -#endif + #endif } } } @@ -1506,6 +1550,16 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, } } +#else // !defined(__HIP__MI300__) TODO: Add NAVI support + +__global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + assert(false); +} + +#endif // defined(__HIP__MI300__) TODO: Add NAVI support + #undef YTILE #undef UNRL #undef M @@ -1514,6 +1568,8 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, #define UNRL 1 #define M 4 +#if defined(__HIP__MI300__) // TODO: Add NAVI support + __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { @@ -1614,36 +1670,36 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, bigType bigA[M][UNRL]; bigType bigB0[UNRL]; -#if (YTILE >= 2) + #if (YTILE >= 2) bigType bigB1[UNRL]; -#endif -#if (YTILE >= 3) + #endif + #if (YTILE >= 3) bigType bigB2[UNRL]; -#endif -#if (YTILE >= 4) + #endif + #if (YTILE >= 4) bigType bigB3[UNRL]; -#endif -#if (YTILE >= 5) + #endif + #if (YTILE >= 5) bigType bigB4[UNRL]; -#endif -#if (YTILE >= 6) + #endif + #if (YTILE >= 6) bigType bigB5[UNRL]; -#endif -#if (YTILE >= 7) + #endif + #if (YTILE >= 7) bigType bigB6[UNRL]; -#endif -#if (YTILE >= 8) + #endif + #if (YTILE >= 8) bigType bigB7[UNRL]; -#endif -#if (YTILE >= 9) + #endif + #if (YTILE >= 9) bigType bigB8[UNRL]; -#endif -#if (YTILE >= 10) + #endif + #if (YTILE >= 10) bigType bigB9[UNRL]; -#endif -#if (YTILE >= 11) + #endif + #if (YTILE >= 11) bigType bigB10[UNRL]; -#endif + #endif //---------------------------------------------------- // Fetch weight matrix B in interleaved K-split! // - Each thread (lane) is fetching 8 elements (A_Chunk) @@ -1663,7 +1719,7 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { // Fetch the weight matrix from memory! -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; @@ -1678,34 +1734,34 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- -#if (YTILE >= 2) + #if (YTILE >= 2) // if (n+1>=N) continue; bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); -#endif -#if (YTILE >= 3) + #endif + #if (YTILE >= 3) // if (n+2>=N) continue; bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); -#endif -#if (YTILE >= 4) + #endif + #if (YTILE >= 4) // if (n+3>=N) continue; bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); -#endif -#if (YTILE >= 5) + #endif + #if (YTILE >= 5) // if (n+4>=N) continue; bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); -#endif -#if (YTILE >= 6) + #endif + #if (YTILE >= 6) // if (n+5>=N) continue; bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); -#endif -#if (YTILE >= 7) + #endif + #if (YTILE >= 7) // if (n+6>=N) continue; bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); -#endif -#if (YTILE >= 8) + #endif + #if (YTILE >= 8) // if (n+7>=N) continue; bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); -#endif + #endif /* #if (YTILE >= 9) if (n+8>=N) continue; bigB8[k2].h8 = @@ -1717,7 +1773,7 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, } // Fetch activation matrix from either just LDS or from both LDS / memory -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; @@ -1734,16 +1790,16 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, } // Do the matrix multiplication in interleaved manner -#pragma unroll + #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; -#pragma unroll + #pragma unroll for (uint32_t m = 0; m < M; m++) { // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! -#pragma unroll + #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][0]) @@ -1752,56 +1808,56 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, //---------------------------------------------------- // The following code with YTILE > 1 //---------------------------------------------------- -#if (YTILE >= 2) + #if (YTILE >= 2) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][1]) : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); -#endif -#if (YTILE >= 3) + #endif + #if (YTILE >= 3) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][2]) : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); -#endif -#if (YTILE >= 4) + #endif + #if (YTILE >= 4) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][3]) : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); -#endif -#if (YTILE >= 5) + #endif + #if (YTILE >= 5) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][4]) : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); -#endif -#if (YTILE >= 6) + #endif + #if (YTILE >= 6) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][5]) : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); -#endif -#if (YTILE >= 7) + #endif + #if (YTILE >= 7) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][6]) : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); -#endif -#if (YTILE >= 8) + #endif + #if (YTILE >= 8) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][7]) : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); -#endif -#if (YTILE >= 9) + #endif + #if (YTILE >= 9) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][8]) : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); -#endif -#if (YTILE >= 10) + #endif + #if (YTILE >= 10) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][9]) : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); -#endif -#if (YTILE >= 11) + #endif + #if (YTILE >= 11) asm("v_dot2c_f32_f16 %0, %2, %3" : "=v"(sum[m][10]) : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); -#endif + #endif } } } @@ -1859,6 +1915,16 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, } } +#else // !defined(__HIP__MI300__) TODO: Add NAVI support + +__global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + assert(false); +} + +#endif // defined(__HIP__MI300__) TODO: Add NAVI support + void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, const int K_in, const int N_in, cudaStream_t stream, const int CuCount = 0) { diff --git a/csrc/custom/paged_attention/attention_ll4mi.cu b/csrc/custom/paged_attention/attention_ll4mi.cu index dcabc7932cfd5..0401bd21f7784 100644 --- a/csrc/custom/paged_attention/attention_ll4mi.cu +++ b/csrc/custom/paged_attention/attention_ll4mi.cu @@ -5,13 +5,20 @@ #include +#if defined(__HIPCC__) && \ + (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300__ +#endif + #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define WARP_SIZE 64 -#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 -#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 +#if defined(__HIP__MI300__) // TODO: Add NAVI support + + #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 + #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using float16x4 = @@ -22,7 +29,7 @@ typedef struct _Half8 { } _Half8; ////// Non temporal load stores /////// -#if 1 + #if 1 template __device__ __forceinline__ T load(T* addr) { @@ -34,7 +41,7 @@ __device__ __forceinline__ void store(T value, T* addr) { addr[0] = value; } -#else + #else template __device__ __forceinline__ T load(const T* addr) { @@ -109,7 +116,7 @@ __device__ __forceinline__ void store(T value, T* addr) { return __builtin_nontemporal_store(value, addr); } -#endif + #endif /////////////////////////////////////// @@ -135,9 +142,9 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] -#if 0 + #if 0 scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] -#endif + #endif int max_ctx_blocks) { constexpr int NWARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; @@ -173,7 +180,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( _Half8 Vlocal[VHELOOP][VTLOOP]; floatx4 dout[QHLOOP]; float qk_max[QHLOOP]; -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { dout[h] = {0}; qk_max[h] = -FLT_MAX; @@ -186,7 +193,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( partition_start_token_idx + warpid * WARP_SIZE; if (warp_start_token_idx >= context_len) { // warp out of context -#pragma unroll + #pragma unroll for (int h = 0; h < GQA_RATIO4; h++) { shared_qk_max[warpid][h] = -FLT_MAX; shared_exp_sum[warpid][h] = 0.0f; @@ -215,7 +222,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; const _Half8* q_ptrh8 = reinterpret_cast(q_ptr); const int qhead_elemh8 = laneid / 4; -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP - 1; h++) { const int qhead_idx = h * 4 + lane4id; Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; @@ -237,14 +244,14 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset // is already cast as _H8 -#pragma unroll + #pragma unroll for (int d = 0; d < KHELOOP; d++) { Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; } float alibi_slope[QHLOOP]; if (alibi_slopes != nullptr) { -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { const int qhead_idx = h * 4 + lane4id; alibi_slope[h] = (qhead_idx < GQA_RATIO) @@ -257,8 +264,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( int vphysical_blocks[VBLOCKS]; const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; -// fetch vphysical block numbers -#pragma unroll + // fetch vphysical block numbers + #pragma unroll for (int b = 0; b < VBLOCKS; b++) { const int vblock_idx = warp_start_block_idx + b; const int vblock_idx_ctx = @@ -268,8 +275,8 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; const _Half8* v_ptrh8 = reinterpret_cast(v_ptr); -// iterate over each v block -#pragma unroll + // iterate over each v block + #pragma unroll for (int b = 0; b < VBLOCKS; b++) { // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride @@ -277,20 +284,20 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( static_cast(vphysical_blocks[b]); const _Half8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; -// iterate over each head elem (within head_size) -#pragma unroll + // iterate over each head elem (within head_size) + #pragma unroll for (int h = 0; h < VHELOOP; h++) { const int head_size_elem = h * WARP_SIZE + laneid; const _Half8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; -// iterate over all velems within block -#pragma unroll + // iterate over all velems within block + #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; } } } -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { dout[h] = GCN_MFMA_INSTR(Qlocal[h].xy[0], Klocal[0].xy[0], dout[h], 4, 0, 0); @@ -360,12 +367,12 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } // KHELOOP>8 dout[h] *= scale; } -// transpose dout so that 4 token ids are in each lane, and 4 heads are across 4 -// lanes -#pragma unroll + // transpose dout so that 4 token ids are in each lane, and 4 heads are across + // 4 lanes + #pragma unroll for (int h = 0; h < QHLOOP; h++) { floatx4 tmp = {0}; -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { const float B = (lane4id == i) ? 1.0f : 0.0f; // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f; @@ -378,48 +385,48 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const int lane4_token_idx = 4 * (global_token_idx >> 2); const int alibi_offset = lane4_token_idx - context_len + 1; if (alibi_slopes != nullptr) { -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { dout[h][i] += alibi_slope[h] * (alibi_offset + i); } } } -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { qk_max[h] = -FLT_MAX; -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { qk_max[h] = (lane4_token_idx + i < context_len) ? fmaxf(qk_max[h], dout[h][i]) : qk_max[h]; } -#pragma unroll + #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); } } float exp_sum[QHLOOP]; -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { exp_sum[h] = 0.0f; -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { dout[h][i] = (lane4_token_idx + i < context_len) ? __expf(dout[h][i] - qk_max[h]) : 0.0f; exp_sum[h] += dout[h][i]; } -#pragma unroll + #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { exp_sum[h] += __shfl_xor(exp_sum[h], mask); } } -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { const int head_idx = 4 * h + lane4id; shared_qk_max[warpid][head_idx] = qk_max[h]; @@ -434,18 +441,18 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx; -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { float global_qk_max = -FLT_MAX; float warp_qk_max[NWARPS]; const int head_idx = 4 * h + lane4id; -#pragma unroll + #pragma unroll for (int w = 0; w < NWARPS; w++) { warp_qk_max[w] = shared_qk_max[w][head_idx]; global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]); } float global_exp_sum = 0.0f; -#pragma unroll + #pragma unroll for (int w = 0; w < NWARPS; w++) { global_exp_sum += shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max); @@ -463,9 +470,9 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there // are 4x16 tokens across warp float16x4 logits[QHLOOP]; -#pragma unroll + #pragma unroll for (int h = 0; h < QHLOOP; h++) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { logits[h][i] = (scalar_t)dout[h][i]; } @@ -474,19 +481,19 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( __shared__ float16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; if (warp_start_token_idx >= context_len) { // warp out of context -#pragma unroll + #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { -#pragma unroll + #pragma unroll for (int vh = 0; vh < VHELOOP; vh++) { vout_shared[qh][vh][laneid][warpid] = {0}; } } } else { // warp in context -// iterate across heads -#pragma unroll + // iterate across heads + #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { -// iterate over each v head elem (within head_size) -#pragma unroll + // iterate over each v head elem (within head_size) + #pragma unroll for (int vh = 0; vh < VHELOOP; vh++) { floatx4 acc = {0}; // iterate over tokens @@ -507,7 +514,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][7].xy[0], acc, 4, 14, 0); acc = GCN_MFMA_INSTR(logits[qh], Vlocal[vh][7].xy[1], acc, 4, 15, 0); float16x4 tmp; -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { tmp[i] = (scalar_t)acc[i]; } @@ -531,18 +538,18 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( out_num_partitions = 1; out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; } -#pragma unroll + #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { -// iterate over each v head elem (within head_size) -#pragma unroll + // iterate over each v head elem (within head_size) + #pragma unroll for (int vh = 0; vh < VHELOOP; vh++) { vout[qh][vh] = {0}; -#pragma unroll + #pragma unroll for (int w = 0; w < NWARPS; w++) { vout[qh][vh] += vout_shared[qh][vh][laneid][w]; } const int head_size_elem = vh * WARP_SIZE + laneid; -#pragma unroll + #pragma unroll for (int i = 0; i < 4; i++) { const int head_idx = 4 * qh + i; if (head_idx < GQA_RATIO) { @@ -557,12 +564,12 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } -#if 0 + #if 0 const int num_seqs = gridDim.x; const int global_token4id = global_token_idx/4; - #pragma unroll + #pragma unroll for (int t=0;t<4;t++) { - #pragma unroll + #pragma unroll for (int h=0;h= 1; mask /= 2) { max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); } @@ -643,7 +650,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( shared_exp_sums[threadIdx.x] = rescaled_exp_sum; shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2; -#pragma unroll + #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { global_exp_sum += __shfl_xor(global_exp_sum, mask); } @@ -656,7 +663,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; constexpr int MAX_NPAR = 64; scalar_t tmps[MAX_NPAR]; -#pragma unroll + #pragma unroll for (int j = 0; j < MAX_NPAR; j++) { tmps[j] = 0.0f; } @@ -666,7 +673,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( constexpr int JCHUNK = 16; -#pragma unroll + #pragma unroll for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { // lastj is last valid partition const int lastj_offset = @@ -677,7 +684,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( __syncthreads(); if (num_partitions > JCHUNK) { -#pragma unroll + #pragma unroll for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { const int lastj_offset = @@ -687,7 +694,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( } if (num_partitions > 2 * JCHUNK) { -#pragma unroll + #pragma unroll for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; j += HEAD_SIZE) { const int lastj_offset = @@ -700,17 +707,17 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // Aggregate tmp_out to out. float acc = 0.0f; -#pragma unroll + #pragma unroll for (int j = 0; j < JCHUNK; j++) { acc += tmps[j] * shared_exp_sums[j]; } if (num_partitions > JCHUNK) { -#pragma unroll + #pragma unroll for (int j = JCHUNK; j < 2 * JCHUNK; j++) { acc += tmps[j] * shared_exp_sums[j]; } if (num_partitions > 2 * JCHUNK) { -#pragma unroll + #pragma unroll for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { acc += tmps[j] * shared_exp_sums[j]; } @@ -719,7 +726,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( if (num_partitions > MAX_NPAR) { idx = 0; -#pragma unroll + #pragma unroll for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE; j += HEAD_SIZE) { // lastj is last valid partition @@ -729,7 +736,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( idx++; } -#pragma unroll + #pragma unroll for (int j = 0; j < MAX_NPAR; j++) { acc += tmps[j] * shared_exp_sums[j + MAX_NPAR]; } @@ -744,6 +751,54 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( out_ptr[threadIdx.x] = (scalar_t)acc; } +#else // !defined(__HIP__MI300__) TODO: Add NAVI support + +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + #if 0 + scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size] + #endif + int max_ctx_blocks) { + assert(false); +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + assert(false); +} + +#endif // defined(__HIP__MI300__) TODO: Add NAVI support + #define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ paged_attention_ll4mi_QKV_kernel \ <<>>( \ diff --git a/tests/kernels/test_attention_custom.py b/tests/kernels/test_attention_custom.py index 5bdbf126c22fa..d9b53ed7bd0d9 100644 --- a/tests/kernels/test_attention_custom.py +++ b/tests/kernels/test_attention_custom.py @@ -3,12 +3,13 @@ import pytest import torch -from allclose_default import get_default_atol, get_default_rtol from vllm._C import cache_ops, ops from vllm._custom_C import paged_attention_custom from vllm.utils import get_max_shared_memory_bytes, is_hip +from .allclose_default import get_default_atol, get_default_rtol + FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 0d3ee47193306..ee2e83f6b272c 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -1,4 +1,3 @@ -import os from dataclasses import dataclass from typing import List, Optional, Tuple @@ -6,10 +5,11 @@ from vllm import _custom_ops as ops from vllm.attention.ops.prefix_prefill import context_attention_fwd +from vllm.envs import VLLM_USE_ROCM_CUSTOM_PAGED_ATTN from vllm.utils import is_hip -custom_attn_available = is_hip() and \ - (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "1") != "0") +custom_attn_available = is_hip() and VLLM_USE_ROCM_CUSTOM_PAGED_ATTN and \ + "gfx1" not in torch.cuda.get_device_properties('cuda').gcnArchName if custom_attn_available: from vllm._custom_C import paged_attention_custom diff --git a/vllm/envs.py b/vllm/envs.py index bd6abca2629ed..739a4792ce078 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -9,6 +9,8 @@ VLLM_NCCL_SO_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = True + VLLM_USE_ROCM_SKINNY_GEMM: bool = True + VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True RANK: int = 0 LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -138,6 +140,16 @@ lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), + # small gemms custom implementation for MI3* cards + "VLLM_USE_ROCM_SKINNY_GEMM": + lambda: (os.getenv("VLLM_USE_ROCM_SKINNY_GEMM", "True").lower() in + ("true", "1")), + + # custom paged attention implemented for MI3* cards + "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN": + lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in + ("true", "1") != "0"), + # rank of the process in the distributed setting, used to determine # the driver worker "RANK": diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index b3783cecfdce8..3cef301bfa442 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -8,6 +8,8 @@ from rocsolidxgemm import rocb_create_extension, rocb_mm from vllm import _custom_C +from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM +from vllm.utils import is_hip class TunedGemm: @@ -26,6 +28,9 @@ def __init__(self): self.cu_count = torch.cuda.get_device_properties( device='cuda').multi_processor_count + self.use_skinny = is_hip() and VLLM_USE_ROCM_SKINNY_GEMM and \ + "gfx1" not in torch.cuda.get_device_properties('cuda').gcnArchName + if (self.save_gemm == 1): self.tuned_df = pd.DataFrame(columns=['M', 'N', 'K']) else: @@ -52,6 +57,8 @@ def query_sol(self, m, n, k): return self.solids.get((m, n, k), (0, 0)) def apply_skinny(self, m, n, k, inp_view, weights): + if not self.use_skinny: + return None if inp_view.dtype != torch.float16 or k % 8 != 0: return None if m > 8 and n <= 4: