Skip to content

Commit

Permalink
Remove data race in compose event
Browse files Browse the repository at this point in the history
  • Loading branch information
guusw committed Apr 29, 2024
1 parent e4d553a commit 7278e55
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 58 deletions.
35 changes: 21 additions & 14 deletions shards/core/foundation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,6 @@ struct SHTableImpl : public SHAlignedMap<shards::OwnedVar, shards::OwnedVar> {

struct SHWire : public std::enable_shared_from_this<SHWire> {
enum State { Stopped, Prepared, Starting, Iterating, IterationEnded, Failed, Ended };

struct OnComposedEvent {
const SHWire *wire;
};

struct OnStartEvent {
const SHWire *wire;
};
Expand All @@ -272,6 +267,20 @@ struct SHWire : public std::enable_shared_from_this<SHWire> {
const SHWire *childWire;
};

// Storage of data used only during compose
struct ComposeData {
// List of output types used for this wire
std::vector<SHTypeInfo> outputTypes;
};
std::shared_ptr<ComposeData> composeData;

ComposeData &getComposeData() {
if (!composeData) {
composeData = std::make_shared<ComposeData>();
}
return *composeData;
}

// Attributes
bool looped{false};
bool unsafe{false};
Expand Down Expand Up @@ -1594,21 +1603,19 @@ inline void collectAllRequiredVariables(const SHExposedTypesInfo &exposed, Expos
}

inline SHStringWithLen swlDuplicate(SHStringWithLen in) {
if(in.len == 0) {
if (in.len == 0) {
return SHStringWithLen{};
}
SHStringWithLen cpy{
.string = new char[in.len],
.len = in.len,
.string = new char[in.len],
.len = in.len,
};
memcpy(const_cast<char*>(cpy.string), in.string, in.len);
memcpy(const_cast<char *>(cpy.string), in.string, in.len);
return cpy;
}
inline SHStringWithLen swlFromStringView(std::string_view in) {
return swlDuplicate(toSWL(in));
}
inline void swlFree(SHStringWithLen& in) {
if(in.len > 0) {
inline SHStringWithLen swlFromStringView(std::string_view in) { return swlDuplicate(toSWL(in)); }
inline void swlFree(SHStringWithLen &in) {
if (in.len > 0) {
delete[] in.string;
in.string = nullptr;
in.len = 0;
Expand Down
32 changes: 20 additions & 12 deletions shards/core/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "hash.inl"
#include "utils.hpp"
#include "trait.hpp"
#include "type_cache.hpp"

#ifdef SH_COMPRESSED_STRINGS
#include <shards/wire_dsl.hpp>
Expand Down Expand Up @@ -1033,14 +1034,8 @@ SHComposeResult composeWire(const std::vector<Shard *> &wire, SHValidationCallba
if (extVar.type) {
type = extVar.type;
} else {
auto hash = deriveTypeHash64(var);
TypeInfo *info = nullptr;
if (ctx.wire->typesCache.find(hash) == ctx.wire->typesCache.end()) {
info = &ctx.wire->typesCache.emplace(hash, TypeInfo(var, data)).first->second;
} else {
info = &ctx.wire->typesCache.at(hash);
}
type = &(const SHTypeInfo &)*info;
static TypeCache typeCache;
type = &typeCache.insertUnique(TypeInfo(var, data));
}

SHExposedTypeInfo expInfo{key.payload.stringValue, {}, *type, true /* mutable */};
Expand Down Expand Up @@ -1140,15 +1135,17 @@ void validateWireTraits(const SHWire *wire, const SHComposeResult &cr) {
}
}

SHComposeResult composeWire(const SHWire *wire, SHValidationCallback callback, void *userData, SHInstanceData data) {
SHComposeResult composeWire(const SHWire *wire_, SHValidationCallback callback, void *userData, SHInstanceData data) {
SHWire *wire = const_cast<SHWire *>(wire_);

// compare exchange and then shassert we were not composing
bool expected = false;
if (!const_cast<SHWire *>(wire)->composing.compare_exchange_strong(expected, true)) {
if (!wire->composing.compare_exchange_strong(expected, true)) {
SHLOG_ERROR("Wire {} is already being composed", wire->name);
throw ComposeError("Wire is already being composed");
}
// defer reset compose state
DEFER(const_cast<SHWire *>(wire)->composing.store(false));
DEFER(wire->composing.store(false));

// settle input type of wire before compose
if (wire->shards.size() > 0 && strncmp(wire->shards[0]->name(wire->shards[0]), "Expect", 6) == 0) {
Expand Down Expand Up @@ -1182,7 +1179,18 @@ SHComposeResult composeWire(const SHWire *wire, SHValidationCallback callback, v
// set output type
wire->outputType = res.outputType;

wire->mesh.lock()->dispatcher.trigger(SHWire::OnComposedEvent{wire});
// validate wire output types for additional return paths
if (wire->composeData) {
auto &cd = *wire->composeData.get();
DEFER({ wire->composeData.reset(); });
for (auto &type : cd.outputTypes) {
if (!matchTypes(res.outputType, type, true, true, true)) {
std::string err =
fmt::format("Possible output {} does not match main output type: {} for wire {}", type, res.outputType, wire->name);
throw ComposeError(err);
}
}
}

return res;
}
Expand Down
1 change: 1 addition & 0 deletions shards/core/type_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define D0E98483_F0A4_476F_81F3_8709C5D852D7

#include "hash.inl"
#include "foundation.hpp"
#include <map>
#include <shared_mutex>

Expand Down
34 changes: 2 additions & 32 deletions shards/modules/core/wires.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,29 +555,14 @@ struct StopWire : public WireBase {

SHTypeInfo _inputType{};

std::weak_ptr<SHMesh> _mesh;
entt::connection _onComposedConn;

void destroy() {
auto mesh = _mesh.lock();
if (mesh && _onComposedConn)
_onComposedConn.release();
}
std::weak_ptr<SHWire> _wire;

SHTypeInfo compose(SHInstanceData &data) {
assert(data.wire);

if (wireref->valueType == SHType::None) {
_inputType = data.inputType;

if (_onComposedConn)
_onComposedConn.release();

_mesh = data.wire->mesh;
auto mesh = _mesh.lock();
if (mesh) {
_onComposedConn = mesh->dispatcher.sink<SHWire::OnComposedEvent>().connect<&StopWire::composed>(this);
}
data.wire->getComposeData().outputTypes.push_back(data.inputType);
} else {
resolveWire();
if (wire) {
Expand All @@ -589,21 +574,6 @@ struct StopWire : public WireBase {
return data.inputType;
}

void composed(const SHWire::OnComposedEvent &e) {
if (e.wire != wire.get())
return;

// this check runs only when (Stop) is called without any params!
// meaning it's stopping the wire it is in
if (!wire && wireref->valueType == SHType::None && !matchTypes(_inputType, e.wire->outputType, false, true, true)) {
SHLOG_ERROR("Stop input and wire output type mismatch, Stop input must "
"be the same type of the wire's output (regular flow), "
"wire: {} expected: {}",
e.wire->name, e.wire->outputType);
throw ComposeError("Stop input and wire output type mismatch");
}
}

void cleanup(SHContext *context) {
if (wireref.isVariable())
wire = nullptr;
Expand Down

0 comments on commit 7278e55

Please sign in to comment.