From 520bf8f2a39ee106885da6e3214ef294032fbea3 Mon Sep 17 00:00:00 2001 From: Michael Demoret Date: Tue, 22 Oct 2024 00:09:36 -0400 Subject: [PATCH] Adding more tests --- .../mrc/node/operators/combine_latest.hpp | 10 +- .../mrc/node/operators/with_latest_from.hpp | 10 +- cpp/mrc/include/mrc/node/operators/zip.hpp | 9 +- cpp/mrc/include/mrc/segment/object.hpp | 108 --------- cpp/mrc/include/mrc/utils/tuple_utils.hpp | 31 +++ .../pymrc/utilities/function_wrappers.hpp | 9 +- python/mrc/core/node.cpp | 213 +++++++++++++++--- python/tests/test_edges.py | 87 ++++++- 8 files changed, 333 insertions(+), 144 deletions(-) diff --git a/cpp/mrc/include/mrc/node/operators/combine_latest.hpp b/cpp/mrc/include/mrc/node/operators/combine_latest.hpp index fae95dd93..5c6788cdd 100644 --- a/cpp/mrc/include/mrc/node/operators/combine_latest.hpp +++ b/cpp/mrc/include/mrc/node/operators/combine_latest.hpp @@ -34,12 +34,20 @@ namespace mrc::node { +class CombineLatestTypelessBase +{ + public: + virtual ~CombineLatestTypelessBase() = default; +}; + template class CombineLatestBase; template class CombineLatestBase, OutputT> - : public WritableAcceptor, public HeterogeneousNodeParent...> + : public CombineLatestTypelessBase, + public WritableAcceptor, + public HeterogeneousNodeParent...> { template static auto build_ingress(CombineLatestBase* self, std::index_sequence /*unused*/) diff --git a/cpp/mrc/include/mrc/node/operators/with_latest_from.hpp b/cpp/mrc/include/mrc/node/operators/with_latest_from.hpp index 25ae7fd15..305371186 100644 --- a/cpp/mrc/include/mrc/node/operators/with_latest_from.hpp +++ b/cpp/mrc/include/mrc/node/operators/with_latest_from.hpp @@ -37,13 +37,21 @@ namespace mrc::node { +class WithLatestFromTypelessBase +{ + public: + virtual ~WithLatestFromTypelessBase() = default; +}; + template class WithLatestFromBase {}; template class WithLatestFromBase, OutputT> - : public WritableAcceptor, public HeterogeneousNodeParent...> + : public WithLatestFromTypelessBase, + public WritableAcceptor, + public HeterogeneousNodeParent...> { public: using input_tuple_t = std::tuple; diff --git a/cpp/mrc/include/mrc/node/operators/zip.hpp b/cpp/mrc/include/mrc/node/operators/zip.hpp index ddd75736e..afc01350a 100644 --- a/cpp/mrc/include/mrc/node/operators/zip.hpp +++ b/cpp/mrc/include/mrc/node/operators/zip.hpp @@ -45,12 +45,19 @@ namespace mrc::node { +class ZipTypelessBase +{ + public: + virtual ~ZipTypelessBase() = default; +}; + template class ZipBase {}; template -class ZipBase, OutputT> : public WritableAcceptor, +class ZipBase, OutputT> : public ZipTypelessBase, + public WritableAcceptor, public HeterogeneousNodeParent...> { public: diff --git a/cpp/mrc/include/mrc/segment/object.hpp b/cpp/mrc/include/mrc/segment/object.hpp index aa20e5446..8ba06e7b6 100644 --- a/cpp/mrc/include/mrc/segment/object.hpp +++ b/cpp/mrc/include/mrc/segment/object.hpp @@ -57,11 +57,6 @@ struct ObjectPropertiesState const bool is_sink; const bool is_source; - // std::optional sink_type = std::nullopt; - // std::optional sink_type_no_holder = std::nullopt; - // std::optional source_type = std::nullopt; - // std::optional source_type_no_holder = std::nullopt; - const bool is_writable_acceptor; const bool is_writable_provider; const bool is_readable_acceptor; @@ -325,38 +320,17 @@ template class Object : public virtual ObjectProperties, public std::enable_shared_from_this> { public: - // Object(const Object& other) : m_name(other.m_name), m_launch_options(other.m_launch_options) {} - // Object(Object&&) = delete; - // Object& operator=(const Object&) = delete; - // Object& operator=(Object&&) = delete; - ObjectT& object(); const ObjectT& object() const; - // std::string name() const final; - // std::string type_name() const final; - - // bool is_source() const final; - // bool is_sink() const final; - std::type_index sink_type(bool ignore_holder) const final; std::type_index source_type(bool ignore_holder) const final; - // bool is_writable_acceptor() const final; - // bool is_writable_provider() const final; - // bool is_readable_acceptor() const final; - // bool is_readable_provider() const final; - edge::IWritableAcceptorBase& writable_acceptor_base() final; edge::IWritableProviderBase& writable_provider_base() final; edge::IReadableAcceptorBase& readable_acceptor_base() final; edge::IReadableProviderBase& readable_provider_base() final; - // bool is_runnable() const final - // { - // return static_cast(std::is_base_of_v); - // } - runnable::LaunchOptions& launch_options() final { if (!is_runnable()) @@ -467,9 +441,6 @@ class Object : public virtual ObjectProperties, public std::enable_shared_from_t return *m_state; } - // // Move to protected to allow only the IBuilder to set the name - // void set_name(const std::string& name) override; - private: virtual ObjectT* get_object() const = 0; @@ -515,7 +486,6 @@ class Object : public virtual ObjectProperties, public std::enable_shared_from_t runnable::LaunchOptions m_launch_options; std::map> m_children; - // std::map()>> m_create_children_fns; // Allows converting to base classes template @@ -548,36 +518,6 @@ const ObjectT& Object::object() const return *node; } -// template -// void Object::set_name(const std::string& name) -// { -// m_name = name; -// } - -// template -// std::string Object::name() const -// { -// return m_name; -// } - -// template -// std::string Object::type_name() const -// { -// return std::string(::mrc::type_name()); -// } - -// template -// bool Object::is_source() const -// { -// return std::is_base_of_v; -// } - -// template -// bool Object::is_sink() const -// { -// return std::is_base_of_v; -// } - template std::type_index Object::sink_type(bool ignore_holder) const { @@ -602,39 +542,9 @@ std::type_index Object::source_type(bool ignore_holder) const return base->source_type(ignore_holder); } -// template -// bool Object::is_writable_acceptor() const -// { -// return std::is_base_of_v; -// } - -// template -// bool Object::is_writable_provider() const -// { -// return std::is_base_of_v; -// } - -// template -// bool Object::is_readable_acceptor() const -// { -// return std::is_base_of_v; -// } - -// template -// bool Object::is_readable_provider() const -// { -// return std::is_base_of_v; -// } - template edge::IWritableAcceptorBase& Object::writable_acceptor_base() { - // if constexpr (!std::is_base_of_v) - // { - // LOG(ERROR) << type_name() << " is not a IIngressAcceptorBase"; - // throw exceptions::MrcRuntimeError("Object is not a IIngressAcceptorBase"); - // } - auto* base = dynamic_cast(get_object()); CHECK(base) << type_name() << " is not a IIngressAcceptorBase"; return *base; @@ -643,12 +553,6 @@ edge::IWritableAcceptorBase& Object::writable_acceptor_base() template edge::IWritableProviderBase& Object::writable_provider_base() { - // if constexpr (!std::is_base_of_v) - // { - // LOG(ERROR) << type_name() << " is not a IIngressProviderBase"; - // throw exceptions::MrcRuntimeError("Object is not a IIngressProviderBase"); - // } - auto* base = dynamic_cast(get_object()); CHECK(base) << type_name() << " is not a IWritableProviderBase"; return *base; @@ -657,12 +561,6 @@ edge::IWritableProviderBase& Object::writable_provider_base() template edge::IReadableAcceptorBase& Object::readable_acceptor_base() { - // if constexpr (!std::is_base_of_v) - // { - // LOG(ERROR) << type_name() << " is not a IEgressAcceptorBase"; - // throw exceptions::MrcRuntimeError("Object is not a IEgressAcceptorBase"); - // } - auto* base = dynamic_cast(get_object()); CHECK(base) << type_name() << " is not a IReadableAcceptorBase"; return *base; @@ -671,12 +569,6 @@ edge::IReadableAcceptorBase& Object::readable_acceptor_base() template edge::IReadableProviderBase& Object::readable_provider_base() { - // if constexpr (!std::is_base_of_v) - // { - // LOG(ERROR) << type_name() << " is not a IEgressProviderBase"; - // throw exceptions::MrcRuntimeError("Object is not a IEgressProviderBase"); - // } - auto* base = dynamic_cast(get_object()); CHECK(base) << type_name() << " is not a IReadableProviderBase"; return *base; diff --git a/cpp/mrc/include/mrc/utils/tuple_utils.hpp b/cpp/mrc/include/mrc/utils/tuple_utils.hpp index edf0f2e9d..14abaed91 100644 --- a/cpp/mrc/include/mrc/utils/tuple_utils.hpp +++ b/cpp/mrc/include/mrc/utils/tuple_utils.hpp @@ -17,6 +17,7 @@ #pragma once +#include #include namespace mrc::utils { @@ -64,4 +65,34 @@ void tuple_for_each(TupleT&& tuple, FuncT&& f) std::forward(f), std::make_index_sequence>::value>()); } + +/** + * @brief Creates a tuple of N elements of type T. For example, `repeat_tuple_type` would be `std::tuple` + * + * @tparam T The type of the tuple + * @tparam N The number of elements in the tuple + */ +template +class repeat_tuple_type +{ + template > + struct impl; + + template + struct impl> + { + template + using wrap = T; + + using type = std::tuple...>; + }; + + public: + using type = typename impl<>::type; +}; + +template +using repeat_tuple_type_t = typename repeat_tuple_type::type; + } // namespace mrc::utils diff --git a/python/mrc/_pymrc/include/pymrc/utilities/function_wrappers.hpp b/python/mrc/_pymrc/include/pymrc/utilities/function_wrappers.hpp index 8a68b4e8b..0c8972552 100644 --- a/python/mrc/_pymrc/include/pymrc/utilities/function_wrappers.hpp +++ b/python/mrc/_pymrc/include/pymrc/utilities/function_wrappers.hpp @@ -105,6 +105,11 @@ struct PyFuncHolder return m_cpp_fn(std::forward(args)...); } + operator bool() const + { + return !m_is_none; + } + static constexpr auto Signature = pybind11::detail::_("Callable[[") + pybind11::detail::concat(pybind11::detail::make_caster::name...) + pybind11::detail::_("], ") + pybind11::detail::make_caster::name + @@ -155,7 +160,8 @@ struct PyFuncHolder // Save the name of the function to help debugging if (py_fn) { - m_repr = pybind11::str(py_fn); + m_repr = pybind11::str(py_fn); + m_is_none = false; } m_cpp_fn = this->build_cpp_function(std::move(py_fn)); @@ -163,6 +169,7 @@ struct PyFuncHolder cpp_fn_t m_cpp_fn; std::string m_repr; + bool m_is_none{true}; }; struct OnNextFunction : public PyFuncHolder diff --git a/python/mrc/core/node.cpp b/python/mrc/core/node.cpp index 0653838f0..f93b0e308 100644 --- a/python/mrc/core/node.cpp +++ b/python/mrc/core/node.cpp @@ -22,11 +22,14 @@ #include "pymrc/utils.hpp" #include "mrc/node/operators/broadcast.hpp" +#include "mrc/node/operators/combine_latest.hpp" #include "mrc/node/operators/round_robin_router_typeless.hpp" +#include "mrc/node/operators/with_latest_from.hpp" #include "mrc/node/operators/zip.hpp" #include "mrc/segment/builder.hpp" #include "mrc/segment/object.hpp" #include "mrc/utils/string_utils.hpp" +#include "mrc/utils/tuple_utils.hpp" #include "mrc/version.hpp" #include @@ -74,33 +77,189 @@ PYBIND11_MODULE(node, py_mod) return node; })); - py::class_< - mrc::segment::Object, PyObjectHolder>>, - mrc::segment::ObjectProperties, - std::shared_ptr, PyObjectHolder>>>>(py_mod, "Zip") - .def(py::init<>([](mrc::segment::IBuilder& builder, std::string name, size_t count) { - if (count == 2) - { - return builder.construct_object< - node::ZipTransformComponent, PyObjectHolder>>( - name, - [](std::tuple&& input_data) { - py::gil_scoped_acquire gil; - - return PyObjectHolder(py::cast(std::move(input_data))); - }); - } - - py::print("Unsupported count!"); - throw std::runtime_error("Unsupported count!"); - })) - .def("get_sink", - [](mrc::segment::Object< - node::ZipTransformComponent, PyObjectHolder>>& self, - size_t index) { - return self.get_child(MRC_CONCAT_STR("sink[" << index << "]")); - }); + py::class_, + mrc::segment::ObjectProperties, + std::shared_ptr>>(py_mod, "ZipComponent") + .def( + py::init<>([](mrc::segment::IBuilder& builder, + std::string name, + size_t count, + PyFuncHolder convert_fn) { + std::function convert_fn_wrapped = + [convert_fn = std::move(convert_fn)](py::tuple input_data) { + if (convert_fn) + { + return PyObjectHolder(convert_fn(std::move(input_data))); + } + + return PyObjectHolder(std::move(input_data)); + }; + + auto make_node = [&builder, + convert_fn_wrapped = std::move(convert_fn_wrapped)](std::string name) { + return builder + .construct_object< + node::ZipTransformComponent, PyObjectHolder>>( + name, + [convert_fn_wrapped = std::move(convert_fn_wrapped)]( + utils::repeat_tuple_type_t&& input_data) { + py::gil_scoped_acquire gil; + + return convert_fn_wrapped(py::cast(std::move(input_data))); + }) + ->template as(); + }; + + if (count == 1) + { + return make_node.template operator()<1>(name); + } + else if (count == 2) + { + return make_node.template operator()<2>(name); + } + else if (count == 3) + { + return make_node.template operator()<3>(name); + } + else if (count == 4) + { + return make_node.template operator()<4>(name); + } + + throw std::runtime_error("Unsupported count!"); + }), + py::arg("builder"), + py::arg("name"), + py::kw_only(), + py::arg("count"), + py::arg("convert_fn") = py::none()) + .def("get_sink", [](mrc::segment::Object& self, size_t index) { + return self.get_child(MRC_CONCAT_STR("sink[" << index << "]")); + }); + + py::class_, + mrc::segment::ObjectProperties, + std::shared_ptr>>(py_mod, + "WithLatestFromComponent") + .def(py::init<>([](mrc::segment::IBuilder& builder, + std::string name, + size_t count, + PyFuncHolder convert_fn) { + std::function convert_fn_wrapped = + [convert_fn = std::move(convert_fn)](py::tuple input_data) { + if (convert_fn) + { + return PyObjectHolder(convert_fn(std::move(input_data))); + } + + return PyObjectHolder(std::move(input_data)); + }; + + auto make_node = [&builder, + convert_fn_wrapped = std::move(convert_fn_wrapped)](std::string name) { + return builder + .construct_object< + node::WithLatestFromTransformComponent, + PyObjectHolder>>( + name, + [convert_fn_wrapped = std::move(convert_fn_wrapped)]( + utils::repeat_tuple_type_t&& input_data) { + py::gil_scoped_acquire gil; + + return convert_fn_wrapped(py::cast(std::move(input_data))); + }) + ->template as(); + }; + + if (count == 1) + { + return make_node.template operator()<1>(name); + } + else if (count == 2) + { + return make_node.template operator()<2>(name); + } + else if (count == 3) + { + return make_node.template operator()<3>(name); + } + else if (count == 4) + { + return make_node.template operator()<4>(name); + } + + throw std::runtime_error("Unsupported count!"); + }), + py::arg("builder"), + py::arg("name"), + py::kw_only(), + py::arg("count"), + py::arg("convert_fn") = py::none()) + .def("get_sink", [](mrc::segment::Object& self, size_t index) { + return self.get_child(MRC_CONCAT_STR("sink[" << index << "]")); + }); + + py::class_, + mrc::segment::ObjectProperties, + std::shared_ptr>>(py_mod, "CombineLatestComponent") + .def(py::init<>([](mrc::segment::IBuilder& builder, + std::string name, + size_t count, + PyFuncHolder convert_fn) { + std::function convert_fn_wrapped = + [convert_fn = std::move(convert_fn)](py::tuple input_data) { + if (convert_fn) + { + return PyObjectHolder(convert_fn(std::move(input_data))); + } + + return PyObjectHolder(std::move(input_data)); + }; + + auto make_node = [&builder, + convert_fn_wrapped = std::move(convert_fn_wrapped)](std::string name) { + return builder + .construct_object< + node::CombineLatestTransformComponent, + PyObjectHolder>>( + name, + [convert_fn_wrapped = std::move(convert_fn_wrapped)]( + utils::repeat_tuple_type_t&& input_data) { + py::gil_scoped_acquire gil; + + return convert_fn_wrapped(py::cast(std::move(input_data))); + }) + ->template as(); + }; + + if (count == 1) + { + return make_node.template operator()<1>(name); + } + else if (count == 2) + { + return make_node.template operator()<2>(name); + } + else if (count == 3) + { + return make_node.template operator()<3>(name); + } + else if (count == 4) + { + return make_node.template operator()<4>(name); + } + + throw std::runtime_error("Unsupported count!"); + }), + py::arg("builder"), + py::arg("name"), + py::kw_only(), + py::arg("count"), + py::arg("convert_fn") = py::none()) + .def("get_sink", [](mrc::segment::Object& self, size_t index) { + return self.get_child(MRC_CONCAT_STR("sink[" << index << "]")); + }); py::class_>, mrc::segment::ObjectProperties, diff --git a/python/tests/test_edges.py b/python/tests/test_edges.py index 28c247c7e..2b5296b29 100644 --- a/python/tests/test_edges.py +++ b/python/tests/test_edges.py @@ -284,7 +284,51 @@ def add_round_robin_router(seg: mrc.Builder, *upstream: mrc.SegmentObject): def add_zip(seg: mrc.Builder, *upstream: mrc.SegmentObject): - node = mrc.core.node.Zip(seg, "Zip", len(upstream)) + node_name = "Zip" + + expected_node_counts.update({f"{node_name}.{k}": v for k, v in {"convert": 5}.items()}) + + def convert_fn(x): + increment_node_counter(f"{node_name}.convert") + return x + + node = mrc.core.node.ZipComponent(seg, node_name, count=len(upstream), convert_fn=convert_fn) + + for i, u in enumerate(upstream): + seg.make_edge(u, node.get_sink(i)) + + return node + + +def add_combine_latest(seg: mrc.Builder, *upstream: mrc.SegmentObject): + + node_name = "CombineLatest" + + expected_node_counts.update({f"{node_name}.{k}": v for k, v in {"convert": 5}.items()}) + + def convert_fn(x): + increment_node_counter(f"{node_name}.convert") + return x + + node = mrc.core.node.CombineLatestComponent(seg, node_name, count=len(upstream), convert_fn=convert_fn) + + for i, u in enumerate(upstream): + seg.make_edge(u, node.get_sink(i)) + + return node + + +def add_with_latest_from(seg: mrc.Builder, *upstream: mrc.SegmentObject): + + node_name = "WithLatestFrom" + + expected_node_counts.update({f"{node_name}.{k}": v for k, v in {"convert": 5}.items()}) + + def convert_fn(x): + increment_node_counter(f"{node_name}.convert") + return x + + node = mrc.core.node.WithLatestFromComponent(seg, node_name, count=len(upstream), convert_fn=convert_fn) for i, u in enumerate(upstream): seg.make_edge(u, node.get_sink(i)) @@ -672,13 +716,46 @@ def segment_init(seg: mrc.Builder): @pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"]) -def test_multi_source_to_zip_to_sink(run_segment, source_cpp: bool): +@pytest.mark.parametrize("upstream_count", range(1, 4), ids=[f"upstream_count_{i}" for i in range(1, 4)]) +def test_multi_source_to_zip_to_sink(run_segment, source_cpp: bool, upstream_count: int): def segment_init(seg: mrc.Builder): - source1 = add_source(seg, is_cpp=source_cpp, data_type=m.Base, is_component=False, suffix="1") - source2 = add_source(seg, is_cpp=source_cpp, data_type=m.Base, is_component=False, suffix="2") - zip = add_zip(seg, source1, source2) + sources = (add_source(seg, is_cpp=source_cpp, data_type=m.Base, is_component=False, suffix=str(i)) + for i in range(upstream_count)) + zip = add_zip(seg, *sources) + add_sink(seg, zip, is_cpp=False, data_type=tuple, is_component=False) + + results = run_segment(segment_init) + + assert results == expected_node_counts + + +@pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"]) +@pytest.mark.parametrize("upstream_count", range(1, 4), ids=[f"upstream_count_{i}" for i in range(1, 4)]) +def test_multi_source_to_combine_latest_to_sink(run_segment, source_cpp: bool, upstream_count: int): + + def segment_init(seg: mrc.Builder): + + sources = (add_source(seg, is_cpp=source_cpp, data_type=m.Base, is_component=False, suffix=str(i)) + for i in range(upstream_count)) + zip = add_combine_latest(seg, *sources) + add_sink(seg, zip, is_cpp=False, data_type=tuple, is_component=False) + + results = run_segment(segment_init) + + assert results == expected_node_counts + + +@pytest.mark.parametrize("source_cpp", [True, False], ids=["source_cpp", "source_py"]) +@pytest.mark.parametrize("upstream_count", range(1, 4), ids=[f"upstream_count_{i}" for i in range(1, 4)]) +def test_multi_source_to_with_latest_from_to_sink(run_segment, source_cpp: bool, upstream_count: int): + + def segment_init(seg: mrc.Builder): + + sources = (add_source(seg, is_cpp=source_cpp, data_type=m.Base, is_component=False, suffix=str(i)) + for i in range(upstream_count)) + zip = add_with_latest_from(seg, *sources) add_sink(seg, zip, is_cpp=False, data_type=tuple, is_component=False) results = run_segment(segment_init)