From 62e18346972fea8cfa25e7d72976a9f37da71ada Mon Sep 17 00:00:00 2001 From: Christopher Harris Date: Thu, 2 Nov 2023 17:33:43 -0500 Subject: [PATCH] Add AsyncioRunnable (#411) Moves the CoroutineRunnable from Morpheus' Sherlock feature branch to MRC and renames it to AsyncioRunnable as it is heavily dependent on asyncio. Adjustments were made such that the Scheduler would no longer own a task container and/or tasks, leaving the scheduler interface simpler. Instead, the runnable is responsible for the lifetime of the tasks it creates. This leaves the scheduler with a single responsibility. Much of the code could be moved to MRC proper from PyMRC, but it's not immediately obvious where the code should live or whether it would be reused, so keeping it colocated with the AsyncioRunnable makes the most sense for now, imo. Authors: - Christopher Harris (https://github.com/cwharris) Approvers: - Devin Robison (https://github.com/drobison00) URL: https://github.com/nv-morpheus/MRC/pull/411 --- cpp/mrc/CMakeLists.txt | 2 +- cpp/mrc/include/mrc/coroutines/scheduler.hpp | 91 +---- .../mrc/exceptions/exception_catcher.hpp | 53 +++ cpp/mrc/src/public/coroutines/scheduler.cpp | 85 ---- .../public/exceptions/exception_catcher.cpp | 50 +++ .../_pymrc/include/pymrc/asyncio_runnable.hpp | 370 ++++++++++++++++++ .../include/pymrc/asyncio_scheduler.hpp | 105 +++++ python/mrc/_pymrc/include/pymrc/coro.hpp | 15 +- python/mrc/_pymrc/tests/CMakeLists.txt | 1 + .../_pymrc/tests/test_asyncio_runnable.cpp | 331 ++++++++++++++++ 10 files changed, 929 insertions(+), 174 deletions(-) create mode 100644 cpp/mrc/include/mrc/exceptions/exception_catcher.hpp delete mode 100644 cpp/mrc/src/public/coroutines/scheduler.cpp create mode 100644 cpp/mrc/src/public/exceptions/exception_catcher.cpp create mode 100644 python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp create mode 100644 python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp create mode 100644 python/mrc/_pymrc/tests/test_asyncio_runnable.cpp diff --git a/cpp/mrc/CMakeLists.txt b/cpp/mrc/CMakeLists.txt index 93909e8c6..f2f1e63cc 100644 --- a/cpp/mrc/CMakeLists.txt +++ b/cpp/mrc/CMakeLists.txt @@ -115,7 +115,6 @@ add_library(libmrc src/public/core/logging.cpp src/public/core/thread.cpp src/public/coroutines/event.cpp - src/public/coroutines/scheduler.cpp src/public/coroutines/sync_wait.cpp src/public/coroutines/task_container.cpp src/public/coroutines/thread_local_context.cpp @@ -124,6 +123,7 @@ add_library(libmrc src/public/cuda/sync.cpp src/public/edge/edge_adapter_registry.cpp src/public/edge/edge_builder.cpp + src/public/exceptions/exception_catcher.cpp src/public/manifold/manifold.cpp src/public/memory/buffer_view.cpp src/public/memory/codable/buffer.cpp diff --git a/cpp/mrc/include/mrc/coroutines/scheduler.hpp b/cpp/mrc/include/mrc/coroutines/scheduler.hpp index 1b0aac502..0e296924a 100644 --- a/cpp/mrc/include/mrc/coroutines/scheduler.hpp +++ b/cpp/mrc/include/mrc/coroutines/scheduler.hpp @@ -25,109 +25,30 @@ #include #include -// IWYU thinks this is needed, but it's not -// IWYU pragma: no_include "mrc/coroutines/task_container.hpp" - namespace mrc::coroutines { -class TaskContainer; // IWYU pragma: keep - /** * @brief Scheduler base class - * - * Allows all schedulers to be discovered via the mrc::this_thread::current_scheduler() */ class Scheduler : public std::enable_shared_from_this { public: - struct Operation - { - Operation(Scheduler& scheduler); - - constexpr static auto await_ready() noexcept -> bool - { - return false; - } - - std::coroutine_handle<> await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept; - - constexpr static auto await_resume() noexcept -> void {} - - Scheduler& m_scheduler; - std::coroutine_handle<> m_awaiting_coroutine; - Operation* m_next{nullptr}; - }; - - Scheduler(); virtual ~Scheduler() = default; /** - * @brief Description of Scheduler - */ - virtual std::string description() const = 0; - - /** - * Schedules the currently executing coroutine to be run on this thread pool. This must be - * called from within the coroutines function body to schedule the coroutine on the thread pool. - * @throw std::runtime_error If the thread pool is `shutdown()` scheduling new tasks is not permitted. - * @return The operation to switch from the calling scheduling thread to the executor thread - * pool thread. - */ - [[nodiscard]] virtual auto schedule() -> Operation; - - // Enqueues a message without waiting for it. Must return void since the caller will not get the return value - virtual void schedule(Task&& task); - - /** - * Schedules any coroutine handle that is ready to be resumed. - * @param handle The coroutine handle to schedule. - */ - virtual auto resume(std::coroutine_handle<> coroutine) -> void = 0; - - /** - * Yields the current task to the end of the queue of waiting tasks. - */ - [[nodiscard]] auto yield() -> Operation; - - /** - * If the calling thread controlled by a Scheduler, return a pointer to the Scheduler + * @brief Resumes a coroutine according to the scheduler's implementation. */ - static auto from_current_thread() noexcept -> Scheduler*; + virtual void resume(std::coroutine_handle<> handle) noexcept = 0; /** - * If the calling thread is owned by a thread_pool, return the thread index (rank) of the current thread with - * respect the threads in the pool; otherwise, return the std::hash of std::this_thread::get_id + * @brief Suspends the current function and resumes it according to the scheduler's implementation. */ - static auto get_thread_id() noexcept -> std::size_t; + [[nodiscard]] virtual Task<> schedule() = 0; - protected: - virtual auto on_thread_start(std::size_t) -> void; - - /** - * @brief Get the task container object - * - * @return TaskContainer& - */ - TaskContainer& get_task_container() const; - - private: /** - * @brief When co_await schedule() is called, this function will be executed by the awaiter. Each scheduler - * implementation should determine how and when to execute the operation. - * - * @param operation The schedule() awaitable pointer - * @return std::coroutine_handle<> Return a coroutine handle to which will be - * used as the return value for await_suspend(). + * @brief Suspends the current function and resumes it according to the scheduler's implementation. */ - virtual std::coroutine_handle<> schedule_operation(Operation* operation) = 0; - - mutable std::mutex m_mutex; - - // Maintains the lifetime of fire-and-forget tasks scheduled with schedule(Task&& task) - std::unique_ptr m_task_container; - - thread_local static Scheduler* m_thread_local_scheduler; - thread_local static std::size_t m_thread_id; + [[nodiscard]] virtual Task<> yield() = 0; }; } // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/exceptions/exception_catcher.hpp b/cpp/mrc/include/mrc/exceptions/exception_catcher.hpp new file mode 100644 index 000000000..98c4a7d6d --- /dev/null +++ b/cpp/mrc/include/mrc/exceptions/exception_catcher.hpp @@ -0,0 +1,53 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#pragma once + +namespace mrc { + +/** + * @brief A utility for catching out-of-stack exceptions in a thread-safe manner such that they + * can be checked and throw from a parent thread. + */ +class ExceptionCatcher +{ + public: + /** + * @brief "catches" an exception to the catcher + */ + void push_exception(std::exception_ptr ex); + + /** + * @brief checks to see if any exceptions have been "caught" by the catcher. + */ + bool has_exception(); + + /** + * @brief rethrows the next exception (in the order in which it was "caught"). + */ + void rethrow_next_exception(); + + private: + std::mutex m_mutex{}; + std::queue m_exceptions{}; +}; + +} // namespace mrc diff --git a/cpp/mrc/src/public/coroutines/scheduler.cpp b/cpp/mrc/src/public/coroutines/scheduler.cpp deleted file mode 100644 index af2e70294..000000000 --- a/cpp/mrc/src/public/coroutines/scheduler.cpp +++ /dev/null @@ -1,85 +0,0 @@ -/** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "mrc/coroutines/scheduler.hpp" - -#include "mrc/coroutines/task_container.hpp" - -#include - -#include -#include -#include -#include - -namespace mrc::coroutines { - -thread_local Scheduler* Scheduler::m_thread_local_scheduler{nullptr}; -thread_local std::size_t Scheduler::m_thread_id{0}; - -Scheduler::Operation::Operation(Scheduler& scheduler) : m_scheduler(scheduler) {} - -std::coroutine_handle<> Scheduler::Operation::await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -{ - m_awaiting_coroutine = awaiting_coroutine; - return m_scheduler.schedule_operation(this); -} - -Scheduler::Scheduler() : m_task_container(new TaskContainer(*this)) {} - -auto Scheduler::schedule() -> Operation -{ - return Operation{*this}; -} - -void Scheduler::schedule(Task&& task) -{ - return m_task_container->start(std::move(task)); -} - -auto Scheduler::yield() -> Operation -{ - return schedule(); -} - -auto Scheduler::from_current_thread() noexcept -> Scheduler* -{ - return m_thread_local_scheduler; -} - -auto Scheduler::get_thread_id() noexcept -> std::size_t -{ - if (m_thread_local_scheduler == nullptr) - { - return std::hash()(std::this_thread::get_id()); - } - return m_thread_id; -} - -auto Scheduler::on_thread_start(std::size_t thread_id) -> void -{ - DVLOG(10) << "scheduler: " << description() << " initializing"; - m_thread_id = thread_id; - m_thread_local_scheduler = this; -} - -TaskContainer& Scheduler::get_task_container() const -{ - return *m_task_container; -} - -} // namespace mrc::coroutines diff --git a/cpp/mrc/src/public/exceptions/exception_catcher.cpp b/cpp/mrc/src/public/exceptions/exception_catcher.cpp new file mode 100644 index 000000000..c139436f7 --- /dev/null +++ b/cpp/mrc/src/public/exceptions/exception_catcher.cpp @@ -0,0 +1,50 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace mrc { + +void ExceptionCatcher::push_exception(std::exception_ptr ex) +{ + auto lock = std::lock_guard(m_mutex); + m_exceptions.push(ex); +} + +bool ExceptionCatcher::has_exception() +{ + auto lock = std::lock_guard(m_mutex); + return not m_exceptions.empty(); +} + +void ExceptionCatcher::rethrow_next_exception() +{ + auto lock = std::lock_guard(m_mutex); + + if (m_exceptions.empty()) + { + return; + } + + auto ex = m_exceptions.front(); + + m_exceptions.pop(); + + std::rethrow_exception(ex); +} + +} // namespace mrc diff --git a/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp b/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp new file mode 100644 index 000000000..965cf551c --- /dev/null +++ b/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp @@ -0,0 +1,370 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "pymrc/asyncio_scheduler.hpp" +#include "pymrc/utilities/object_wrappers.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace mrc::pymrc { + +/** + * @brief A wrapper for executing a function as an async boost fiber, the result of which is a + * C++20 coroutine awaiter. + */ +template +class BoostFutureAwaitableOperation +{ + class Awaiter; + + public: + BoostFutureAwaitableOperation(std::function fn) : m_fn(std::move(fn)) {} + + /** + * @brief Calls the wrapped function as an asyncboost fiber and returns a C++20 coroutine awaiter. + */ + template + auto operator()(ArgsT&&... args) -> Awaiter + { + // Make a copy of m_fn here so we can call this operator again + return Awaiter(m_fn, std::forward(args)...); + } + + private: + class Awaiter + { + public: + using return_t = typename std::function::result_type; + + template + Awaiter(std::function fn, ArgsT&&... args) + { + m_future = boost::fibers::async(boost::fibers::launch::post, fn, std::forward(args)...); + } + + bool await_ready() noexcept + { + return false; + } + + void await_suspend(std::coroutine_handle<> continuation) noexcept + { + // Launch a new fiber that waits on the future and then resumes the coroutine + boost::fibers::async( + boost::fibers::launch::post, + [this](std::coroutine_handle<> continuation) { + // Wait on the future + m_future.wait(); + + // Resume the coroutine + continuation.resume(); + }, + std::move(continuation)); + } + + auto await_resume() + { + return m_future.get(); + } + + private: + boost::fibers::future m_future; + std::function)> m_inner_fn; + }; + + std::function m_fn; +}; + +/** + * @brief A MRC Sink which receives from a channel using an awaitable interface. + */ +template +class AsyncSink : public mrc::node::WritableProvider, + public mrc::node::ReadableAcceptor, + public mrc::node::SinkChannelOwner +{ + protected: + AsyncSink() : + m_read_async([this](T& value) { + return this->get_readable_edge()->await_read(value); + }) + { + // Set the default channel + this->set_channel(std::make_unique>()); + } + + /** + * @brief Asynchronously reads a value from the sink's channel + */ + coroutines::Task read_async(T& value) + { + co_return co_await m_read_async(std::ref(value)); + } + + private: + BoostFutureAwaitableOperation m_read_async; +}; + +/** + * @brief A MRC Source which produces to a channel using an awaitable interface. + */ +template +class AsyncSource : public mrc::node::WritableAcceptor, + public mrc::node::ReadableProvider, + public mrc::node::SourceChannelOwner +{ + protected: + AsyncSource() : + m_write_async([this](T&& value) { + return this->get_writable_edge()->await_write(std::move(value)); + }) + { + // Set the default channel + this->set_channel(std::make_unique>()); + } + + /** + * @brief Asynchronously writes a value to the source's channel + */ + coroutines::Task write_async(T&& value) + { + co_return co_await m_write_async(std::move(value)); + } + + private: + BoostFutureAwaitableOperation m_write_async; +}; + +/** + * @brief A MRC Runnable base class which hosts it's own asyncio loop and exposes a flatmap hook + */ +template +class AsyncioRunnable : public AsyncSink, + public AsyncSource, + public mrc::runnable::RunnableWithContext<> +{ + using state_t = mrc::runnable::Runnable::State; + using task_buffer_t = mrc::coroutines::ClosableRingBuffer; + + public: + AsyncioRunnable(size_t concurrency = 8) : m_concurrency(concurrency){}; + ~AsyncioRunnable() override = default; + + private: + /** + * @brief Runnable's entrypoint. + */ + void run(mrc::runnable::Context& ctx) override; + + /** + * @brief Runnable's state control, for stopping from MRC. + */ + void on_state_update(const state_t& state) final; + + /** + * @brief The top-level coroutine which is run while the asyncio event loop is running. + */ + coroutines::Task<> main_task(std::shared_ptr scheduler); + + /** + * @brief The per-value coroutine run asynchronously alongside other calls. + */ + coroutines::Task<> process_one(InputT value, + task_buffer_t& task_buffer, + std::shared_ptr on, + ExceptionCatcher& catcher); + + /** + * @brief Value's read from the sink's channel are fed to this function and yields from the + * resulting generator are written to the source's channel. + */ + virtual mrc::coroutines::AsyncGenerator on_data(InputT&& value) = 0; + + std::stop_source m_stop_source; + + size_t m_concurrency{8}; +}; + +template +void AsyncioRunnable::run(mrc::runnable::Context& ctx) +{ + std::exception_ptr exception; + + { + py::gil_scoped_acquire gil; + + auto asyncio = py::module_::import("asyncio"); + + auto loop = [](auto& asyncio) -> PyObjectHolder { + try + { + return asyncio.attr("get_running_loop")(); + } catch (...) + { + return py::none(); + } + }(asyncio); + + if (not loop.is_none()) + { + throw std::runtime_error("asyncio loop already running, but runnable is expected to create it."); + } + + // Need to create a loop + LOG(INFO) << "AsyncioRunnable::run() > Creating new event loop"; + + // Gets (or more likely, creates) an event loop and runs it forever until stop is called + loop = asyncio.attr("new_event_loop")(); + + // Set the event loop as the current event loop + asyncio.attr("set_event_loop")(loop); + + // TODO(MDD): Eventually we should get this from the context object. For now, just create it directly + auto scheduler = std::make_shared(loop); + + auto py_awaitable = coro::BoostFibersMainPyAwaitable(this->main_task(scheduler)); + + LOG(INFO) << "AsyncioRunnable::run() > Calling run_until_complete() on main_task()"; + + try + { + loop.attr("run_until_complete")(std::move(py_awaitable)); + } catch (...) + { + exception = std::current_exception(); + } + + loop.attr("close")(); + } + + // Need to drop the output edges + mrc::node::SourceProperties::release_edge_connection(); + mrc::node::SinkProperties::release_edge_connection(); + + if (exception != nullptr) + { + std::rethrow_exception(exception); + } +} + +template +coroutines::Task<> AsyncioRunnable::main_task(std::shared_ptr scheduler) +{ + // Create the task buffer to limit the number of running tasks + task_buffer_t task_buffer{{.capacity = m_concurrency}}; + + coroutines::TaskContainer outstanding_tasks(scheduler); + + ExceptionCatcher catcher{}; + + while (not m_stop_source.stop_requested() and not catcher.has_exception()) + { + InputT data; + + auto read_status = co_await this->read_async(data); + + if (read_status != mrc::channel::Status::success) + { + break; + } + + // Wait for an available slot in the task buffer + co_await task_buffer.write(0); + + outstanding_tasks.start(this->process_one(std::move(data), task_buffer, scheduler, catcher)); + } + + // Close the buffer + task_buffer.close(); + + // Now block until all tasks are complete + co_await task_buffer.completed(); + + co_await outstanding_tasks.garbage_collect_and_yield_until_empty(); + + catcher.rethrow_next_exception(); +} + +template +coroutines::Task<> AsyncioRunnable::process_one(InputT value, + task_buffer_t& task_buffer, + std::shared_ptr on, + ExceptionCatcher& catcher) +{ + co_await on->yield(); + + try + { + // Call the on_data function + auto on_data_gen = this->on_data(std::move(value)); + + auto iter = co_await on_data_gen.begin(); + + while (iter != on_data_gen.end()) + { + // Weird bug, cant directly move the value into the async_write call + auto data = std::move(*iter); + + co_await this->write_async(std::move(data)); + + // Advance the iterator + co_await ++iter; + } + } catch (...) + { + catcher.push_exception(std::current_exception()); + } + + // Return the slot to the task buffer + co_await task_buffer.read(); +} + +template +void AsyncioRunnable::on_state_update(const state_t& state) +{ + switch (state) + { + case state_t::Stop: + // Do nothing, we wait for the upstream channel to return closed + // m_stop_source.request_stop(); + break; + + case state_t::Kill: + m_stop_source.request_stop(); + break; + + default: + break; + } +} + +} // namespace mrc::pymrc diff --git a/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp b/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp new file mode 100644 index 000000000..3d9e563b9 --- /dev/null +++ b/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp @@ -0,0 +1,105 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "pymrc/coro.hpp" +#include "pymrc/utilities/acquire_gil.hpp" +#include "pymrc/utilities/object_wrappers.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace py = pybind11; + +namespace mrc::pymrc { + +/** + * @brief A MRC Scheduler which allows resuming C++20 coroutines on an Asyncio event loop. + */ +class AsyncioScheduler : public mrc::coroutines::Scheduler +{ + private: + class ContinueOnLoopOperation + { + public: + ContinueOnLoopOperation(PyObjectHolder loop) : m_loop(std::move(loop)) {} + + static bool await_ready() noexcept + { + return false; + } + + void await_suspend(std::coroutine_handle<> handle) noexcept + { + AsyncioScheduler::resume(m_loop, handle); + } + + static void await_resume() noexcept {} + + private: + PyObjectHolder m_loop; + }; + + static void resume(PyObjectHolder loop, std::coroutine_handle<> handle) noexcept + { + pybind11::gil_scoped_acquire acquire; + loop.attr("call_soon_threadsafe")(pybind11::cpp_function([handle]() { + pybind11::gil_scoped_release release; + handle.resume(); + })); + } + + public: + AsyncioScheduler(PyObjectHolder loop) : m_loop(std::move(loop)) {} + + /** + * @brief Resumes a coroutine on the scheduler's Asyncio event loop + */ + void resume(std::coroutine_handle<> handle) noexcept override + { + AsyncioScheduler::resume(m_loop, handle); + } + + /** + * @brief Suspends the current function and resumes it on the scheduler's Asyncio event loop + */ + [[nodiscard]] coroutines::Task<> schedule() override + { + co_await ContinueOnLoopOperation(m_loop); + } + + /** + * @brief Suspends the current function and resumes it on the scheduler's Asyncio event loop + */ + [[nodiscard]] coroutines::Task<> yield() override + { + co_await ContinueOnLoopOperation(m_loop); + } + + private: + mrc::pymrc::PyHolder m_loop; +}; + +} // namespace mrc::pymrc diff --git a/python/mrc/_pymrc/include/pymrc/coro.hpp b/python/mrc/_pymrc/include/pymrc/coro.hpp index 5c80398cc..5b50f14a0 100644 --- a/python/mrc/_pymrc/include/pymrc/coro.hpp +++ b/python/mrc/_pymrc/include/pymrc/coro.hpp @@ -174,13 +174,22 @@ class PYBIND11_EXPORT PyTaskToCppAwaitable PyTaskToCppAwaitable(mrc::pymrc::PyObjectHolder&& task) : m_task(std::move(task)) { pybind11::gil_scoped_acquire acquire; - if (pybind11::module_::import("inspect").attr("iscoroutine")(m_task).cast()) + + auto asyncio = pybind11::module_::import("asyncio"); + + if (not asyncio.attr("isfuture")(m_task).cast()) { - m_task = pybind11::module_::import("asyncio").attr("create_task")(m_task); + if (not asyncio.attr("iscoroutine")(m_task).cast()) + { + throw std::runtime_error(MRC_CONCAT_STR("PyTaskToCppAwaitable expected task or coroutine but got " + << pybind11::repr(m_task).cast())); + } + + m_task = asyncio.attr("create_task")(m_task); } } - static bool await_ready() noexcept // NOLINT(readability-convert-member-functions-to-static) + static bool await_ready() noexcept { // Always suspend return false; diff --git a/python/mrc/_pymrc/tests/CMakeLists.txt b/python/mrc/_pymrc/tests/CMakeLists.txt index f40e20d72..02186de90 100644 --- a/python/mrc/_pymrc/tests/CMakeLists.txt +++ b/python/mrc/_pymrc/tests/CMakeLists.txt @@ -21,6 +21,7 @@ add_subdirectory(coro) # Keep all source files sorted!!! add_executable(test_pymrc + test_asyncio_runnable.cpp test_codable_pyobject.cpp test_executor.cpp test_main.cpp diff --git a/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp b/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp new file mode 100644 index 000000000..46a139a04 --- /dev/null +++ b/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp @@ -0,0 +1,331 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pymrc/asyncio_runnable.hpp" +#include "pymrc/coro.hpp" +#include "pymrc/executor.hpp" +#include "pymrc/pipeline.hpp" +#include "pymrc/utilities/object_wrappers.hpp" + +#include "mrc/coroutines/async_generator.hpp" +#include "mrc/coroutines/sync_wait.hpp" +#include "mrc/coroutines/task.hpp" +#include "mrc/node/rx_sink.hpp" +#include "mrc/node/rx_source.hpp" +#include "mrc/options/engine_groups.hpp" +#include "mrc/options/options.hpp" +#include "mrc/options/topology.hpp" +#include "mrc/runnable/types.hpp" +#include "mrc/segment/builder.hpp" +#include "mrc/segment/object.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +namespace pymrc = mrc::pymrc; +using namespace std::string_literals; +using namespace py::literals; + +class __attribute__((visibility("default"))) TestAsyncioRunnable : public ::testing::Test +{ + public: + static void SetUpTestSuite() + { + m_interpreter = std::make_unique(); + pybind11::gil_scoped_acquire acquire; + pybind11::module_::import("mrc.core.coro"); + } + + static void TearDownTestSuite() + { + m_interpreter.reset(); + } + + private: + static std::unique_ptr m_interpreter; +}; + +std::unique_ptr TestAsyncioRunnable::m_interpreter; + +class PythonCallbackAsyncioRunnable : public pymrc::AsyncioRunnable +{ + public: + PythonCallbackAsyncioRunnable(pymrc::PyObjectHolder operation) : m_operation(std::move(operation)) {} + + mrc::coroutines::AsyncGenerator on_data(int&& value) override + { + py::gil_scoped_acquire acquire; + + auto coroutine = m_operation(py::cast(value)); + + pymrc::PyObjectHolder result; + { + py::gil_scoped_release release; + + result = co_await pymrc::coro::PyTaskToCppAwaitable(std::move(coroutine)); + } + + auto result_casted = py::cast(result); + + py::gil_scoped_release release; + + co_yield result_casted; + }; + + private: + pymrc::PyObjectHolder m_operation; +}; + +TEST_F(TestAsyncioRunnable, UseAsyncioTasks) +{ + py::object globals = py::globals(); + py::exec( + R"( + async def fn(value): + import asyncio + await asyncio.sleep(0) + return value * 2 + )", + globals); + + pymrc::PyObjectHolder fn = static_cast(globals["fn"]); + + ASSERT_FALSE(fn.is_none()); + + std::atomic counter = 0; + pymrc::Pipeline p; + + auto init = [&counter, &fn](mrc::segment::IBuilder& seg) { + auto src = seg.make_source("src", [](rxcpp::subscriber& s) { + if (s.is_subscribed()) + { + s.on_next(5); + s.on_next(10); + } + + s.on_completed(); + }); + + auto internal = seg.construct_object("internal", fn); + + auto sink = seg.make_sink("sink", [&counter](int x) { + counter.fetch_add(x, std::memory_order_relaxed); + }); + + seg.make_edge(src, internal); + seg.make_edge(internal, sink); + }; + + p.make_segment("seg1"s, init); + p.make_segment("seg2"s, init); + + auto options = std::make_shared(); + options->topology().user_cpuset("0"); + // AsyncioRunnable only works with the Thread engine due to asyncio loops being thread-specific. + options->engine_factories().set_default_engine_type(mrc::runnable::EngineType::Thread); + + pymrc::Executor exec{options}; + exec.register_pipeline(p); + + exec.start(); + exec.join(); + + EXPECT_EQ(counter, 60); +} + +TEST_F(TestAsyncioRunnable, UseAsyncioGeneratorThrows) +{ + // pybind11::module_::import("mrc.core.coro"); + + py::object globals = py::globals(); + py::exec( + R"( + async def fn(value): + yield value + )", + globals); + + pymrc::PyObjectHolder fn = static_cast(globals["fn"]); + + ASSERT_FALSE(fn.is_none()); + + std::atomic counter = 0; + pymrc::Pipeline p; + + auto init = [&counter, &fn](mrc::segment::IBuilder& seg) { + auto src = seg.make_source("src", [](rxcpp::subscriber& s) { + if (s.is_subscribed()) + { + s.on_next(5); + s.on_next(10); + } + + s.on_completed(); + }); + + auto internal = seg.construct_object("internal", fn); + + auto sink = seg.make_sink("sink", [&counter](int x) { + counter.fetch_add(x, std::memory_order_relaxed); + }); + + seg.make_edge(src, internal); + seg.make_edge(internal, sink); + }; + + p.make_segment("seg1"s, init); + p.make_segment("seg2"s, init); + + auto options = std::make_shared(); + options->topology().user_cpuset("0"); + // AsyncioRunnable only works with the Thread engine due to asyncio loops being thread-specific. + options->engine_factories().set_default_engine_type(mrc::runnable::EngineType::Thread); + + pymrc::Executor exec{options}; + exec.register_pipeline(p); + + exec.start(); + + ASSERT_THROW(exec.join(), std::runtime_error); +} + +TEST_F(TestAsyncioRunnable, UseAsyncioTasksThrows) +{ + // pybind11::module_::import("mrc.core.coro"); + + py::object globals = py::globals(); + py::exec( + R"( + async def fn(value): + raise RuntimeError("oops") + )", + globals); + + pymrc::PyObjectHolder fn = static_cast(globals["fn"]); + + ASSERT_FALSE(fn.is_none()); + + std::atomic counter = 0; + pymrc::Pipeline p; + + auto init = [&counter, &fn](mrc::segment::IBuilder& seg) { + auto src = seg.make_source("src", [](rxcpp::subscriber& s) { + if (s.is_subscribed()) + { + s.on_next(5); + s.on_next(10); + } + + s.on_completed(); + }); + + auto internal = seg.construct_object("internal", fn); + + auto sink = seg.make_sink("sink", [&counter](int x) { + counter.fetch_add(x, std::memory_order_relaxed); + }); + + seg.make_edge(src, internal); + seg.make_edge(internal, sink); + }; + + p.make_segment("seg1"s, init); + p.make_segment("seg2"s, init); + + auto options = std::make_shared(); + options->topology().user_cpuset("0"); + // AsyncioRunnable only works with the Thread engine due to asyncio loops being thread-specific. + options->engine_factories().set_default_engine_type(mrc::runnable::EngineType::Thread); + + pymrc::Executor exec{options}; + exec.register_pipeline(p); + + exec.start(); + + ASSERT_THROW(exec.join(), std::runtime_error); +} + +template +auto run_operation(OperationT& operation) -> mrc::coroutines::Task +{ + auto stop_source = std::stop_source(); + + auto coro = [](auto& operation, auto stop_source) -> mrc::coroutines::Task { + try + { + auto value = co_await operation(); + stop_source.request_stop(); + co_return value; + } catch (...) + { + stop_source.request_stop(); + throw; + } + }(operation, stop_source); + + coro.resume(); + + while (not stop_source.stop_requested()) + { + if (boost::fibers::has_ready_fibers()) + { + boost::this_fiber::yield(); + } + } + + co_return co_await coro; +} + +TEST_F(TestAsyncioRunnable, BoostFutureAwaitableOperationCanReturn) +{ + auto operation = mrc::pymrc::BoostFutureAwaitableOperation([]() { + using namespace std::chrono_literals; + boost::this_fiber::sleep_for(10ms); + return 5; + }); + + ASSERT_EQ(mrc::coroutines::sync_wait(run_operation(operation)), 5); +} + +TEST_F(TestAsyncioRunnable, BoostFutureAwaitableOperationCanThrow) +{ + auto operation = mrc::pymrc::BoostFutureAwaitableOperation([]() { + throw std::runtime_error("oops"); + return 5; + }); + + ASSERT_THROW(mrc::coroutines::sync_wait(run_operation(operation)), std::runtime_error); +}