diff --git a/src/config.cpp b/src/config.cpp index 451fd5b64..5603ed4ba 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "generators.h" +#include "runtime_settings.h" #include "json.h" #include #include @@ -640,7 +641,7 @@ struct RootObject_Element : JSON::Element { JSON::Element& t_; }; -void ParseConfig(const fs::path& filename, Config& config) { +void ParseConfig(const fs::path& filename, std::string_view json_overlay, Config& config) { std::ifstream file = filename.open(std::ios::binary | std::ios::ate); if (!file.is_open()) { throw std::runtime_error("Error opening " + filename.string()); @@ -662,10 +663,20 @@ void ParseConfig(const fs::path& filename, Config& config) { oss << "Error encountered while parsing '" << filename.string() << "' " << message.what(); throw std::runtime_error(oss.str()); } + + if (!json_overlay.empty()) { + try { + JSON::Parse(root_object, json_overlay); + } catch (const std::exception& message) { + std::ostringstream oss; + oss << "Error encountered while parsing config overlay: " << message.what(); + throw std::runtime_error(oss.str()); + } + } } -Config::Config(const fs::path& path) : config_path{path} { - ParseConfig(path / "genai_config.json", *this); +Config::Config(const fs::path& path, std::string_view json_overlay) : config_path{path} { + ParseConfig(path / "genai_config.json", json_overlay, *this); if (model.context_length == 0) throw std::runtime_error("model context_length is 0 or was not set. It must be greater than 0"); diff --git a/src/config.h b/src/config.h index 4a14cdd6e..37e1837a2 100644 --- a/src/config.h +++ b/src/config.h @@ -4,9 +4,11 @@ namespace Generators { +struct RuntimeSettings; + struct Config { Config() = default; - Config(const fs::path& path); + Config(const fs::path& path, std::string_view json_overlay); struct Defaults { static constexpr std::string_view InputIdsName = "input_ids"; diff --git a/src/generators.h b/src/generators.h index 4350d5f2d..c36ed6b33 100644 --- a/src/generators.h +++ b/src/generators.h @@ -37,6 +37,7 @@ using cudaStream_t = void*; #include "models/debugging.h" #include "config.h" #include "logging.h" +#include "runtime_settings.h" #include "tensor.h" namespace Generators { @@ -135,7 +136,7 @@ std::unique_ptr& GetOrtGlobals(); void Shutdown(); // Do this once at exit, Ort code will fail after this call OrtEnv& GetOrtEnv(); -std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path); +std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path, const RuntimeSettings* settings = nullptr); std::shared_ptr CreateGeneratorParams(const Model& model); std::shared_ptr CreateGeneratorParams(const Config& config); // For benchmarking purposes only std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params); diff --git a/src/models/model.cpp b/src/models/model.cpp index 93338b7fa..a5f549fb1 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -471,9 +471,12 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_ opts.emplace(option.first, option.second); } session_options.AppendExecutionProvider("QNN", opts); - } else if (provider_options.name == "web") { + } else if (provider_options.name == "webgpu") { device_type_ = DeviceType::WEBGPU; std::unordered_map opts; + for (auto& option : provider_options.options) { + opts.emplace(option.first, option.second); + } session_options.AppendExecutionProvider("WebGPU", opts); } else throw std::runtime_error("Unknown provider type: " + provider_options.name); @@ -510,8 +513,12 @@ std::shared_ptr Model::CreateMultiModalProcessor() const { return std::make_shared(*config_, *session_info_); } -std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path) { - auto config = std::make_unique(fs::path(config_path)); +std::shared_ptr CreateModel(OrtEnv& ort_env, const char* config_path, const RuntimeSettings* settings /*= nullptr*/) { + std::string config_overlay; + if (settings) { + config_overlay = settings->GenerateConfigOverlay(); + } + auto config = std::make_unique(fs::path(config_path), config_overlay); if (config->model.type == "gpt2") return std::make_shared(std::move(config), ort_env); diff --git a/src/ort_genai.h b/src/ort_genai.h index beffd0730..424e7aad7 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -59,12 +59,34 @@ inline void OgaCheckResult(OgaResult* result) { } } +struct OgaRuntimeSettings : OgaAbstract { + static std::unique_ptr Create() { + OgaRuntimeSettings* p; + OgaCheckResult(OgaCreateRuntimeSettings(&p)); + return std::unique_ptr(p); + } + + void SetHandle(const char* name, void* handle) { + OgaCheckResult(OgaRuntimeSettingsSetHandle(this, name, handle)); + } + void SetHandle(const std::string& name, void* handle) { + SetHandle(name.c_str(), handle); + } + + static void operator delete(void* p) { OgaDestroyRuntimeSettings(reinterpret_cast(p)); } +}; + struct OgaModel : OgaAbstract { static std::unique_ptr Create(const char* config_path) { OgaModel* p; OgaCheckResult(OgaCreateModel(config_path, &p)); return std::unique_ptr(p); } + static std::unique_ptr Create(const char* config_path, const OgaRuntimeSettings& settings) { + OgaModel* p; + OgaCheckResult(OgaCreateModelWithRuntimeSettings(config_path, &settings, &p)); + return std::unique_ptr(p); + } std::unique_ptr Generate(const OgaGeneratorParams& params) const { OgaSequences* p; diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 629d6e962..9483fc23b 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -8,6 +8,7 @@ #include "ort_genai_c.h" #include "generators.h" #include "models/model.h" +#include "runtime_settings.h" #include "search.h" namespace Generators { @@ -134,15 +135,26 @@ OgaResult* OGA_API_CALL OgaLoadAudios(const OgaStringArray* audio_paths, OgaAudi OGA_CATCH } -OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out) { +OgaResult* OGA_API_CALL OgaCreateRuntimeSettings(OgaRuntimeSettings** out) { + OGA_TRY + *out = reinterpret_cast(Generators::CreateRuntimeSettings().release()); + return nullptr; + OGA_CATCH +} + +OgaResult* OGA_API_CALL OgaCreateModelWithRuntimeSettings(const char* config_path, const OgaRuntimeSettings* settings, OgaModel** out) { OGA_TRY - auto model = Generators::CreateModel(Generators::GetOrtEnv(), config_path); + auto model = Generators::CreateModel(Generators::GetOrtEnv(), config_path, reinterpret_cast(settings)); model->external_owner_ = model; *out = reinterpret_cast(model.get()); return nullptr; OGA_CATCH } +OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out) { + return OgaCreateModelWithRuntimeSettings(config_path, nullptr, out); +} + OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* model, OgaGeneratorParams** out) { OGA_TRY auto params = std::make_shared(*reinterpret_cast(model)); @@ -152,6 +164,14 @@ OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* model, OgaGener OGA_CATCH } +OgaResult* OGA_API_CALL OgaRuntimeSettingsSetHandle(OgaRuntimeSettings* settings, const char* handle_name, void* handle) { + OGA_TRY + Generators::RuntimeSettings* settings_ = reinterpret_cast(settings); + settings_->handles_[handle_name] = handle; + return nullptr; + OGA_CATCH +} + OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchNumber(OgaGeneratorParams* generator_params, const char* name, double value) { OGA_TRY Generators::SetSearchNumber(reinterpret_cast(generator_params)->search, name, value); @@ -623,4 +643,8 @@ void OGA_API_CALL OgaDestroyNamedTensors(OgaNamedTensors* p) { void OGA_API_CALL OgaDestroyAdapters(OgaAdapters* p) { reinterpret_cast(p)->external_owner_ = nullptr; } + +void OGA_API_CALL OgaDestroyRuntimeSettings(OgaRuntimeSettings* p) { + delete reinterpret_cast(p); +} } diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index e4e72fe6f..b00494d7d 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -50,6 +50,7 @@ typedef enum OgaElementType { typedef struct OgaResult OgaResult; typedef struct OgaGeneratorParams OgaGeneratorParams; typedef struct OgaGenerator OgaGenerator; +typedef struct OgaRuntimeSettings OgaRuntimeSettings; typedef struct OgaModel OgaModel; // OgaSequences is an array of token arrays where the number of token arrays can be obtained using // OgaSequencesCount and the number of tokens in each token array can be obtained using OgaSequencesGetSequenceCount. @@ -149,6 +150,27 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaLoadAudios(const OgaStringArray* audio_pat OGA_EXPORT void OGA_API_CALL OgaDestroyAudios(OgaAudios* audios); +/* + * \brief Creates a runtime settings instance to be used to create a model. + * \param[out] out The created runtime settings. + * \return OgaResult containing the error message if the creation of the runtime settings failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateRuntimeSettings(OgaRuntimeSettings** out); +/* + * \brief Destroys the given runtime settings. + * \param[in] settings The runtime settings to be destroyed. + */ +OGA_EXPORT void OGA_API_CALL OgaDestroyRuntimeSettings(OgaRuntimeSettings* settings); + +/* + * \brief Sets a specific runtime handle for the runtime settings. + * \param[in] settings The runtime settings to set the device type. + * \param[in] handle_name The name of the handle to set for the runtime settings. + * \param[in] handle The value of handle to set for the runtime settings. + * \return OgaResult containing the error message if the setting of the device type failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaRuntimeSettingsSetHandle(OgaRuntimeSettings* settings, const char* handle_name, void* handle); + /* * \brief Creates a model from the given configuration directory and device type. * \param[in] config_path The path to the model configuration directory. The path is expected to be encoded in UTF-8. @@ -158,6 +180,16 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyAudios(OgaAudios* audios); */ OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModel(const char* config_path, OgaModel** out); +/* + * \brief Creates a model from the given configuration directory, runtime settings and device type. + * \param[in] config_path The path to the model configuration directory. The path is expected to be encoded in UTF-8. + * \param[in] settings The runtime settings to use for the model. + * \param[in] device_type The device type to use for the model. + * \param[out] out The created model. + * \return OgaResult containing the error message if the model creation failed. + */ +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateModelWithRuntimeSettings(const char* config_path, const OgaRuntimeSettings* settings, OgaModel** out); + /* * \brief Destroys the given model. * \param[in] model The model to be destroyed. diff --git a/src/runtime_settings.cpp b/src/runtime_settings.cpp new file mode 100644 index 000000000..c6a64c854 --- /dev/null +++ b/src/runtime_settings.cpp @@ -0,0 +1,43 @@ +#include "runtime_settings.h" + +namespace Generators { + +std::unique_ptr CreateRuntimeSettings() { + return std::make_unique(); +} + +std::string RuntimeSettings::GenerateConfigOverlay() const { + // #if USE_WEBGPU + constexpr std::string_view webgpu_overlay_pre = R"({ + "model": { + "decoder": { + "session_options": { + "provider_options": [ + { + "webgpu": { + "dawnProcTable": ")"; + constexpr std::string_view webgpu_overlay_post = R"(" + } + } + ] + } + } + } +} +)"; + + auto it = handles_.find("dawnProcTable"); + if (it != handles_.end()) { + void* dawn_proc_table_handle = it->second; + std::string overlay; + overlay.reserve(webgpu_overlay_pre.size() + webgpu_overlay_post.size() + 20); // Optional small optimization of buffer size + overlay += webgpu_overlay_pre; + overlay += std::to_string((size_t)(dawn_proc_table_handle)); + overlay += webgpu_overlay_post; + return overlay; + } + + return {}; +} + +} // namespace Generators \ No newline at end of file diff --git a/src/runtime_settings.h b/src/runtime_settings.h new file mode 100644 index 000000000..c5d76d526 --- /dev/null +++ b/src/runtime_settings.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include +#include + +namespace Generators { + +// This struct should only be used for runtime settings that are not able to be put into config. +struct RuntimeSettings { + RuntimeSettings() = default; + + std::string GenerateConfigOverlay() const; + + std::unordered_map handles_; +}; + +std::unique_ptr CreateRuntimeSettings(); + +} // namespace Generators \ No newline at end of file