-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move Pycoro from Morpheus to MRC (#409)
Moves pycoro from Morpheus to MRC and incorperates tests from nv-morpheus/Morpheus#1286 Closes nv-morpheus/Morpheus#1268 Authors: - Christopher Harris (https://github.com/cwharris) Approvers: - Devin Robison (https://github.com/drobison00) URL: #409
- Loading branch information
Showing
10 changed files
with
783 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "pymrc/coro.hpp" | ||
|
||
namespace mrc::pymrc::coro { | ||
|
||
namespace py = pybind11; | ||
|
||
StopIteration::~StopIteration() = default; | ||
|
||
} // namespace mrc::pymrc::coro |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# ============================================================================= | ||
# Copyright (c) 2022-2023, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except | ||
# in compliance with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software distributed under the License | ||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
# or implied. See the License for the specific language governing permissions and limitations under | ||
# the License. | ||
# ============================================================================= | ||
|
||
list(APPEND CMAKE_MESSAGE_CONTEXT "coro") | ||
|
||
set(MODULE_SOURCE_FILES) | ||
|
||
# Add the module file | ||
list(APPEND MODULE_SOURCE_FILES module.cpp) | ||
|
||
# Create the python module | ||
mrc_add_pybind11_module(coro | ||
INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/include | ||
SOURCE_FILES ${MODULE_SOURCE_FILES} | ||
LINK_TARGETS mrc::pymrc | ||
) | ||
|
||
list(POP_BACK CMAKE_MESSAGE_CONTEXT) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/* | ||
* SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <mrc/coroutines/task.hpp> | ||
#include <pybind11/cast.h> | ||
#include <pybind11/gil.h> | ||
#include <pybind11/pybind11.h> | ||
#include <pymrc/coro.hpp> | ||
#include <pymrc/types.hpp> | ||
|
||
#include <coroutine> | ||
#include <stdexcept> | ||
|
||
mrc::coroutines::Task<int> subtract(int a, int b) | ||
{ | ||
co_return a - b; | ||
} | ||
|
||
mrc::coroutines::Task<mrc::pymrc::PyHolder> call_fib_async(mrc::pymrc::PyHolder fib, int value, int minus) | ||
{ | ||
auto result = co_await subtract(value, minus); | ||
co_return co_await mrc::pymrc::coro::PyTaskToCppAwaitable([](auto fib, auto result) { | ||
pybind11::gil_scoped_acquire acquire; | ||
return fib(result); | ||
}(fib, result)); | ||
} | ||
|
||
mrc::coroutines::Task<mrc::pymrc::PyHolder> raise_at_depth_async(mrc::pymrc::PyHolder fn, int depth) | ||
{ | ||
if (depth <= 0) | ||
{ | ||
throw std::runtime_error("depth reached zero in c++"); | ||
} | ||
|
||
co_return co_await mrc::pymrc::coro::PyTaskToCppAwaitable([](auto fn, auto depth) { | ||
pybind11::gil_scoped_acquire acquire; | ||
return fn(depth - 1); | ||
}(fn, depth)); | ||
} | ||
|
||
mrc::coroutines::Task<mrc::pymrc::PyHolder> call_async(mrc::pymrc::PyHolder fn) | ||
{ | ||
co_return co_await mrc::pymrc::coro::PyTaskToCppAwaitable([](auto fn) { | ||
pybind11::gil_scoped_acquire acquire; | ||
return fn(); | ||
}(fn)); | ||
} | ||
|
||
PYBIND11_MODULE(coro, _module) | ||
{ | ||
pybind11::module_::import("mrc.core.coro"); // satisfies automatic type conversions for tasks | ||
|
||
_module.def("call_fib_async", &call_fib_async); | ||
_module.def("raise_at_depth_async", &raise_at_depth_async); | ||
_module.def("call_async", &call_async); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/* | ||
* SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include "pymrc/coro.hpp" | ||
|
||
#include <glog/logging.h> | ||
#include <mrc/coroutines/task.hpp> | ||
#include <pybind11/gil.h> | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/pytypes.h> | ||
#include <pybind11/stl.h> // IWYU pragma: keep | ||
|
||
#include <coroutine> | ||
#include <memory> | ||
#include <ostream> | ||
#include <string> | ||
#include <vector> | ||
|
||
namespace mrc::pymrc::coro { | ||
|
||
namespace py = pybind11; | ||
|
||
PYBIND11_MODULE(coro, _module) | ||
{ | ||
_module.doc() = R"pbdoc( | ||
----------------------- | ||
.. currentmodule:: morpheus.llm | ||
.. autosummary:: | ||
:toctree: _generate | ||
)pbdoc"; | ||
|
||
py::class_<CppToPyAwaitable, std::shared_ptr<CppToPyAwaitable>>(_module, "CppToPyAwaitable") | ||
.def(py::init<>()) | ||
.def("__iter__", &CppToPyAwaitable::iter) | ||
.def("__await__", &CppToPyAwaitable::await) | ||
.def("__next__", &CppToPyAwaitable::next); | ||
|
||
py::class_<BoostFibersMainPyAwaitable, CppToPyAwaitable, std::shared_ptr<BoostFibersMainPyAwaitable>>( // | ||
_module, | ||
"BoostFibersMainPyAwaitable") | ||
.def(py::init<>()); | ||
|
||
_module.def("wrap_coroutine", [](coroutines::Task<std::vector<std::string>> fn) -> coroutines::Task<std::string> { | ||
DCHECK_EQ(PyGILState_Check(), 0) << "Should not have the GIL when resuming a C++ coroutine"; | ||
|
||
auto strings = co_await fn; | ||
|
||
co_return strings[0]; | ||
}); | ||
|
||
// _module.attr("__version__") = | ||
// MRC_CONCAT_STR(morpheus_VERSION_MAJOR << "." << morpheus_VERSION_MINOR << "." << morpheus_VERSION_PATCH); | ||
} | ||
} // namespace mrc::pymrc::coro |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import asyncio | ||
|
||
import pytest | ||
|
||
from mrc._pymrc.tests.coro.coro import call_async | ||
from mrc._pymrc.tests.coro.coro import call_fib_async | ||
from mrc._pymrc.tests.coro.coro import raise_at_depth_async | ||
from mrc.core import coro | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_coro(): | ||
|
||
# hit_inside = False | ||
|
||
async def inner(): | ||
|
||
# nonlocal hit_inside | ||
|
||
result = await coro.wrap_coroutine(asyncio.sleep(1, result=['a', 'b', 'c'])) | ||
|
||
# hit_inside = True | ||
|
||
return [result] | ||
|
||
returned_val = await coro.wrap_coroutine(inner()) | ||
|
||
assert returned_val == 'a' | ||
# assert hit_inside | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_coro_many(): | ||
|
||
expected_count = 1000 | ||
hit_count = 0 | ||
|
||
start_time = asyncio.get_running_loop().time() | ||
|
||
async def inner(): | ||
|
||
nonlocal hit_count | ||
|
||
await asyncio.sleep(0.1) | ||
|
||
hit_count += 1 | ||
|
||
return ['a', 'b', 'c'] | ||
|
||
coros = [coro.wrap_coroutine(inner()) for _ in range(expected_count)] | ||
|
||
returned_vals = await asyncio.gather(*coros) | ||
|
||
end_time = asyncio.get_running_loop().time() | ||
|
||
assert returned_vals == ['a'] * expected_count | ||
assert hit_count == expected_count | ||
assert (end_time - start_time) < 1.5 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_python_cpp_async_interleave(): | ||
|
||
def fib(n): | ||
if n < 0: | ||
raise ValueError() | ||
|
||
if n < 2: | ||
return 1 | ||
|
||
return fib(n - 1) + fib(n - 2) | ||
|
||
async def fib_async(n): | ||
if n < 0: | ||
raise ValueError() | ||
|
||
if n < 2: | ||
return 1 | ||
|
||
task_a = call_fib_async(fib_async, n, 1) | ||
task_b = call_fib_async(fib_async, n, 2) | ||
|
||
[a, b] = await asyncio.gather(task_a, task_b) | ||
|
||
return a + b | ||
|
||
assert fib(15) == await fib_async(15) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_python_cpp_async_exception(): | ||
|
||
async def py_raise_at_depth_async(n: int): | ||
if n <= 0: | ||
raise RuntimeError("depth reached zero in python") | ||
|
||
await raise_at_depth_async(py_raise_at_depth_async, n - 1) | ||
|
||
depth = 100 | ||
|
||
with pytest.raises(RuntimeError) as ex: | ||
await raise_at_depth_async(py_raise_at_depth_async, depth + 1) | ||
assert "python" in str(ex.value) | ||
|
||
with pytest.raises(RuntimeError) as ex: | ||
await raise_at_depth_async(py_raise_at_depth_async, depth) | ||
assert "c++" in str(ex.value) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_can_cancel_coroutine_from_python(): | ||
|
||
counter = 0 | ||
|
||
async def increment_recursively(): | ||
nonlocal counter | ||
await asyncio.sleep(0) | ||
counter += 1 | ||
await call_async(increment_recursively) | ||
|
||
task = asyncio.ensure_future(call_async(increment_recursively)) | ||
|
||
await asyncio.sleep(0) | ||
assert counter == 0 | ||
await asyncio.sleep(0) | ||
await asyncio.sleep(0) | ||
assert counter == 1 | ||
await asyncio.sleep(0) | ||
await asyncio.sleep(0) | ||
assert counter == 2 | ||
|
||
task.cancel() | ||
|
||
with pytest.raises(asyncio.exceptions.CancelledError): | ||
await task | ||
|
||
assert counter == 3 |