diff --git a/.devcontainer/conda/Dockerfile b/.devcontainer/conda/Dockerfile index 62c801dd2..af03369a8 100644 --- a/.devcontainer/conda/Dockerfile +++ b/.devcontainer/conda/Dockerfile @@ -13,6 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM rapidsai/devcontainers:23.04-cuda12.1-mambaforge-ubuntu22.04 AS base +FROM rapidsai/devcontainers:24.12-cuda12.1-mambaforge-ubuntu22.04 AS base ENV PATH="${PATH}:/workspaces/mrc/.devcontainer/bin" diff --git a/.devcontainer/conda/devcontainer.json b/.devcontainer/conda/devcontainer.json index 219f6748b..8f5937318 100644 --- a/.devcontainer/conda/devcontainer.json +++ b/.devcontainer/conda/devcontainer.json @@ -35,7 +35,7 @@ "MRC_ROOT": "${containerWorkspaceFolder}", "DEFAULT_CONDA_ENV": "mrc", "MAMBA_NO_BANNER": "1", - "VAULT_HOST": "https://vault.ops.k8s.rapids.ai" + "AWS_ROLE_ARN": "arn:aws:iam::279114543810:role/nv-gha-token-sccache-devs" }, "initializeCommand": [ "${localWorkspaceFolder}/.devcontainer/conda/initialize-command.sh" ], "remoteUser": "coder", diff --git a/.devcontainer/opt/mrc/conda/Dockerfile b/.devcontainer/opt/mrc/conda/Dockerfile index 62c801dd2..af03369a8 100644 --- a/.devcontainer/opt/mrc/conda/Dockerfile +++ b/.devcontainer/opt/mrc/conda/Dockerfile @@ -13,6 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM rapidsai/devcontainers:23.04-cuda12.1-mambaforge-ubuntu22.04 AS base +FROM rapidsai/devcontainers:24.12-cuda12.1-mambaforge-ubuntu22.04 AS base ENV PATH="${PATH}:/workspaces/mrc/.devcontainer/bin" diff --git a/.devcontainer/opt/mrc/conda/devcontainer.json b/.devcontainer/opt/mrc/conda/devcontainer.json index 6046d60a1..3a9f80795 100644 --- a/.devcontainer/opt/mrc/conda/devcontainer.json +++ b/.devcontainer/opt/mrc/conda/devcontainer.json @@ -35,7 +35,7 @@ "MRC_ROOT": "${containerWorkspaceFolder}", "DEFAULT_CONDA_ENV": "mrc", "MAMBA_NO_BANNER": "1", - "VAULT_HOST": "https://vault.ops.k8s.rapids.ai" + "AWS_ROLE_ARN": "arn:aws:iam::279114543810:role/nv-gha-token-sccache-devs" }, "initializeCommand": [ "${localWorkspaceFolder}/.devcontainer/initialize-command.sh" ], "remoteUser": "coder", diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 6f36c3754..f63edd92c 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -58,7 +58,7 @@ jobs: steps: - name: Get PR Info id: get-pr-info - uses: rapidsai/shared-action-workflows/get-pr-info@branch-23.08 + uses: nv-gha-runners/get-pr-info@main if: ${{ startsWith(github.ref_name, 'pull-request/') }} outputs: is_pr: ${{ startsWith(github.ref_name, 'pull-request/') }} diff --git a/CMakeLists.txt b/CMakeLists.txt index 26bdb6ce3..2069257ee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,8 +85,6 @@ project(mrc morpheus_utils_initialize_install_prefix(MRC_USE_CONDA) -rapids_cmake_write_version_file(${CMAKE_BINARY_DIR}/autogenerated/include/mrc/version.hpp) - # Delay enabling CUDA until after we have determined our CXX compiler if(NOT DEFINED CMAKE_CUDA_HOST_COMPILER) message(STATUS "Setting CUDA host compiler to match CXX compiler: ${CMAKE_CXX_COMPILER}") @@ -180,6 +178,47 @@ if(MRC_BUILD_DOCS) add_subdirectory(docs) endif() +# ################################################################################################## +# - install export --------------------------------------------------------------------------------- + +set(doc_string + [=[ +Provide targets for mrc. +]=]) + +set(code_string "") + +set(rapids_project_version_compat SameMinorVersion) + +# Need to explicitly set VERSION ${PROJECT_VERSION} here since rapids_cmake gets +# confused with the `RAPIDS_VERSION` variable we use +rapids_export(INSTALL ${PROJECT_NAME} + EXPORT_SET ${PROJECT_NAME}-exports + GLOBAL_TARGETS libmrc pymrc + COMPONENTS python + COMPONENTS_EXPORT_SET ${PROJECT_NAME}-python-exports + VERSION ${PROJECT_VERSION} + NAMESPACE mrc:: + DOCUMENTATION doc_string + FINAL_CODE_BLOCK code_string +) + +# ################################################################################################## +# - build export ----------------------------------------------------------------------------------- +rapids_export(BUILD ${PROJECT_NAME} + EXPORT_SET ${PROJECT_NAME}-exports + GLOBAL_TARGETS libmrc pymrc + COMPONENTS python + COMPONENTS_EXPORT_SET ${PROJECT_NAME}-python-exports + VERSION ${PROJECT_VERSION} + LANGUAGES C CXX CUDA + NAMESPACE mrc:: + DOCUMENTATION doc_string + FINAL_CODE_BLOCK code_string +) + +# ################################################################################################## +# - debug info ------------------------------------------------------------------------------------- if (MRC_ENABLE_DEBUG_INFO) morpheus_utils_print_all_targets() diff --git a/ci/iwyu/mappings.imp b/ci/iwyu/mappings.imp index 627e20127..ff6a9f562 100644 --- a/ci/iwyu/mappings.imp +++ b/ci/iwyu/mappings.imp @@ -3,6 +3,7 @@ ## Include mappings # stdlib +{ "include": [ "", private, "", "public" ] }, { "include": [ "", private, "", "public" ] }, { "include": [ "", private, "", "public" ] }, { "include": [ "", private, "", "public" ] }, diff --git a/ci/scripts/cpp_checks.sh b/ci/scripts/cpp_checks.sh index c9127cc36..95cf68cf9 100755 --- a/ci/scripts/cpp_checks.sh +++ b/ci/scripts/cpp_checks.sh @@ -80,9 +80,8 @@ if [[ -n "${MRC_MODIFIED_FILES}" ]]; then # Include What You Use if [[ "${SKIP_IWYU}" == "" ]]; then - # Remove .h, .hpp, and .cu files from the modified list shopt -s extglob - IWYU_MODIFIED_FILES=( "${MRC_MODIFIED_FILES[@]/*.@(h|hpp|cu)/}" ) + IWYU_MODIFIED_FILES=( "${MRC_MODIFIED_FILES[@]}" ) if [[ -n "${IWYU_MODIFIED_FILES}" ]]; then # Get the list of compiled files relative to this directory diff --git a/cpp/mrc/CMakeLists.txt b/cpp/mrc/CMakeLists.txt index 88ac29a70..10d57f9e3 100644 --- a/cpp/mrc/CMakeLists.txt +++ b/cpp/mrc/CMakeLists.txt @@ -16,6 +16,11 @@ # ################################################################################################## # - libmrc ----------------------------------------------------------------------------------------- +include(GenerateExportHeader) + +# Generate the version header file +rapids_cmake_write_version_file(${CMAKE_CURRENT_BINARY_DIR}/autogenerated/include/mrc/version.hpp) + # Keep all source files sorted!!! add_library(libmrc src/internal/codable/codable_storage.cpp @@ -125,6 +130,7 @@ add_library(libmrc src/public/cuda/sync.cpp src/public/edge/edge_adapter_registry.cpp src/public/edge/edge_builder.cpp + src/public/exceptions/checks.cpp src/public/exceptions/exception_catcher.cpp src/public/manifold/manifold.cpp src/public/memory/buffer_view.cpp @@ -204,6 +210,47 @@ target_compile_features(libmrc PUBLIC cxx_std_20) set_target_properties(libmrc PROPERTIES OUTPUT_NAME ${PROJECT_NAME}) +# Generates an include file for specifying external linkage since everything is hidden by default +generate_export_header(libmrc + NO_EXPORT_MACRO_NAME + MRC_LOCAL + EXPORT_FILE_NAME + "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/include/mrc/export.h" +) + +# ################################################################################################## +# - source information ----------------------------------------------------------------------------- + +# Ideally, we dont use glob here. But there is no good way to guarantee you dont miss anything like *.cpp +file(GLOB_RECURSE libmrc_public_headers + LIST_DIRECTORIES FALSE + CONFIGURE_DEPENDS + "${CMAKE_CURRENT_SOURCE_DIR}/include/mrc/*" +) + +# Add headers to target sources file_set so they can be installed +# https://discourse.cmake.org/t/installing-headers-the-modern-way-regurgitated-and-revisited/3238/3 +target_sources(libmrc + PUBLIC + FILE_SET public_headers + TYPE HEADERS + BASE_DIRS "${CMAKE_CURRENT_SOURCE_DIR}/include" + FILES + ${libmrc_public_headers} +) + +# Add generated headers to fileset +target_sources(libmrc + PUBLIC + FILE_SET public_headers + TYPE HEADERS + BASE_DIRS + "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/include" + FILES + "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/include/mrc/version.hpp" + "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/include/mrc/export.h" +) + # ################################################################################################## # - install targets -------------------------------------------------------------------------------- rapids_cmake_install_lib_dir(lib_dir) @@ -215,12 +262,7 @@ install( DESTINATION ${lib_dir} EXPORT ${PROJECT_NAME}-exports COMPONENT Core -) - -install( - DIRECTORY include/ - DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} - COMPONENT Core + FILE_SET public_headers ) # ################################################################################################## @@ -234,37 +276,3 @@ endif() if(MRC_BUILD_BENCHMARKS) add_subdirectory(benchmarks) endif() - -# ################################################################################################## -# - install export --------------------------------------------------------------------------------- -set(doc_string - [=[ -Provide targets for mrc. -]=]) - -set(code_string "") - -set(rapids_project_version_compat SameMinorVersion) - -# Need to explicitly set VERSION ${PROJECT_VERSION} here since rapids_cmake gets -# confused with the `RAPIDS_VERSION` variable we use -rapids_export(INSTALL ${PROJECT_NAME} - EXPORT_SET ${PROJECT_NAME}-exports - GLOBAL_TARGETS libmrc - VERSION ${PROJECT_VERSION} - NAMESPACE mrc:: - DOCUMENTATION doc_string - FINAL_CODE_BLOCK code_string -) - -# ################################################################################################## -# - build export ---------------------------------------------------------------------------------- -rapids_export(BUILD ${PROJECT_NAME} - EXPORT_SET ${PROJECT_NAME}-exports - GLOBAL_TARGETS libmrc - VERSION ${PROJECT_VERSION} - LANGUAGES C CXX CUDA - NAMESPACE mrc:: - DOCUMENTATION doc_string - FINAL_CODE_BLOCK code_string -) diff --git a/cpp/mrc/include/mrc/api.hpp b/cpp/mrc/include/mrc/api.hpp deleted file mode 100644 index ede97c7e9..000000000 --- a/cpp/mrc/include/mrc/api.hpp +++ /dev/null @@ -1,77 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -// Generic helper definitions for shared library support from -// https://gcc.gnu.org/wiki/Visibility - -// For every non-templated non-static function definition in your library (both headers and source files), decide if it -// is publicly used or internally used: If it is publicly used, mark with MRC_API like this: extern MRC_API PublicFunc() - -// If it is only internally used, mark with MRC_LOCAL like this: extern MRC_LOCAL PublicFunc() Remember, static -// functions need no demarcation, nor does anything in an anonymous namespace, nor does anything which is templated. - -// For every non-templated class definition in your library (both headers and source files), decide if it is publicly -// used or internally used: If it is publicly used, mark with MRC_API like this: class MRC_API PublicClass - -// If it is only internally used, mark with MRC_LOCAL like this: class MRC_LOCAL PublicClass - -// Individual member functions of an exported class that are not part of the interface, in particular ones which are -// private, and are not used by friend code, should be marked individually with MRC_LOCAL. - -// In your build system (Makefile etc), you will probably wish to add the -fvisibility=hidden and -// -fvisibility-inlines-hidden options to the command line arguments of every GCC invocation. Remember to test your -// library thoroughly afterwards, including that all exceptions correctly traverse shared object boundaries. - -#if defined _WIN32 || defined __CYGWIN__ - #define MRC_HELPER_DLL_IMPORT __declspec(dllimport) - #define MRC_HELPER_DLL_EXPORT __declspec(dllexport) - #define MRC_HELPER_DLL_LOCAL -#else - #if __GNUC__ >= 4 - #define MRC_HELPER_DLL_IMPORT __attribute__((visibility("default"))) - #define MRC_HELPER_DLL_EXPORT __attribute__((visibility("default"))) - #define MRC_HELPER_DLL_LOCAL __attribute__((visibility("hidden"))) - #else - #define MRC_HELPER_DLL_IMPORT - #define MRC_HELPER_DLL_EXPORT - #define MRC_HELPER_DLL_LOCAL - #endif -#endif - -// Now we use the generic helper definitions above to define MRC_API and MRC_LOCAL. -// MRC_API is used for the public API symbols. It either DLL imports or DLL exports (or does nothing for static build) -// MRC_LOCAL is used for non-api symbols. - -#define MRC_DLL // we alway build the .so/.dll -#ifdef libmrc_EXPORTS - #define MRC_DLL_EXPORTS -#endif - -#ifdef MRC_DLL // defined if MRC is compiled as a DLL - #ifdef MRC_DLL_EXPORTS // defined if we are building the MRC DLL (instead of using it) - #define MRC_API MRC_HELPER_DLL_EXPORT - #else - #define MRC_API MRC_HELPER_DLL_IMPORT - #endif // MRC_DLL_EXPORTS - #define MRC_LOCAL MRC_HELPER_DLL_LOCAL -#else // MRC_DLL is not defined: this means MRC is a static lib. - #define MRC_API - #define MRC_LOCAL -static_assert(false, "always build the .so/.dll") -#endif // MRC_DLL diff --git a/cpp/mrc/include/mrc/channel/status.hpp b/cpp/mrc/include/mrc/channel/status.hpp index 91f6b2800..cf307b9b1 100644 --- a/cpp/mrc/include/mrc/channel/status.hpp +++ b/cpp/mrc/include/mrc/channel/status.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,6 +17,8 @@ #pragma once +#include + namespace mrc::channel { enum class Status @@ -29,4 +31,25 @@ enum class Status error }; +static inline std::ostream& operator<<(std::ostream& os, const Status& s) +{ + switch (s) + { + case Status::success: + return os << "success"; + case Status::empty: + return os << "empty"; + case Status::full: + return os << "full"; + case Status::closed: + return os << "closed"; + case Status::timeout: + return os << "timeout"; + case Status::error: + return os << "error"; + default: + throw std::logic_error("Unsupported channel::Status enum. Was a new value added recently?"); + } } + +} // namespace mrc::channel diff --git a/cpp/mrc/include/mrc/core/fiber_pool.hpp b/cpp/mrc/include/mrc/core/fiber_pool.hpp index 09838d473..94ff9b8ad 100644 --- a/cpp/mrc/include/mrc/core/fiber_pool.hpp +++ b/cpp/mrc/include/mrc/core/fiber_pool.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -43,14 +43,14 @@ class FiberPool [[nodiscard]] virtual std::size_t thread_count() const = 0; template - auto enqueue(std::uint32_t index, F&& f, ArgsT&&... args) -> Future::type> + auto enqueue(std::uint32_t index, F&& f, ArgsT&&... args) -> Future> { return task_queue(index).enqueue(f, std::forward(args)...); } template auto enqueue(std::uint32_t index, MetaDataT&& md, F&& f, ArgsT&&... args) - -> Future::type> + -> Future> { return task_queue(index).enqueue(std::forward(md), std::forward(f), std::forward(args)...); } diff --git a/cpp/mrc/include/mrc/core/task_queue.hpp b/cpp/mrc/include/mrc/core/task_queue.hpp index 492f0a165..b9db8127d 100644 --- a/cpp/mrc/include/mrc/core/task_queue.hpp +++ b/cpp/mrc/include/mrc/core/task_queue.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -40,7 +40,7 @@ class FiberTaskQueue virtual ~FiberTaskQueue() = default; template - auto enqueue(F&& f, ArgsT&&... args) -> Future::type> + auto enqueue(F&& f, ArgsT&&... args) -> Future> { FiberMetaData meta_data; return enqueue(meta_data, std::forward(f), std::forward(args)...); @@ -48,7 +48,7 @@ class FiberTaskQueue template auto enqueue(const FiberMetaData& meta_data, F&& f, ArgsT&&... args) - -> Future::type> + -> Future> { FiberMetaData copy = meta_data; return enqueue(std::move(copy), std::forward(f), std::forward(args)...); @@ -56,7 +56,7 @@ class FiberTaskQueue template auto enqueue(FiberMetaData&& meta_data, F&& f, ArgsT&&... args) - -> Future::type> + -> Future> { if (task_queue().is_closed()) { @@ -64,7 +64,7 @@ class FiberTaskQueue } using namespace boost::fibers; - using return_type_t = typename std::result_of::type; + using return_type_t = typename std::invoke_result_t; packaged_task task(std::bind(std::forward(f), std::forward(args)...)); future future = task.get_future(); diff --git a/cpp/mrc/include/mrc/edge/deferred_edge.hpp b/cpp/mrc/include/mrc/edge/deferred_edge.hpp index bf333b88b..c51aa2633 100644 --- a/cpp/mrc/include/mrc/edge/deferred_edge.hpp +++ b/cpp/mrc/include/mrc/edge/deferred_edge.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -30,7 +30,7 @@ namespace mrc::edge { -class DeferredWritableMultiEdgeBase : public IMultiWritableAcceptorBase, +class DeferredWritableMultiEdgeBase : public virtual IMultiWritableAcceptorBase, public virtual IEdgeWritableBase, public virtual EdgeBase { diff --git a/cpp/mrc/include/mrc/edge/edge.hpp b/cpp/mrc/include/mrc/edge/edge.hpp index ea2a35e07..66a40ac3d 100644 --- a/cpp/mrc/include/mrc/edge/edge.hpp +++ b/cpp/mrc/include/mrc/edge/edge.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,6 +24,7 @@ #include "mrc/exceptions/runtime_error.hpp" #include "mrc/type_traits.hpp" #include "mrc/utils/string_utils.hpp" +#include "mrc/utils/type_utils.hpp" #include #include @@ -257,13 +258,25 @@ class EdgeTypeInfo m_unwrapped_type(unwrapped_type), m_is_deferred(is_deferred) { + if (m_full_type.has_value()) + { + m_full_type_str = type_name(m_full_type.value()); + } + + if (m_unwrapped_type.has_value()) + { + m_unwrapped_type_str = type_name(m_unwrapped_type.value()); + } + CHECK((m_is_deferred && !m_full_type.has_value() && !m_unwrapped_type.has_value()) || (!m_is_deferred && m_full_type.has_value() && m_unwrapped_type.has_value())) << "Inconsistent deferred setting with concrete types"; } std::optional m_full_type; // Includes any wrappers like shared_ptr + std::string m_full_type_str; // For debugging purposes only std::optional m_unwrapped_type; // Excludes any wrappers like shared_ptr if they exist + std::string m_unwrapped_type_str; // For debugging purposes only bool m_is_deferred{false}; // Whether or not this type is deferred or concrete }; diff --git a/cpp/mrc/include/mrc/edge/edge_builder.hpp b/cpp/mrc/include/mrc/edge/edge_builder.hpp index 78e88b577..a4151dedc 100644 --- a/cpp/mrc/include/mrc/edge/edge_builder.hpp +++ b/cpp/mrc/include/mrc/edge/edge_builder.hpp @@ -363,6 +363,31 @@ class DeferredWritableMultiEdge : public MultiEdgeHolder, } protected: + bool has_writable_edge(const std::size_t& key) const override + { + return MultiEdgeHolder::has_edge_connection(key); + } + + void release_writable_edge(const std::size_t& key) override + { + return MultiEdgeHolder::release_edge_connection(key); + } + + void release_writable_edges() override + { + return MultiEdgeHolder::release_edge_connections(); + } + + size_t writable_edge_count() const override + { + return MultiEdgeHolder::edge_connection_count(); + } + + std::vector writable_edge_keys() const override + { + return MultiEdgeHolder::edge_connection_keys(); + } + std::shared_ptr> get_writable_edge(std::size_t edge_idx) const { return std::dynamic_pointer_cast>(this->get_connected_edge(edge_idx)); @@ -457,39 +482,79 @@ std::shared_ptr EdgeBuilder::adapt_readable_edge(std::shared // Put make edge in the mrc namespace since it is used so often namespace mrc { +/** + * @brief Since its not currently possible to dynamically build an error message for `static_assert`, provide all of the + * information that would be in the error message about the types as template parameters. This way the value of the + * template parameters is still displayed in the error message. + * + */ +namespace detail { +template +void display_make_edge_error_message() +{ + static_assert(!sizeof(SourceT), + "Arguments to make_edge were incorrect. Ensure you are providing either " + "WritableAcceptor->WritableProvider or ReadableProvider->ReadableAcceptor"); +} +} // namespace detail + template void make_edge(SourceT& source, SinkT& sink) { using source_full_t = SourceT; using sink_full_t = SinkT; - if constexpr (is_base_of_template::value && - is_base_of_template::value) + constexpr bool IsSourceIWritableAcceptor = is_base_of_template::value; + constexpr bool IsSourceIReadableProvider = is_base_of_template::value; + + constexpr bool IsSinkIWritableProvider = is_base_of_template::value; + constexpr bool IsSinkIReadableAcceptor = is_base_of_template::value; + + constexpr bool IsSourceIWritableAcceptorBase = std::is_base_of_v; + constexpr bool IsSourceIReadableProviderBase = std::is_base_of_v; + + constexpr bool IsSinkIWritableProviderBase = std::is_base_of_v; + constexpr bool IsSinkIReadableAcceptorBase = std::is_base_of_v; + + if constexpr (IsSourceIWritableAcceptor && IsSinkIWritableProvider) { // Call the typed version for ingress provider/acceptor edge::EdgeBuilder::make_edge_writable(source, sink); } - else if constexpr (is_base_of_template::value && - is_base_of_template::value) + else if constexpr (IsSourceIReadableProvider && IsSinkIReadableAcceptor) { // Call the typed version for egress provider/acceptor edge::EdgeBuilder::make_edge_readable(source, sink); } - else if constexpr (std::is_base_of_v && - std::is_base_of_v) + else if constexpr (IsSourceIWritableAcceptorBase && IsSinkIWritableProviderBase) { edge::EdgeBuilder::make_edge_writable_typeless(source, sink); } - else if constexpr (std::is_base_of_v && - std::is_base_of_v) + else if constexpr (IsSourceIReadableProviderBase && IsSinkIReadableAcceptorBase) { edge::EdgeBuilder::make_edge_readable_typeless(source, sink); } else { - static_assert(!sizeof(source_full_t), - "Arguments to make_edge were incorrect. Ensure you are providing either " - "WritableAcceptor->WritableProvider or ReadableProvider->ReadableAcceptor"); + detail::display_make_edge_error_message(); } } diff --git a/cpp/mrc/include/mrc/edge/edge_channel.hpp b/cpp/mrc/include/mrc/edge/edge_channel.hpp index 5da85d74c..d821d4112 100644 --- a/cpp/mrc/include/mrc/edge/edge_channel.hpp +++ b/cpp/mrc/include/mrc/edge/edge_channel.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +20,7 @@ #include "mrc/edge/edge_readable.hpp" #include "mrc/edge/edge_writable.hpp" #include "mrc/edge/forward.hpp" +#include "mrc/utils/macros.hpp" #include @@ -89,6 +90,24 @@ class EdgeChannel { CHECK(m_channel) << "Cannot create an EdgeChannel from an empty pointer"; } + + EdgeChannel(EdgeChannel&& other) : m_channel(std::move(other.m_channel)) {} + + EdgeChannel& operator=(EdgeChannel&& other) + { + if (this == &other) + { + return *this; + } + + m_channel = std::move(other.m_channel); + + return *this; + } + + // This should not be copyable because it requires passing in a unique_ptr + DELETE_COPYABILITY(EdgeChannel); + virtual ~EdgeChannel() = default; [[nodiscard]] std::shared_ptr> get_reader() const diff --git a/cpp/mrc/include/mrc/edge/edge_connector.hpp b/cpp/mrc/include/mrc/edge/edge_connector.hpp index 6df92b1f2..be75019a5 100644 --- a/cpp/mrc/include/mrc/edge/edge_connector.hpp +++ b/cpp/mrc/include/mrc/edge/edge_connector.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -99,7 +99,7 @@ struct EdgeConnector EdgeAdapterRegistry::register_egress_converter( typeid(InputT), typeid(OutputT), - [lambda_fn](std::shared_ptr channel) { + [lambda_fn](std::shared_ptr channel) { std::shared_ptr> egress = std::dynamic_pointer_cast>( channel); diff --git a/cpp/mrc/include/mrc/edge/edge_holder.hpp b/cpp/mrc/include/mrc/edge/edge_holder.hpp index 0262a7e71..ac920b12f 100644 --- a/cpp/mrc/include/mrc/edge/edge_holder.hpp +++ b/cpp/mrc/include/mrc/edge/edge_holder.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -211,6 +211,18 @@ class MultiEdgeHolder edge_pair.init_owned_edge(std::move(edge)); } + void init_connected_edge(KeyT key, std::shared_ptr> edge) + { + auto& edge_pair = this->get_edge_pair(key, true); + + edge_pair.init_connected_edge(std::move(edge)); + } + + bool has_edge_connection(const KeyT& key) const + { + return m_edges.contains(key); + } + std::shared_ptr get_edge_connection(const KeyT& key) const { auto& edge_pair = this->get_edge_pair(key); diff --git a/cpp/mrc/include/mrc/edge/edge_readable.hpp b/cpp/mrc/include/mrc/edge/edge_readable.hpp index 2495341fd..2f68d9550 100644 --- a/cpp/mrc/include/mrc/edge/edge_readable.hpp +++ b/cpp/mrc/include/mrc/edge/edge_readable.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -192,8 +192,32 @@ class IReadableAcceptorBase virtual EdgeTypeInfo readable_acceptor_type() const = 0; }; +template +class IMultiReadableProviderBase +{ + public: + virtual bool has_readable_edge(const KeyT& key) const = 0; + virtual void release_readable_edge(const KeyT& key) = 0; + virtual void release_readable_edges() = 0; + virtual size_t readable_edge_count() const = 0; + virtual std::vector readable_edge_keys() const = 0; + virtual std::shared_ptr get_readable_edge_handle(KeyT key) const = 0; +}; + +template +class IMultiReadableAcceptorBase +{ + public: + virtual bool has_readable_edge(const KeyT& key) const = 0; + virtual void release_readable_edge(const KeyT& key) = 0; + virtual void release_readable_edges() = 0; + virtual size_t readable_edge_count() const = 0; + virtual std::vector readable_edge_keys() const = 0; + virtual void set_readable_edge_handle(KeyT key, std::shared_ptr egress) = 0; +}; + template -class IReadableProvider : public IReadableProviderBase +class IReadableProvider : public virtual IReadableProviderBase { public: EdgeTypeInfo readable_provider_type() const override @@ -203,7 +227,7 @@ class IReadableProvider : public IReadableProviderBase }; template -class IReadableAcceptor : public IReadableAcceptorBase +class IReadableAcceptor : public virtual IReadableAcceptorBase { public: EdgeTypeInfo readable_acceptor_type() const override @@ -212,4 +236,11 @@ class IReadableAcceptor : public IReadableAcceptorBase } }; +template +class IMultiReadableProvider : public virtual IMultiReadableProviderBase +{}; + +template +class IMultiReadableAcceptor : public virtual IMultiReadableAcceptorBase +{}; } // namespace mrc::edge diff --git a/cpp/mrc/include/mrc/edge/edge_writable.hpp b/cpp/mrc/include/mrc/edge/edge_writable.hpp index 04db574f9..9d0824279 100644 --- a/cpp/mrc/include/mrc/edge/edge_writable.hpp +++ b/cpp/mrc/include/mrc/edge/edge_writable.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -203,15 +203,32 @@ class IWritableAcceptorBase virtual EdgeTypeInfo writable_acceptor_type() const = 0; }; +template +class IMultiWritableProviderBase +{ + public: + virtual bool has_writable_edge(const KeyT& key) const = 0; + virtual void release_writable_edge(const KeyT& key) = 0; + virtual void release_writable_edges() = 0; + virtual size_t writable_edge_count() const = 0; + virtual std::vector writable_edge_keys() const = 0; + virtual std::shared_ptr get_writable_edge_handle(KeyT key) const = 0; +}; + template class IMultiWritableAcceptorBase { public: + virtual bool has_writable_edge(const KeyT& key) const = 0; + virtual void release_writable_edge(const KeyT& key) = 0; + virtual void release_writable_edges() = 0; + virtual size_t writable_edge_count() const = 0; + virtual std::vector writable_edge_keys() const = 0; virtual void set_writable_edge_handle(KeyT key, std::shared_ptr ingress) = 0; }; template -class IWritableProvider : public IWritableProviderBase +class IWritableProvider : public virtual IWritableProviderBase { public: EdgeTypeInfo writable_provider_type() const override @@ -221,7 +238,7 @@ class IWritableProvider : public IWritableProviderBase }; template -class IWritableAcceptor : public IWritableAcceptorBase +class IWritableAcceptor : public virtual IWritableAcceptorBase { public: EdgeTypeInfo writable_acceptor_type() const override @@ -230,8 +247,12 @@ class IWritableAcceptor : public IWritableAcceptorBase } }; -template -class IMultiWritableAcceptor : public IMultiWritableAcceptorBase +template +class IMultiWritableProvider : public virtual IMultiWritableProviderBase +{}; + +template +class IMultiWritableAcceptor : public virtual IMultiWritableAcceptorBase {}; } // namespace mrc::edge diff --git a/cpp/mrc/include/mrc/exceptions/checks.hpp b/cpp/mrc/include/mrc/exceptions/checks.hpp new file mode 100644 index 000000000..daef0d411 --- /dev/null +++ b/cpp/mrc/include/mrc/exceptions/checks.hpp @@ -0,0 +1,34 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace mrc::exceptions { + +void throw_failed_check_exception(const std::string& file, + const std::string& function, + unsigned int line, + const std::string& msg = ""); + +#define MRC_CHECK_THROW(condition) \ + for (std::stringstream ss; !(condition); \ + ::mrc::exceptions::throw_failed_check_exception(__FILE__, __func__, __LINE__, ss.str())) \ + ss + +} // namespace mrc::exceptions diff --git a/cpp/mrc/include/mrc/exceptions/exception_catcher.hpp b/cpp/mrc/include/mrc/exceptions/exception_catcher.hpp index 98c4a7d6d..75a3e5906 100644 --- a/cpp/mrc/include/mrc/exceptions/exception_catcher.hpp +++ b/cpp/mrc/include/mrc/exceptions/exception_catcher.hpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,12 +15,12 @@ * limitations under the License. */ +#pragma once + #include #include #include -#pragma once - namespace mrc { /** diff --git a/cpp/mrc/include/mrc/experimental/modules/mirror_tap/mirror_tap.hpp b/cpp/mrc/include/mrc/experimental/modules/mirror_tap/mirror_tap.hpp index 8b0d12099..9124ca451 100644 --- a/cpp/mrc/include/mrc/experimental/modules/mirror_tap/mirror_tap.hpp +++ b/cpp/mrc/include/mrc/experimental/modules/mirror_tap/mirror_tap.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -88,8 +88,8 @@ void MirrorTapModule::initialize(segment::IBuilder& builder) builder.make_edge(bcast, builder.get_egress(m_egress_name)); // to mirror tap // Register the submodules output as one of this module's outputs - register_input_port("input", bcast); - register_output_port("output", bcast); + builder.register_module_input("input", bcast); + builder.register_module_output("output", bcast); } template diff --git a/cpp/mrc/include/mrc/experimental/modules/mirror_tap/mirror_tap_stream.hpp b/cpp/mrc/include/mrc/experimental/modules/mirror_tap/mirror_tap_stream.hpp index b053a4098..1a6d8b68d 100644 --- a/cpp/mrc/include/mrc/experimental/modules/mirror_tap/mirror_tap_stream.hpp +++ b/cpp/mrc/include/mrc/experimental/modules/mirror_tap/mirror_tap_stream.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -89,7 +89,7 @@ void MirrorTapStreamModule::initialize(segment::IBuilder& builder) builder.make_edge(mirror_ingress, m_stream_buffer->input_port("input")); - register_output_port("output", m_stream_buffer->output_port("output")); + builder.register_module_output("output", m_stream_buffer->output_port("output")); } template diff --git a/cpp/mrc/include/mrc/experimental/modules/stream_buffer/stream_buffer_module.hpp b/cpp/mrc/include/mrc/experimental/modules/stream_buffer/stream_buffer_module.hpp index d7eeece45..6feafe9ee 100644 --- a/cpp/mrc/include/mrc/experimental/modules/stream_buffer/stream_buffer_module.hpp +++ b/cpp/mrc/include/mrc/experimental/modules/stream_buffer/stream_buffer_module.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -138,8 +138,8 @@ void StreamBufferModule::initialize(segment::IBuil } }); - register_input_port("input", buffer_sink); - register_output_port("output", buffer_source); + builder.register_module_input("input", buffer_sink); + builder.register_module_output("output", buffer_source); } template class StreamBufferTypeT> diff --git a/cpp/mrc/include/mrc/manifold/egress.hpp b/cpp/mrc/include/mrc/manifold/egress.hpp index 781122d61..6b7b67168 100644 --- a/cpp/mrc/include/mrc/manifold/egress.hpp +++ b/cpp/mrc/include/mrc/manifold/egress.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -53,7 +53,7 @@ class TypedEgress : public EgressDelegate }; template -class RoundRobinEgress : public node::Router, public TypedEgress +class RoundRobinEgress : public node::DynamicRouterComponentBase, public TypedEgress { protected: SegmentAddress determine_key_for_value(const T& t) override diff --git a/cpp/mrc/include/mrc/modules/sample_modules.hpp b/cpp/mrc/include/mrc/modules/sample_modules.hpp index 23db4db57..3ba6e4bfd 100644 --- a/cpp/mrc/include/mrc/modules/sample_modules.hpp +++ b/cpp/mrc/include/mrc/modules/sample_modules.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -183,7 +183,7 @@ void TemplateModule::initialize(segment::IBuilder& builder) }); // Register the submodules output as one of this module's outputs - register_output_port("source", source); + builder.register_module_output("source", source); } template @@ -248,7 +248,7 @@ void TemplateWithInitModule::initialize(segment::IBuil }); // Register the submodules output as one of this module's outputs - register_output_port("source", source); + builder.register_module_output("source", source); } template diff --git a/cpp/mrc/include/mrc/modules/segment_modules.hpp b/cpp/mrc/include/mrc/modules/segment_modules.hpp index 3cabb0d03..508b6d5f9 100644 --- a/cpp/mrc/include/mrc/modules/segment_modules.hpp +++ b/cpp/mrc/include/mrc/modules/segment_modules.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -138,6 +138,23 @@ class SegmentModule */ virtual void initialize(segment::IBuilder& builder) = 0; + private: + /** + * @brief Registers an object with the module to keep it alive + * + * @param name The name of the object + * @param object The object to register + */ + void register_object(std::string name, std::shared_ptr object); + + /** + * @brief Find an object by name. Must be registered with the module + * + * @param name The name of the object + * @return segment::ObjectProperties& + */ + segment::ObjectProperties& find_object(const std::string& name) const; + /* Interface Functions */ /** * Register an input port that should be exposed for the module @@ -153,7 +170,6 @@ class SegmentModule */ void register_output_port(std::string output_name, std::shared_ptr object); - private: /** * Register an input port that should be exposed for the module, with explicit type index. This is * necessary for Objects that aren't explicit Source or Sink types (e.g. a custom object type) @@ -188,6 +204,9 @@ class SegmentModule segment_module_port_map_t m_input_ports{}; segment_module_port_map_t m_output_ports{}; + // Maintain a map of all objects to keep them alive. These are registered as internal names + std::map> m_objects; + const nlohmann::json m_config; friend class segment::BuilderDefinition; diff --git a/cpp/mrc/include/mrc/node/generic_source.hpp b/cpp/mrc/include/mrc/node/generic_source.hpp index 19956d422..7f16492d1 100644 --- a/cpp/mrc/include/mrc/node/generic_source.hpp +++ b/cpp/mrc/include/mrc/node/generic_source.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -68,7 +68,7 @@ GenericSource::GenericSource() : {} template -class GenericSourceComponent : public ForwardingEgressProvider +class GenericSourceComponent : public ForwardingReadableProvider { public: GenericSourceComponent() = default; diff --git a/cpp/mrc/include/mrc/node/node_parent.hpp b/cpp/mrc/include/mrc/node/node_parent.hpp new file mode 100644 index 000000000..280e5395d --- /dev/null +++ b/cpp/mrc/include/mrc/node/node_parent.hpp @@ -0,0 +1,47 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace mrc::node { + +template +class HomogeneousNodeParent +{ + public: + using child_node_t = ChildT; + + virtual std::map> get_children_refs( + std::optional child_name = std::nullopt) const = 0; +}; + +template +class HeterogeneousNodeParent +{ + public: + using child_types_t = std::tuple; + + virtual std::tuple>...> get_children_refs() const = 0; +}; + +} // namespace mrc::node diff --git a/cpp/mrc/include/mrc/node/operators/combine_latest.hpp b/cpp/mrc/include/mrc/node/operators/combine_latest.hpp index a5d50d217..cef6c0d3c 100644 --- a/cpp/mrc/include/mrc/node/operators/combine_latest.hpp +++ b/cpp/mrc/include/mrc/node/operators/combine_latest.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,8 +18,11 @@ #pragma once #include "mrc/channel/status.hpp" +#include "mrc/node/node_parent.hpp" #include "mrc/node/sink_properties.hpp" #include "mrc/node/source_properties.hpp" +#include "mrc/types.hpp" +#include "mrc/utils/tuple_utils.hpp" #include "mrc/utils/type_utils.hpp" #include @@ -29,152 +32,88 @@ #include #include +// IWYU pragma: begin_exports + namespace mrc::node { -// template -// class ParameterPackIndexer -// { -// public: -// ParameterPackIndexer(TypesT... ts) : ParameterPackIndexer(std::make_index_sequence{}, ts...) -// {} - -// std::tuple...> tup; - -// private: -// template -// ParameterPackIndexer(std::index_sequence const& /*unused*/, TypesT... ts) : tup{std::make_tuple(ts, -// Is)...} -// {} -// }; - -// template -// constexpr size_t getTypeIndexInTemplateList() -// { -// if constexpr (std::is_same::value) -// { -// return 0; -// } -// else -// { -// return 1 + getTypeIndexInTemplateList(); -// } -// } - -namespace detail { -struct Surely +class CombineLatestTypelessBase { - template - auto operator()(const T&... t) const -> decltype(std::make_tuple(t.value()...)) - { - return std::make_tuple(t.value()...); - } + public: + virtual ~CombineLatestTypelessBase() = default; }; -} // namespace detail -// template -// inline auto surely(const std::tuple& tpl) -> decltype(rxcpp::util::apply(tpl, detail::surely())) -// { -// return rxcpp::util::apply(tpl, detail::surely()); -// } +template +class CombineLatestBase; -template -inline auto surely2(const std::tuple& tpl) -{ - return std::apply([](auto... args) { - return std::make_tuple(args.value()...); - }); -} - -// template -// static auto surely2(const std::tuple& tpl, std::index_sequence) -// { -// return std::make_tuple(std::make_shared>(*self)...); -// } - -// template -// struct IndexTypePair -// { -// static constexpr size_t index{i}; -// using Type = T; -// }; - -// template -// struct make_index_type_tuple_helper -// { -// template -// struct idx; - -// template -// struct idx> -// { -// using tuple_type = std::tuple...>; -// }; - -// using tuple_type = typename idx>::tuple_type; -// }; - -// template -// using make_index_type_tuple = typename make_index_type_tuple_helper::tuple_type; - -template -class CombineLatest : public WritableAcceptor> +template +class CombineLatestBase, OutputT> + : public CombineLatestTypelessBase, + public WritableAcceptor, + public HeterogeneousNodeParent...> { template - static auto build_ingress(CombineLatest* self, std::index_sequence /*unused*/) + static auto build_ingress(CombineLatestBase* self, std::index_sequence /*unused*/) { return std::make_tuple(std::make_shared>(*self)...); } - public: - CombineLatest() : - m_upstream_holders(build_ingress(const_cast(this), std::index_sequence_for{})) + template + static std::tuple>>...> + build_child_pairs(CombineLatestBase* self, std::index_sequence /*unused*/) { - // auto a = build_ingress(const_cast(this), std::index_sequence_for{}); + return std::make_tuple( + std::make_pair(MRC_CONCAT_STR("sink[" << Is << "]"), std::ref(*self->get_sink()))...); } - virtual ~CombineLatest() = default; + public: + using input_tuple_t = std::tuple; + using output_t = OutputT; + + CombineLatestBase() : + m_upstream_holders(build_ingress(const_cast(this), std::index_sequence_for{})) + {} + + ~CombineLatestBase() override = default; template - std::shared_ptr>> get_sink() const + std::shared_ptr>> get_sink() const { return std::get(m_upstream_holders); } + std::tuple>>...> get_children_refs() + const override + { + return build_child_pairs(const_cast(this), std::index_sequence_for{}); + } + protected: template - class Upstream : public WritableProvider> + class Upstream : public ForwardingWritableProvider> { - using upstream_t = NthTypeOf; + using upstream_t = NthTypeOf; public: - Upstream(CombineLatest& parent) + Upstream(CombineLatestBase& parent) : m_parent(parent) {} + + protected: + channel::Status on_next(upstream_t&& data) override { - this->init_owned_edge(std::make_shared(parent)); + return m_parent.upstream_await_write(std::move(data)); } - private: - class InnerEdge : public edge::IEdgeWritable> + void on_complete() override { - public: - InnerEdge(CombineLatest& parent) : m_parent(parent) {} - ~InnerEdge() - { - m_parent.edge_complete(); - } - - virtual channel::Status await_write(upstream_t&& data) - { - return m_parent.set_upstream_value(std::move(data)); - } - - private: - CombineLatest& m_parent; - }; + m_parent.edge_complete(); + } + + private: + CombineLatestBase& m_parent; }; private: template - channel::Status set_upstream_value(NthTypeOf value) + channel::Status upstream_await_write(NthTypeOf value) { std::unique_lock lock(m_mutex); @@ -191,11 +130,11 @@ class CombineLatest : public WritableAcceptor> channel::Status status = channel::Status::success; // Check if we should push the new value - if (m_values_set == sizeof...(TypesT)) + if (m_values_set == sizeof...(InputT)) { - // std::tuple new_val = surely2(m_state); + std::tuple new_val = utils::tuple_surely(m_state); - // status = this->get_writable_edge()->await_write(std::move(new_val)); + status = this->get_writable_edge()->await_write(this->convert_value(std::move(new_val))); } return status; @@ -207,18 +146,87 @@ class CombineLatest : public WritableAcceptor> m_completions++; - if (m_completions == sizeof...(TypesT)) + if (m_completions == sizeof...(InputT)) { - WritableAcceptor>::release_edge_connection(); + // Clear the held tuple to remove any dangling values + m_state = std::tuple...>(); + + WritableAcceptor::release_edge_connection(); } } - boost::fibers::mutex m_mutex; + virtual output_t convert_value(input_tuple_t&& data) = 0; + + mutable Mutex m_mutex; + + // The number of elements that have been set. Can start emitting when m_values_set == sizeof...(TypesT) size_t m_values_set{0}; + + // Counts the number of upstream completions. When m_completions == sizeof...(TypesT), the downstream edges are + // released size_t m_completions{0}; - std::tuple...> m_state; - std::tuple>...> m_upstream_holders; + // Holds onto the latest values to eventually push when new ones are emitted + std::tuple...> m_state; + + // Upstream edges + std::tuple>...> m_upstream_holders; +}; + +template +class CombineLatestComponent; + +template +class CombineLatestComponent, OutputT> : public CombineLatestBase, OutputT> +{ + public: + using base_t = CombineLatestBase, std::tuple>; + using input_tuple_t = typename base_t::input_tuple_t; + using output_t = typename base_t::output_t; +}; + +// Specialization for CombineLatest with a default output type +template +class CombineLatestComponent> + : public CombineLatestBase, std::tuple> +{ + public: + using base_t = CombineLatestBase, std::tuple>; + using input_tuple_t = typename base_t::input_tuple_t; + using output_t = typename base_t::output_t; + + private: + output_t convert_value(input_tuple_t&& data) override + { + // No change to the output type + return std::move(data); + } +}; + +template +class CombineLatestTransformComponent; + +template +class CombineLatestTransformComponent, OutputT> + : public CombineLatestBase, OutputT> +{ + public: + using base_t = CombineLatestBase, OutputT>; + using input_tuple_t = typename base_t::input_tuple_t; + using output_t = typename base_t::output_t; + using transform_fn_t = std::function; + + CombineLatestTransformComponent(transform_fn_t transform_fn) : base_t(), m_transform_fn(std::move(transform_fn)) {} + + private: + output_t convert_value(input_tuple_t&& data) override + { + return m_transform_fn(std::move(data)); + } + + transform_fn_t m_transform_fn; }; } // namespace mrc::node + +// IWYU pragma: end_exports diff --git a/cpp/mrc/include/mrc/node/operators/conditional.hpp b/cpp/mrc/include/mrc/node/operators/conditional.hpp index 250942b7c..1c04e91a9 100644 --- a/cpp/mrc/include/mrc/node/operators/conditional.hpp +++ b/cpp/mrc/include/mrc/node/operators/conditional.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,19 +22,12 @@ namespace mrc::node { template -class Conditional : public Router +class Conditional : public LambdaDynamicRouterComponent { - public: - Conditional(std::function predicate) : m_predicate(std::move(predicate)) {} - - protected: - virtual CaseT determine_key_for_value(const T& t) - { - return m_predicate(t); - } + using base_t = LambdaDynamicRouterComponent; - private: - std::function m_predicate; + public: + using base_t::base_t; }; } // namespace mrc::node diff --git a/cpp/mrc/include/mrc/node/operators/router.hpp b/cpp/mrc/include/mrc/node/operators/router.hpp index 4261e8c4b..9230c5ca6 100644 --- a/cpp/mrc/include/mrc/node/operators/router.hpp +++ b/cpp/mrc/include/mrc/node/operators/router.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,146 +17,487 @@ #pragma once +#include "mrc/channel/buffered_channel.hpp" #include "mrc/channel/status.hpp" +#include "mrc/core/utils.hpp" +#include "mrc/edge/edge_writable.hpp" #include "mrc/exceptions/runtime_error.hpp" #include "mrc/node/forward.hpp" +#include "mrc/node/node_parent.hpp" +#include "mrc/node/sink_channel_owner.hpp" #include "mrc/node/sink_properties.hpp" #include "mrc/node/source_channel_owner.hpp" #include "mrc/node/source_properties.hpp" +#include "mrc/runnable/forward.hpp" +#include "mrc/runnable/runnable.hpp" +#include "mrc/utils/string_utils.hpp" + +#include #include #include +#include +#include #include +// IWYU pragma: begin_exports + namespace mrc::node { +template +class RouterDownstreamNode : public edge::IWritableAcceptor, + public edge::IReadableProvider, + public ISourceChannelOwner +{}; + template -class RouterBase : public ForwardingWritableProvider, public MultiSourceProperties +class RouterBase : public MultiWritableAcceptor, + public MultiReadableProvider, + public MultiSourceChannelOwner { public: - using input_data_t = InputT; - using output_data_t = OutputT; - - RouterBase() : ForwardingWritableProvider() {} - - std::shared_ptr> get_source(const KeyT& key) const - { - // Simply return an object that will set the message to upstream and go away - return std::make_shared(*const_cast*>(this), key); - } + virtual std::shared_ptr> get_source(const KeyT& key) const = 0; bool has_source(const KeyT& key) const { - return MultiSourceProperties::get_edge_pair(key).first; + return MultiSourceProperties::get_edge_pair(key).first; } - void drop_edge(const KeyT& key) + void drop_source(const KeyT& key) { - MultiSourceProperties::release_edge_connection(key); + MultiSourceProperties::release_edge_connection(key); } protected: - class DownstreamEdge : public edge::IWritableAcceptor + class Downstream : public RouterDownstreamNode { public: - DownstreamEdge(RouterBase& parent, KeyT key) : m_parent(parent), m_key(std::move(key)) {} + Downstream(RouterBase& parent, KeyT key) : m_parent(parent), m_key(std::move(key)) + { + this->set_channel(std::make_unique>()); + } + + void set_channel(std::unique_ptr> channel) override + { + m_parent.MultiSourceChannelOwner::set_channel(m_key, std::move(channel)); + } void set_writable_edge_handle(std::shared_ptr ingress) override { - // Make sure we do any type conversions as needed - auto adapted_ingress = edge::EdgeBuilder::adapt_writable_edge(std::move(ingress)); + m_parent.MultiWritableAcceptor::set_writable_edge_handle(m_key, std::move(ingress)); + } - m_parent.MultiSourceProperties::make_edge_connection(m_key, std::move(adapted_ingress)); + std::shared_ptr get_readable_edge_handle() const override + { + return m_parent.MultiReadableProvider::get_readable_edge_handle(m_key); } private: - RouterBase& m_parent; + RouterBase& m_parent; KeyT m_key; }; - void on_complete() override + virtual KeyT determine_key_for_value(const InputT& t) = 0; + + virtual OutputT convert_value(InputT&& data) = 0; + + channel::Status process_one(InputT&& data) { - MultiSourceProperties::release_edge_connections(); + try + { + KeyT key = this->determine_key_for_value(data); + + if constexpr (std::is_same_v || std::is_convertible_v) + { + return MultiSourceProperties::get_writable_edge(key)->await_write(std::move(data)); + } + else + { + OutputT output = this->convert_value(std::move(data)); + + return MultiSourceProperties::get_writable_edge(key)->await_write(std::move(output)); + } + + } catch (const std::exception& e) + { + LOG(ERROR) << "Caught exception: " << e.what() << std::endl; + return channel::Status::error; + } } }; template -class Router; +class ConvertingRouterBase; + +template +class ConvertingRouterBase && !std::is_convertible_v>> + : public RouterBase +{}; template -class Router && !std::is_convertible_v>> +class ConvertingRouterBase || std::is_convertible_v>> : public RouterBase { protected: - channel::Status on_next(InputT&& data) override + OutputT convert_value(InputT&& data) override { - KeyT key = this->determine_key_for_value(data); + // This is a no-op, we just return the data. This wont be used. + return std::move(data); + } +}; + +template +class LambdaRouterBase; - auto output = this->convert_value(std::move(data)); +template +class LambdaRouterBase && !std::is_convertible_v>> + : public virtual ConvertingRouterBase +{ + public: + using base_t = ConvertingRouterBase; + using key_fn_t = std::function; + using convert_fn_t = std::function; - return MultiSourceProperties::get_writable_edge(key)->await_write(std::move(output)); + LambdaRouterBase(key_fn_t key_fn, convert_fn_t convert_fn) : + base_t(), + m_key_fn(std::move(key_fn)), + m_convert_fn(std::move(convert_fn)) + {} + + protected: + KeyT determine_key_for_value(const InputT& t) override + { + return m_key_fn(t); } - virtual KeyT determine_key_for_value(const InputT& t) = 0; + OutputT convert_value(InputT&& data) override + { + return m_convert_fn(std::move(data)); + } - virtual OutputT convert_value(InputT&& data) = 0; + key_fn_t m_key_fn; + convert_fn_t m_convert_fn; }; template -class Router && std::is_convertible_v>> - : public RouterBase +class LambdaRouterBase || std::is_convertible_v>> + : public virtual ConvertingRouterBase { + public: + using base_t = ConvertingRouterBase; + using key_fn_t = std::function; + + LambdaRouterBase(key_fn_t key_fn) : base_t(), m_key_fn(std::move(key_fn)) {} + protected: - channel::Status on_next(InputT&& data) override + KeyT determine_key_for_value(const InputT& t) override + { + return m_key_fn(t); + } + + key_fn_t m_key_fn; +}; + +template +class StaticRouterBase : public virtual ConvertingRouterBase, + public HomogeneousNodeParent> +{ + public: + using base_t = ConvertingRouterBase; + using this_t = StaticRouterBase; + + StaticRouterBase(std::vector route_keys) { - KeyT key = this->determine_key_for_value(data); + // Create a downstream for each key + for (const auto& key : route_keys) + { + m_downstreams[key] = std::make_shared(*this, key); + } + } - return MultiSourceProperties::get_writable_edge(key)->await_write(std::move(data)); + std::shared_ptr> get_source(const KeyT& key) const override + { + if (!m_downstreams.contains(key)) + { + throw exceptions::MrcRuntimeError(MRC_CONCAT_STR("Key '" << key << "' found in router")); + } + + return m_downstreams.at(key); } - virtual KeyT determine_key_for_value(const InputT& t) = 0; + std::map> get_children_refs( + std::optional child_name = std::nullopt) const override + { + std::map> children; + + for (const auto& [key, downstream] : m_downstreams) + { + // Utilize MRC_CONCAT_STR to convert the type to a string as best we can + children.emplace(MRC_CONCAT_STR(key), std::ref(*downstream)); + } + + return children; + } + + protected: + std::map> m_downstreams; }; -template -class Router>> - : public RouterBase +template +class DynamicRouterBase : public virtual ConvertingRouterBase +{ + using this_t = DynamicRouterBase; + + public: + std::shared_ptr> get_source(const KeyT& key) const override + { + std::shared_ptr downstream; + + if (!m_downstreams.contains(key) || (downstream = m_downstreams.at(key).lock()) == nullptr) + { + // Cast away constness to create the downstream + auto non_const_this = const_cast(this); + + downstream = std::make_shared(*non_const_this, key); + + non_const_this->m_downstreams[key] = downstream; + + return downstream; + } + + return downstream; + } + + protected: + std::map> m_downstreams; +}; + +template +class ComponentRouterBase : public ForwardingWritableProvider, + public virtual ConvertingRouterBase { protected: channel::Status on_next(InputT&& data) override { - KeyT key = this->determine_key_for_value(data); + return this->process_one(std::move(data)); + } - return MultiSourceProperties::get_writable_edge(key)->await_write(std::move(data)); + void on_complete() override + { + MultiSourceProperties::release_edge_connections(); } +}; - virtual KeyT determine_key_for_value(const InputT& t) = 0; +template +class RunnableRouterBase : public WritableProvider, + public ReadableAcceptor, + public SinkChannelOwner, + public virtual ConvertingRouterBase, + public mrc::runnable::RunnableWithContext<> +{ + protected: + RunnableRouterBase() + { + SinkChannelOwner::set_channel(std::make_unique>()); + } + + // Allows for easier testing of this method + void do_run() + { + InputT data; + channel::Status read_status; + channel::Status write_status = channel::Status::success; // give an initial value + + // Loop until either the node has been killed or the upstream terminated + while (!m_stop_source.stop_requested() && + (read_status = this->get_readable_edge()->await_read(data)) == channel::Status::success && + write_status == channel::Status::success) + { + write_status = this->process_one(std::move(data)); + } + + // Drop all connections + + if (read_status == channel::Status::error) + { + throw exceptions::MrcRuntimeError("Failed to read from upstream"); + } + + if (write_status == channel::Status::error) + { + throw exceptions::MrcRuntimeError("Failed to write to downstream"); + } + } + + private: + /** + * @brief Runnable's entrypoint. + */ + void run(mrc::runnable::Context& ctx) override + { + Unwinder unwinder([&] { + ctx.barrier(); + + if (ctx.rank() == 0) + { + MultiSourceProperties::release_edge_connections(); + } + }); + + this->do_run(); + } + + /** + * @brief Runnable's state control, for stopping from MRC. + */ + void on_state_update(const mrc::runnable::Runnable::State& state) final + { + switch (state) + { + case mrc::runnable::Runnable::State::Stop: + // Do nothing, we wait for the upstream channel to return closed + // m_stop_source.request_stop(); + break; + + case mrc::runnable::Runnable::State::Kill: + m_stop_source.request_stop(); + break; + + default: + break; + } + } + + std::stop_source m_stop_source; +}; + +template +class StaticRouterComponentBase : public StaticRouterBase, + public ComponentRouterBase +{ + public: + StaticRouterComponentBase(std::vector route_keys) : + StaticRouterBase(std::move(route_keys)) + {} +}; + +template +class LambdaStaticRouterComponent; + +template +class LambdaStaticRouterComponent< + KeyT, + InputT, + OutputT, + std::enable_if_t && !std::is_convertible_v>> + : public LambdaRouterBase, public StaticRouterComponentBase +{ + public: + using key_fn_t = LambdaRouterBase::key_fn_t; + using convert_fn_t = LambdaRouterBase::convert_fn_t; + + LambdaStaticRouterComponent(std::vector route_keys, key_fn_t key_fn, convert_fn_t convert_fn) : + LambdaRouterBase(std::move(key_fn), std::move(convert_fn)), + StaticRouterComponentBase(std::move(route_keys)) + {} +}; + +template +class LambdaStaticRouterComponent< + KeyT, + InputT, + OutputT, + std::enable_if_t || std::is_convertible_v>> + : public LambdaRouterBase, public StaticRouterComponentBase +{ + public: + using key_fn_t = LambdaRouterBase::key_fn_t; + + LambdaStaticRouterComponent(std::vector route_keys, key_fn_t key_fn) : + LambdaRouterBase(std::move(key_fn)), + StaticRouterComponentBase(std::move(route_keys)) + {} +}; + +template +class StaticRouterRunnableBase : public StaticRouterBase, + public RunnableRouterBase +{ + public: + StaticRouterRunnableBase(std::vector route_keys) : + StaticRouterBase(std::move(route_keys)) + {} +}; + +template +class LambdaStaticRouterRunnable : public LambdaRouterBase, + public StaticRouterRunnableBase +{ + public: + using key_fn_t = std::function; + + LambdaStaticRouterRunnable(std::vector route_keys, key_fn_t key_fn) : + StaticRouterRunnableBase(std::move(route_keys)), + LambdaRouterBase(std::move(key_fn)) + {} +}; + +template +class DynamicRouterComponentBase : public DynamicRouterBase, + public ComponentRouterBase +{}; + +template +class LambdaDynamicRouterComponent : public LambdaRouterBase, + public DynamicRouterComponentBase +{ + public: + using LambdaRouterBase::LambdaRouterBase; +}; + +template +class DynamicRouterRunnableBase : public DynamicRouterBase, + public RunnableRouterBase +{}; + +template +class LambdaDynamicRouterRunnable : public LambdaRouterBase, + public DynamicRouterRunnableBase +{ + public: + using LambdaRouterBase::LambdaRouterBase; }; template -class TaggedRouter : public Router, T> +class TaggedRouter : public DynamicRouterComponentBase, T> { protected: - using typename RouterBase, T>::input_data_t; - using typename RouterBase, T>::output_data_t; - - KeyT determine_key_for_value(const input_data_t& data) override + KeyT determine_key_for_value(const std::pair& data) override { return data.first; } - output_data_t convert_value(input_data_t&& data) override + T convert_value(std::pair&& data) override { // TODO(MDD): Do we need to move the key too? - output_data_t tmp = std::move(data.second); + T tmp = std::move(data.second); return tmp; } }; } // namespace mrc::node + +// IWYU pragma: end_exports diff --git a/cpp/mrc/include/mrc/node/operators/with_latest_from.hpp b/cpp/mrc/include/mrc/node/operators/with_latest_from.hpp new file mode 100644 index 000000000..ca337373e --- /dev/null +++ b/cpp/mrc/include/mrc/node/operators/with_latest_from.hpp @@ -0,0 +1,309 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mrc/channel/buffered_channel.hpp" +#include "mrc/channel/status.hpp" +#include "mrc/core/utils.hpp" +#include "mrc/node/node_parent.hpp" +#include "mrc/node/sink_properties.hpp" +#include "mrc/node/source_properties.hpp" +#include "mrc/types.hpp" +#include "mrc/utils/tuple_utils.hpp" +#include "mrc/utils/type_utils.hpp" + +#include +#include + +#include +#include +#include +#include + +// IWYU pragma: begin_exports + +namespace mrc::node { + +class WithLatestFromTypelessBase +{ + public: + virtual ~WithLatestFromTypelessBase() = default; +}; + +template +class WithLatestFromBase +{}; + +template +class WithLatestFromBase, OutputT> + : public WithLatestFromTypelessBase, + public WritableAcceptor, + public HeterogeneousNodeParent...> +{ + public: + using input_tuple_t = std::tuple; + using output_t = OutputT; + + private: + template + using queue_t = BufferedChannel; + template + using wrapped_queue_t = std::unique_ptr>; + using queues_tuple_type = std::tuple...>; + + template + static auto build_ingress(WithLatestFromBase* self, std::index_sequence /*unused*/) + { + return std::make_tuple(std::make_shared>(*self)...); + } + + static auto build_queues(size_t channel_size) + { + return std::make_tuple(std::make_unique>(channel_size)...); + } + + template + static std::tuple>>...> + build_child_pairs(WithLatestFromBase* self, std::index_sequence /*unused*/) + { + return std::make_tuple( + std::make_pair(MRC_CONCAT_STR("sink[" << Is << "]"), std::ref(*self->get_sink()))...); + } + + public: + WithLatestFromBase(size_t max_outstanding = channel::default_channel_size()) : + m_primary_queue(std::make_unique>>(max_outstanding)), + m_upstream_holders(build_ingress(const_cast(this), std::index_sequence_for{})) + {} + + ~WithLatestFromBase() override = default; + + template + std::shared_ptr>> get_sink() const + { + return std::get(m_upstream_holders); + } + + std::tuple>>...> get_children_refs() + const override + { + return build_child_pairs(const_cast(this), std::index_sequence_for{}); + } + + protected: + template + class Upstream : public ForwardingWritableProvider> + { + using upstream_t = NthTypeOf; + + public: + Upstream(WithLatestFromBase& parent) : m_parent(parent) {} + + protected: + channel::Status on_next(upstream_t&& data) override + { + return m_parent.upstream_await_write(std::move(data)); + } + + void on_complete() override + { + m_parent.edge_complete(); + } + + private: + WithLatestFromBase& m_parent; + }; + + private: + template + channel::Status upstream_await_write(NthTypeOf value) + { + std::unique_lock lock(m_mutex); + + // Get a reference to the current value + auto& nth_val = std::get(m_state); + + // Check if we have fully initialized + if (m_values_set < sizeof...(InputT)) + { + if (!nth_val.has_value()) + { + ++m_values_set; + } + + // Move the value into the state + nth_val = std::move(value); + + // For the primary upstream only, move the value onto a queue + if constexpr (N == 0) + { + // Temporarily unlock to prevent deadlock + lock.unlock(); + + Unwinder relock([&]() { + lock.lock(); + }); + + // Move it into the queue + CHECK_EQ(m_primary_queue->await_write(std::move(nth_val.value())), channel::Status::success); + } + + // Check if this put us over the edge + if (m_values_set == sizeof...(InputT)) + { + // Need to complete initialization. First close the primary channel + m_primary_queue->close_channel(); + + auto& primary_val = std::get<0>(m_state); + + // Loop over the values in the queue, pushing each one + while (m_primary_queue->await_read(primary_val.value()) == channel::Status::success) + { + std::tuple new_val = utils::tuple_surely(m_state); + + CHECK_EQ(this->get_writable_edge()->await_write(this->convert_value(std::move(new_val))), + channel::Status::success); + } + } + } + else + { + // Move the value into the state + nth_val = std::move(value); + + // Only when we are the primary, do we push a new value + if constexpr (N == 0) + { + std::tuple new_val = utils::tuple_surely(m_state); + + return this->get_writable_edge()->await_write(this->convert_value(std::move(new_val))); + } + } + + return channel::Status::success; + } + + void edge_complete() + { + std::unique_lock lock(m_mutex); + + m_completions++; + + if (m_completions == sizeof...(InputT)) + { + NthTypeOf<0, InputT...> tmp; + bool had_values = false; + + // Try to clear out any values left in the channel + while (m_primary_queue->await_read(tmp) == channel::Status::success) + { + had_values = true; + } + + LOG_IF(ERROR, had_values) << "The primary source values were never pushed downstream. Ensure all upstream " + "sources pushed at least 1 value"; + + // Clear the held tuple to remove any dangling values + m_state = std::tuple...>(); + + WritableAcceptor::release_edge_connection(); + } + } + + virtual output_t convert_value(input_tuple_t&& data) = 0; + + mutable Mutex m_mutex; + + // The number of elements that have been set. Can start emitting when m_values_set == sizeof...(TypesT) + size_t m_values_set{0}; + + // Counts the number of upstream completions. When m_completions == sizeof...(TypesT), the downstream edges are + // released + size_t m_completions{0}; + + // Holds onto the latest values to eventually push when new ones are emitted + std::tuple...> m_state; + + // Queue to allow backpressure to upstreams. Only 1 queue for the primary is needed + wrapped_queue_t> m_primary_queue; + + // Upstream edges + std::tuple>...> m_upstream_holders; +}; + +template +class WithLatestFromComponent; + +template +class WithLatestFromComponent, OutputT> + : public WithLatestFromBase, OutputT> +{ + public: + using base_t = WithLatestFromBase, std::tuple>; + using input_tuple_t = typename base_t::input_tuple_t; + using output_t = typename base_t::output_t; +}; + +// Specialization for WithLatestFromBase with a default output type +template +class WithLatestFromComponent> + : public WithLatestFromBase, std::tuple> +{ + public: + using base_t = WithLatestFromBase, std::tuple>; + using input_tuple_t = typename base_t::input_tuple_t; + using output_t = typename base_t::output_t; + + private: + output_t convert_value(input_tuple_t&& data) override + { + // No change to the output type + return std::move(data); + } +}; + +template +class WithLatestFromTransformComponent; + +template +class WithLatestFromTransformComponent, OutputT> + : public WithLatestFromBase, OutputT> +{ + public: + using base_t = WithLatestFromBase, OutputT>; + using input_tuple_t = typename base_t::input_tuple_t; + using output_t = typename base_t::output_t; + using transform_fn_t = std::function; + + WithLatestFromTransformComponent(transform_fn_t transform_fn, size_t max_outstanding = 64) : + base_t(max_outstanding), + m_transform_fn(std::move(transform_fn)) + {} + + private: + output_t convert_value(input_tuple_t&& data) override + { + return m_transform_fn(std::move(data)); + } + + transform_fn_t m_transform_fn; +}; + +} // namespace mrc::node + +// IWYU pragma: end_exports diff --git a/cpp/mrc/include/mrc/node/operators/zip.hpp b/cpp/mrc/include/mrc/node/operators/zip.hpp new file mode 100644 index 000000000..db6a359a3 --- /dev/null +++ b/cpp/mrc/include/mrc/node/operators/zip.hpp @@ -0,0 +1,353 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mrc/channel/buffered_channel.hpp" +#include "mrc/channel/channel.hpp" +#include "mrc/channel/status.hpp" +#include "mrc/node/node_parent.hpp" +#include "mrc/node/sink_properties.hpp" +#include "mrc/node/source_properties.hpp" +#include "mrc/types.hpp" +#include "mrc/utils/string_utils.hpp" +#include "mrc/utils/tuple_utils.hpp" +#include "mrc/utils/type_utils.hpp" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// IWYU pragma: begin_exports + +namespace mrc::node { + +class ZipTypelessBase +{ + public: + virtual ~ZipTypelessBase() = default; +}; + +template +class ZipBase +{}; + +template +class ZipBase, OutputT> : public ZipTypelessBase, + public WritableAcceptor, + public HeterogeneousNodeParent...> +{ + public: + using input_tuple_t = std::tuple; + using output_t = OutputT; + + private: + template + using queue_t = BufferedChannel; + template + using wrapped_queue_t = std::unique_ptr>; + using queues_tuple_type = std::tuple...>; + + template + static auto build_ingress(ZipBase* self, std::index_sequence /*unused*/) + { + return std::make_tuple(std::make_shared>(*self)...); + } + + static auto build_queues(size_t channel_size) + { + return std::make_tuple(std::make_unique>(channel_size)...); + } + + template + static std::tuple>>...> + build_child_pairs(ZipBase* self, std::index_sequence /*unused*/) + { + return std::make_tuple( + std::make_pair(MRC_CONCAT_STR("sink[" << Is << "]"), std::ref(*self->get_sink()))...); + } + + template + channel::Status tuple_pop_each(queues_tuple_type& queues_tuple, input_tuple_t& output_tuple) + { + channel::Status status = std::get(queues_tuple)->await_read(std::get(output_tuple)); + + if constexpr (I + 1 < sizeof...(InputT)) + { + // Iterate to the next index + channel::Status inner_status = tuple_pop_each(queues_tuple, output_tuple); + + // If the inner status failed, return that, otherwise return our status + status = inner_status == channel::Status::success ? status : inner_status; + } + + return status; + } + + public: + ZipBase(size_t max_outstanding = channel::default_channel_size()) : + m_queues(build_queues(max_outstanding)), + m_upstream_holders(build_ingress(const_cast(this), std::index_sequence_for{})) + { + // Must be sure to set any array values + m_queue_counts.fill(0); + } + + ~ZipBase() override = default; + + template + std::shared_ptr>> get_sink() const + { + return std::get(m_upstream_holders); + } + + std::tuple>>...> get_children_refs() + const override + { + return build_child_pairs(const_cast(this), std::index_sequence_for{}); + } + + protected: + template + class Upstream : public ForwardingWritableProvider> + { + using upstream_t = NthTypeOf; + + public: + Upstream(ZipBase& parent) : m_parent(parent) {} + + protected: + channel::Status on_next(upstream_t&& data) override + { + return m_parent.upstream_await_write(std::move(data)); + } + + void on_complete() override + { + m_parent.edge_complete(); + } + + private: + ZipBase& m_parent; + }; + + private: + template + channel::Status upstream_await_write(NthTypeOf value) + { + // Push before locking so we dont deadlock + auto push_status = std::get(m_queues)->await_write(std::move(value)); + + if (push_status != channel::Status::success) + { + return push_status; + } + + std::unique_lock lock(m_mutex); + + // Update the counts array + m_queue_counts[N]++; + + if (m_queue_counts[N] == m_max_queue_count) + { + // Close the queue to prevent pushing more messages + std::get(m_queues)->close_channel(); + } + + DCHECK_LE(m_queue_counts[N], m_max_queue_count) << "Queue count has surpassed the max count"; + + // See if we have values in every queue + auto all_queues_have_value = std::transform_reduce(m_queue_counts.begin(), + m_queue_counts.end(), + true, + std::logical_and<>(), + [this](const size_t& v) { + return v > m_pull_count; + }); + + channel::Status status = channel::Status::success; + + if (all_queues_have_value) + { + // For each tuple, pop a value off + std::tuple new_val; + + auto channel_status = tuple_pop_each(m_queues, new_val); + + DCHECK_EQ(channel_status, channel::Status::success) << "Queues returned failed status"; + + // Push the new value + status = this->get_writable_edge()->await_write(this->convert_value(std::move(new_val))); + + m_pull_count++; + } + + return status; + } + + template + void edge_complete() + { + std::unique_lock lock(m_mutex); + + if (m_queue_counts[N] < m_max_queue_count) + { + // We are setting a new lower limit. Check to make sure this isnt an issue + m_max_queue_count = m_queue_counts[N]; + + utils::tuple_for_each(m_queues, + [this](std::unique_ptr>& q, size_t idx) { + if (m_queue_counts[idx] >= m_max_queue_count) + { + // Close the channel + q->close_channel(); + + if (m_queue_counts[idx] > m_max_queue_count) + { + LOG(ERROR) + << "Unbalanced count in upstream sources for Zip operator. Upstream '" + << N << "' ended with " << m_queue_counts[N] << " elements but " + << m_queue_counts[idx] + << " elements have already been pushed by upstream '" << idx << "'"; + } + } + }); + } + + m_completions++; + + if (m_completions == sizeof...(InputT)) + { + // Warn on any left over values + auto left_over_messages = std::transform_reduce(m_queue_counts.begin(), + m_queue_counts.end(), + 0, + std::plus<>(), + [this](const size_t& v) { + return v - m_pull_count; + }); + if (left_over_messages > 0) + { + LOG(ERROR) << "Unbalanced count in upstream sources for Zip operator. " << left_over_messages + << " messages were left in the queues"; + } + + // Finally, drain the queues of any remaining values + utils::tuple_for_each(m_queues, + [](std::unique_ptr>& q, size_t idx) { + QueueValueT value; + + while (q->await_read(value) == channel::Status::success) {} + }); + + WritableAcceptor::release_edge_connection(); + } + } + + virtual output_t convert_value(input_tuple_t&& data) = 0; + + mutable Mutex m_mutex; + + // Once an upstream is closed, this is set representing the max number of values in a queue before its closed + size_t m_max_queue_count{std::numeric_limits::max()}; + + // Counts the number of upstream completions. When m_completions == sizeof...(TypesT), the downstream edges are + // released + size_t m_completions{0}; + + // Holds the number of values pushed to each queue + std::array m_queue_counts; + + // The number of messages pulled off the queue + size_t m_pull_count{0}; + + // Queue used to allow backpressure to upstreams + queues_tuple_type m_queues; + + // Upstream edges + std::tuple>...> m_upstream_holders; +}; + +template +class ZipComponent; + +template +class ZipComponent, OutputT> : public ZipBase, OutputT> +{ + public: + using base_t = ZipBase, std::tuple>; + using input_tuple_t = typename base_t::input_tuple_t; + using output_t = typename base_t::output_t; +}; + +// Specialization for Zip with a default output type +template +class ZipComponent> : public ZipBase, std::tuple> +{ + public: + using base_t = ZipBase, std::tuple>; + using input_tuple_t = typename base_t::input_tuple_t; + using output_t = typename base_t::output_t; + + private: + output_t convert_value(input_tuple_t&& data) override + { + // No change to the output type + return std::move(data); + } +}; + +template +class ZipTransformComponent; + +template +class ZipTransformComponent, OutputT> : public ZipBase, OutputT> +{ + public: + using base_t = ZipBase, OutputT>; + using input_tuple_t = typename base_t::input_tuple_t; + using output_t = typename base_t::output_t; + using transform_fn_t = std::function; + + ZipTransformComponent(transform_fn_t transform_fn, size_t max_outstanding = 64) : + base_t(max_outstanding), + m_transform_fn(std::move(transform_fn)) + {} + + private: + output_t convert_value(input_tuple_t&& data) override + { + return m_transform_fn(std::move(data)); + } + + transform_fn_t m_transform_fn; +}; + +} // namespace mrc::node + +// IWYU pragma: end_exports diff --git a/cpp/mrc/include/mrc/node/sink_channel_owner.hpp b/cpp/mrc/include/mrc/node/sink_channel_owner.hpp index 8997e3a8d..909f9e710 100644 --- a/cpp/mrc/include/mrc/node/sink_channel_owner.hpp +++ b/cpp/mrc/include/mrc/node/sink_channel_owner.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -38,13 +38,13 @@ class SinkChannelOwner : public virtual SinkProperties { edge::EdgeChannel edge_channel(std::move(channel)); - this->do_set_channel(edge_channel); + this->do_set_channel(std::move(edge_channel)); } protected: SinkChannelOwner() = default; - void do_set_channel(edge::EdgeChannel& edge_channel) + void do_set_channel(edge::EdgeChannel edge_channel) { // Create 2 edges, one for reading and writing. On connection, persist the other to allow the node to still use // get_readable+edge diff --git a/cpp/mrc/include/mrc/node/sink_properties.hpp b/cpp/mrc/include/mrc/node/sink_properties.hpp index af56b0fe0..c71f70c1c 100644 --- a/cpp/mrc/include/mrc/node/sink_properties.hpp +++ b/cpp/mrc/include/mrc/node/sink_properties.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +20,7 @@ #include "mrc/channel/status.hpp" // IWYU pragma: export #include "mrc/edge/edge_builder.hpp" #include "mrc/edge/edge_readable.hpp" +#include "mrc/edge/edge_writable.hpp" #include "mrc/node/forward.hpp" #include "mrc/type_traits.hpp" #include "mrc/utils/type_utils.hpp" @@ -121,6 +122,36 @@ class SinkProperties : public edge::EdgeHolder, public SinkPropertiesBase } }; +template +class MultiSinkProperties : public edge::MultiEdgeHolder, public SinkPropertiesBase +{ + public: + using sink_type_t = T; + + std::type_index sink_type(bool ignore_holder = false) const final + { + if (ignore_holder) + { + if constexpr (is_smart_ptr::value) + { + return typeid(typename T::element_type); + } + } + return typeid(T); + } + + std::string sink_type_name() const final + { + return std::string(type_name()); + } + + protected: + std::shared_ptr> get_readable_edge(KeyT edge_key) const + { + return std::dynamic_pointer_cast>(this->get_connected_edge(edge_key)); + } +}; + template class ReadableAcceptor : public virtual SinkProperties, public edge::IReadableAcceptor { @@ -131,7 +162,7 @@ class ReadableAcceptor : public virtual SinkProperties, public edge::IReadabl SinkProperties::operator=(std::move(other)); } - private: + protected: void set_readable_edge_handle(std::shared_ptr egress) override { // Do any conversion to the correct type here @@ -151,13 +182,90 @@ class WritableProvider : public virtual SinkProperties, public edge::IWritabl SinkProperties::operator=(std::move(other)); } - private: + protected: std::shared_ptr get_writable_edge_handle() const override { return edge::WritableEdgeHandle::from_typeless(SinkProperties::get_edge_connection()); } }; +template +class ReadableWritableSink : public WritableProvider, public ReadableAcceptor +{}; + +template +class MultiReadableAcceptor : public virtual MultiSinkProperties, public edge::IMultiReadableAcceptor +{ + public: + protected: + bool has_readable_edge(const KeyT& key) const override + { + return MultiSinkProperties::has_edge_connection(key); + } + + void release_readable_edge(const KeyT& key) override + { + return MultiSinkProperties::release_edge_connection(key); + } + + void release_readable_edges() override + { + return MultiSinkProperties::release_edge_connections(); + } + + size_t readable_edge_count() const override + { + return MultiSinkProperties::edge_connection_count(); + } + + std::vector readable_edge_keys() const override + { + return MultiSinkProperties::edge_connection_keys(); + } + + void set_readable_edge_handle(KeyT key, std::shared_ptr egress) override + { + auto adapted_egress = edge::EdgeBuilder::adapt_readable_edge(egress); + MultiSinkProperties::make_edge_connection(key, adapted_egress); + } +}; + +template +class MultiWritableProvider : public virtual MultiSinkProperties, public edge::IMultiWritableProvider +{ + public: + protected: + bool has_writable_edge(const KeyT& key) const override + { + return MultiSinkProperties::has_edge_connection(key); + } + + void release_writable_edge(const KeyT& key) override + { + return MultiSinkProperties::release_edge_connection(key); + } + + void release_writable_edges() override + { + return MultiSinkProperties::release_edge_connections(); + } + + size_t writable_edge_count() const override + { + return MultiSinkProperties::edge_connection_count(); + } + + std::vector writable_edge_keys() const override + { + return MultiSinkProperties::edge_connection_keys(); + } + + std::shared_ptr get_writable_edge_handle(KeyT key) const override + { + return edge::WritableEdgeHandle::from_typeless(MultiSinkProperties::get_edge_connection(key)); + } +}; + template class ForwardingWritableProvider : public WritableProvider { diff --git a/cpp/mrc/include/mrc/node/source_channel_owner.hpp b/cpp/mrc/include/mrc/node/source_channel_owner.hpp index 226492e5e..b85fb25c7 100644 --- a/cpp/mrc/include/mrc/node/source_channel_owner.hpp +++ b/cpp/mrc/include/mrc/node/source_channel_owner.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,28 +25,37 @@ namespace mrc::node { +template +class ISourceChannelOwner +{ + public: + virtual ~ISourceChannelOwner() = default; + + virtual void set_channel(std::unique_ptr> channel) = 0; +}; + /** * @brief Extends SourceProperties to hold a channel ingress which is the writing interface to an edge. * * @tparam T */ template -class SourceChannelOwner : public virtual SourceProperties +class SourceChannelOwner : public ISourceChannelOwner, public virtual SourceProperties { public: ~SourceChannelOwner() override = default; - void set_channel(std::unique_ptr> channel) + void set_channel(std::unique_ptr> channel) override { edge::EdgeChannel edge_channel(std::move(channel)); - this->do_set_channel(edge_channel); + this->do_set_channel(std::move(edge_channel)); } protected: SourceChannelOwner() = default; - void do_set_channel(edge::EdgeChannel& edge_channel) + void do_set_channel(edge::EdgeChannel edge_channel) { // Create 2 edges, one for reading and writing. On connection, persist the other to allow the node to still use // get_writable_edge @@ -64,4 +73,38 @@ class SourceChannelOwner : public virtual SourceProperties } }; +template +class MultiSourceChannelOwner : public virtual MultiSourceProperties +{ + public: + ~MultiSourceChannelOwner() override = default; + + void set_channel(KeyT key, std::unique_ptr> channel) + { + edge::EdgeChannel edge_channel(std::move(channel)); + + this->do_set_channel(std::move(key), std::move(edge_channel)); + } + + protected: + MultiSourceChannelOwner() = default; + + void do_set_channel(KeyT key, edge::EdgeChannel edge_channel) + { + // Create 2 edges, one for reading and writing. On connection, persist the other to allow the node to still use + // get_writable_edge + auto channel_reader = edge_channel.get_reader(); + auto channel_writer = edge_channel.get_writer(); + + channel_reader->add_connector([this, channel_writer, key]() { + // Finally, set the other half as the connected edge to allow writers the ability to push to the channel. + // Only do this after a full connection has been made to avoid writing to a channel that will never be + // read from. + this->MultiSourceProperties::init_connected_edge(key, channel_writer); + }); + + MultiSourceProperties::init_owned_edge(key, channel_reader); + } +}; + } // namespace mrc::node diff --git a/cpp/mrc/include/mrc/node/source_properties.hpp b/cpp/mrc/include/mrc/node/source_properties.hpp index 12166eb43..58faf513b 100644 --- a/cpp/mrc/include/mrc/node/source_properties.hpp +++ b/cpp/mrc/include/mrc/node/source_properties.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,6 +20,7 @@ #include "mrc/channel/ingress.hpp" #include "mrc/channel/status.hpp" // IWYU pragma: export #include "mrc/edge/edge_builder.hpp" +#include "mrc/edge/edge_readable.hpp" #include "mrc/edge/edge_writable.hpp" #include "mrc/node/forward.hpp" #include "mrc/type_traits.hpp" @@ -163,7 +164,7 @@ class ReadableProvider : public virtual SourceProperties, public edge::IReada SourceProperties::operator=(std::move(other)); } - private: + protected: std::shared_ptr get_readable_edge_handle() const override { return edge::ReadableEdgeHandle::from_typeless(SourceProperties::get_edge_connection()); @@ -180,7 +181,7 @@ class WritableAcceptor : public virtual SourceProperties, public edge::IWrita SourceProperties::operator=(std::move(other)); } - private: + protected: void set_writable_edge_handle(std::shared_ptr ingress) override { // Do any conversion to the correct type here @@ -190,28 +191,95 @@ class WritableAcceptor : public virtual SourceProperties, public edge::IWrita } }; -template -class MultiIngressAcceptor : public virtual MultiSourceProperties, public edge::IMultiWritableAcceptor +template +class ReadableWritableSource : public ReadableProvider, public WritableAcceptor +{}; + +template +class MultiReadableProvider : public virtual MultiSourceProperties, + public edge::IMultiReadableProvider +{ + public: + protected: + bool has_readable_edge(const KeyT& key) const override + { + return MultiSourceProperties::has_edge_connection(key); + } + + void release_readable_edge(const KeyT& key) override + { + return MultiSourceProperties::release_edge_connection(key); + } + + void release_readable_edges() override + { + return MultiSourceProperties::release_edge_connections(); + } + + size_t readable_edge_count() const override + { + return MultiSourceProperties::edge_connection_count(); + } + + std::vector readable_edge_keys() const override + { + return MultiSourceProperties::edge_connection_keys(); + } + + std::shared_ptr get_readable_edge_handle(KeyT key) const override + { + return edge::ReadableEdgeHandle::from_typeless(MultiSourceProperties::get_edge_connection(key)); + } +}; + +template +class MultiWritableAcceptor : public virtual MultiSourceProperties, + public edge::IMultiWritableAcceptor { public: - private: + protected: + bool has_writable_edge(const KeyT& key) const override + { + return MultiSourceProperties::has_edge_connection(key); + } + + void release_writable_edge(const KeyT& key) override + { + return MultiSourceProperties::release_edge_connection(key); + } + + void release_writable_edges() override + { + return MultiSourceProperties::release_edge_connections(); + } + + size_t writable_edge_count() const override + { + return MultiSourceProperties::edge_connection_count(); + } + + std::vector writable_edge_keys() const override + { + return MultiSourceProperties::edge_connection_keys(); + } + void set_writable_edge_handle(KeyT key, std::shared_ptr ingress) override { // Do any conversion to the correct type here auto adapted_ingress = edge::EdgeBuilder::adapt_writable_edge(ingress); - MultiSourceProperties::make_edge_connection(key, adapted_ingress); + MultiSourceProperties::make_edge_connection(key, adapted_ingress); } }; template -class ForwardingEgressProvider : public ReadableProvider +class ForwardingReadableProvider : public ReadableProvider { protected: class ForwardingEdge : public edge::IEdgeReadable { public: - ForwardingEdge(ForwardingEgressProvider& parent) : m_parent(parent) {} + ForwardingEdge(ForwardingReadableProvider& parent) : m_parent(parent) {} ~ForwardingEdge() = default; @@ -221,10 +289,10 @@ class ForwardingEgressProvider : public ReadableProvider } private: - ForwardingEgressProvider& m_parent; + ForwardingReadableProvider& m_parent; }; - ForwardingEgressProvider() + ForwardingReadableProvider() { auto inner_edge = std::make_shared(*this); diff --git a/cpp/mrc/include/mrc/segment/builder.hpp b/cpp/mrc/include/mrc/segment/builder.hpp index a35f571c9..e7d93720f 100644 --- a/cpp/mrc/include/mrc/segment/builder.hpp +++ b/cpp/mrc/include/mrc/segment/builder.hpp @@ -464,6 +464,16 @@ void IBuilder::make_edge(SourceObjectT source, SinkObjectT sink) auto& source_object = to_object_properties(source); auto& sink_object = to_object_properties(sink); + if (source_object.owning_builder() != this) + { + throw exceptions::MrcRuntimeError("Source object does not belong to this builder"); + } + + if (sink_object.owning_builder() != this) + { + throw exceptions::MrcRuntimeError("Sink object does not belong to this builder"); + } + // If we can determine the type from the actual object, use that, then fall back to hints or defaults. using deduced_source_type_t = first_non_void_type_t() << std::endl; VLOG(2) << "Deduced sink type: " << mrc::type_name() << std::endl; - if (source_object.is_writable_acceptor() && sink_object.is_writable_provider()) + if constexpr (std::is_void_v || std::is_void_v) { - mrc::make_edge(source_object.template writable_acceptor_typed(), - sink_object.template writable_provider_typed()); - return; + // Try typeless edge creation + if (source_object.is_writable_acceptor() && sink_object.is_writable_provider()) + { + mrc::make_edge_typeless(source_object.writable_acceptor_base(), sink_object.writable_provider_base()); + } + else if (source_object.is_readable_provider() && sink_object.is_readable_acceptor()) + { + mrc::make_edge_typeless(source_object.readable_provider_base(), sink_object.readable_acceptor_base()); + } + else + { + throw std::runtime_error( + "Invalid edges. Arguments to make_edge were incorrect. Ensure you are providing either " + "WritableAcceptor->WritableProvider or ReadableProvider->ReadableAcceptor"); + } } - - if (source_object.is_readable_provider() && sink_object.is_readable_acceptor()) + else { - mrc::make_edge(source_object.template readable_provider_typed(), - sink_object.template readable_acceptor_typed()); - return; - } + if (source_object.is_writable_acceptor() && sink_object.is_writable_provider()) + { + mrc::make_edge(source_object.template writable_acceptor_typed(), + sink_object.template writable_provider_typed()); + } - LOG(ERROR) << "Incompatible node types"; + else if (source_object.is_readable_provider() && sink_object.is_readable_acceptor()) + { + mrc::make_edge(source_object.template readable_provider_typed(), + sink_object.template readable_acceptor_typed()); + } + else + { + throw std::runtime_error( + "Invalid edges. Arguments to make_edge were incorrect. Ensure you are providing either " + "WritableAcceptor->WritableProvider or ReadableProvider->ReadableAcceptor"); + } + } } template +#include #include +#include #include #include +#include +#include namespace mrc::segment { -struct ObjectProperties +template +class SharedObject; + +template +class ReferencedObject; + +struct ObjectPropertiesState { - virtual ~ObjectProperties() = 0; + const std::string type_name; + + const bool is_sink; + const bool is_source; + + const bool is_writable_acceptor; + const bool is_writable_provider; + const bool is_readable_acceptor; + const bool is_readable_provider; + + const bool is_runnable; + + bool is_initialized() const + { + return m_is_initialized; + } + + const std::string& name() const + { + return m_name; + } - virtual void set_name(const std::string& name) = 0; - virtual std::string name() const = 0; - virtual std::string type_name() const = 0; + IBuilder* owning_builder() const + { + return m_owning_builder; + } - virtual bool is_sink() const = 0; - virtual bool is_source() const = 0; + void initialize(std::string name, IBuilder* owning_builder) + { + MRC_CHECK_THROW(!m_is_initialized) << "Object '" << name << "' is already initialized."; + + m_name = std::move(name); + m_owning_builder = owning_builder; + m_is_initialized = true; + } + + template + static std::shared_ptr create() + { + auto state = std::shared_ptr(new ObjectPropertiesState( + /*.type_name = */ std::string(::mrc::type_name()), + /*.is_sink = */ std::is_base_of_v, + /*.is_source = */ std::is_base_of_v, + /*.is_writable_acceptor = */ std::is_base_of_v, + /*.is_writable_provider = */ std::is_base_of_v, + /*.is_readable_acceptor = */ std::is_base_of_v, + /*.is_readable_provider = */ std::is_base_of_v, + /*.is_runnable = */ std::is_base_of_v)); + + return state; + } + + private: + ObjectPropertiesState(std::string type_name, + bool is_sink, + bool is_source, + bool is_writable_acceptor, + bool is_writable_provider, + bool is_readable_acceptor, + bool is_readable_provider, + bool is_runnable) : + type_name(std::move(type_name)), + is_sink(is_sink), + is_source(is_source), + is_writable_acceptor(is_writable_acceptor), + is_writable_provider(is_writable_provider), + is_readable_acceptor(is_readable_acceptor), + is_readable_provider(is_readable_provider), + is_runnable(is_runnable) + {} + + // Will be set by the builder class when the object is added to a segment + bool m_is_initialized{false}; + + std::string m_name; + + // The owning builder. Once set, name cannot be changed + IBuilder* m_owning_builder{nullptr}; +}; + +class ObjectProperties +{ + public: + virtual ~ObjectProperties() = default; + + void initialize(std::string name, IBuilder* owning_builder) + { + // Set our name first + this->get_state().initialize(name, owning_builder); + + // Initialize the children + this->init_children(); + } + + virtual std::string name() const + { + return this->get_state().name(); + } + + virtual std::string type_name() const + { + return this->get_state().type_name; + } + + virtual bool is_sink() const + { + return this->get_state().is_sink; + } + + virtual bool is_source() const + { + return this->get_state().is_source; + } + + virtual std::type_index sink_type(bool ignore_holder = false) const = 0; - virtual std::type_index sink_type(bool ignore_holder = false) const = 0; virtual std::type_index source_type(bool ignore_holder = false) const = 0; - virtual bool is_writable_acceptor() const = 0; - virtual bool is_writable_provider() const = 0; - virtual bool is_readable_acceptor() const = 0; - virtual bool is_readable_provider() const = 0; + bool is_writable_acceptor() const + { + return this->get_state().is_writable_acceptor; + } + bool is_writable_provider() const + { + return this->get_state().is_writable_provider; + } + bool is_readable_acceptor() const + { + return this->get_state().is_readable_acceptor; + } + bool is_readable_provider() const + { + return this->get_state().is_readable_provider; + } virtual edge::IWritableAcceptorBase& writable_acceptor_base() = 0; virtual edge::IWritableProviderBase& writable_provider_base() = 0; @@ -70,13 +204,34 @@ struct ObjectProperties template edge::IReadableAcceptor& readable_acceptor_typed(); - virtual bool is_runnable() const = 0; + virtual bool is_runnable() const + { + return this->get_state().is_runnable; + } + + virtual IBuilder* owning_builder() const + { + return this->get_state().owning_builder(); + } virtual runnable::LaunchOptions& launch_options() = 0; virtual const runnable::LaunchOptions& launch_options() const = 0; -}; -inline ObjectProperties::~ObjectProperties() = default; + virtual bool has_child(const std::string& name) const = 0; + virtual std::shared_ptr get_child(const std::string& name) const = 0; + virtual const std::map>& get_children() const = 0; + + protected: + ObjectProperties() = default; + + virtual const ObjectPropertiesState& get_state() const = 0; + virtual ObjectPropertiesState& get_state() = 0; + + private: + virtual void init_children() = 0; + + friend class IBuilder; +}; template edge::IWritableAcceptor& ObjectProperties::writable_acceptor_typed() @@ -146,38 +301,36 @@ edge::IReadableProvider& ObjectProperties::readable_provider_typed() return *readable_provider; } -// Object +// template +// std::type_index deduce_type_index(bool ignore_holder) +// { +// if (ignore_holder) +// { +// if constexpr (is_smart_ptr_v) +// { +// return std::type_index(typeid(typename T::element_type)); +// } +// } + +// return std::type_index(typeid(T)); +// } +// Object template -class Object : public virtual ObjectProperties +class Object : public virtual ObjectProperties, public std::enable_shared_from_this> { public: ObjectT& object(); - - std::string name() const final; - std::string type_name() const final; - - bool is_source() const final; - bool is_sink() const final; + const ObjectT& object() const; std::type_index sink_type(bool ignore_holder) const final; std::type_index source_type(bool ignore_holder) const final; - bool is_writable_acceptor() const final; - bool is_writable_provider() const final; - bool is_readable_acceptor() const final; - bool is_readable_provider() const final; - edge::IWritableAcceptorBase& writable_acceptor_base() final; edge::IWritableProviderBase& writable_provider_base() final; edge::IReadableAcceptorBase& readable_acceptor_base() final; edge::IReadableProviderBase& readable_provider_base() final; - bool is_runnable() const final - { - return static_cast(std::is_base_of_v); - } - runnable::LaunchOptions& launch_options() final { if (!is_runnable()) @@ -198,15 +351,145 @@ class Object : public virtual ObjectProperties return m_launch_options; } + bool has_child(const std::string& name) const override + { + // First, split the name into the local and child names + auto child_name_start_idx = name.find("/"); + + if (child_name_start_idx != std::string::npos) + { + auto local_name = name.substr(0, child_name_start_idx); + auto child_name = name.substr(child_name_start_idx + 1); + + // Check if the local name matches + auto found = m_children.find(local_name); + + if (found == m_children.end()) + { + return false; + } + + // Now check if the child exists + return found->second->has_child(child_name); + } + + return m_children.contains(name); + } + + std::shared_ptr get_child(const std::string& name) const override + { + auto local_name = name; + std::string child_name; + + // First, split the name into the local and child names + auto child_name_start_idx = name.find("/"); + + if (child_name_start_idx != std::string::npos) + { + local_name = name.substr(0, child_name_start_idx); + child_name = name.substr(child_name_start_idx + 1); + } + + auto found = m_children.find(local_name); + + if (found == m_children.end()) + { + throw exceptions::MrcRuntimeError("Child " + local_name + " not found in " + this->name()); + } + + if (!child_name.empty()) + { + return found->second->get_child(child_name); + } + + return found->second; + } + + const std::map>& get_children() const override + { + return m_children; + } + + template + requires std::derived_from + std::shared_ptr> as() const + { + auto shared_object = std::make_shared>(*const_cast(this)); + + return shared_object; + } + protected: - // Move to protected to allow only the IBuilder to set the name - void set_name(const std::string& name) override; + Object() : m_state(ObjectPropertiesState::create()) {} + + template + requires std::derived_from + Object(const Object& other) : + ObjectProperties(other), + m_state(ObjectPropertiesState::create()), + m_launch_options(other.m_launch_options), + m_children(other.m_children) + {} + + const ObjectPropertiesState& get_state() const override + { + return *m_state; + } - private: - std::string m_name{}; + ObjectPropertiesState& get_state() override + { + return *m_state; + } + private: virtual ObjectT* get_object() const = 0; + + void init_children() override + { + if constexpr (is_base_of_template::value) + { + using child_node_t = typename ObjectT::child_node_t; + + // Get a map of the name/reference pairs from the NodeParent + auto children_ref_pairs = this->object().get_children_refs(); + + // Now loop and add any new children + for (const auto& [name, child_ref] : children_ref_pairs) + { + auto child_obj = std::make_shared>(this->shared_from_this(), child_ref); + + m_children.emplace(name, std::move(child_obj)); + } + } + + if constexpr (is_base_of_template::value) + { + using child_types_t = typename ObjectT::child_types_t; + + // Get the name/reference pairs from the NodeParent + auto children_ref_pairs = this->object().get_children_refs(); + + // Finally, convert the tuple of name/ChildObject pairs into a map + utils::tuple_for_each( + children_ref_pairs, + [this](std::pair>& pair, + size_t idx) { + auto child_obj = std::make_shared>(this->shared_from_this(), pair.second); + + m_children.emplace(pair.first, std::move(child_obj)); + }); + } + } + + std::shared_ptr m_state; + runnable::LaunchOptions m_launch_options; + + std::map> m_children; + + // Allows converting to base classes + template + friend class Object; }; template @@ -223,33 +506,16 @@ ObjectT& Object::object() } template -void Object::set_name(const std::string& name) -{ - m_name = name; -} - -template -std::string Object::name() const +const ObjectT& Object::object() const { - return m_name; -} - -template -std::string Object::type_name() const -{ - return std::string(::mrc::type_name()); -} - -template -bool Object::is_source() const -{ - return std::is_base_of_v; -} - -template -bool Object::is_sink() const -{ - return std::is_base_of_v; + auto* node = get_object(); + if (node == nullptr) + { + LOG(ERROR) << "Error accessing the Object API; Nodes are moved from the Segment API to the Executor " + "when the pipeline is started."; + throw exceptions::MrcRuntimeError("Object API is unavailable - expected if the Pipeline is running."); + } + return *node; } template @@ -277,82 +543,79 @@ std::type_index Object::source_type(bool ignore_holder) const } template -bool Object::is_writable_acceptor() const -{ - return std::is_base_of_v; -} - -template -bool Object::is_writable_provider() const +edge::IWritableAcceptorBase& Object::writable_acceptor_base() { - return std::is_base_of_v; + auto* base = dynamic_cast(get_object()); + CHECK(base) << type_name() << " is not a IIngressAcceptorBase"; + return *base; } template -bool Object::is_readable_acceptor() const +edge::IWritableProviderBase& Object::writable_provider_base() { - return std::is_base_of_v; + auto* base = dynamic_cast(get_object()); + CHECK(base) << type_name() << " is not a IWritableProviderBase"; + return *base; } template -bool Object::is_readable_provider() const +edge::IReadableAcceptorBase& Object::readable_acceptor_base() { - return std::is_base_of_v; + auto* base = dynamic_cast(get_object()); + CHECK(base) << type_name() << " is not a IReadableAcceptorBase"; + return *base; } template -edge::IWritableAcceptorBase& Object::writable_acceptor_base() +edge::IReadableProviderBase& Object::readable_provider_base() { - if constexpr (!std::is_base_of_v) - { - LOG(ERROR) << type_name() << " is not a IIngressAcceptorBase"; - throw exceptions::MrcRuntimeError("Object is not a IIngressAcceptorBase"); - } - - auto* base = dynamic_cast(get_object()); - CHECK(base); + auto* base = dynamic_cast(get_object()); + CHECK(base) << type_name() << " is not a IReadableProviderBase"; return *base; } template -edge::IWritableProviderBase& Object::writable_provider_base() +class SharedObject final : public Object { - if constexpr (!std::is_base_of_v) + public: + SharedObject(std::shared_ptr owner, std::reference_wrapper resource) : + m_owner(std::move(owner)), + m_resource(std::move(resource)) + {} + ~SharedObject() final = default; + + private: + ObjectT* get_object() const final { - LOG(ERROR) << type_name() << " is not a IIngressProviderBase"; - throw exceptions::MrcRuntimeError("Object is not a IIngressProviderBase"); + return &m_resource.get(); } - auto* base = dynamic_cast(get_object()); - CHECK(base); - return *base; -} + std::shared_ptr m_owner; + std::reference_wrapper m_resource; +}; template -edge::IReadableAcceptorBase& Object::readable_acceptor_base() +class ReferencedObject final : public Object { - if constexpr (!std::is_base_of_v) - { - LOG(ERROR) << type_name() << " is not a IEgressAcceptorBase"; - throw exceptions::MrcRuntimeError("Object is not a IEgressAcceptorBase"); - } + public: + template + requires std::derived_from + ReferencedObject(Object& other) : + Object(other), + m_owner(other.shared_from_this()), + m_resource(other.object()) + {} - auto* base = dynamic_cast(get_object()); - CHECK(base); - return *base; -} + ~ReferencedObject() final = default; -template -edge::IReadableProviderBase& Object::readable_provider_base() -{ - if constexpr (!std::is_base_of_v) + private: + ObjectT* get_object() const final { - LOG(ERROR) << type_name() << " is not a IEgressProviderBase"; - throw exceptions::MrcRuntimeError("Object is not a IEgressProviderBase"); + return &m_resource.get(); } - auto* base = dynamic_cast(get_object()); - CHECK(base); - return *base; -} + std::shared_ptr m_owner; + std::reference_wrapper m_resource; +}; + } // namespace mrc::segment diff --git a/cpp/mrc/include/mrc/segment/runnable.hpp b/cpp/mrc/include/mrc/segment/runnable.hpp index ab5b590ca..0121bed65 100644 --- a/cpp/mrc/include/mrc/segment/runnable.hpp +++ b/cpp/mrc/include/mrc/segment/runnable.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -37,15 +37,15 @@ template class Runnable : public Object, public runnable::Launchable { public: - template - Runnable(ArgsT&&... args) : m_node(std::make_unique(std::forward(args)...)) - {} - Runnable(std::unique_ptr node) : m_node(std::move(node)) { CHECK(m_node); } + template + Runnable(ArgsT&&... args) : Runnable(std::make_unique(std::forward(args)...)) + {} + private: NodeT* get_object() const final; std::unique_ptr prepare_launcher(runnable::LaunchControl& launch_control) final; diff --git a/cpp/mrc/include/mrc/type_traits.hpp b/cpp/mrc/include/mrc/type_traits.hpp index 4f1477abf..9e86f60d8 100644 --- a/cpp/mrc/include/mrc/type_traits.hpp +++ b/cpp/mrc/include/mrc/type_traits.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -115,6 +115,10 @@ template