From 1af24b7617876d1d789d9deaddeb4010edea5477 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 21 Oct 2024 08:53:39 -0700 Subject: [PATCH] initial webgpu support (#992) minimal changes to run genai on top of webgpu. There is no onnxruntime webgpu build yet so I can't turn on any tests that use webgpu or do a release build. Will do that in a 2nd pr once the webgpu landed in onnxruntime, --- CMakeLists.txt | 2 ++ build.py | 3 +++ cmake/check_webgpu.cmake | 6 ++++++ cmake/options.cmake | 1 + src/generators.cpp | 2 ++ src/generators.h | 1 + src/models/input_ids.cpp | 3 ++- src/models/kv_cache.cpp | 20 +++++++++--------- src/models/logits.cpp | 4 ++-- src/models/model.cpp | 26 +++++++++++++++++------ src/models/model.h | 10 +++++++-- src/models/position_inputs.cpp | 38 ++++++++++++++++++---------------- src/python/python.cpp | 1 + 13 files changed, 78 insertions(+), 39 deletions(-) create mode 100644 cmake/check_webgpu.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index db64f6958..90519fdec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,8 @@ include(cmake/check_cuda.cmake) include(cmake/check_rocm.cmake) # Checking if DML is supported include(cmake/check_dml.cmake) +# Checking if WebGpu is supported +include(cmake/check_webgpu.cmake) include(cmake/cxx_standard.cmake) diff --git a/build.py b/build.py index b2087a46f..d8fb42867 100644 --- a/build.py +++ b/build.py @@ -126,6 +126,8 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescript parser.add_argument("--use_rocm", action="store_true", help="Whether to use ROCm. Default is to not use rocm.") + parser.add_argument("--use_webgpu", action="store_true", help="Whether to use WebGpu. Default is to not use WebGpu.") + parser.add_argument("--use_dml", action="store_true", help="Whether to use DML. Default is to not use DML.") # The following options are mutually exclusive (cross compiling options such as android, ios, etc.) @@ -471,6 +473,7 @@ def update(args: argparse.Namespace, env: dict[str, str]): "-DCMAKE_POSITION_INDEPENDENT_CODE=ON", f"-DUSE_CUDA={'ON' if args.use_cuda else 'OFF'}", f"-DUSE_ROCM={'ON' if args.use_rocm else 'OFF'}", + f"-DUSE_WEBGPU={'ON' if args.use_webgpu else 'OFF'}", f"-DUSE_DML={'ON' if args.use_dml else 'OFF'}", f"-DENABLE_JAVA={'ON' if args.build_java else 'OFF'}", f"-DBUILD_WHEEL={build_wheel}", diff --git a/cmake/check_webgpu.cmake b/cmake/check_webgpu.cmake new file mode 100644 index 000000000..6ef488052 --- /dev/null +++ b/cmake/check_webgpu.cmake @@ -0,0 +1,6 @@ + +if(USE_WEBGPU) + add_compile_definitions(USE_WEBGPU=1) +else() + add_compile_definitions(USE_WEBGPU=0) +endif() \ No newline at end of file diff --git a/cmake/options.cmake b/cmake/options.cmake index b5dc7c97c..fcb8454bb 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -4,6 +4,7 @@ include(CMakeDependentOption) option(USE_CUDA "Build with CUDA support" ON) option(USE_ROCM "Build with ROCm support" ON) option(USE_DML "Build with DML support" OFF) +option(USE_WEBGPU "Build with WEBGPU support" ON) # bindings option(ENABLE_JAVA "Build the Java API." OFF) diff --git a/src/generators.cpp b/src/generators.cpp index 3ab83f988..1d5e74ff4 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -193,6 +193,8 @@ std::string to_string(DeviceType device_type) { return "CUDA"; case DeviceType::DML: return "DirectML"; + case DeviceType::WEBGPU: + return "WebGpu"; } throw std::runtime_error("Unknown device type"); } diff --git a/src/generators.h b/src/generators.h index 2dde56c32..4350d5f2d 100644 --- a/src/generators.h +++ b/src/generators.h @@ -52,6 +52,7 @@ enum struct DeviceType { CPU, CUDA, DML, + WEBGPU, }; std::string to_string(DeviceType device_type); diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 9daa0a628..acf4bd5fe 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -145,7 +145,8 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { input_ids_cast_command_list_state_); #endif } break; - case DeviceType::CPU: { + default: { + // CPU, WEBGPU auto* data = value_->GetTensorMutableData(); auto next_tokens = next_tokens_unk.GetCPU(); for (int i = 0; i < shape_[0]; i++) { diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index 35f47ef26..381b2eda7 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -34,11 +34,11 @@ KV_Cache_Combined::KV_Cache_Combined(State& state) // Derive the KV data type from the KV input 0 type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]); - empty_past_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + empty_past_ = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_); shape_[3] = state_.params_->sequence_length; for (int i = 0; i < layer_count_; ++i) { - presents_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_)); + presents_.push_back(OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_)); } } @@ -67,7 +67,7 @@ void KV_Cache_Combined::Update(std::span beam_indices, int curren shape_[3] = current_length; for (int i = 0; i < layer_count_; i++) { - presents_[i] = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + presents_[i] = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_); state_.inputs_[input_index_ + i] = pasts_[i].get(); state_.outputs_[output_index_ + i] = presents_[i].get(); } @@ -81,7 +81,7 @@ void KV_Cache_Combined::PickPastState(std::span beam_indices, int auto element_count = shape_[0] * past_key_size; const OrtValue& present = *presents_[index]; - std::unique_ptr past = OrtValue::CreateTensor(*model_.allocator_device_, shape_); + std::unique_ptr past = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_); auto past_span = std::span(past->GetTensorMutableData(), element_count); auto present_span = std::span(present.GetTensorData(), element_count); @@ -149,7 +149,7 @@ KV_Cache::KV_Cache(State& state) // Derive the KV data type from the KV input 0 type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]); - empty_past_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + empty_past_ = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_); // Set the size after empty_past_ has been created with 0 for this field if (past_present_share_buffer_) @@ -167,7 +167,7 @@ KV_Cache::KV_Cache(State& state) for (int i = 0; i < layer_count_ * 2; ++i) { presents_.push_back( - sb_kv_caches_.empty() ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) + sb_kv_caches_.empty() ? OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_) : sb_kv_caches_[i]->CreateTensorOnStaticBuffer(shape_, type_)); } } @@ -216,7 +216,7 @@ void KV_Cache::Update(std::span beam_indices, int current_length) shape_[2] = current_length; for (int i = 0; i < layer_count_ * 2; i++) { - presents_[i] = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); + presents_[i] = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_); state_.outputs_[output_index_ + i] = presents_[i].get(); } } @@ -228,7 +228,7 @@ void KV_Cache::PickPastState(std::span beam_indices, int index) { auto element_count = shape_[0] * block_size_per_beam; const OrtValue& present_value = *presents_[index]; - std::unique_ptr past_value = OrtValue::CreateTensor(*model_.allocator_device_, shape_); + std::unique_ptr past_value = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_); auto past_span = std::span(past_value->GetTensorMutableData(), element_count); auto present_span = std::span(present_value.GetTensorData(), element_count); @@ -280,8 +280,8 @@ Cross_Cache::Cross_Cache(State& state) type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]); for (int i = 0; i < layer_count_; ++i) { - values_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_)); - values_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_)); + values_.push_back(OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_)); + values_.push_back(OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_)); } } diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 0e333e150..b7800c89e 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -98,8 +98,8 @@ RoamingArray Logits::Get() { #endif } break; - case DeviceType::CPU: - case DeviceType::CUDA: { + default: { + // CPU, CUDA, WEBGPU auto logits_raw = std::span{output_raw_->GetTensorMutableData(), element_count * element_size}; auto logits_last_tokens = std::span{logits_of_last_token->GetTensorMutableData(), element_count_last_token * element_size}; auto target = logits_last_tokens.subspan(vocab_index * element_size, vocab_size * element_size); diff --git a/src/models/model.cpp b/src/models/model.cpp index 5459879ff..93338b7fa 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -282,14 +282,23 @@ void Model::InitDeviceAllocator([[maybe_unused]] OrtSession& session) { if (device_type_ == DeviceType::CUDA) { allocator_device_ = GetCudaAllocator(session); } -#elif USE_DML +#endif +#if USE_DML if (device_type_ == DeviceType::DML) { memory_info_device_ = OrtMemoryInfo::Create("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); dml_owned_allocator_ = Ort::Allocator::Create(session, *memory_info_device_); allocator_device_ = dml_owned_allocator_.get(); } #endif - + allocator_kvcache_ = allocator_device_; +#if USE_WEBGPU + if (device_type_ == DeviceType::WEBGPU) { + // for webgpu we only use device memory for kv_cache + memory_info_device_ = OrtMemoryInfo::Create("WebGPU_Buffer", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault); + webgpu_owned_allocator_ = Ort::Allocator::Create(session, *memory_info_device_); + allocator_kvcache_ = webgpu_owned_allocator_.get(); + } +#endif session_info_ = std::make_unique(session); captured_graph_pool_ = std::make_shared(config_.get(), session_info_.get(), allocator_device_); } @@ -461,8 +470,11 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_ for (auto& option : provider_options.options) { opts.emplace(option.first, option.second); } - session_options.AppendExecutionProvider("QNN", opts); + } else if (provider_options.name == "web") { + device_type_ = DeviceType::WEBGPU; + std::unordered_map opts; + session_options.AppendExecutionProvider("WebGPU", opts); } else throw std::runtime_error("Unknown provider type: " + provider_options.name); } @@ -544,8 +556,9 @@ void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptrGetTensorMutableData(); switch (device_type) { + case DeviceType::WEBGPU: case DeviceType::DML: - // DML doesn't currently support on-device scoring, so we fall back to the CPU + // DML, WebGpu doesn't currently support on-device scoring, so we fall back to the CPU case DeviceType::CPU: for (int i = 0; i < count; i++) fp32[i] = FastFloat16ToFloat32(fp16[i]); @@ -605,7 +618,7 @@ std::unique_ptr Model::ExpandInputs(std::unique_ptr& input, // If we're on CUDA, we still want to do the copy to move the data over to CUDA memory where we will read from it later. // DML doesn't currently support on-device scoring, so we go the same route as the CPU - if (num_beams == 1 && (device_type_ == DeviceType::CPU || device_type_ == DeviceType::DML)) { + if (num_beams == 1 && (device_type_ == DeviceType::CPU || device_type_ == DeviceType::DML || device_type_ == DeviceType::WEBGPU)) { return std::move(input); } @@ -625,8 +638,9 @@ std::unique_ptr Model::ExpandInputs(std::unique_ptr& input, auto* target = expanded_data; switch (device_type_) { + case DeviceType::WEBGPU: case DeviceType::DML: - // DML doesn't currently support on-device scoring, so we use the CPU for non-cache inputs/outputs + // DML and WebGpu doesn't currently support on-device scoring, so we use the CPU for non-cache inputs/outputs case DeviceType::CPU: for (int i = 0; i < batch_size; i++) { for (int j = 0; j < num_beams; j++) { diff --git a/src/models/model.h b/src/models/model.h index 1a6b1b0f0..9eed5df75 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -143,7 +143,8 @@ struct Model : std::enable_shared_from_this, LeakChecked { cuda_stream_holder cuda_stream_; DeviceType device_type_{DeviceType::CPU}; Ort::Allocator& allocator_cpu_{Ort::Allocator::GetWithDefaultOptions()}; - Ort::Allocator* allocator_device_{}; // Can be CUDA or CPU based on the DeviceType in the model + Ort::Allocator* allocator_device_{}; // Can be CUDA or CPU based on the DeviceType in the model + Ort::Allocator* allocator_kvcache_{}; // keep allocator for kv_cache seperate to allow that only kv_cache is on device std::unique_ptr session_info_; @@ -175,9 +176,14 @@ struct Model : std::enable_shared_from_this, LeakChecked { std::unique_ptr dml_readback_heap_; ComPtr dml_device_; std::unique_ptr dml_owned_allocator_; +#endif +#if USE_WEBGPU + std::unique_ptr webgpu_owned_allocator_; + std::unique_ptr webgpu_io_binding_; +#endif +#if USE_DML || USE_WEBGPU std::unique_ptr memory_info_device_; #endif - std::shared_ptr captured_graph_pool_; std::map> pipeline_session_options_; }; diff --git a/src/models/position_inputs.cpp b/src/models/position_inputs.cpp index 2666afc17..a14e36225 100644 --- a/src/models/position_inputs.cpp +++ b/src/models/position_inputs.cpp @@ -172,13 +172,6 @@ void PositionInputs::UpdatePositionIDs(int current_length) { model_.GetDmlExecutionContext()->ExecuteCommandList(dml_update_position_ids_kernel_->GetCommandList(), &fence, &completion_value); } break; #endif - case DeviceType::CPU: { - if (type_ == Ort::TypeToTensorType) - UpdatePositionIDsImpl(); - else - UpdatePositionIDsImpl(); - break; - } #if USE_CUDA case DeviceType::CUDA: if (type_ == Ort::TypeToTensorType) @@ -187,6 +180,14 @@ void PositionInputs::UpdatePositionIDs(int current_length) { cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData(), static_cast(position_ids_shape_[0]), model_.cuda_stream_); break; #endif + case DeviceType::CPU: + case DeviceType::WEBGPU: { + if (type_ == Ort::TypeToTensorType) + UpdatePositionIDsImpl(); + else + UpdatePositionIDsImpl(); + break; + } default: throw std::runtime_error("PositionIDs::Update - Unsupported device type"); } @@ -269,17 +270,6 @@ void PositionInputs::UpdateAttentionMask(int current_length) { break; } #endif - case DeviceType::CPU: { - if (type_ == Ort::TypeToTensorType) - UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), - attention_mask_->GetTensorData(), - current_length); - else - UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), - attention_mask_->GetTensorData(), - current_length); - break; - } #if USE_CUDA case DeviceType::CUDA: { int max_seq_len = sb_attention_mask_ ? state_.params_->search.max_length : current_length; @@ -304,6 +294,18 @@ void PositionInputs::UpdateAttentionMask(int current_length) { break; } #endif + case DeviceType::WEBGPU: + case DeviceType::CPU: { + if (type_ == Ort::TypeToTensorType) + UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), + attention_mask_->GetTensorData(), + current_length); + else + UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData(), + attention_mask_->GetTensorData(), + current_length); + break; + } default: throw std::runtime_error("PositionIDs::Update - Unsupported device type"); } diff --git a/src/python/python.cpp b/src/python/python.cpp index 9bf4836f9..06e9b8e3c 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -565,6 +565,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { m.def("is_cuda_available", []() { return USE_CUDA != 0; }); m.def("is_dml_available", []() { return USE_DML != 0; }); m.def("is_rocm_available", []() { return USE_ROCM != 0; }); + m.def("is_webgpu_available", []() { return USE_WEBGPU != 0; }); m.def("set_current_gpu_device_id", [](int device_id) { Ort::SetCurrentGpuDeviceId(device_id); }); m.def("get_current_gpu_device_id", []() { return Ort::GetCurrentGpuDeviceId(); });