diff --git a/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp b/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp index 965cf551c..2517acdd8 100644 --- a/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp +++ b/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp @@ -176,7 +176,6 @@ class AsyncioRunnable : public AsyncSink, using task_buffer_t = mrc::coroutines::ClosableRingBuffer; public: - AsyncioRunnable(size_t concurrency = 8) : m_concurrency(concurrency){}; ~AsyncioRunnable() override = default; private: @@ -199,7 +198,6 @@ class AsyncioRunnable : public AsyncSink, * @brief The per-value coroutine run asynchronously alongside other calls. */ coroutines::Task<> process_one(InputT value, - task_buffer_t& task_buffer, std::shared_ptr on, ExceptionCatcher& catcher); @@ -211,7 +209,11 @@ class AsyncioRunnable : public AsyncSink, std::stop_source m_stop_source; - size_t m_concurrency{8}; + /** + * @brief A semaphore used to control the number of outstanding operations. Acquire one before + * beginning a task, and release it when finished. + */ + std::counting_semaphore<8> m_task_tickets{8}; }; template @@ -279,15 +281,14 @@ void AsyncioRunnable::run(mrc::runnable::Context& ctx) template coroutines::Task<> AsyncioRunnable::main_task(std::shared_ptr scheduler) { - // Create the task buffer to limit the number of running tasks - task_buffer_t task_buffer{{.capacity = m_concurrency}}; - coroutines::TaskContainer outstanding_tasks(scheduler); ExceptionCatcher catcher{}; while (not m_stop_source.stop_requested() and not catcher.has_exception()) { + m_task_tickets.acquire(); + InputT data; auto read_status = co_await this->read_async(data); @@ -297,18 +298,9 @@ coroutines::Task<> AsyncioRunnable::main_task(std::shared_ptrprocess_one(std::move(data), task_buffer, scheduler, catcher)); + outstanding_tasks.start(this->process_one(std::move(data), scheduler, catcher)); } - // Close the buffer - task_buffer.close(); - - // Now block until all tasks are complete - co_await task_buffer.completed(); - co_await outstanding_tasks.garbage_collect_and_yield_until_empty(); catcher.rethrow_next_exception(); @@ -316,7 +308,6 @@ coroutines::Task<> AsyncioRunnable::main_task(std::shared_ptr coroutines::Task<> AsyncioRunnable::process_one(InputT value, - task_buffer_t& task_buffer, std::shared_ptr on, ExceptionCatcher& catcher) { @@ -344,8 +335,7 @@ coroutines::Task<> AsyncioRunnable::process_one(InputT value, catcher.push_exception(std::current_exception()); } - // Return the slot to the task buffer - co_await task_buffer.read(); + m_task_tickets.release(); } template