From c0db1587c7f39e06334b2cf25ae49e9de23cba5a Mon Sep 17 00:00:00 2001 From: Giovanni Petrantoni <7008900+sinkingsugar@users.noreply.github.com> Date: Thu, 31 Oct 2024 11:40:09 +0800 Subject: [PATCH] Add atomic flag management with SetFlag and GetFlag shards Introduce AtomicFlag struct with atomic bool for flag state management. Implement `SetFlag` shard to set flag value based on input and `GetFlag` shard to retrieve the current flag state. Integrate shared flag access using getFlag function with thread-safe handling. Register new shards in channels module. --- shards/modules/channels/channels.cpp | 70 ++++++++++++++++++++++++++++ shards/modules/channels/channels.hpp | 8 ++++ shards/tests/hello.shs | 3 ++ 3 files changed, 81 insertions(+) diff --git a/shards/modules/channels/channels.cpp b/shards/modules/channels/channels.cpp index ced108c6ca..1495c3944a 100644 --- a/shards/modules/channels/channels.cpp +++ b/shards/modules/channels/channels.cpp @@ -511,8 +511,76 @@ struct Flush : public Base { return input; } }; + +FlagPtr getFlag(const std::string &name) { + static std::unordered_map> flags; + static std::shared_mutex mutex; + + std::shared_lock _l(mutex); + auto it = flags.find(name); + if (it == flags.end()) { + _l.unlock(); + std::scoped_lock _l1(mutex); + auto sp = std::make_shared(); + flags[name] = sp; + return sp; + } else { + std::shared_ptr sp = it->second.lock(); + if (!sp) { + _l.unlock(); + std::scoped_lock _l1(mutex); + sp = std::make_shared(); + flags[name] = sp; + } + return sp; + } +} + +// Set flag value shard +struct SetFlag : public Base { + FlagPtr _flag; + + static inline Parameters setFlagParams{ + {"Name", SHCCSTR("The name of the flag."), {CoreInfo::StringType}}, + }; + + static SHTypesInfo inputTypes() { return CoreInfo::BoolType; } + static SHTypesInfo outputTypes() { return CoreInfo::BoolType; } + static SHParametersInfo parameters() { return setFlagParams; } + + SHTypeInfo compose(const SHInstanceData &data) { + _flag = getFlag(_name); + return data.inputType; + } + + SHVar activate(SHContext *context, const SHVar &input) { + _flag->value.store(input.payload.boolValue); + return input; + } +}; + +// Get flag value shard +struct GetFlag : public Base { + FlagPtr _flag; + + static inline Parameters getFlagParams{ + {"Name", SHCCSTR("The name of the flag."), {CoreInfo::StringType}}, + }; + + static SHTypesInfo inputTypes() { return CoreInfo::AnyType; } + static SHTypesInfo outputTypes() { return CoreInfo::BoolType; } + static SHParametersInfo parameters() { return getFlagParams; } + + SHTypeInfo compose(const SHInstanceData &data) { + _flag = getFlag(_name); + return CoreInfo::BoolType; + } + + SHVar activate(SHContext *context, const SHVar &input) { return Var(_flag->value.load()); } +}; } // namespace channels } // namespace shards + SHARDS_REGISTER_FN(channels) { using namespace shards::channels; REGISTER_SHARD("Produce", Produce); @@ -521,4 +589,6 @@ SHARDS_REGISTER_FN(channels) { REGISTER_SHARD("Listen", Listen); REGISTER_SHARD("Complete", Complete); REGISTER_SHARD("Flush", Flush); + REGISTER_SHARD("SetFlag", SetFlag); + REGISTER_SHARD("GetFlag", GetFlag); } diff --git a/shards/modules/channels/channels.hpp b/shards/modules/channels/channels.hpp index c7fabaed3a..290202f2fb 100644 --- a/shards/modules/channels/channels.hpp +++ b/shards/modules/channels/channels.hpp @@ -98,6 +98,14 @@ class BroadcastChannel : public ChannelShared { using Channel = std::variant; std::shared_ptr get(const std::string &name); +// Atomic flag wrapper +struct AtomicFlag { + std::atomic_bool value{false}; +}; + +using FlagPtr = std::shared_ptr; +FlagPtr getFlag(const std::string &name); + } // namespace channels } // namespace shards diff --git a/shards/tests/hello.shs b/shards/tests/hello.shs index e51ae5aba4..fc2fe63737 100644 --- a/shards/tests/hello.shs +++ b/shards/tests/hello.shs @@ -247,3 +247,6 @@ ToJson | Yaml.FromJson | Log sorted-table-test | First | ExpectSeq | Assert.Is([["A" 1] 1]) sorted-table-test | Last | ExpectSeq | Assert.Is([["E"] 6]) + +true | SetFlag("test-flag") +GetFlag("test-flag") | Log | Assert.Is(true)