From cbb5aebd85f514ccd6445afb14a1281c6f98ad67 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:31:24 -0700 Subject: [PATCH 1/6] C-API --- src/ort_genai_c.h | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index e4e72fe6f..331ed26a8 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -50,6 +50,7 @@ typedef enum OgaElementType { typedef struct OgaResult OgaResult; typedef struct OgaGeneratorParams OgaGeneratorParams; typedef struct OgaGenerator OgaGenerator; +typedef struct OgaRuntimeSettings OgaRuntimeSettings; typedef struct OgaModel OgaModel; // OgaSequences is an array of token arrays where the number of token arrays can be obtained using // OgaSequencesCount and the number of tokens in each token array can be obtained using OgaSequencesGetSequenceCount. @@ -149,6 +150,27 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaLoadAudios(const OgaStringArray* audio_pat OGA_EXPORT void OGA_API_CALL OgaDestroyAudios(OgaAudios* audios); +/* + * \brief Creates a runtime settings instance to be used to create a model. + * \param[out] out The created runtime settings. + * \return OgaResult containing the error message if the creation of the runtime settings failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateRuntimeSettings(OgaRuntimeSettings** out); +/* + * \brief Destroys the given runtime settings. + * \param[in] settings The runtime settings to be destroyed. + */ +OGA_EXPORT void OGA_API_CALL OgaDestroyRuntimeSettings(OgaRuntimeSettings* settings); + +/* + * \brief Sets a specific runtime handle for the runtime settings. + * \param[in] settings The runtime settings to set the device type. + * \param[in] handle_name The name of the handle to set for the runtime settings. + * \param[in] handle The value of handle to set for the runtime settings. + * \return OgaResult containing the error message if the setting of the device type failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaRuntimeSettingsSetHandle(OgaRuntimeSettings* settings, const char* handle_name, void* handle); + /* * \brief Creates a model from the given configuration directory and device type. * \param[in] config_path The path to the model configuration directory. The path is expected to be encoded in UTF-8. @@ -158,6 +180,16 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyAudios(OgaAudios* audios); */ OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out); +/* + * \brief Creates a model from the given configuration directory, runtime settings and device type. + * \param[in] config_path The path to the model configuration directory. The path is expected to be encoded in UTF-8. + * \param[in] settings The runtime settings to use for the model. + * \param[in] device_type The device type to use for the model. + * \param[out] out The created model. + * \return OgaResult containing the error message if the model creation failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModelWithRuntimeSettings(const char* config_path, OgaRuntimeSettings* settings, OgaModel** out); + /* * \brief Destroys the given model. * \param[in] model The model to be destroyed. From 6cdd29fc602ec9285f16b143bb5c3e5bedb89867 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 18 Oct 2024 17:04:34 -0700 Subject: [PATCH 2/6] impl --- src/config.cpp | 24 +++++++++++++++++++++++- src/config.h | 4 +++- src/generators.h | 3 ++- src/models/model.cpp | 4 ++-- src/ort_genai.h | 22 ++++++++++++++++++++++ src/ort_genai_c.cpp | 28 ++++++++++++++++++++++++++-- src/ort_genai_c.h | 2 +- src/runtime_settings.cpp | 9 +++++++++ src/runtime_settings.h | 18 ++++++++++++++++++ 9 files changed, 106 insertions(+), 8 deletions(-) create mode 100644 src/runtime_settings.cpp create mode 100644 src/runtime_settings.h diff --git a/src/config.cpp b/src/config.cpp index 1e07e7f9f..f054e7282 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "generators.h" +#include "runtime_settings.h" #include "json.h" #include #include @@ -659,7 +660,7 @@ void ParseConfig(const fs::path& filename, Config& config) { } } -Config::Config(const fs::path& path) : config_path{path} { +Config::Config(const fs::path& path, const RuntimeSettings* settings) : config_path{path} { ParseConfig(path / "genai_config.json", *this); if (model.context_length == 0) @@ -667,6 +668,27 @@ Config::Config(const fs::path& path) : config_path{path} { if (search.max_length == 0) search.max_length = model.context_length; + + if (settings) { + // Enable the following code after #992 is merged + // https://github.com/microsoft/onnxruntime-genai/pull/992 + + // #if USE_WEBGPU + auto& provider_options = model.decoder.session_options.provider_options; + auto maybe_webgpu_options = std::find_if(provider_options.begin(), + provider_options.end(), + [](Generators::Config::ProviderOptions& po) { + return po.name == "webgpu"; + }); + if (maybe_webgpu_options != provider_options.end()) { + auto it = settings->handles_.find("dawnProcTable"); + if (it != settings->handles_.end()) { + void* dawn_proc_table_handle = it->second; + maybe_webgpu_options->options.emplace_back("dawnProcTable", std::to_string((size_t)(dawn_proc_table_handle))); + } + } + // #endif + } } void Config::AddMapping(const std::string& nominal_name, const std::string& graph_name) { diff --git a/src/config.h b/src/config.h index 4a14cdd6e..16b590c10 100644 --- a/src/config.h +++ b/src/config.h @@ -4,9 +4,11 @@ namespace Generators { +struct RuntimeSettings; + struct Config { Config() = default; - Config(const fs::path& path); + Config(const fs::path& path, const RuntimeSettings* settings); struct Defaults { static constexpr std::string_view InputIdsName = "input_ids"; diff --git a/src/generators.h b/src/generators.h index 2dde56c32..db564319a 100644 --- a/src/generators.h +++ b/src/generators.h @@ -37,6 +37,7 @@ using cudaStream_t = void*; #include "models/debugging.h" #include "config.h" #include "logging.h" +#include "runtime_settings.h" #include "tensor.h" namespace Generators { @@ -134,7 +135,7 @@ std::unique_ptr& GetOrtGlobals(); void Shutdown(); // Do this once at exit, Ort code will fail after this call OrtEnv& GetOrtEnv(); -std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path); +std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path, const RuntimeSettings* settings = nullptr); std::shared_ptr CreateGeneratorParams(const Model& model); std::shared_ptr CreateGeneratorParams(const Config& config); // For benchmarking purposes only std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params); diff --git a/src/models/model.cpp b/src/models/model.cpp index 25ad3f36f..8816e4a03 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -496,8 +496,8 @@ std::shared_ptr Model::CreateMultiModalProcessor() const { return std::make_shared(*config_, *session_info_); } -std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path) { - auto config = std::make_unique(fs::path(config_path)); +std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path, const RuntimeSettings* settings /*= nullptr*/) { + auto config = std::make_unique(fs::path(config_path), settings); if (config->model.type == "gpt2") return std::make_shared(std::move(config), ort_env); diff --git a/src/ort_genai.h b/src/ort_genai.h index beffd0730..424e7aad7 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -59,12 +59,34 @@ inline void OgaCheckResult(OgaResult* result) { } } +struct OgaRuntimeSettings : OgaAbstract { + static std::unique_ptr Create() { + OgaRuntimeSettings* p; + OgaCheckResult(OgaCreateRuntimeSettings(&p)); + return std::unique_ptr(p); + } + + void SetHandle(const char* name, void* handle) { + OgaCheckResult(OgaRuntimeSettingsSetHandle(this, name, handle)); + } + void SetHandle(const std::string& name, void* handle) { + SetHandle(name.c_str(), handle); + } + + static void operator delete(void* p) { OgaDestroyRuntimeSettings(reinterpret_cast(p)); } +}; + struct OgaModel : OgaAbstract { static std::unique_ptr Create(const char* config_path) { OgaModel* p; OgaCheckResult(OgaCreateModel(config_path, &p)); return std::unique_ptr(p); } + static std::unique_ptr Create(const char* config_path, const OgaRuntimeSettings& settings) { + OgaModel* p; + OgaCheckResult(OgaCreateModelWithRuntimeSettings(config_path, &settings, &p)); + return std::unique_ptr(p); + } std::unique_ptr Generate(const OgaGeneratorParams& params) const { OgaSequences* p; diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 629d6e962..9483fc23b 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -8,6 +8,7 @@ #include "ort_genai_c.h" #include "generators.h" #include "models/model.h" +#include "runtime_settings.h" #include "search.h" namespace Generators { @@ -134,15 +135,26 @@ OgaResult* OGA_API_CALL OgaLoadAudios(const OgaStringArray* audio_paths, OgaAudi OGA_CATCH } -OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out) { +OgaResult* OGA_API_CALL OgaCreateRuntimeSettings(OgaRuntimeSettings** out) { + OGA_TRY + *out = reinterpret_cast(Generators::CreateRuntimeSettings().release()); + return nullptr; + OGA_CATCH +} + +OgaResult* OGA_API_CALL OgaCreateModelWithRuntimeSettings(const char* config_path, const OgaRuntimeSettings* settings, OgaModel** out) { OGA_TRY - auto model = Generators::CreateModel(Generators::GetOrtEnv(), config_path); + auto model = Generators::CreateModel(Generators::GetOrtEnv(), config_path, reinterpret_cast(settings)); model->external_owner_ = model; *out = reinterpret_cast(model.get()); return nullptr; OGA_CATCH } +OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out) { + return OgaCreateModelWithRuntimeSettings(config_path, nullptr, out); +} + OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* model, OgaGeneratorParams** out) { OGA_TRY auto params = std::make_shared(*reinterpret_cast(model)); @@ -152,6 +164,14 @@ OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* model, OgaGener OGA_CATCH } +OgaResult* OGA_API_CALL OgaRuntimeSettingsSetHandle(OgaRuntimeSettings* settings, const char* handle_name, void* handle) { + OGA_TRY + Generators::RuntimeSettings* settings_ = reinterpret_cast(settings); + settings_->handles_[handle_name] = handle; + return nullptr; + OGA_CATCH +} + OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchNumber(OgaGeneratorParams* generator_params, const char* name, double value) { OGA_TRY Generators::SetSearchNumber(reinterpret_cast(generator_params)->search, name, value); @@ -623,4 +643,8 @@ void OGA_API_CALL OgaDestroyNamedTensors(OgaNamedTensors* p) { void OGA_API_CALL OgaDestroyAdapters(OgaAdapters* p) { reinterpret_cast(p)->external_owner_ = nullptr; } + +void OGA_API_CALL OgaDestroyRuntimeSettings(OgaRuntimeSettings* p) { + delete reinterpret_cast(p); +} } diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index 331ed26a8..b00494d7d 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -188,7 +188,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaMo * \param[out] out The created model. * \return OgaResult containing the error message if the model creation failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModelWithRuntimeSettings(const char* config_path, OgaRuntimeSettings* settings, OgaModel** out); +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModelWithRuntimeSettings(const char* config_path, const OgaRuntimeSettings* settings, OgaModel** out); /* * \brief Destroys the given model. diff --git a/src/runtime_settings.cpp b/src/runtime_settings.cpp new file mode 100644 index 000000000..7afc07ac5 --- /dev/null +++ b/src/runtime_settings.cpp @@ -0,0 +1,9 @@ +#include "runtime_settings.h" + +namespace Generators { + +std::unique_ptr CreateRuntimeSettings() { + return std::make_unique(); +} + +} // namespace Generators \ No newline at end of file diff --git a/src/runtime_settings.h b/src/runtime_settings.h new file mode 100644 index 000000000..5f08a1d8a --- /dev/null +++ b/src/runtime_settings.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include +#include + +namespace Generators { + +// This struct should only be used for runtime settings that are not able to be put into config. +struct RuntimeSettings { + RuntimeSettings() = default; + + std::unordered_map handles_; +}; + +std::unique_ptr CreateRuntimeSettings(); + +} // namespace Generators \ No newline at end of file From 17cc9bba59f6352c8a8b072c8c9960255de08a04 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 21 Oct 2024 18:54:31 -0700 Subject: [PATCH 3/6] alternative way to implement --- src/config.cpp | 38 ++++++++++++++++---------------------- src/config.h | 2 +- src/models/model.cpp | 9 ++++++++- src/runtime_settings.cpp | 34 ++++++++++++++++++++++++++++++++++ src/runtime_settings.h | 2 ++ 5 files changed, 61 insertions(+), 24 deletions(-) diff --git a/src/config.cpp b/src/config.cpp index f054e7282..84b31f30d 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -660,35 +660,29 @@ void ParseConfig(const fs::path& filename, Config& config) { } } -Config::Config(const fs::path& path, const RuntimeSettings* settings) : config_path{path} { +void ParseConfig(std::string_view json, Config& config) { + Root_Element root{config}; + RootObject_Element root_object{root}; + try { + JSON::Parse(root_object, json); + } catch (const std::exception& message) { + std::ostringstream oss; + oss << "Error encountered while parsing JSON " << message.what(); + throw std::runtime_error(oss.str()); + } +} + +Config::Config(const fs::path& path, std::string_view json_overlay) : config_path{path} { ParseConfig(path / "genai_config.json", *this); + if (!json_overlay.empty()) { + ParseConfig(json_overlay, *this); + } if (model.context_length == 0) throw std::runtime_error("model context_length is 0 or was not set. It must be greater than 0"); if (search.max_length == 0) search.max_length = model.context_length; - - if (settings) { - // Enable the following code after #992 is merged - // https://github.com/microsoft/onnxruntime-genai/pull/992 - - // #if USE_WEBGPU - auto& provider_options = model.decoder.session_options.provider_options; - auto maybe_webgpu_options = std::find_if(provider_options.begin(), - provider_options.end(), - [](Generators::Config::ProviderOptions& po) { - return po.name == "webgpu"; - }); - if (maybe_webgpu_options != provider_options.end()) { - auto it = settings->handles_.find("dawnProcTable"); - if (it != settings->handles_.end()) { - void* dawn_proc_table_handle = it->second; - maybe_webgpu_options->options.emplace_back("dawnProcTable", std::to_string((size_t)(dawn_proc_table_handle))); - } - } - // #endif - } } void Config::AddMapping(const std::string& nominal_name, const std::string& graph_name) { diff --git a/src/config.h b/src/config.h index 16b590c10..37e1837a2 100644 --- a/src/config.h +++ b/src/config.h @@ -8,7 +8,7 @@ struct RuntimeSettings; struct Config { Config() = default; - Config(const fs::path& path, const RuntimeSettings* settings); + Config(const fs::path& path, std::string_view json_overlay); struct Defaults { static constexpr std::string_view InputIdsName = "input_ids"; diff --git a/src/models/model.cpp b/src/models/model.cpp index 30509d81b..fc3e09bb5 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -474,6 +474,9 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_ } else if (provider_options.name == "web") { device_type_ = DeviceType::WEBGPU; std::unordered_map opts; + for (auto& option : provider_options.options) { + opts.emplace(option.first, option.second); + } session_options.AppendExecutionProvider("WebGPU", opts); } else throw std::runtime_error("Unknown provider type: " + provider_options.name); @@ -511,7 +514,11 @@ std::shared_ptr Model::CreateMultiModalProcessor() const { } std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path, const RuntimeSettings* settings /*= nullptr*/) { - auto config = std::make_unique(fs::path(config_path), settings); + std::string config_overlay; + if (settings) { + config_overlay = settings->GenerateConfigOverlay(); + } + auto config = std::make_unique(fs::path(config_path), config_overlay); if (config->model.type == "gpt2") return std::make_shared(std::move(config), ort_env); diff --git a/src/runtime_settings.cpp b/src/runtime_settings.cpp index 7afc07ac5..570863e2d 100644 --- a/src/runtime_settings.cpp +++ b/src/runtime_settings.cpp @@ -6,4 +6,38 @@ std::unique_ptr CreateRuntimeSettings() { return std::make_unique(); } +std::string RuntimeSettings::GenerateConfigOverlay() const { + // #if USE_WEBGPU + constexpr std::string_view webgpu_overlay_pre = R"({ + "model": { + "decoder": { + "session_options": { + "provider_options": [ + { + "webgpu": { + "dawnProcTable": ")"; + constexpr std::string_view webgpu_overlay_post = R"(" + } + } + ] + } + } + } +} +)"; + + auto it = handles_.find("dawnProcTable"); + if (it != handles_.end()) { + void* dawn_proc_table_handle = it->second; + std::string overlay; + overlay.reserve(webgpu_overlay_pre.size() + webgpu_overlay_post.size() + 20); + overlay += webgpu_overlay_pre; + overlay += std::to_string((size_t)(dawn_proc_table_handle)); + overlay += webgpu_overlay_post; + return overlay; + } + + return {}; +} + } // namespace Generators \ No newline at end of file diff --git a/src/runtime_settings.h b/src/runtime_settings.h index 5f08a1d8a..c5d76d526 100644 --- a/src/runtime_settings.h +++ b/src/runtime_settings.h @@ -10,6 +10,8 @@ namespace Generators { struct RuntimeSettings { RuntimeSettings() = default; + std::string GenerateConfigOverlay() const; + std::unordered_map handles_; }; From eb4d20dfe21309befe9f11c2c4f8f2c9a7bb877e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 22 Oct 2024 18:28:50 -0700 Subject: [PATCH 4/6] web -> webgpu --- src/models/model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models/model.cpp b/src/models/model.cpp index fc3e09bb5..a5f549fb1 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -471,7 +471,7 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_ opts.emplace(option.first, option.second); } session_options.AppendExecutionProvider("QNN", opts); - } else if (provider_options.name == "web") { + } else if (provider_options.name == "webgpu") { device_type_ = DeviceType::WEBGPU; std::unordered_map opts; for (auto& option : provider_options.options) { From 2449d9a076321d55bbed10d3cb870333d63f1346 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 24 Oct 2024 00:16:59 -0700 Subject: [PATCH 5/6] Update src/runtime_settings.cpp Co-authored-by: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> --- src/runtime_settings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime_settings.cpp b/src/runtime_settings.cpp index 570863e2d..c6a64c854 100644 --- a/src/runtime_settings.cpp +++ b/src/runtime_settings.cpp @@ -30,7 +30,7 @@ std::string RuntimeSettings::GenerateConfigOverlay() const { if (it != handles_.end()) { void* dawn_proc_table_handle = it->second; std::string overlay; - overlay.reserve(webgpu_overlay_pre.size() + webgpu_overlay_post.size() + 20); + overlay.reserve(webgpu_overlay_pre.size() + webgpu_overlay_post.size() + 20); // Optional small optimization of buffer size overlay += webgpu_overlay_pre; overlay += std::to_string((size_t)(dawn_proc_table_handle)); overlay += webgpu_overlay_post; From d24a70cbfc0d5988ca9c98b241ba9af75aea8e6c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 24 Oct 2024 00:24:49 -0700 Subject: [PATCH 6/6] resolve comments --- src/config.cpp | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/config.cpp b/src/config.cpp index 205716399..5603ed4ba 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -641,7 +641,7 @@ struct RootObject_Element : JSON::Element { JSON::Element& t_; }; -void ParseConfig(const fs::path& filename, Config& config) { +void ParseConfig(const fs::path& filename, std::string_view json_overlay, Config& config) { std::ifstream file = filename.open(std::ios::binary | std::ios::ate); if (!file.is_open()) { throw std::runtime_error("Error opening " + filename.string()); @@ -663,25 +663,20 @@ void ParseConfig(const fs::path& filename, Config& config) { oss << "Error encountered while parsing '" << filename.string() << "' " << message.what(); throw std::runtime_error(oss.str()); } -} -void ParseConfig(std::string_view json, Config& config) { - Root_Element root{config}; - RootObject_Element root_object{root}; - try { - JSON::Parse(root_object, json); - } catch (const std::exception& message) { - std::ostringstream oss; - oss << "Error encountered while parsing JSON " << message.what(); - throw std::runtime_error(oss.str()); + if (!json_overlay.empty()) { + try { + JSON::Parse(root_object, json_overlay); + } catch (const std::exception& message) { + std::ostringstream oss; + oss << "Error encountered while parsing config overlay: " << message.what(); + throw std::runtime_error(oss.str()); + } } } Config::Config(const fs::path& path, std::string_view json_overlay) : config_path{path} { - ParseConfig(path / "genai_config.json", *this); - if (!json_overlay.empty()) { - ParseConfig(json_overlay, *this); - } + ParseConfig(path / "genai_config.json", json_overlay, *this); if (model.context_length == 0) throw std::runtime_error("model context_length is 0 or was not set. It must be greater than 0");