Skip to content

Commit

Permalink
Fix CPU build
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Sep 11, 2023
1 parent c6cff74 commit f015c8c
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 24 deletions.
14 changes: 4 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ else()
rapids_cuda_init_architectures(RAPIDS_TRITON_BACKEND)
project(RAPIDS_TRITON_BACKEND VERSION 22.10.00 LANGUAGES CXX CUDA)
else()
project(RAPIDS_TRITON_BACKEND VERSION 22.10.00 LANGUAGES CXX)
project(RAPIDS_TRITON_BACKEND VERSION 22.10.00 LANGUAGES CXX CUDA)
endif()

##############################################################################
Expand Down Expand Up @@ -166,11 +166,7 @@ else()
set(RAPIDS_TRITON_MIN_VERSION_rapids_projects "${RAPIDS_DEPENDENCIES_VERSION}.00")
set(RAPIDS_TRITON_BRANCH_VERSION_rapids_projects "${RAPIDS_DEPENDENCIES_VERSION}")

if(TRITON_ENABLE_GPU)
include(cmake/thirdparty/get_cuml.cmake)
else()
include(cmake/thirdparty/get_treelite.cmake)
endif()
include(cmake/thirdparty/get_cuml.cmake)
include(cmake/thirdparty/get_rapids-triton.cmake)

if(BUILD_TESTS)
Expand Down Expand Up @@ -230,7 +226,7 @@ else()

target_link_libraries(${BACKEND_TARGET}
PRIVATE
$<$<BOOL:${TRITON_ENABLE_GPU}>:cuml++>
cuml++
${TREELITE_LIBS}
rapids_triton::rapids_triton
triton-core-serverstub
Expand All @@ -240,9 +236,7 @@ else()
OpenMP::OpenMP_CXX
)

if(TRITON_ENABLE_GPU)
list(APPEND BACKEND_TARGET "cuml++")
endif()
list(APPEND BACKEND_TARGET "cuml++")

if(NOT TRITON_FIL_USE_TREELITE_STATIC)
list(APPEND BACKEND_TARGET ${TREELITE_LIBS_NO_PREFIX})
Expand Down
9 changes: 5 additions & 4 deletions cmake/thirdparty/get_cuml.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

function(find_and_configure_cuml)

set(oneValueArgs VERSION FORK PINNED_TAG USE_TREELITE_STATIC)
set(oneValueArgs VERSION FORK PINNED_TAG USE_TREELITE_STATIC TRITON_ENABLE_GPU)
cmake_parse_arguments(PKG "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN} )

Expand All @@ -43,6 +43,7 @@ function(find_and_configure_cuml)
"BUILD_CUML_STD_COMMS OFF"
"BUILD_SHARED_LIBS ON"
"CUML_USE_TREELITE_STATIC ${PKG_USE_TREELITE_STATIC}"
"CUML_ENABLE_GPU ${PKG_TRITON_ENABLE_GPU}"
"USE_CCACHE ON"
"RAFT_COMPILE_LIBRARIES OFF"
"RAFT_ENABLE_NN_DEPENDENCIES OFF"
Expand All @@ -56,7 +57,7 @@ endfunction()
# To use a different RAFT locally, set the CMake variable
# CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_cuml(VERSION ${RAPIDS_TRITON_MIN_VERSION_rapids_projects}
FORK rapidsai
PINNED_TAG branch-23.08
FORK hcho3
PINNED_TAG fix_cpu_fil
USE_TREELITE_STATIC ${TRITON_FIL_USE_TREELITE_STATIC}
)
TRITON_ENABLE_GPU ${TRITON_ENABLE_GPU})
8 changes: 2 additions & 6 deletions src/cpu_forest_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,11 @@ namespace filex = ML::experimental::fil;
template <>
struct ForestModel<rapids::HostMemory> {
ForestModel() = default;
// TODO(hcho3): Add a filex::forest_model::predict() that does not require
// a RAFT handle. Currently, we need to pass a RAFT handle, which in turn
// requires a working CUDA stream.
using device_id_t = int;
ForestModel(
device_id_t device_id, cudaStream_t stream,
std::shared_ptr<TreeliteModel> tl_model, bool use_new_fil)
: device_id_{device_id}, raft_handle_{stream}, tl_model_{tl_model},
: device_id_{device_id}, tl_model_{tl_model},
new_fil_model_{[this, use_new_fil]() {
auto result = std::optional<filex::forest_model>{};
if (use_new_fil) {
Expand Down Expand Up @@ -100,7 +97,7 @@ struct ForestModel<rapids::HostMemory> {
// TODO(hcho3): Revise new FIL so that it takes in (const io_t*) type for
// input buffer
new_fil_model_->predict(
raft_proto::handle_t{raft_handle_}, output_buffer.data(),
raft_proto::handle_t{}, output_buffer.data(),
const_cast<float*>(input.data()), samples,
get_raft_proto_device_type(output.mem_type()),
get_raft_proto_device_type(input.mem_type()),
Expand All @@ -123,7 +120,6 @@ struct ForestModel<rapids::HostMemory> {


private:
raft::handle_t raft_handle_;
std::shared_ptr<TreeliteModel> tl_model_;
device_id_t device_id_;
// TODO(hcho3): Make filex::forest_model::predict() a const method
Expand Down
5 changes: 1 addition & 4 deletions src/forest_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@

#ifdef TRITON_ENABLE_GPU
#include <cuda_runtime_api.h>

#include <cuml/experimental/fil/detail/raft_proto/device_type.hpp>
#endif

#include <names.h>
#include <tl_model.h>

#include <cstddef>
#include <cuml/experimental/fil/detail/raft_proto/device_type.hpp>
#include <memory>
#include <rapids_triton/exceptions.hpp>
#include <rapids_triton/memory/buffer.hpp>
Expand Down Expand Up @@ -66,7 +65,6 @@ struct ForestModel {
}
};

#ifdef TRITON_ENABLE_GPU
// TODO(hcho3): Remove this once raft_proto becomes part of RAFT or
// Rapids-Triton
raft_proto::device_type
Expand All @@ -78,6 +76,5 @@ get_raft_proto_device_type(rapids::MemoryType mem_type)
return raft_proto::device_type::cpu;
}
}
#endif

}}} // namespace triton::backend::NAMESPACE

0 comments on commit f015c8c

Please sign in to comment.