From 19451ab847fc48aa4fab311ba92ea811ff517c22 Mon Sep 17 00:00:00 2001 From: Jason Du <171721794+jadu-nv@users.noreply.github.com> Date: Sat, 17 Aug 2024 23:01:18 -0700 Subject: [PATCH] Revert changes made to CMake and add thorough commenting --- .../include/mrc/runtime/remote_descriptor.hpp | 42 ++++++ .../data_plane/data_plane_resources.cpp | 37 +++-- .../data_plane/data_plane_resources.hpp | 13 +- .../src/internal/ucx/registration_cache.cpp | 1 + .../src/internal/ucx/registration_cache.hpp | 14 +- cpp/mrc/src/public/codable/decode.cpp | 2 + cpp/mrc/src/public/codable/encode.cpp | 5 + .../src/public/runtime/remote_descriptor.cpp | 5 +- cpp/mrc/src/tests/CMakeLists.txt | 18 +++ cpp/mrc/src/tests/test_codable.cpp | 129 ++++-------------- cpp/mrc/src/tests/test_network.cpp | 54 ++++---- protos/mrc/protos/codable.proto | 8 +- 12 files changed, 180 insertions(+), 148 deletions(-) diff --git a/cpp/mrc/include/mrc/runtime/remote_descriptor.hpp b/cpp/mrc/include/mrc/runtime/remote_descriptor.hpp index 227cfb89c..143d7fd00 100644 --- a/cpp/mrc/include/mrc/runtime/remote_descriptor.hpp +++ b/cpp/mrc/include/mrc/runtime/remote_descriptor.hpp @@ -418,21 +418,58 @@ std::unique_ptr> TypedValueDescriptor::from_local( new TypedValueDescriptor(mrc::codable::decode2(local_descriptor->encoded_object()))); } +/** + * @brief Descriptor2 class used to faciliate communication between any arbitrary pair of machines. Supports multi-node, + * multi-gpu communication, and asynchronous data transfer. + */ class Descriptor2 { public: + /** + * @brief Gets the protobuf object associated with this descriptor instance + * + * @return codable::protos::DescriptorObject& + */ virtual codable::protos::DescriptorObject& encoded_object(); + /** + * @brief Serialize the encoded object stored by this descriptor into a byte stream for remote communication + * + * @param mr Instance of memory_resource for allocating a memory_buffer to return + * @return memory::buffer + */ memory::buffer serialize(std::shared_ptr mr); + /** + * @brief Deserialize the encoded object stored by this descriptor into a class T instance + * + * @return T + */ template [[nodiscard]] const T deserialize(); + /** + * @brief Creates a Descriptor2 instance from a class T value + * + * @param value class T instance + * @param data_plane_resources reference to DataPlaneResources2 for remote communication + * @return std::shared_ptr + */ template static std::shared_ptr create_from_value(T value, data_plane::DataPlaneResources2& data_plane_resources); + /** + * @brief Creates a Descriptor2 instance from a byte stream + * + * @param view byte stream + * @param data_plane_resources reference to DataPlaneResources2 for remote communication + * @return std::shared_ptr + */ static std::shared_ptr create_from_bytes(memory::buffer_view&& view, data_plane::DataPlaneResources2& data_plane_resources); + /** + * @brief Fetches all deferred payloads from the sending remote machine + */ void fetch_remote_payloads(); protected: @@ -452,12 +489,16 @@ class Descriptor2 data_plane::DataPlaneResources2& m_data_plane_resources; }; +/** + * @brief Class used for type erasure of Descriptor2 when serialized with class T instance + */ template class TypedDescriptor : public Descriptor2 { public: codable::protos::DescriptorObject& encoded_object() { + // If the encoded object does not exist yet, lazily create it if (!m_encoded_object) { m_encoded_object = std::move(mrc::codable::encode2(std::any_cast(m_value))); @@ -470,6 +511,7 @@ class TypedDescriptor : public Descriptor2 template friend std::shared_ptr Descriptor2::create_from_value(U value, data_plane::DataPlaneResources2& data_plane_resources); + // Private constructor to prohibit instantiation of this class outside of use in create_from_value TypedDescriptor(T value, data_plane::DataPlaneResources2& data_plane_resources): Descriptor2(std::move(value), data_plane_resources) {} }; diff --git a/cpp/mrc/src/internal/data_plane/data_plane_resources.cpp b/cpp/mrc/src/internal/data_plane/data_plane_resources.cpp index 81cf66c1c..5c452ad24 100644 --- a/cpp/mrc/src/internal/data_plane/data_plane_resources.cpp +++ b/cpp/mrc/src/internal/data_plane/data_plane_resources.cpp @@ -165,7 +165,10 @@ DataPlaneResources2::DataPlaneResources2() m_registration_cache3 = std::make_shared(m_context); // When DataPlanResources2 initializes, m_max_remote_descriptors is initialized to std::numeric_limits::max() - // By default, m_remote_descriptors_semaphore should have capacity = practical limit, 100000 + // By default, the following should have capacity = practical limit, 100000 + m_recv_descriptors = std::unique_ptr>>( + new coroutines::ClosableRingBuffer>({.capacity = uint64_t(100000)})); + m_remote_descriptors_semaphore = std::unique_ptr( new coroutines::Semaphore{{.capacity = static_cast(100000)}}); @@ -180,6 +183,7 @@ DataPlaneResources2::DataPlaneResources2() auto* message = reinterpret_cast( req->getRecvBuffer()->data()); + // Await for callback function to decrement shared_ptr reference count or signal end-of-life of a descriptor coroutines::sync_wait(complete_remote_pull(message)); }); m_worker->registerAmReceiverCallback( @@ -199,12 +203,16 @@ DataPlaneResources2::DataPlaneResources2() req->getRecvBuffer()->getSize(), mrc::memory::memory_kind::host); - std::shared_ptr recv_descriptor = runtime::Descriptor2::create_from_bytes(std::move(buffer_view), *this); + // Create the descriptor object from received data + // Note that we do not immediately call fetch_remote_payloads. This callback function runs on a UCXX thread and + // should freed ASAP. Defer remote payload pulling when descriptor object is actually consumed + std::shared_ptr recv_descriptor = + runtime::Descriptor2::create_from_bytes(std::move(buffer_view), *this); // Although ClosableRingBuffer::write is a coroutine, write always completes instantaneously without awaiting. // ClosablRingBuffer size is always >= m_max_remote_descriptors, so there is always an empty slot. auto write_descriptor = [this, recv_descriptor]() -> coroutines::Task { - co_await m_recv_descriptors.write(recv_descriptor); + co_await m_recv_descriptors->write(recv_descriptor); co_return; }; @@ -340,7 +348,6 @@ std::shared_ptr DataPlaneResources2::memory_recv_async(std::share uintptr_t remote_addr, const std::string& serialized_rkey) { - // Const cast away because UCXX only accepts void* auto rkey = ucxx::createRemoteKeyFromSerialized(endpoint, serialized_rkey); auto request = endpoint->memGet(addr, bytes, rkey); @@ -408,8 +415,9 @@ coroutines::Task> DataPlaneResources2::await_am_s { coroutines::Event event{}; - // Const cast away because UCXX only accepts void* - auto request = endpoint->amSend(const_cast(buffer_view.data()), + // Use AmReceiverCallbackInfo to handle receiving/processing message downstream and lambda callback function + // to signal send request completion + auto request = endpoint->amSend(const_cast(buffer_view.data()), // Const cast away, UCXX only accepts void* buffer_view.bytes(), ucx::to_ucs_memory_type(buffer_view.kind()), ucxx::AmReceiverCallbackInfo("MRC", 1), @@ -440,7 +448,6 @@ coroutines::Task> DataPlaneResources2::await_send std::shared_ptr endpoint) { // Wait until there is an empty slot to register remote descriptor - // Require register_remote_descriptor before descriptor serialize as serialize requires object_id to be in the protobuf uint64_t object_id = co_await this->register_remote_descriptor(send_descriptor); // Serialize the descriptor's protobuf into a byte stream for remote communication @@ -452,9 +459,11 @@ coroutines::Task> DataPlaneResources2::await_send coroutines::Task> DataPlaneResources2::await_recv_descriptor() { - auto read_element = co_await m_recv_descriptors.read(); + // Await and get descriptor object from shared buffer + auto read_element = co_await m_recv_descriptors->read(); std::shared_ptr recv_descriptor = std::move(*read_element); + // Now that user is consuming the descriptor object, pull deferred payloads from remote machine recv_descriptor->fetch_remote_payloads(); co_return recv_descriptor; @@ -462,7 +471,7 @@ coroutines::Task> DataPlaneResources2::awa coroutines::Task DataPlaneResources2::register_remote_descriptor(std::shared_ptr descriptor) { - // If the descriptor has an object_id > 0, the descriptor has already been registered and should not be re-registered + // If the descriptor has an object_id > 0, descriptor has already been registered and should not be re-registered auto object_id = descriptor->encoded_object().object_id(); if (object_id > 0) { @@ -477,8 +486,7 @@ coroutines::Task DataPlaneResources2::register_remote_descriptor(std:: descriptor->encoded_object().set_object_id(object_id); // Wait for semaphore to ensure that we have an empty slot to register the current descriptor - co_await m_remote_descriptors_semaphore->acquire(); // Directly await the semaphore - + co_await m_remote_descriptors_semaphore->acquire(); { std::unique_lock lock(m_remote_descriptors_mutex); m_descriptor_by_id[object_id].push_back(descriptor); @@ -494,7 +502,6 @@ coroutines::Task DataPlaneResources2::complete_remote_pull(remote_descript // Once we've completed pulling of a descriptor, we remove a descriptor shared ptr from the vector // When the vector becomes empty, there will be no more shared ptrs pointing to the descriptor object, // it will be destructed accordingly. - // We should also remove that mapping as the object_id corresponding to that mapping will not be reused. auto& descriptors = m_descriptor_by_id[message->object_id]; descriptors.pop_back(); if (descriptors.size() == 0) @@ -523,6 +530,12 @@ uint64_t DataPlaneResources2::registered_remote_descriptor_ptr_count(uint64_t ob void DataPlaneResources2::set_max_remote_descriptors(uint64_t max_remote_descriptors) { m_max_remote_descriptors = max_remote_descriptors; + + // Update the remote descriptor ClosableRingBuffer and Semaphore capacity + m_recv_descriptors = std::unique_ptr>>( + new coroutines::ClosableRingBuffer>( + {.capacity = std::min(m_max_remote_descriptors, uint64_t(100000))})); + m_remote_descriptors_semaphore = std::unique_ptr( new coroutines::Semaphore{{.capacity = std::min(m_max_remote_descriptors, static_cast(100000))}}); } diff --git a/cpp/mrc/src/internal/data_plane/data_plane_resources.hpp b/cpp/mrc/src/internal/data_plane/data_plane_resources.hpp index 71a470701..c174a3e88 100644 --- a/cpp/mrc/src/internal/data_plane/data_plane_resources.hpp +++ b/cpp/mrc/src/internal/data_plane/data_plane_resources.hpp @@ -131,6 +131,7 @@ class DataPlaneResources2 bool has_instance_id() const; uint64_t get_instance_id() const; + // Should only be called when there are no in-flight messages as m_recv_descriptors will be reset void set_max_remote_descriptors(uint64_t max_remote_descriptors); ucxx::Context& context() const; @@ -202,14 +203,19 @@ class DataPlaneResources2 std::size_t bytes, ucs_memory_type_t mem_type); + // Coroutine to asynchronously send message to remote machine coroutines::Task> await_am_send(std::shared_ptr endpoint, memory::const_buffer_view buffer_view); std::shared_ptr am_recv_async(std::shared_ptr endpoint); + // Coroutine to async register, serialize, and send a descriptor to the specified endpoint + // Relies on callback to receive the message. Must be used in tandem with await_recv_descriptor coroutines::Task> await_send_descriptor( std::shared_ptr send_descriptor, std::shared_ptr endpoint); + + // Coroutine to async await on new descriptor object in shared buffer, fetch deferred payloads from remote machine coroutines::Task> await_recv_descriptor(); coroutines::Task register_remote_descriptor(std::shared_ptr descriptor); @@ -238,6 +244,8 @@ class DataPlaneResources2 uint64_t get_next_object_id(); + // Callback function to decrement shared_ptr reference count or signal end-of-life of a descriptor object + // Requires awaiting on the release of coroutines::Semaphore coroutines::Task complete_remote_pull(remote_descriptor::DescriptorPullCompletionMessage* message); uint64_t m_max_remote_descriptors{std::numeric_limits::max()}; @@ -249,10 +257,11 @@ class DataPlaneResources2 boost::fibers::mutex m_remote_descriptors_mutex{}; // ClosableRingBuffer uses 100000 as a "practical" limit where the capacity is the minimum of the two values. - coroutines::ClosableRingBuffer> m_recv_descriptors{ - {.capacity = std::min(m_max_remote_descriptors, static_cast(100000))}}; + std::unique_ptr>> m_recv_descriptors; protected: + // Maps descriptor id to a vector of shared_ptr instances + // Uses std::shared_ptr reference counting for maintaining the lifetime of a descriptor object std::map>> m_descriptor_by_id; }; diff --git a/cpp/mrc/src/internal/ucx/registration_cache.cpp b/cpp/mrc/src/internal/ucx/registration_cache.cpp index f6fcf8a1b..8e8ca8560 100644 --- a/cpp/mrc/src/internal/ucx/registration_cache.cpp +++ b/cpp/mrc/src/internal/ucx/registration_cache.cpp @@ -153,6 +153,7 @@ std::optional> RegistrationCache3::lookup(ui { std::lock_guard lock(m_mutex); + // The descriptor obj_id and memory block addr must both be valid if (m_memory_handle_by_address.find(obj_id) != m_memory_handle_by_address.end()) { auto descriptor_handles = m_memory_handle_by_address.at(obj_id); diff --git a/cpp/mrc/src/internal/ucx/registration_cache.hpp b/cpp/mrc/src/internal/ucx/registration_cache.hpp index 04753e4c2..c72077406 100644 --- a/cpp/mrc/src/internal/ucx/registration_cache.hpp +++ b/cpp/mrc/src/internal/ucx/registration_cache.hpp @@ -181,7 +181,7 @@ class RegistrationCache2 final * @brief UCX Registration Cache * * UCX memory registration object that will both register/deregister memory. The cache can be queried for the original - * memory block by providing the starting address of the contiguous block. + * memory block by providing the id of the descriptor object and the starting address of the contiguous block. */ class RegistrationCache3 final { @@ -194,8 +194,11 @@ class RegistrationCache3 final * For each block of memory registered with the RegistrationCache, an entry containing the block information is * storage and can be queried. * + * @param obj_id ID of the descriptor object that owns the memory block being registered * @param addr * @param bytes + * @param memory_type + * @return std::shared_ptr */ std::shared_ptr add_block(uint64_t obj_id, void* addr, std::size_t bytes, memory::memory_kind memory_type); @@ -207,6 +210,7 @@ class RegistrationCache3 final * This method queries the registration cache to find the MemoryHanlde containing the original address and size as * well as the serialized remote keys associated with the memory block. * + * @param obj_id ID of the descriptor object that owns the memory block being registered * @param addr * @return std::shared_ptr */ @@ -214,6 +218,14 @@ class RegistrationCache3 final std::optional> lookup(uint64_t obj_id, uintptr_t addr) const noexcept; + /** + * @brief Deregistration of all memory blocks owned by the descriptor object with id obj_id + * + * This method deregisters all memory blocks owned by the descriptor object at the end of the descriptor's lifetime. + * Required so the system does not run into memory insufficiency errors. + * + * @param obj_id ID of the descriptor object that owns the memory block being registered + */ void remove_descriptor(uint64_t obj_id); private: diff --git a/cpp/mrc/src/public/codable/decode.cpp b/cpp/mrc/src/public/codable/decode.cpp index 6c3d2b838..23479031c 100644 --- a/cpp/mrc/src/public/codable/decode.cpp +++ b/cpp/mrc/src/public/codable/decode.cpp @@ -38,6 +38,8 @@ void DecoderBase::read_descriptor(memory::buffer_view dst_view) const else { const auto& deferred_msg = payload.deferred_msg(); + + // Depending on the message memory type, we will use a different memcpy method to properly copy the data switch (payload.memory_kind()) { case protos::MemoryKind::Host: diff --git a/cpp/mrc/src/public/codable/encode.cpp b/cpp/mrc/src/public/codable/encode.cpp index abec5160f..ae5178e57 100644 --- a/cpp/mrc/src/public/codable/encode.cpp +++ b/cpp/mrc/src/public/codable/encode.cpp @@ -29,6 +29,8 @@ EncoderBase::EncoderBase(DescriptorObjectHandler& encoded_object) : void EncoderBase::write_descriptor(memory::const_buffer_view view) { + // Static check with arbitrary memory size to determine whether we should use eager or deferred protocol + // Thorough benchmarking and analysis should be done to derive protocol selection heuristic MessageKind kind = (view.bytes() < 64_KiB) ? MessageKind::Eager : MessageKind::Deferred; protos::Payload* payload = m_encoded_object.proto().add_payloads(); @@ -37,10 +39,12 @@ void EncoderBase::write_descriptor(memory::const_buffer_view view) switch (kind) { case MessageKind::Eager: { + // If the message is allocated on device memory, we should fall through and default to using deferred protocol if (view.kind() == memory::memory_kind::host) { auto* eager_msg = payload->mutable_eager_msg(); + // Directly set the data for eager payload eager_msg->set_data(view.data(), view.bytes()); return; @@ -49,6 +53,7 @@ void EncoderBase::write_descriptor(memory::const_buffer_view view) case MessageKind::Deferred: { auto* deferred_msg = payload->mutable_deferred_msg(); + // Set the payload address and number of bytes for later RDMA operation deferred_msg->set_address(reinterpret_cast(view.data())); deferred_msg->set_bytes(view.bytes()); diff --git a/cpp/mrc/src/public/runtime/remote_descriptor.cpp b/cpp/mrc/src/public/runtime/remote_descriptor.cpp index fe2cff61a..f2fc1d0fb 100644 --- a/cpp/mrc/src/public/runtime/remote_descriptor.cpp +++ b/cpp/mrc/src/public/runtime/remote_descriptor.cpp @@ -383,11 +383,12 @@ void Descriptor2::setup_remote_payloads() auto* deferred_msg = payload.mutable_deferred_msg(); + // Look for the memory block in the registration cache auto ucx_block = m_data_plane_resources.registration_cache3().lookup(remote_object.object_id(), deferred_msg->address()); if (!ucx_block.has_value()) { - // Need to register the memory + // Given that the memory block is not registered, we must register the memory ucx_block = m_data_plane_resources.registration_cache3().add_block(remote_object.object_id(), deferred_msg->address(), deferred_msg->bytes(), @@ -409,7 +410,7 @@ void Descriptor2::fetch_remote_payloads() // Loop over all remote payloads and convert them to local payloads for (auto& remote_payload : *m_encoded_object->proto().mutable_payloads()) { - // If payload is an EagerMessage, we do not need to do any pulling + // If payload is an EagerMessage, we do not need to do RDMA operations on remote sending machine if (remote_payload.has_eager_msg()) { continue; diff --git a/cpp/mrc/src/tests/CMakeLists.txt b/cpp/mrc/src/tests/CMakeLists.txt index f53e0ab05..ce78d1762 100644 --- a/cpp/mrc/src/tests/CMakeLists.txt +++ b/cpp/mrc/src/tests/CMakeLists.txt @@ -45,10 +45,28 @@ add_executable(test_mrc_private pipelines/multi_segment.cpp pipelines/single_segment.cpp segments/common_segments.cpp + test_codable.cpp # test_control_plane_components.cpp # test_control_plane.cpp + test_expected.cpp + test_grpc.cpp test_main.cpp + test_memory.cpp test_network.cpp + test_next.cpp + # test_partition_manager.cpp + # test_partitions.cpp + test_pipeline.cpp + test_ranges.cpp + # test_remote_descriptor.cpp + # test_resources.cpp + test_reusable_pool.cpp + # test_runnable.cpp + # test_runtime.cpp + test_service.cpp + test_system.cpp + test_topology.cpp + test_ucx.cpp ) target_link_libraries(test_mrc_private diff --git a/cpp/mrc/src/tests/test_codable.cpp b/cpp/mrc/src/tests/test_codable.cpp index f7ea8d9c5..25c974a87 100644 --- a/cpp/mrc/src/tests/test_codable.cpp +++ b/cpp/mrc/src/tests/test_codable.cpp @@ -70,105 +70,40 @@ class CodableObject CodableObject() = default; ~CodableObject() = default; - static CodableObject deserialize(const Decoder& buffer, std::size_t /*unused*/) + static CodableObject deserialize(const mrc::codable::Decoder2& decoder) { return {}; } - void serialize(Encoder& /*unused*/) const {} + void serialize(mrc::codable::Encoder2& encoder) const {} }; -class CodableObjectWithOptions -{ - public: - CodableObjectWithOptions() = default; - ~CodableObjectWithOptions() = default; - - static CodableObjectWithOptions deserialize(const Decoder& encoding, - std::size_t /*unused*/) - { - return {}; - } - - void serialize(Encoder& /*unused*/, const EncodingOptions& opts) const {} -}; - -class CodableViaExternalStruct -{}; - -namespace mrc::codable { - -template <> -struct codable_protocol -{ - void serialize(const CodableViaExternalStruct& /*unused*/, Encoder& /*unused*/) {} -}; - -}; // namespace mrc::codable - namespace mrc::codable {} struct NotCodableObject {}; -class TestCodable : public ::testing::Test -{ - protected: - void SetUp() override - { - m_runtime = std::make_unique(system::SystemProvider(tests::make_system([](Options& options) { - // todo(#114) - propose: remove this option entirely - options.enable_server(true); - options.architect_url("localhost:13337"); - options.placement().resources_strategy(PlacementResources::Dedicated); - }))); - - DVLOG(10) << "Setup Complete"; - } - - void TearDown() override - { - DVLOG(10) << "Start Teardown"; - m_runtime.reset(); - DVLOG(10) << "Teardown Complete"; - } - - std::unique_ptr m_runtime; -}; +class TestCodable : public ::testing::Test {}; TEST_F(TestCodable, Objects) { - static_assert(codable::encodable::value, "should be encodable"); - static_assert(codable::encodable::value, "should be encodable"); - static_assert(codable::encodable::value, "should be encodable"); - static_assert(!codable::encodable::value, "should NOT be encodable"); - - static_assert(codable::decodable::value, "should be decodable"); - static_assert(codable::decodable::value, "should be decodable"); - static_assert(!codable::decodable::value, "should NOT be decodable"); - static_assert(!codable::decodable::value, "should NOT be decodable"); - - static_assert(codable::value, "fully codable"); - static_assert(codable::value, "fully codable"); - static_assert(!codable::value, "half codable"); - static_assert(!codable::value, "not codable"); + static_assert(codable::encodable, "should be encodable"); + static_assert(!codable::encodable, "should NOT be encodable"); + + static_assert(codable::decodable, "should be decodable"); + static_assert(!codable::decodable, "should NOT be decodable"); } TEST_F(TestCodable, String) { - static_assert(codable::value, "should be codable"); + static_assert(codable::encodable, "should be encodable"); + static_assert(codable::decodable, "should be decodable"); std::string str = "Hello MRC"; - auto str_block = m_runtime->partition(0).resources().network()->data_plane().registration_cache().lookup( - str.data()); - EXPECT_FALSE(str_block); - auto encodable_storage = m_runtime->partition(0).make_codable_storage(); + std::unique_ptr encoded_obj = encode2(str); - encode(str, *encodable_storage); - EXPECT_EQ(encodable_storage->descriptor_count(), 1); - - auto decoded_str = decode(*encodable_storage); + auto decoded_str = decode2(*encoded_obj); EXPECT_STREQ(str.c_str(), decoded_str.c_str()); } @@ -185,7 +120,8 @@ void populate(int size, int* ptr) TEST_F(TestCodable, Buffer) { - static_assert(codable::value, "should be codable"); + // static_assert(codable::encodable, "should be encodable"); + // static_assert(codable::decodable, "should be decodable"); // Uncomment when local copy is working! // auto encodable_storage = m_runtime->partition(0).make_codable_storage(); @@ -209,45 +145,34 @@ TEST_F(TestCodable, Buffer) TEST_F(TestCodable, Double) { - static_assert(codable::value, "should be codable"); - - auto encodable_storage = m_runtime->partition(0).make_codable_storage(); + static_assert(codable::encodable, "should be encodable"); + static_assert(codable::decodable, "should be decodable"); double pi = 3.14159; - encode(pi, *encodable_storage); - EXPECT_EQ(encodable_storage->descriptor_count(), 1); + std::unique_ptr encoded_obj = encode2(pi); - auto decoding = decode(*encodable_storage); + auto decoding = decode2(*encoded_obj); EXPECT_DOUBLE_EQ(pi, decoding); } TEST_F(TestCodable, Composite) { - static_assert(codable::value, "should be codable"); - static_assert(codable::value, "should be codable"); + static_assert(codable::encodable, "should be encodable"); + static_assert(codable::decodable, "should be decodable"); + + static_assert(codable::encodable, "should be encodable"); + static_assert(codable::decodable, "should be decodable"); std::string str = "Hello Mrc"; std::uint64_t ans = 42; - auto encodable_storage = m_runtime->partition(0).make_codable_storage(); - - encode(str, *encodable_storage); - encode(ans, *encodable_storage); + std::unique_ptr encoded_obj1 = encode2(str); + std::unique_ptr encoded_obj2 = encode2(ans); - EXPECT_EQ(encodable_storage->object_count(), 2); - EXPECT_EQ(encodable_storage->descriptor_count(), 2); - - auto decoded_str = decode(*encodable_storage, 0); - auto decoded_ans = decode(*encodable_storage, 1); + auto decoded_str = decode2(*encoded_obj1); + auto decoded_ans = decode2(*encoded_obj2); EXPECT_STREQ(str.c_str(), decoded_str.c_str()); EXPECT_EQ(ans, decoded_ans); } - -TEST_F(TestCodable, EncodedObjectProto) -{ - static_assert(codable::encodable::value, "should be encodable"); - static_assert(codable::decodable::value, "should be decodable"); - static_assert(codable::value, "should be codable"); -} diff --git a/cpp/mrc/src/tests/test_network.cpp b/cpp/mrc/src/tests/test_network.cpp index a5a9388cc..4d0afaaf8 100644 --- a/cpp/mrc/src/tests/test_network.cpp +++ b/cpp/mrc/src/tests/test_network.cpp @@ -136,6 +136,9 @@ class TestNetwork : public ::testing::Test std::shared_ptr m_loopback_endpoint; }; +/** + * Serialization and deserialization methods for vector objects allocated on Host or Device memory. + */ namespace mrc::codable { template struct codable_protocol> @@ -176,14 +179,14 @@ struct codable_protocol> } }; -// Serialization methods meant for testing device allocated memory template <> struct codable_protocol { static void serialize(const unsigned char* obj, mrc::codable::Encoder2& encoder) { - size_t size = 64_KiB; // Assuming a fixed size for this example + // Since unsigned char* does not carry indicator of size, specify a fixed size for testing purposes + size_t size = 64_KiB; mrc::codable::encode2(size, encoder); encoder.write_descriptor({obj, size * sizeof(unsigned char), memory::memory_kind::device}); @@ -612,7 +615,7 @@ TEST_F(TestNetwork, LocalDescriptorRoundTrip) auto send_data_copy = send_data; - // Create a descriptor that will pass throught the local path + // Create a descriptor that will pass through the local path auto descriptor = runtime::Descriptor2::create_from_value(std::move(send_data_copy), *m_resources); // deserialize the descriptor to get value @@ -630,12 +633,14 @@ TEST_F(TestNetwork, TransferFullDescriptors) auto send_data_copy = send_data; + // Create the descriptor object from value std::shared_ptr send_descriptor = runtime::Descriptor2::create_from_value(std::move(send_data_copy), *m_resources); // Check that no remote payloads are yet registered with `DataPlaneResources2`. EXPECT_EQ(m_resources->registered_remote_descriptor_count(), 0); + // Await for registering remote descriptor coroutine to complete uint64_t obj_id = coroutines::sync_wait(m_resources->register_remote_descriptor(send_descriptor)); // Get the serialized data @@ -666,17 +671,17 @@ TEST_F(TestNetwork, TransferFullDescriptors) receive_request->getRecvBuffer()->getSize(), mrc::memory::memory_kind::host); + // Create the descriptor object from received data std::shared_ptr recv_descriptor = runtime::Descriptor2::create_from_bytes(std::move(buffer_view), *m_resources); + // Pull the remaining deferred payloads from the remote machine recv_descriptor->fetch_remote_payloads(); uint64_t recv_descriptor_object_id = recv_descriptor->encoded_object().object_id(); EXPECT_EQ(send_descriptor_object_id, recv_descriptor_object_id); - // TODO(Peter): This is now completely async and we must progress the worker, we need a timeout in case it fails to - // complete. // Wait for remote decrement messages. while (registered_send_descriptor.lock() != nullptr) m_resources->progress(); @@ -706,12 +711,14 @@ TEST_F(TestNetwork, TransferFullDescriptorsDevice) cudaMalloc(&send_data_device, send_data_host.size() * sizeof(u_int8_t)); cudaMemcpy(send_data_device, send_data_host.data(), send_data_host.size() * sizeof(u_int8_t), cudaMemcpyHostToDevice); + // Create the descriptor object from value std::shared_ptr send_descriptor = runtime::Descriptor2::create_from_value(std::move(send_data_device), *m_resources); // Check that no remote payloads are yet registered with `DataPlaneResources2`. EXPECT_EQ(m_resources->registered_remote_descriptor_count(), 0); + // Await for registering remote descriptor coroutine to complete uint64_t obj_id = coroutines::sync_wait(m_resources->register_remote_descriptor(send_descriptor)); // Get the serialized data @@ -737,22 +744,21 @@ TEST_F(TestNetwork, TransferFullDescriptorsDevice) std::weak_ptr registered_send_descriptor = m_resources->get_descriptor(send_descriptor_object_id); EXPECT_NE(registered_send_descriptor.lock(), nullptr); - // Create a descriptor from the received data auto buffer_view = memory::buffer_view(receive_request->getRecvBuffer()->data(), receive_request->getRecvBuffer()->getSize(), mrc::memory::memory_kind::host); + // Create the descriptor object from received data std::shared_ptr recv_descriptor = runtime::Descriptor2::create_from_bytes(std::move(buffer_view), *m_resources); + // Pull the remaining deferred payloads from the remote machine recv_descriptor->fetch_remote_payloads(); uint64_t recv_descriptor_object_id = recv_descriptor->encoded_object().object_id(); EXPECT_EQ(send_descriptor_object_id, recv_descriptor_object_id); - // TODO(Peter): This is now completely async and we must progress the worker, we need a timeout in case it fails to - // complete. // Wait for remote decrement messages. while (registered_send_descriptor.lock() != nullptr) m_resources->progress(); @@ -767,6 +773,7 @@ TEST_F(TestNetwork, TransferFullDescriptorsDevice) // Finally, get the value auto recv_data_device = recv_descriptor->deserialize(); + // Copy the data into host memory to easily compare the results std::vector recv_data_host(data_size); cudaMemcpy(recv_data_host.data(), recv_data_device, data_size * sizeof(u_int8_t), cudaMemcpyDeviceToHost); @@ -798,13 +805,14 @@ TEST_F(TestNetwork, TransferFullDescriptorsBroadcast) auto send_data_copy = send_data; - // Create a descriptor + // Create the descriptor object from value std::shared_ptr send_descriptor = runtime::Descriptor2::create_from_value(std::move(send_data_copy), *m_resources); // Check that no remote payloads are yet registered with `DataPlaneResources2`. EXPECT_EQ(m_resources->registered_remote_descriptor_count(), 0); + // Await for registering remote descriptor coroutines to complete uint64_t obj_id1 = coroutines::sync_wait(m_resources->register_remote_descriptor(send_descriptor)); uint64_t obj_id2 = coroutines::sync_wait(m_resources->register_remote_descriptor(send_descriptor)); @@ -837,22 +845,22 @@ TEST_F(TestNetwork, TransferFullDescriptorsBroadcast) std::weak_ptr registered_send_descriptor = m_resources->get_descriptor(send_descriptor_object_id); EXPECT_NE(registered_send_descriptor.lock(), nullptr); - // Create a descriptor from the received data auto buffer_view = memory::buffer_view(receive_request->getRecvBuffer()->data(), receive_request->getRecvBuffer()->getSize(), mrc::memory::memory_kind::host); + // Create the descriptor object from received data std::shared_ptr recv_descriptor = runtime::Descriptor2::create_from_bytes(std::move(buffer_view), *m_resources); + // Pull the remaining deferred payloads from the remote machine recv_descriptor->fetch_remote_payloads(); - uint64_t recv_descriptor_object_id = recv_descriptor->encoded_object().object_id(); + uint64_t recv_descriptor_object_id = recv_descriptor->encoded_object().object_id(); if (expected_ptrs > 0) { - // TODO(Peter): This is now completely async and we must progress the worker, we need a timeout in case - // it fails to complete. Wait for remote decrement messages. + // Wait for remote decrement messages. while (m_resources->registered_remote_descriptor_ptr_count(send_descriptor_object_id) != expected_ptrs) { m_resources->progress(); @@ -864,9 +872,7 @@ TEST_F(TestNetwork, TransferFullDescriptorsBroadcast) } else { - // TODO(Peter): This is now completely async and we must progress the worker, we need a timeout in case - // it fails to complete. Wait for remote decrement messages. - + // Wait for remote decrement messages. while (registered_send_descriptor.lock() != nullptr) { m_resources->progress(); @@ -919,10 +925,11 @@ TEST_P(TestNetworkPressure, TransferPressureControl) { auto send_data_copy = send_data; + // Create the descriptor object from received data std::shared_ptr send_descriptor = runtime::Descriptor2::create_from_value(std::move(send_data_copy), *m_resources); - // Register descriptor with the DataPlaneResource object + // Await for registering remote descriptor coroutines to complete uint64_t obj_id = coroutines::sync_wait(m_resources->register_remote_descriptor(send_descriptor)); // Get the serialized data and push to queue for consumption by request processing thread @@ -939,10 +946,7 @@ TEST_P(TestNetworkPressure, TransferPressureControl) auto increase_max_descriptors = [this, max_descriptors, total_descriptors, ®istered_descriptors]() { // Wait until registration hits max number of descriptors and blocks - while (registered_descriptors < max_descriptors) - { - ; - } + while (registered_descriptors < max_descriptors) {} // Unblock `register_descriptors` immediately, even if requests are not being processed m_resources->set_max_remote_descriptors(total_descriptors); @@ -959,10 +963,7 @@ TEST_P(TestNetworkPressure, TransferPressureControl) // (`DataPlaneResources2` internal queue is full) and registrations are still ongoing or serialized data is not // available yet. while ((m_resources->registered_remote_descriptor_count() < max_descriptors && !registration_finished) || - serialized_data.empty()) - { - ; - } + serialized_data.empty()) {} auto local_serialized_data = std::move(serialized_data.front()); serialized_data.pop(); @@ -985,13 +986,14 @@ TEST_P(TestNetworkPressure, TransferPressureControl) EXPECT_NE(registered_send_descriptor.lock(), nullptr); - // Create a remote descriptor from the received data + // Create the descriptor object from received data auto recv_descriptor = runtime::Descriptor2::create_from_bytes( {receive_request->getRecvBuffer()->data(), receive_request->getRecvBuffer()->getSize(), mrc::memory::memory_kind::host}, *m_resources); + // Pull the remaining deferred payloads from the remote machine recv_descriptor->fetch_remote_payloads(); auto recv_descriptor_object_id = recv_descriptor->encoded_object().object_id(); diff --git a/protos/mrc/protos/codable.proto b/protos/mrc/protos/codable.proto index 69cacc5e4..0c9a46b3f 100644 --- a/protos/mrc/protos/codable.proto +++ b/protos/mrc/protos/codable.proto @@ -266,7 +266,9 @@ message DescriptorObject // Payloads associated with this object repeated Payload payloads = 3; - // The number of "lifetime" tokens associated with this object. When the number of tokens reaches 0, the object - // is deleted - uint64 tokens = 5; + // The number of "lifetime" tokens associated with this object. When number of tokens reaches 0, the object is deleted + + // NOTE: This object is deprecated in the new implementation but kept for backwards compatibility. + // The new implementation relies on shared_ptr built-in reference counting + uint64 tokens = 4; }