diff --git a/onnxruntime/cppbuild.sh b/onnxruntime/cppbuild.sh index 6816f09e77..1838aa1ff2 100755 --- a/onnxruntime/cppbuild.sh +++ b/onnxruntime/cppbuild.sh @@ -84,6 +84,7 @@ sedinplace 's/MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()/false/g' onnxrunt # work around toolchain issues on Mac and Windows patch -p1 < ../../../onnxruntime.patch +patch -p1 < ../../../onnxruntime-cuda.patch # https://github.com/microsoft/onnxruntime/pull/22316 #patch -p1 < ../../../onnxruntime-windows.patch # https://github.com/microsoft/onnxruntime/pull/7883 sedinplace '/--Werror/d' cmake/CMakeLists.txt sedinplace '/-DCMAKE_CUDA_COMPILER=/d' tools/ci_build/build.py @@ -167,7 +168,7 @@ sedinplace 's/UTFChars(javaNameStrings/UTFChars((jstring)javaNameStrings/g' java sedinplace 's/initializers = allocarray/initializers = (const OrtValue**)allocarray/g' java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.cpp which ctest3 &> /dev/null && CTEST="ctest3" || CTEST="ctest" -"$PYTHON_BIN_PATH" tools/ci_build/build.py --build_dir ../build --config Release --cmake_path "$CMAKE" --ctest_path "$CTEST" --build_shared_lib $ARCH_FLAGS $DNNL_FLAGS $OPENMP_FLAGS $GPU_FLAGS +"$PYTHON_BIN_PATH" tools/ci_build/build.py --build_dir ../build --config Release --parallel $MAKEJ --cmake_path "$CMAKE" --ctest_path "$CTEST" --build_shared_lib $ARCH_FLAGS $DNNL_FLAGS $OPENMP_FLAGS $GPU_FLAGS # install headers and libraries in standard directories cp -r include/* ../include diff --git a/onnxruntime/onnxruntime-cuda.patch b/onnxruntime/onnxruntime-cuda.patch new file mode 100644 index 0000000000..4f8688ad0e --- /dev/null +++ b/onnxruntime/onnxruntime-cuda.patch @@ -0,0 +1,1371 @@ +diff --git a/cmake/deps.txt b/cmake/deps.txt +index 2487ea144227d..1900e88b1ece1 100644 +--- a/cmake/deps.txt ++++ b/cmake/deps.txt +@@ -53,7 +53,7 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/ca678952a9a8eaa6de112 + re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cde7fcef990bf4a8e87b047882e88 + safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac + tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 +-cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.0.zip;ae038931b9fc2c416c17d9cda91d9706b343f56d ++cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.1.zip;e49b2b964163d27765a5002d210a2f3c73771835 + utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 + extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c + composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 +diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake +index 1ece2e7a509ba..f04f4bec76cd5 100644 +--- a/cmake/external/cutlass.cmake ++++ b/cmake/external/cutlass.cmake +@@ -3,7 +3,6 @@ FetchContent_Declare( + cutlass + URL ${DEP_URL_cutlass} + URL_HASH SHA1=${DEP_SHA1_cutlass} +- PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass_3.5.0.patch + ) + + FetchContent_GetProperties(cutlass) +diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +index 5ffa63c54c8fb..175ef9f250a28 100644 +--- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h ++++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +@@ -10,7 +10,7 @@ + #endif + + #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +-#include "41_fused_multi_head_attention/kernel_forward.h" ++#include "contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h" + + namespace onnxruntime { + namespace contrib { +diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h +new file mode 100644 +index 0000000000000..881d2384e4a5c +--- /dev/null ++++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h +@@ -0,0 +1,1327 @@ ++/*************************************************************************************************** ++ * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. ++ * SPDX-License-Identifier: BSD-3-Clause ++ * ++ * Redistribution and use in source and binary forms, with or without ++ * modification, are permitted provided that the following conditions are met: ++ * ++ * 1. Redistributions of source code must retain the above copyright notice, this ++ * list of conditions and the following disclaimer. ++ * ++ * 2. Redistributions in binary form must reproduce the above copyright notice, ++ * this list of conditions and the following disclaimer in the documentation ++ * and/or other materials provided with the distribution. ++ * ++ * 3. Neither the name of the copyright holder nor the names of its ++ * contributors may be used to endorse or promote products derived from ++ * this software without specific prior written permission. ++ * ++ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" ++ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE ++ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE ++ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE ++ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL ++ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR ++ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER ++ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, ++ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE ++ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ++ * ++ **************************************************************************************************/ ++ ++#pragma once ++ ++#ifdef HAS_PYTORCH ++#include ++#include ++#endif ++ ++#include ++#include ++#include ++ ++#include "cutlass/fast_math.h" ++#include "cutlass/gemm/gemm.h" ++#include "cutlass/layout/matrix.h" ++#include "cutlass/layout/vector.h" ++#include "cutlass/matrix.h" ++#include "cutlass/numeric_types.h" ++#include "cutlass/tensor_ref.h" ++ ++#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" ++#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" ++#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" ++#include "cutlass/gemm/device/default_gemm_configuration.h" ++#include "cutlass/gemm/kernel/default_gemm.h" ++#include "cutlass/gemm/threadblock/default_mma.h" ++#include "cutlass/gemm/threadblock/default_mma_core_simt.h" ++#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" ++#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" ++#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" ++#include "cutlass/gemm/threadblock/threadblock_swizzle.h" ++#include "cutlass/matrix_shape.h" ++#include "cutlass/platform/platform.h" ++#include "cutlass/transform/threadblock/predicated_tile_iterator.h" ++#include "41_fused_multi_head_attention/debug_utils.h" ++#include "41_fused_multi_head_attention/epilogue/epilogue_pipelined.h" ++#include "41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h" ++#include "41_fused_multi_head_attention/gemm/custom_mma.h" ++#include "41_fused_multi_head_attention/gemm/find_default_mma.h" ++#include "41_fused_multi_head_attention/gemm/mma_from_smem.h" ++#include "41_fused_multi_head_attention/gemm_kernel_utils.h" ++#include "41_fused_multi_head_attention/transform/tile_smem_loader.h" ++ ++#include ++ ++using namespace gemm_kernel_utils; ++ ++namespace { ++template ++constexpr int getWarpsPerSmFw() { ++ return ( ++ Arch::kMinComputeCapability >= 80 && ++ !cutlass::platform::is_same::value ++ ? 16 ++ : 12); ++} ++static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { ++ // source: https://stackoverflow.com/a/51549250 ++ return (value >= 0) ++ ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) ++ : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); ++} ++} // namespace ++ ++// If ToBatchHookType_ is supplied other than this default (which is ++// never the case in the xformers library) then the user is ++// defining the logic which each block uses to find its data to work on, ++// with the advance_to_batch function with the following signature. ++// It should return false if there is no work to do for this block. ++// In general this will not work with saving for backward due to fixed layout ++// for logsumexp and incompatible rngs for dropout, so is likely only useful for ++// custom inference. ++struct DefaultToBatchHook { ++ template ++ CUTLASS_DEVICE static bool advance_to_batch( ++ Params&, ++ int64_t& /* q_start */, ++ int64_t& /* k_start */) { ++ return true; ++ } ++}; ++ ++template < ++ // The datatype of Q/K/V ++ typename scalar_t_, ++ // Architecture we are targeting (eg `cutlass::arch::Sm80`) ++ typename ArchTag, ++ // If Q/K/V are correctly aligned in memory and we can run a fast kernel ++ bool isAligned_, ++ int kQueriesPerBlock_, ++ int kKeysPerBlock_, ++ // upperbound on `max(value.shape[-1], query.shape[-1])` ++ int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), ++ // This is quite slower on V100 for some reason ++ // Set to false if you know at compile-time you will never need dropout ++ bool kSupportsDropout_ = true, ++ bool kSupportsBias_ = true, ++ typename ToBatchHookType_ = DefaultToBatchHook> ++struct AttentionKernel { ++ enum CustomMaskType { ++ NoCustomMask = 0, ++ CausalFromTopLeft = 1, ++ CausalFromBottomRight = 2, ++ NumCustomMaskTypes, ++ }; ++ ++ using scalar_t = scalar_t_; ++ using accum_t = float; ++ using lse_scalar_t = float; ++ using output_t = scalar_t; ++ // Accumulator between 2 iterations ++ // Using `accum_t` improves perf on f16 at the cost of ++ // numerical errors ++ using output_accum_t = accum_t; ++ static constexpr bool kSupportsDropout = kSupportsDropout_; ++ static constexpr bool kSupportsBias = kSupportsBias_; ++ static constexpr int kKeysPerBlock = kKeysPerBlock_; ++ static constexpr int kQueriesPerBlock = kQueriesPerBlock_; ++ static constexpr int kMaxK = kMaxK_; ++ static constexpr bool kIsAligned = isAligned_; ++ static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock; ++ static constexpr int32_t kAlignLSE = 32; // block size of backward ++ static constexpr bool kIsHalf = cutlass::sizeof_bits::value == 16; ++ static constexpr bool kPreloadV = ++ ArchTag::kMinComputeCapability >= 80 && kIsHalf; ++ static constexpr bool kKeepOutputInRF = kSingleValueIteration; ++ static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && ++ !cutlass::platform::is_same::value; ++ ++ static_assert(kQueriesPerBlock % 32 == 0, ""); ++ static_assert(kKeysPerBlock % 32 == 0, ""); ++ static constexpr int kNumWarpsPerBlock = ++ kQueriesPerBlock * kKeysPerBlock / (32 * 32); ++ static constexpr int kWarpSize = 32; ++ ++ // Launch bounds ++ static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; ++ static constexpr int kMinBlocksPerSm = ++ getWarpsPerSmFw() / kNumWarpsPerBlock; ++ ++ struct Params { ++ // Input tensors ++ scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim] ++ scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim] ++ scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value] ++ scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] ++ int32_t* seqstart_q_ptr = nullptr; ++ int32_t* seqstart_k_ptr = nullptr; ++ ++ int32_t* seqlen_k_ptr = nullptr; ++ uint32_t causal_diagonal_offset = 0; ++ ++ // Output tensors ++ output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value] ++ // [num_queries, num_heads, head_dim_value] ++ output_accum_t* output_accum_ptr = nullptr; ++ // [num_heads, num_queries] - can be null ++ lse_scalar_t* logsumexp_ptr = nullptr; ++ ++ // Scale ++ accum_t scale = 0.0; ++ ++ // Dimensions/strides ++ int32_t head_dim = 0; ++ int32_t head_dim_value = 0; ++ int32_t num_queries = 0; ++ int32_t num_keys = 0; ++ int32_t num_keys_absolute = 0; ++ ++ uint8_t custom_mask_type = NoCustomMask; ++ ++ int32_t q_strideM = 0; ++ int32_t k_strideM = 0; ++ int32_t v_strideM = 0; ++ int32_t bias_strideM = 0; ++ ++ int32_t o_strideM = 0; ++ ++ // Everything below is only used in `advance_to_block` ++ // and shouldn't use registers ++ int32_t q_strideH = 0; ++ int32_t k_strideH = 0; ++ int32_t v_strideH = 0; ++ int64_t bias_strideH = 0; ++ ++ int64_t q_strideB = 0; ++ int64_t k_strideB = 0; ++ int64_t v_strideB = 0; ++ int64_t bias_strideB = 0; ++ ++ int32_t num_batches = 0; ++ int32_t num_heads = 0; ++ ++ bool use_smooth_softmax = false; ++ ++ // dropout ++ bool use_dropout = false; ++ unsigned long long dropout_batch_head_rng_offset = 0; ++ float dropout_prob = 0.0f; ++#ifdef HAS_PYTORCH ++ at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0); ++#endif ++ ++ // Moves pointers to what we should process ++ // Returns "false" if there is no work to do ++ CUTLASS_DEVICE bool advance_to_block() { ++ auto batch_id = blockIdx.z; ++ auto head_id = blockIdx.y; ++ auto query_start = blockIdx.x * kQueriesPerBlock; ++ ++ auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; ++ ++ if (kSupportsDropout) { ++ dropout_batch_head_rng_offset = ++ batch_id * num_heads * num_queries * num_keys + ++ head_id * num_queries * num_keys; ++ } ++ ++ int64_t q_start = 0, k_start = 0; ++ // Advance to current batch - in case of different sequence lengths ++ constexpr bool kToBatchHook = ++ !cutlass::platform::is_same:: ++ value; ++ if (kToBatchHook) { ++ // Call out to a custom implementation. ++ if (!ToBatchHookType_::advance_to_batch(*this, q_start, k_start)) { ++ return false; ++ } ++ } else if (seqstart_q_ptr != nullptr) { ++ assert(seqstart_k_ptr != nullptr); ++ seqstart_q_ptr += batch_id; ++ ++ q_start = seqstart_q_ptr[0]; ++ int64_t q_next_start = seqstart_q_ptr[1]; ++ int64_t k_end; ++ seqstart_k_ptr += batch_id; ++ ++ if (seqlen_k_ptr) { ++ k_start = seqstart_k_ptr[0]; ++ k_end = k_start + seqlen_k_ptr[batch_id]; ++ } else { ++ k_start = seqstart_k_ptr[0]; ++ k_end = seqstart_k_ptr[1]; ++ } ++ ++ num_queries = q_next_start - q_start; ++ num_keys = k_end - k_start; ++ ++ if (query_start >= num_queries) { ++ return false; ++ } ++ } else { ++ query_ptr += batch_id * q_strideB; ++ key_ptr += batch_id * k_strideB; ++ value_ptr += batch_id * v_strideB; ++ output_ptr += int64_t(batch_id * num_queries) * o_strideM; ++ if (output_accum_ptr != nullptr) { ++ output_accum_ptr += ++ int64_t(batch_id * num_queries) * (head_dim_value * num_heads); ++ } ++ q_start = 0; ++ k_start = 0; ++ } ++ ++ // Advance to the current batch / head / query_start ++ query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH; ++ key_ptr += k_start * k_strideM + head_id * k_strideH; ++ ++ value_ptr += k_start * v_strideM + head_id * v_strideH; ++ output_ptr += ++ int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value; ++ ++ if (kSupportsBias && attn_bias_ptr != nullptr) { ++ attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH); ++ } ++ if (output_accum_ptr != nullptr) { ++ output_accum_ptr += ++ int64_t(q_start + query_start) * (head_dim_value * num_heads) + ++ head_id * head_dim_value; ++ } else { ++ // Accumulate directly in the destination buffer (eg for f32) ++ output_accum_ptr = (accum_t*)output_ptr; ++ } ++ ++ if (logsumexp_ptr != nullptr) { ++ // lse[batch_id, head_id, query_start] ++ logsumexp_ptr += ++ batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; ++ } ++ ++ // Custom masking ++ if (custom_mask_type == CausalFromBottomRight) { ++ causal_diagonal_offset = num_keys - num_queries; ++ } ++ // We use num_keys_absolute to index into the rng_state ++ // We need this index to match between forward and backwards ++ num_keys_absolute = num_keys; ++ if (custom_mask_type == CausalFromTopLeft || ++ custom_mask_type == CausalFromBottomRight) { ++ // the bottom row of the current block is query_start + kQueriesPerBlock ++ // the last active key is then query_start + causal_diagonal_offset + ++ // kQueriesPerBlock so num_keys is the min between actual num_keys and ++ // this to avoid extra computations ++ num_keys = cutlass::fast_min( ++ int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock), ++ num_keys); ++ } ++ ++ num_queries -= query_start; ++ num_batches = 0; // no longer used after ++ ++ // If num_queries == 1, and there is only one key head we're wasting ++ // 15/16th of tensor core compute In that case : ++ // - we only launch kernels for head_id % kQueriesPerBlock == 0 ++ // - we iterate over heads instead of queries (strideM = strideH) ++ if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) { ++ if (head_id % kQueriesPerBlock != 0) ++ return false; ++ q_strideM = q_strideH; ++ num_queries = num_heads; ++ num_heads = 1; // unused but here for intent ++ // remove causal since n_query = 1 ++ // otherwise, offset would change with head ! ++ custom_mask_type = NoCustomMask; ++ o_strideM = head_dim_value; ++ } ++ ++ // Make sure the compiler knows these variables are the same on all ++ // the threads of the warp. ++ // Only worth doing if they could have been modified above. ++ query_ptr = warp_uniform(query_ptr); ++ key_ptr = warp_uniform(key_ptr); ++ value_ptr = warp_uniform(value_ptr); ++ if (kSupportsBias) { ++ attn_bias_ptr = warp_uniform(attn_bias_ptr); ++ } ++ output_ptr = warp_uniform(output_ptr); ++ output_accum_ptr = warp_uniform(output_accum_ptr); ++ logsumexp_ptr = warp_uniform(logsumexp_ptr); ++ num_queries = warp_uniform(num_queries); ++ num_keys = warp_uniform(num_keys); ++ num_heads = warp_uniform(num_heads); ++ o_strideM = warp_uniform(o_strideM); ++ custom_mask_type = warp_uniform(custom_mask_type); ++ return true; ++ } ++ ++ __host__ dim3 getBlocksGrid() const { ++ return dim3( ++ ceil_div(num_queries, (int32_t)kQueriesPerBlock), ++ num_heads, ++ num_batches); ++ } ++ ++ __host__ dim3 getThreadsGrid() const { ++ return dim3(kWarpSize, kNumWarpsPerBlock, 1); ++ } ++ }; ++ ++ struct MM0 { ++ /* ++ In this first matmul, we compute a block of `Q @ K.T`. ++ While the calculation result is still hot in registers, we update ++ `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value ++ into a shared-memory ("AccumulatorSharedStorage") that is used later as ++ operand A for the second matmul (see MM1) ++ */ ++ using GemmType = DefaultGemmType; ++ ++ using OpClass = typename GemmType::OpClass; ++ using DefaultConfig = ++ typename cutlass::gemm::device::DefaultGemmConfiguration< ++ OpClass, ++ ArchTag, ++ scalar_t, ++ scalar_t, ++ scalar_t, // ElementC ++ accum_t // ElementAccumulator ++ >; ++ static constexpr int kAlignmentA = ++ kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; ++ static constexpr int kAlignmentB = ++ kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; ++ using ThreadblockShape = cutlass::gemm:: ++ GemmShape; ++ using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; ++ using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< ++ scalar_t, // ElementA, ++ cutlass::layout::RowMajor, // LayoutA, ++ kAlignmentA, ++ scalar_t, // ElementB, ++ cutlass::layout::ColumnMajor, // LayoutB, ++ kAlignmentB, ++ accum_t, ++ cutlass::layout::RowMajor, // LayoutC, ++ OpClass, ++ ArchTag, // ArchTag ++ ThreadblockShape, // ThreadblockShape ++ WarpShape, // WarpShape ++ typename GemmType::InstructionShape, // InstructionShape ++ ArchTag::kMinComputeCapability >= 80 && kIsHalf ++ ? 4 ++ : DefaultConfig::kStages, ++ typename GemmType::Operator // Operator ++ >::DefaultMma; ++ using MmaCore = typename DefaultMma::MmaCore; ++ using IteratorA = typename DefaultMma::IteratorA; ++ using IteratorB = typename DefaultMma::IteratorB; ++ using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma; ++ using Mma = typename cutlass::platform::conditional< ++ kSingleValueIteration, ++ typename MakeCustomMma::Mma, ++ DefaultThreadblockMma>::type; ++ using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< ++ typename Mma::Operator::IteratorC, ++ accum_t, ++ kWarpSize>::Iterator; ++ static_assert( ++ MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * ++ MmaCore::WarpCount::kK == ++ kNumWarpsPerBlock, ++ ""); ++ ++ // used for efficient load of bias tile Bij from global to shared memory ++ using BiasLoader = TileSmemLoader< ++ scalar_t, ++ cutlass::MatrixShape, ++ MmaCore::kThreads, ++ // input restriction: kv_len has to be a multiple of this value ++ 128 / cutlass::sizeof_bits::value>; ++ ++ // Epilogue to store to shared-memory in a format that we can use later for ++ // the second matmul ++ using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< ++ typename Mma::Operator::IteratorC, ++ typename Mma::Operator, ++ scalar_t, ++ WarpShape, ++ ThreadblockShape>; ++ using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; ++ }; ++ ++ struct MM1 { ++ /** ++ Second matmul: perform `attn @ V` where `attn` is the attention (not ++ normalized) and stored in shared memory ++ */ ++ using GemmType = DefaultGemmType; ++ ++ using OpClass = typename GemmType::OpClass; ++ using DefaultConfig = ++ typename cutlass::gemm::device::DefaultGemmConfiguration< ++ OpClass, ++ ArchTag, ++ scalar_t, ++ scalar_t, ++ output_accum_t, // ElementC ++ accum_t // ElementAccumulator ++ >; ++ static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem ++ static constexpr int kAlignmentB = ++ kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; ++ using ThreadblockShape = cutlass::gemm:: ++ GemmShape; ++ using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; ++ using InstructionShape = typename GemmType::InstructionShape; ++ ++ using LayoutB = cutlass::layout::RowMajor; ++ using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< ++ scalar_t, // ElementA, ++ cutlass::layout::RowMajor, // LayoutA, ++ kAlignmentA, ++ scalar_t, // ElementB, ++ LayoutB, // LayoutB, ++ kAlignmentB, ++ output_accum_t, ++ cutlass::layout::RowMajor, // LayoutC, ++ accum_t, ++ OpClass, ++ ArchTag, ++ ThreadblockShape, ++ WarpShape, ++ typename GemmType::InstructionShape, ++ typename DefaultConfig::EpilogueOutputOp, ++ void, // ThreadblockSwizzle - not used ++ ArchTag::kMinComputeCapability >= 80 && kIsHalf ++ ? 4 ++ : DefaultConfig::kStages, ++ false, // SplitKSerial ++ typename GemmType::Operator>; ++ ++ using WarpIteratorA = typename cutlass::gemm::threadblock:: ++ DefaultWarpIteratorAFromSharedMemory< ++ typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape ++ typename DefaultGemm::Mma::Policy::Operator::InstructionShape, ++ typename DefaultGemm::Mma::Policy::Operator::IteratorA, ++ typename DefaultGemm::Mma::Policy>::WarpIterator; ++ using DefaultMmaFromSmem = ++ typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< ++ typename DefaultGemm::Mma, ++ MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK ++ WarpIteratorA, ++ false>; // kScaleOperandA ++ using Mma = typename DefaultMmaFromSmem::Mma; ++ using IteratorB = typename Mma::IteratorB; ++ using WarpCount = typename Mma::WarpCount; ++ static_assert( ++ WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock, ++ ""); ++ ++ using DefaultEpilogue = typename DefaultGemm::Epilogue; ++ using OutputTileIterator = ++ typename cutlass::epilogue::threadblock::PredicatedTileIterator< ++ typename DefaultEpilogue::OutputTileIterator::ThreadMap, ++ output_t>; ++ using OutputTileIteratorAccum = ++ typename cutlass::epilogue::threadblock::PredicatedTileIterator< ++ typename DefaultEpilogue::OutputTileIterator::ThreadMap, ++ output_accum_t>; ++ }; ++ ++ static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; ++ static constexpr int64_t kAlignmentK = MM0::kAlignmentB; ++ static constexpr int64_t kAlignmentV = 1; ++ ++ // Shared storage - depends on kernel params ++ struct ScalingCoefs { ++ cutlass::Array m_prime; ++ cutlass::Array s_prime; ++ cutlass::Array mi; ++ cutlass::Array out_rescale; ++ cutlass::Array ++ addition_storage; ++ }; ++ ++ struct SharedStorageEpilogueAtEnd : ScalingCoefs { ++ struct SharedStorageAfterMM0 { ++ // Everything here might be overwritten during MM0 ++ union { ++ typename MM0::BiasLoader::SmemTile bias; ++ typename MM0::AccumulatorSharedStorage si; ++ }; ++ typename MM1::Mma::SharedStorage mm1; ++ }; ++ ++ union { ++ typename MM0::Mma::SharedStorage mm0; ++ SharedStorageAfterMM0 after_mm0; ++ typename MM1::DefaultEpilogue::SharedStorage epilogue; ++ }; ++ ++ CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& ++ epilogue_shared_storage() { ++ return epilogue; ++ } ++ }; ++ ++ struct SharedStorageEpilogueInLoop : ScalingCoefs { ++ struct SharedStorageAfterMM0 { ++ // Everything here might be overwritten during MM0 ++ union { ++ typename MM0::BiasLoader::SmemTile bias; ++ typename MM0::AccumulatorSharedStorage si; ++ }; ++ typename MM1::Mma::SharedStorage mm1; ++ typename MM1::DefaultEpilogue::SharedStorage epilogue; ++ }; ++ ++ union { ++ typename MM0::Mma::SharedStorage mm0; ++ SharedStorageAfterMM0 after_mm0; ++ }; ++ ++ CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& ++ epilogue_shared_storage() { ++ return after_mm0.epilogue; ++ } ++ }; ++ ++ using SharedStorage = typename cutlass::platform::conditional< ++ kSingleValueIteration || kKeepOutputInRF, ++ SharedStorageEpilogueAtEnd, ++ SharedStorageEpilogueInLoop>::type; ++ ++ static bool __host__ check_supported(Params const& p) { ++ CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); ++ CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); ++ CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); ++ if (kSupportsBias) { ++ CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); ++ XFORMERS_CHECK( ++ p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0, ++ "attn_bias is not correctly aligned (strideB)"); ++ XFORMERS_CHECK( ++ p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0, ++ "attn_bias is not correctly aligned (strideH)"); ++ XFORMERS_CHECK( ++ p.bias_strideM % kAlignmentQ == 0, ++ "attn_bias is not correctly aligned"); ++ } ++ XFORMERS_CHECK( ++ p.q_strideM % kAlignmentQ == 0, ++ "query is not correctly aligned (strideM)"); ++ XFORMERS_CHECK( ++ p.k_strideM % kAlignmentK == 0, ++ "key is not correctly aligned (strideM)"); ++ XFORMERS_CHECK( ++ p.v_strideM % kAlignmentV == 0, ++ "value is not correctly aligned (strideM)"); ++ XFORMERS_CHECK( ++ p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0, ++ "query is not correctly aligned (strideH)"); ++ XFORMERS_CHECK( ++ p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0, ++ "key is not correctly aligned (strideH)"); ++ XFORMERS_CHECK( ++ p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0, ++ "value is not correctly aligned (strideH)"); ++ XFORMERS_CHECK( ++ p.custom_mask_type < NumCustomMaskTypes, ++ "invalid value for `custom_mask_type`"); ++ return true; ++ } ++ ++ static void CUTLASS_DEVICE attention_kernel(Params& p) { ++ // In this block, we will only ever: ++ // - read query[query_start:query_end, :] ++ // - write to output[query_start:query_end, :] ++ ++ extern __shared__ char smem_buffer[]; ++ SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); ++ auto& m_prime = shared_storage.m_prime; ++ auto& s_prime = shared_storage.s_prime; ++ auto& mi = shared_storage.mi; ++ auto& out_rescale = shared_storage.out_rescale; ++ const uint32_t query_start = blockIdx.x * kQueriesPerBlock; ++ ++ static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); ++ if (thread_id() < kQueriesPerBlock) { ++ s_prime[thread_id()] = accum_t(0); ++ out_rescale[thread_id()] = accum_t(1.0); ++ m_prime[thread_id()] = ++ -cutlass::platform::numeric_limits::infinity(); ++ mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); ++ } ++ typename MM1::Mma::FragmentC accum_o; ++ accum_o.clear(); ++ ++ auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { ++ using OutputTileIterator = typename MM1::OutputTileIterator; ++ return OutputTileIterator( ++ typename OutputTileIterator::Params{(int32_t)p.o_strideM}, ++ p.output_ptr, ++ typename OutputTileIterator::TensorCoord{ ++ p.num_queries, p.head_dim_value}, ++ thread_id(), ++ {0, col}); ++ }; ++ ++ auto createOutputAccumIter = [&](int col) -> ++ typename MM1::OutputTileIteratorAccum { ++ using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; ++ return OutputTileIteratorAccum( ++ typename OutputTileIteratorAccum::Params{ ++ (int32_t)(p.head_dim_value * p.num_heads)}, ++ p.output_accum_ptr, ++ typename OutputTileIteratorAccum::TensorCoord{ ++ p.num_queries, p.head_dim_value}, ++ thread_id(), ++ {0, col}); ++ }; ++ ++#ifdef HAS_PYTORCH ++ curandStatePhilox4_32_10_t curand_state_init; ++ if (kSupportsDropout && p.use_dropout) { ++ const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); ++ ++ // each element of the attention matrix P with shape ++ // (batch_sz, n_heads, n_queries, n_keys) is associated with a single ++ // offset in RNG sequence. we initialize the RNG state with offset that ++ // starts at the beginning of a (n_queries, n_keys) matrix for this ++ // block's batch_id and head_id ++ // initializing rng state is very expensive, so we run once per kernel, ++ // rather than once per iteration. each iteration takes a copy of the ++ // initialized RNG state and offsets it as needed. ++ curand_init( ++ std::get<0>(seeds), ++ 0, ++ std::get<1>(seeds) + p.dropout_batch_head_rng_offset, ++ &curand_state_init); ++ } ++#endif ++ ++ // Iterate through keys ++ for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; ++ iter_key_start += kKeysPerBlock) { ++ int32_t problem_size_0_m = ++ cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); ++ int32_t problem_size_0_n = cutlass::fast_min( ++ int32_t(kKeysPerBlock), p.num_keys - iter_key_start); ++ int32_t const& problem_size_0_k = p.head_dim; ++ int32_t const& problem_size_1_n = p.head_dim_value; ++ int32_t const& problem_size_1_k = problem_size_0_n; ++ ++ auto prologueV = [&](int blockN) { ++ typename MM1::Mma::IteratorB iterator_V( ++ typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)}, ++ p.value_ptr + iter_key_start * p.v_strideM, ++ {problem_size_1_k, problem_size_1_n}, ++ thread_id(), ++ cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); ++ MM1::Mma::prologue( ++ shared_storage.after_mm0.mm1, ++ iterator_V, ++ thread_id(), ++ problem_size_1_k); ++ }; ++ ++ __syncthreads(); // Need to have shared memory initialized, and `m_prime` ++ // updated from end of prev iter ++ // ++ // MATMUL: Q.K_t ++ // ++ // Computes the block-matrix product of: ++ // (a) query[query_start:query_end, :] ++ // with ++ // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] ++ // and stores that into `shared_storage.si` ++ // ++ ++ // Compute threadblock location ++ cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0}; ++ ++ cutlass::MatrixCoord tb_offset_A{ ++ tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()}; ++ ++ cutlass::MatrixCoord tb_offset_B{ ++ tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN}; ++ ++ // Construct iterators to A and B operands ++ typename MM0::IteratorA iterator_A( ++ typename MM0::IteratorA::Params( ++ typename MM0::MmaCore::LayoutA(p.q_strideM)), ++ p.query_ptr, ++ {problem_size_0_m, problem_size_0_k}, ++ thread_id(), ++ tb_offset_A); ++ ++ typename MM0::IteratorB iterator_B( ++ typename MM0::IteratorB::Params( ++ typename MM0::MmaCore::LayoutB(p.k_strideM)), ++ p.key_ptr + iter_key_start * p.k_strideM, ++ {problem_size_0_k, problem_size_0_n}, ++ thread_id(), ++ tb_offset_B); ++ ++ auto my_warp_id = warp_uniform(warp_id()); ++ auto my_lane_id = lane_id(); ++ ++ // Construct thread-scoped matrix multiply ++ typename MM0::Mma mma( ++ shared_storage.mm0, thread_id(), my_warp_id, my_lane_id); ++ ++ typename MM0::Mma::FragmentC accum; ++ ++ accum.clear(); ++ ++ auto gemm_k_iterations = ++ (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; ++ ++ // Compute threadblock-scoped matrix multiply-add ++ mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); ++ __syncthreads(); ++ ++ if (kPreloadV) { ++ prologueV(0); ++ } else { ++ MM1::Mma::drain_cp_asyncs(); ++ } ++ ++ typename MM0::Mma::Operator::IteratorC::TensorCoord ++ iteratorC_tile_offset = { ++ (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) + ++ (my_warp_id % MM0::Mma::WarpCount::kM), ++ (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + ++ (my_warp_id / MM0::Mma::WarpCount::kM)}; ++ ++ // multiply by scaling factor ++ if (kSupportsBias) { ++ accum = ++ cutlass::multiplies()(p.scale, accum); ++ } ++ ++ // apply attention bias if applicable ++ if (kSupportsBias && p.attn_bias_ptr != nullptr) { ++ // load bias tile Bij into shared memory ++ typename MM0::BiasLoader::GmemTileIterator bias_iter( ++ {cutlass::layout::RowMajor(p.bias_strideM)}, ++ // attn_bias_pointer points to matrix of size (n_queries, n_keys) ++ // for the relevant batch_id and head_id ++ p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start, ++ {problem_size_0_m, problem_size_0_n}, ++ thread_id()); ++ cutlass::TensorRef bias_tensor_ref( ++ shared_storage.after_mm0.bias.data(), ++ cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); ++ typename MM0::BiasLoader::SmemTileIterator smem_tile_iter( ++ bias_tensor_ref, thread_id()); ++ MM0::BiasLoader::load(bias_iter, smem_tile_iter); ++ ++ // Pij += Bij, Pij is in register fragment and Bij is in shared memory ++ auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( ++ my_lane_id, my_warp_id, iteratorC_tile_offset); ++ MM0::AccumLambdaIterator::iterateRows( ++ lane_offset, ++ [&](int accum_m) {}, ++ [&](int accum_m, int accum_n, int idx) { ++ if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { ++ accum[idx] += bias_tensor_ref.at({accum_m, accum_n}); ++ } ++ }, ++ [&](int accum_m) {}); ++ } ++ ++ // Mask out last if causal ++ // This is only needed if upper-right corner of current query / key block ++ // intersects the mask Coordinates of upper-right corner of current block ++ // is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The ++ // first masked element is x = y + offset -> query_start + offset There is ++ // intersection (and we need to mask) if min(iter_key_start + ++ // kKeysPerBlock, num_keys)) >= query_start + offset ++ if (p.custom_mask_type && ++ cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >= ++ (query_start + p.causal_diagonal_offset)) { ++ auto query_start = blockIdx.x * kQueriesPerBlock; ++ auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( ++ my_lane_id, my_warp_id, iteratorC_tile_offset); ++ int32_t last_col; ++ MM0::AccumLambdaIterator::iterateRows( ++ lane_offset, ++ [&](int accum_m) { ++ // last absolute col is (last absolute query + offset) ++ // last local col is (last absolute query + offset - ++ // iter_key_start) ++ last_col = query_start + accum_m + p.causal_diagonal_offset - ++ iter_key_start; ++ }, ++ [&](int accum_m, int accum_n, int idx) { ++ if (accum_n > last_col) { ++ accum[idx] = ++ -cutlass::platform::numeric_limits::infinity(); ++ } ++ }, ++ [&](int accum_m) {}); ++ } ++ // Update `mi` from accum stored in registers ++ // Also does accum[i] <- exp(accum[i] - mi) ++ iterative_softmax( ++ accum_o, ++ accum, ++ mi, ++ m_prime, ++ s_prime, ++ out_rescale, ++ shared_storage.addition_storage, ++ my_lane_id, ++ thread_id(), ++ my_warp_id, ++ p.num_keys - iter_key_start, ++ iter_key_start == 0, ++ iteratorC_tile_offset, ++ kSupportsBias ? 1.0f : p.scale, ++ p.use_smooth_softmax); ++ ++ // Output results to shared-memory ++ int warp_idx_mn_0 = my_warp_id % ++ (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); ++ auto output_tile_coords = cutlass::MatrixCoord{ ++ warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, ++ warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; ++ ++ MM0::B2bGemm::accumToSmem( ++ shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); ++ ++ __syncthreads(); ++ ++#ifdef HAS_PYTORCH ++ // apply dropout (if applicable) after we've written Pij to smem. ++ // dropout is applied by multiplying each element of Pij by: ++ // - 0 with probability dropout_p ++ // - 1 / (1 - dropout_p) with probability 1 - dropout_p ++ // ++ // for backward purposes we want to be able to map each element of the ++ // attention matrix to the same random uniform number as the one we used ++ // in forward, without needing to use the same iteration order or having ++ // to store the dropout matrix. its possible to do this in registers but ++ // it ends up being very slow because each thread having noncontiguous ++ // strips of the Pij tile means we have to skip around a lot, and also ++ // have to generate a single random number at a time ++ if (kSupportsDropout && p.use_dropout) { ++ auto si = shared_storage.after_mm0.si.accum_ref(); ++ // each thread handles a contiguous sequence of elements from Sij, all ++ // coming from the same row. the reason they have to come from the same ++ // row is that the sampling random numbers from a contiguous random ++ // number sequence is much more efficient than jumping around, and the ++ // linear offset of each element of S (the global matrix) maps to an ++ // offset in a random number sequence. for S, the end of a row and the ++ // beginning of the next have adjacent offsets, but for Sij, this is not ++ // necessarily the case. ++ const int num_threads = blockDim.x * blockDim.y * blockDim.z; ++ const int threads_per_row = ++ cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n); ++ const int elts_per_thread = cutlass::round_nearest( ++ cutlass::ceil_div(problem_size_0_n, threads_per_row), 4); ++ ++ const int thread_i = thread_id() / threads_per_row; ++ const int thread_start_j = ++ (thread_id() % threads_per_row) * elts_per_thread; ++ ++ if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) { ++ curandStatePhilox4_32_10_t curand_state = curand_state_init; ++ skipahead( ++ static_cast( ++ (query_start + thread_i) * p.num_keys_absolute + ++ (iter_key_start + thread_start_j)), ++ &curand_state); ++ const float dropout_scale = 1.0 / (1.0 - p.dropout_prob); ++ ++ // apply dropout scaling to elements this thread is responsible for, ++ // in chunks of 4 ++ for (int sij_start_col_idx = thread_start_j; sij_start_col_idx < ++ cutlass::fast_min(thread_start_j + elts_per_thread, ++ problem_size_0_n); ++ sij_start_col_idx += 4) { ++ const float4 rand_uniform_quad = curand_uniform4(&curand_state); ++ ++ CUTLASS_PRAGMA_UNROLL ++ for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { ++ si.at({thread_i, sij_start_col_idx + quad_idx}) *= ++ static_cast( ++ dropout_scale * ++ ((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob)); ++ } ++ } ++ } ++ __syncthreads(); // p.use_dropout should have same value kernel-wide ++ } ++#endif ++ ++ // ++ // MATMUL: Attn . V ++ // Run the matmul `attn @ V` for a block of attn and V. ++ // `attn` is read from shared memory (in `shared_storage_si`) ++ // `V` is read from global memory (with iterator_B) ++ // ++ ++ const int64_t nBlockN = kSingleValueIteration ++ ? 1 ++ : ceil_div( ++ (int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN)); ++ for (int blockN = 0; blockN < nBlockN; ++blockN) { ++ int gemm_k_iterations = ++ (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; ++ ++ // Compute threadblock-scoped matrix multiply-add and store it in accum ++ // (in registers) ++ if (!kPreloadV) { ++ __syncthreads(); // we share shmem between mma and epilogue ++ } ++ ++ typename MM1::Mma::IteratorB iterator_V( ++ typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)}, ++ p.value_ptr + iter_key_start * p.v_strideM, ++ {problem_size_1_k, problem_size_1_n}, ++ thread_id(), ++ cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); ++ typename MM1::Mma mma_pv( ++ // operand A: Pij_dropped in shared memory ++ shared_storage.after_mm0.si.accum_ref(), ++ // operand B: shared memory staging area for Vj, which is loaded ++ // from global memory ++ shared_storage.after_mm0.mm1.operand_B_ref(), ++ (int)thread_id(), ++ (int)my_warp_id, ++ (int)my_lane_id); ++ mma_pv.set_prologue_done(kPreloadV); ++ if (!kKeepOutputInRF) { ++ accum_o.clear(); ++ } ++ mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); ++ __syncthreads(); ++ ++ if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) { ++ prologueV(blockN + 1); ++ } ++ ++ if (!kKeepOutputInRF) { ++ MM1::Mma::drain_cp_asyncs(); ++ DISPATCH_BOOL( ++ iter_key_start == 0, kIsFirst, ([&] { ++ DISPATCH_BOOL( ++ (iter_key_start + kKeysPerBlock) >= p.num_keys, ++ kIsLast, ++ ([&] { ++ using DefaultEpilogue = typename MM1::DefaultEpilogue; ++ using DefaultOp = ++ typename MM1::DefaultConfig::EpilogueOutputOp; ++ using ElementCompute = typename DefaultOp::ElementCompute; ++ using EpilogueOutputOp = typename cutlass::epilogue:: ++ thread::MemoryEfficientAttentionNormalize< ++ typename cutlass::platform::conditional< ++ kIsLast, ++ output_t, ++ output_accum_t>::type, ++ output_accum_t, ++ DefaultOp::kCount, ++ typename DefaultOp::ElementAccumulator, ++ ElementCompute, ++ kIsFirst, ++ kIsLast, ++ cutlass::Array>; ++ using Epilogue = typename cutlass::epilogue::threadblock:: ++ EpiloguePipelined< ++ typename DefaultEpilogue::Shape, ++ typename MM1::Mma::Operator, ++ DefaultEpilogue::kPartitionsK, ++ typename cutlass::platform::conditional< ++ kIsLast, ++ typename MM1::OutputTileIterator, ++ typename MM1::OutputTileIteratorAccum>::type, ++ typename DefaultEpilogue:: ++ AccumulatorFragmentIterator, ++ typename DefaultEpilogue::WarpTileIterator, ++ typename DefaultEpilogue::SharedLoadIterator, ++ EpilogueOutputOp, ++ typename DefaultEpilogue::Padding, ++ DefaultEpilogue::kFragmentsPerIteration, ++ true, // IterationsUnroll ++ typename MM1::OutputTileIteratorAccum // Read ++ // iterator ++ >; ++ ++ int col = blockN * MM1::Mma::Shape::kN; ++ auto source_iter = createOutputAccumIter(col); ++ auto dest_iter = call_conditional< ++ kIsLast, ++ decltype(createOutputIter), ++ decltype(createOutputAccumIter)>:: ++ apply(createOutputIter, createOutputAccumIter, col); ++ EpilogueOutputOp rescale(s_prime, out_rescale); ++ Epilogue epilogue( ++ shared_storage.epilogue_shared_storage(), ++ thread_id(), ++ my_warp_id, ++ my_lane_id); ++ epilogue(rescale, dest_iter, accum_o, source_iter); ++ })); ++ })); ++ if (!kSingleValueIteration) { ++ __syncthreads(); ++ } ++ } ++ } ++ __syncthreads(); // we modify `m_prime` after ++ } ++ ++ if (kKeepOutputInRF) { ++ constexpr bool kIsFirst = true; ++ constexpr bool kIsLast = true; ++ using DefaultEpilogue = typename MM1::DefaultEpilogue; ++ using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; ++ using ElementCompute = typename DefaultOp::ElementCompute; ++ using EpilogueOutputOp = ++ typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< ++ output_t, // output ++ output_accum_t, // source ++ DefaultOp::kCount, ++ typename DefaultOp::ElementAccumulator, // accum ++ output_accum_t, // compute ++ kIsFirst, ++ kIsLast, ++ cutlass::Array>; ++ using Epilogue = ++ typename cutlass::epilogue::threadblock::EpiloguePipelined< ++ typename DefaultEpilogue::Shape, ++ typename MM1::Mma::Operator, ++ DefaultEpilogue::kPartitionsK, ++ typename MM1::OutputTileIterator, // destination ++ typename DefaultEpilogue::AccumulatorFragmentIterator, ++ typename DefaultEpilogue::WarpTileIterator, ++ typename DefaultEpilogue::SharedLoadIterator, ++ EpilogueOutputOp, ++ typename DefaultEpilogue::Padding, ++ DefaultEpilogue::kFragmentsPerIteration, ++ true, // IterationsUnroll ++ typename MM1::OutputTileIteratorAccum // source tile ++ >; ++ auto dest_iter = createOutputIter(0); ++ EpilogueOutputOp rescale(s_prime, out_rescale); ++ Epilogue epilogue( ++ shared_storage.epilogue_shared_storage(), ++ thread_id(), ++ warp_id(), ++ lane_id()); ++ MM1::Mma::drain_cp_asyncs(); ++ epilogue(rescale, dest_iter, accum_o); ++ } ++ ++ // 7. Calculate logsumexp ++ // To make the backward easier, we pad logsumexp with `inf` ++ // this avoids a few bound checks, and is not more expensive during fwd ++ static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); ++ if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { ++ auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; ++ constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E ++ if (thread_id() < p.num_queries) { ++ p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) + ++ cutlass::fast_log(accum_t(s_prime[thread_id()])); ++ } else if (thread_id() < lse_dim) { ++ p.logsumexp_ptr[thread_id()] = ++ cutlass::platform::numeric_limits::infinity(); ++ } ++ } ++ } ++ ++ template ++ CUTLASS_DEVICE static void iterative_softmax( ++ typename WarpIteratorC::Fragment& frag_o, // output so far ++ typename WarpIteratorC::Fragment& frag, ++ cutlass::Array& mi, ++ cutlass::Array& m_prime, ++ cutlass::Array& s_prime, ++ cutlass::Array& out_rescale, ++ cutlass::Array& ++ addition_storage, ++ int8_t lane_id, ++ int8_t thread_id, ++ int8_t warp_id, ++ int max_col, ++ bool is_first, ++ typename WarpIteratorC::TensorCoord const& tile_offset, ++ float scaling, ++ bool use_smooth_softmax) { ++ /* Iterates on the accumulator and corresponding position on result matrix ++ ++ (1) Update `mi[r]` to the max value of the row `r` ++ (2) In a second iteration do the following: ++ (a) accum <- exp(accum - mi) ++ (b) m_prime <- exp(m_prime - mi) ++ (c) s_prime <- s_prime * m_prime + sum(accum) ++ ++ All of this is done on registers, before we store all of this ++ on shared memory for the next matmul with Value. ++ */ ++ using Fragment = typename WarpIteratorC::Fragment; ++ using LambdaIterator = typename DefaultMmaAccumLambdaIterator< ++ WarpIteratorC, ++ accum_t, ++ kWarpSize>::Iterator; ++ // Convert to `accum_t` (rather than double) ++ constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E ++ ++ static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, ""); ++ static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock; ++ ++ frag = cutlass::multiplies()(scaling * kLog2e, frag); ++ ++ auto lane_offset = ++ LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); ++ ++ // First update `mi` to the max per-row ++ { ++ accum_t max; ++ LambdaIterator::iterateRows( ++ lane_offset, ++ [&](int accum_m) { ++ max = -cutlass::platform::numeric_limits::infinity(); ++ }, ++ [&](int accum_m, int accum_n, int idx) { ++ if (accum_n < max_col) { ++ max = cutlass::fast_max(max, frag[idx]); ++ } ++ }, ++ [&](int accum_m) { ++ // Having 4x atomicMax seems faster than reduce within warp ++ // first... ++ atomicMaxFloat(&mi[accum_m], max); ++ }); ++ } ++ ++ // Make sure we all share the update values for `mi` ++ __syncthreads(); ++ ++ // Doing this `exp` is quite expensive. Let's ++ // split it across the warps ++ bool restore_mi_to_minus_inf = false; ++ if (lane_id < kLinesPerWarp) { ++ int id = warp_id * kLinesPerWarp + lane_id; ++ auto m_prime_id = m_prime[id]; ++ auto mi_id = mi[id]; ++ bool changed = m_prime_id < mi_id; // `false` if both are -inf ++ if (changed) { ++ auto m_prime_exp = exp2f(m_prime_id - mi_id); ++ out_rescale[id] = m_prime_exp; ++ s_prime[id] *= m_prime_exp; ++ } else { ++ // Only when bias is enabled, it's possible that all the first values ++ // of attention are masked to `-inf`. In that case we want to avoid ++ // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0 ++ if (kSupportsBias && ++ mi_id == -cutlass::platform::numeric_limits::infinity()) { ++ restore_mi_to_minus_inf = true; ++ mi[id] = 0.0f; ++ } ++ out_rescale[id] = 1.0f; ++ } ++ } ++ __syncthreads(); // Update output fragments ++ if (kKeepOutputInRF && !is_first) { ++ accum_t line_rescale; ++ LambdaIterator::iterateRows( ++ lane_offset, ++ [&](int accum_m) { line_rescale = out_rescale[accum_m]; }, ++ [&](int accum_m, int accum_n, int idx) { ++ frag_o[idx] = frag_o[idx] * line_rescale; ++ }, ++ [&](int accum_m) {}); ++ } ++ // Update accum_m, accum_n, ... ++ { ++ accum_t mi_row, total_row; ++ LambdaIterator::iterateRows( ++ lane_offset, ++ [&](int accum_m) { mi_row = mi[accum_m]; }, ++ [&](int accum_m, int accum_n, int idx) { ++ frag[idx] = ++ (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); ++ }, ++ [&](int accum_m) {}); ++ LambdaIterator::iterateRows( ++ lane_offset, ++ [&](int accum_m) { total_row = 0.0; }, ++ [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, ++ [&](int accum_m) { ++ if (LambdaIterator::reduceSameRow( ++ lane_id, total_row, [](accum_t a, accum_t b) { ++ return a + b; ++ })) { ++ // NOTE: we could atomically add `total_row` to `s_prime`, but ++ // it's faster (and deterministic) to avoid atomics here ++ addition_storage ++ [accum_m + kQueriesPerBlock * tile_offset.column()] = ++ total_row; ++ } ++ }); ++ } ++ __syncthreads(); ++ if (lane_id < kLinesPerWarp) { ++ int id = warp_id * kLinesPerWarp + lane_id; ++ accum_t total_row = s_prime[id]; ++ if (restore_mi_to_minus_inf) { ++ // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true` ++ mi[id] = -cutlass::platform::numeric_limits::infinity(); ++ } else { ++ m_prime[id] = mi[id]; ++ } ++ CUTLASS_PRAGMA_UNROLL ++ for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { ++ total_row += addition_storage[id + kQueriesPerBlock * i]; ++ } ++ s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row; ++ } ++ } ++ ++ static CUTLASS_DEVICE int8_t lane_id() { ++ return threadIdx.x; ++ } ++ static CUTLASS_DEVICE int8_t warp_id() { ++ return threadIdx.y; ++ } ++ static CUTLASS_DEVICE int16_t thread_id() { ++ return threadIdx.x + threadIdx.y * blockDim.x; ++ } ++}; ++ ++template ++__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) ++ attention_kernel_batched_impl(typename AK::Params p) { ++ if (!p.advance_to_block()) { ++ return; ++ } ++ AK::attention_kernel(p); ++} ++ ++template ++__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) ++ attention_kernel_batched(typename AK::Params params);