Skip to content

Commit

Permalink
Adding more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mdemoret-nv committed Oct 22, 2024
1 parent 2bc8d9e commit 520bf8f
Show file tree
Hide file tree
Showing 8 changed files with 333 additions and 144 deletions.
10 changes: 9 additions & 1 deletion cpp/mrc/include/mrc/node/operators/combine_latest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,20 @@

namespace mrc::node {

class CombineLatestTypelessBase
{
public:
virtual ~CombineLatestTypelessBase() = default;
};

template <typename...>
class CombineLatestBase;

template <typename... InputT, typename OutputT>
class CombineLatestBase<std::tuple<InputT...>, OutputT>
: public WritableAcceptor<OutputT>, public HeterogeneousNodeParent<edge::IWritableProvider<InputT>...>
: public CombineLatestTypelessBase,
public WritableAcceptor<OutputT>,
public HeterogeneousNodeParent<edge::IWritableProvider<InputT>...>
{
template <std::size_t... Is>
static auto build_ingress(CombineLatestBase* self, std::index_sequence<Is...> /*unused*/)
Expand Down
10 changes: 9 additions & 1 deletion cpp/mrc/include/mrc/node/operators/with_latest_from.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,21 @@

namespace mrc::node {

class WithLatestFromTypelessBase
{
public:
virtual ~WithLatestFromTypelessBase() = default;
};

template <typename... TypesT>
class WithLatestFromBase
{};

template <typename... InputT, typename OutputT>
class WithLatestFromBase<std::tuple<InputT...>, OutputT>
: public WritableAcceptor<OutputT>, public HeterogeneousNodeParent<edge::IWritableProvider<InputT>...>
: public WithLatestFromTypelessBase,
public WritableAcceptor<OutputT>,
public HeterogeneousNodeParent<edge::IWritableProvider<InputT>...>
{
public:
using input_tuple_t = std::tuple<InputT...>;
Expand Down
9 changes: 8 additions & 1 deletion cpp/mrc/include/mrc/node/operators/zip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,19 @@

namespace mrc::node {

class ZipTypelessBase
{
public:
virtual ~ZipTypelessBase() = default;
};

template <typename... TypesT>
class ZipBase
{};

template <typename... InputT, typename OutputT>
class ZipBase<std::tuple<InputT...>, OutputT> : public WritableAcceptor<OutputT>,
class ZipBase<std::tuple<InputT...>, OutputT> : public ZipTypelessBase,
public WritableAcceptor<OutputT>,
public HeterogeneousNodeParent<edge::IWritableProvider<InputT>...>
{
public:
Expand Down
108 changes: 0 additions & 108 deletions cpp/mrc/include/mrc/segment/object.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@ struct ObjectPropertiesState
const bool is_sink;
const bool is_source;

// std::optional<std::type_index> sink_type = std::nullopt;
// std::optional<std::type_index> sink_type_no_holder = std::nullopt;
// std::optional<std::type_index> source_type = std::nullopt;
// std::optional<std::type_index> source_type_no_holder = std::nullopt;

const bool is_writable_acceptor;
const bool is_writable_provider;
const bool is_readable_acceptor;
Expand Down Expand Up @@ -325,38 +320,17 @@ template <typename ObjectT>
class Object : public virtual ObjectProperties, public std::enable_shared_from_this<Object<ObjectT>>
{
public:
// Object(const Object& other) : m_name(other.m_name), m_launch_options(other.m_launch_options) {}
// Object(Object&&) = delete;
// Object& operator=(const Object&) = delete;
// Object& operator=(Object&&) = delete;

ObjectT& object();
const ObjectT& object() const;

// std::string name() const final;
// std::string type_name() const final;

// bool is_source() const final;
// bool is_sink() const final;

std::type_index sink_type(bool ignore_holder) const final;
std::type_index source_type(bool ignore_holder) const final;

// bool is_writable_acceptor() const final;
// bool is_writable_provider() const final;
// bool is_readable_acceptor() const final;
// bool is_readable_provider() const final;

edge::IWritableAcceptorBase& writable_acceptor_base() final;
edge::IWritableProviderBase& writable_provider_base() final;
edge::IReadableAcceptorBase& readable_acceptor_base() final;
edge::IReadableProviderBase& readable_provider_base() final;

// bool is_runnable() const final
// {
// return static_cast<bool>(std::is_base_of_v<runnable::Runnable, ObjectT>);
// }

runnable::LaunchOptions& launch_options() final
{
if (!is_runnable())
Expand Down Expand Up @@ -467,9 +441,6 @@ class Object : public virtual ObjectProperties, public std::enable_shared_from_t
return *m_state;
}

// // Move to protected to allow only the IBuilder to set the name
// void set_name(const std::string& name) override;

private:
virtual ObjectT* get_object() const = 0;

Expand Down Expand Up @@ -515,7 +486,6 @@ class Object : public virtual ObjectProperties, public std::enable_shared_from_t
runnable::LaunchOptions m_launch_options;

std::map<std::string, std::shared_ptr<ObjectProperties>> m_children;
// std::map<std::string, std::function<std::shared_ptr<ObjectProperties>()>> m_create_children_fns;

// Allows converting to base classes
template <typename U>
Expand Down Expand Up @@ -548,36 +518,6 @@ const ObjectT& Object<ObjectT>::object() const
return *node;
}

// template <typename ObjectT>
// void Object<ObjectT>::set_name(const std::string& name)
// {
// m_name = name;
// }

// template <typename ObjectT>
// std::string Object<ObjectT>::name() const
// {
// return m_name;
// }

// template <typename ObjectT>
// std::string Object<ObjectT>::type_name() const
// {
// return std::string(::mrc::type_name<ObjectT>());
// }

// template <typename ObjectT>
// bool Object<ObjectT>::is_source() const
// {
// return std::is_base_of_v<node::SourcePropertiesBase, ObjectT>;
// }

// template <typename ObjectT>
// bool Object<ObjectT>::is_sink() const
// {
// return std::is_base_of_v<node::SinkPropertiesBase, ObjectT>;
// }

template <typename ObjectT>
std::type_index Object<ObjectT>::sink_type(bool ignore_holder) const
{
Expand All @@ -602,39 +542,9 @@ std::type_index Object<ObjectT>::source_type(bool ignore_holder) const
return base->source_type(ignore_holder);
}

// template <typename ObjectT>
// bool Object<ObjectT>::is_writable_acceptor() const
// {
// return std::is_base_of_v<edge::IWritableAcceptorBase, ObjectT>;
// }

// template <typename ObjectT>
// bool Object<ObjectT>::is_writable_provider() const
// {
// return std::is_base_of_v<edge::IWritableProviderBase, ObjectT>;
// }

// template <typename ObjectT>
// bool Object<ObjectT>::is_readable_acceptor() const
// {
// return std::is_base_of_v<edge::IReadableAcceptorBase, ObjectT>;
// }

// template <typename ObjectT>
// bool Object<ObjectT>::is_readable_provider() const
// {
// return std::is_base_of_v<edge::IReadableProviderBase, ObjectT>;
// }

template <typename ObjectT>
edge::IWritableAcceptorBase& Object<ObjectT>::writable_acceptor_base()
{
// if constexpr (!std::is_base_of_v<edge::IWritableAcceptorBase, ObjectT>)
// {
// LOG(ERROR) << type_name() << " is not a IIngressAcceptorBase";
// throw exceptions::MrcRuntimeError("Object is not a IIngressAcceptorBase");
// }

auto* base = dynamic_cast<edge::IWritableAcceptorBase*>(get_object());
CHECK(base) << type_name() << " is not a IIngressAcceptorBase";
return *base;
Expand All @@ -643,12 +553,6 @@ edge::IWritableAcceptorBase& Object<ObjectT>::writable_acceptor_base()
template <typename ObjectT>
edge::IWritableProviderBase& Object<ObjectT>::writable_provider_base()
{
// if constexpr (!std::is_base_of_v<edge::IWritableProviderBase, ObjectT>)
// {
// LOG(ERROR) << type_name() << " is not a IIngressProviderBase";
// throw exceptions::MrcRuntimeError("Object is not a IIngressProviderBase");
// }

auto* base = dynamic_cast<edge::IWritableProviderBase*>(get_object());
CHECK(base) << type_name() << " is not a IWritableProviderBase";
return *base;
Expand All @@ -657,12 +561,6 @@ edge::IWritableProviderBase& Object<ObjectT>::writable_provider_base()
template <typename ObjectT>
edge::IReadableAcceptorBase& Object<ObjectT>::readable_acceptor_base()
{
// if constexpr (!std::is_base_of_v<edge::IReadableAcceptorBase, ObjectT>)
// {
// LOG(ERROR) << type_name() << " is not a IEgressAcceptorBase";
// throw exceptions::MrcRuntimeError("Object is not a IEgressAcceptorBase");
// }

auto* base = dynamic_cast<edge::IReadableAcceptorBase*>(get_object());
CHECK(base) << type_name() << " is not a IReadableAcceptorBase";
return *base;
Expand All @@ -671,12 +569,6 @@ edge::IReadableAcceptorBase& Object<ObjectT>::readable_acceptor_base()
template <typename ObjectT>
edge::IReadableProviderBase& Object<ObjectT>::readable_provider_base()
{
// if constexpr (!std::is_base_of_v<edge::IReadableProviderBase, ObjectT>)
// {
// LOG(ERROR) << type_name() << " is not a IEgressProviderBase";
// throw exceptions::MrcRuntimeError("Object is not a IEgressProviderBase");
// }

auto* base = dynamic_cast<edge::IReadableProviderBase*>(get_object());
CHECK(base) << type_name() << " is not a IReadableProviderBase";
return *base;
Expand Down
31 changes: 31 additions & 0 deletions cpp/mrc/include/mrc/utils/tuple_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#pragma once

#include <cstddef>
#include <tuple>

namespace mrc::utils {
Expand Down Expand Up @@ -64,4 +65,34 @@ void tuple_for_each(TupleT&& tuple, FuncT&& f)
std::forward<FuncT>(f),
std::make_index_sequence<std::tuple_size<std::decay_t<TupleT>>::value>());
}

/**
* @brief Creates a tuple of N elements of type T. For example, `repeat_tuple_type<int, 3>` would be `std::tuple<int,
* int, int>`
*
* @tparam T The type of the tuple
* @tparam N The number of elements in the tuple
*/
template <typename T, size_t N>
class repeat_tuple_type
{
template <typename = std::make_index_sequence<N>>
struct impl;

template <size_t... Is>
struct impl<std::index_sequence<Is...>>
{
template <size_t>
using wrap = T;

using type = std::tuple<wrap<Is>...>;
};

public:
using type = typename impl<>::type;
};

template <typename T, size_t N>
using repeat_tuple_type_t = typename repeat_tuple_type<T, N>::type;

} // namespace mrc::utils
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ struct PyFuncHolder<ReturnT(ArgsT...)>
return m_cpp_fn(std::forward<ArgsT>(args)...);
}

operator bool() const
{
return !m_is_none;
}

static constexpr auto Signature = pybind11::detail::_("Callable[[") +
pybind11::detail::concat(pybind11::detail::make_caster<ArgsT>::name...) +
pybind11::detail::_("], ") + pybind11::detail::make_caster<return_t>::name +
Expand Down Expand Up @@ -155,14 +160,16 @@ struct PyFuncHolder<ReturnT(ArgsT...)>
// Save the name of the function to help debugging
if (py_fn)
{
m_repr = pybind11::str(py_fn);
m_repr = pybind11::str(py_fn);
m_is_none = false;
}

m_cpp_fn = this->build_cpp_function(std::move(py_fn));
}

cpp_fn_t m_cpp_fn;
std::string m_repr;
bool m_is_none{true};
};

struct OnNextFunction : public PyFuncHolder<void(PyObjectHolder)>
Expand Down
Loading

0 comments on commit 520bf8f

Please sign in to comment.