Skip to content

Commit

Permalink
initial webgpu support (#992)
Browse files Browse the repository at this point in the history
minimal changes to run genai on top of webgpu.
There is no onnxruntime webgpu build yet so I can't turn on any tests
that use webgpu or do a release build.
Will do that in a 2nd pr once the webgpu landed in onnxruntime,
  • Loading branch information
guschmue authored Oct 21, 2024
1 parent 8e7f92c commit 1af24b7
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 39 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ include(cmake/check_cuda.cmake)
include(cmake/check_rocm.cmake)
# Checking if DML is supported
include(cmake/check_dml.cmake)
# Checking if WebGpu is supported
include(cmake/check_webgpu.cmake)

include(cmake/cxx_standard.cmake)

Expand Down
3 changes: 3 additions & 0 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescript

parser.add_argument("--use_rocm", action="store_true", help="Whether to use ROCm. Default is to not use rocm.")

parser.add_argument("--use_webgpu", action="store_true", help="Whether to use WebGpu. Default is to not use WebGpu.")

parser.add_argument("--use_dml", action="store_true", help="Whether to use DML. Default is to not use DML.")

# The following options are mutually exclusive (cross compiling options such as android, ios, etc.)
Expand Down Expand Up @@ -471,6 +473,7 @@ def update(args: argparse.Namespace, env: dict[str, str]):
"-DCMAKE_POSITION_INDEPENDENT_CODE=ON",
f"-DUSE_CUDA={'ON' if args.use_cuda else 'OFF'}",
f"-DUSE_ROCM={'ON' if args.use_rocm else 'OFF'}",
f"-DUSE_WEBGPU={'ON' if args.use_webgpu else 'OFF'}",
f"-DUSE_DML={'ON' if args.use_dml else 'OFF'}",
f"-DENABLE_JAVA={'ON' if args.build_java else 'OFF'}",
f"-DBUILD_WHEEL={build_wheel}",
Expand Down
6 changes: 6 additions & 0 deletions cmake/check_webgpu.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

if(USE_WEBGPU)
add_compile_definitions(USE_WEBGPU=1)
else()
add_compile_definitions(USE_WEBGPU=0)
endif()
1 change: 1 addition & 0 deletions cmake/options.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ include(CMakeDependentOption)
option(USE_CUDA "Build with CUDA support" ON)
option(USE_ROCM "Build with ROCm support" ON)
option(USE_DML "Build with DML support" OFF)
option(USE_WEBGPU "Build with WEBGPU support" ON)

# bindings
option(ENABLE_JAVA "Build the Java API." OFF)
Expand Down
2 changes: 2 additions & 0 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ std::string to_string(DeviceType device_type) {
return "CUDA";
case DeviceType::DML:
return "DirectML";
case DeviceType::WEBGPU:
return "WebGpu";
}
throw std::runtime_error("Unknown device type");
}
Expand Down
1 change: 1 addition & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ enum struct DeviceType {
CPU,
CUDA,
DML,
WEBGPU,
};

std::string to_string(DeviceType device_type);
Expand Down
3 changes: 2 additions & 1 deletion src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ void InputIDs::Update(RoamingArray<int32_t> next_tokens_unk) {
input_ids_cast_command_list_state_);
#endif
} break;
case DeviceType::CPU: {
default: {
// CPU, WEBGPU
auto* data = value_->GetTensorMutableData<int64_t>();
auto next_tokens = next_tokens_unk.GetCPU();
for (int i = 0; i < shape_[0]; i++) {
Expand Down
20 changes: 10 additions & 10 deletions src/models/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ KV_Cache_Combined::KV_Cache_Combined(State& state)
// Derive the KV data type from the KV input 0
type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]);

empty_past_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
empty_past_ = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_);
shape_[3] = state_.params_->sequence_length;

for (int i = 0; i < layer_count_; ++i) {
presents_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_));
presents_.push_back(OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_));
}
}

Expand Down Expand Up @@ -67,7 +67,7 @@ void KV_Cache_Combined::Update(std::span<const int32_t> beam_indices, int curren

shape_[3] = current_length;
for (int i = 0; i < layer_count_; i++) {
presents_[i] = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
presents_[i] = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_);
state_.inputs_[input_index_ + i] = pasts_[i].get();
state_.outputs_[output_index_ + i] = presents_[i].get();
}
Expand All @@ -81,7 +81,7 @@ void KV_Cache_Combined::PickPastState(std::span<const int32_t> beam_indices, int
auto element_count = shape_[0] * past_key_size;

const OrtValue& present = *presents_[index];
std::unique_ptr<OrtValue> past = OrtValue::CreateTensor<ScoreType>(*model_.allocator_device_, shape_);
std::unique_ptr<OrtValue> past = OrtValue::CreateTensor<ScoreType>(*model_.allocator_kvcache_, shape_);
auto past_span = std::span<ScoreType>(past->GetTensorMutableData<ScoreType>(), element_count);
auto present_span = std::span<const ScoreType>(present.GetTensorData<ScoreType>(), element_count);

Expand Down Expand Up @@ -149,7 +149,7 @@ KV_Cache::KV_Cache(State& state)
// Derive the KV data type from the KV input 0
type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]);

empty_past_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
empty_past_ = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_);

// Set the size after empty_past_ has been created with 0 for this field
if (past_present_share_buffer_)
Expand All @@ -167,7 +167,7 @@ KV_Cache::KV_Cache(State& state)

for (int i = 0; i < layer_count_ * 2; ++i) {
presents_.push_back(
sb_kv_caches_.empty() ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_)
sb_kv_caches_.empty() ? OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_)
: sb_kv_caches_[i]->CreateTensorOnStaticBuffer(shape_, type_));
}
}
Expand Down Expand Up @@ -216,7 +216,7 @@ void KV_Cache::Update(std::span<const int32_t> beam_indices, int current_length)

shape_[2] = current_length;
for (int i = 0; i < layer_count_ * 2; i++) {
presents_[i] = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
presents_[i] = OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_);
state_.outputs_[output_index_ + i] = presents_[i].get();
}
}
Expand All @@ -228,7 +228,7 @@ void KV_Cache::PickPastState(std::span<const int32_t> beam_indices, int index) {
auto element_count = shape_[0] * block_size_per_beam;

const OrtValue& present_value = *presents_[index];
std::unique_ptr<OrtValue> past_value = OrtValue::CreateTensor<ScoreType>(*model_.allocator_device_, shape_);
std::unique_ptr<OrtValue> past_value = OrtValue::CreateTensor<ScoreType>(*model_.allocator_kvcache_, shape_);
auto past_span = std::span<ScoreType>(past_value->GetTensorMutableData<ScoreType>(), element_count);
auto present_span = std::span<const ScoreType>(present_value.GetTensorData<ScoreType>(), element_count);

Expand Down Expand Up @@ -280,8 +280,8 @@ Cross_Cache::Cross_Cache(State& state)
type_ = model_.session_info_->GetInputDataType(input_name_strings_[0]);

for (int i = 0; i < layer_count_; ++i) {
values_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_));
values_.push_back(OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_));
values_.push_back(OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_));
values_.push_back(OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_));
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ RoamingArray<float> Logits::Get() {
#endif
} break;

case DeviceType::CPU:
case DeviceType::CUDA: {
default: {
// CPU, CUDA, WEBGPU
auto logits_raw = std::span<const uint8_t>{output_raw_->GetTensorMutableData<uint8_t>(), element_count * element_size};
auto logits_last_tokens = std::span<uint8_t>{logits_of_last_token->GetTensorMutableData<uint8_t>(), element_count_last_token * element_size};
auto target = logits_last_tokens.subspan(vocab_index * element_size, vocab_size * element_size);
Expand Down
26 changes: 20 additions & 6 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,14 +282,23 @@ void Model::InitDeviceAllocator([[maybe_unused]] OrtSession& session) {
if (device_type_ == DeviceType::CUDA) {
allocator_device_ = GetCudaAllocator(session);
}
#elif USE_DML
#endif
#if USE_DML
if (device_type_ == DeviceType::DML) {
memory_info_device_ = OrtMemoryInfo::Create("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
dml_owned_allocator_ = Ort::Allocator::Create(session, *memory_info_device_);
allocator_device_ = dml_owned_allocator_.get();
}
#endif

allocator_kvcache_ = allocator_device_;
#if USE_WEBGPU
if (device_type_ == DeviceType::WEBGPU) {
// for webgpu we only use device memory for kv_cache
memory_info_device_ = OrtMemoryInfo::Create("WebGPU_Buffer", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault);
webgpu_owned_allocator_ = Ort::Allocator::Create(session, *memory_info_device_);
allocator_kvcache_ = webgpu_owned_allocator_.get();
}
#endif
session_info_ = std::make_unique<SessionInfo>(session);
captured_graph_pool_ = std::make_shared<CapturedGraphPool>(config_.get(), session_info_.get(), allocator_device_);
}
Expand Down Expand Up @@ -461,8 +470,11 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
for (auto& option : provider_options.options) {
opts.emplace(option.first, option.second);
}

session_options.AppendExecutionProvider("QNN", opts);
} else if (provider_options.name == "web") {
device_type_ = DeviceType::WEBGPU;
std::unordered_map<std::string, std::string> opts;
session_options.AppendExecutionProvider("WebGPU", opts);
} else
throw std::runtime_error("Unknown provider type: " + provider_options.name);
}
Expand Down Expand Up @@ -544,8 +556,9 @@ void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr<Or
auto* fp32 = p_out->GetTensorMutableData<float>();

switch (device_type) {
case DeviceType::WEBGPU:
case DeviceType::DML:
// DML doesn't currently support on-device scoring, so we fall back to the CPU
// DML, WebGpu doesn't currently support on-device scoring, so we fall back to the CPU
case DeviceType::CPU:
for (int i = 0; i < count; i++)
fp32[i] = FastFloat16ToFloat32(fp16[i]);
Expand Down Expand Up @@ -605,7 +618,7 @@ std::unique_ptr<OrtValue> Model::ExpandInputs(std::unique_ptr<OrtValue>& input,

// If we're on CUDA, we still want to do the copy to move the data over to CUDA memory where we will read from it later.
// DML doesn't currently support on-device scoring, so we go the same route as the CPU
if (num_beams == 1 && (device_type_ == DeviceType::CPU || device_type_ == DeviceType::DML)) {
if (num_beams == 1 && (device_type_ == DeviceType::CPU || device_type_ == DeviceType::DML || device_type_ == DeviceType::WEBGPU)) {
return std::move(input);
}

Expand All @@ -625,8 +638,9 @@ std::unique_ptr<OrtValue> Model::ExpandInputs(std::unique_ptr<OrtValue>& input,
auto* target = expanded_data;

switch (device_type_) {
case DeviceType::WEBGPU:
case DeviceType::DML:
// DML doesn't currently support on-device scoring, so we use the CPU for non-cache inputs/outputs
// DML and WebGpu doesn't currently support on-device scoring, so we use the CPU for non-cache inputs/outputs
case DeviceType::CPU:
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < num_beams; j++) {
Expand Down
10 changes: 8 additions & 2 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {
cuda_stream_holder cuda_stream_;
DeviceType device_type_{DeviceType::CPU};
Ort::Allocator& allocator_cpu_{Ort::Allocator::GetWithDefaultOptions()};
Ort::Allocator* allocator_device_{}; // Can be CUDA or CPU based on the DeviceType in the model
Ort::Allocator* allocator_device_{}; // Can be CUDA or CPU based on the DeviceType in the model
Ort::Allocator* allocator_kvcache_{}; // keep allocator for kv_cache seperate to allow that only kv_cache is on device

std::unique_ptr<SessionInfo> session_info_;

Expand Down Expand Up @@ -175,9 +176,14 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model> {
std::unique_ptr<DmlReadbackHeap> dml_readback_heap_;
ComPtr<IDMLDevice> dml_device_;
std::unique_ptr<Ort::Allocator> dml_owned_allocator_;
#endif
#if USE_WEBGPU
std::unique_ptr<Ort::Allocator> webgpu_owned_allocator_;
std::unique_ptr<OrtIoBinding> webgpu_io_binding_;
#endif
#if USE_DML || USE_WEBGPU
std::unique_ptr<OrtMemoryInfo> memory_info_device_;
#endif

std::shared_ptr<CapturedGraphPool> captured_graph_pool_;
std::map<std::string, std::unique_ptr<OrtSessionOptions>> pipeline_session_options_;
};
Expand Down
38 changes: 20 additions & 18 deletions src/models/position_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,6 @@ void PositionInputs::UpdatePositionIDs(int current_length) {
model_.GetDmlExecutionContext()->ExecuteCommandList(dml_update_position_ids_kernel_->GetCommandList(), &fence, &completion_value);
} break;
#endif
case DeviceType::CPU: {
if (type_ == Ort::TypeToTensorType<int32_t>)
UpdatePositionIDsImpl<int32_t>();
else
UpdatePositionIDsImpl<int64_t>();
break;
}
#if USE_CUDA
case DeviceType::CUDA:
if (type_ == Ort::TypeToTensorType<int32_t>)
Expand All @@ -187,6 +180,14 @@ void PositionInputs::UpdatePositionIDs(int current_length) {
cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData<int64_t>(), static_cast<int>(position_ids_shape_[0]), model_.cuda_stream_);
break;
#endif
case DeviceType::CPU:
case DeviceType::WEBGPU: {
if (type_ == Ort::TypeToTensorType<int32_t>)
UpdatePositionIDsImpl<int32_t>();
else
UpdatePositionIDsImpl<int64_t>();
break;
}
default:
throw std::runtime_error("PositionIDs::Update - Unsupported device type");
}
Expand Down Expand Up @@ -269,17 +270,6 @@ void PositionInputs::UpdateAttentionMask(int current_length) {
break;
}
#endif
case DeviceType::CPU: {
if (type_ == Ort::TypeToTensorType<int32_t>)
UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData<int32_t>(),
attention_mask_->GetTensorData<int32_t>(),
current_length);
else
UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData<int64_t>(),
attention_mask_->GetTensorData<int64_t>(),
current_length);
break;
}
#if USE_CUDA
case DeviceType::CUDA: {
int max_seq_len = sb_attention_mask_ ? state_.params_->search.max_length : current_length;
Expand All @@ -304,6 +294,18 @@ void PositionInputs::UpdateAttentionMask(int current_length) {
break;
}
#endif
case DeviceType::WEBGPU:
case DeviceType::CPU: {
if (type_ == Ort::TypeToTensorType<int32_t>)
UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData<int32_t>(),
attention_mask_->GetTensorData<int32_t>(),
current_length);
else
UpdateAttentionMaskImpl(attention_mask_next_->GetTensorMutableData<int64_t>(),
attention_mask_->GetTensorData<int64_t>(),
current_length);
break;
}
default:
throw std::runtime_error("PositionIDs::Update - Unsupported device type");
}
Expand Down
1 change: 1 addition & 0 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
m.def("is_cuda_available", []() { return USE_CUDA != 0; });
m.def("is_dml_available", []() { return USE_DML != 0; });
m.def("is_rocm_available", []() { return USE_ROCM != 0; });
m.def("is_webgpu_available", []() { return USE_WEBGPU != 0; });

m.def("set_current_gpu_device_id", [](int device_id) { Ort::SetCurrentGpuDeviceId(device_id); });
m.def("get_current_gpu_device_id", []() { return Ort::GetCurrentGpuDeviceId(); });
Expand Down

0 comments on commit 1af24b7

Please sign in to comment.