diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp index 5aeeb62bee1ec..d26d25ae03e39 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -1,11 +1,9 @@ +#ifdef USE_C10D_XCCL + #include #include -#include -#include - -#ifdef USE_C10D_XCCL -#include #include +#include #include #include #include @@ -13,15 +11,7 @@ #include #include -#include -#include -#include #include -#include -#include -#include -#include -#include namespace c10d { @@ -45,36 +35,6 @@ std::map xcclDatatypes = { {at::kBool, ccl::datatype::uint8}, }; -XCCL_KVS kvs; -std::mutex kvs_mutex; - -XCCL_KVS get_kvs(int rank, c10d::Store& store) { - std::lock_guard lock(kvs_mutex); - if (kvs) - return kvs; - std::string storeKey = "xccl_kvs"; - - // Rank 0 broadcast the bootstrap network information to other ranks - if (rank == 0) { - kvs = ccl::create_main_kvs(); - ccl::kvs::address_type main_addr = kvs->get_address(); - auto ccl_kvs_addr = - std::vector(main_addr.begin(), main_addr.end()); - store.set(storeKey, ccl_kvs_addr); - } else { - auto ccl_kvs_addr = store.get(storeKey); - if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { - throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); - } - ccl::kvs::address_type main_addr; - std::copy_n( - ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); - kvs = ccl::create_kvs(main_addr); - } - - return kvs; -} - void check_xpu_single_tensor(const at::Tensor& tensor) { if (!tensor.is_xpu() || tensor.is_sparse()) { C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense"); @@ -106,23 +66,9 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { } return xcclOps.at(reduceOp); } catch (const std::out_of_range&) { - switch (reduceOp) { - case ReduceOp::AVG: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp AVG with XCCL"); - break; - case ReduceOp::BAND: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with XCCL"); - break; - case ReduceOp::BOR: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with XCCL"); - break; - case ReduceOp::BXOR: - C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with XCCL"); - break; - default: - C10_THROW_ERROR(ValueError, "Unhandled ReduceOp"); - break; - } + C10_THROW_ERROR( + ValueError, + "Cannot use ReduceOp." + reduce_op_to_string(reduceOp) + " with XCCL"); } } @@ -153,20 +99,6 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default; -bool ProcessGroupXCCL::WorkXCCL::checkTimeout( - std::optional timeout) { - auto currentTimepoint = std::chrono::steady_clock::now(); - auto timeElapsed = std::chrono::duration_cast( - currentTimepoint - workStartTime_); - std::chrono::milliseconds opTimeout = std::chrono::milliseconds(60000); - - auto workTimeout = timeout ? *timeout : opTimeout; - - if (timeElapsed < workTimeout) - return false; - return true; -} - bool ProcessGroupXCCL::WorkXCCL::isCompleted() { if (xcclEndEvent_ && xcclEndEvent_->query()) { return true; @@ -178,23 +110,23 @@ void ProcessGroupXCCL::WorkXCCL::synchronize() { synchronizeInternal(kNoTimeout); } -void ProcessGroupXCCL::WorkXCCL::synchronizeStream() { - auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); - // Block the current stream on the XCCL stream - xcclEndEvent_->block(currentStream); -} - void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( std::chrono::milliseconds timeout) { - synchronizeStream(); - + auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); + xcclEndEvent_->block(currentStream); if (blockingWait_) { while (!isCompleted()) { - bool timedOut = checkTimeout( - timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout)); - if (timedOut) { - break; + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - workStartTime_); + if (timeElapsed >= timeout) { + std::string exceptionMsg = c10::str( + "Work ran for ", + timeElapsed.count(), + " milliseconds before timing out."); + TORCH_CHECK(false, exceptionMsg) } + std::this_thread::sleep_for( std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); } diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp index 14a9f398a8cbe..99b815f2138b4 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -28,43 +28,8 @@ #include #include #include -#include namespace c10d { -namespace { -int getXCCLEnvVar(std::string envVarName) { - char* stringValue = std::getenv(envVarName.c_str()); - if (stringValue != nullptr) { - try { - int val = std::stoi(stringValue); - return val; - } catch (std::exception& e) { - TORCH_CHECK( - false, - "Invalid value for environment variable: " + std::string(envVarName)); - } - } else { - return -1; - } -} - -template -void setXCCLEnvVar(const std::string& envVarName, T val) { - if constexpr (std::is_same_v) { - setenv(envVarName.c_str(), std::to_string(val).c_str(), 1); - } else if constexpr (std::is_same_v) { - setenv(envVarName.c_str(), val.c_str(), 1); - } -} - -bool with_mpirun() { - return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") || - getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK")) - ? true - : false; -} -} // namespace - static std::vector TORCH_XCCL_BLOCKING_WAIT = { "TORCH_XCCL_BLOCKING_WAIT", "XCCL_BLOCKING_WAIT"}; @@ -98,8 +63,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { void synchronize() override; - void synchronizeStream(); - bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; c10::intrusive_ptr getFuture() override { @@ -110,9 +73,6 @@ class TORCH_API ProcessGroupXCCL : public Backend { TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::result not implemented"); } - bool checkTimeout( - std::optional timeout = std::nullopt); - protected: at::Device device_; std::shared_ptr xcclEndEvent_; @@ -302,7 +262,70 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr store_; std::mutex mutex_; bool blockingWait_ = false; + + private: + XCCL_KVS kvs; + std::mutex kvs_mutex; + XCCL_KVS get_kvs(int rank, c10d::Store& store) { + std::lock_guard lock(kvs_mutex); + if (kvs) + return kvs; + std::string storeKey = "xccl_kvs"; + // Rank 0 broadcast the bootstrap network information to other ranks + if (rank == 0) { + kvs = ccl::create_main_kvs(); + ccl::kvs::address_type main_addr = kvs->get_address(); + auto ccl_kvs_addr = + std::vector(main_addr.begin(), main_addr.end()); + store.set(storeKey, ccl_kvs_addr); + } else { + auto ccl_kvs_addr = store.get(storeKey); + if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { + throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); + } + ccl::kvs::address_type main_addr; + std::copy_n( + ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); + kvs = ccl::create_kvs(main_addr); + } + return kvs; + } }; + +namespace { +int getXCCLEnvVar(std::string envVarName) { + char* stringValue = std::getenv(envVarName.c_str()); + if (stringValue != nullptr) { + try { + int val = std::stoi(stringValue); + return val; + } catch (std::exception& e) { + TORCH_CHECK( + false, + "Invalid value for environment variable: " + std::string(envVarName)); + } + } else { + return -1; + } +} + +template +void setXCCLEnvVar(const std::string& envVarName, T val) { + if constexpr (std::is_same_v) { + setenv(envVarName.c_str(), std::to_string(val).c_str(), 1); + } else if constexpr (std::is_same_v) { + setenv(envVarName.c_str(), val.c_str(), 1); + } +} + +bool with_mpirun() { + return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") || + getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK")) + ? true + : false; +} + +} // namespace } // namespace c10d #endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index ea4a4653bc35f..73e37e0437c45 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -557,6 +557,31 @@ size_t computeLengthsAndOffsets( return offset; } +inline std::string reduce_op_to_string(c10d::ReduceOp op) { + switch (op) { + case c10d::ReduceOp::SUM: + return "SUM"; + case c10d::ReduceOp::PRODUCT: + return "PRODUCT"; + case c10d::ReduceOp::MIN: + return "MIN"; + case c10d::ReduceOp::MAX: + return "MAX"; + case c10d::ReduceOp::BAND: + return "BAND"; + case c10d::ReduceOp::BOR: + return "BOR"; + case c10d::ReduceOp::BXOR: + return "BXOR"; + case c10d::ReduceOp::AVG: + return "AVG"; + case c10d::ReduceOp::PREMUL_SUM: + return "PREMUL_SUM"; + default: + return "UNKNOWN"; + } +} + using RankType = uint32_t; using SizeType = uint64_t;