diff --git a/.github/copy-pr-bot.yaml b/.github/copy-pr-bot.yaml new file mode 100644 index 000000000..895ba83ee --- /dev/null +++ b/.github/copy-pr-bot.yaml @@ -0,0 +1,4 @@ +# Configuration file for `copy-pr-bot` GitHub App +# https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/ + +enabled: true diff --git a/.github/ops-bot.yaml b/.github/ops-bot.yaml index 2ef41b367..d2ca78924 100644 --- a/.github/ops-bot.yaml +++ b/.github/ops-bot.yaml @@ -5,5 +5,3 @@ auto_merger: true branch_checker: true label_checker: true release_drafter: true -copy_prs: true -rerun_tests: true diff --git a/.github/workflows/ci_pipe.yml b/.github/workflows/ci_pipe.yml index 0c6da79f4..6dd1b08c3 100644 --- a/.github/workflows/ci_pipe.yml +++ b/.github/workflows/ci_pipe.yml @@ -294,7 +294,7 @@ jobs: run: ./mrc/ci/scripts/github/benchmark.sh - name: post_benchmark shell: bash - run: ./mrc/ci/scripts/github/benchmark.sh + run: ./mrc/ci/scripts/github/post_benchmark.sh package: diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index f10b02fea..dd5b73dd3 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -47,7 +47,7 @@ jobs: name: Prepare runs-on: ubuntu-latest container: - image: rapidsai/ci:latest + image: rapidsai/ci-conda:latest steps: - name: Get PR Info id: get-pr-info @@ -71,9 +71,9 @@ jobs: # Update conda package only for non PR branches. Use 'main' for main branch and 'dev' for all other branches conda_upload_label: ${{ !fromJSON(needs.prepare.outputs.is_pr) && (fromJSON(needs.prepare.outputs.is_main_branch) && 'main' || 'dev') || '' }} # Build container - container: nvcr.io/ea-nvidia-morpheus/morpheus:mrc-ci-build-230711 + container: nvcr.io/ea-nvidia-morpheus/morpheus:mrc-ci-build-230920 # Test container - test_container: nvcr.io/ea-nvidia-morpheus/morpheus:mrc-ci-test-230711 + test_container: nvcr.io/ea-nvidia-morpheus/morpheus:mrc-ci-test-230920 # Info about the PR. Empty for non PR branches. Useful for extracting PR number, title, etc. pr_info: ${{ needs.prepare.outputs.pr_info }} secrets: diff --git a/.gitignore b/.gitignore index 1a20325a2..53c9f38e0 100755 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /build*/ +.tmp *.engine .Dockerfile .gitignore @@ -17,6 +18,9 @@ include/mrc/version.hpp .vscode/settings.json .vscode/tasks.json +# Ignore user-defined clangd settings +.clangd + # Created by https://www.gitignore.io/api/vim,c++,cmake,python,synology ### C++ ### diff --git a/.gitmodules b/.gitmodules index 76d78c90c..0180e5c60 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "morpheus_utils"] path = external/utilities url = https://github.com/nv-morpheus/utilities.git - branch = branch-23.07 + branch = branch-23.11 diff --git a/CHANGELOG.md b/CHANGELOG.md index c10b0f7d1..5499c5e19 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,37 @@ +# MRC 23.11.00 (30 Nov 2023) + +## 🐛 Bug Fixes + +- Use a traditional semaphore in AsyncioRunnable ([#412](https://github.com/nv-morpheus/MRC/pull/412)) [@cwharris](https://github.com/cwharris) +- Fix libhwloc & stubgen versions to match dev yaml ([#405](https://github.com/nv-morpheus/MRC/pull/405)) [@dagardner-nv](https://github.com/dagardner-nv) +- Update boost versions to match version used in dev env ([#404](https://github.com/nv-morpheus/MRC/pull/404)) [@dagardner-nv](https://github.com/dagardner-nv) +- Fix EdgeHolder from incorrectly reporting an active connection ([#402](https://github.com/nv-morpheus/MRC/pull/402)) [@dagardner-nv](https://github.com/dagardner-nv) +- Safe handling of control plane promises & fix CI ([#391](https://github.com/nv-morpheus/MRC/pull/391)) [@dagardner-nv](https://github.com/dagardner-nv) +- Revert boost upgrade, and update clang to v16 ([#382](https://github.com/nv-morpheus/MRC/pull/382)) [@dagardner-nv](https://github.com/dagardner-nv) +- Fixing an issue with `update-versions.sh` which always blocked CI ([#377](https://github.com/nv-morpheus/MRC/pull/377)) [@mdemoret-nv](https://github.com/mdemoret-nv) +- Add test for gc being invoked in a thread finalizer ([#365](https://github.com/nv-morpheus/MRC/pull/365)) [@dagardner-nv](https://github.com/dagardner-nv) +- Adopt patched pybind11 ([#364](https://github.com/nv-morpheus/MRC/pull/364)) [@dagardner-nv](https://github.com/dagardner-nv) + +## 📖 Documentation + +- Add missing flags to docker command to mount the working dir and set -cap-add=sys_nice ([#383](https://github.com/nv-morpheus/MRC/pull/383)) [@dagardner-nv](https://github.com/dagardner-nv) +- Make Quick Start Guide not use `make_node_full` ([#376](https://github.com/nv-morpheus/MRC/pull/376)) [@cwharris](https://github.com/cwharris) + +## 🚀 New Features + +- Add AsyncioRunnable ([#411](https://github.com/nv-morpheus/MRC/pull/411)) [@cwharris](https://github.com/cwharris) +- Adding more coroutine components to support async generators and task containers ([#408](https://github.com/nv-morpheus/MRC/pull/408)) [@mdemoret-nv](https://github.com/mdemoret-nv) +- Update ObservableProxy::pipe to support any number of operators ([#387](https://github.com/nv-morpheus/MRC/pull/387)) [@cwharris](https://github.com/cwharris) +- Updates for MRC/Morpheus to build in the same RAPIDS devcontainer environment ([#375](https://github.com/nv-morpheus/MRC/pull/375)) [@cwharris](https://github.com/cwharris) + +## 🛠️ Improvements + +- Move Pycoro from Morpheus to MRC ([#409](https://github.com/nv-morpheus/MRC/pull/409)) [@cwharris](https://github.com/cwharris) +- update rapidsai/ci to rapidsai/ci-conda ([#396](https://github.com/nv-morpheus/MRC/pull/396)) [@AyodeAwe](https://github.com/AyodeAwe) +- Add local CI scripts & rebase docker image ([#394](https://github.com/nv-morpheus/MRC/pull/394)) [@dagardner-nv](https://github.com/dagardner-nv) +- Use `copy-pr-bot` ([#369](https://github.com/nv-morpheus/MRC/pull/369)) [@ajschmidt8](https://github.com/ajschmidt8) +- Update Versions for v23.11.00 ([#357](https://github.com/nv-morpheus/MRC/pull/357)) [@mdemoret-nv](https://github.com/mdemoret-nv) + # MRC 23.07.00 (19 Jul 2023) ## 🚨 Breaking Changes diff --git a/CMakeLists.txt b/CMakeLists.txt index 4f0f92c91..abaa790a3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -30,6 +30,7 @@ option(MRC_BUILD_PYTHON "Enable building the python bindings for MRC" ON) option(MRC_BUILD_TESTS "Whether or not to build MRC tests" ON) option(MRC_ENABLE_CODECOV "Enable gcov code coverage" OFF) option(MRC_ENABLE_DEBUG_INFO "Enable printing debug information" OFF) +option(MRC_PYTHON_INPLACE_BUILD "Whether or not to copy built python modules back to the source tree for debug purposes." OFF) option(MRC_USE_CCACHE "Enable caching compilation results with ccache" OFF) option(MRC_USE_CLANG_TIDY "Enable running clang-tidy as part of the build process" OFF) option(MRC_USE_CONDA "Enables finding dependencies via conda. All dependencies must be installed first in the conda @@ -78,7 +79,7 @@ morpheus_utils_initialize_package_manager( morpheus_utils_initialize_cuda_arch(mrc) project(mrc - VERSION 23.07.00 + VERSION 23.11.00 LANGUAGES C CXX ) diff --git a/Dockerfile b/Dockerfile index cae834533..e303f19f9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,7 +16,7 @@ # limitations under the License. -ARG FROM_IMAGE="rapidsai/ci" +ARG FROM_IMAGE="rapidsai/ci-conda" ARG CUDA_VER=11.8.0 ARG LINUX_DISTRO=ubuntu ARG LINUX_VER=20.04 diff --git a/README.md b/README.md index 0f05e754a..4bdb5c3a3 100644 --- a/README.md +++ b/README.md @@ -151,13 +151,17 @@ pytest $MRC_ROOT/python ### Docker Installation A Dockerfile is provided at `$MRC_ROOT` and can be built with ```bash -docker build -t mrc:latest . +DOCKER_BUILDKIT=1 docker build -t mrc:latest . ``` To run the container ```bash -docker run --gpus all --rm -it mrc:latest /bin/bash +docker run --gpus all --cap-add=sys_nice -v $PWD:/work --rm -it mrc:latest /bin/bash ``` +> **Note:** +> Users wishing to debug MRC in a Docker container should add the following to the `docker run` command: +> `--cap-add=SYS_PTRACE` + ## Quickstart Guide To quickly learn about both the C++ and Python MRC APIs, including following along with various complexity examples, we recommend following the MRC Quickstart Repository located [here](/docs/quickstart/README.md). This tutorial walks new users through topics like diff --git a/ci/conda/environments/clang_env.yml b/ci/conda/environments/clang_env.yml index 9c8867ae4..bebe11bfd 100644 --- a/ci/conda/environments/clang_env.yml +++ b/ci/conda/environments/clang_env.yml @@ -19,11 +19,11 @@ name: mrc channels: - conda-forge dependencies: - - clang=15 - - clang-tools=15 - - clangdev=15 - - clangxx=15 - - libclang=15 - - libclang-cpp=15 - - llvmdev=15 - - include-what-you-use=0.19 + - clang=16 + - clang-tools=16 + - clangdev=16 + - clangxx=16 + - libclang=16 + - libclang-cpp=16 + - llvmdev=16 + - include-what-you-use=0.20 diff --git a/ci/conda/environments/dev_env.yml b/ci/conda/environments/dev_env.yml index 58d83d9a7..ecd4003b6 100644 --- a/ci/conda/environments/dev_env.yml +++ b/ci/conda/environments/dev_env.yml @@ -25,9 +25,9 @@ dependencies: - autoconf>=2.69 - bash-completion - benchmark=1.6.0 - - boost-cpp=1.74 + - boost-cpp=1.82 - ccache - - cmake=3.24 + - cmake=3.25 - cuda-toolkit # Version comes from the channel above - cxx-compiler # Sets up the distro versions of our compilers - doxygen=1.9.2 @@ -46,7 +46,7 @@ dependencies: - isort - jinja2=3.0 - lcov=1.15 - - libhwloc=2.5 + - libhwloc=2.9.2 - libprotobuf=3.21 - librmm=23.06 - libtool @@ -59,6 +59,7 @@ dependencies: - pybind11-stubgen=0.10 - pytest - pytest-timeout + - pytest-asyncio - python=3.10 - scikit-build>=0.17 - sysroot_linux-64=2.17 diff --git a/ci/conda/recipes/libmrc/conda_build_config.yaml b/ci/conda/recipes/libmrc/conda_build_config.yaml index 008688e98..e674d6b6d 100644 --- a/ci/conda/recipes/libmrc/conda_build_config.yaml +++ b/ci/conda/recipes/libmrc/conda_build_config.yaml @@ -71,9 +71,9 @@ zip_keys: # The following mimic what is available in the pinning feedstock: # https://github.com/conda-forge/conda-forge-pinning-feedstock/blob/main/recipe/conda_build_config.yaml boost: - - 1.74.0 + - 1.82 boost_cpp: - - 1.74.0 + - 1.82 gflags: - 2.2 glog: diff --git a/ci/conda/recipes/libmrc/meta.yaml b/ci/conda/recipes/libmrc/meta.yaml index 6abbd7c19..68ec16ecd 100644 --- a/ci/conda/recipes/libmrc/meta.yaml +++ b/ci/conda/recipes/libmrc/meta.yaml @@ -62,12 +62,12 @@ requirements: - libgrpc - gtest 1.13.* - libabseil - - libhwloc 2.5.* + - libhwloc 2.9.2 - libprotobuf - librmm {{ rapids_version }} - nlohmann_json 3.9.1 - pybind11-abi # See: https://conda-forge.org/docs/maintainer/knowledge_base.html#pybind11-abi-constraints - - pybind11-stubgen 0.10.5 + - pybind11-stubgen 0.10 - python {{ python }} - scikit-build >=0.17 - ucx @@ -98,7 +98,7 @@ outputs: - glog - libgrpc - libabseil # Needed for transitive run_exports from libgrpc. Does not need a version - - libhwloc 2.5.* + - libhwloc 2.9.2 - libprotobuf # Needed for transitive run_exports from libgrpc. Does not need a version - librmm {{ rapids_version }} - nlohmann_json 3.9.* diff --git a/ci/release/pr_code_freeze_template.md b/ci/release/pr_code_freeze_template.md new file mode 100644 index 000000000..99642f7c9 --- /dev/null +++ b/ci/release/pr_code_freeze_template.md @@ -0,0 +1,11 @@ +## :snowflake: Code freeze for `branch-${VERSION}` and `v${VERSION}` release + +### What does this mean? +Only critical/hotfix level issues should be merged into `branch-${VERSION}` until release (merging of this PR). + +All other development PRs should be retargeted towards the next release branch: `branch-${NEXT_VERSION}`. + +### What is the purpose of this PR? +- Update documentation +- Allow testing for the new release +- Enable a means to merge `branch-${VERSION}` into `main` for the release diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index 8e4895f23..31b541957 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -60,6 +60,10 @@ function sed_runner() { # .gitmodules git submodule set-branch -b branch-${NEXT_SHORT_TAG} morpheus_utils +if [[ "$(git diff --name-only | grep .gitmodules)" != "" ]]; then + # Only update the submodules if setting the branch changed .gitmodules + git submodule update --remote +fi # Root CMakeLists.txt sed_runner 's/'"VERSION ${CURRENT_FULL_VERSION}.*"'/'"VERSION ${NEXT_FULL_VERSION}"'/g' CMakeLists.txt diff --git a/ci/scripts/bootstrap_local_ci.sh b/ci/scripts/bootstrap_local_ci.sh new file mode 100755 index 000000000..f1ff55bb2 --- /dev/null +++ b/ci/scripts/bootstrap_local_ci.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export WORKSPACE_TMP="$(pwd)/.tmp/local_ci_workspace" +mkdir -p ${WORKSPACE_TMP} +git clone ${GIT_URL} mrc +cd mrc/ +git checkout ${GIT_BRANCH} +git pull +git checkout ${GIT_COMMIT} + +export MRC_ROOT=$(pwd) +export WORKSPACE=${MRC_ROOT} +export LOCAL_CI=1 +GH_SCRIPT_DIR="${MRC_ROOT}/ci/scripts/github" + +unset CMAKE_CUDA_COMPILER_LAUNCHER +unset CMAKE_CXX_COMPILER_LAUNCHER +unset CMAKE_C_COMPILER_LAUNCHER + +if [[ "${STAGE}" != "bash" ]]; then + # benchmark & codecov are composite stages, the rest are composed of a single shell script + if [[ "${STAGE}" == "benchmark" || "${STAGE}" == "codecov" ]]; then + CI_SCRIPT="${WORKSPACE_TMP}/ci_script.sh" + echo "#!/bin/bash" > ${CI_SCRIPT} + if [[ "${STAGE}" == "benchmark" ]]; then + echo "${GH_SCRIPT_DIR}/pre_benchmark.sh" >> ${CI_SCRIPT} + echo "${GH_SCRIPT_DIR}/benchmark.sh" >> ${CI_SCRIPT} + echo "${GH_SCRIPT_DIR}/post_benchmark.sh" >> ${CI_SCRIPT} + else + echo "${GH_SCRIPT_DIR}/build.sh" >> ${CI_SCRIPT} + echo "${GH_SCRIPT_DIR}/test_codecov.sh" >> ${CI_SCRIPT} + fi + + chmod +x ${CI_SCRIPT} + else + if [[ "${STAGE}" =~ "build" ]]; then + CI_SCRIPT="${GH_SCRIPT_DIR}/build.sh" + elif [[ "${STAGE}" =~ "test" ]]; then + CI_SCRIPT="${GH_SCRIPT_DIR}/test.sh" + else + CI_SCRIPT="${GH_SCRIPT_DIR}/${STAGE}.sh" + fi + fi + + ${CI_SCRIPT} +fi diff --git a/ci/scripts/cpp_checks.sh b/ci/scripts/cpp_checks.sh index 416c92167..b83df0727 100755 --- a/ci/scripts/cpp_checks.sh +++ b/ci/scripts/cpp_checks.sh @@ -80,9 +80,22 @@ if [[ -n "${MRC_MODIFIED_FILES}" ]]; then # Include What You Use if [[ "${SKIP_IWYU}" == "" ]]; then - IWYU_DIRS="cpp python" + # Remove .h, .hpp, and .cu files from the modified list + shopt -s extglob + IWYU_MODIFIED_FILES=( "${MRC_MODIFIED_FILES[@]/*.@(h|hpp|cu)/}" ) + + # Get the list of compiled files relative to this directory + WORKING_PREFIX="${PWD}/" + COMPILED_FILES=( $(jq -r .[].file ${BUILD_DIR}/compile_commands.json | sort -u ) ) + COMPILED_FILES=( "${COMPILED_FILES[@]/#$WORKING_PREFIX/}" ) + COMBINED_FILES=("${COMPILED_FILES[@]}") + COMBINED_FILES+=("${IWYU_MODIFIED_FILES[@]}") + + # Find the intersection between compiled files and modified files + IWYU_MODIFIED_FILES=( $(printf '%s\0' "${COMBINED_FILES[@]}" | sort -z | uniq -d -z | xargs -0n1) ) + NUM_PROC=$(get_num_proc) - IWYU_OUTPUT=`${IWYU_TOOL} -p ${BUILD_DIR} -j ${NUM_PROC} ${IWYU_DIRS} 2>&1` + IWYU_OUTPUT=`${IWYU_TOOL} -p ${BUILD_DIR} -j ${NUM_PROC} ${IWYU_MODIFIED_FILES[@]} 2>&1` IWYU_RETVAL=$? fi else diff --git a/ci/scripts/github/build.sh b/ci/scripts/github/build.sh index e63f04eef..300452c05 100755 --- a/ci/scripts/github/build.sh +++ b/ci/scripts/github/build.sh @@ -20,7 +20,12 @@ source ${WORKSPACE}/ci/scripts/github/common.sh update_conda_env -CMAKE_CACHE_FLAGS="-DCCACHE_PROGRAM_PATH=$(which sccache) -DMRC_USE_CCACHE=ON" +if [[ "${LOCAL_CI}" == "" ]]; then + CMAKE_CACHE_FLAGS="-DCCACHE_PROGRAM_PATH=$(which sccache) -DMRC_USE_CCACHE=ON" +else + CMAKE_CACHE_FLAGS="" +fi + rapids-logger "Check versions" python3 --version @@ -56,18 +61,20 @@ cmake -B build -G Ninja ${CMAKE_FLAGS} . rapids-logger "Building MRC" cmake --build build --parallel ${PARALLEL_LEVEL} -rapids-logger "sccache usage for MRC build:" -sccache --show-stats +if [[ "${LOCAL_CI}" == "" ]]; then + rapids-logger "sccache usage for MRC build:" + sccache --show-stats +fi -if [[ "${BUILD_CC}" != "gcc-coverage" ]]; then +if [[ "${BUILD_CC}" != "gcc-coverage" || ${LOCAL_CI} == "1" ]]; then rapids-logger "Archiving results" tar cfj "${WORKSPACE_TMP}/dot_cache.tar.bz" .cache tar cfj "${WORKSPACE_TMP}/build.tar.bz" build ls -lh ${WORKSPACE_TMP}/ rapids-logger "Pushing results to ${DISPLAY_ARTIFACT_URL}/" - aws s3 cp --no-progress "${WORKSPACE_TMP}/build.tar.bz" "${ARTIFACT_URL}/build.tar.bz" - aws s3 cp --no-progress "${WORKSPACE_TMP}/dot_cache.tar.bz" "${ARTIFACT_URL}/dot_cache.tar.bz" + upload_artifact "${WORKSPACE_TMP}/build.tar.bz" + upload_artifact "${WORKSPACE_TMP}/dot_cache.tar.bz" fi rapids-logger "Success" diff --git a/ci/scripts/github/checks.sh b/ci/scripts/github/checks.sh index 4ea5c5583..e64b36183 100755 --- a/ci/scripts/github/checks.sh +++ b/ci/scripts/github/checks.sh @@ -24,7 +24,8 @@ update_conda_env rapids-logger "Configuring CMake" git submodule update --init --recursive -cmake -B build -G Ninja ${CMAKE_BUILD_ALL_FEATURES} . +CMAKE_CLANG_OPTIONS="-DCMAKE_C_COMPILER:FILEPATH=$(which clang) -DCMAKE_CXX_COMPILER:FILEPATH=$(which clang++) -DCMAKE_CUDA_COMPILER:FILEPATH=$(which nvcc)" +cmake -B build -G Ninja ${CMAKE_CLANG_OPTIONS} ${CMAKE_BUILD_ALL_FEATURES} . rapids-logger "Building targets that generate source code" cmake --build build --target mrc_style_checks --parallel ${PARALLEL_LEVEL} diff --git a/ci/scripts/github/common.sh b/ci/scripts/github/common.sh index 02684da2f..17807bdce 100644 --- a/ci/scripts/github/common.sh +++ b/ci/scripts/github/common.sh @@ -56,7 +56,12 @@ export S3_URL="s3://rapids-downloads/ci/mrc" export DISPLAY_URL="https://downloads.rapids.ai/ci/mrc" export ARTIFACT_ENDPOINT="/pull-request/${PR_NUM}/${GIT_COMMIT}/${NVARCH}/${BUILD_CC}" export ARTIFACT_URL="${S3_URL}${ARTIFACT_ENDPOINT}" -export DISPLAY_ARTIFACT_URL="${DISPLAY_URL}${ARTIFACT_ENDPOINT}" + +if [[ "${LOCAL_CI}" == "1" ]]; then + export DISPLAY_ARTIFACT_URL="${LOCAL_CI_TMP}" +else + export DISPLAY_ARTIFACT_URL="${DISPLAY_URL}${ARTIFACT_ENDPOINT}" +fi # Set sccache env vars export SCCACHE_S3_KEY_PREFIX=mrc-${NVARCH}-${BUILD_CC} @@ -78,9 +83,11 @@ function update_conda_env() { # Deactivate the environment first before updating conda deactivate - # Make sure we have the conda-merge package installed - if [[ -z "$(conda list | grep conda-merge)" ]]; then - rapids-mamba-retry install -q -n mrc -c conda-forge "conda-merge>=0.2" + if [[ "${SKIP_CONDA_ENV_UPDATE}" == "" ]]; then + # Make sure we have the conda-merge package installed + if [[ -z "$(conda list | grep conda-merge)" ]]; then + rapids-mamba-retry install -q -n mrc -c conda-forge "conda-merge>=0.2" + fi fi # Create a temp directory which we store the combined environment file in @@ -90,8 +97,10 @@ function update_conda_env() { # will clobber the last env update conda run -n mrc --live-stream conda-merge ${CONDA_ENV_YML} ${CONDA_CLANG_ENV_YML} ${CONDA_CI_ENV_YML} > ${condatmpdir}/merged_env.yml - # Update the conda env with prune remove excess packages (in case one was removed from the env) - rapids-mamba-retry env update -n mrc --prune --file ${condatmpdir}/merged_env.yml + if [[ "${SKIP_CONDA_ENV_UPDATE}" == "" ]]; then + # Update the conda env with prune remove excess packages (in case one was removed from the env) + rapids-mamba-retry env update -n mrc --prune --file ${condatmpdir}/merged_env.yml + fi # Delete the temp directory rm -rf ${condatmpdir} @@ -105,7 +114,12 @@ function update_conda_env() { print_env_vars -function fetch_base_branch() { +function fetch_base_branch_gh_api() { + # For PRs, $GIT_BRANCH is like: pull-request/989 + REPO_NAME=$(basename "${GITHUB_REPOSITORY}") + ORG_NAME="${GITHUB_REPOSITORY_OWNER}" + PR_NUM="${GITHUB_REF_NAME##*/}" + rapids-logger "Retrieving base branch from GitHub API" [[ -n "$GH_TOKEN" ]] && CURL_HEADERS=('-H' "Authorization: token ${GH_TOKEN}") RESP=$( @@ -115,25 +129,31 @@ function fetch_base_branch() { "${GITHUB_API_URL}/repos/${ORG_NAME}/${REPO_NAME}/pulls/${PR_NUM}" ) - BASE_BRANCH=$(echo "${RESP}" | jq -r '.base.ref') + export BASE_BRANCH=$(echo "${RESP}" | jq -r '.base.ref') # Change target is the branch name we are merging into but due to the weird way jenkins does # the checkout it isn't recognized by git without the origin/ prefix export CHANGE_TARGET="origin/${BASE_BRANCH}" - git submodule update --init --recursive - rapids-logger "Base branch: ${BASE_BRANCH}" } -function fetch_s3() { - ENDPOINT=$1 - DESTINATION=$2 - if [[ "${USE_S3_CURL}" == "1" ]]; then - curl -f "${DISPLAY_URL}${ENDPOINT}" -o "${DESTINATION}" - FETCH_STATUS=$? +function fetch_base_branch_local() { + rapids-logger "Retrieving base branch from git" + git remote add upstream ${GIT_UPSTREAM_URL} + git fetch upstream --tags + source ${MRC_ROOT}/ci/scripts/common.sh + export BASE_BRANCH=$(get_base_branch) + export CHANGE_TARGET="upstream/${BASE_BRANCH}" +} + +function fetch_base_branch() { + if [[ "${LOCAL_CI}" == "1" ]]; then + fetch_base_branch_local else - aws s3 cp --no-progress "${S3_URL}${ENDPOINT}" "${DESTINATION}" - FETCH_STATUS=$? + fetch_base_branch_gh_api fi + + git submodule update --init --recursive + rapids-logger "Base branch: ${BASE_BRANCH}" } function show_conda_info() { @@ -143,3 +163,25 @@ function show_conda_info() { conda config --show-sources conda list --show-channel-urls } + +function upload_artifact() { + FILE_NAME=$1 + BASE_NAME=$(basename "${FILE_NAME}") + rapids-logger "Uploading artifact: ${BASE_NAME}" + if [[ "${LOCAL_CI}" == "1" ]]; then + cp ${FILE_NAME} "${LOCAL_CI_TMP}/${BASE_NAME}" + else + aws s3 cp --only-show-errors "${FILE_NAME}" "${ARTIFACT_URL}/${BASE_NAME}" + echo "- ${DISPLAY_ARTIFACT_URL}/${BASE_NAME}" >> ${GITHUB_STEP_SUMMARY} + fi +} + +function download_artifact() { + ARTIFACT=$1 + rapids-logger "Downloading ${ARTIFACT} from ${DISPLAY_ARTIFACT_URL}" + if [[ "${LOCAL_CI}" == "1" ]]; then + cp "${LOCAL_CI_TMP}/${ARTIFACT}" "${WORKSPACE_TMP}/${ARTIFACT}" + else + aws s3 cp --no-progress "${ARTIFACT_URL}/${ARTIFACT}" "${WORKSPACE_TMP}/${ARTIFACT}" + fi +} diff --git a/ci/scripts/github/conda.sh b/ci/scripts/github/conda.sh index 3b8104ad3..36a878528 100755 --- a/ci/scripts/github/conda.sh +++ b/ci/scripts/github/conda.sh @@ -16,6 +16,7 @@ set -e +CI_SCRIPT_ARGS="$@" source ${WORKSPACE}/ci/scripts/github/common.sh # Its important that we are in the base environment for the build @@ -39,4 +40,15 @@ conda info rapids-logger "Building Conda Package" # Run the conda build and upload -${MRC_ROOT}/ci/conda/recipes/run_conda_build.sh "$@" +${MRC_ROOT}/ci/conda/recipes/run_conda_build.sh "${CI_SCRIPT_ARGS}" + +if [[ " ${CI_SCRIPT_ARGS} " =~ " upload " ]]; then + rapids-logger "Building Conda Package... Done" +else + # if we didn't receive the upload argument, we can still upload the artifact to S3 + tar cfj "${WORKSPACE_TMP}/conda.tar.bz" "${RAPIDS_CONDA_BLD_OUTPUT_DIR}" + ls -lh ${WORKSPACE_TMP}/ + + rapids-logger "Pushing results to ${DISPLAY_ARTIFACT_URL}/" + upload_artifact "${WORKSPACE_TMP}/conda.tar.bz" +fi diff --git a/ci/scripts/github/docs.sh b/ci/scripts/github/docs.sh index 2e0a1f64c..c5f10a53a 100755 --- a/ci/scripts/github/docs.sh +++ b/ci/scripts/github/docs.sh @@ -39,6 +39,6 @@ rapids-logger "Tarring the docs" tar cfj "${WORKSPACE_TMP}/docs.tar.bz" build/docs/html rapids-logger "Pushing results to ${DISPLAY_ARTIFACT_URL}/" -aws s3 cp --no-progress "${WORKSPACE_TMP}/docs.tar.bz" "${ARTIFACT_URL}/docs.tar.bz" +upload_artifact "${WORKSPACE_TMP}/docs.tar.bz" rapids-logger "Success" diff --git a/ci/scripts/github/post_benchmark.sh b/ci/scripts/github/post_benchmark.sh index d08bce2b4..943abc7e0 100755 --- a/ci/scripts/github/post_benchmark.sh +++ b/ci/scripts/github/post_benchmark.sh @@ -1,5 +1,5 @@ #!/usr/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,6 +25,6 @@ cd $(dirname ${REPORTS_DIR}) tar cfj ${WORKSPACE_TMP}/benchmark_reports.tar.bz $(basename ${REPORTS_DIR}) rapids-logger "Pushing results to ${DISPLAY_ARTIFACT_URL}/" -aws s3 cp ${WORKSPACE_TMP}/benchmark_reports.tar.bz "${ARTIFACT_URL}/benchmark_reports.tar.bz" +upload_artifact ${WORKSPACE_TMP}/benchmark_reports.tar.bz exit $(cat ${WORKSPACE_TMP}/exit_status) diff --git a/ci/scripts/github/pre_benchmark.sh b/ci/scripts/github/pre_benchmark.sh index 419df25c2..c14a29144 100755 --- a/ci/scripts/github/pre_benchmark.sh +++ b/ci/scripts/github/pre_benchmark.sh @@ -1,5 +1,5 @@ #!/usr/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,7 +21,7 @@ source ${WORKSPACE}/ci/scripts/github/common.sh update_conda_env rapids-logger "Fetching Build artifacts from ${DISPLAY_ARTIFACT_URL}/" -fetch_s3 "${ARTIFACT_ENDPOINT}/build.tar.bz" "${WORKSPACE_TMP}/build.tar.bz" +download_artifact "build.tar.bz" tar xf "${WORKSPACE_TMP}/build.tar.bz" diff --git a/ci/scripts/github/test.sh b/ci/scripts/github/test.sh index 0aab525a0..40000a516 100755 --- a/ci/scripts/github/test.sh +++ b/ci/scripts/github/test.sh @@ -22,8 +22,8 @@ source ${WORKSPACE}/ci/scripts/github/common.sh update_conda_env rapids-logger "Fetching Build artifacts from ${DISPLAY_ARTIFACT_URL}/" -fetch_s3 "${ARTIFACT_ENDPOINT}/dot_cache.tar.bz" "${WORKSPACE_TMP}/dot_cache.tar.bz" -fetch_s3 "${ARTIFACT_ENDPOINT}/build.tar.bz" "${WORKSPACE_TMP}/build.tar.bz" +download_artifact "dot_cache.tar.bz" +download_artifact "build.tar.bz" tar xf "${WORKSPACE_TMP}/dot_cache.tar.bz" tar xf "${WORKSPACE_TMP}/build.tar.bz" @@ -60,7 +60,7 @@ cd $(dirname ${REPORTS_DIR}) tar cfj ${WORKSPACE_TMP}/test_reports.tar.bz $(basename ${REPORTS_DIR}) rapids-logger "Pushing results to ${DISPLAY_ARTIFACT_URL}/" -aws s3 cp ${WORKSPACE_TMP}/test_reports.tar.bz "${ARTIFACT_URL}/test_reports.tar.bz" +upload_artifact ${WORKSPACE_TMP}/test_reports.tar.bz TEST_RESULTS=$(($CTEST_RESULTS+$PYTEST_RESULTS)) exit ${TEST_RESULTS} diff --git a/ci/scripts/github/test_codecov.sh b/ci/scripts/github/test_codecov.sh index 4a0ef3ce8..97955859a 100755 --- a/ci/scripts/github/test_codecov.sh +++ b/ci/scripts/github/test_codecov.sh @@ -58,13 +58,16 @@ cd ${MRC_ROOT}/build # correctly and enabling relative only ignores system and conda files. find . -type f -name '*.gcda' -exec x86_64-conda_cos6-linux-gnu-gcov -pbc --source-prefix ${MRC_ROOT} --relative-only {} + 1> /dev/null -rapids-logger "Uploading codecov for C++ tests" -# Get the list of files that we are interested in (Keeps the upload small) -GCOV_FILES=$(find . -type f \( -iname "cpp#mrc#include#*.gcov" -or -iname "python#*.gcov" -or -iname "cpp#mrc#src#*.gcov" \)) +if [[ "${LOCAL_CI}" == "" ]]; then + rapids-logger "Uploading codecov for C++ tests" -# Upload the .gcov files directly to codecov. They do a good job at processing the partials -/opt/conda/envs/mrc/bin/codecov ${CODECOV_ARGS} -f ${GCOV_FILES} -F cpp + # Get the list of files that we are interested in (Keeps the upload small) + GCOV_FILES=$(find . -type f \( -iname "cpp#mrc#include#*.gcov" -or -iname "python#*.gcov" -or -iname "cpp#mrc#src#*.gcov" \)) + + # Upload the .gcov files directly to codecov. They do a good job at processing the partials + /opt/conda/envs/mrc/bin/codecov ${CODECOV_ARGS} -f ${GCOV_FILES} -F cpp +fi # Remove the gcov files and any gcda files to reset counters find . -type f \( -iname "*.gcov" -or -iname "*.gcda" \) -exec rm {} \; @@ -85,13 +88,15 @@ cd ${MRC_ROOT}/build # correctly and enabling relative only ignores system and conda files. find . -type f -name '*.gcda' -exec x86_64-conda_cos6-linux-gnu-gcov -pbc --source-prefix ${MRC_ROOT} --relative-only {} + 1> /dev/null -rapids-logger "Uploading codecov for Python tests" +if [[ "${LOCAL_CI}" == "" ]]; then + rapids-logger "Uploading codecov for Python tests" -# Get the list of files that we are interested in (Keeps the upload small) -GCOV_FILES=$(find . -type f \( -iname "cpp#mrc#include#*.gcov" -or -iname "python#*.gcov" -or -iname "cpp#mrc#src#*.gcov" \)) + # Get the list of files that we are interested in (Keeps the upload small) + GCOV_FILES=$(find . -type f \( -iname "cpp#mrc#include#*.gcov" -or -iname "python#*.gcov" -or -iname "cpp#mrc#src#*.gcov" \)) -# Upload the .gcov files directly to codecov. They do a good job at processing the partials -/opt/conda/envs/mrc/bin/codecov ${CODECOV_ARGS} -f ${GCOV_FILES} -F py + # Upload the .gcov files directly to codecov. They do a good job at processing the partials + /opt/conda/envs/mrc/bin/codecov ${CODECOV_ARGS} -f ${GCOV_FILES} -F py +fi # Remove the gcov files and any gcda files to reset counters find . -type f \( -iname "*.gcov" -or -iname "*.gcda" \) -exec rm {} \; @@ -101,7 +106,7 @@ cd $(dirname ${REPORTS_DIR}) tar cfj ${WORKSPACE_TMP}/test_reports.tar.bz $(basename ${REPORTS_DIR}) rapids-logger "Pushing results to ${DISPLAY_ARTIFACT_URL}/" -aws s3 cp ${WORKSPACE_TMP}/test_reports.tar.bz "${ARTIFACT_URL}/test_reports.tar.bz" +upload_artifact ${WORKSPACE_TMP}/test_reports.tar.bz TEST_RESULTS=$(($CTEST_RESULTS+$PYTEST_RESULTS)) exit ${TEST_RESULTS} diff --git a/ci/scripts/run_ci_local.sh b/ci/scripts/run_ci_local.sh new file mode 100755 index 000000000..bae506ccf --- /dev/null +++ b/ci/scripts/run_ci_local.sh @@ -0,0 +1,128 @@ +#!/bin/bash +# SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +case "$1" in + "" ) + STAGES=("bash") + ;; + "all" ) + STAGES=("checks" "build-clang" "build-gcc" "test-clang" "test-gcc" "codecov" "docs" "benchmark" "conda") + ;; + "build" ) + STAGES=("build-clang" "build-gcc") + ;; + "test" ) + STAGES=("test-clang" "test-gcc") + ;; + "checks" | "build-clang" | "build-gcc" | "test" | "test-clang" | "test-gcc" | "codecov" | "docs" | "benchmark" | \ + "conda" | "bash" ) + STAGES=("$1") + ;; + * ) + echo "Error: Invalid argument \"$1\" provided. Expected values: \"all\", \"checks\", \"build\", " \ + "\"build-clang\", \"build-gcc\", \"test\", \"test-clang\", \"test-gcc\", \"codecov\"," \ + "\"docs\", \"benchmark\", \"conda\" or \"bash\"" + exit 1 + ;; +esac + +# CI image doesn't contain ssh, need to use https +function git_ssh_to_https() +{ + local url=$1 + echo $url | sed -e 's|^git@github\.com:|https://github.com/|' +} + +MRC_ROOT=${MRC_ROOT:-$(git rev-parse --show-toplevel)} + +GIT_URL=$(git remote get-url origin) +GIT_URL=$(git_ssh_to_https ${GIT_URL}) + +GIT_UPSTREAM_URL=$(git remote get-url upstream) +GIT_UPSTREAM_URL=$(git_ssh_to_https ${GIT_UPSTREAM_URL}) + +GIT_BRANCH=$(git branch --show-current) +GIT_COMMIT=$(git log -n 1 --pretty=format:%H) + +BASE_LOCAL_CI_TMP=${BASE_LOCAL_CI_TMP:-${MRC_ROOT}/.tmp/local_ci_tmp} +CONTAINER_VER=${CONTAINER_VER:-230920} +CUDA_VER=${CUDA_VER:-11.8} +DOCKER_EXTRA_ARGS=${DOCKER_EXTRA_ARGS:-""} + +BUILD_CONTAINER="nvcr.io/ea-nvidia-morpheus/morpheus:mrc-ci-build-${CONTAINER_VER}" +TEST_CONTAINER="nvcr.io/ea-nvidia-morpheus/morpheus:mrc-ci-test-${CONTAINER_VER}" + +# These variables are common to all stages +BASE_ENV_LIST="--env LOCAL_CI_TMP=/ci_tmp" +BASE_ENV_LIST="${BASE_ENV_LIST} --env GIT_URL=${GIT_URL}" +BASE_ENV_LIST="${BASE_ENV_LIST} --env GIT_UPSTREAM_URL=${GIT_UPSTREAM_URL}" +BASE_ENV_LIST="${BASE_ENV_LIST} --env GIT_BRANCH=${GIT_BRANCH}" +BASE_ENV_LIST="${BASE_ENV_LIST} --env GIT_COMMIT=${GIT_COMMIT}" +BASE_ENV_LIST="${BASE_ENV_LIST} --env PARALLEL_LEVEL=$(nproc)" +BASE_ENV_LIST="${BASE_ENV_LIST} --env CUDA_VER=${CUDA_VER}" +BASE_ENV_LIST="${BASE_ENV_LIST} --env SKIP_CONDA_ENV_UPDATE=${SKIP_CONDA_ENV_UPDATE}" + +for STAGE in "${STAGES[@]}"; do + # Take a copy of the base env list, then make stage specific changes + ENV_LIST="${BASE_ENV_LIST}" + + if [[ "${STAGE}" =~ benchmark|clang|codecov|gcc ]]; then + if [[ "${STAGE}" =~ "clang" ]]; then + BUILD_CC="clang" + elif [[ "${STAGE}" == "codecov" ]]; then + BUILD_CC="gcc-coverage" + else + BUILD_CC="gcc" + fi + + ENV_LIST="${ENV_LIST} --env BUILD_CC=${BUILD_CC}" + LOCAL_CI_TMP="${BASE_LOCAL_CI_TMP}/${BUILD_CC}" + mkdir -p ${LOCAL_CI_TMP} + else + LOCAL_CI_TMP="${BASE_LOCAL_CI_TMP}" + fi + + mkdir -p ${LOCAL_CI_TMP} + cp ${MRC_ROOT}/ci/scripts/bootstrap_local_ci.sh ${LOCAL_CI_TMP} + + + DOCKER_RUN_ARGS="--rm -ti --net=host -v "${LOCAL_CI_TMP}":/ci_tmp ${ENV_LIST} --env STAGE=${STAGE}" + if [[ "${STAGE}" =~ "test" || "${STAGE}" =~ "codecov" || "${USE_GPU}" == "1" ]]; then + CONTAINER="${TEST_CONTAINER}" + DOCKER_RUN_ARGS="${DOCKER_RUN_ARGS} --runtime=nvidia --gpus all --cap-add=sys_nice --cap-add=sys_ptrace" + else + CONTAINER="${BUILD_CONTAINER}" + DOCKER_RUN_ARGS="${DOCKER_RUN_ARGS} --runtime=runc" + if [[ "${STAGE}" == "benchmark" ]]; then + DOCKER_RUN_ARGS="${DOCKER_RUN_ARGS} --cap-add=sys_nice --cap-add=sys_ptrace" + fi + fi + + if [[ "${STAGE}" == "bash" ]]; then + DOCKER_RUN_CMD="bash --init-file /ci_tmp/bootstrap_local_ci.sh" + else + DOCKER_RUN_CMD="/ci_tmp/bootstrap_local_ci.sh" + fi + + echo "Running ${STAGE} stage in ${CONTAINER}" + docker run ${DOCKER_RUN_ARGS} ${DOCKER_EXTRA_ARGS} ${CONTAINER} ${DOCKER_RUN_CMD} + + STATUS=$? + if [[ ${STATUS} -ne 0 ]]; then + echo "Error: docker exited with a non-zero status code for ${STAGE} of ${STATUS}" + exit ${STATUS} + fi +done diff --git a/cpp/mrc/CMakeLists.txt b/cpp/mrc/CMakeLists.txt index a0af3cbcd..f2f1e63cc 100644 --- a/cpp/mrc/CMakeLists.txt +++ b/cpp/mrc/CMakeLists.txt @@ -38,6 +38,7 @@ add_library(libmrc src/internal/data_plane/server.cpp src/internal/executor/executor_definition.cpp src/internal/grpc/progress_engine.cpp + src/internal/grpc/promise_handler.cpp src/internal/grpc/server.cpp src/internal/memory/device_resources.cpp src/internal/memory/host_resources.cpp @@ -115,12 +116,14 @@ add_library(libmrc src/public/core/thread.cpp src/public/coroutines/event.cpp src/public/coroutines/sync_wait.cpp + src/public/coroutines/task_container.cpp src/public/coroutines/thread_local_context.cpp src/public/coroutines/thread_pool.cpp src/public/cuda/device_guard.cpp src/public/cuda/sync.cpp src/public/edge/edge_adapter_registry.cpp src/public/edge/edge_builder.cpp + src/public/exceptions/exception_catcher.cpp src/public/manifold/manifold.cpp src/public/memory/buffer_view.cpp src/public/memory/codable/buffer.cpp diff --git a/cpp/mrc/include/mrc/core/userspace_threads.hpp b/cpp/mrc/include/mrc/core/userspace_threads.hpp index 19e36c9c2..273b04b3a 100644 --- a/cpp/mrc/include/mrc/core/userspace_threads.hpp +++ b/cpp/mrc/include/mrc/core/userspace_threads.hpp @@ -19,44 +19,51 @@ #include -namespace mrc { +namespace mrc::userspace_threads { -struct userspace_threads // NOLINT -{ - using mutex = boost::fibers::mutex; // NOLINT +// Suppress naming conventions in this file to allow matching std and boost libraries +// NOLINTBEGIN(readability-identifier-naming) + +using mutex = boost::fibers::mutex; + +using recursive_mutex = boost::fibers::recursive_mutex; - using cv = boost::fibers::condition_variable; // NOLINT +using cv = boost::fibers::condition_variable; - using launch = boost::fibers::launch; // NOLINT +using cv_any = boost::fibers::condition_variable_any; - template - using promise = boost::fibers::promise; // NOLINT +using launch = boost::fibers::launch; - template - using future = boost::fibers::future; // NOLINT +template +using promise = boost::fibers::promise; - template - using shared_future = boost::fibers::shared_future; // NOLINT +template +using future = boost::fibers::future; - template // NOLINT - using packaged_task = boost::fibers::packaged_task; // NOLINT +template +using shared_future = boost::fibers::shared_future; - template // NOLINT - static auto async(Function&& f, Args&&... args) - { - return boost::fibers::async(f, std::forward(args)...); - } +template +using packaged_task = boost::fibers::packaged_task; + +template +static auto async(Function&& f, Args&&... args) +{ + return boost::fibers::async(f, std::forward(args)...); +} + +template +static void sleep_for(std::chrono::duration const& timeout_duration) +{ + boost::this_fiber::sleep_for(timeout_duration); +} + +template +static void sleep_until(std::chrono::time_point const& sleep_time_point) +{ + boost::this_fiber::sleep_until(sleep_time_point); +} - template // NOLINT - static void sleep_for(std::chrono::duration const& timeout_duration) - { - boost::this_fiber::sleep_for(timeout_duration); - } +// NOLINTEND(readability-identifier-naming) - template // NOLINT - static void sleep_until(std::chrono::time_point const& sleep_time_point) - { - boost::this_fiber::sleep_until(sleep_time_point); - } -}; -} // namespace mrc +} // namespace mrc::userspace_threads diff --git a/cpp/mrc/include/mrc/core/utils.hpp b/cpp/mrc/include/mrc/core/utils.hpp index 84e2f8e06..72d9089a7 100644 --- a/cpp/mrc/include/mrc/core/utils.hpp +++ b/cpp/mrc/include/mrc/core/utils.hpp @@ -60,9 +60,12 @@ std::set extract_keys(const std::map& stdmap) class Unwinder { public: - explicit Unwinder(std::function unwind_fn) : m_unwind_fn(std::move(unwind_fn)) {} + explicit Unwinder(std::function unwind_fn) : + m_unwind_fn(std::move(unwind_fn)), + m_ctor_exception_count(std::uncaught_exceptions()) + {} - ~Unwinder() + ~Unwinder() noexcept(false) { if (!!m_unwind_fn) { @@ -71,8 +74,14 @@ class Unwinder m_unwind_fn(); } catch (...) { - LOG(ERROR) << "Fatal error during unwinder function"; - std::terminate(); + if (std::uncaught_exceptions() > m_ctor_exception_count) + { + LOG(ERROR) << "Error occurred during unwinder function, but another exception is active."; + std::terminate(); + } + + LOG(ERROR) << "Error occurred during unwinder function. Rethrowing"; + throw; } } } @@ -92,6 +101,9 @@ class Unwinder } private: + // Stores the number of active exceptions during creation. If the number of active exceptions during destruction is + // greater, we do not throw and log error and terminate + int m_ctor_exception_count; std::function m_unwind_fn; }; diff --git a/cpp/mrc/include/mrc/coroutines/async_generator.hpp b/cpp/mrc/include/mrc/coroutines/async_generator.hpp new file mode 100644 index 000000000..22036c2e7 --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/async_generator.hpp @@ -0,0 +1,399 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Original Source: https://github.com/lewissbaker/cppcoro + * Original License: MIT; included below + */ + +// Copyright 2017 Lewis Baker + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is furnished +// to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#pragma once + +#include "mrc/utils/macros.hpp" + +#include + +#include +#include +#include +#include + +namespace mrc::coroutines { + +template +class AsyncGenerator; + +namespace detail { + +template +class AsyncGeneratorIterator; +class AsyncGeneratorYieldOperation; +class AsyncGeneratorAdvanceOperation; + +class AsyncGeneratorPromiseBase +{ + public: + AsyncGeneratorPromiseBase() noexcept : m_exception(nullptr) {} + + DELETE_COPYABILITY(AsyncGeneratorPromiseBase) + + constexpr static std::suspend_always initial_suspend() noexcept + { + return {}; + } + + AsyncGeneratorYieldOperation final_suspend() noexcept; + + void unhandled_exception() noexcept + { + m_exception = std::current_exception(); + } + + auto return_void() noexcept -> void {} + + auto finished() const noexcept -> bool + { + return m_value == nullptr; + } + + auto rethrow_on_unhandled_exception() -> void + { + if (m_exception) + { + std::rethrow_exception(m_exception); + } + } + + protected: + AsyncGeneratorYieldOperation internal_yield_value() noexcept; + void* m_value{nullptr}; + + private: + std::exception_ptr m_exception; + std::coroutine_handle<> m_consumer; + + friend class AsyncGeneratorYieldOperation; + friend class AsyncGeneratorAdvanceOperation; +}; + +class AsyncGeneratorYieldOperation final +{ + public: + AsyncGeneratorYieldOperation(std::coroutine_handle<> consumer) noexcept : m_consumer(consumer) {} + + constexpr static bool await_ready() noexcept + { + return false; + } + + std::coroutine_handle<> await_suspend([[maybe_unused]] std::coroutine_handle<> producer) const noexcept + { + return m_consumer; + } + + constexpr static void await_resume() noexcept {} + + private: + std::coroutine_handle<> m_consumer; +}; + +inline AsyncGeneratorYieldOperation AsyncGeneratorPromiseBase::final_suspend() noexcept +{ + m_value = nullptr; + return internal_yield_value(); +} + +inline AsyncGeneratorYieldOperation AsyncGeneratorPromiseBase::internal_yield_value() noexcept +{ + return AsyncGeneratorYieldOperation{m_consumer}; +} + +class AsyncGeneratorAdvanceOperation +{ + protected: + AsyncGeneratorAdvanceOperation(std::nullptr_t) noexcept : m_promise(nullptr), m_producer(nullptr) {} + + AsyncGeneratorAdvanceOperation(AsyncGeneratorPromiseBase& promise, std::coroutine_handle<> producer) noexcept : + m_promise(std::addressof(promise)), + m_producer(producer) + {} + + public: + constexpr static bool await_ready() noexcept + { + return false; + } + + std::coroutine_handle<> await_suspend(std::coroutine_handle<> consumer) noexcept + { + m_promise->m_consumer = consumer; + return m_producer; + } + + protected: + AsyncGeneratorPromiseBase* m_promise; + std::coroutine_handle<> m_producer; +}; + +template +class AsyncGeneratorPromise final : public AsyncGeneratorPromiseBase +{ + using value_t = std::remove_reference_t; + using reference_t = std::conditional_t, T, T&>; + using pointer_t = value_t*; + + public: + AsyncGeneratorPromise() noexcept = default; + + AsyncGenerator get_return_object() noexcept; + + template ::value, int> = 0> + auto yield_value(value_t& value) noexcept -> AsyncGeneratorYieldOperation + { + m_value = std::addressof(value); + return internal_yield_value(); + } + + auto yield_value(value_t&& value) noexcept -> AsyncGeneratorYieldOperation + { + m_value = std::addressof(value); + return internal_yield_value(); + } + + auto value() const noexcept -> reference_t + { + return *static_cast(m_value); + } +}; + +template +class AsyncGeneratorIncrementOperation final : public AsyncGeneratorAdvanceOperation +{ + public: + AsyncGeneratorIncrementOperation(AsyncGeneratorIterator& iterator) noexcept : + AsyncGeneratorAdvanceOperation(iterator.m_coroutine.promise(), iterator.m_coroutine), + m_iterator(iterator) + {} + + AsyncGeneratorIterator& await_resume(); + + private: + AsyncGeneratorIterator& m_iterator; +}; + +struct AsyncGeneratorSentinel +{}; + +template +class AsyncGeneratorIterator final +{ + using promise_t = AsyncGeneratorPromise; + using handle_t = std::coroutine_handle; + + public: + using iterator_category = std::input_iterator_tag; // NOLINT + // Not sure what type should be used for difference_type as we don't + // allow calculating difference between two iterators. + using difference_t = std::ptrdiff_t; + using value_t = std::remove_reference_t; + using reference = std::add_lvalue_reference_t; // NOLINT + using pointer = std::add_pointer_t; // NOLINT + + AsyncGeneratorIterator(std::nullptr_t) noexcept : m_coroutine(nullptr) {} + + AsyncGeneratorIterator(handle_t coroutine) noexcept : m_coroutine(coroutine) {} + + AsyncGeneratorIncrementOperation operator++() noexcept + { + return AsyncGeneratorIncrementOperation{*this}; + } + + reference operator*() const noexcept + { + return m_coroutine.promise().value(); + } + + bool operator==(const AsyncGeneratorIterator& other) const noexcept + { + return m_coroutine == other.m_coroutine; + } + + bool operator!=(const AsyncGeneratorIterator& other) const noexcept + { + return !(*this == other); + } + + operator bool() const noexcept + { + return m_coroutine && !m_coroutine.promise().finished(); + } + + private: + friend class AsyncGeneratorIncrementOperation; + + handle_t m_coroutine; +}; + +template +inline AsyncGeneratorIterator& AsyncGeneratorIncrementOperation::await_resume() +{ + if (m_promise->finished()) + { + // Update iterator to end() + m_iterator = AsyncGeneratorIterator{nullptr}; + m_promise->rethrow_on_unhandled_exception(); + } + + return m_iterator; +} + +template +class AsyncGeneratorBeginOperation final : public AsyncGeneratorAdvanceOperation +{ + using promise_t = AsyncGeneratorPromise; + using handle_t = std::coroutine_handle; + + public: + AsyncGeneratorBeginOperation(std::nullptr_t) noexcept : AsyncGeneratorAdvanceOperation(nullptr) {} + + AsyncGeneratorBeginOperation(handle_t producer) noexcept : + AsyncGeneratorAdvanceOperation(producer.promise(), producer) + {} + + bool await_ready() const noexcept + { + return m_promise == nullptr || AsyncGeneratorAdvanceOperation::await_ready(); + } + + AsyncGeneratorIterator await_resume() + { + if (m_promise == nullptr) + { + // Called begin() on the empty generator. + return AsyncGeneratorIterator{nullptr}; + } + + if (m_promise->finished()) + { + // Completed without yielding any values. + m_promise->rethrow_on_unhandled_exception(); + return AsyncGeneratorIterator{nullptr}; + } + + return AsyncGeneratorIterator{handle_t::from_promise(*static_cast(m_promise))}; + } +}; + +} // namespace detail + +template +class [[nodiscard]] AsyncGenerator +{ + public: + // There must be a type called `promise_type` for coroutines to work. Skil linting + using promise_type = detail::AsyncGeneratorPromise; // NOLINT(readability-identifier-naming) + using iterator = detail::AsyncGeneratorIterator; // NOLINT(readability-identifier-naming) + + AsyncGenerator() noexcept : m_coroutine(nullptr) {} + + explicit AsyncGenerator(promise_type& promise) noexcept : + m_coroutine(std::coroutine_handle::from_promise(promise)) + {} + + AsyncGenerator(AsyncGenerator&& other) noexcept : m_coroutine(other.m_coroutine) + { + other.m_coroutine = nullptr; + } + + ~AsyncGenerator() + { + if (m_coroutine) + { + m_coroutine.destroy(); + } + } + + AsyncGenerator& operator=(AsyncGenerator&& other) noexcept + { + AsyncGenerator temp(std::move(other)); + swap(temp); + return *this; + } + + AsyncGenerator(const AsyncGenerator&) = delete; + AsyncGenerator& operator=(const AsyncGenerator&) = delete; + + auto begin() noexcept + { + if (!m_coroutine) + { + return detail::AsyncGeneratorBeginOperation{nullptr}; + } + + return detail::AsyncGeneratorBeginOperation{m_coroutine}; + } + + auto end() noexcept + { + return iterator{nullptr}; + } + + void swap(AsyncGenerator& other) noexcept + { + using std::swap; + swap(m_coroutine, other.m_coroutine); + } + + private: + std::coroutine_handle m_coroutine; +}; + +template +void swap(AsyncGenerator& a, AsyncGenerator& b) noexcept +{ + a.swap(b); +} + +namespace detail { +template +AsyncGenerator AsyncGeneratorPromise::get_return_object() noexcept +{ + return AsyncGenerator{*this}; +} + +} // namespace detail + +} // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/coroutines/closable_ring_buffer.hpp b/cpp/mrc/include/mrc/coroutines/closable_ring_buffer.hpp new file mode 100644 index 000000000..386dd7d32 --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/closable_ring_buffer.hpp @@ -0,0 +1,703 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Original Source: https://github.com/jbaldwin/libcoro + * Original License: Apache License, Version 2.0; included below + */ + +/** + * Copyright 2021 Josh Baldwin + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mrc/core/expected.hpp" +#include "mrc/coroutines/schedule_policy.hpp" +#include "mrc/coroutines/thread_local_context.hpp" +#include "mrc/coroutines/thread_pool.hpp" + +#include + +#include +#include +#include +#include +#include + +namespace mrc::coroutines { + +enum class RingBufferOpStatus +{ + Success, + Stopped, +}; + +/** + * @tparam ElementT The type of element the ring buffer will store. Note that this type should be + * cheap to move if possible as it is moved into and out of the buffer upon write and + * read operations. + */ +template +class ClosableRingBuffer +{ + using mutex_type = std::mutex; + + public: + struct Options + { + // capacity of ring buffer + std::size_t capacity{8}; + + // when there is an awaiting reader, the active execution context of the next writer will resume the awaiting + // reader, the schedule_policy_t dictates how that is accomplished. + SchedulePolicy reader_policy{SchedulePolicy::Reschedule}; + + // when there is an awaiting writer, the active execution context of the next reader will resume the awaiting + // writer, the producder_policy_t dictates how that is accomplished. + SchedulePolicy writer_policy{SchedulePolicy::Reschedule}; + + // when there is an awaiting writer, the active execution context of the next reader will resume the awaiting + // writer, the producder_policy_t dictates how that is accomplished. + SchedulePolicy completed_policy{SchedulePolicy::Reschedule}; + }; + + /** + * @throws std::runtime_error If `num_elements` == 0. + */ + explicit ClosableRingBuffer(Options opts = {}) : + m_elements(opts.capacity), // elements needs to be extended from just holding ElementT to include a TraceContext + m_num_elements(opts.capacity), + m_writer_policy(opts.writer_policy), + m_reader_policy(opts.reader_policy), + m_completed_policy(opts.completed_policy) + { + if (m_num_elements == 0) + { + throw std::runtime_error{"num_elements cannot be zero"}; + } + } + + ~ClosableRingBuffer() + { + // Wake up anyone still using the ring buffer. + notify_waiters(); + } + + ClosableRingBuffer(const ClosableRingBuffer&) = delete; + ClosableRingBuffer(ClosableRingBuffer&&) = delete; + + auto operator=(const ClosableRingBuffer&) noexcept -> ClosableRingBuffer& = delete; + auto operator=(ClosableRingBuffer&&) noexcept -> ClosableRingBuffer& = delete; + + struct Operation + { + virtual void resume() = 0; + }; + + struct WriteOperation : ThreadLocalContext, Operation + { + WriteOperation(ClosableRingBuffer& rb, ElementT e) : + m_rb(rb), + m_e(std::move(e)), + m_policy(m_rb.m_writer_policy) + {} + + auto await_ready() noexcept -> bool + { + // return immediate if the buffer is closed + if (m_rb.m_stopped.load(std::memory_order::acquire)) + { + m_stopped = true; + return true; + } + + // start a span to time the write - this would include time suspended if the buffer is full + // m_write_span->AddEvent("start_on", {{"thead.id", mrc::this_thread::get_id()}}); + + // the lock is owned by the operation, not scoped to the await_ready function + m_lock = std::unique_lock(m_rb.m_mutex); + return m_rb.try_write_locked(m_lock, m_e); + } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + // m_lock was acquired as part of await_ready; await_suspend is responsible for releasing the lock + auto lock = std::move(m_lock); // use raii + + ThreadLocalContext::suspend_thread_local_context(); + + m_awaiting_coroutine = awaiting_coroutine; + m_next = m_rb.m_write_waiters; + m_rb.m_write_waiters = this; + return true; + } + + /** + * @return write_result + */ + auto await_resume() -> RingBufferOpStatus + { + ThreadLocalContext::resume_thread_local_context(); + return (!m_stopped ? RingBufferOpStatus::Success : RingBufferOpStatus::Stopped); + } + + WriteOperation& use_scheduling_policy(SchedulePolicy policy) & + { + m_policy = policy; + return *this; + } + + WriteOperation use_scheduling_policy(SchedulePolicy policy) && + { + m_policy = policy; + return std::move(*this); + } + + WriteOperation& resume_immediately() & + { + m_policy = SchedulePolicy::Immediate; + return *this; + } + + WriteOperation resume_immediately() && + { + m_policy = SchedulePolicy::Immediate; + return std::move(*this); + } + + WriteOperation& resume_on(ThreadPool* thread_pool) & + { + m_policy = SchedulePolicy::Reschedule; + set_resume_on_thread_pool(thread_pool); + return *this; + } + + WriteOperation resume_on(ThreadPool* thread_pool) && + { + m_policy = SchedulePolicy::Reschedule; + set_resume_on_thread_pool(thread_pool); + return std::move(*this); + } + + private: + friend ClosableRingBuffer; + + void resume() + { + if (m_policy == SchedulePolicy::Immediate) + { + set_resume_on_thread_pool(nullptr); + } + resume_coroutine(m_awaiting_coroutine); + } + + /// The lock is acquired in await_ready; if ready it is release; otherwise, await_suspend should release it + std::unique_lock m_lock; + /// The ring buffer the element is being written into. + ClosableRingBuffer& m_rb; + /// If the operation needs to suspend, the coroutine to resume when the element can be written. + std::coroutine_handle<> m_awaiting_coroutine; + /// Linked list of write operations that are awaiting to write their element. + WriteOperation* m_next{nullptr}; + /// The element this write operation is producing into the ring buffer. + ElementT m_e; + /// Was the operation stopped? + bool m_stopped{false}; + /// Scheduling Policy - default provided by the ClosableRingBuffer, but can be overrided owner of the Awaiter + SchedulePolicy m_policy; + /// Span to measure the duration the writer spent writting data + // trace::Handle m_write_span{nullptr}; + }; + + struct ReadOperation : ThreadLocalContext, Operation + { + explicit ReadOperation(ClosableRingBuffer& rb) : m_rb(rb), m_policy(m_rb.m_reader_policy) {} + + auto await_ready() noexcept -> bool + { + // the lock is owned by the operation, not scoped to the await_ready function + m_lock = std::unique_lock(m_rb.m_mutex); + // m_read_span->AddEvent("start_on", {{"thead.id", mrc::this_thread::get_id()}}); + return m_rb.try_read_locked(m_lock, this); + } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + // m_lock was acquired as part of await_ready; await_suspend is responsible for releasing the lock + auto lock = std::move(m_lock); + + // the buffer is empty; don't suspend if the stop signal has been set. + if (m_rb.m_stopped.load(std::memory_order::acquire)) + { + m_stopped = true; + return false; + } + + // m_read_span->AddEvent("buffer_empty"); + ThreadLocalContext::suspend_thread_local_context(); + + m_awaiting_coroutine = awaiting_coroutine; + m_next = m_rb.m_read_waiters; + m_rb.m_read_waiters = this; + return true; + } + + /** + * @return The consumed element or std::nullopt if the read has failed. + */ + auto await_resume() -> mrc::expected + { + ThreadLocalContext::resume_thread_local_context(); + + if (m_stopped) + { + return mrc::unexpected(RingBufferOpStatus::Stopped); + } + + return std::move(m_e); + } + + ReadOperation& use_scheduling_policy(SchedulePolicy policy) + { + m_policy = policy; + return *this; + } + + ReadOperation& resume_immediately() + { + m_policy = SchedulePolicy::Immediate; + return *this; + } + + ReadOperation& resume_on(ThreadPool* thread_pool) + { + m_policy = SchedulePolicy::Reschedule; + set_resume_on_thread_pool(thread_pool); + return *this; + } + + private: + friend ClosableRingBuffer; + + void resume() + { + if (m_policy == SchedulePolicy::Immediate) + { + set_resume_on_thread_pool(nullptr); + } + resume_coroutine(m_awaiting_coroutine); + } + + /// The lock is acquired in await_ready; if ready it is release; otherwise, await_suspend should release it + std::unique_lock m_lock; + /// The ring buffer to read an element from. + ClosableRingBuffer& m_rb; + /// If the operation needs to suspend, the coroutine to resume when the element can be consumed. + std::coroutine_handle<> m_awaiting_coroutine; + /// Linked list of read operations that are awaiting to read an element. + ReadOperation* m_next{nullptr}; + /// The element this read operation will read. + ElementT m_e; + /// Was the operation stopped? + bool m_stopped{false}; + /// Scheduling Policy - default provided by the ClosableRingBuffer, but can be overrided owner of the Awaiter + SchedulePolicy m_policy; + /// Span measure time awaiting on reading data + // trace::Handle m_read_span; + }; + + struct CompletedOperation : ThreadLocalContext, Operation + { + explicit CompletedOperation(ClosableRingBuffer& rb) : m_rb(rb), m_policy(m_rb.m_completed_policy) {} + + auto await_ready() noexcept -> bool + { + // the lock is owned by the operation, not scoped to the await_ready function + m_lock = std::unique_lock(m_rb.m_mutex); + // m_read_span->AddEvent("start_on", {{"thead.id", mrc::this_thread::get_id()}}); + return m_rb.try_completed_locked(m_lock, this); + } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + // m_lock was acquired as part of await_ready; await_suspend is responsible for releasing the lock + auto lock = std::move(m_lock); + + // m_read_span->AddEvent("buffer_empty"); + ThreadLocalContext::suspend_thread_local_context(); + + m_awaiting_coroutine = awaiting_coroutine; + m_next = m_rb.m_completed_waiters; + m_rb.m_completed_waiters = this; + return true; + } + + /** + * @return The consumed element or std::nullopt if the read has failed. + */ + auto await_resume() + { + ThreadLocalContext::resume_thread_local_context(); + } + + ReadOperation& use_scheduling_policy(SchedulePolicy policy) + { + m_policy = policy; + return *this; + } + + ReadOperation& resume_immediately() + { + m_policy = SchedulePolicy::Immediate; + return *this; + } + + ReadOperation& resume_on(ThreadPool* thread_pool) + { + m_policy = SchedulePolicy::Reschedule; + set_resume_on_thread_pool(thread_pool); + return *this; + } + + private: + friend ClosableRingBuffer; + + void resume() + { + if (m_policy == SchedulePolicy::Immediate) + { + set_resume_on_thread_pool(nullptr); + } + resume_coroutine(m_awaiting_coroutine); + } + + /// The lock is acquired in await_ready; if ready it is release; otherwise, await_suspend should release it + std::unique_lock m_lock; + /// The ring buffer to read an element from. + ClosableRingBuffer& m_rb; + /// If the operation needs to suspend, the coroutine to resume when the element can be consumed. + std::coroutine_handle<> m_awaiting_coroutine; + /// Linked list of read operations that are awaiting to read an element. + CompletedOperation* m_next{nullptr}; + /// Was the operation stopped? + bool m_stopped{false}; + /// Scheduling Policy - default provided by the ClosableRingBuffer, but can be overrided owner of the Awaiter + SchedulePolicy m_policy; + /// Span measure time awaiting on reading data + // trace::Handle m_read_span; + }; + + /** + * Produces the given element into the ring buffer. This operation will suspend until a slot + * in the ring buffer becomes available. + * @param e The element to write. + */ + [[nodiscard]] auto write(ElementT e) -> WriteOperation + { + return WriteOperation{*this, std::move(e)}; + } + + /** + * Consumes an element from the ring buffer. This operation will suspend until an element in + * the ring buffer becomes available. + */ + [[nodiscard]] auto read() -> ReadOperation + { + return ReadOperation{*this}; + } + + /** + * Blocks until `close()` has been called and all elements have been returned + */ + [[nodiscard]] auto completed() -> CompletedOperation + { + return CompletedOperation{*this}; + } + + void close() + { + // if there are awaiting readers, then we must wait them up and signal that the buffer is closed; + // otherwise, mark the buffer as closed and fail all new writes immediately. readers should be allowed + // to keep reading until the buffer is empty. when the buffer is empty, readers will fail to suspend and exit + // with a stopped status + + // Only wake up waiters once. + if (m_stopped.load(std::memory_order::acquire)) + { + return; + } + + std::unique_lock lk{m_mutex}; + m_stopped.exchange(true, std::memory_order::release); + + // the buffer is empty and no more items will be added + if (m_used == 0) + { + // there should be no awaiting writers + CHECK(m_write_waiters == nullptr); + + // signal all awaiting readers that the buffer is stopped + while (m_read_waiters != nullptr) + { + auto* to_resume = m_read_waiters; + to_resume->m_stopped = true; + m_read_waiters = m_read_waiters->m_next; + + lk.unlock(); + to_resume->resume(); + lk.lock(); + } + + // signal all awaiting completed that the buffer is completed + while (m_completed_waiters != nullptr) + { + auto* to_resume = m_completed_waiters; + to_resume->m_stopped = true; + m_completed_waiters = m_completed_waiters->m_next; + + lk.unlock(); + to_resume->resume(); + lk.lock(); + } + } + } + + bool is_closed() const noexcept + { + return m_stopped.load(std::memory_order::acquire); + } + + /** + * @return The current number of elements contained in the ring buffer. + */ + auto size() const -> size_t + { + std::atomic_thread_fence(std::memory_order::acquire); + return m_used; + } + + /** + * @return True if the ring buffer contains zero elements. + */ + auto empty() const -> bool + { + return size() == 0; + } + + /** + * Wakes up all currently awaiting writers and readers. Their await_resume() function + * will return an expected read result that the ring buffer has stopped. + */ + auto notify_waiters() -> void + { + // Only wake up waiters once. + if (m_stopped.load(std::memory_order::acquire)) + { + return; + } + + std::unique_lock lk{m_mutex}; + m_stopped.exchange(true, std::memory_order::release); + + while (m_write_waiters != nullptr) + { + auto* to_resume = m_write_waiters; + to_resume->m_stopped = true; + m_write_waiters = m_write_waiters->m_next; + + lk.unlock(); + to_resume->resume(); + lk.lock(); + } + + while (m_read_waiters != nullptr) + { + auto* to_resume = m_read_waiters; + to_resume->m_stopped = true; + m_read_waiters = m_read_waiters->m_next; + + lk.unlock(); + to_resume->resume(); + lk.lock(); + } + + while (m_completed_waiters != nullptr) + { + auto* to_resume = m_completed_waiters; + to_resume->m_stopped = true; + m_completed_waiters = m_completed_waiters->m_next; + + lk.unlock(); + to_resume->resume(); + lk.lock(); + } + } + + private: + friend WriteOperation; + friend ReadOperation; + friend CompletedOperation; + + mutex_type m_mutex{}; + + std::vector m_elements; + const std::size_t m_num_elements; + const SchedulePolicy m_writer_policy; + const SchedulePolicy m_reader_policy; + const SchedulePolicy m_completed_policy; + + /// The current front pointer to an open slot if not full. + size_t m_front{0}; + /// The current back pointer to the oldest item in the buffer if not empty. + size_t m_back{0}; + /// The number of items in the ring buffer. + size_t m_used{0}; + + /// The LIFO list of write waiters - single writers will have order perserved + // Note: if there are multiple writers order can not be guaranteed, so no need for FIFO + WriteOperation* m_write_waiters{nullptr}; + /// The LIFO list of read watier. + ReadOperation* m_read_waiters{nullptr}; + /// The LIFO list of completed watier. + CompletedOperation* m_completed_waiters{nullptr}; + + std::atomic m_stopped{false}; + + auto try_write_locked(std::unique_lock& lk, ElementT& e) -> bool + { + if (m_used == m_num_elements) + { + DCHECK(m_read_waiters == nullptr); + return false; + } + + // We will be able to write an element into the buffer. + m_elements[m_front] = std::move(e); + m_front = (m_front + 1) % m_num_elements; + ++m_used; + + ReadOperation* to_resume = nullptr; + + if (m_read_waiters != nullptr) + { + to_resume = m_read_waiters; + m_read_waiters = m_read_waiters->m_next; + + // Since the read operation suspended it needs to be provided an element to read. + to_resume->m_e = std::move(m_elements[m_back]); + m_back = (m_back + 1) % m_num_elements; + --m_used; // And we just consumed up another item. + } + + // After this point we will no longer be checking state objects on the buffer + lk.unlock(); + + if (to_resume != nullptr) + { + to_resume->resume(); + } + + return true; + } + + auto try_read_locked(std::unique_lock& lk, ReadOperation* op) -> bool + { + if (m_used == 0) + { + return false; + } + + // We will be successful in reading an element from the buffer. + op->m_e = std::move(m_elements[m_back]); + m_back = (m_back + 1) % m_num_elements; + --m_used; + + WriteOperation* writer_to_resume = nullptr; + + if (m_write_waiters != nullptr) + { + writer_to_resume = m_write_waiters; + m_write_waiters = m_write_waiters->m_next; + + // Since the write operation suspended it needs to be provided a slot to place its element. + m_elements[m_front] = std::move(writer_to_resume->m_e); + m_front = (m_front + 1) % m_num_elements; + ++m_used; // And we just written another item. + } + + CompletedOperation* completed_waiters = nullptr; + + // Check if we are stopped and there are no more elements in the buffer. + if (m_used == 0 && m_stopped.load(std::memory_order::acquire)) + { + completed_waiters = m_completed_waiters; + m_completed_waiters = nullptr; + } + + // After this point we will no longer be checking state objects on the buffer + lk.unlock(); + + // Resume any writer + if (writer_to_resume != nullptr) + { + DCHECK(completed_waiters == nullptr) << "Logic error. Wrote value but count is 0"; + + writer_to_resume->resume(); + } + + // Resume completed if there are any + while (completed_waiters != nullptr) + { + completed_waiters->resume(); + + completed_waiters = completed_waiters->m_next; + } + + return true; + } + + auto try_completed_locked(std::unique_lock& lk, CompletedOperation* op) -> bool + { + // Condition is already met, no need to wait + if (!m_stopped.load(std::memory_order::acquire) || m_used >= 0) + { + return false; + } + + DCHECK(m_write_waiters == nullptr) << "Should not have any writers with a closed buffer"; + + // release lock + lk.unlock(); + + return true; + } +}; + +} // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/coroutines/schedule_on.hpp b/cpp/mrc/include/mrc/coroutines/schedule_on.hpp new file mode 100644 index 000000000..73505a1bd --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/schedule_on.hpp @@ -0,0 +1,98 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Original Source: https://github.com/lewissbaker/cppcoro + * Original License: MIT; included below + */ + +/////////////////////////////////////////////////////////////////////////////// +// Copyright (c) Lewis Baker +// Licenced under MIT license. See LICENSE.txt for details. +/////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include "async_generator.hpp" + +#include +#include +#include + +#include + +namespace mrc::coroutines { + +/** + * @brief Schedules an awaitable to run on the supplied scheduler. Returns the value as if it were awaited on in the + * current thread. + */ +template +auto schedule_on(SchedulerT& scheduler, AwaitableT awaitable) -> Task::awaiter_return_type>::type> +{ + using return_t = typename boost::detail::remove_rvalue_ref< + typename mrc::coroutines::concepts::awaitable_traits::awaiter_return_type>::type; + + co_await scheduler.schedule(); + + if constexpr (std::is_same_v) + { + co_await std::move(awaitable); + VLOG(10) << "schedule_on completed"; + co_return; + } + else + { + auto result = co_await std::move(awaitable); + VLOG(10) << "schedule_on completed"; + co_return std::move(result); + } +} + +/** + * @brief Schedules an async generator to run on the supplied scheduler. Each value in the generator run on the + * scheduler. The return value is the same as if the generator was run on the current thread. + * + * @tparam T + * @tparam SchedulerT + * @param scheduler + * @param source + * @return mrc::coroutines::AsyncGenerator + */ +template +mrc::coroutines::AsyncGenerator schedule_on(SchedulerT& scheduler, mrc::coroutines::AsyncGenerator source) +{ + // Transfer exection to the scheduler before the implicit calls to + // 'co_await begin()' or subsequent calls to `co_await iterator::operator++()` + // below. This ensures that all calls to the generator's coroutine_handle<>::resume() + // are executed on the execution context of the scheduler. + co_await scheduler.schedule(); + + const auto iter_end = source.end(); + auto iter = co_await source.begin(); + while (iter != iter_end) + { + co_yield *iter; + + co_await scheduler.schedule(); + + (void)co_await ++iter; + } +} + +} // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/coroutines/scheduler.hpp b/cpp/mrc/include/mrc/coroutines/scheduler.hpp new file mode 100644 index 000000000..0e296924a --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/scheduler.hpp @@ -0,0 +1,54 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mrc/coroutines/task.hpp" + +#include +#include +#include +#include +#include + +namespace mrc::coroutines { + +/** + * @brief Scheduler base class + */ +class Scheduler : public std::enable_shared_from_this +{ + public: + virtual ~Scheduler() = default; + + /** + * @brief Resumes a coroutine according to the scheduler's implementation. + */ + virtual void resume(std::coroutine_handle<> handle) noexcept = 0; + + /** + * @brief Suspends the current function and resumes it according to the scheduler's implementation. + */ + [[nodiscard]] virtual Task<> schedule() = 0; + + /** + * @brief Suspends the current function and resumes it according to the scheduler's implementation. + */ + [[nodiscard]] virtual Task<> yield() = 0; +}; + +} // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/coroutines/task_container.hpp b/cpp/mrc/include/mrc/coroutines/task_container.hpp new file mode 100644 index 000000000..20cab894e --- /dev/null +++ b/cpp/mrc/include/mrc/coroutines/task_container.hpp @@ -0,0 +1,173 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Original Source: https://github.com/jbaldwin/libcoro + * Original License: Apache License, Version 2.0; included below + */ + +/** + * Copyright 2021 Josh Baldwin + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "mrc/coroutines/task.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace mrc::coroutines { +class Scheduler; + +class TaskContainer +{ + public: + using task_position_t = std::list>>::iterator; + + /** + * @param e Tasks started in the container are scheduled onto this executor. For tasks created + * from a coro::io_scheduler, this would usually be that coro::io_scheduler instance. + */ + TaskContainer(std::shared_ptr e); + + TaskContainer(const TaskContainer&) = delete; + TaskContainer(TaskContainer&&) = delete; + auto operator=(const TaskContainer&) -> TaskContainer& = delete; + auto operator=(TaskContainer&&) -> TaskContainer& = delete; + + ~TaskContainer(); + + enum class GarbageCollectPolicy + { + /// Execute garbage collection. + yes, + /// Do not execute garbage collection. + no + }; + + /** + * Stores a user task and starts its execution on the container's thread pool. + * @param user_task The scheduled user's task to store in this task container and start its execution. + * @param cleanup Should the task container run garbage collect at the beginning of this store + * call? Calling at regular intervals will reduce memory usage of completed + * tasks and allow for the task container to re-use allocated space. + */ + auto start(Task&& user_task, GarbageCollectPolicy cleanup = GarbageCollectPolicy::yes) -> void; + + /** + * Garbage collects any tasks that are marked as deleted. This frees up space to be re-used by + * the task container for newly stored tasks. + * @return The number of tasks that were deleted. + */ + auto garbage_collect() -> std::size_t; + + /** + * @return The number of tasks that are awaiting deletion. + */ + auto delete_task_size() const -> std::size_t; + + /** + * @return True if there are no tasks awaiting deletion. + */ + auto delete_tasks_empty() const -> bool; + + /** + * @return The number of active tasks in the container. + */ + auto size() const -> std::size_t; + + /** + * @return True if there are no active tasks in the container. + */ + auto empty() const -> bool; + + /** + * @return The capacity of this task manager before it will need to grow in size. + */ + auto capacity() const -> std::size_t; + + /** + * Will continue to garbage collect and yield until all tasks are complete. This method can be + * co_await'ed to make it easier to wait for the task container to have all its tasks complete. + * + * This does not shut down the task container, but can be used when shutting down, or if your + * logic requires all the tasks contained within to complete, it is similar to coro::latch. + */ + auto garbage_collect_and_yield_until_empty() -> Task; + + private: + /** + * Special constructor for internal types to create their embeded task containers. + */ + TaskContainer(Scheduler& e); + + /** + * Interal GC call, expects the public function to lock. + */ + auto gc_internal() -> std::size_t; + + /** + * Encapsulate the users tasks in a cleanup task which marks itself for deletion upon + * completion. Simply co_await the users task until its completed and then mark the given + * position within the task manager as being deletable. The scheduler's next iteration + * in its event loop will then free that position up to be re-used. + * + * This function will also unconditionally catch all unhandled exceptions by the user's + * task to prevent the scheduler from throwing exceptions. + * @param user_task The user's task. + * @param pos The position where the task data will be stored in the task manager. + * @return The user's task wrapped in a self cleanup task. + */ + auto make_cleanup_task(Task user_task, task_position_t pos) -> Task; + + /// Mutex for safely mutating the task containers across threads, expected usage is within + /// thread pools for indeterminate lifetime requests. + std::mutex m_mutex{}; + /// The number of alive tasks. + std::atomic m_size{}; + /// Maintains the lifetime of the tasks until they are completed. + std::list>> m_tasks{}; + /// The set of tasks that have completed and need to be deleted. + std::vector m_tasks_to_delete{}; + /// The executor to schedule tasks that have just started. This is only used for lifetime management and may be + /// nullptr + std::shared_ptr m_scheduler_lifetime{nullptr}; + /// This is used internally since io_scheduler cannot pass itself in as a shared_ptr. + Scheduler* m_scheduler{nullptr}; + + friend Scheduler; +}; + +} // namespace mrc::coroutines diff --git a/cpp/mrc/include/mrc/edge/edge_holder.hpp b/cpp/mrc/include/mrc/edge/edge_holder.hpp index b3d801484..0262a7e71 100644 --- a/cpp/mrc/include/mrc/edge/edge_holder.hpp +++ b/cpp/mrc/include/mrc/edge/edge_holder.hpp @@ -152,7 +152,6 @@ class EdgeHolder void release_edge_connection() { - m_owned_edge_lifetime.reset(); m_connected_edge.reset(); } diff --git a/cpp/mrc/include/mrc/exceptions/exception_catcher.hpp b/cpp/mrc/include/mrc/exceptions/exception_catcher.hpp new file mode 100644 index 000000000..98c4a7d6d --- /dev/null +++ b/cpp/mrc/include/mrc/exceptions/exception_catcher.hpp @@ -0,0 +1,53 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#pragma once + +namespace mrc { + +/** + * @brief A utility for catching out-of-stack exceptions in a thread-safe manner such that they + * can be checked and throw from a parent thread. + */ +class ExceptionCatcher +{ + public: + /** + * @brief "catches" an exception to the catcher + */ + void push_exception(std::exception_ptr ex); + + /** + * @brief checks to see if any exceptions have been "caught" by the catcher. + */ + bool has_exception(); + + /** + * @brief rethrows the next exception (in the order in which it was "caught"). + */ + void rethrow_next_exception(); + + private: + std::mutex m_mutex{}; + std::queue m_exceptions{}; +}; + +} // namespace mrc diff --git a/cpp/mrc/include/mrc/types.hpp b/cpp/mrc/include/mrc/types.hpp index 063e00831..4bbdc8171 100644 --- a/cpp/mrc/include/mrc/types.hpp +++ b/cpp/mrc/include/mrc/types.hpp @@ -24,33 +24,40 @@ namespace mrc { +// Suppress naming conventions in this file to allow matching std and boost libraries +// NOLINTBEGIN(readability-identifier-naming) + // Typedefs template -using Promise = userspace_threads::promise; // NOLINT(readability-identifier-naming) +using Promise = userspace_threads::promise; template -using Future = userspace_threads::future; // NOLINT(readability-identifier-naming) +using Future = userspace_threads::future; template -using SharedFuture = userspace_threads::shared_future; // NOLINT(readability-identifier-naming) +using SharedFuture = userspace_threads::shared_future; + +using Mutex = userspace_threads::mutex; -using Mutex = userspace_threads::mutex; // NOLINT(readability-identifier-naming) +using RecursiveMutex = userspace_threads::recursive_mutex; -using CondV = userspace_threads::cv; // NOLINT(readability-identifier-naming) +using CondV = userspace_threads::cv; -using MachineID = std::uint64_t; // NOLINT(readability-identifier-naming) -using InstanceID = std::uint64_t; // NOLINT(readability-identifier-naming) -using TagID = std::uint64_t; // NOLINT(readability-identifier-naming) +using MachineID = std::uint64_t; +using InstanceID = std::uint64_t; +using TagID = std::uint64_t; template -using Handle = std::shared_ptr; // NOLINT(readability-identifier-naming) +using Handle = std::shared_ptr; + +using SegmentID = std::uint16_t; +using SegmentRank = std::uint16_t; +using SegmentAddress = std::uint32_t; // id + rank -using SegmentID = std::uint16_t; // NOLINT(readability-identifier-naming) -using SegmentRank = std::uint16_t; // NOLINT(readability-identifier-naming) -using SegmentAddress = std::uint32_t; // NOLINT(readability-identifier-naming) // id + rank +using PortName = std::string; +using PortID = std::uint16_t; +using PortAddress = std::uint64_t; // id + rank + port -using PortName = std::string; // NOLINT(readability-identifier-naming) -using PortID = std::uint16_t; // NOLINT(readability-identifier-naming) -using PortAddress = std::uint64_t; // NOLINT(readability-identifier-naming) // id + rank + port +// NOLINTEND(readability-identifier-naming) } // namespace mrc diff --git a/cpp/mrc/src/internal/codable/decodable_storage_view.cpp b/cpp/mrc/src/internal/codable/decodable_storage_view.cpp index a4db24dac..5d29c7128 100644 --- a/cpp/mrc/src/internal/codable/decodable_storage_view.cpp +++ b/cpp/mrc/src/internal/codable/decodable_storage_view.cpp @@ -37,7 +37,6 @@ #include #include #include -#include namespace mrc::codable { diff --git a/cpp/mrc/src/internal/codable/storage_view.cpp b/cpp/mrc/src/internal/codable/storage_view.cpp index 3ae474ad7..834af06e1 100644 --- a/cpp/mrc/src/internal/codable/storage_view.cpp +++ b/cpp/mrc/src/internal/codable/storage_view.cpp @@ -19,7 +19,6 @@ #include -#include #include namespace mrc::codable { diff --git a/cpp/mrc/src/internal/control_plane/client.cpp b/cpp/mrc/src/internal/control_plane/client.cpp index 7a85adc2e..54f68a5da 100644 --- a/cpp/mrc/src/internal/control_plane/client.cpp +++ b/cpp/mrc/src/internal/control_plane/client.cpp @@ -19,8 +19,10 @@ #include "internal/control_plane/client/connections_manager.hpp" #include "internal/grpc/progress_engine.hpp" -#include "internal/grpc/promise_handler.hpp" +#include "internal/grpc/promise_handler.hpp" // for PromiseHandler +#include "internal/grpc/stream_writer.hpp" // for StreamWriter #include "internal/runnable/runnable_resources.hpp" +#include "internal/service.hpp" #include "internal/system/system.hpp" #include "mrc/channel/status.hpp" @@ -33,23 +35,42 @@ #include "mrc/runnable/launch_control.hpp" #include "mrc/runnable/launcher.hpp" #include "mrc/runnable/runner.hpp" +#include "mrc/types.hpp" +#include // for promise #include #include #include +#include #include namespace mrc::control_plane { +std::atomic_uint64_t AsyncEventStatus::s_request_id_counter; + +AsyncEventStatus::AsyncEventStatus() : m_request_id(++s_request_id_counter) {} + +size_t AsyncEventStatus::request_id() const +{ + return m_request_id; +} + +void AsyncEventStatus::set_future(Future future) +{ + m_future = std::move(future); +} + Client::Client(resources::PartitionResourceBase& base, std::shared_ptr cq) : resources::PartitionResourceBase(base), + Service("control_plane::Client"), m_cq(std::move(cq)), m_owns_progress_engine(false) {} Client::Client(resources::PartitionResourceBase& base) : resources::PartitionResourceBase(base), + Service("control_plane::Client"), m_cq(std::make_shared()), m_owns_progress_engine(true) {} @@ -73,13 +94,11 @@ void Client::do_service_start() if (m_owns_progress_engine) { CHECK(m_cq); - auto progress_engine = std::make_unique(m_cq); - auto progress_handler = std::make_unique(); + auto progress_engine = std::make_unique(m_cq); + m_progress_handler = std::make_unique(); - mrc::make_edge(*progress_engine, *progress_handler); + mrc::make_edge(*progress_engine, *m_progress_handler); - m_progress_handler = - runnable().launch_control().prepare_launcher(launch_options(), std::move(progress_handler))->ignition(); m_progress_engine = runnable().launch_control().prepare_launcher(launch_options(), std::move(progress_engine))->ignition(); } @@ -135,7 +154,6 @@ void Client::do_service_await_live() if (m_owns_progress_engine) { m_progress_engine->await_live(); - m_progress_handler->await_live(); } m_event_handler->await_live(); } @@ -150,7 +168,6 @@ void Client::do_service_await_join() { m_cq->Shutdown(); m_progress_engine->await_join(); - m_progress_handler->await_join(); } } @@ -161,10 +178,21 @@ void Client::do_handle_event(event_t&& event) // handle a subset of events directly on the event handler case protos::EventType::Response: { - auto* promise = reinterpret_cast*>(event.msg.tag()); - if (promise != nullptr) + auto event_tag = event.msg.tag(); + + if (event_tag != 0) { - promise->set_value(std::move(event.msg)); + // Lock to prevent multiple threads + std::unique_lock lock(m_mutex); + + // Find the promise associated with the event tag + auto promise = m_pending_events.extract(event_tag); + + // Unlock to allow other threads to continue as soon as possible + lock.unlock(); + + // Finally, set the value + promise.mapped().set_value(std::move(event.msg)); } } break; @@ -242,11 +270,11 @@ const mrc::runnable::LaunchOptions& Client::launch_options() const return m_launch_options; } -void Client::issue_event(const protos::EventType& event_type) +AsyncEventStatus Client::issue_event(const protos::EventType& event_type) { protos::Event event; event.set_event(event_type); - m_writer->await_write(std::move(event)); + return this->write_event(std::move(event), false); } void Client::request_update() @@ -260,4 +288,37 @@ void Client::request_update() // } } +AsyncEventStatus Client::write_event(protos::Event event, bool await_response) +{ + if (event.tag() != 0) + { + LOG(WARNING) << "event tag is set but this field should exclusively be used by the control plane client. " + "Clearing to avoid confusion"; + event.clear_tag(); + } + + AsyncEventStatus status; + + if (await_response) + { + // If we are supporting awaiting, create the promise now + Promise promise; + + // Set the future to the status + status.set_future(promise.get_future()); + + // Set the tag to the request ID to allow looking up the promise later + event.set_tag(status.request_id()); + + // Save the promise to the pending promises to be retrieved later + std::unique_lock lock(m_mutex); + + m_pending_events[status.request_id()] = std::move(promise); + } + + // Finally, write the event + m_writer->await_write(std::move(event)); + + return status; +} } // namespace mrc::control_plane diff --git a/cpp/mrc/src/internal/control_plane/client.hpp b/cpp/mrc/src/internal/control_plane/client.hpp index 0a07991a6..efda25db8 100644 --- a/cpp/mrc/src/internal/control_plane/client.hpp +++ b/cpp/mrc/src/internal/control_plane/client.hpp @@ -19,22 +19,22 @@ #include "internal/control_plane/client/instance.hpp" // IWYU pragma: keep #include "internal/grpc/client_streaming.hpp" -#include "internal/grpc/stream_writer.hpp" #include "internal/resources/partition_resources_base.hpp" #include "internal/service.hpp" #include "mrc/core/error.hpp" +#include "mrc/exceptions/runtime_error.hpp" #include "mrc/node/forward.hpp" #include "mrc/node/writable_entrypoint.hpp" #include "mrc/protos/architect.grpc.pb.h" #include "mrc/protos/architect.pb.h" #include "mrc/runnable/launch_options.hpp" #include "mrc/types.hpp" -#include "mrc/utils/macros.hpp" -#include #include +#include +#include // for size_t #include #include #include @@ -65,10 +65,56 @@ namespace mrc::runnable { class Runner; } // namespace mrc::runnable +namespace mrc::rpc { +class PromiseHandler; +template +struct StreamWriter; +} // namespace mrc::rpc + namespace mrc::control_plane { -template -class AsyncStatus; +class AsyncEventStatus +{ + public: + size_t request_id() const; + + template + Expected await_response() + { + if (!m_future.valid()) + { + throw exceptions::MrcRuntimeError( + "This AsyncEventStatus is not expecting a response or the response has already been awaited"); + } + + auto event = m_future.get(); + + if (event.has_error()) + { + return Error::create(event.error().message()); + } + + ResponseT response; + if (!event.message().UnpackTo(&response)) + { + throw Error::create("fatal error: unable to unpack message; server sent the wrong message type"); + } + + return response; + } + + private: + AsyncEventStatus(); + + void set_future(Future future); + + static std::atomic_size_t s_request_id_counter; + + size_t m_request_id; + Future m_future; + + friend class Client; +}; /** * @brief Primary Control Plane Client @@ -128,13 +174,13 @@ class Client final : public resources::PartitionResourceBase, public Service template Expected await_unary(const protos::EventType& event_type, RequestT&& request); - template - void async_unary(const protos::EventType& event_type, RequestT&& request, AsyncStatus& status); + template + AsyncEventStatus async_unary(const protos::EventType& event_type, RequestT&& request); template - void issue_event(const protos::EventType& event_type, MessageT&& message); + AsyncEventStatus issue_event(const protos::EventType& event_type, MessageT&& message); - void issue_event(const protos::EventType& event_type); + AsyncEventStatus issue_event(const protos::EventType& event_type); bool has_subscription_service(const std::string& name) const; @@ -150,6 +196,8 @@ class Client final : public resources::PartitionResourceBase, public Service void request_update(); private: + AsyncEventStatus write_event(protos::Event event, bool await_response = false); + void route_state_update(std::uint64_t tag, protos::StateUpdate&& update); void do_service_start() final; @@ -175,7 +223,7 @@ class Client final : public resources::PartitionResourceBase, public Service // if true, then the following runners should not be null // if false, then the following runners must be null const bool m_owns_progress_engine; - std::unique_ptr m_progress_handler; + std::unique_ptr m_progress_handler; std::unique_ptr m_progress_engine; std::unique_ptr m_event_handler; @@ -201,70 +249,39 @@ class Client final : public resources::PartitionResourceBase, public Service std::mutex m_mutex; + std::map> m_pending_events; + friend network::NetworkResources; }; // todo: create this object from the client which will own the stop_source // create this object with a stop_token associated with the client's stop_source -template -class AsyncStatus -{ - public: - AsyncStatus() = default; - - DELETE_COPYABILITY(AsyncStatus); - DELETE_MOVEABILITY(AsyncStatus); - - Expected await_response() - { - // todo(ryan): expand this into a wait_until with a deadline and a stop token - auto event = m_promise.get_future().get(); - - if (event.has_error()) - { - return Error::create(event.error().message()); - } - - ResponseT response; - if (!event.message().UnpackTo(&response)) - { - throw Error::create("fatal error: unable to unpack message; server sent the wrong message type"); - } - - return response; - } - - private: - Promise m_promise; - friend Client; -}; - template Expected Client::await_unary(const protos::EventType& event_type, RequestT&& request) { - AsyncStatus status; - async_unary(event_type, std::move(request), status); - return status.await_response(); + auto status = this->async_unary(event_type, std::move(request)); + return status.template await_response(); } -template -void Client::async_unary(const protos::EventType& event_type, RequestT&& request, AsyncStatus& status) +template +AsyncEventStatus Client::async_unary(const protos::EventType& event_type, RequestT&& request) { protos::Event event; event.set_event(event_type); - event.set_tag(reinterpret_cast(&status.m_promise)); CHECK(event.mutable_message()->PackFrom(request)); - m_writer->await_write(std::move(event)); + + return this->write_event(std::move(event), true); } template -void Client::issue_event(const protos::EventType& event_type, MessageT&& message) +AsyncEventStatus Client::issue_event(const protos::EventType& event_type, MessageT&& message) { protos::Event event; event.set_event(event_type); CHECK(event.mutable_message()->PackFrom(message)); - m_writer->await_write(std::move(event)); + + return this->write_event(std::move(event), false); } } // namespace mrc::control_plane diff --git a/cpp/mrc/src/internal/control_plane/client/connections_manager.cpp b/cpp/mrc/src/internal/control_plane/client/connections_manager.cpp index 76cc2477e..1cb40b953 100644 --- a/cpp/mrc/src/internal/control_plane/client/connections_manager.cpp +++ b/cpp/mrc/src/internal/control_plane/client/connections_manager.cpp @@ -31,7 +31,6 @@ #include #include -#include #include #include #include diff --git a/cpp/mrc/src/internal/control_plane/client/instance.cpp b/cpp/mrc/src/internal/control_plane/client/instance.cpp index 65c0040ad..5843c59a8 100644 --- a/cpp/mrc/src/internal/control_plane/client/instance.cpp +++ b/cpp/mrc/src/internal/control_plane/client/instance.cpp @@ -24,6 +24,7 @@ #include "internal/utils/contains.hpp" #include "mrc/edge/edge_builder.hpp" +#include "mrc/edge/edge_writable.hpp" #include "mrc/node/rx_sink.hpp" #include "mrc/protos/architect.pb.h" #include "mrc/runnable/launch_control.hpp" @@ -49,6 +50,7 @@ Instance::Instance(Client& client, resources::PartitionResourceBase& base, mrc::edge::IWritableAcceptor& update_channel) : resources::PartitionResourceBase(base), + Service("control_plane::client::Instance"), m_client(client), m_instance_id(instance_id) { diff --git a/cpp/mrc/src/internal/control_plane/client/state_manager.cpp b/cpp/mrc/src/internal/control_plane/client/state_manager.cpp index 1970e3574..e21fc6519 100644 --- a/cpp/mrc/src/internal/control_plane/client/state_manager.cpp +++ b/cpp/mrc/src/internal/control_plane/client/state_manager.cpp @@ -22,6 +22,7 @@ #include "mrc/core/error.hpp" #include "mrc/edge/edge_builder.hpp" +#include "mrc/edge/edge_writable.hpp" #include "mrc/node/rx_sink.hpp" #include "mrc/protos/architect.pb.h" #include "mrc/runnable/launch_control.hpp" diff --git a/cpp/mrc/src/internal/control_plane/client/subscription_service.cpp b/cpp/mrc/src/internal/control_plane/client/subscription_service.cpp index 50e6e2351..c190e3995 100644 --- a/cpp/mrc/src/internal/control_plane/client/subscription_service.cpp +++ b/cpp/mrc/src/internal/control_plane/client/subscription_service.cpp @@ -34,6 +34,7 @@ namespace mrc::control_plane::client { SubscriptionService::SubscriptionService(const std::string& service_name, Instance& instance) : + Service("control_plane::client::SubscriptionService"), m_service_name(std::move(service_name)), m_instance(instance) { diff --git a/cpp/mrc/src/internal/control_plane/server.cpp b/cpp/mrc/src/internal/control_plane/server.cpp index aa980aba8..afaee91c7 100644 --- a/cpp/mrc/src/internal/control_plane/server.cpp +++ b/cpp/mrc/src/internal/control_plane/server.cpp @@ -41,7 +41,6 @@ #include #include -#include #include #include #include @@ -86,9 +85,16 @@ static Expected<> unary_response(Server::event_t& event, Expected&& me return {}; } -Server::Server(runnable::RunnableResources& runnable) : m_runnable(runnable), m_server(m_runnable) {} +Server::Server(runnable::RunnableResources& runnable) : + Service("control_plane::Server"), + m_runnable(runnable), + m_server(m_runnable) +{} -Server::~Server() = default; +Server::~Server() +{ + Service::call_in_destructor(); +} void Server::do_service_start() { diff --git a/cpp/mrc/src/internal/control_plane/server.hpp b/cpp/mrc/src/internal/control_plane/server.hpp index d3d319502..6f7464de9 100644 --- a/cpp/mrc/src/internal/control_plane/server.hpp +++ b/cpp/mrc/src/internal/control_plane/server.hpp @@ -35,7 +35,7 @@ #include #include #include - +// IWYU pragma: no_include "internal/control_plane/server/subscription_manager.hpp" // IWYU pragma: no_forward_declare mrc::node::WritableEntrypoint namespace mrc::node { @@ -45,7 +45,7 @@ class Queue; namespace mrc::control_plane::server { class ClientInstance; -class SubscriptionService; +class SubscriptionService; // IWYU pragma: keep } // namespace mrc::control_plane::server namespace mrc::rpc { template diff --git a/cpp/mrc/src/internal/control_plane/server/connection_manager.cpp b/cpp/mrc/src/internal/control_plane/server/connection_manager.cpp index 617c3b4c6..2098f283b 100644 --- a/cpp/mrc/src/internal/control_plane/server/connection_manager.cpp +++ b/cpp/mrc/src/internal/control_plane/server/connection_manager.cpp @@ -27,7 +27,6 @@ #include #include -#include #include #include diff --git a/cpp/mrc/src/internal/data_plane/client.cpp b/cpp/mrc/src/internal/data_plane/client.cpp index 0f0a5ee4c..dc8709e43 100644 --- a/cpp/mrc/src/internal/data_plane/client.cpp +++ b/cpp/mrc/src/internal/data_plane/client.cpp @@ -25,7 +25,7 @@ #include "internal/memory/transient_pool.hpp" #include "internal/remote_descriptor/manager.hpp" #include "internal/runnable/runnable_resources.hpp" -#include "internal/ucx/common.hpp" +#include "internal/service.hpp" #include "internal/ucx/endpoint.hpp" #include "internal/ucx/ucx_resources.hpp" #include "internal/ucx/worker.hpp" @@ -53,7 +53,6 @@ #include #include #include -#include namespace mrc::data_plane { @@ -64,13 +63,17 @@ Client::Client(resources::PartitionResourceBase& base, control_plane::client::ConnectionsManager& connections_manager, memory::TransientPool& transient_pool) : resources::PartitionResourceBase(base), + Service("data_plane::Client"), m_ucx(ucx), m_connnection_manager(connections_manager), m_transient_pool(transient_pool), m_rd_channel(std::make_unique>()) {} -Client::~Client() = default; +Client::~Client() +{ + Service::call_in_destructor(); +} std::shared_ptr Client::endpoint_shared(const InstanceID& id) const { diff --git a/cpp/mrc/src/internal/data_plane/data_plane_resources.cpp b/cpp/mrc/src/internal/data_plane/data_plane_resources.cpp index 3ecf2d3f6..78cf64f7e 100644 --- a/cpp/mrc/src/internal/data_plane/data_plane_resources.cpp +++ b/cpp/mrc/src/internal/data_plane/data_plane_resources.cpp @@ -38,6 +38,7 @@ DataPlaneResources::DataPlaneResources(resources::PartitionResourceBase& base, const InstanceID& instance_id, control_plane::Client& control_plane_client) : resources::PartitionResourceBase(base), + Service("DataPlaneResources"), m_ucx(ucx), m_host(host), m_control_plane_client(control_plane_client), diff --git a/cpp/mrc/src/internal/data_plane/server.cpp b/cpp/mrc/src/internal/data_plane/server.cpp index a230ad934..d2f3974c9 100644 --- a/cpp/mrc/src/internal/data_plane/server.cpp +++ b/cpp/mrc/src/internal/data_plane/server.cpp @@ -36,7 +36,6 @@ #include "mrc/runnable/runner.hpp" #include "mrc/types.hpp" -#include #include #include #include @@ -47,7 +46,6 @@ #include #include #include -#include #include #include @@ -148,6 +146,7 @@ Server::Server(resources::PartitionResourceBase& provider, memory::TransientPool& transient_pool, InstanceID instance_id) : resources::PartitionResourceBase(provider), + Service("data_plane::Server"), m_ucx(ucx), m_host(host), m_instance_id(instance_id), diff --git a/cpp/mrc/src/internal/executor/executor_definition.cpp b/cpp/mrc/src/internal/executor/executor_definition.cpp index de630115d..a341f4434 100644 --- a/cpp/mrc/src/internal/executor/executor_definition.cpp +++ b/cpp/mrc/src/internal/executor/executor_definition.cpp @@ -76,6 +76,7 @@ static bool valid_pipeline(const pipeline::PipelineDefinition& pipeline) ExecutorDefinition::ExecutorDefinition(std::unique_ptr system) : SystemProvider(std::move(system)), + Service("ExecutorDefinition"), m_resources_manager(std::make_unique(*this)) {} @@ -128,7 +129,6 @@ void ExecutorDefinition::join() void ExecutorDefinition::do_service_start() { CHECK(m_pipeline_manager); - m_pipeline_manager->service_start(); pipeline::SegmentAddresses initial_segments; for (const auto& [id, segment] : m_pipeline_manager->pipeline().segments()) diff --git a/cpp/mrc/src/internal/grpc/client_streaming.hpp b/cpp/mrc/src/internal/grpc/client_streaming.hpp index 8ee6bd82e..ad2c82fb5 100644 --- a/cpp/mrc/src/internal/grpc/client_streaming.hpp +++ b/cpp/mrc/src/internal/grpc/client_streaming.hpp @@ -18,6 +18,7 @@ #pragma once #include "internal/grpc/progress_engine.hpp" +#include "internal/grpc/promise_handler.hpp" #include "internal/grpc/stream_writer.hpp" #include "internal/runnable/runnable_resources.hpp" #include "internal/service.hpp" @@ -152,6 +153,7 @@ class ClientStream : private Service, public std::enable_shared_from_this>(grpc::ClientContext* context)>; ClientStream(prepare_fn_t prepare_fn, runnable::RunnableResources& runnable) : + Service("rpc::ClientStream"), m_prepare_fn(prepare_fn), m_runnable(runnable), m_reader_source(std::make_unique>( @@ -195,10 +197,10 @@ class ClientStream : private Service, public std::enable_shared_from_this read; + auto* wrapper = new PromiseWrapper("Client::Read"); IncomingData data; - m_stream->Read(&data.msg, &read); - auto ok = read.get_future().get(); + m_stream->Read(&data.msg, wrapper); + auto ok = wrapper->get_future(); if (!ok) { m_write_channel.reset(); @@ -216,9 +218,9 @@ class ClientStream : private Service, public std::enable_shared_from_this promise; - m_stream->Write(request, &promise); - auto ok = promise.get_future().get(); + auto* wrapper = new PromiseWrapper("Client::Write"); + m_stream->Write(request, wrapper); + auto ok = wrapper->get_future(); if (!ok) { m_can_write = false; @@ -234,10 +236,20 @@ class ClientStream : private Service, public std::enable_shared_from_this writes_done; - m_stream->WritesDone(&writes_done); - writes_done.get_future().get(); - DVLOG(10) << "client issued writes done to server"; + { + auto* wrapper = new PromiseWrapper("Client::WritesDone"); + m_stream->WritesDone(wrapper); + wrapper->get_future(); + } + + { + // Now issue finish since this is OK at the client level + auto* wrapper = new PromiseWrapper("Client::Finish"); + m_stream->Finish(&m_status, wrapper); + wrapper->get_future(); + } + + // DVLOG(10) << "client issued writes done to server"; }; } @@ -284,9 +296,9 @@ class ClientStream : private Service, public std::enable_shared_from_this promise; - m_stream->StartCall(&promise); - auto ok = promise.get_future().get(); + auto* wrapper = new PromiseWrapper("Client::StartCall", false); + m_stream->StartCall(wrapper); + auto ok = wrapper->get_future(); if (!ok) { @@ -327,10 +339,6 @@ class ClientStream : private Service, public std::enable_shared_from_thisawait_join(); m_reader->await_join(); - - Promise finish; - m_stream->Finish(&m_status, &finish); - auto ok = finish.get_future().get(); } } diff --git a/cpp/mrc/src/internal/grpc/progress_engine.cpp b/cpp/mrc/src/internal/grpc/progress_engine.cpp index 68f157bf5..f540bf8b9 100644 --- a/cpp/mrc/src/internal/grpc/progress_engine.cpp +++ b/cpp/mrc/src/internal/grpc/progress_engine.cpp @@ -23,7 +23,6 @@ #include #include -#include #include #include @@ -40,6 +39,9 @@ void ProgressEngine::data_source(rxcpp::subscriber& s) while (s.is_subscribed()) { + event.ok = false; + event.tag = nullptr; + switch (m_cq->AsyncNext(&event.tag, &event.ok, gpr_time_0(GPR_CLOCK_REALTIME))) { case grpc::CompletionQueue::NextStatus::GOT_EVENT: { diff --git a/cpp/mrc/src/internal/grpc/progress_engine.hpp b/cpp/mrc/src/internal/grpc/progress_engine.hpp index 7bea6239e..23afa26f1 100644 --- a/cpp/mrc/src/internal/grpc/progress_engine.hpp +++ b/cpp/mrc/src/internal/grpc/progress_engine.hpp @@ -23,7 +23,6 @@ #include #include -#include namespace grpc { class CompletionQueue; diff --git a/cpp/mrc/src/internal/grpc/promise_handler.cpp b/cpp/mrc/src/internal/grpc/promise_handler.cpp new file mode 100644 index 000000000..444d69738 --- /dev/null +++ b/cpp/mrc/src/internal/grpc/promise_handler.cpp @@ -0,0 +1,67 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "internal/grpc/promise_handler.hpp" + +// MRC_CONCAT_STR is needed for debug builds, in CI IWYU is run with a release config +#include "mrc/utils/string_utils.hpp" // IWYU pragma: keep for MRC_CONCAT_STR + +#include // for future +#include // for COMPACT_GOOGLE_LOG_INFO + +#include +#include // for operator<<, basic_ostream +#include // for move + +namespace mrc::rpc { + +std::atomic_size_t PromiseWrapper::s_id_counter = 0; + +PromiseWrapper::PromiseWrapper(std::string method, bool in_runtime) : id(++s_id_counter), method(std::move(method)) +{ +#if (!defined(NDEBUG)) + this->prefix = MRC_CONCAT_STR("Promise[" << id << ", " << this << "](" << method << "): "); +#endif + VLOG(20) << this->to_string() << "#1 creating promise"; +} + +void PromiseWrapper::set_value(bool val) +{ + auto tmp_prefix = this->to_string(); + + VLOG(20) << tmp_prefix << "#2 setting promise to " << val; + this->promise.set_value(val); + VLOG(20) << tmp_prefix << "#3 setting promise to " << val << "... done"; +} + +bool PromiseWrapper::get_future() +{ + auto future = this->promise.get_future(); + + auto value = future.get(); + + VLOG(20) << this->to_string() << "#4 got future with value " << value; + + return value; +} + +std::string PromiseWrapper::to_string() const +{ + return this->prefix; +} + +} // namespace mrc::rpc diff --git a/cpp/mrc/src/internal/grpc/promise_handler.hpp b/cpp/mrc/src/internal/grpc/promise_handler.hpp index 437a22e69..0220eb685 100644 --- a/cpp/mrc/src/internal/grpc/promise_handler.hpp +++ b/cpp/mrc/src/internal/grpc/promise_handler.hpp @@ -20,21 +20,55 @@ #include "internal/grpc/progress_engine.hpp" #include "mrc/node/generic_sink.hpp" +#include "mrc/node/sink_properties.hpp" // for SinkProperties, Status -#include +#include // for promise + +#include // for atomic_size_t +#include // for size_t +#include namespace mrc::rpc { +struct PromiseWrapper +{ + PromiseWrapper(std::string method, bool in_runtime = true); + + ~PromiseWrapper() = default; + + size_t id; + std::string method; + std::string prefix; + boost::fibers::promise promise; + + void set_value(bool val); + + bool get_future(); + + std::string to_string() const; + + private: + static std::atomic_size_t s_id_counter; +}; + /** * @brief MRC Sink to handle ProgressEvents which correspond to Promise tags */ -class PromiseHandler final : public mrc::node::GenericSink +class PromiseHandler final : public mrc::node::GenericSinkComponent { - void on_data(ProgressEvent&& event) final + mrc::channel::Status on_data(ProgressEvent&& event) final { - auto* promise = static_cast*>(event.tag); + auto* promise = static_cast(event.tag); + promise->set_value(event.ok); - } + return mrc::channel::Status::success; + delete promise; + }; + + void on_complete() override + { + SinkProperties::release_edge_connection(); + }; }; } // namespace mrc::rpc diff --git a/cpp/mrc/src/internal/grpc/server.cpp b/cpp/mrc/src/internal/grpc/server.cpp index 9e0c0ecb4..e03293d15 100644 --- a/cpp/mrc/src/internal/grpc/server.cpp +++ b/cpp/mrc/src/internal/grpc/server.cpp @@ -18,7 +18,7 @@ #include "internal/grpc/server.hpp" #include "internal/grpc/progress_engine.hpp" -#include "internal/grpc/promise_handler.hpp" +#include "internal/grpc/promise_handler.hpp" // for PromiseHandler #include "internal/runnable/runnable_resources.hpp" #include "mrc/edge/edge_builder.hpp" @@ -31,7 +31,7 @@ namespace mrc::rpc { -Server::Server(runnable::RunnableResources& runnable) : m_runnable(runnable) +Server::Server(runnable::RunnableResources& runnable) : Service("rpc::Server"), m_runnable(runnable) { m_cq = m_builder.AddCompletionQueue(); m_builder.AddListeningPort("0.0.0.0:13337", grpc::InsecureServerCredentials()); @@ -47,11 +47,10 @@ void Server::do_service_start() m_server = m_builder.BuildAndStart(); auto progress_engine = std::make_unique(m_cq); - auto event_handler = std::make_unique(); - mrc::make_edge(*progress_engine, *event_handler); + m_event_hander = std::make_unique(); + mrc::make_edge(*progress_engine, *m_event_hander); m_progress_engine = m_runnable.launch_control().prepare_launcher(std::move(progress_engine))->ignition(); - m_event_hander = m_runnable.launch_control().prepare_launcher(std::move(event_handler))->ignition(); } void Server::do_service_stop() @@ -70,19 +69,17 @@ void Server::do_service_kill() void Server::do_service_await_live() { - if (m_progress_engine && m_event_hander) + if (m_progress_engine) { m_progress_engine->await_live(); - m_event_hander->await_live(); } } void Server::do_service_await_join() { - if (m_progress_engine && m_event_hander) + if (m_progress_engine) { m_progress_engine->await_join(); - m_event_hander->await_join(); } } diff --git a/cpp/mrc/src/internal/grpc/server.hpp b/cpp/mrc/src/internal/grpc/server.hpp index cacd4602d..db9436d95 100644 --- a/cpp/mrc/src/internal/grpc/server.hpp +++ b/cpp/mrc/src/internal/grpc/server.hpp @@ -34,6 +34,10 @@ namespace mrc::runnable { class Runner; } // namespace mrc::runnable +namespace mrc::rpc { +class PromiseHandler; +} // namespace mrc::rpc + namespace mrc::rpc { class Server : public Service @@ -61,7 +65,7 @@ class Server : public Service std::shared_ptr m_cq; std::unique_ptr m_server; std::unique_ptr m_progress_engine; - std::unique_ptr m_event_hander; + std::unique_ptr m_event_hander; }; } // namespace mrc::rpc diff --git a/cpp/mrc/src/internal/grpc/server_streaming.hpp b/cpp/mrc/src/internal/grpc/server_streaming.hpp index 0d4da8b44..f2d50e1d4 100644 --- a/cpp/mrc/src/internal/grpc/server_streaming.hpp +++ b/cpp/mrc/src/internal/grpc/server_streaming.hpp @@ -18,6 +18,7 @@ #pragma once #include "internal/grpc/progress_engine.hpp" +#include "internal/grpc/promise_handler.hpp" #include "internal/grpc/stream_writer.hpp" #include "internal/runnable/runnable_resources.hpp" #include "internal/service.hpp" @@ -164,6 +165,7 @@ class ServerStream : private Service, public std::enable_shared_from_this* stream, void* tag)>; ServerStream(request_fn_t request_fn, runnable::RunnableResources& runnable) : + Service("rpc::ServerStream"), m_runnable(runnable), m_stream(std::make_unique>(&m_context)), m_reader_source(std::make_unique>( @@ -223,10 +225,11 @@ class ServerStream : private Service, public std::enable_shared_from_this read; + IncomingData data; - m_stream->Read(&data.msg, &read); - auto ok = read.get_future().get(); + auto* wrapper = new PromiseWrapper("Server::Read"); + m_stream->Read(&data.msg, wrapper); + auto ok = wrapper->get_future(); data.ok = ok; data.stream = writer(); s.on_next(std::move(data)); @@ -247,9 +250,9 @@ class ServerStream : private Service, public std::enable_shared_from_this promise; - m_stream->Write(request, &promise); - auto ok = promise.get_future().get(); + auto* wrapper = new PromiseWrapper("Server::Write"); + m_stream->Write(request, wrapper); + auto ok = wrapper->get_future(); if (!ok) { DVLOG(10) << "server failed to write to client; disabling writes and beginning shutdown"; @@ -272,10 +275,10 @@ class ServerStream : private Service, public std::enable_shared_from_this finish; - m_stream->Finish(*m_status, &finish); - auto ok = finish.get_future().get(); - DVLOG(10) << "server done with finish"; + auto* wrapper = new PromiseWrapper("Server::Finish"); + m_stream->Finish(*m_status, wrapper); + auto ok = wrapper->get_future(); + // DVLOG(10) << "server done with finish"; } } @@ -317,10 +320,9 @@ class ServerStream : private Service, public std::enable_shared_from_this promise; - m_init_fn(&promise); - auto ok = promise.get_future().get(); - + auto* wrapper = new PromiseWrapper("Server::m_init_fn"); + m_init_fn(wrapper); + auto ok = wrapper->get_future(); if (!ok) { DVLOG(10) << "server stream could not be initialized"; diff --git a/cpp/mrc/src/internal/memory/device_resources.cpp b/cpp/mrc/src/internal/memory/device_resources.cpp index 907eb1a4a..9ec0f5b04 100644 --- a/cpp/mrc/src/internal/memory/device_resources.cpp +++ b/cpp/mrc/src/internal/memory/device_resources.cpp @@ -35,16 +35,12 @@ #include "mrc/types.hpp" #include "mrc/utils/bytes_to_string.hpp" -#include #include -#include -#include #include #include #include #include -#include namespace mrc::memory { diff --git a/cpp/mrc/src/internal/memory/host_resources.cpp b/cpp/mrc/src/internal/memory/host_resources.cpp index c98c78618..42acfd32b 100644 --- a/cpp/mrc/src/internal/memory/host_resources.cpp +++ b/cpp/mrc/src/internal/memory/host_resources.cpp @@ -35,13 +35,10 @@ #include "mrc/types.hpp" #include "mrc/utils/bytes_to_string.hpp" -#include #include -#include #include #include -#include #include #include #include diff --git a/cpp/mrc/src/internal/network/network_resources.cpp b/cpp/mrc/src/internal/network/network_resources.cpp index b28a0d14f..ea078bee5 100644 --- a/cpp/mrc/src/internal/network/network_resources.cpp +++ b/cpp/mrc/src/internal/network/network_resources.cpp @@ -27,7 +27,6 @@ #include "mrc/core/task_queue.hpp" #include "mrc/types.hpp" -#include #include #include diff --git a/cpp/mrc/src/internal/pipeline/controller.cpp b/cpp/mrc/src/internal/pipeline/controller.cpp index 93946abbe..459817351 100644 --- a/cpp/mrc/src/internal/pipeline/controller.cpp +++ b/cpp/mrc/src/internal/pipeline/controller.cpp @@ -31,12 +31,10 @@ #include #include #include -#include #include #include #include #include -#include namespace mrc::pipeline { diff --git a/cpp/mrc/src/internal/pipeline/manager.cpp b/cpp/mrc/src/internal/pipeline/manager.cpp index 0487fdfb9..abec10d4a 100644 --- a/cpp/mrc/src/internal/pipeline/manager.cpp +++ b/cpp/mrc/src/internal/pipeline/manager.cpp @@ -34,16 +34,14 @@ #include #include -#include #include #include -#include #include -#include namespace mrc::pipeline { Manager::Manager(std::shared_ptr pipeline, resources::Manager& resources) : + Service("pipeline::Manager"), m_pipeline(std::move(pipeline)), m_resources(resources) { diff --git a/cpp/mrc/src/internal/pipeline/pipeline_instance.cpp b/cpp/mrc/src/internal/pipeline/pipeline_instance.cpp index 50e3abca1..dddd73a3c 100644 --- a/cpp/mrc/src/internal/pipeline/pipeline_instance.cpp +++ b/cpp/mrc/src/internal/pipeline/pipeline_instance.cpp @@ -24,6 +24,7 @@ #include "internal/runnable/runnable_resources.hpp" #include "internal/segment/segment_definition.hpp" #include "internal/segment/segment_instance.hpp" +#include "internal/service.hpp" #include "mrc/core/addresses.hpp" #include "mrc/core/task_queue.hpp" @@ -46,13 +47,17 @@ namespace mrc::pipeline { PipelineInstance::PipelineInstance(std::shared_ptr definition, resources::Manager& resources) : PipelineResources(resources), + Service("pipeline::PipelineInstance"), m_definition(std::move(definition)) { CHECK(m_definition); m_joinable_future = m_joinable_promise.get_future().share(); } -PipelineInstance::~PipelineInstance() = default; +PipelineInstance::~PipelineInstance() +{ + Service::call_in_destructor(); +} void PipelineInstance::update() { diff --git a/cpp/mrc/src/internal/pipeline/pipeline_instance.hpp b/cpp/mrc/src/internal/pipeline/pipeline_instance.hpp index d9f2489b8..7dc51e38e 100644 --- a/cpp/mrc/src/internal/pipeline/pipeline_instance.hpp +++ b/cpp/mrc/src/internal/pipeline/pipeline_instance.hpp @@ -25,12 +25,13 @@ #include #include #include +// IWYU pragma: no_include "internal/segment/segment_instance.hpp" namespace mrc::resources { class Manager; } // namespace mrc::resources namespace mrc::segment { -class SegmentInstance; +class SegmentInstance; // IWYU pragma: keep } // namespace mrc::segment namespace mrc::manifold { struct Interface; diff --git a/cpp/mrc/src/internal/pubsub/publisher_service.cpp b/cpp/mrc/src/internal/pubsub/publisher_service.cpp index 2ea517e44..5175e5315 100644 --- a/cpp/mrc/src/internal/pubsub/publisher_service.cpp +++ b/cpp/mrc/src/internal/pubsub/publisher_service.cpp @@ -39,10 +39,8 @@ #include #include -#include #include #include -#include namespace mrc::pubsub { diff --git a/cpp/mrc/src/internal/pubsub/subscriber_service.cpp b/cpp/mrc/src/internal/pubsub/subscriber_service.cpp index c53dac546..fba47135b 100644 --- a/cpp/mrc/src/internal/pubsub/subscriber_service.cpp +++ b/cpp/mrc/src/internal/pubsub/subscriber_service.cpp @@ -27,6 +27,7 @@ #include "internal/runtime/partition.hpp" #include "mrc/edge/edge_builder.hpp" +#include "mrc/edge/edge_writable.hpp" #include "mrc/node/operators/router.hpp" #include "mrc/node/rx_sink.hpp" #include "mrc/protos/codable.pb.h" @@ -41,7 +42,6 @@ #include #include #include -#include namespace mrc::pubsub { diff --git a/cpp/mrc/src/internal/remote_descriptor/manager.cpp b/cpp/mrc/src/internal/remote_descriptor/manager.cpp index fe73a61bc..b624b7c82 100644 --- a/cpp/mrc/src/internal/remote_descriptor/manager.cpp +++ b/cpp/mrc/src/internal/remote_descriptor/manager.cpp @@ -55,9 +55,7 @@ #include #include #include -#include #include -#include namespace mrc::remote_descriptor { @@ -86,6 +84,7 @@ ucs_status_t active_message_callback(void* arg, } // namespace Manager::Manager(const InstanceID& instance_id, resources::PartitionResources& resources) : + Service("remote_descriptor::Manager"), m_instance_id(instance_id), m_resources(resources) { diff --git a/cpp/mrc/src/internal/resources/manager.cpp b/cpp/mrc/src/internal/resources/manager.cpp index b47334c04..fab210109 100644 --- a/cpp/mrc/src/internal/resources/manager.cpp +++ b/cpp/mrc/src/internal/resources/manager.cpp @@ -26,6 +26,7 @@ #include "internal/network/network_resources.hpp" #include "internal/resources/partition_resources_base.hpp" #include "internal/runnable/runnable_resources.hpp" +#include "internal/system/device_partition.hpp" #include "internal/system/engine_factory_cpu_sets.hpp" #include "internal/system/host_partition.hpp" #include "internal/system/partition.hpp" @@ -45,6 +46,7 @@ #include #include +#include #include #include #include @@ -54,16 +56,18 @@ namespace mrc::resources { +std::atomic_size_t Manager::s_id_counter = 0; thread_local Manager* Manager::m_thread_resources{nullptr}; thread_local PartitionResources* Manager::m_thread_partition{nullptr}; Manager::Manager(const system::SystemProvider& system) : SystemProvider(system), + m_runtime_id(++s_id_counter), m_threading(std::make_unique(system)) { const auto& partitions = this->system().partitions().flattened(); const auto& host_partitions = this->system().partitions().host_partitions(); - const bool network_enabled = !this->system().options().architect_url().empty(); + bool network_enabled = !this->system().options().architect_url().empty(); // construct the runnable resources on each host_partition - launch control and main for (std::size_t i = 0; i < host_partitions.size(); ++i) @@ -197,6 +201,11 @@ Manager::~Manager() m_network.clear(); } +std::size_t Manager::runtime_id() const +{ + return m_runtime_id; +} + std::size_t Manager::partition_count() const { return system().partitions().flattened().size(); diff --git a/cpp/mrc/src/internal/resources/manager.hpp b/cpp/mrc/src/internal/resources/manager.hpp index a823bbe27..55e4af014 100644 --- a/cpp/mrc/src/internal/resources/manager.hpp +++ b/cpp/mrc/src/internal/resources/manager.hpp @@ -24,25 +24,29 @@ #include "mrc/types.hpp" +#include #include #include #include #include +// IWYU pragma: no_include "internal/memory/device_resources.hpp" +// IWYU pragma: no_include "internal/network/network_resources.hpp" +// IWYU pragma: no_include "internal/ucx/ucx_resources.hpp" namespace mrc::network { -class NetworkResources; +class NetworkResources; // IWYU pragma: keep } // namespace mrc::network namespace mrc::control_plane { class ControlPlaneResources; } // namespace mrc::control_plane namespace mrc::memory { -class DeviceResources; +class DeviceResources; // IWYU pragma: keep } // namespace mrc::memory namespace mrc::system { class ThreadingResources; } // namespace mrc::system namespace mrc::ucx { -class UcxResources; +class UcxResources; // IWYU pragma: keep } // namespace mrc::ucx namespace mrc::runtime { class Runtime; @@ -57,6 +61,8 @@ class Manager final : public system::SystemProvider // Manager(std::unique_ptr resources); ~Manager() override; + std::size_t runtime_id() const; + static Manager& get_resources(); static PartitionResources& get_partition(); @@ -68,6 +74,8 @@ class Manager final : public system::SystemProvider private: Future shutdown(); + const size_t m_runtime_id; // unique id for this runtime + const std::unique_ptr m_threading; std::vector m_runnable; // one per host partition std::vector> m_ucx; // one per flattened partition if network is enabled @@ -82,6 +90,7 @@ class Manager final : public system::SystemProvider // which must be destroyed before all other std::vector> m_network; // one per flattened partition + static std::atomic_size_t s_id_counter; static thread_local PartitionResources* m_thread_partition; static thread_local Manager* m_thread_resources; diff --git a/cpp/mrc/src/internal/runnable/fiber_engine.cpp b/cpp/mrc/src/internal/runnable/fiber_engine.cpp index 10dc1eb51..f208d5791 100644 --- a/cpp/mrc/src/internal/runnable/fiber_engine.cpp +++ b/cpp/mrc/src/internal/runnable/fiber_engine.cpp @@ -21,8 +21,6 @@ #include "mrc/runnable/types.hpp" #include "mrc/types.hpp" -#include - #include namespace mrc::runnable { diff --git a/cpp/mrc/src/internal/runnable/fiber_engines.cpp b/cpp/mrc/src/internal/runnable/fiber_engines.cpp index 87dfa5556..ed720803c 100644 --- a/cpp/mrc/src/internal/runnable/fiber_engines.cpp +++ b/cpp/mrc/src/internal/runnable/fiber_engines.cpp @@ -27,7 +27,6 @@ #include #include -#include #include namespace mrc::runnable { diff --git a/cpp/mrc/src/internal/runnable/runnable_resources.cpp b/cpp/mrc/src/internal/runnable/runnable_resources.cpp index 4fa98f1ce..9930c7778 100644 --- a/cpp/mrc/src/internal/runnable/runnable_resources.cpp +++ b/cpp/mrc/src/internal/runnable/runnable_resources.cpp @@ -27,7 +27,6 @@ #include "mrc/runnable/types.hpp" #include "mrc/types.hpp" -#include #include #include diff --git a/cpp/mrc/src/internal/runnable/thread_engine.cpp b/cpp/mrc/src/internal/runnable/thread_engine.cpp index fb18c3b60..b22edd730 100644 --- a/cpp/mrc/src/internal/runnable/thread_engine.cpp +++ b/cpp/mrc/src/internal/runnable/thread_engine.cpp @@ -24,7 +24,6 @@ #include "mrc/runnable/types.hpp" #include "mrc/types.hpp" -#include #include #include diff --git a/cpp/mrc/src/internal/runnable/thread_engines.cpp b/cpp/mrc/src/internal/runnable/thread_engines.cpp index 23f9c430a..92ea1a65e 100644 --- a/cpp/mrc/src/internal/runnable/thread_engines.cpp +++ b/cpp/mrc/src/internal/runnable/thread_engines.cpp @@ -28,7 +28,6 @@ #include #include #include -#include #include namespace mrc::runnable { diff --git a/cpp/mrc/src/internal/segment/builder_definition.cpp b/cpp/mrc/src/internal/segment/builder_definition.cpp index e631c3f1e..b11614328 100644 --- a/cpp/mrc/src/internal/segment/builder_definition.cpp +++ b/cpp/mrc/src/internal/segment/builder_definition.cpp @@ -28,9 +28,9 @@ #include "mrc/modules/properties/persistent.hpp" // IWYU pragma: keep #include "mrc/modules/segment_modules.hpp" #include "mrc/node/port_registry.hpp" +#include "mrc/runnable/launchable.hpp" #include "mrc/segment/egress_port.hpp" // IWYU pragma: keep #include "mrc/segment/ingress_port.hpp" // IWYU pragma: keep -#include "mrc/segment/initializers.hpp" #include "mrc/segment/object.hpp" #include "mrc/types.hpp" diff --git a/cpp/mrc/src/internal/segment/segment_instance.cpp b/cpp/mrc/src/internal/segment/segment_instance.cpp index 871b7a2ca..53f66b804 100644 --- a/cpp/mrc/src/internal/segment/segment_instance.cpp +++ b/cpp/mrc/src/internal/segment/segment_instance.cpp @@ -36,7 +36,6 @@ #include "mrc/segment/utils.hpp" #include "mrc/types.hpp" -#include #include #include @@ -54,6 +53,7 @@ SegmentInstance::SegmentInstance(std::shared_ptr defini SegmentRank rank, pipeline::PipelineResources& resources, std::size_t partition_id) : + Service("segment::SegmentInstance"), m_name(definition->name()), m_id(definition->id()), m_rank(rank), @@ -78,7 +78,10 @@ SegmentInstance::SegmentInstance(std::shared_ptr defini .get(); } -SegmentInstance::~SegmentInstance() = default; +SegmentInstance::~SegmentInstance() +{ + Service::call_in_destructor(); +} const std::string& SegmentInstance::name() const { diff --git a/cpp/mrc/src/internal/service.cpp b/cpp/mrc/src/internal/service.cpp index 01c51b014..3ea3f6b90 100644 --- a/cpp/mrc/src/internal/service.cpp +++ b/cpp/mrc/src/internal/service.cpp @@ -17,131 +17,293 @@ #include "internal/service.hpp" +#include "mrc/core/utils.hpp" +#include "mrc/exceptions/runtime_error.hpp" +#include "mrc/utils/string_utils.hpp" + #include -#include +#include +#include // for function +#include +#include // for operator<<, basic_ostream #include namespace mrc { +Service::Service(std::string service_name) : m_service_name(std::move(service_name)) {} + Service::~Service() { + if (!m_call_in_destructor_called) + { + LOG(ERROR) << "Must call Service::call_in_destructor to ensure service is cleaned up before being " + "destroyed"; + } + auto state = this->state(); CHECK(state == ServiceState::Initialized || state == ServiceState::Completed); } +const std::string& Service::service_name() const +{ + return m_service_name; +} + +bool Service::is_service_startable() const +{ + std::lock_guard lock(m_mutex); + return (m_state == ServiceState::Initialized); +} + +bool Service::is_running() const +{ + std::lock_guard lock(m_mutex); + return (m_state > ServiceState::Initialized && m_state < ServiceState::Completed); +} + +const ServiceState& Service::state() const +{ + std::lock_guard lock(m_mutex); + return m_state; +} + void Service::service_start() { - if (forward_state(ServiceState::Running)) + std::unique_lock lock(m_mutex); + + if (!this->is_service_startable()) { - do_service_start(); + throw exceptions::MrcRuntimeError(MRC_CONCAT_STR(this->debug_prefix() << " Service has already been started")); + } + + if (advance_state(ServiceState::Starting)) + { + // Unlock the mutex before calling start to avoid a deadlock + lock.unlock(); + + try + { + this->do_service_start(); + + // Use ensure_state here in case the service itself called stop or kill + this->ensure_state(ServiceState::Running); + } catch (...) + { + // On error, set this to completed and rethrow the error to allow for cleanup + this->advance_state(ServiceState::Completed); + + throw; + } } } void Service::service_await_live() { - do_service_await_live(); + { + std::unique_lock lock(m_mutex); + + if (this->is_service_startable()) + { + throw exceptions::MrcRuntimeError(MRC_CONCAT_STR(this->debug_prefix() << " Service must be started before " + "awaiting live")); + } + + // Check if this is our first call to service_await_join + if (!m_service_await_live_called) + { + // Prevent reentry + m_service_await_live_called = true; + + // We now create a promise and a future to track the completion of this function + Promise live_promise; + + m_live_future = live_promise.get_future(); + + // Unlock the mutex before calling await to avoid a deadlock + lock.unlock(); + + try + { + // Now call the await join (this can throw!) + this->do_service_await_live(); + + // Set the value only if there was not an exception + live_promise.set_value(); + + } catch (...) + { + // Join must have thrown, set the exception in the promise (it will be retrieved later) + live_promise.set_exception(std::current_exception()); + } + } + } + + // Wait for the future to be returned. This will rethrow any exception thrown in do_service_await_join + m_live_future.get(); } void Service::service_stop() { - bool execute = false; + std::unique_lock lock(m_mutex); + + if (this->is_service_startable()) { - std::lock_guard lock(m_mutex); - if (m_state < ServiceState::Stopping) - { - execute = (m_state < ServiceState::Stopping); - m_state = ServiceState::Stopping; - } + throw exceptions::MrcRuntimeError(MRC_CONCAT_STR(this->debug_prefix() << " Service must be started before " + "stopping")); } - if (execute) + + // Ensure we are at least in the stopping state. If so, execute the stop call + if (this->ensure_state(ServiceState::Stopping)) { - do_service_stop(); + lock.unlock(); + + this->do_service_stop(); } } void Service::service_kill() { - bool execute = false; + std::unique_lock lock(m_mutex); + + if (this->is_service_startable()) { - std::lock_guard lock(m_mutex); - if (m_state < ServiceState::Killing) - { - execute = (m_state < ServiceState::Killing); - m_state = ServiceState::Killing; - } + throw exceptions::MrcRuntimeError(MRC_CONCAT_STR(this->debug_prefix() << " Service must be started before " + "killing")); } - if (execute) + + // Ensure we are at least in the stopping state. If so, execute the stop call + if (this->ensure_state(ServiceState::Killing)) { - do_service_kill(); + lock.unlock(); + + this->do_service_kill(); } } void Service::service_await_join() { - bool execute = false; { - std::lock_guard lock(m_mutex); - if (m_state < ServiceState::Completed) + std::unique_lock lock(m_mutex); + + if (this->is_service_startable()) { - execute = (m_state < ServiceState::Completed); - m_state = ServiceState::Awaiting; + throw exceptions::MrcRuntimeError(MRC_CONCAT_STR(this->debug_prefix() << " Service must be started before " + "awaiting join")); } - } - if (execute) - { - do_service_await_join(); - forward_state(ServiceState::Completed); - } -} -const ServiceState& Service::state() const -{ - std::lock_guard lock(m_mutex); - return m_state; -} + // Check if this is our first call to service_await_join + if (!m_service_await_join_called) + { + // Prevent reentry + m_service_await_join_called = true; -bool Service::is_service_startable() const -{ - std::lock_guard lock(m_mutex); - return (m_state == ServiceState::Initialized); + // We now create a promise and a future to track the completion of the service + Promise completed_promise; + + m_completed_future = completed_promise.get_future(); + + // Unlock the mutex before calling await join to avoid a deadlock + lock.unlock(); + + try + { + Unwinder ensure_completed_set([this]() { + // Always set the state to completed before releasing the future + this->advance_state(ServiceState::Completed); + }); + + // Now call the await join (this can throw!) + this->do_service_await_join(); + + // Set the value only if there was not an exception + completed_promise.set_value(); + + } catch (const std::exception& ex) + { + LOG(ERROR) << this->debug_prefix() << " caught exception in service_await_join: " << ex.what(); + // Join must have thrown, set the exception in the promise (it will be retrieved later) + completed_promise.set_exception(std::current_exception()); + } + } + } + + // Wait for the completed future to be returned. This will rethrow any exception thrown in do_service_await_join + m_completed_future.get(); } -bool Service::forward_state(ServiceState new_state) +std::string Service::debug_prefix() const { - std::lock_guard lock(m_mutex); - CHECK(m_state <= new_state) << m_description - << ": invalid ServiceState requested; ServiceState is only allowed to advance"; - if (m_state < new_state) - { - m_state = new_state; - return true; - } - return false; + return MRC_CONCAT_STR("Service[" << m_service_name << "]:"); } void Service::call_in_destructor() { + // Guarantee that we set the flag that this was called + Unwinder ensure_flag([this]() { + m_call_in_destructor_called = true; + }); + auto state = this->state(); if (state > ServiceState::Initialized) { if (state == ServiceState::Running) { - LOG(ERROR) << m_description << ": service was not stopped/killed before being destructed; issuing kill"; - service_kill(); + LOG(ERROR) << this->debug_prefix() + << ": service was not stopped/killed before being destructed; issuing kill"; + this->service_kill(); } if (state != ServiceState::Completed) { - LOG(ERROR) << m_description << ": service was not joined before being destructed; issuing join"; - service_await_join(); + LOG(ERROR) << this->debug_prefix() << ": service was not joined before being destructed; issuing join"; + this->service_await_join(); } } } void Service::service_set_description(std::string description) { - m_description = std::move(description); + m_service_name = std::move(description); +} + +bool Service::advance_state(ServiceState new_state, bool assert_state_change) +{ + std::lock_guard lock(m_mutex); + + // State needs to always be moving foward or the same + CHECK_GE(new_state, m_state) << this->debug_prefix() + << " invalid ServiceState requested; ServiceState is only allowed to advance. " + "Current: " + << m_state << ", Requested: " << new_state; + + if (m_state < new_state) + { + DVLOG(20) << this->debug_prefix() << " advancing state. From: " << m_state << " to " << new_state; + + m_state = new_state; + + return true; + } + + CHECK(!assert_state_change) << this->debug_prefix() + << " invalid ServiceState requested; ServiceState was required to move forward " + "but the state was already set to " + << m_state; + + return false; +} + +bool Service::ensure_state(ServiceState desired_state) +{ + std::lock_guard lock(m_mutex); + + if (desired_state > m_state) + { + return advance_state(desired_state); + } + + return false; } } // namespace mrc diff --git a/cpp/mrc/src/internal/service.hpp b/cpp/mrc/src/internal/service.hpp index f707321e2..d24e059c5 100644 --- a/cpp/mrc/src/internal/service.hpp +++ b/cpp/mrc/src/internal/service.hpp @@ -17,7 +17,10 @@ #pragma once -#include +#include "mrc/types.hpp" + +#include // for ostream +#include // for logic_error #include namespace mrc { @@ -25,44 +28,90 @@ namespace mrc { enum class ServiceState { Initialized, + Starting, Running, - Awaiting, Stopping, Killing, Completed, }; -// struct IService -// { -// virtual ~IService() = default; +/** + * @brief Converts a `AsyncServiceState` enum to a string + * + * @param f + * @return std::string + */ +inline std::string servicestate_to_str(const ServiceState& s) +{ + switch (s) + { + case ServiceState::Initialized: + return "Initialized"; + case ServiceState::Starting: + return "Starting"; + case ServiceState::Running: + return "Running"; + case ServiceState::Stopping: + return "Stopping"; + case ServiceState::Killing: + return "Killing"; + case ServiceState::Completed: + return "Completed"; + default: + throw std::logic_error("Unsupported ServiceState enum. Was a new value added recently?"); + } +} -// virtual void service_start() = 0; -// virtual void service_await_live() = 0; -// virtual void service_stop() = 0; -// virtual void service_kill() = 0; -// virtual void service_await_join() = 0; -// }; +/** + * @brief Stream operator for `AsyncServiceState` + * + * @param os + * @param f + * @return std::ostream& + */ +static inline std::ostream& operator<<(std::ostream& os, const ServiceState& f) +{ + os << servicestate_to_str(f); + return os; +} -class Service // : public IService +class Service { public: virtual ~Service(); + const std::string& service_name() const; + + bool is_service_startable() const; + + bool is_running() const; + + const ServiceState& state() const; + void service_start(); void service_await_live(); void service_stop(); void service_kill(); void service_await_join(); - bool is_service_startable() const; - const ServiceState& state() const; - protected: + Service(std::string service_name); + + // Prefix to use for debug messages. Contains useful information about the service + std::string debug_prefix() const; + void call_in_destructor(); void service_set_description(std::string description); private: - bool forward_state(ServiceState new_state); + // Advances the state. New state value must be greater than or equal to current state. Using a value less than the + // current state will generate an error. Use assert_forward = false to require that the state advances. Normally, + // same states are fine + bool advance_state(ServiceState new_state, bool assert_state_change = false); + + // Ensures the state is at least the current value or higher. Does not change the state if the value is less than or + // equal the current state + bool ensure_state(ServiceState desired_state); virtual void do_service_start() = 0; virtual void do_service_await_live() = 0; @@ -71,8 +120,21 @@ class Service // : public IService virtual void do_service_await_join() = 0; ServiceState m_state{ServiceState::Initialized}; - std::string m_description{"mrc::service"}; - mutable std::mutex m_mutex; + std::string m_service_name{"mrc::Service"}; + + // This future is set in `service_await_live` and is used to wait for the service to to be live. We use a future + // here in case it is called multiple times, so that all callers will all be released when the service is live. + SharedFuture m_live_future; + + // This future is set in `service_await_join` and is used to wait for the service to complete. We use a future here + // in case join is called multiple times, so that all callers will all be released when the service completes. + SharedFuture m_completed_future; + + bool m_service_await_live_called{false}; + bool m_service_await_join_called{false}; + bool m_call_in_destructor_called{false}; + + mutable RecursiveMutex m_mutex; }; } // namespace mrc diff --git a/cpp/mrc/src/internal/system/fiber_manager.cpp b/cpp/mrc/src/internal/system/fiber_manager.cpp index 2eec52f12..5a73dcab7 100644 --- a/cpp/mrc/src/internal/system/fiber_manager.cpp +++ b/cpp/mrc/src/internal/system/fiber_manager.cpp @@ -26,9 +26,11 @@ #include "mrc/exceptions/runtime_error.hpp" #include "mrc/options/fiber_pool.hpp" #include "mrc/options/options.hpp" +#include "mrc/utils/string_utils.hpp" #include #include +#include namespace mrc::system { @@ -44,7 +46,7 @@ FiberManager::FiberManager(const ThreadingResources& resources) : m_cpu_set(reso topology.cpu_set().for_each_bit([&](std::int32_t idx, std::int32_t cpu_id) { DVLOG(10) << "initializing fiber queue " << idx << " of " << cpu_count << " on cpu_id " << cpu_id; - m_queues[cpu_id] = std::make_unique(resources, cpu_id); + m_queues[cpu_id] = std::make_unique(resources, cpu_id, MRC_CONCAT_STR("fibq[" << idx << "]")); }); } diff --git a/cpp/mrc/src/internal/system/fiber_task_queue.cpp b/cpp/mrc/src/internal/system/fiber_task_queue.cpp index 709be264e..5af806d21 100644 --- a/cpp/mrc/src/internal/system/fiber_task_queue.cpp +++ b/cpp/mrc/src/internal/system/fiber_task_queue.cpp @@ -28,7 +28,6 @@ #include #include #include -#include #include #include @@ -39,12 +38,16 @@ namespace mrc::system { -FiberTaskQueue::FiberTaskQueue(const ThreadingResources& resources, CpuSet cpu_affinity, std::size_t channel_size) : +FiberTaskQueue::FiberTaskQueue(const ThreadingResources& resources, + CpuSet cpu_affinity, + std::string thread_name, + std::size_t channel_size) : m_queue(channel_size), m_cpu_affinity(std::move(cpu_affinity)), - m_thread(resources.make_thread("fiberq", m_cpu_affinity, [this] { + m_thread(resources.make_thread(std::move(thread_name), m_cpu_affinity, [this] { main(); })) + { DVLOG(10) << "awaiting fiber task queue worker thread running on cpus " << m_cpu_affinity; enqueue([] {}).get(); @@ -106,7 +109,7 @@ void FiberTaskQueue::launch(task_pkg_t&& pkg) const boost::fibers::fiber fiber(std::move(pkg.first)); auto& props(fiber.properties()); props.set_priority(pkg.second.priority); - DVLOG(10) << *this << ": created fiber " << fiber.get_id() << " with priority " << pkg.second.priority; + DVLOG(20) << *this << ": created fiber " << fiber.get_id() << " with priority " << pkg.second.priority; fiber.detach(); } diff --git a/cpp/mrc/src/internal/system/fiber_task_queue.hpp b/cpp/mrc/src/internal/system/fiber_task_queue.hpp index c58c8190b..ccd7499b5 100644 --- a/cpp/mrc/src/internal/system/fiber_task_queue.hpp +++ b/cpp/mrc/src/internal/system/fiber_task_queue.hpp @@ -27,6 +27,7 @@ #include #include +#include #include namespace mrc::system { @@ -36,7 +37,10 @@ class ThreadingResources; class FiberTaskQueue final : public core::FiberTaskQueue { public: - FiberTaskQueue(const ThreadingResources& resources, CpuSet cpu_affinity, std::size_t channel_size = 64); + FiberTaskQueue(const ThreadingResources& resources, + CpuSet cpu_affinity, + std::string thread_name, + std::size_t channel_size = 64); ~FiberTaskQueue() final; DELETE_COPYABILITY(FiberTaskQueue); diff --git a/cpp/mrc/src/internal/system/host_partition_provider.cpp b/cpp/mrc/src/internal/system/host_partition_provider.cpp index 953833435..42a579547 100644 --- a/cpp/mrc/src/internal/system/host_partition_provider.cpp +++ b/cpp/mrc/src/internal/system/host_partition_provider.cpp @@ -17,6 +17,7 @@ #include "internal/system/host_partition_provider.hpp" +#include "internal/system/host_partition.hpp" #include "internal/system/partitions.hpp" #include "internal/system/system.hpp" @@ -25,7 +26,6 @@ #include namespace mrc::system { -class HostPartition; HostPartitionProvider::HostPartitionProvider(const SystemProvider& _system, std::size_t _host_partition_id) : SystemProvider(_system), diff --git a/cpp/mrc/src/internal/system/partition_provider.cpp b/cpp/mrc/src/internal/system/partition_provider.cpp index 33feb2c77..7597da9cc 100644 --- a/cpp/mrc/src/internal/system/partition_provider.cpp +++ b/cpp/mrc/src/internal/system/partition_provider.cpp @@ -17,6 +17,7 @@ #include "internal/system/partition_provider.hpp" +#include "internal/system/partition.hpp" #include "internal/system/partitions.hpp" #include "internal/system/system.hpp" diff --git a/cpp/mrc/src/internal/system/thread.cpp b/cpp/mrc/src/internal/system/thread.cpp index 413e86f6c..04345006f 100644 --- a/cpp/mrc/src/internal/system/thread.cpp +++ b/cpp/mrc/src/internal/system/thread.cpp @@ -90,13 +90,13 @@ void ThreadResources::initialize_thread(const std::string& desc, const CpuSet& c { std::stringstream ss; ss << "cpu_id: " << cpu_affinity.first(); - affinity = ss.str(); + affinity = MRC_CONCAT_STR("cpu[" << cpu_affinity.str() << "]"); } else { std::stringstream ss; ss << "cpus: " << cpu_affinity.str(); - affinity = ss.str(); + affinity = MRC_CONCAT_STR("cpu[" << cpu_affinity.str() << "]"); auto numa_set = topology.numaset_for_cpuset(cpu_affinity); if (numa_set.weight() != 1) { @@ -110,13 +110,13 @@ void ThreadResources::initialize_thread(const std::string& desc, const CpuSet& c DVLOG(10) << "tid: " << std::this_thread::get_id() << "; setting cpu affinity to " << affinity; auto rc = hwloc_set_cpubind(topology.handle(), &cpu_affinity.bitmap(), HWLOC_CPUBIND_THREAD); CHECK_NE(rc, -1); - set_current_thread_name(MRC_CONCAT_STR("[" << desc << "; " << affinity << "]")); + set_current_thread_name(MRC_CONCAT_STR(desc << ";" << affinity)); } else { DVLOG(10) << "thread_binding is disabled; tid: " << std::this_thread::get_id() << " will use the affinity of caller"; - set_current_thread_name(MRC_CONCAT_STR("[" << desc << "; tid:" << std::this_thread::get_id() << "]")); + set_current_thread_name(MRC_CONCAT_STR(desc << ";tid[" << std::this_thread::get_id() << "]")); } // todo(ryan) - enable thread/memory binding should be a system option, not specifically a fiber_pool option diff --git a/cpp/mrc/src/internal/system/threading_resources.cpp b/cpp/mrc/src/internal/system/threading_resources.cpp index 27001092a..1e0f8c16b 100644 --- a/cpp/mrc/src/internal/system/threading_resources.cpp +++ b/cpp/mrc/src/internal/system/threading_resources.cpp @@ -19,9 +19,10 @@ #include "internal/system/fiber_manager.hpp" +#include "mrc/types.hpp" + #include -#include #include namespace mrc::system { diff --git a/cpp/mrc/src/internal/ucx/receive_manager.cpp b/cpp/mrc/src/internal/ucx/receive_manager.cpp index 2796bf84e..70cda928a 100644 --- a/cpp/mrc/src/internal/ucx/receive_manager.cpp +++ b/cpp/mrc/src/internal/ucx/receive_manager.cpp @@ -23,7 +23,6 @@ #include "mrc/types.hpp" #include -#include #include #include // for launch, launch::post #include // for ucp_tag_probe_nb, ucp_tag_recv_info diff --git a/cpp/mrc/src/internal/ucx/ucx_resources.cpp b/cpp/mrc/src/internal/ucx/ucx_resources.cpp index 458dd9814..1ce368662 100644 --- a/cpp/mrc/src/internal/ucx/ucx_resources.cpp +++ b/cpp/mrc/src/internal/ucx/ucx_resources.cpp @@ -30,7 +30,6 @@ #include "mrc/cuda/common.hpp" #include "mrc/types.hpp" -#include #include #include diff --git a/cpp/mrc/src/public/core/thread.cpp b/cpp/mrc/src/public/core/thread.cpp index 0553a8fe0..f81ecb38d 100644 --- a/cpp/mrc/src/public/core/thread.cpp +++ b/cpp/mrc/src/public/core/thread.cpp @@ -20,7 +20,6 @@ #include "mrc/coroutines/thread_pool.hpp" #include -#include #include #include #include diff --git a/cpp/mrc/src/public/coroutines/task_container.cpp b/cpp/mrc/src/public/coroutines/task_container.cpp new file mode 100644 index 000000000..e29b50fc2 --- /dev/null +++ b/cpp/mrc/src/public/coroutines/task_container.cpp @@ -0,0 +1,166 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mrc/coroutines/task_container.hpp" + +#include "mrc/coroutines/scheduler.hpp" + +#include + +#include +#include +#include +#include +#include +#include + +namespace mrc::coroutines { + +TaskContainer::TaskContainer(std::shared_ptr e) : + m_scheduler_lifetime(std::move(e)), + m_scheduler(m_scheduler_lifetime.get()) +{ + if (m_scheduler_lifetime == nullptr) + { + throw std::runtime_error{"TaskContainer cannot have a nullptr executor"}; + } +} + +TaskContainer::~TaskContainer() +{ + // This will hang the current thread.. but if tasks are not complete thats also pretty bad. + while (!this->empty()) + { + this->garbage_collect(); + } +} + +auto TaskContainer::start(Task&& user_task, GarbageCollectPolicy cleanup) -> void +{ + m_size.fetch_add(1, std::memory_order::relaxed); + + std::scoped_lock lk{m_mutex}; + + if (cleanup == GarbageCollectPolicy::yes) + { + gc_internal(); + } + + // Store the task inside a cleanup task for self deletion. + auto pos = m_tasks.emplace(m_tasks.end(), std::nullopt); + auto task = make_cleanup_task(std::move(user_task), pos); + *pos = std::move(task); + + // Start executing from the cleanup task to schedule the user's task onto the thread pool. + pos->value().resume(); +} + +auto TaskContainer::garbage_collect() -> std::size_t +{ + std::scoped_lock lk{m_mutex}; + return gc_internal(); +} + +auto TaskContainer::delete_task_size() const -> std::size_t +{ + std::atomic_thread_fence(std::memory_order::acquire); + return m_tasks_to_delete.size(); +} + +auto TaskContainer::delete_tasks_empty() const -> bool +{ + std::atomic_thread_fence(std::memory_order::acquire); + return m_tasks_to_delete.empty(); +} + +auto TaskContainer::size() const -> std::size_t +{ + return m_size.load(std::memory_order::relaxed); +} + +auto TaskContainer::empty() const -> bool +{ + return size() == 0; +} + +auto TaskContainer::capacity() const -> std::size_t +{ + std::atomic_thread_fence(std::memory_order::acquire); + return m_tasks.size(); +} + +auto TaskContainer::garbage_collect_and_yield_until_empty() -> Task +{ + while (!empty()) + { + garbage_collect(); + co_await m_scheduler->yield(); + } +} + +TaskContainer::TaskContainer(Scheduler& e) : m_scheduler(&e) {} +auto TaskContainer::gc_internal() -> std::size_t +{ + std::size_t deleted{0}; + if (!m_tasks_to_delete.empty()) + { + for (const auto& pos : m_tasks_to_delete) + { + // Destroy the cleanup task and the user task. + if (pos->has_value()) + { + pos->value().destroy(); + } + m_tasks.erase(pos); + } + deleted = m_tasks_to_delete.size(); + m_tasks_to_delete.clear(); + } + return deleted; +} + +auto TaskContainer::make_cleanup_task(Task user_task, task_position_t pos) -> Task +{ + // Immediately move the task onto the executor. + co_await m_scheduler->schedule(); + + try + { + // Await the users task to complete. + co_await user_task; + } catch (const std::exception& e) + { + // TODO(MDD): what would be a good way to report this to the user...? Catching here is required + // since the co_await will unwrap the unhandled exception on the task. + // The user's task should ideally be wrapped in a catch all and handle it themselves, but + // that cannot be guaranteed. + LOG(ERROR) << "coro::task_container user_task had an unhandled exception e.what()= " << e.what() << "\n"; + } catch (...) + { + // don't crash if they throw something that isn't derived from std::exception + LOG(ERROR) << "coro::task_container user_task had unhandle exception, not derived from std::exception.\n"; + } + + std::scoped_lock lk{m_mutex}; + m_tasks_to_delete.push_back(pos); + // This has to be done within scope lock to make sure this coroutine task completes before the + // task container object destructs -- if it was waiting on .empty() to become true. + m_size.fetch_sub(1, std::memory_order::relaxed); + co_return; +} + +} // namespace mrc::coroutines diff --git a/cpp/mrc/src/public/coroutines/thread_pool.cpp b/cpp/mrc/src/public/coroutines/thread_pool.cpp index e2724409e..805a64d2a 100644 --- a/cpp/mrc/src/public/coroutines/thread_pool.cpp +++ b/cpp/mrc/src/public/coroutines/thread_pool.cpp @@ -39,7 +39,6 @@ #include "mrc/coroutines/thread_pool.hpp" #include -#include #include #include diff --git a/cpp/mrc/src/public/exceptions/exception_catcher.cpp b/cpp/mrc/src/public/exceptions/exception_catcher.cpp new file mode 100644 index 000000000..c139436f7 --- /dev/null +++ b/cpp/mrc/src/public/exceptions/exception_catcher.cpp @@ -0,0 +1,50 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace mrc { + +void ExceptionCatcher::push_exception(std::exception_ptr ex) +{ + auto lock = std::lock_guard(m_mutex); + m_exceptions.push(ex); +} + +bool ExceptionCatcher::has_exception() +{ + auto lock = std::lock_guard(m_mutex); + return not m_exceptions.empty(); +} + +void ExceptionCatcher::rethrow_next_exception() +{ + auto lock = std::lock_guard(m_mutex); + + if (m_exceptions.empty()) + { + return; + } + + auto ex = m_exceptions.front(); + + m_exceptions.pop(); + + std::rethrow_exception(ex); +} + +} // namespace mrc diff --git a/cpp/mrc/src/public/modules/sample_modules.cpp b/cpp/mrc/src/public/modules/sample_modules.cpp index fe850615c..405dcfe3c 100644 --- a/cpp/mrc/src/public/modules/sample_modules.cpp +++ b/cpp/mrc/src/public/modules/sample_modules.cpp @@ -26,10 +26,8 @@ #include -#include #include #include -#include namespace mrc::modules { diff --git a/cpp/mrc/src/tests/CMakeLists.txt b/cpp/mrc/src/tests/CMakeLists.txt index 9a746e718..8ef8676fe 100644 --- a/cpp/mrc/src/tests/CMakeLists.txt +++ b/cpp/mrc/src/tests/CMakeLists.txt @@ -61,6 +61,7 @@ add_executable(test_mrc_private test_resources.cpp test_reusable_pool.cpp test_runnable.cpp + test_service.cpp test_system.cpp test_topology.cpp test_ucx.cpp diff --git a/cpp/mrc/src/tests/nodes/common_nodes.cpp b/cpp/mrc/src/tests/nodes/common_nodes.cpp index f7432f670..1c7acd824 100644 --- a/cpp/mrc/src/tests/nodes/common_nodes.cpp +++ b/cpp/mrc/src/tests/nodes/common_nodes.cpp @@ -28,13 +28,11 @@ #include #include -#include #include #include #include #include #include -#include using namespace mrc; using namespace mrc::memory::literals; diff --git a/cpp/mrc/src/tests/nodes/common_nodes.hpp b/cpp/mrc/src/tests/nodes/common_nodes.hpp index aa1ff13d2..bb19235e3 100644 --- a/cpp/mrc/src/tests/nodes/common_nodes.hpp +++ b/cpp/mrc/src/tests/nodes/common_nodes.hpp @@ -30,7 +30,6 @@ #include #include #include -#include namespace test::nodes { diff --git a/cpp/mrc/src/tests/pipelines/multi_segment.cpp b/cpp/mrc/src/tests/pipelines/multi_segment.cpp index 5157ef5d6..05dabd28c 100644 --- a/cpp/mrc/src/tests/pipelines/multi_segment.cpp +++ b/cpp/mrc/src/tests/pipelines/multi_segment.cpp @@ -18,7 +18,9 @@ #include "common_pipelines.hpp" #include "mrc/node/rx_sink.hpp" +#include "mrc/node/rx_sink_base.hpp" #include "mrc/node/rx_source.hpp" +#include "mrc/node/rx_source_base.hpp" #include "mrc/pipeline/pipeline.hpp" #include "mrc/segment/builder.hpp" #include "mrc/segment/egress_ports.hpp" @@ -29,11 +31,8 @@ #include #include -#include #include #include -#include -#include using namespace mrc; diff --git a/cpp/mrc/src/tests/segments/common_segments.cpp b/cpp/mrc/src/tests/segments/common_segments.cpp index 9e0f6b61d..eb1d0126d 100644 --- a/cpp/mrc/src/tests/segments/common_segments.cpp +++ b/cpp/mrc/src/tests/segments/common_segments.cpp @@ -28,7 +28,6 @@ #include #include -#include using namespace mrc; diff --git a/cpp/mrc/src/tests/test_control_plane.cpp b/cpp/mrc/src/tests/test_control_plane.cpp index 96d85945c..b49e5ae0d 100644 --- a/cpp/mrc/src/tests/test_control_plane.cpp +++ b/cpp/mrc/src/tests/test_control_plane.cpp @@ -27,6 +27,7 @@ #include "internal/runnable/runnable_resources.hpp" #include "internal/runtime/partition.hpp" #include "internal/runtime/runtime.hpp" +#include "internal/system/partition.hpp" #include "internal/system/partitions.hpp" #include "internal/system/system.hpp" #include "internal/system/system_provider.hpp" @@ -43,7 +44,6 @@ #include "mrc/pubsub/subscriber.hpp" #include "mrc/types.hpp" -#include #include #include #include @@ -66,7 +66,7 @@ static auto make_runtime(std::function options_lambda = { auto resources = std::make_unique( system::SystemProvider(tests::make_system([&](Options& options) { - options.topology().user_cpuset("0-3"); + options.topology().user_cpuset("0"); options.topology().restrict_gpus(true); options.placement().resources_strategy(PlacementResources::Dedicated); options.placement().cpu_strategy(PlacementStrategy::PerMachine); @@ -85,7 +85,10 @@ class TestControlPlane : public ::testing::Test TEST_F(TestControlPlane, LifeCycle) { - auto sr = make_runtime(); + auto sr = make_runtime([](Options& options) { + options.enable_server(true); + options.architect_url("localhost:13337"); + }); auto server = std::make_unique(sr->partition(0).resources().runnable()); server->service_start(); @@ -121,6 +124,35 @@ TEST_F(TestControlPlane, SingleClientConnectDisconnect) server->service_await_join(); } +TEST_F(TestControlPlane, SingleClientConnectDisconnectSingleCore) +{ + // Similar to SingleClientConnectDisconnect except both client & server are locked to the same core + // making issue #379 easier to reproduce. + auto sr = make_runtime([](Options& options) { + options.topology().user_cpuset("0"); + }); + auto server = std::make_unique(sr->partition(0).resources().runnable()); + + server->service_start(); + server->service_await_live(); + + auto cr = make_runtime([](Options& options) { + options.topology().user_cpuset("0"); + options.architect_url("localhost:13337"); + }); + + // the total number of partition is system dependent + auto expected_partitions = cr->resources().system().partitions().flattened().size(); + EXPECT_EQ(cr->partition(0).resources().network()->control_plane().client().connections().instance_ids().size(), + expected_partitions); + + // destroying the resources should gracefully shutdown the data plane and the control plane. + cr.reset(); + + server->service_stop(); + server->service_await_join(); +} + TEST_F(TestControlPlane, DoubleClientConnectExchangeDisconnect) { auto sr = make_runtime(); diff --git a/cpp/mrc/src/tests/test_grpc.cpp b/cpp/mrc/src/tests/test_grpc.cpp index 68acc2913..95ef5801a 100644 --- a/cpp/mrc/src/tests/test_grpc.cpp +++ b/cpp/mrc/src/tests/test_grpc.cpp @@ -43,21 +43,16 @@ #include "mrc/runnable/runner.hpp" #include "mrc/types.hpp" -#include #include #include #include #include #include -#include -#include #include #include -#include #include #include -#include // Avoid forward declaring template specialization base classes // IWYU pragma: no_forward_declare grpc::ServerAsyncReaderWriter diff --git a/cpp/mrc/src/tests/test_memory.cpp b/cpp/mrc/src/tests/test_memory.cpp index 2544827d3..65059071d 100644 --- a/cpp/mrc/src/tests/test_memory.cpp +++ b/cpp/mrc/src/tests/test_memory.cpp @@ -38,11 +38,9 @@ #include #include #include -#include #include #include #include -#include #include #include #include diff --git a/cpp/mrc/src/tests/test_network.cpp b/cpp/mrc/src/tests/test_network.cpp index 1a14cebf4..509649eed 100644 --- a/cpp/mrc/src/tests/test_network.cpp +++ b/cpp/mrc/src/tests/test_network.cpp @@ -38,6 +38,7 @@ #include "internal/ucx/registration_cache.hpp" #include "mrc/edge/edge_builder.hpp" +#include "mrc/edge/edge_writable.hpp" #include "mrc/memory/adaptors.hpp" #include "mrc/memory/buffer.hpp" #include "mrc/memory/literals.hpp" @@ -62,15 +63,11 @@ #include #include #include -#include #include #include #include -#include -#include #include #include -#include using namespace mrc; using namespace mrc::memory::literals; diff --git a/cpp/mrc/src/tests/test_next.cpp b/cpp/mrc/src/tests/test_next.cpp index da54e0a3f..1886664f7 100644 --- a/cpp/mrc/src/tests/test_next.cpp +++ b/cpp/mrc/src/tests/test_next.cpp @@ -25,6 +25,7 @@ #include "mrc/channel/ingress.hpp" #include "mrc/data/reusable_pool.hpp" #include "mrc/edge/edge_builder.hpp" +#include "mrc/edge/edge_writable.hpp" #include "mrc/node/generic_node.hpp" #include "mrc/node/generic_sink.hpp" #include "mrc/node/generic_source.hpp" @@ -64,12 +65,10 @@ #include #include #include -#include #include #include #include #include -#include using namespace mrc; @@ -573,7 +572,7 @@ TEST_F(TestNext, RxWithReusableOnNextAndOnError) }); static_assert(rxcpp::detail::is_on_next_of>::value, " "); - static_assert(rxcpp::detail::is_on_next_of>::value, " "); + static_assert(rxcpp::detail::is_on_next_of>::value, " "); auto observer = rxcpp::make_observer_dynamic( [](data_t&& int_ptr) { diff --git a/cpp/mrc/src/tests/test_pipeline.cpp b/cpp/mrc/src/tests/test_pipeline.cpp index 1b6e9c85f..0f23a6fa2 100644 --- a/cpp/mrc/src/tests/test_pipeline.cpp +++ b/cpp/mrc/src/tests/test_pipeline.cpp @@ -35,7 +35,9 @@ #include "mrc/node/queue.hpp" #include "mrc/node/rx_node.hpp" #include "mrc/node/rx_sink.hpp" +#include "mrc/node/rx_sink_base.hpp" // for RxSinkBase #include "mrc/node/rx_source.hpp" +#include "mrc/node/rx_source_base.hpp" // for RxSourceBase #include "mrc/options/engine_groups.hpp" #include "mrc/options/options.hpp" #include "mrc/options/placement.hpp" @@ -67,7 +69,6 @@ #include #include #include -#include #include #include #include @@ -111,7 +112,6 @@ static void run_custom_manager(std::unique_ptr pipeline, } }); - manager->service_start(); manager->push_updates(std::move(update)); manager->service_await_join(); @@ -139,7 +139,6 @@ static void run_manager(std::unique_ptr pipeline, bool dela } }); - manager->service_start(); manager->push_updates(std::move(update)); manager->service_await_join(); diff --git a/cpp/mrc/src/tests/test_remote_descriptor.cpp b/cpp/mrc/src/tests/test_remote_descriptor.cpp index df4468897..33c85a440 100644 --- a/cpp/mrc/src/tests/test_remote_descriptor.cpp +++ b/cpp/mrc/src/tests/test_remote_descriptor.cpp @@ -39,7 +39,6 @@ #include "mrc/runtime/remote_descriptor_handle.hpp" #include "mrc/types.hpp" -#include #include #include diff --git a/cpp/mrc/src/tests/test_resources.cpp b/cpp/mrc/src/tests/test_resources.cpp index b6b4c953f..6f1abebd0 100644 --- a/cpp/mrc/src/tests/test_resources.cpp +++ b/cpp/mrc/src/tests/test_resources.cpp @@ -28,7 +28,6 @@ #include "mrc/options/placement.hpp" #include "mrc/types.hpp" -#include #include #include diff --git a/cpp/mrc/src/tests/test_runnable.cpp b/cpp/mrc/src/tests/test_runnable.cpp index c5bc0a048..6c303d8a2 100644 --- a/cpp/mrc/src/tests/test_runnable.cpp +++ b/cpp/mrc/src/tests/test_runnable.cpp @@ -47,14 +47,12 @@ #include #include #include -#include #include #include #include #include #include #include -#include using namespace mrc; diff --git a/cpp/mrc/src/tests/test_service.cpp b/cpp/mrc/src/tests/test_service.cpp new file mode 100644 index 000000000..39a6a6b95 --- /dev/null +++ b/cpp/mrc/src/tests/test_service.cpp @@ -0,0 +1,407 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tests/common.hpp" // IWYU pragma: associated + +#include "internal/service.hpp" + +#include "mrc/exceptions/runtime_error.hpp" + +#include + +#include +#include // for size_t +#include // for function +#include // for move + +namespace mrc { + +class SimpleService : public Service +{ + public: + SimpleService(bool do_call_in_destructor = true) : + Service("SimpleService"), + m_do_call_in_destructor(do_call_in_destructor) + {} + + ~SimpleService() override + { + if (m_do_call_in_destructor) + { + Service::call_in_destructor(); + } + } + + size_t start_call_count() const + { + return m_start_call_count.load(); + } + + size_t stop_call_count() const + { + return m_stop_call_count.load(); + } + + size_t kill_call_count() const + { + return m_kill_call_count.load(); + } + + size_t await_live_call_count() const + { + return m_await_live_call_count.load(); + } + + size_t await_join_call_count() const + { + return m_await_join_call_count.load(); + } + + void set_start_callback(std::function callback) + { + m_start_callback = std::move(callback); + } + + void set_stop_callback(std::function callback) + { + m_stop_callback = std::move(callback); + } + + void set_kill_callback(std::function callback) + { + m_kill_callback = std::move(callback); + } + + void set_await_live_callback(std::function callback) + { + m_await_live_callback = std::move(callback); + } + + void set_await_join_callback(std::function callback) + { + m_await_join_callback = std::move(callback); + } + + private: + void do_service_start() final + { + if (m_start_callback) + { + m_start_callback(); + } + + m_start_call_count++; + } + + void do_service_stop() final + { + if (m_stop_callback) + { + m_stop_callback(); + } + + m_stop_call_count++; + } + + void do_service_kill() final + { + if (m_kill_callback) + { + m_kill_callback(); + } + + m_kill_call_count++; + } + + void do_service_await_live() final + { + if (m_await_live_callback) + { + m_await_live_callback(); + } + + m_await_live_call_count++; + } + + void do_service_await_join() final + { + if (m_await_join_callback) + { + m_await_join_callback(); + } + + m_await_join_call_count++; + } + + bool m_do_call_in_destructor{true}; + + std::atomic_size_t m_start_call_count{0}; + std::atomic_size_t m_stop_call_count{0}; + std::atomic_size_t m_kill_call_count{0}; + std::atomic_size_t m_await_live_call_count{0}; + std::atomic_size_t m_await_join_call_count{0}; + + std::function m_start_callback; + std::function m_stop_callback; + std::function m_kill_callback; + std::function m_await_live_callback; + std::function m_await_join_callback; +}; + +class TestService : public ::testing::Test +{ + protected: +}; + +TEST_F(TestService, LifeCycle) +{ + SimpleService service; + + service.service_start(); + + EXPECT_EQ(service.state(), ServiceState::Running); + EXPECT_EQ(service.start_call_count(), 1); + + service.service_await_live(); + + EXPECT_EQ(service.await_live_call_count(), 1); + + service.service_await_join(); + + EXPECT_EQ(service.state(), ServiceState::Completed); + EXPECT_EQ(service.await_join_call_count(), 1); + + EXPECT_EQ(service.stop_call_count(), 0); + EXPECT_EQ(service.kill_call_count(), 0); +} + +TEST_F(TestService, ServiceNotStarted) +{ + SimpleService service; + + EXPECT_ANY_THROW(service.service_await_live()); + EXPECT_ANY_THROW(service.service_stop()); + EXPECT_ANY_THROW(service.service_kill()); + EXPECT_ANY_THROW(service.service_await_join()); +} + +TEST_F(TestService, ServiceStop) +{ + SimpleService service; + + service.service_start(); + + EXPECT_EQ(service.state(), ServiceState::Running); + + service.service_stop(); + + EXPECT_EQ(service.state(), ServiceState::Stopping); + + service.service_await_join(); + + EXPECT_EQ(service.state(), ServiceState::Completed); + + EXPECT_EQ(service.stop_call_count(), 1); +} + +TEST_F(TestService, ServiceKill) +{ + SimpleService service; + + service.service_start(); + + EXPECT_EQ(service.state(), ServiceState::Running); + + service.service_kill(); + + EXPECT_EQ(service.state(), ServiceState::Killing); + + service.service_await_join(); + + EXPECT_EQ(service.state(), ServiceState::Completed); + + EXPECT_EQ(service.kill_call_count(), 1); +} + +TEST_F(TestService, ServiceStopThenKill) +{ + SimpleService service; + + service.service_start(); + + EXPECT_EQ(service.state(), ServiceState::Running); + + service.service_stop(); + + EXPECT_EQ(service.state(), ServiceState::Stopping); + + service.service_kill(); + + EXPECT_EQ(service.state(), ServiceState::Killing); + + service.service_await_join(); + + EXPECT_EQ(service.state(), ServiceState::Completed); + + EXPECT_EQ(service.stop_call_count(), 1); + EXPECT_EQ(service.kill_call_count(), 1); +} + +TEST_F(TestService, ServiceKillThenStop) +{ + SimpleService service; + + service.service_start(); + + EXPECT_EQ(service.state(), ServiceState::Running); + + service.service_kill(); + + EXPECT_EQ(service.state(), ServiceState::Killing); + + service.service_stop(); + + EXPECT_EQ(service.state(), ServiceState::Killing); + + service.service_await_join(); + + EXPECT_EQ(service.state(), ServiceState::Completed); + + EXPECT_EQ(service.stop_call_count(), 0); + EXPECT_EQ(service.kill_call_count(), 1); +} + +TEST_F(TestService, MultipleStartCalls) +{ + SimpleService service; + + service.service_start(); + + // Call again (should be an error) + EXPECT_ANY_THROW(service.service_start()); + + EXPECT_EQ(service.start_call_count(), 1); +} + +TEST_F(TestService, MultipleStopCalls) +{ + SimpleService service; + + service.service_start(); + + // Multiple calls to stop are fine + service.service_stop(); + service.service_stop(); + + EXPECT_EQ(service.stop_call_count(), 1); +} + +TEST_F(TestService, MultipleKillCalls) +{ + SimpleService service; + + service.service_start(); + + // Multiple calls to kill are fine + service.service_kill(); + service.service_kill(); + + EXPECT_EQ(service.kill_call_count(), 1); +} + +TEST_F(TestService, MultipleJoinCalls) +{ + SimpleService service; + + service.service_start(); + + service.service_await_live(); + + service.service_await_join(); + service.service_await_join(); + + EXPECT_EQ(service.await_join_call_count(), 1); +} + +TEST_F(TestService, StartWithException) +{ + SimpleService service; + + service.set_start_callback([]() { + throw exceptions::MrcRuntimeError("Live Exception"); + }); + + EXPECT_ANY_THROW(service.service_start()); + + EXPECT_EQ(service.state(), ServiceState::Completed); +} + +TEST_F(TestService, LiveWithException) +{ + SimpleService service; + + service.set_await_join_callback([]() { + throw exceptions::MrcRuntimeError("Live Exception"); + }); + + service.service_start(); + + EXPECT_ANY_THROW(service.service_await_join()); +} + +TEST_F(TestService, MultipleLiveWithException) +{ + SimpleService service; + + service.set_await_live_callback([]() { + throw exceptions::MrcRuntimeError("Live Exception"); + }); + + service.service_start(); + + EXPECT_ANY_THROW(service.service_await_live()); + EXPECT_ANY_THROW(service.service_await_live()); +} + +TEST_F(TestService, JoinWithException) +{ + SimpleService service; + + service.set_await_join_callback([]() { + throw exceptions::MrcRuntimeError("Join Exception"); + }); + + service.service_start(); + + EXPECT_ANY_THROW(service.service_await_join()); +} + +TEST_F(TestService, MultipleJoinWithException) +{ + SimpleService service; + + service.set_await_join_callback([]() { + throw exceptions::MrcRuntimeError("Join Exception"); + }); + + service.service_start(); + + EXPECT_ANY_THROW(service.service_await_join()); + EXPECT_ANY_THROW(service.service_await_join()); +} + +} // namespace mrc diff --git a/cpp/mrc/src/tests/test_ucx.cpp b/cpp/mrc/src/tests/test_ucx.cpp index a80321017..8f65b6f34 100644 --- a/cpp/mrc/src/tests/test_ucx.cpp +++ b/cpp/mrc/src/tests/test_ucx.cpp @@ -39,7 +39,6 @@ #include #include #include -#include using namespace mrc; using namespace ucx; diff --git a/cpp/mrc/tests/CMakeLists.txt b/cpp/mrc/tests/CMakeLists.txt index 821e0d8a2..db193b455 100644 --- a/cpp/mrc/tests/CMakeLists.txt +++ b/cpp/mrc/tests/CMakeLists.txt @@ -15,9 +15,11 @@ # Keep all source files sorted!!! add_executable(test_mrc + coroutines/test_async_generator.cpp coroutines/test_event.cpp coroutines/test_latch.cpp coroutines/test_ring_buffer.cpp + coroutines/test_task_container.cpp coroutines/test_task.cpp modules/test_mirror_tap_module.cpp modules/test_mirror_tap_orchestrator.cpp diff --git a/cpp/mrc/tests/benchmarking/test_benchmarking.hpp b/cpp/mrc/tests/benchmarking/test_benchmarking.hpp index 99de4e475..c9f7e368d 100644 --- a/cpp/mrc/tests/benchmarking/test_benchmarking.hpp +++ b/cpp/mrc/tests/benchmarking/test_benchmarking.hpp @@ -31,13 +31,11 @@ #include #include -#include #include #include #include #include #include -#include namespace mrc { diff --git a/cpp/mrc/tests/benchmarking/test_stat_gather.hpp b/cpp/mrc/tests/benchmarking/test_stat_gather.hpp index 746be4356..0af0df8ca 100644 --- a/cpp/mrc/tests/benchmarking/test_stat_gather.hpp +++ b/cpp/mrc/tests/benchmarking/test_stat_gather.hpp @@ -29,14 +29,12 @@ #include #include -#include #include #include #include #include #include #include -#include namespace mrc { class TestSegmentResources; diff --git a/cpp/mrc/tests/coroutines/test_async_generator.cpp b/cpp/mrc/tests/coroutines/test_async_generator.cpp new file mode 100644 index 000000000..81626a28c --- /dev/null +++ b/cpp/mrc/tests/coroutines/test_async_generator.cpp @@ -0,0 +1,133 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mrc/coroutines/async_generator.hpp" +#include "mrc/coroutines/sync_wait.hpp" +#include "mrc/coroutines/task.hpp" + +#include + +#include + +using namespace mrc; + +class TestCoroAsyncGenerator : public ::testing::Test +{}; + +TEST_F(TestCoroAsyncGenerator, Iterator) +{ + auto generator = []() -> coroutines::AsyncGenerator { + for (int i = 0; i < 2; i++) + { + co_yield i; + } + }(); + + auto task = [&]() -> coroutines::Task<> { + auto iter = co_await generator.begin(); + + EXPECT_TRUE(iter); + EXPECT_EQ(*iter, 0); + EXPECT_NE(iter, generator.end()); + + co_await ++iter; + + EXPECT_TRUE(iter); + EXPECT_EQ(*iter, 1); + EXPECT_NE(iter, generator.end()); + + co_await ++iter; + EXPECT_FALSE(iter); + EXPECT_EQ(iter, generator.end()); + + co_return; + }; + + coroutines::sync_wait(task()); +} + +TEST_F(TestCoroAsyncGenerator, LoopOnGenerator) +{ + auto generator = []() -> coroutines::AsyncGenerator { + for (int i = 0; i < 2; i++) + { + co_yield i; + } + }(); + + auto task = [&]() -> coroutines::Task<> { + for (int i = 0; i < 2; i++) + { + auto iter = co_await generator.begin(); + + EXPECT_TRUE(iter); + EXPECT_EQ(*iter, 0); + EXPECT_NE(iter, generator.end()); + + co_await ++iter; + + EXPECT_TRUE(iter); + EXPECT_EQ(*iter, 1); + EXPECT_NE(iter, generator.end()); + + co_await ++iter; + EXPECT_FALSE(iter); + EXPECT_EQ(iter, generator.end()); + + co_return; + } + }; + + coroutines::sync_wait(task()); +} + +TEST_F(TestCoroAsyncGenerator, MultipleBegins) +{ + auto generator = []() -> coroutines::AsyncGenerator { + for (int i = 0; i < 2; i++) + { + co_yield i; + } + }(); + + // this test shows that begin() and operator++() perform essentially the same function + // both advance the generator to the next state + // while a generator is an iterable, it doesn't hold the entire sequence in memory, it does + // what it suggests, it generates the next item from the previous + + auto task = [&]() -> coroutines::Task<> { + auto iter = co_await generator.begin(); + + EXPECT_TRUE(iter); + EXPECT_EQ(*iter, 0); + EXPECT_NE(iter, generator.end()); + + iter = co_await generator.begin(); + + EXPECT_TRUE(iter); + EXPECT_EQ(*iter, 1); + EXPECT_NE(iter, generator.end()); + + iter = co_await generator.begin(); + EXPECT_FALSE(iter); + EXPECT_EQ(iter, generator.end()); + + co_return; + }; + + coroutines::sync_wait(task()); +} diff --git a/cpp/mrc/tests/coroutines/test_event.cpp b/cpp/mrc/tests/coroutines/test_event.cpp index 68689637d..61326e0b3 100644 --- a/cpp/mrc/tests/coroutines/test_event.cpp +++ b/cpp/mrc/tests/coroutines/test_event.cpp @@ -48,7 +48,6 @@ #include #include #include -#include #include using namespace mrc; diff --git a/cpp/mrc/tests/coroutines/test_latch.cpp b/cpp/mrc/tests/coroutines/test_latch.cpp index 1136bf76e..5be3b31e7 100644 --- a/cpp/mrc/tests/coroutines/test_latch.cpp +++ b/cpp/mrc/tests/coroutines/test_latch.cpp @@ -44,7 +44,6 @@ #include #include -#include using namespace mrc; diff --git a/cpp/mrc/tests/coroutines/test_ring_buffer.cpp b/cpp/mrc/tests/coroutines/test_ring_buffer.cpp index fb9afa1c4..a5b0163a2 100644 --- a/cpp/mrc/tests/coroutines/test_ring_buffer.cpp +++ b/cpp/mrc/tests/coroutines/test_ring_buffer.cpp @@ -49,7 +49,6 @@ #include #include #include -#include #include #include #include diff --git a/cpp/mrc/tests/coroutines/test_task.cpp b/cpp/mrc/tests/coroutines/test_task.cpp index ffc40a3ef..60cbfafa5 100644 --- a/cpp/mrc/tests/coroutines/test_task.cpp +++ b/cpp/mrc/tests/coroutines/test_task.cpp @@ -49,9 +49,7 @@ #include #include #include -#include #include -#include using namespace mrc; diff --git a/cpp/mrc/tests/coroutines/test_task_container.cpp b/cpp/mrc/tests/coroutines/test_task_container.cpp new file mode 100644 index 000000000..a55f88039 --- /dev/null +++ b/cpp/mrc/tests/coroutines/test_task_container.cpp @@ -0,0 +1,23 @@ +/** + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +class TestCoroTaskContainer : public ::testing::Test +{}; + +TEST_F(TestCoroTaskContainer, LifeCycle) {} diff --git a/cpp/mrc/tests/logging/test_logging.cpp b/cpp/mrc/tests/logging/test_logging.cpp index f72cb113c..0d26a82bb 100644 --- a/cpp/mrc/tests/logging/test_logging.cpp +++ b/cpp/mrc/tests/logging/test_logging.cpp @@ -21,8 +21,6 @@ #include -#include - namespace mrc { TEST_CLASS(Logging); diff --git a/cpp/mrc/tests/modules/dynamic_module.cpp b/cpp/mrc/tests/modules/dynamic_module.cpp index 3db4e08cd..9538ed825 100644 --- a/cpp/mrc/tests/modules/dynamic_module.cpp +++ b/cpp/mrc/tests/modules/dynamic_module.cpp @@ -19,13 +19,13 @@ #include "mrc/modules/segment_modules.hpp" #include "mrc/node/rx_source.hpp" #include "mrc/segment/builder.hpp" +#include "mrc/segment/object.hpp" #include "mrc/utils/type_utils.hpp" #include "mrc/version.hpp" #include #include -#include #include #include #include diff --git a/cpp/mrc/tests/modules/test_mirror_tap_module.cpp b/cpp/mrc/tests/modules/test_mirror_tap_module.cpp index 7f68a354b..165382a94 100644 --- a/cpp/mrc/tests/modules/test_mirror_tap_module.cpp +++ b/cpp/mrc/tests/modules/test_mirror_tap_module.cpp @@ -20,10 +20,10 @@ #include "mrc/cuda/device_guard.hpp" #include "mrc/experimental/modules/mirror_tap/mirror_tap.hpp" #include "mrc/modules/properties/persistent.hpp" -#include "mrc/node/operators/broadcast.hpp" #include "mrc/node/rx_node.hpp" #include "mrc/node/rx_sink.hpp" #include "mrc/node/rx_source.hpp" +#include "mrc/node/rx_source_base.hpp" #include "mrc/options/options.hpp" #include "mrc/options/topology.hpp" #include "mrc/pipeline/executor.hpp" @@ -38,11 +38,9 @@ #include #include -#include #include #include #include -#include using namespace mrc; diff --git a/cpp/mrc/tests/modules/test_mirror_tap_orchestrator.cpp b/cpp/mrc/tests/modules/test_mirror_tap_orchestrator.cpp index ceeba44e2..2de1cf98c 100644 --- a/cpp/mrc/tests/modules/test_mirror_tap_orchestrator.cpp +++ b/cpp/mrc/tests/modules/test_mirror_tap_orchestrator.cpp @@ -20,9 +20,10 @@ #include "mrc/cuda/device_guard.hpp" #include "mrc/experimental/modules/mirror_tap/mirror_tap_orchestrator.hpp" #include "mrc/modules/properties/persistent.hpp" -#include "mrc/node/operators/broadcast.hpp" #include "mrc/node/rx_sink.hpp" +#include "mrc/node/rx_sink_base.hpp" #include "mrc/node/rx_source.hpp" +#include "mrc/node/rx_source_base.hpp" #include "mrc/options/options.hpp" #include "mrc/options/topology.hpp" #include "mrc/pipeline/executor.hpp" @@ -37,12 +38,10 @@ #include #include -#include #include #include #include #include -#include using namespace mrc; diff --git a/cpp/mrc/tests/modules/test_module_util.cpp b/cpp/mrc/tests/modules/test_module_util.cpp index 989ec4ed1..f064df81a 100644 --- a/cpp/mrc/tests/modules/test_module_util.cpp +++ b/cpp/mrc/tests/modules/test_module_util.cpp @@ -20,13 +20,11 @@ #include "mrc/modules/module_registry_util.hpp" #include "mrc/modules/properties/persistent.hpp" #include "mrc/modules/sample_modules.hpp" -#include "mrc/node/rx_source.hpp" #include "mrc/version.hpp" #include +#include -#include -#include #include #include #include diff --git a/cpp/mrc/tests/modules/test_segment_modules.cpp b/cpp/mrc/tests/modules/test_segment_modules.cpp index 6c23a930f..ac4f1ec79 100644 --- a/cpp/mrc/tests/modules/test_segment_modules.cpp +++ b/cpp/mrc/tests/modules/test_segment_modules.cpp @@ -67,6 +67,7 @@ TEST_F(TestSegmentModules, ModuleInitializationTest) { using namespace modules; + GTEST_SKIP() << "To be re-enabled by issue #390"; auto init_wrapper = [](segment::IBuilder& builder) { auto config_1 = nlohmann::json(); auto config_2 = nlohmann::json(); @@ -118,7 +119,7 @@ TEST_F(TestSegmentModules, ModuleInitializationTest) Executor executor(options); executor.register_pipeline(std::move(m_pipeline)); - executor.stop(); + executor.start(); executor.join(); } diff --git a/cpp/mrc/tests/modules/test_stream_buffer_modules.cpp b/cpp/mrc/tests/modules/test_stream_buffer_modules.cpp index c5cb376f8..cab4d21ac 100644 --- a/cpp/mrc/tests/modules/test_stream_buffer_modules.cpp +++ b/cpp/mrc/tests/modules/test_stream_buffer_modules.cpp @@ -39,13 +39,11 @@ #include #include -#include #include #include #include #include #include -#include using namespace mrc; using namespace mrc::modules::stream_buffers; @@ -57,6 +55,7 @@ TEST_F(TestStreamBufferModule, InitailizationTest) { using namespace modules; + GTEST_SKIP() << "To be re-enabled by issue #390"; auto init_wrapper = [](segment::IBuilder& builder) { auto config1 = nlohmann::json(); auto mirror_buffer1 = builder.make_module("mirror_tap", config1); @@ -70,7 +69,7 @@ TEST_F(TestStreamBufferModule, InitailizationTest) Executor executor(options); executor.register_pipeline(std::move(m_pipeline)); - executor.stop(); + executor.start(); executor.join(); } diff --git a/cpp/mrc/tests/test_channel.cpp b/cpp/mrc/tests/test_channel.cpp index 6d796dba6..1a5f8ef2e 100644 --- a/cpp/mrc/tests/test_channel.cpp +++ b/cpp/mrc/tests/test_channel.cpp @@ -27,7 +27,6 @@ #include #include -#include #include // for sleep_for #include // for duration, system_clock, milliseconds, time_point @@ -35,7 +34,6 @@ #include // for uint64_t #include // for ref, reference_wrapper #include -#include #include // IWYU thinks algorithm is needed for: auto channel = std::make_shared>(2); // IWYU pragma: no_include diff --git a/cpp/mrc/tests/test_edges.cpp b/cpp/mrc/tests/test_edges.cpp index 86e42dfb5..91c6d4e09 100644 --- a/cpp/mrc/tests/test_edges.cpp +++ b/cpp/mrc/tests/test_edges.cpp @@ -19,8 +19,10 @@ #include "mrc/channel/buffered_channel.hpp" // IWYU pragma: keep #include "mrc/channel/forward.hpp" +#include "mrc/edge/edge.hpp" // for Edge #include "mrc/edge/edge_builder.hpp" #include "mrc/edge/edge_channel.hpp" +#include "mrc/edge/edge_holder.hpp" // for EdgeHolder #include "mrc/edge/edge_readable.hpp" #include "mrc/edge/edge_writable.hpp" #include "mrc/node/generic_source.hpp" @@ -40,7 +42,6 @@ #include // for observable_member #include -#include #include #include #include @@ -996,4 +997,37 @@ TEST_F(TestEdges, EdgeTapWithSpliceRxComponent) EXPECT_TRUE(node->stream_fn_called); } + +template +class TestEdgeHolder : public edge::EdgeHolder +{ + public: + bool has_active_connection() const + { + return this->check_active_connection(false); + } + + void call_release_edge_connection() + { + this->release_edge_connection(); + } + + void call_init_owned_edge(std::shared_ptr> edge) + { + this->init_owned_edge(std::move(edge)); + } +}; + +TEST_F(TestEdges, EdgeHolderIsConnected) +{ + TestEdgeHolder edge_holder; + auto edge = std::make_shared>(); + EXPECT_FALSE(edge_holder.has_active_connection()); + + edge_holder.call_init_owned_edge(edge); + EXPECT_FALSE(edge_holder.has_active_connection()); + + edge_holder.call_release_edge_connection(); + EXPECT_FALSE(edge_holder.has_active_connection()); +} } // namespace mrc diff --git a/cpp/mrc/tests/test_executor.cpp b/cpp/mrc/tests/test_executor.cpp index e8da2fe0b..989dfe2f1 100644 --- a/cpp/mrc/tests/test_executor.cpp +++ b/cpp/mrc/tests/test_executor.cpp @@ -17,7 +17,9 @@ #include "mrc/node/rx_node.hpp" #include "mrc/node/rx_sink.hpp" +#include "mrc/node/rx_sink_base.hpp" #include "mrc/node/rx_source.hpp" +#include "mrc/node/rx_source_base.hpp" #include "mrc/options/engine_groups.hpp" #include "mrc/options/options.hpp" #include "mrc/options/topology.hpp" @@ -41,7 +43,6 @@ #include #include #include -#include #include #include #include @@ -49,7 +50,6 @@ #include #include #include -#include namespace mrc { diff --git a/cpp/mrc/tests/test_node.cpp b/cpp/mrc/tests/test_node.cpp index 428c41d2c..8305fc79d 100644 --- a/cpp/mrc/tests/test_node.cpp +++ b/cpp/mrc/tests/test_node.cpp @@ -40,7 +40,6 @@ #include #include #include -#include #include #include #include diff --git a/cpp/mrc/tests/test_pipeline.cpp b/cpp/mrc/tests/test_pipeline.cpp index c34731302..6d1bc4499 100644 --- a/cpp/mrc/tests/test_pipeline.cpp +++ b/cpp/mrc/tests/test_pipeline.cpp @@ -16,7 +16,9 @@ */ #include "mrc/node/rx_sink.hpp" +#include "mrc/node/rx_sink_base.hpp" #include "mrc/node/rx_source.hpp" +#include "mrc/node/rx_source_base.hpp" #include "mrc/options/options.hpp" #include "mrc/options/topology.hpp" #include "mrc/pipeline/executor.hpp" @@ -33,12 +35,9 @@ #include #include -#include #include #include -#include #include -#include namespace mrc { diff --git a/cpp/mrc/tests/test_segment.cpp b/cpp/mrc/tests/test_segment.cpp index be1bbc29c..bd3b09d78 100644 --- a/cpp/mrc/tests/test_segment.cpp +++ b/cpp/mrc/tests/test_segment.cpp @@ -23,7 +23,6 @@ #include "mrc/node/rx_node.hpp" #include "mrc/node/rx_sink.hpp" #include "mrc/node/rx_source.hpp" -#include "mrc/node/rx_source_base.hpp" #include "mrc/options/options.hpp" #include "mrc/options/topology.hpp" #include "mrc/pipeline/executor.hpp" @@ -40,9 +39,7 @@ #include #include -#include #include -#include #include #include #include diff --git a/cpp/mrc/tests/test_thread.cpp b/cpp/mrc/tests/test_thread.cpp index c19753734..88785379f 100644 --- a/cpp/mrc/tests/test_thread.cpp +++ b/cpp/mrc/tests/test_thread.cpp @@ -25,7 +25,6 @@ #include #include -#include using namespace mrc; diff --git a/dependencies.yaml b/dependencies.yaml new file mode 100644 index 000000000..966608a19 --- /dev/null +++ b/dependencies.yaml @@ -0,0 +1,60 @@ +# Dependency list for https://github.com/rapidsai/dependency-file-generator +files: + all: + output: conda + matrix: + cuda: ["11.8"] + arch: [x86_64] + includes: + - empty + - build_cpp + - cudatoolkit + +channels: + - rapidsai + - nvidia/label/cuda-11.8.0 + - nvidia + - rapidsai-nightly + - conda-forge + +dependencies: + + empty: + common: + - output_types: [conda] + packages: + - cxx-compiler + + build_cpp: + common: + - output_types: [conda] + packages: + - boost-cpp=1.82 + - ccache + - cmake=3.24 + - cuda-nvcc + - cxx-compiler + - glog=0.6 + - gxx=11.2 + - libgrpc=1.54.0 + - libhwloc=2.9.2 + - librmm=23.06 + - ninja=1.10 + - ucx=1.14 + - nlohmann_json=3.9 + - gtest=1.13 + - scikit-build>=0.17 + - pybind11-stubgen=0.10 + - python=3.10 + cudatoolkit: + specific: + - output_types: [conda] + matrices: + - matrix: + cuda: "11.8" + packages: + - cuda-cudart-dev=11.8 + - cuda-nvrtc-dev=11.8 + - cuda-version=11.8 + - cuda-nvml-dev=11.8 + - cuda-tools=11.8 diff --git a/docs/quickstart/CMakeLists.txt b/docs/quickstart/CMakeLists.txt index 1b87c4b87..3a6766c35 100644 --- a/docs/quickstart/CMakeLists.txt +++ b/docs/quickstart/CMakeLists.txt @@ -28,7 +28,7 @@ list(PREPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../../external/utili include(morpheus_utils/load) project(mrc-quickstart - VERSION 23.07.00 + VERSION 23.11.00 LANGUAGES C CXX ) diff --git a/docs/quickstart/environment_cpp.yml b/docs/quickstart/environment_cpp.yml index 379bf6477..775ab13ad 100644 --- a/docs/quickstart/environment_cpp.yml +++ b/docs/quickstart/environment_cpp.yml @@ -30,7 +30,7 @@ dependencies: - pkg-config=0.29 - python=3.10 - scikit-build>=0.12 - - mrc=23.07 + - mrc=23.11 - sysroot_linux-64=2.17 - pip: - cython diff --git a/docs/quickstart/python/mrc_qs_python/ex02_reactive_operators/README.md b/docs/quickstart/python/mrc_qs_python/ex02_reactive_operators/README.md index 705c79e15..95cb08d01 100644 --- a/docs/quickstart/python/mrc_qs_python/ex02_reactive_operators/README.md +++ b/docs/quickstart/python/mrc_qs_python/ex02_reactive_operators/README.md @@ -27,36 +27,33 @@ Lets look at a more complex example: value_count = 0 value_sum = 0 -def node_fn(src: mrc.Observable, dst: mrc.Subscriber): - def update_obj(x: MyCustomClass): - nonlocal value_count - nonlocal value_sum +def update_obj(x: MyCustomClass): + nonlocal value_count + nonlocal value_sum - # Alter the value property of the class - x.value = x.value * 2 + # Alter the value property of the class + x.value = x.value * 2 - # Update the sum values - value_count += 1 - value_sum += x.value + # Update the sum values + value_count += 1 + value_sum += x.value - return x + return x - def on_completed(): +def on_completed(): - # Prevent divide by 0. Just in case - if (value_count <= 0): - return + # Prevent divide by 0. Just in case + if (value_count <= 0): + return - return MyCustomClass(value_sum / value_count, "Mean") - - src.pipe( - ops.filter(lambda x: x.value % 2 == 0), - ops.map(update_obj), - ops.on_completed(on_completed) - ).subscribe(dst) + return MyCustomClass(value_sum / value_count, "Mean") # Make an intermediate node -node = seg.make_node_full("node", node_fn) +node = seg.make_node("node", + ops.filter(lambda x: x.value % 2 == 0), + ops.map(update_obj), + ops.on_completed(on_completed) +) ``` In this example, we are using 3 different operators: `filter`, `map`, and `on_completed`: @@ -66,7 +63,7 @@ In this example, we are using 3 different operators: `filter`, `map`, and `on_co - The `map` operator can transform the incoming value and return a new value - In our example, we are doubling the `value` property and recording the total count and total sum of this property - The `on_completed` function is only called once when there are no more messages to process. You can optionally return a value which will be passed on to the rest of the pipeline. - - In our example, we are calculating the average from the sum and count values and emitting a new obect with the value set to the mean + - In our example, we are calculating the average from the sum and count values and emitting a new object with the value set to the mean In combination, these operators perform a higher level functionality to modify the stream, record some information, and finally print an analysis of all emitted values. Let's see it in practice. diff --git a/external/utilities b/external/utilities index a5b9689e3..c642d23a8 160000 --- a/external/utilities +++ b/external/utilities @@ -1 +1 @@ -Subproject commit a5b9689e3a82fe5b49245b0a02c907ea70aed7b8 +Subproject commit c642d23a80871946bc8b17e98cf260958f531e3a diff --git a/python/mrc/_pymrc/CMakeLists.txt b/python/mrc/_pymrc/CMakeLists.txt index b1aa4eb77..ed385504f 100644 --- a/python/mrc/_pymrc/CMakeLists.txt +++ b/python/mrc/_pymrc/CMakeLists.txt @@ -18,6 +18,7 @@ find_package(prometheus-cpp REQUIRED) # Keep all source files sorted!!! add_library(pymrc + src/coro.cpp src/executor.cpp src/logging.cpp src/module_registry.cpp @@ -49,8 +50,9 @@ target_link_libraries(pymrc PUBLIC ${PROJECT_NAME}::libmrc ${Python_LIBRARIES} - prometheus-cpp::core pybind11::pybind11 + PRIVATE + prometheus-cpp::core ) target_include_directories(pymrc diff --git a/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp b/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp new file mode 100644 index 000000000..36dad7208 --- /dev/null +++ b/python/mrc/_pymrc/include/pymrc/asyncio_runnable.hpp @@ -0,0 +1,360 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "pymrc/asyncio_scheduler.hpp" +#include "pymrc/utilities/object_wrappers.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace mrc::pymrc { + +/** + * @brief A wrapper for executing a function as an async boost fiber, the result of which is a + * C++20 coroutine awaiter. + */ +template +class BoostFutureAwaitableOperation +{ + class Awaiter; + + public: + BoostFutureAwaitableOperation(std::function fn) : m_fn(std::move(fn)) {} + + /** + * @brief Calls the wrapped function as an asyncboost fiber and returns a C++20 coroutine awaiter. + */ + template + auto operator()(ArgsT&&... args) -> Awaiter + { + // Make a copy of m_fn here so we can call this operator again + return Awaiter(m_fn, std::forward(args)...); + } + + private: + class Awaiter + { + public: + using return_t = typename std::function::result_type; + + template + Awaiter(std::function fn, ArgsT&&... args) + { + m_future = boost::fibers::async(boost::fibers::launch::post, fn, std::forward(args)...); + } + + bool await_ready() noexcept + { + return false; + } + + void await_suspend(std::coroutine_handle<> continuation) noexcept + { + // Launch a new fiber that waits on the future and then resumes the coroutine + boost::fibers::async( + boost::fibers::launch::post, + [this](std::coroutine_handle<> continuation) { + // Wait on the future + m_future.wait(); + + // Resume the coroutine + continuation.resume(); + }, + std::move(continuation)); + } + + auto await_resume() + { + return m_future.get(); + } + + private: + boost::fibers::future m_future; + std::function)> m_inner_fn; + }; + + std::function m_fn; +}; + +/** + * @brief A MRC Sink which receives from a channel using an awaitable interface. + */ +template +class AsyncSink : public mrc::node::WritableProvider, + public mrc::node::ReadableAcceptor, + public mrc::node::SinkChannelOwner +{ + protected: + AsyncSink() : + m_read_async([this](T& value) { + return this->get_readable_edge()->await_read(value); + }) + { + // Set the default channel + this->set_channel(std::make_unique>()); + } + + /** + * @brief Asynchronously reads a value from the sink's channel + */ + coroutines::Task read_async(T& value) + { + co_return co_await m_read_async(std::ref(value)); + } + + private: + BoostFutureAwaitableOperation m_read_async; +}; + +/** + * @brief A MRC Source which produces to a channel using an awaitable interface. + */ +template +class AsyncSource : public mrc::node::WritableAcceptor, + public mrc::node::ReadableProvider, + public mrc::node::SourceChannelOwner +{ + protected: + AsyncSource() : + m_write_async([this](T&& value) { + return this->get_writable_edge()->await_write(std::move(value)); + }) + { + // Set the default channel + this->set_channel(std::make_unique>()); + } + + /** + * @brief Asynchronously writes a value to the source's channel + */ + coroutines::Task write_async(T&& value) + { + co_return co_await m_write_async(std::move(value)); + } + + private: + BoostFutureAwaitableOperation m_write_async; +}; + +/** + * @brief A MRC Runnable base class which hosts it's own asyncio loop and exposes a flatmap hook + */ +template +class AsyncioRunnable : public AsyncSink, + public AsyncSource, + public mrc::runnable::RunnableWithContext<> +{ + using state_t = mrc::runnable::Runnable::State; + using task_buffer_t = mrc::coroutines::ClosableRingBuffer; + + public: + ~AsyncioRunnable() override = default; + + private: + /** + * @brief Runnable's entrypoint. + */ + void run(mrc::runnable::Context& ctx) override; + + /** + * @brief Runnable's state control, for stopping from MRC. + */ + void on_state_update(const state_t& state) final; + + /** + * @brief The top-level coroutine which is run while the asyncio event loop is running. + */ + coroutines::Task<> main_task(std::shared_ptr scheduler); + + /** + * @brief The per-value coroutine run asynchronously alongside other calls. + */ + coroutines::Task<> process_one(InputT value, + std::shared_ptr on, + ExceptionCatcher& catcher); + + /** + * @brief Value's read from the sink's channel are fed to this function and yields from the + * resulting generator are written to the source's channel. + */ + virtual mrc::coroutines::AsyncGenerator on_data(InputT&& value) = 0; + + std::stop_source m_stop_source; + + /** + * @brief A semaphore used to control the number of outstanding operations. Acquire one before + * beginning a task, and release it when finished. + */ + std::counting_semaphore<8> m_task_tickets{8}; +}; + +template +void AsyncioRunnable::run(mrc::runnable::Context& ctx) +{ + std::exception_ptr exception; + + { + py::gil_scoped_acquire gil; + + auto asyncio = py::module_::import("asyncio"); + + auto loop = [](auto& asyncio) -> PyObjectHolder { + try + { + return asyncio.attr("get_running_loop")(); + } catch (...) + { + return py::none(); + } + }(asyncio); + + if (not loop.is_none()) + { + throw std::runtime_error("asyncio loop already running, but runnable is expected to create it."); + } + + // Need to create a loop + LOG(INFO) << "AsyncioRunnable::run() > Creating new event loop"; + + // Gets (or more likely, creates) an event loop and runs it forever until stop is called + loop = asyncio.attr("new_event_loop")(); + + // Set the event loop as the current event loop + asyncio.attr("set_event_loop")(loop); + + // TODO(MDD): Eventually we should get this from the context object. For now, just create it directly + auto scheduler = std::make_shared(loop); + + auto py_awaitable = coro::BoostFibersMainPyAwaitable(this->main_task(scheduler)); + + LOG(INFO) << "AsyncioRunnable::run() > Calling run_until_complete() on main_task()"; + + try + { + loop.attr("run_until_complete")(std::move(py_awaitable)); + } catch (...) + { + exception = std::current_exception(); + } + + loop.attr("close")(); + } + + // Need to drop the output edges + mrc::node::SourceProperties::release_edge_connection(); + mrc::node::SinkProperties::release_edge_connection(); + + if (exception != nullptr) + { + std::rethrow_exception(exception); + } +} + +template +coroutines::Task<> AsyncioRunnable::main_task(std::shared_ptr scheduler) +{ + coroutines::TaskContainer outstanding_tasks(scheduler); + + ExceptionCatcher catcher{}; + + while (not m_stop_source.stop_requested() and not catcher.has_exception()) + { + m_task_tickets.acquire(); + + InputT data; + + auto read_status = co_await this->read_async(data); + + if (read_status != mrc::channel::Status::success) + { + break; + } + + outstanding_tasks.start(this->process_one(std::move(data), scheduler, catcher)); + } + + co_await outstanding_tasks.garbage_collect_and_yield_until_empty(); + + catcher.rethrow_next_exception(); +} + +template +coroutines::Task<> AsyncioRunnable::process_one(InputT value, + std::shared_ptr on, + ExceptionCatcher& catcher) +{ + co_await on->yield(); + + try + { + // Call the on_data function + auto on_data_gen = this->on_data(std::move(value)); + + auto iter = co_await on_data_gen.begin(); + + while (iter != on_data_gen.end()) + { + // Weird bug, cant directly move the value into the async_write call + auto data = std::move(*iter); + + co_await this->write_async(std::move(data)); + + // Advance the iterator + co_await ++iter; + } + } catch (...) + { + catcher.push_exception(std::current_exception()); + } + + m_task_tickets.release(); +} + +template +void AsyncioRunnable::on_state_update(const state_t& state) +{ + switch (state) + { + case state_t::Stop: + // Do nothing, we wait for the upstream channel to return closed + // m_stop_source.request_stop(); + break; + + case state_t::Kill: + m_stop_source.request_stop(); + break; + + default: + break; + } +} + +} // namespace mrc::pymrc diff --git a/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp b/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp new file mode 100644 index 000000000..3d9e563b9 --- /dev/null +++ b/python/mrc/_pymrc/include/pymrc/asyncio_scheduler.hpp @@ -0,0 +1,105 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "pymrc/coro.hpp" +#include "pymrc/utilities/acquire_gil.hpp" +#include "pymrc/utilities/object_wrappers.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace py = pybind11; + +namespace mrc::pymrc { + +/** + * @brief A MRC Scheduler which allows resuming C++20 coroutines on an Asyncio event loop. + */ +class AsyncioScheduler : public mrc::coroutines::Scheduler +{ + private: + class ContinueOnLoopOperation + { + public: + ContinueOnLoopOperation(PyObjectHolder loop) : m_loop(std::move(loop)) {} + + static bool await_ready() noexcept + { + return false; + } + + void await_suspend(std::coroutine_handle<> handle) noexcept + { + AsyncioScheduler::resume(m_loop, handle); + } + + static void await_resume() noexcept {} + + private: + PyObjectHolder m_loop; + }; + + static void resume(PyObjectHolder loop, std::coroutine_handle<> handle) noexcept + { + pybind11::gil_scoped_acquire acquire; + loop.attr("call_soon_threadsafe")(pybind11::cpp_function([handle]() { + pybind11::gil_scoped_release release; + handle.resume(); + })); + } + + public: + AsyncioScheduler(PyObjectHolder loop) : m_loop(std::move(loop)) {} + + /** + * @brief Resumes a coroutine on the scheduler's Asyncio event loop + */ + void resume(std::coroutine_handle<> handle) noexcept override + { + AsyncioScheduler::resume(m_loop, handle); + } + + /** + * @brief Suspends the current function and resumes it on the scheduler's Asyncio event loop + */ + [[nodiscard]] coroutines::Task<> schedule() override + { + co_await ContinueOnLoopOperation(m_loop); + } + + /** + * @brief Suspends the current function and resumes it on the scheduler's Asyncio event loop + */ + [[nodiscard]] coroutines::Task<> yield() override + { + co_await ContinueOnLoopOperation(m_loop); + } + + private: + mrc::pymrc::PyHolder m_loop; +}; + +} // namespace mrc::pymrc diff --git a/python/mrc/_pymrc/include/pymrc/coro.hpp b/python/mrc/_pymrc/include/pymrc/coro.hpp new file mode 100644 index 000000000..ad8224a58 --- /dev/null +++ b/python/mrc/_pymrc/include/pymrc/coro.hpp @@ -0,0 +1,444 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include // for operator<<, basic_ostringstream +#include // for runtime_error +#include // for string +#include + +// Dont directly include python headers +// IWYU pragma: no_include + +namespace mrc::pymrc::coro { + +class PYBIND11_EXPORT StopIteration : public pybind11::stop_iteration +{ + public: + StopIteration(pybind11::object&& result) : stop_iteration("--"), m_result(std::move(result)){}; + ~StopIteration() override; + + void set_error() const override + { + PyErr_SetObject(PyExc_StopIteration, this->m_result.ptr()); + } + + private: + pybind11::object m_result; +}; + +class PYBIND11_EXPORT CppToPyAwaitable : public std::enable_shared_from_this +{ + public: + CppToPyAwaitable() = default; + + template + CppToPyAwaitable(mrc::coroutines::Task&& task) + { + auto converter = [](mrc::coroutines::Task incoming_task) -> mrc::coroutines::Task { + DCHECK_EQ(PyGILState_Check(), 0) << "Should not have the GIL when resuming a C++ coroutine"; + + mrc::pymrc::PyHolder holder; + + if constexpr (std::is_same_v) + { + co_await incoming_task; + + // Need the GIL to make the return object + pybind11::gil_scoped_acquire gil; + + holder = pybind11::none(); + } + else + { + auto result = co_await incoming_task; + + // Need the GIL to cast the return object + pybind11::gil_scoped_acquire gil; + + holder = pybind11::cast(std::move(result)); + } + + co_return holder; + }; + + m_task = converter(std::move(task)); + } + + CppToPyAwaitable(mrc::coroutines::Task&& task) : m_task(std::move(task)) {} + + std::shared_ptr iter() + { + return this->shared_from_this(); + } + + std::shared_ptr await() + { + return this->shared_from_this(); + } + + void next() + { + // Need to release the GIL before waiting + pybind11::gil_scoped_release nogil; + + // Run the tick function which will resume the coroutine + this->tick(); + + if (m_task.is_ready()) + { + pybind11::gil_scoped_acquire gil; + + // job done -> throw + auto exception = StopIteration(std::move(m_task.promise().result())); + + // Destroy the task now that we have the value + m_task.destroy(); + + throw exception; + } + } + + protected: + virtual void tick() + { + if (!m_has_resumed) + { + m_has_resumed = true; + + m_task.resume(); + } + } + + bool m_has_resumed{false}; + mrc::coroutines::Task m_task; +}; + +/** + * @brief Similar to CppToPyAwaitable but will yield to other fibers when waiting for the coroutine to finish. Use this + * once per loop at the main entry point for the asyncio loop + * + */ +class PYBIND11_EXPORT BoostFibersMainPyAwaitable : public CppToPyAwaitable +{ + public: + using CppToPyAwaitable::CppToPyAwaitable; + + protected: + void tick() override + { + // Call the base class and then see if any fibers need processing by calling yield + CppToPyAwaitable::tick(); + + bool has_fibers = boost::fibers::has_ready_fibers(); + + if (has_fibers) + { + // Yield to other fibers + boost::this_fiber::yield(); + } + } +}; + +class PYBIND11_EXPORT PyTaskToCppAwaitable +{ + public: + PyTaskToCppAwaitable() = default; + PyTaskToCppAwaitable(mrc::pymrc::PyObjectHolder&& task) : m_task(std::move(task)) + { + pybind11::gil_scoped_acquire acquire; + + auto asyncio = pybind11::module_::import("asyncio"); + + if (not asyncio.attr("isfuture")(m_task).cast()) + { + if (not asyncio.attr("iscoroutine")(m_task).cast()) + { + throw std::runtime_error(MRC_CONCAT_STR("PyTaskToCppAwaitable expected task or coroutine but got " + << pybind11::repr(m_task).cast())); + } + + m_task = asyncio.attr("create_task")(m_task); + } + } + + static bool await_ready() noexcept + { + // Always suspend + return false; + } + + void await_suspend(std::coroutine_handle<> caller) noexcept + { + pybind11::gil_scoped_acquire gil; + + auto done_callback = pybind11::cpp_function([this, caller](pybind11::object future) { + try + { + // Save the result value + m_result = future.attr("result")(); + } catch (pybind11::error_already_set) + { + m_exception_ptr = std::current_exception(); + } + + pybind11::gil_scoped_release nogil; + + // Resume the coroutine + caller.resume(); + }); + + m_task.attr("add_done_callback")(done_callback); + } + + mrc::pymrc::PyHolder await_resume() + { + if (m_exception_ptr) + { + std::rethrow_exception(m_exception_ptr); + } + + return std::move(m_result); + } + + private: + mrc::pymrc::PyObjectHolder m_task; + mrc::pymrc::PyHolder m_result; + std::exception_ptr m_exception_ptr; +}; + +// ====== HELPER MACROS ====== + +#define MRC_PYBIND11_FAIL_ABSTRACT(cname, fnname) \ + pybind11::pybind11_fail(MRC_CONCAT_STR("Tried to call pure virtual function \"" << PYBIND11_STRINGIFY(cname) \ + << "::" << fnname << "\"")); + +// ====== OVERRIDE PURE TEMPLATE ====== +#define MRC_PYBIND11_OVERRIDE_PURE_TEMPLATE_NAME(ret_type, abstract_cname, cname, name, fn, ...) \ + do \ + { \ + PYBIND11_OVERRIDE_IMPL(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__); \ + if constexpr (std::is_same_v) \ + { \ + MRC_PYBIND11_FAIL_ABSTRACT(PYBIND11_TYPE(abstract_cname), name); \ + } \ + else \ + { \ + return cname::fn(__VA_ARGS__); \ + } \ + } while (false) + +#define MRC_PYBIND11_OVERRIDE_PURE_TEMPLATE(ret_type, abstract_cname, cname, fn, ...) \ + MRC_PYBIND11_OVERRIDE_PURE_TEMPLATE_NAME(PYBIND11_TYPE(ret_type), \ + PYBIND11_TYPE(abstract_cname), \ + PYBIND11_TYPE(cname), \ + #fn, \ + fn, \ + __VA_ARGS__) +// ====== OVERRIDE PURE TEMPLATE ====== + +// ====== OVERRIDE COROUTINE IMPL ====== +#define MRC_PYBIND11_OVERRIDE_CORO_IMPL(ret_type, cname, name, ...) \ + do \ + { \ + DCHECK_EQ(PyGILState_Check(), 0) << "Should not have the GIL when resuming a C++ coroutine"; \ + pybind11::gil_scoped_acquire gil; \ + pybind11::function override = pybind11::get_override(static_cast(this), name); \ + if (override) \ + { \ + auto o_coro = override(__VA_ARGS__); \ + auto asyncio_module = pybind11::module::import("asyncio"); \ + /* Return type must be a coroutine to allow calling asyncio.create_task() */ \ + if (!asyncio_module.attr("iscoroutine")(o_coro).cast()) \ + { \ + pybind11::pybind11_fail(MRC_CONCAT_STR("Return value from overriden async function " \ + << PYBIND11_STRINGIFY(cname) << "::" << name \ + << " did not return a coroutine. Returned: " \ + << pybind11::str(o_coro).cast())); \ + } \ + auto o_task = asyncio_module.attr("create_task")(o_coro); \ + mrc::pymrc::PyHolder o_result; \ + { \ + pybind11::gil_scoped_release nogil; \ + o_result = co_await mrc::pymrc::coro::PyTaskToCppAwaitable(std::move(o_task)); \ + DCHECK_EQ(PyGILState_Check(), 0) << "Should not have the GIL after returning from co_await"; \ + } \ + if (pybind11::detail::cast_is_temporary_value_reference::value) \ + { \ + static pybind11::detail::override_caster_t caster; \ + co_return pybind11::detail::cast_ref(std::move(o_result), caster); \ + } \ + co_return pybind11::detail::cast_safe(std::move(o_result)); \ + } \ + } while (false) +// ====== OVERRIDE COROUTINE IMPL====== + +// ====== OVERRIDE COROUTINE ====== +#define MRC_PYBIND11_OVERRIDE_CORO_NAME(ret_type, cname, name, fn, ...) \ + do \ + { \ + MRC_PYBIND11_OVERRIDE_CORO_IMPL(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__); \ + return cname::fn(__VA_ARGS__); \ + } while (false) + +#define MRC_PYBIND11_OVERRIDE_CORO(ret_type, cname, fn, ...) \ + MRC_PYBIND11_OVERRIDE_CORO_NAME(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__) +// ====== OVERRIDE COROUTINE ====== + +// ====== OVERRIDE COROUTINE PURE====== +#define MRC_PYBIND11_OVERRIDE_CORO_PURE_NAME(ret_type, cname, name, fn, ...) \ + do \ + { \ + MRC_PYBIND11_OVERRIDE_CORO_IMPL(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__); \ + MRC_PYBIND11_FAIL_ABSTRACT(PYBIND11_TYPE(cname), name); \ + } while (false) + +#define MRC_PYBIND11_OVERRIDE_CORO_PURE(ret_type, cname, fn, ...) \ + MRC_PYBIND11_OVERRIDE_CORO_PURE_NAME(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__) +// ====== OVERRIDE COROUTINE PURE====== + +// ====== OVERRIDE COROUTINE PURE TEMPLATE====== +#define MRC_PYBIND11_OVERRIDE_CORO_PURE_TEMPLATE_NAME(ret_type, abstract_cname, cname, name, fn, ...) \ + do \ + { \ + MRC_PYBIND11_OVERRIDE_CORO_IMPL(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__); \ + if constexpr (std::is_same_v) \ + { \ + MRC_PYBIND11_FAIL_ABSTRACT(PYBIND11_TYPE(abstract_cname), name); \ + } \ + else \ + { \ + co_return co_await cname::fn(__VA_ARGS__); \ + } \ + } while (false) + +#define MRC_PYBIND11_OVERRIDE_CORO_PURE_TEMPLATE(ret_type, abstract_cname, cname, fn, ...) \ + MRC_PYBIND11_OVERRIDE_CORO_PURE_TEMPLATE_NAME(PYBIND11_TYPE(ret_type), \ + PYBIND11_TYPE(abstract_cname), \ + PYBIND11_TYPE(cname), \ + #fn, \ + fn, \ + __VA_ARGS__) +// ====== OVERRIDE COROUTINE PURE TEMPLATE====== + +} // namespace mrc::pymrc::coro + +// NOLINTNEXTLINE(modernize-concat-nested-namespaces) +namespace PYBIND11_NAMESPACE { +namespace detail { + +/** + * @brief Provides a type caster for converting a C++ coroutine to a python awaitable. Include this file in any pybind11 + * module to automatically convert the types. Allows for converting arguments and return values. + * + * @tparam ReturnT The return type of the coroutine + */ +template +struct type_caster> +{ + public: + /** + * This macro establishes the name 'inty' in + * function signatures and declares a local variable + * 'value' of type inty + */ + PYBIND11_TYPE_CASTER(mrc::coroutines::Task, _("typing.Awaitable[") + make_caster::name + _("]")); + + /** + * Conversion part 1 (Python->C++): convert a PyObject into a inty + * instance or return false upon failure. The second argument + * indicates whether implicit conversions should be applied. + */ + bool load(handle src, bool convert) + { + if (!src || src.is_none()) + { + return false; + } + + if (!PyCoro_CheckExact(src.ptr())) + { + return false; + } + + auto cpp_coro = [](mrc::pymrc::PyHolder py_task) -> mrc::coroutines::Task { + DCHECK_EQ(PyGILState_Check(), 0) << "Should not have the GIL when resuming a C++ coroutine"; + + // Always assume we are resuming without the GIL + pybind11::gil_scoped_acquire gil; + + auto asyncio_task = pybind11::module_::import("asyncio").attr("create_task")(py_task); + + mrc::pymrc::PyHolder py_result; + { + // Release the GIL before awaiting + pybind11::gil_scoped_release nogil; + + py_result = co_await mrc::pymrc::coro::PyTaskToCppAwaitable(std::move(asyncio_task)); + } + + // Now cast back to the C++ type + if (pybind11::detail::cast_is_temporary_value_reference::value) + { + static pybind11::detail::override_caster_t caster; + co_return pybind11::detail::cast_ref(std::move(py_result), caster); + } + co_return pybind11::detail::cast_safe(std::move(py_result)); + }; + + value = cpp_coro(pybind11::reinterpret_borrow(std::move(src))); + + return true; + } + + /** + * Conversion part 2 (C++ -> Python): convert an inty instance into + * a Python object. The second and third arguments are used to + * indicate the return value policy and parent object (for + * ``return_value_policy::reference_internal``) and are generally + * ignored by implicit casters. + */ + static handle cast(mrc::coroutines::Task src, return_value_policy policy, handle parent) + { + // Wrap the object in a CppToPyAwaitable + std::shared_ptr awaitable = + std::make_shared(std::move(src)); + + // Convert the object to a python object + auto py_awaitable = pybind11::cast(std::move(awaitable)); + + return py_awaitable.release(); + } +}; + +} // namespace detail +} // namespace PYBIND11_NAMESPACE diff --git a/python/mrc/_pymrc/include/pymrc/utilities/function_wrappers.hpp b/python/mrc/_pymrc/include/pymrc/utilities/function_wrappers.hpp index f6f5c3c30..83e243d63 100644 --- a/python/mrc/_pymrc/include/pymrc/utilities/function_wrappers.hpp +++ b/python/mrc/_pymrc/include/pymrc/utilities/function_wrappers.hpp @@ -27,7 +27,6 @@ #include #include -#include #include #include #include diff --git a/python/mrc/_pymrc/src/coro.cpp b/python/mrc/_pymrc/src/coro.cpp new file mode 100644 index 000000000..8bb57cb84 --- /dev/null +++ b/python/mrc/_pymrc/src/coro.cpp @@ -0,0 +1,26 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pymrc/coro.hpp" + +namespace mrc::pymrc::coro { + +namespace py = pybind11; + +StopIteration::~StopIteration() = default; + +} // namespace mrc::pymrc::coro diff --git a/python/mrc/_pymrc/src/executor.cpp b/python/mrc/_pymrc/src/executor.cpp index a62e2c1e7..8e1ad5c67 100644 --- a/python/mrc/_pymrc/src/executor.cpp +++ b/python/mrc/_pymrc/src/executor.cpp @@ -25,7 +25,6 @@ #include "mrc/types.hpp" #include -#include #include #include #include diff --git a/python/mrc/_pymrc/src/module_registry.cpp b/python/mrc/_pymrc/src/module_registry.cpp index 424eb2b68..bedcf7ebf 100644 --- a/python/mrc/_pymrc/src/module_registry.cpp +++ b/python/mrc/_pymrc/src/module_registry.cpp @@ -28,7 +28,6 @@ #include #include -#include #include #include #include diff --git a/python/mrc/_pymrc/src/module_wrappers/pickle.cpp b/python/mrc/_pymrc/src/module_wrappers/pickle.cpp index fd6e99290..378fa83e2 100644 --- a/python/mrc/_pymrc/src/module_wrappers/pickle.cpp +++ b/python/mrc/_pymrc/src/module_wrappers/pickle.cpp @@ -24,7 +24,6 @@ #include #include -#include #include #include diff --git a/python/mrc/_pymrc/src/module_wrappers/shared_memory.cpp b/python/mrc/_pymrc/src/module_wrappers/shared_memory.cpp index 9a4106f76..7eac9864f 100644 --- a/python/mrc/_pymrc/src/module_wrappers/shared_memory.cpp +++ b/python/mrc/_pymrc/src/module_wrappers/shared_memory.cpp @@ -20,10 +20,9 @@ #include "pymrc/utilities/object_cache.hpp" #include -#include +#include // IWYU pragma: keep #include -#include #include #include #include diff --git a/python/mrc/_pymrc/src/segment.cpp b/python/mrc/_pymrc/src/segment.cpp index 4e60e63e4..f5b931cf0 100644 --- a/python/mrc/_pymrc/src/segment.cpp +++ b/python/mrc/_pymrc/src/segment.cpp @@ -28,12 +28,9 @@ #include "mrc/channel/status.hpp" #include "mrc/edge/edge_builder.hpp" #include "mrc/node/port_registry.hpp" -#include "mrc/node/rx_sink_base.hpp" -#include "mrc/node/rx_source_base.hpp" #include "mrc/runnable/context.hpp" #include "mrc/segment/builder.hpp" #include "mrc/segment/object.hpp" -#include "mrc/types.hpp" #include #include @@ -44,7 +41,6 @@ #include #include #include -#include #include #include #include @@ -52,7 +48,6 @@ #include #include #include -#include // IWYU thinks we need array for py::print // IWYU pragma: no_include diff --git a/python/mrc/_pymrc/src/subscriber.cpp b/python/mrc/_pymrc/src/subscriber.cpp index 35f795175..c00aaa187 100644 --- a/python/mrc/_pymrc/src/subscriber.cpp +++ b/python/mrc/_pymrc/src/subscriber.cpp @@ -28,7 +28,6 @@ #include #include -#include #include #include #include @@ -128,12 +127,6 @@ PySubscription ObservableProxy::subscribe(PyObjectObservable* self, PyObjectSubs return self->subscribe(subscriber); } -template -PyObjectObservable pipe_ops(const PyObjectObservable* self, OpsT&&... ops) -{ - return (*self | ... | ops); -} - PyObjectObservable ObservableProxy::pipe(const PyObjectObservable* self, py::args args) { std::vector operators; @@ -150,66 +143,19 @@ PyObjectObservable ObservableProxy::pipe(const PyObjectObservable* self, py::arg operators.emplace_back(op.get_operate_fn()); } - switch (operators.size()) + if (operators.empty()) + { + throw std::runtime_error("pipe() must be given at least one argument"); + } + + auto result = *self | operators[0]; + + for (auto i = 1; i < operators.size(); i++) { - case 1: - return pipe_ops(self, operators[0]); - case 2: - return pipe_ops(self, operators[0], operators[1]); - case 3: - return pipe_ops(self, operators[0], operators[1], operators[2]); - case 4: - return pipe_ops(self, operators[0], operators[1], operators[2], operators[3]); - case 5: - return pipe_ops(self, operators[0], operators[1], operators[2], operators[3], operators[4]); - case 6: - return pipe_ops(self, operators[0], operators[1], operators[2], operators[3], operators[4], operators[5]); - case 7: - return pipe_ops(self, - operators[0], - operators[1], - operators[2], - operators[3], - operators[4], - operators[5], - operators[6]); - case 8: - return pipe_ops(self, - operators[0], - operators[1], - operators[2], - operators[3], - operators[4], - operators[5], - operators[6], - operators[7]); - case 9: - return pipe_ops(self, - operators[0], - operators[1], - operators[2], - operators[3], - operators[4], - operators[5], - operators[6], - operators[7], - operators[8]); - case 10: - return pipe_ops(self, - operators[0], - operators[1], - operators[2], - operators[3], - operators[4], - operators[5], - operators[6], - operators[7], - operators[8], - operators[9]); - default: - // Not supported error - throw std::runtime_error("pipe() only supports up 10 arguments. Please use another pipe() to use more"); + result = result | operators[i]; } + + return result; } } // namespace mrc::pymrc diff --git a/python/mrc/_pymrc/src/utilities/object_cache.cpp b/python/mrc/_pymrc/src/utilities/object_cache.cpp index 604a21200..574afc2a2 100644 --- a/python/mrc/_pymrc/src/utilities/object_cache.cpp +++ b/python/mrc/_pymrc/src/utilities/object_cache.cpp @@ -24,7 +24,6 @@ #include #include -#include #include #include #include diff --git a/python/mrc/_pymrc/src/watchers.cpp b/python/mrc/_pymrc/src/watchers.cpp index d474d7ae4..114bc6dac 100644 --- a/python/mrc/_pymrc/src/watchers.cpp +++ b/python/mrc/_pymrc/src/watchers.cpp @@ -24,8 +24,8 @@ #include "mrc/benchmarking/tracer.hpp" #include "mrc/node/rx_node.hpp" #include "mrc/node/rx_sink.hpp" -#include "mrc/node/rx_source.hpp" #include "mrc/segment/builder.hpp" +#include "mrc/segment/object.hpp" #include #include @@ -34,11 +34,9 @@ #include #include -#include #include #include #include -#include namespace mrc::pymrc { diff --git a/python/mrc/_pymrc/tests/CMakeLists.txt b/python/mrc/_pymrc/tests/CMakeLists.txt index 4ac354a78..02186de90 100644 --- a/python/mrc/_pymrc/tests/CMakeLists.txt +++ b/python/mrc/_pymrc/tests/CMakeLists.txt @@ -17,8 +17,11 @@ list(APPEND CMAKE_MESSAGE_CONTEXT "tests") find_package(pybind11 REQUIRED) +add_subdirectory(coro) + # Keep all source files sorted!!! add_executable(test_pymrc + test_asyncio_runnable.cpp test_codable_pyobject.cpp test_executor.cpp test_main.cpp diff --git a/python/mrc/_pymrc/tests/coro/CMakeLists.txt b/python/mrc/_pymrc/tests/coro/CMakeLists.txt new file mode 100644 index 000000000..788d04832 --- /dev/null +++ b/python/mrc/_pymrc/tests/coro/CMakeLists.txt @@ -0,0 +1,29 @@ +# ============================================================================= +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +list(APPEND CMAKE_MESSAGE_CONTEXT "coro") + +set(MODULE_SOURCE_FILES) + +# Add the module file +list(APPEND MODULE_SOURCE_FILES module.cpp) + +# Create the python module +mrc_add_pybind11_module(coro + INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/include + SOURCE_FILES ${MODULE_SOURCE_FILES} + LINK_TARGETS mrc::pymrc +) + +list(POP_BACK CMAKE_MESSAGE_CONTEXT) diff --git a/python/mrc/_pymrc/tests/coro/module.cpp b/python/mrc/_pymrc/tests/coro/module.cpp new file mode 100644 index 000000000..c5332c78c --- /dev/null +++ b/python/mrc/_pymrc/tests/coro/module.cpp @@ -0,0 +1,70 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +mrc::coroutines::Task subtract(int a, int b) +{ + co_return a - b; +} + +mrc::coroutines::Task call_fib_async(mrc::pymrc::PyHolder fib, int value, int minus) +{ + auto result = co_await subtract(value, minus); + co_return co_await mrc::pymrc::coro::PyTaskToCppAwaitable([](auto fib, auto result) { + pybind11::gil_scoped_acquire acquire; + return fib(result); + }(fib, result)); +} + +mrc::coroutines::Task raise_at_depth_async(mrc::pymrc::PyHolder fn, int depth) +{ + if (depth <= 0) + { + throw std::runtime_error("depth reached zero in c++"); + } + + co_return co_await mrc::pymrc::coro::PyTaskToCppAwaitable([](auto fn, auto depth) { + pybind11::gil_scoped_acquire acquire; + return fn(depth - 1); + }(fn, depth)); +} + +mrc::coroutines::Task call_async(mrc::pymrc::PyHolder fn) +{ + co_return co_await mrc::pymrc::coro::PyTaskToCppAwaitable([](auto fn) { + pybind11::gil_scoped_acquire acquire; + return fn(); + }(fn)); +} + +PYBIND11_MODULE(coro, _module) +{ + pybind11::module_::import("mrc.core.coro"); // satisfies automatic type conversions for tasks + + _module.def("call_fib_async", &call_fib_async); + _module.def("raise_at_depth_async", &raise_at_depth_async); + _module.def("call_async", &call_async); +} diff --git a/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp b/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp new file mode 100644 index 000000000..a46bea824 --- /dev/null +++ b/python/mrc/_pymrc/tests/test_asyncio_runnable.cpp @@ -0,0 +1,331 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pymrc/asyncio_runnable.hpp" +#include "pymrc/coro.hpp" +#include "pymrc/executor.hpp" +#include "pymrc/pipeline.hpp" +#include "pymrc/utilities/object_wrappers.hpp" + +#include "mrc/coroutines/async_generator.hpp" +#include "mrc/coroutines/sync_wait.hpp" +#include "mrc/coroutines/task.hpp" +#include "mrc/node/rx_sink.hpp" +#include "mrc/node/rx_source.hpp" +#include "mrc/options/engine_groups.hpp" +#include "mrc/options/options.hpp" +#include "mrc/options/topology.hpp" +#include "mrc/runnable/types.hpp" +#include "mrc/segment/builder.hpp" +#include "mrc/segment/object.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +namespace pymrc = mrc::pymrc; +using namespace std::string_literals; +using namespace py::literals; + +class __attribute__((visibility("default"))) TestAsyncioRunnable : public ::testing::Test +{ + public: + static void SetUpTestSuite() + { + m_interpreter = std::make_unique(); + pybind11::gil_scoped_acquire acquire; + pybind11::module_::import("mrc.core.coro"); + } + + static void TearDownTestSuite() + { + m_interpreter.reset(); + } + + private: + static std::unique_ptr m_interpreter; +}; + +std::unique_ptr TestAsyncioRunnable::m_interpreter; + +class __attribute__((visibility("default"))) PythonCallbackAsyncioRunnable : public pymrc::AsyncioRunnable +{ + public: + PythonCallbackAsyncioRunnable(pymrc::PyObjectHolder operation) : m_operation(std::move(operation)) {} + + mrc::coroutines::AsyncGenerator on_data(int&& value) override + { + py::gil_scoped_acquire acquire; + + auto coroutine = m_operation(py::cast(value)); + + pymrc::PyObjectHolder result; + { + py::gil_scoped_release release; + + result = co_await pymrc::coro::PyTaskToCppAwaitable(std::move(coroutine)); + } + + auto result_casted = py::cast(result); + + py::gil_scoped_release release; + + co_yield result_casted; + }; + + private: + pymrc::PyObjectHolder m_operation; +}; + +TEST_F(TestAsyncioRunnable, UseAsyncioTasks) +{ + py::object globals = py::globals(); + py::exec( + R"( + async def fn(value): + import asyncio + await asyncio.sleep(0) + return value * 2 + )", + globals); + + pymrc::PyObjectHolder fn = static_cast(globals["fn"]); + + ASSERT_FALSE(fn.is_none()); + + std::atomic counter = 0; + pymrc::Pipeline p; + + auto init = [&counter, &fn](mrc::segment::IBuilder& seg) { + auto src = seg.make_source("src", [](rxcpp::subscriber& s) { + if (s.is_subscribed()) + { + s.on_next(5); + s.on_next(10); + } + + s.on_completed(); + }); + + auto internal = seg.construct_object("internal", fn); + + auto sink = seg.make_sink("sink", [&counter](int x) { + counter.fetch_add(x, std::memory_order_relaxed); + }); + + seg.make_edge(src, internal); + seg.make_edge(internal, sink); + }; + + p.make_segment("seg1"s, init); + p.make_segment("seg2"s, init); + + auto options = std::make_shared(); + options->topology().user_cpuset("0"); + // AsyncioRunnable only works with the Thread engine due to asyncio loops being thread-specific. + options->engine_factories().set_default_engine_type(mrc::runnable::EngineType::Thread); + + pymrc::Executor exec{options}; + exec.register_pipeline(p); + + exec.start(); + exec.join(); + + EXPECT_EQ(counter, 60); +} + +TEST_F(TestAsyncioRunnable, UseAsyncioGeneratorThrows) +{ + // pybind11::module_::import("mrc.core.coro"); + + py::object globals = py::globals(); + py::exec( + R"( + async def fn(value): + yield value + )", + globals); + + pymrc::PyObjectHolder fn = static_cast(globals["fn"]); + + ASSERT_FALSE(fn.is_none()); + + std::atomic counter = 0; + pymrc::Pipeline p; + + auto init = [&counter, &fn](mrc::segment::IBuilder& seg) { + auto src = seg.make_source("src", [](rxcpp::subscriber& s) { + if (s.is_subscribed()) + { + s.on_next(5); + s.on_next(10); + } + + s.on_completed(); + }); + + auto internal = seg.construct_object("internal", fn); + + auto sink = seg.make_sink("sink", [&counter](int x) { + counter.fetch_add(x, std::memory_order_relaxed); + }); + + seg.make_edge(src, internal); + seg.make_edge(internal, sink); + }; + + p.make_segment("seg1"s, init); + p.make_segment("seg2"s, init); + + auto options = std::make_shared(); + options->topology().user_cpuset("0"); + // AsyncioRunnable only works with the Thread engine due to asyncio loops being thread-specific. + options->engine_factories().set_default_engine_type(mrc::runnable::EngineType::Thread); + + pymrc::Executor exec{options}; + exec.register_pipeline(p); + + exec.start(); + + ASSERT_THROW(exec.join(), std::runtime_error); +} + +TEST_F(TestAsyncioRunnable, UseAsyncioTasksThrows) +{ + // pybind11::module_::import("mrc.core.coro"); + + py::object globals = py::globals(); + py::exec( + R"( + async def fn(value): + raise RuntimeError("oops") + )", + globals); + + pymrc::PyObjectHolder fn = static_cast(globals["fn"]); + + ASSERT_FALSE(fn.is_none()); + + std::atomic counter = 0; + pymrc::Pipeline p; + + auto init = [&counter, &fn](mrc::segment::IBuilder& seg) { + auto src = seg.make_source("src", [](rxcpp::subscriber& s) { + if (s.is_subscribed()) + { + s.on_next(5); + s.on_next(10); + } + + s.on_completed(); + }); + + auto internal = seg.construct_object("internal", fn); + + auto sink = seg.make_sink("sink", [&counter](int x) { + counter.fetch_add(x, std::memory_order_relaxed); + }); + + seg.make_edge(src, internal); + seg.make_edge(internal, sink); + }; + + p.make_segment("seg1"s, init); + p.make_segment("seg2"s, init); + + auto options = std::make_shared(); + options->topology().user_cpuset("0"); + // AsyncioRunnable only works with the Thread engine due to asyncio loops being thread-specific. + options->engine_factories().set_default_engine_type(mrc::runnable::EngineType::Thread); + + pymrc::Executor exec{options}; + exec.register_pipeline(p); + + exec.start(); + + ASSERT_THROW(exec.join(), std::runtime_error); +} + +template +auto run_operation(OperationT& operation) -> mrc::coroutines::Task +{ + auto stop_source = std::stop_source(); + + auto coro = [](auto& operation, auto stop_source) -> mrc::coroutines::Task { + try + { + auto value = co_await operation(); + stop_source.request_stop(); + co_return value; + } catch (...) + { + stop_source.request_stop(); + throw; + } + }(operation, stop_source); + + coro.resume(); + + while (not stop_source.stop_requested()) + { + if (boost::fibers::has_ready_fibers()) + { + boost::this_fiber::yield(); + } + } + + co_return co_await coro; +} + +TEST_F(TestAsyncioRunnable, BoostFutureAwaitableOperationCanReturn) +{ + auto operation = mrc::pymrc::BoostFutureAwaitableOperation([]() { + using namespace std::chrono_literals; + boost::this_fiber::sleep_for(10ms); + return 5; + }); + + ASSERT_EQ(mrc::coroutines::sync_wait(run_operation(operation)), 5); +} + +TEST_F(TestAsyncioRunnable, BoostFutureAwaitableOperationCanThrow) +{ + auto operation = mrc::pymrc::BoostFutureAwaitableOperation([]() { + throw std::runtime_error("oops"); + return 5; + }); + + ASSERT_THROW(mrc::coroutines::sync_wait(run_operation(operation)), std::runtime_error); +} diff --git a/python/mrc/_pymrc/tests/test_executor.cpp b/python/mrc/_pymrc/tests/test_executor.cpp index 41e284d91..20ea8b10d 100644 --- a/python/mrc/_pymrc/tests/test_executor.cpp +++ b/python/mrc/_pymrc/tests/test_executor.cpp @@ -33,11 +33,9 @@ #include #include -#include #include #include #include -#include namespace py = pybind11; namespace pymrc = mrc::pymrc; diff --git a/python/mrc/_pymrc/tests/test_pipeline.cpp b/python/mrc/_pymrc/tests/test_pipeline.cpp index 68091ba14..7b375d21a 100644 --- a/python/mrc/_pymrc/tests/test_pipeline.cpp +++ b/python/mrc/_pymrc/tests/test_pipeline.cpp @@ -31,9 +31,7 @@ #include "mrc/options/topology.hpp" #include "mrc/segment/builder.hpp" #include "mrc/segment/object.hpp" -#include "mrc/types.hpp" -#include #include #include #include @@ -46,7 +44,6 @@ #include #include #include -#include #include #include #include diff --git a/python/mrc/_pymrc/tests/test_serializers.cpp b/python/mrc/_pymrc/tests/test_serializers.cpp index cbf5147c5..e6c72e27c 100644 --- a/python/mrc/_pymrc/tests/test_serializers.cpp +++ b/python/mrc/_pymrc/tests/test_serializers.cpp @@ -28,7 +28,6 @@ #include #include // IWYU pragma: keep -#include #include #include #include diff --git a/python/mrc/_pymrc/tests/test_utils.cpp b/python/mrc/_pymrc/tests/test_utils.cpp index a802009fc..713bdc5f4 100644 --- a/python/mrc/_pymrc/tests/test_utils.cpp +++ b/python/mrc/_pymrc/tests/test_utils.cpp @@ -34,7 +34,6 @@ #include #include #include -#include #include #include #include diff --git a/python/mrc/benchmarking/watchers.cpp b/python/mrc/benchmarking/watchers.cpp index 2a4b3418f..920826239 100644 --- a/python/mrc/benchmarking/watchers.cpp +++ b/python/mrc/benchmarking/watchers.cpp @@ -26,11 +26,9 @@ #include // IWYU pragma: keep #include -#include #include #include #include -#include namespace mrc::pymrc { namespace py = pybind11; diff --git a/python/mrc/core/CMakeLists.txt b/python/mrc/core/CMakeLists.txt index d635e071f..f04b17f1f 100644 --- a/python/mrc/core/CMakeLists.txt +++ b/python/mrc/core/CMakeLists.txt @@ -16,6 +16,7 @@ list(APPEND CMAKE_MESSAGE_CONTEXT "core") mrc_add_pybind11_module(common SOURCE_FILES common.cpp) +mrc_add_pybind11_module(coro SOURCE_FILES coro.cpp) mrc_add_pybind11_module(executor SOURCE_FILES executor.cpp) mrc_add_pybind11_module(logging SOURCE_FILES logging.cpp) mrc_add_pybind11_module(node SOURCE_FILES node.cpp) diff --git a/python/mrc/core/common.cpp b/python/mrc/core/common.cpp index 741fec61b..7dde55b4b 100644 --- a/python/mrc/core/common.cpp +++ b/python/mrc/core/common.cpp @@ -18,21 +18,15 @@ #include "pymrc/port_builders.hpp" #include "pymrc/types.hpp" -#include "mrc/node/rx_sink_base.hpp" -#include "mrc/node/rx_source_base.hpp" -#include "mrc/types.hpp" #include "mrc/utils/string_utils.hpp" #include "mrc/version.hpp" -#include #include #include #include -#include #include #include -#include namespace mrc::pymrc { diff --git a/python/mrc/core/coro.cpp b/python/mrc/core/coro.cpp new file mode 100644 index 000000000..d647a7b11 --- /dev/null +++ b/python/mrc/core/coro.cpp @@ -0,0 +1,67 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pymrc/coro.hpp" + +#include +#include +#include +#include +#include +#include // IWYU pragma: keep + +#include +#include +#include +#include + +namespace mrc::pymrc::coro { + +namespace py = pybind11; + +PYBIND11_MODULE(coro, _module) +{ + _module.doc() = R"pbdoc( + ----------------------- + .. currentmodule:: morpheus.llm + .. autosummary:: + :toctree: _generate + + )pbdoc"; + + py::class_>(_module, "CppToPyAwaitable") + .def(py::init<>()) + .def("__iter__", &CppToPyAwaitable::iter) + .def("__await__", &CppToPyAwaitable::await) + .def("__next__", &CppToPyAwaitable::next); + + py::class_>( // + _module, + "BoostFibersMainPyAwaitable") + .def(py::init<>()); + + _module.def("wrap_coroutine", [](coroutines::Task> fn) -> coroutines::Task { + DCHECK_EQ(PyGILState_Check(), 0) << "Should not have the GIL when resuming a C++ coroutine"; + + auto strings = co_await fn; + + co_return strings[0]; + }); + + // _module.attr("__version__") = + // MRC_CONCAT_STR(morpheus_VERSION_MAJOR << "." << morpheus_VERSION_MINOR << "." << morpheus_VERSION_PATCH); +} +} // namespace mrc::pymrc::coro diff --git a/python/mrc/core/operators.cpp b/python/mrc/core/operators.cpp index b74ff96ec..be931fc27 100644 --- a/python/mrc/core/operators.cpp +++ b/python/mrc/core/operators.cpp @@ -28,7 +28,6 @@ #include #include // IWYU pragma: keep -#include #include namespace mrc::pymrc { diff --git a/python/mrc/core/pipeline.cpp b/python/mrc/core/pipeline.cpp index 2f1dcf970..a6e9f0b5e 100644 --- a/python/mrc/core/pipeline.cpp +++ b/python/mrc/core/pipeline.cpp @@ -27,7 +27,6 @@ #include #include // IWYU pragma: keep -#include #include namespace mrc::pymrc { diff --git a/python/mrc/core/segment.cpp b/python/mrc/core/segment.cpp index ed87f83f2..addba6813 100644 --- a/python/mrc/core/segment.cpp +++ b/python/mrc/core/segment.cpp @@ -38,12 +38,9 @@ #include #include -#include -#include #include #include #include -#include namespace mrc::pymrc { diff --git a/python/mrc/core/segment/module_definitions/mirror_tap_orchestrator.cpp b/python/mrc/core/segment/module_definitions/mirror_tap_orchestrator.cpp index 570bd3c69..49aee1f7b 100644 --- a/python/mrc/core/segment/module_definitions/mirror_tap_orchestrator.cpp +++ b/python/mrc/core/segment/module_definitions/mirror_tap_orchestrator.cpp @@ -26,18 +26,14 @@ #include "mrc/experimental/modules/stream_buffer/stream_buffer_module.hpp" #include "mrc/modules/module_registry.hpp" #include "mrc/modules/module_registry_util.hpp" -#include "mrc/node/operators/broadcast.hpp" -#include "mrc/node/rx_sink.hpp" -#include "mrc/node/rx_source.hpp" #include "mrc/version.hpp" #include #include // IWYU pragma: keep #include #include +#include -#include -#include #include #include #include diff --git a/python/mrc/core/segment/module_definitions/segment_module_registry.cpp b/python/mrc/core/segment/module_definitions/segment_module_registry.cpp index 0ae7b5728..86d21f65c 100644 --- a/python/mrc/core/segment/module_definitions/segment_module_registry.cpp +++ b/python/mrc/core/segment/module_definitions/segment_module_registry.cpp @@ -25,12 +25,9 @@ #include #include // IWYU pragma: keep #include -#include #include // IWYU pragma: keep -#include #include -#include #include #include diff --git a/python/mrc/core/segment/module_definitions/segment_modules.cpp b/python/mrc/core/segment/module_definitions/segment_modules.cpp index 08332dd40..5cc22f61d 100644 --- a/python/mrc/core/segment/module_definitions/segment_modules.cpp +++ b/python/mrc/core/segment/module_definitions/segment_modules.cpp @@ -25,9 +25,7 @@ #include #include -#include #include -#include namespace mrc::pymrc { diff --git a/python/mrc/core/subscriber.cpp b/python/mrc/core/subscriber.cpp index 656ff6884..0a9458f9f 100644 --- a/python/mrc/core/subscriber.cpp +++ b/python/mrc/core/subscriber.cpp @@ -27,8 +27,8 @@ #include // IWYU pragma: keep #include // IWYU pragma: keep(for call_guard) #include +#include -#include #include #include diff --git a/python/mrc/tests/sample_modules.cpp b/python/mrc/tests/sample_modules.cpp index 8bd6d354e..041d67a91 100644 --- a/python/mrc/tests/sample_modules.cpp +++ b/python/mrc/tests/sample_modules.cpp @@ -20,15 +20,12 @@ #include "pymrc/utils.hpp" #include "mrc/modules/module_registry_util.hpp" -#include "mrc/node/rx_source.hpp" #include "mrc/utils/string_utils.hpp" #include "mrc/version.hpp" #include #include -#include -#include #include #include diff --git a/python/mrc/tests/test_edges.cpp b/python/mrc/tests/test_edges.cpp index 1e9cc0359..ccac5a2d7 100644 --- a/python/mrc/tests/test_edges.cpp +++ b/python/mrc/tests/test_edges.cpp @@ -24,29 +24,22 @@ #include "mrc/channel/status.hpp" #include "mrc/edge/edge_connector.hpp" -#include "mrc/node/rx_sink_base.hpp" -#include "mrc/node/rx_source_base.hpp" #include "mrc/segment/builder.hpp" #include "mrc/segment/object.hpp" -#include "mrc/types.hpp" #include "mrc/utils/string_utils.hpp" #include "mrc/version.hpp" -#include #include #include #include #include -#include #include #include -#include #include #include #include #include -#include namespace mrc::pytests { diff --git a/python/mrc/tests/utils.cpp b/python/mrc/tests/utils.cpp index 35a64d6e5..d700df879 100644 --- a/python/mrc/tests/utils.cpp +++ b/python/mrc/tests/utils.cpp @@ -21,6 +21,7 @@ #include "mrc/version.hpp" #include +#include // for gil_scoped_acquire #include #include @@ -30,6 +31,16 @@ namespace mrc::pytests { namespace py = pybind11; +// Simple test class which uses pybind11's `gil_scoped_acquire` class in the destructor. Needed to repro #362 +struct RequireGilInDestructor +{ + ~RequireGilInDestructor() + { + // Grab the GIL + py::gil_scoped_acquire gil; + } +}; + PYBIND11_MODULE(utils, py_mod) { py_mod.doc() = R"pbdoc()pbdoc"; @@ -48,6 +59,8 @@ PYBIND11_MODULE(utils, py_mod) }, py::arg("msg") = ""); + py::class_(py_mod, "RequireGilInDestructor").def(py::init<>()); + py_mod.attr("__version__") = MRC_CONCAT_STR(mrc_VERSION_MAJOR << "." << mrc_VERSION_MINOR << "." << mrc_VERSION_PATCH); } diff --git a/python/tests/test_coro.py b/python/tests/test_coro.py new file mode 100644 index 000000000..940160f18 --- /dev/null +++ b/python/tests/test_coro.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import pytest + +from mrc._pymrc.tests.coro.coro import call_async +from mrc._pymrc.tests.coro.coro import call_fib_async +from mrc._pymrc.tests.coro.coro import raise_at_depth_async +from mrc.core import coro + + +@pytest.mark.asyncio +async def test_coro(): + + # hit_inside = False + + async def inner(): + + # nonlocal hit_inside + + result = await coro.wrap_coroutine(asyncio.sleep(1, result=['a', 'b', 'c'])) + + # hit_inside = True + + return [result] + + returned_val = await coro.wrap_coroutine(inner()) + + assert returned_val == 'a' + # assert hit_inside + + +@pytest.mark.asyncio +async def test_coro_many(): + + expected_count = 1000 + hit_count = 0 + + start_time = asyncio.get_running_loop().time() + + async def inner(): + + nonlocal hit_count + + await asyncio.sleep(0.1) + + hit_count += 1 + + return ['a', 'b', 'c'] + + coros = [coro.wrap_coroutine(inner()) for _ in range(expected_count)] + + returned_vals = await asyncio.gather(*coros) + + end_time = asyncio.get_running_loop().time() + + assert returned_vals == ['a'] * expected_count + assert hit_count == expected_count + assert (end_time - start_time) < 1.5 + + +@pytest.mark.asyncio +async def test_python_cpp_async_interleave(): + + def fib(n): + if n < 0: + raise ValueError() + + if n < 2: + return 1 + + return fib(n - 1) + fib(n - 2) + + async def fib_async(n): + if n < 0: + raise ValueError() + + if n < 2: + return 1 + + task_a = call_fib_async(fib_async, n, 1) + task_b = call_fib_async(fib_async, n, 2) + + [a, b] = await asyncio.gather(task_a, task_b) + + return a + b + + assert fib(15) == await fib_async(15) + + +@pytest.mark.asyncio +async def test_python_cpp_async_exception(): + + async def py_raise_at_depth_async(n: int): + if n <= 0: + raise RuntimeError("depth reached zero in python") + + await raise_at_depth_async(py_raise_at_depth_async, n - 1) + + depth = 100 + + with pytest.raises(RuntimeError) as ex: + await raise_at_depth_async(py_raise_at_depth_async, depth + 1) + assert "python" in str(ex.value) + + with pytest.raises(RuntimeError) as ex: + await raise_at_depth_async(py_raise_at_depth_async, depth) + assert "c++" in str(ex.value) + + +@pytest.mark.asyncio +async def test_can_cancel_coroutine_from_python(): + + counter = 0 + + async def increment_recursively(): + nonlocal counter + await asyncio.sleep(0) + counter += 1 + await call_async(increment_recursively) + + task = asyncio.ensure_future(call_async(increment_recursively)) + + await asyncio.sleep(0) + assert counter == 0 + await asyncio.sleep(0) + await asyncio.sleep(0) + assert counter == 1 + await asyncio.sleep(0) + await asyncio.sleep(0) + assert counter == 2 + + task.cancel() + + with pytest.raises(asyncio.exceptions.CancelledError): + await task + + assert counter == 3 diff --git a/python/tests/test_gil_tls.py b/python/tests/test_gil_tls.py new file mode 100644 index 000000000..eca5a23d7 --- /dev/null +++ b/python/tests/test_gil_tls.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading + +import mrc +from mrc.tests.utils import RequireGilInDestructor + +TLS = threading.local() + + +def test_gil_thread_local_storage(): + """ + Test to reproduce issue #362 + No asserts needed if it doesn't segfault, then we're good + """ + + def source_gen(): + x = RequireGilInDestructor() + TLS.x = x + yield x + + def init_seg(builder: mrc.Builder): + builder.make_source("souce_gen", source_gen) + + pipe = mrc.Pipeline() + pipe.make_segment("seg1", init_seg) + + options = mrc.Options() + executor = mrc.Executor(options) + executor.register_pipeline(pipe) + executor.start() + executor.join()