From f4f04c7768fe864a536c30cb3fd9c8d053f82a31 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Tue, 15 Oct 2024 10:25:53 -0700 Subject: [PATCH 01/17] Add Set terminate Option for user --- src/generators.cpp | 9 +++++++++ src/generators.h | 1 + src/models/model.cpp | 20 ++++++++++++++++++-- src/models/model.h | 2 ++ src/ort_genai.h | 8 ++++++++ src/ort_genai_c.cpp | 16 ++++++++++++++++ src/ort_genai_c.h | 2 ++ 7 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 36f12a2bd..5ed9dff76 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -154,6 +154,9 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ } void Generator::ComputeLogits() { + if (state_->params_->session_terminated) { + throw std::runtime_error("Session in Terminated state, exiting!"); + } if (computed_logits_) throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken first"); @@ -172,6 +175,9 @@ void Generator::ComputeLogits() { } bool Generator::IsDone() const { + if (state_->params_->session_terminated) { + throw std::runtime_error("Session in Terminated state, exiting!"); + } if (computed_logits_) throw std::runtime_error("IsDone() can't be called in the middle of processing logits"); @@ -184,6 +190,9 @@ bool Generator::IsDone() const { } void Generator::GenerateNextToken() { + if (state_->params_->session_terminated) { + throw std::runtime_error("Session in Terminated state, exiting!"); + } if (!computed_logits_) throw std::runtime_error("Must call ComputeLogits before GenerateNextToken"); computed_logits_ = false; diff --git a/src/generators.h b/src/generators.h index 488dd8fa9..e75b03ff1 100644 --- a/src/generators.h +++ b/src/generators.h @@ -66,6 +66,7 @@ struct GeneratorParams : std::enable_shared_from_this, LeakChec int batch_size{1}; int max_batch_size{0}; bool use_cuda_graph{}; + mutable bool session_terminated{false}; int sequence_length{}; int BatchBeamSize() const { return search.num_beams * batch_size; } diff --git a/src/models/model.cpp b/src/models/model.cpp index 33942223b..c6cdce93e 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -37,9 +37,9 @@ static std::string CurrentModulePath() { namespace Generators { -State::State(const GeneratorParams& params, const Model& model) +State::State(const GeneratorParams& params_, const Model& model) : model_{model}, - params_{params.shared_from_this()} {} + params_{params_.shared_from_this()} {} void State::Run(OrtSession& session, OrtRunOptions& run_options, int new_batch_size) { auto captured_graph_info = GetCapturedGraphInfo(); @@ -76,7 +76,20 @@ void State::Run(OrtSession& session, OrtRunOptions& run_options, int new_batch_s } } +void State::SetTerminate() { + params_->session_terminated = true; + model_.run_options_->SetTerminate(); +} + +void State::UnsetTerminate() { + params_->session_terminated = false; + model_.run_options_->UnsetTerminate(); +} + OrtValue* State::GetInput(const char* name) { + if (params_->session_terminated) { + throw std::runtime_error("Session in Terminated state, exiting!"); + } for (size_t i = 0; i < input_names_.size(); i++) { if (std::strcmp(input_names_[i], name) == 0) { return inputs_[i]; @@ -86,6 +99,9 @@ OrtValue* State::GetInput(const char* name) { } OrtValue* State::GetOutput(const char* name) { + if (params_->session_terminated) { + throw std::runtime_error("Session in Terminated state, exiting!"); + } for (size_t i = 0; i < output_names_.size(); i++) { if (std::strcmp(output_names_[i], name) == 0) { return outputs_[i]; diff --git a/src/models/model.h b/src/models/model.h index b7b7bdfb2..9049d1704 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -33,6 +33,8 @@ struct State { virtual const CapturedGraphInfo* GetCapturedGraphInfo() const { return nullptr; } virtual void Finalize() {} + void SetTerminate(); + void UnsetTerminate(); OrtValue* GetInput(const char* name); virtual OrtValue* GetOutput(const char* name); diff --git a/src/ort_genai.h b/src/ort_genai.h index 36b56d9bd..45a6a1a2c 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -244,6 +244,14 @@ struct OgaGenerator : OgaAbstract { OgaCheckResult(OgaGenerator_GenerateNextToken(this)); } + void SetTerminate() { + OgaCheckResult(OgaGenerator_SetTerminate(this)); + } + + void UnsetTerminate() { + OgaCheckResult(OgaGenerator_UnsetTerminate(this)); + } + size_t GetSequenceCount(size_t index) const { return OgaGenerator_GetSequenceCount(this, index); } diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index d392bd361..90ca34930 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -265,6 +265,22 @@ OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator) OGA_CATCH } +OgaResult* OGA_API_CALL OgaGenerator_SetTerminate(OgaGenerator* oga_generator) { + OGA_TRY + auto& generator = *reinterpret_cast(oga_generator); + generator.state_->SetTerminate(); + return nullptr; + OGA_CATCH +} + +OgaResult* OGA_API_CALL OgaGenerator_UnsetTerminate(OgaGenerator* oga_generator) { + OGA_TRY + auto& generator = *reinterpret_cast(oga_generator); + generator.state_->UnsetTerminate(); + return nullptr; + OGA_CATCH +} + OgaResult* OGA_API_CALL OgaGenerator_GetOutput(const OgaGenerator* oga_generator, const char* name, OgaTensor** out) { OGA_TRY auto& generator = *reinterpret_cast(oga_generator); diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index c1d03f8e1..ac2f53d61 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -253,6 +253,8 @@ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator); */ OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator); OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_GenerateNextToken(OgaGenerator* generator); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_SetTerminate(OgaGenerator* generator); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerator_UnsetTerminate(OgaGenerator* generator); /* * \brief Returns a copy of the model output identified by the given name as an OgaTensor on CPU. The buffer is owned by returned OgaTensor From 1b2ba608fbf6b22b2bd69f5aac5792e328eb4187 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Thu, 17 Oct 2024 01:48:36 -0700 Subject: [PATCH 02/17] modify run options --- src/models/model.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/models/model.cpp b/src/models/model.cpp index 0a84f05ea..f38155aa4 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -65,12 +65,12 @@ void State::Run(OrtSession& session, int new_batch_size) { void State::SetTerminate() { params_->session_terminated = true; - model_.run_options_->SetTerminate(); + run_options_->SetTerminate(); } void State::UnsetTerminate() { params_->session_terminated = false; - model_.run_options_->UnsetTerminate(); + run_options_->UnsetTerminate(); } OrtValue* State::GetInput(const char* name) { From 6ab7d6890ef11896bba00b2ad2e739875a7a8210 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Thu, 17 Oct 2024 09:04:11 -0700 Subject: [PATCH 03/17] use common function for throwing error --- src/generators.cpp | 17 ++++++++--------- src/generators.h | 2 ++ src/models/model.cpp | 8 ++------ 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 2872a28f0..2453549df 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -30,6 +30,11 @@ std::string CurrentModulePath() { } #endif +void ThrowErrorIfSessionTerminated(bool is_session_terminated) { + if (is_session_terminated) + throw std::runtime_error("Session in Terminated state, exiting!"); +} + namespace Generators { #if USE_CUDA @@ -282,9 +287,7 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ } void Generator::ComputeLogits() { - if (state_->params_->session_terminated) { - throw std::runtime_error("Session in Terminated state, exiting!"); - } + ThrowErrorIfSessionTerminated(state_->params_->session_terminated); if (computed_logits_) throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken first"); @@ -303,9 +306,7 @@ void Generator::ComputeLogits() { } bool Generator::IsDone() const { - if (state_->params_->session_terminated) { - throw std::runtime_error("Session in Terminated state, exiting!"); - } + ThrowErrorIfSessionTerminated(state_->params_->session_terminated); if (computed_logits_) throw std::runtime_error("IsDone() can't be called in the middle of processing logits"); @@ -318,9 +319,7 @@ bool Generator::IsDone() const { } void Generator::GenerateNextToken() { - if (state_->params_->session_terminated) { - throw std::runtime_error("Session in Terminated state, exiting!"); - } + ThrowErrorIfSessionTerminated(state_->params_->session_terminated); if (!computed_logits_) throw std::runtime_error("Must call ComputeLogits before GenerateNextToken"); computed_logits_ = false; diff --git a/src/generators.h b/src/generators.h index 7cf2f4fda..7631f770c 100644 --- a/src/generators.h +++ b/src/generators.h @@ -39,6 +39,8 @@ using cudaStream_t = void*; #include "logging.h" #include "tensor.h" +void ThrowErrorIfSessionTerminated(bool is_session_terminated); + namespace Generators { struct Model; struct State; diff --git a/src/models/model.cpp b/src/models/model.cpp index f38155aa4..c2b64466b 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -74,9 +74,7 @@ void State::UnsetTerminate() { } OrtValue* State::GetInput(const char* name) { - if (params_->session_terminated) { - throw std::runtime_error("Session in Terminated state, exiting!"); - } + ThrowErrorIfSessionTerminated(params_->session_terminated); for (size_t i = 0; i < input_names_.size(); i++) { if (std::strcmp(input_names_[i], name) == 0) { return inputs_[i]; @@ -86,9 +84,7 @@ OrtValue* State::GetInput(const char* name) { } OrtValue* State::GetOutput(const char* name) { - if (params_->session_terminated) { - throw std::runtime_error("Session in Terminated state, exiting!"); - } + ThrowErrorIfSessionTerminated(params_->session_terminated); for (size_t i = 0; i < output_names_.size(); i++) { if (std::strcmp(output_names_[i], name) == 0) { return outputs_[i]; From 3b55b650982be7c774e07c33fc5faa6610a8a7fc Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Fri, 18 Oct 2024 09:52:55 -0700 Subject: [PATCH 04/17] move var to state params --- src/generators.cpp | 6 +++--- src/generators.h | 1 - src/models/model.cpp | 8 ++++---- src/models/model.h | 1 + 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 2453549df..942be2957 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -287,7 +287,7 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_ } void Generator::ComputeLogits() { - ThrowErrorIfSessionTerminated(state_->params_->session_terminated); + ThrowErrorIfSessionTerminated(state_->session_terminated); if (computed_logits_) throw std::runtime_error("ComputeLogits called again without calling GenerateNextToken first"); @@ -306,7 +306,7 @@ void Generator::ComputeLogits() { } bool Generator::IsDone() const { - ThrowErrorIfSessionTerminated(state_->params_->session_terminated); + ThrowErrorIfSessionTerminated(state_->session_terminated); if (computed_logits_) throw std::runtime_error("IsDone() can't be called in the middle of processing logits"); @@ -319,7 +319,7 @@ bool Generator::IsDone() const { } void Generator::GenerateNextToken() { - ThrowErrorIfSessionTerminated(state_->params_->session_terminated); + ThrowErrorIfSessionTerminated(state_->session_terminated); if (!computed_logits_) throw std::runtime_error("Must call ComputeLogits before GenerateNextToken"); computed_logits_ = false; diff --git a/src/generators.h b/src/generators.h index 7631f770c..6249c9037 100644 --- a/src/generators.h +++ b/src/generators.h @@ -68,7 +68,6 @@ struct GeneratorParams : std::enable_shared_from_this, LeakChec int batch_size{1}; int max_batch_size{0}; bool use_cuda_graph{}; - mutable bool session_terminated{false}; int sequence_length{}; int BatchBeamSize() const { return search.num_beams * batch_size; } diff --git a/src/models/model.cpp b/src/models/model.cpp index d3bb6bc63..e58368874 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -64,17 +64,17 @@ void State::Run(OrtSession& session, int new_batch_size) { } void State::SetTerminate() { - params_->session_terminated = true; + session_terminated = true; run_options_->SetTerminate(); } void State::UnsetTerminate() { - params_->session_terminated = false; + session_terminated = false; run_options_->UnsetTerminate(); } OrtValue* State::GetInput(const char* name) { - ThrowErrorIfSessionTerminated(params_->session_terminated); + ThrowErrorIfSessionTerminated(session_terminated); for (size_t i = 0; i < input_names_.size(); i++) { if (std::strcmp(input_names_[i], name) == 0) { return inputs_[i]; @@ -84,7 +84,7 @@ OrtValue* State::GetInput(const char* name) { } OrtValue* State::GetOutput(const char* name) { - ThrowErrorIfSessionTerminated(params_->session_terminated); + ThrowErrorIfSessionTerminated(session_terminated); for (size_t i = 0; i < output_names_.size(); i++) { if (std::strcmp(output_names_[i], name) == 0) { return outputs_[i]; diff --git a/src/models/model.h b/src/models/model.h index 5dc6a3e17..fd9b86478 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -36,6 +36,7 @@ struct State { void SetTerminate(); void UnsetTerminate(); + mutable bool session_terminated{}; OrtValue* GetInput(const char* name); virtual OrtValue* GetOutput(const char* name); From c60cedebcdc19bacc03c61fac37b81e3b31e4252 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Mon, 21 Oct 2024 00:03:26 -0700 Subject: [PATCH 05/17] add test for set terminate --- test/c_api_tests.cpp | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index c04ef284f..ecdda82b3 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -4,6 +4,8 @@ #include #include #include +#include +#include #ifndef MODEL_PATH #define MODEL_PATH "../../test/test_models/" #endif @@ -281,6 +283,43 @@ TEST(CAPITests, GetOutputCAPI) { generator->GenerateNextToken(); } +TEST(CAPITests, SetTerminate) { +#if TEST_PHI2 + void Generator_SetTerminate_Call(OgaGenerator* generator) { + generator->SetTerminate(); + } + + void Generate_Output(OgaGenerator* generator, std::unique_ptr tokenizer_stream) { + while (!generator->IsDone()) { + generator->ComputeLogits(); + generator->GenerateNextToken(); + } + } + + auto model = OgaModel::Create(PHI2_PATH); + auto tokenizer = OgaTokenizer::Create(*model); + auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer); + + const char* input_string = "She sells sea shells by the sea shore."; + auto input_sequences = OgaSequences::Create(); + tokenizer->Encode(input_string, *input_sequences); + auto params = OgaGeneratorParams::Create(*model); + params->SetInputSequences(*input_sequences); + params->SetSearchOption("max_length", 40); + + auto generator = OgaGenerator::Create(*model, *params); + threads.push_back(std::thread(Generate_Output, generator.get(), std::move(tokenizer_stream))); + threads.push_back(std::thread(Generator_SetTerminate_Call, generator.get())); + + std::vector threads; + + for (auto& th : threads) { + std::cout << "Waiting for threads completion" << std::endl; + th.join(); // Wait for each thread to finish + } +#endif +} + #if TEST_PHI2 struct Phi2Test { From a9b35dd8fcdd85e4639228df358d7107766198fa Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Mon, 21 Oct 2024 01:23:40 -0700 Subject: [PATCH 06/17] declare functions outside test --- test/c_api_tests.cpp | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index ecdda82b3..5dd71983e 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -283,18 +283,22 @@ TEST(CAPITests, GetOutputCAPI) { generator->GenerateNextToken(); } -TEST(CAPITests, SetTerminate) { #if TEST_PHI2 - void Generator_SetTerminate_Call(OgaGenerator* generator) { - generator->SetTerminate(); - } +void Generator_SetTerminate_Call(OgaGenerator* generator) { + generator->SetTerminate(); +} - void Generate_Output(OgaGenerator* generator, std::unique_ptr tokenizer_stream) { - while (!generator->IsDone()) { - generator->ComputeLogits(); - generator->GenerateNextToken(); - } +void Generate_Output(OgaGenerator* generator, std::unique_ptr tokenizer_stream) { + while (!generator->IsDone()) { + generator->ComputeLogits(); + generator->GenerateNextToken(); } +} + +#endif + +TEST(CAPITests, SetTerminate) { +#if TEST_PHI2 auto model = OgaModel::Create(PHI2_PATH); auto tokenizer = OgaTokenizer::Create(*model); From 80b0a79367a9375f4467a337e6dec5240ed3195f Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Mon, 21 Oct 2024 02:37:24 -0700 Subject: [PATCH 07/17] bug fix --- test/c_api_tests.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index 5dd71983e..eec685b6f 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -312,6 +312,7 @@ TEST(CAPITests, SetTerminate) { params->SetSearchOption("max_length", 40); auto generator = OgaGenerator::Create(*model, *params); + std::vector threads; threads.push_back(std::thread(Generate_Output, generator.get(), std::move(tokenizer_stream))); threads.push_back(std::thread(Generator_SetTerminate_Call, generator.get())); From 4de71a550d2a5e4a7af21e4c56df8da99f60c8fa Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Mon, 21 Oct 2024 02:55:56 -0700 Subject: [PATCH 08/17] bug fix --- test/c_api_tests.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index eec685b6f..a0c78d411 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -316,8 +316,6 @@ TEST(CAPITests, SetTerminate) { threads.push_back(std::thread(Generate_Output, generator.get(), std::move(tokenizer_stream))); threads.push_back(std::thread(Generator_SetTerminate_Call, generator.get())); - std::vector threads; - for (auto& th : threads) { std::cout << "Waiting for threads completion" << std::endl; th.join(); // Wait for each thread to finish From c1420eaf6681aaeae2d63c6dbc8fc25987f23a5c Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Mon, 21 Oct 2024 03:37:06 -0700 Subject: [PATCH 09/17] add checks for session terminated --- src/models/model.cpp | 4 ++++ src/models/model.h | 1 + test/c_api_tests.cpp | 16 ++++++++++++---- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/models/model.cpp b/src/models/model.cpp index cdd6a6954..58de9ab9f 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -75,6 +75,10 @@ void State::UnsetTerminate() { run_options_->UnsetTerminate(); } +bool State::IsSessionTerminated() { + return session_terminated; +} + OrtValue* State::GetInput(const char* name) { ThrowErrorIfSessionTerminated(session_terminated); for (size_t i = 0; i < input_names_.size(); i++) { diff --git a/src/models/model.h b/src/models/model.h index fd9b86478..6cf91fd3c 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -36,6 +36,7 @@ struct State { void SetTerminate(); void UnsetTerminate(); + bool IsSessionTerminated(); mutable bool session_terminated{}; OrtValue* GetInput(const char* name); diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index a0c78d411..55371b46d 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -289,12 +289,16 @@ void Generator_SetTerminate_Call(OgaGenerator* generator) { } void Generate_Output(OgaGenerator* generator, std::unique_ptr tokenizer_stream) { - while (!generator->IsDone()) { - generator->ComputeLogits(); - generator->GenerateNextToken(); + try{ + while (!generator->IsDone()) { + generator->ComputeLogits(); + generator->GenerateNextToken(); + } + } + catch (const std::exception& e) { + std::cout << "Session Terminated" << std::endl; } } - #endif TEST(CAPITests, SetTerminate) { @@ -312,6 +316,7 @@ TEST(CAPITests, SetTerminate) { params->SetSearchOption("max_length", 40); auto generator = OgaGenerator::Create(*model, *params); + EXPECT_EQ(generator->IsSessionTerminated(), false); std::vector threads; threads.push_back(std::thread(Generate_Output, generator.get(), std::move(tokenizer_stream))); threads.push_back(std::thread(Generator_SetTerminate_Call, generator.get())); @@ -320,6 +325,9 @@ TEST(CAPITests, SetTerminate) { std::cout << "Waiting for threads completion" << std::endl; th.join(); // Wait for each thread to finish } + EXPECT_EQ(generator->IsSessionTerminated(), true); + generator->UnsetTerminate(); + EXPECT_EQ(generator->IsSessionTerminated(), false); #endif } From 3ab84c5d226563e1b6ce850d803a9fe9930ba2a8 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Mon, 21 Oct 2024 04:11:59 -0700 Subject: [PATCH 10/17] modify set terminate check --- src/generators.cpp | 4 ++++ src/generators.h | 1 + src/models/model.cpp | 4 ---- src/models/model.h | 1 - test/c_api_tests.cpp | 6 +++--- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 942be2957..c67602687 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -318,6 +318,10 @@ bool Generator::IsDone() const { return is_done; } +bool Generator::IsSessionTerminated() const { + return state_->session_terminated; +} + void Generator::GenerateNextToken() { ThrowErrorIfSessionTerminated(state_->session_terminated); if (!computed_logits_) diff --git a/src/generators.h b/src/generators.h index 6249c9037..2fe9c4cd8 100644 --- a/src/generators.h +++ b/src/generators.h @@ -108,6 +108,7 @@ struct Generator : LeakChecked { Generator(const Model& model, const GeneratorParams& params); bool IsDone() const; + bool IsSessionTerminated() const; void ComputeLogits(); void GenerateNextToken(); diff --git a/src/models/model.cpp b/src/models/model.cpp index 58de9ab9f..cdd6a6954 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -75,10 +75,6 @@ void State::UnsetTerminate() { run_options_->UnsetTerminate(); } -bool State::IsSessionTerminated() { - return session_terminated; -} - OrtValue* State::GetInput(const char* name) { ThrowErrorIfSessionTerminated(session_terminated); for (size_t i = 0; i < input_names_.size(); i++) { diff --git a/src/models/model.h b/src/models/model.h index 6cf91fd3c..fd9b86478 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -36,7 +36,6 @@ struct State { void SetTerminate(); void UnsetTerminate(); - bool IsSessionTerminated(); mutable bool session_terminated{}; OrtValue* GetInput(const char* name); diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index 55371b46d..7382bf181 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -316,7 +316,7 @@ TEST(CAPITests, SetTerminate) { params->SetSearchOption("max_length", 40); auto generator = OgaGenerator::Create(*model, *params); - EXPECT_EQ(generator->IsSessionTerminated(), false); + // EXPECT_EQ(generator->IsSessionTerminated(), false); std::vector threads; threads.push_back(std::thread(Generate_Output, generator.get(), std::move(tokenizer_stream))); threads.push_back(std::thread(Generator_SetTerminate_Call, generator.get())); @@ -325,9 +325,9 @@ TEST(CAPITests, SetTerminate) { std::cout << "Waiting for threads completion" << std::endl; th.join(); // Wait for each thread to finish } - EXPECT_EQ(generator->IsSessionTerminated(), true); + // EXPECT_EQ(generator->IsSessionTerminated(), true); generator->UnsetTerminate(); - EXPECT_EQ(generator->IsSessionTerminated(), false); + // EXPECT_EQ(generator->IsSessionTerminated(), false); #endif } From b5e917fcebec4eef9c1b791836455cd2eeef5064 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Mon, 21 Oct 2024 04:32:28 -0700 Subject: [PATCH 11/17] check session terminated status --- test/c_api_tests.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index 7382bf181..d1147e7f0 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -296,7 +296,7 @@ void Generate_Output(OgaGenerator* generator, std::unique_ptrSetSearchOption("max_length", 40); auto generator = OgaGenerator::Create(*model, *params); - // EXPECT_EQ(generator->IsSessionTerminated(), false); + EXPECT_EQ(generator->IsSessionTerminated(), false); std::vector threads; threads.push_back(std::thread(Generate_Output, generator.get(), std::move(tokenizer_stream))); threads.push_back(std::thread(Generator_SetTerminate_Call, generator.get())); @@ -325,9 +325,9 @@ TEST(CAPITests, SetTerminate) { std::cout << "Waiting for threads completion" << std::endl; th.join(); // Wait for each thread to finish } - // EXPECT_EQ(generator->IsSessionTerminated(), true); + EXPECT_EQ(generator->IsSessionTerminated(), true); generator->UnsetTerminate(); - // EXPECT_EQ(generator->IsSessionTerminated(), false); + EXPECT_EQ(generator->IsSessionTerminated(), false); #endif } From b1c9d297910728c64a1726bd9794a6ac37e2c742 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Mon, 21 Oct 2024 04:47:48 -0700 Subject: [PATCH 12/17] bug fix --- test/c_api_tests.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index d1147e7f0..a0a333774 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -316,7 +316,7 @@ TEST(CAPITests, SetTerminate) { params->SetSearchOption("max_length", 40); auto generator = OgaGenerator::Create(*model, *params); - EXPECT_EQ(generator->IsSessionTerminated(), false); + EXPECT_EQ(generator.IsSessionTerminated(), false); std::vector threads; threads.push_back(std::thread(Generate_Output, generator.get(), std::move(tokenizer_stream))); threads.push_back(std::thread(Generator_SetTerminate_Call, generator.get())); @@ -325,9 +325,9 @@ TEST(CAPITests, SetTerminate) { std::cout << "Waiting for threads completion" << std::endl; th.join(); // Wait for each thread to finish } - EXPECT_EQ(generator->IsSessionTerminated(), true); + EXPECT_EQ(generator.IsSessionTerminated(), true); generator->UnsetTerminate(); - EXPECT_EQ(generator->IsSessionTerminated(), false); + EXPECT_EQ(generator.IsSessionTerminated(), false); #endif } From 28cd0acd45f3438fa77136261c118ebb51b9557b Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Mon, 21 Oct 2024 05:19:02 -0700 Subject: [PATCH 13/17] bug fix --- test/c_api_tests.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index a0a333774..3847b1a79 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -316,7 +316,7 @@ TEST(CAPITests, SetTerminate) { params->SetSearchOption("max_length", 40); auto generator = OgaGenerator::Create(*model, *params); - EXPECT_EQ(generator.IsSessionTerminated(), false); + EXPECT_EQ(generator.get()->IsSessionTerminated(), false); std::vector threads; threads.push_back(std::thread(Generate_Output, generator.get(), std::move(tokenizer_stream))); threads.push_back(std::thread(Generator_SetTerminate_Call, generator.get())); @@ -325,9 +325,9 @@ TEST(CAPITests, SetTerminate) { std::cout << "Waiting for threads completion" << std::endl; th.join(); // Wait for each thread to finish } - EXPECT_EQ(generator.IsSessionTerminated(), true); + EXPECT_EQ(generator.get()->IsSessionTerminated(), true); generator->UnsetTerminate(); - EXPECT_EQ(generator.IsSessionTerminated(), false); + EXPECT_EQ(generator.get()->IsSessionTerminated(), false); #endif } From fd4f11c230b67b9c7715dd54c64213ff04c5af55 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Mon, 21 Oct 2024 05:38:19 -0700 Subject: [PATCH 14/17] add additional functions --- src/ort_genai.h | 4 ++++ src/ort_genai_c.cpp | 4 ++++ src/ort_genai_c.h | 1 + test/c_api_tests.cpp | 6 +++--- 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/ort_genai.h b/src/ort_genai.h index 0588ff4fb..883993952 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -236,6 +236,10 @@ struct OgaGenerator : OgaAbstract { return OgaGenerator_IsDone(this); } + bool IsSessionTerminated() const { + return OgaGenerator_IsSessionTerminated(this); + } + void ComputeLogits() { OgaCheckResult(OgaGenerator_ComputeLogits(this)); } diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index acd37ea0f..b29a4e9d6 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -251,6 +251,10 @@ bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator) { return reinterpret_cast(generator)->IsDone(); } +bool OGA_API_CALL OgaGenerator_IsSessionTerminated(const OgaGenerator* generator) { + return reinterpret_cast(generator)->IsSessionTerminated(); +} + OgaResult* OGA_API_CALL OgaGenerator_ComputeLogits(OgaGenerator* generator) { OGA_TRY reinterpret_cast(generator)->ComputeLogits(); diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index e6453d3a3..4a3e88a61 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -246,6 +246,7 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyGenerator(OgaGenerator* generator); * \return True if the generator has finished generating all the sequences, false otherwise. */ OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsDone(const OgaGenerator* generator); +OGA_EXPORT bool OGA_API_CALL OgaGenerator_IsSessionTerminated(const OgaGenerator* generator); /* * \brief Computes the logits from the model based on the input ids and the past state. The computed logits are stored in the generator. diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index 3847b1a79..ebe9d6f55 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -316,7 +316,7 @@ TEST(CAPITests, SetTerminate) { params->SetSearchOption("max_length", 40); auto generator = OgaGenerator::Create(*model, *params); - EXPECT_EQ(generator.get()->IsSessionTerminated(), false); + EXPECT_EQ(generator->IsSessionTerminated(), false) std::vector threads; threads.push_back(std::thread(Generate_Output, generator.get(), std::move(tokenizer_stream))); threads.push_back(std::thread(Generator_SetTerminate_Call, generator.get())); @@ -325,9 +325,9 @@ TEST(CAPITests, SetTerminate) { std::cout << "Waiting for threads completion" << std::endl; th.join(); // Wait for each thread to finish } - EXPECT_EQ(generator.get()->IsSessionTerminated(), true); + EXPECT_EQ(generator->IsSessionTerminated(), true); generator->UnsetTerminate(); - EXPECT_EQ(generator.get()->IsSessionTerminated(), false); + EXPECT_EQ(generator->IsSessionTerminated(), false); #endif } From 9f81bb14a7a4744f28481c7912f1d74f4ad97fb7 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Mon, 21 Oct 2024 06:01:04 -0700 Subject: [PATCH 15/17] bug fix --- test/c_api_tests.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index ebe9d6f55..d1147e7f0 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -316,7 +316,7 @@ TEST(CAPITests, SetTerminate) { params->SetSearchOption("max_length", 40); auto generator = OgaGenerator::Create(*model, *params); - EXPECT_EQ(generator->IsSessionTerminated(), false) + EXPECT_EQ(generator->IsSessionTerminated(), false); std::vector threads; threads.push_back(std::thread(Generate_Output, generator.get(), std::move(tokenizer_stream))); threads.push_back(std::thread(Generator_SetTerminate_Call, generator.get())); From 936e0227b9fcc2453794eb991052b3a0d81e6cad Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Wed, 23 Oct 2024 10:23:04 -0700 Subject: [PATCH 16/17] use lambda fn --- test/c_api_tests.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index d1147e7f0..3c635de87 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -284,9 +284,6 @@ TEST(CAPITests, GetOutputCAPI) { } #if TEST_PHI2 -void Generator_SetTerminate_Call(OgaGenerator* generator) { - generator->SetTerminate(); -} void Generate_Output(OgaGenerator* generator, std::unique_ptr tokenizer_stream) { try{ @@ -304,6 +301,10 @@ void Generate_Output(OgaGenerator* generator, std::unique_ptrSetTerminate(); + }; + auto model = OgaModel::Create(PHI2_PATH); auto tokenizer = OgaTokenizer::Create(*model); auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer); From f75e1802e2d536c1969d5caf4670fccfc1f1b46e Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Wed, 23 Oct 2024 11:03:06 -0700 Subject: [PATCH 17/17] use lambda fns --- test/c_api_tests.cpp | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/test/c_api_tests.cpp b/test/c_api_tests.cpp index 3c635de87..aad1660fc 100644 --- a/test/c_api_tests.cpp +++ b/test/c_api_tests.cpp @@ -301,10 +301,22 @@ void Generate_Output(OgaGenerator* generator, std::unique_ptrSetTerminate(); }; + auto GenerateOutput = [](OgaGenerator* generator, std::unique_ptr tokenizer_stream) { + try { + while (!generator->IsDone()) { + generator->ComputeLogits(); + generator->GenerateNextToken(); + } + } + catch (const std::exception& e) { + std::cout << "Session Terminated: " << e.what() << std::endl; + } + }; + auto model = OgaModel::Create(PHI2_PATH); auto tokenizer = OgaTokenizer::Create(*model); auto tokenizer_stream = OgaTokenizerStream::Create(*tokenizer); @@ -319,8 +331,8 @@ TEST(CAPITests, SetTerminate) { auto generator = OgaGenerator::Create(*model, *params); EXPECT_EQ(generator->IsSessionTerminated(), false); std::vector threads; - threads.push_back(std::thread(Generate_Output, generator.get(), std::move(tokenizer_stream))); - threads.push_back(std::thread(Generator_SetTerminate_Call, generator.get())); + threads.push_back(std::thread(GenerateOutput, generator.get(), std::move(tokenizer_stream))); + threads.push_back(std::thread(GeneratorSetTerminateCall, generator.get())); for (auto& th : threads) { std::cout << "Waiting for threads completion" << std::endl;