Skip to content

Commit

Permalink
hidden xccl specific
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Oct 10, 2024
1 parent ef261c6 commit 6c648cd
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 126 deletions.
104 changes: 18 additions & 86 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
@@ -1,27 +1,17 @@
#ifdef USE_C10D_XCCL

#include <torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp>
#include <fstream>
#include <mutex>
#include <sstream>

#ifdef USE_C10D_XCCL
#include <exception>
#include <map>
#include <sstream>
#include <stdexcept>
#include <tuple>
#include <unordered_set>
#include <utility>

#include <ATen/detail/FunctionTraits.h>
#include <c10/core/DeviceType.h>
#include <c10/util/CallOnce.h>
#include <c10/util/Exception.h>
#include <c10/util/Logging.h>
#include <c10/util/Optional.h>
#include <c10/util/irange.h>
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
#include <torch/csrc/distributed/c10d/TraceUtils.h>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/torch.h>

namespace c10d {

Expand All @@ -45,36 +35,6 @@ std::map<at::ScalarType, ccl::datatype> xcclDatatypes = {
{at::kBool, ccl::datatype::uint8},
};

XCCL_KVS kvs;
std::mutex kvs_mutex;

XCCL_KVS get_kvs(int rank, c10d::Store& store) {
std::lock_guard<std::mutex> lock(kvs_mutex);
if (kvs)
return kvs;
std::string storeKey = "xccl_kvs";

// Rank 0 broadcast the bootstrap network information to other ranks
if (rank == 0) {
kvs = ccl::create_main_kvs();
ccl::kvs::address_type main_addr = kvs->get_address();
auto ccl_kvs_addr =
std::vector<uint8_t>(main_addr.begin(), main_addr.end());
store.set(storeKey, ccl_kvs_addr);
} else {
auto ccl_kvs_addr = store.get(storeKey);
if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) {
throw std::runtime_error("Unexpected ccl kvs addr from the store\n");
}
ccl::kvs::address_type main_addr;
std::copy_n(
ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin());
kvs = ccl::create_kvs(main_addr);
}

return kvs;
}

void check_xpu_single_tensor(const at::Tensor& tensor) {
if (!tensor.is_xpu() || tensor.is_sparse()) {
C10_THROW_ERROR(ValueError, "Tensors must be XPU and dense");
Expand Down Expand Up @@ -106,23 +66,9 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) {
}
return xcclOps.at(reduceOp);
} catch (const std::out_of_range&) {
switch (reduceOp) {
case ReduceOp::AVG:
C10_THROW_ERROR(ValueError, "Cannot use ReduceOp AVG with XCCL");
break;
case ReduceOp::BAND:
C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BAND with XCCL");
break;
case ReduceOp::BOR:
C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BOR with XCCL");
break;
case ReduceOp::BXOR:
C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with XCCL");
break;
default:
C10_THROW_ERROR(ValueError, "Unhandled ReduceOp");
break;
}
C10_THROW_ERROR(
ValueError,
"Cannot use ReduceOp." + reduce_op_to_string(reduceOp) + " with XCCL");
}
}

Expand Down Expand Up @@ -153,20 +99,6 @@ ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w)

ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default;

bool ProcessGroupXCCL::WorkXCCL::checkTimeout(
std::optional<std::chrono::milliseconds> timeout) {
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
currentTimepoint - workStartTime_);
std::chrono::milliseconds opTimeout = std::chrono::milliseconds(60000);

auto workTimeout = timeout ? *timeout : opTimeout;

if (timeElapsed < workTimeout)
return false;
return true;
}

bool ProcessGroupXCCL::WorkXCCL::isCompleted() {
if (xcclEndEvent_ && xcclEndEvent_->query()) {
return true;
Expand All @@ -178,23 +110,23 @@ void ProcessGroupXCCL::WorkXCCL::synchronize() {
synchronizeInternal(kNoTimeout);
}

void ProcessGroupXCCL::WorkXCCL::synchronizeStream() {
auto currentStream = at::xpu::getCurrentXPUStream(device_.index());
// Block the current stream on the XCCL stream
xcclEndEvent_->block(currentStream);
}

void ProcessGroupXCCL::WorkXCCL::synchronizeInternal(
std::chrono::milliseconds timeout) {
synchronizeStream();

auto currentStream = at::xpu::getCurrentXPUStream(device_.index());
xcclEndEvent_->block(currentStream);
if (blockingWait_) {
while (!isCompleted()) {
bool timedOut = checkTimeout(
timeout == kNoTimeout ? std::nullopt : std::make_optional(timeout));
if (timedOut) {
break;
auto currentTimepoint = std::chrono::steady_clock::now();
auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>(
currentTimepoint - workStartTime_);
if (timeElapsed >= timeout) {
std::string exceptionMsg = c10::str(
"Work ran for ",
timeElapsed.count(),
" milliseconds before timing out.");
TORCH_CHECK(false, exceptionMsg)
}

std::this_thread::sleep_for(
std::chrono::milliseconds(kSynchronizeBusyWaitMillis));
}
Expand Down
103 changes: 63 additions & 40 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,43 +28,8 @@
#include <c10/xpu/XPUCachingAllocator.h>
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/Store.hpp>
namespace c10d {

namespace {
int getXCCLEnvVar(std::string envVarName) {
char* stringValue = std::getenv(envVarName.c_str());
if (stringValue != nullptr) {
try {
int val = std::stoi(stringValue);
return val;
} catch (std::exception& e) {
TORCH_CHECK(
false,
"Invalid value for environment variable: " + std::string(envVarName));
}
} else {
return -1;
}
}

template <typename T>
void setXCCLEnvVar(const std::string& envVarName, T val) {
if constexpr (std::is_same_v<T, int>) {
setenv(envVarName.c_str(), std::to_string(val).c_str(), 1);
} else if constexpr (std::is_same_v<T, std::string>) {
setenv(envVarName.c_str(), val.c_str(), 1);
}
}

bool with_mpirun() {
return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") ||
getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK"))
? true
: false;
}
} // namespace

static std::vector<std::string> TORCH_XCCL_BLOCKING_WAIT = {
"TORCH_XCCL_BLOCKING_WAIT",
"XCCL_BLOCKING_WAIT"};
Expand Down Expand Up @@ -98,8 +63,6 @@ class TORCH_API ProcessGroupXCCL : public Backend {

void synchronize() override;

void synchronizeStream();

bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;

c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
Expand All @@ -110,9 +73,6 @@ class TORCH_API ProcessGroupXCCL : public Backend {
TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::result not implemented");
}

bool checkTimeout(
std::optional<std::chrono::milliseconds> timeout = std::nullopt);

protected:
at::Device device_;
std::shared_ptr<at::xpu::XPUEvent> xcclEndEvent_;
Expand Down Expand Up @@ -302,7 +262,70 @@ class TORCH_API ProcessGroupXCCL : public Backend {
c10::intrusive_ptr<Store> store_;
std::mutex mutex_;
bool blockingWait_ = false;

private:
XCCL_KVS kvs;
std::mutex kvs_mutex;
XCCL_KVS get_kvs(int rank, c10d::Store& store) {
std::lock_guard<std::mutex> lock(kvs_mutex);
if (kvs)
return kvs;
std::string storeKey = "xccl_kvs";
// Rank 0 broadcast the bootstrap network information to other ranks
if (rank == 0) {
kvs = ccl::create_main_kvs();
ccl::kvs::address_type main_addr = kvs->get_address();
auto ccl_kvs_addr =
std::vector<uint8_t>(main_addr.begin(), main_addr.end());
store.set(storeKey, ccl_kvs_addr);
} else {
auto ccl_kvs_addr = store.get(storeKey);
if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) {
throw std::runtime_error("Unexpected ccl kvs addr from the store\n");
}
ccl::kvs::address_type main_addr;
std::copy_n(
ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin());
kvs = ccl::create_kvs(main_addr);
}
return kvs;
}
};

namespace {
int getXCCLEnvVar(std::string envVarName) {
char* stringValue = std::getenv(envVarName.c_str());
if (stringValue != nullptr) {
try {
int val = std::stoi(stringValue);
return val;
} catch (std::exception& e) {
TORCH_CHECK(
false,
"Invalid value for environment variable: " + std::string(envVarName));
}
} else {
return -1;
}
}

template <typename T>
void setXCCLEnvVar(const std::string& envVarName, T val) {
if constexpr (std::is_same_v<T, int>) {
setenv(envVarName.c_str(), std::to_string(val).c_str(), 1);
} else if constexpr (std::is_same_v<T, std::string>) {
setenv(envVarName.c_str(), val.c_str(), 1);
}
}

bool with_mpirun() {
return (getenv("MPI_LOCALRANKID") || getenv("MPI_LOCALNRANKS") ||
getenv("PMI_RANK") || getenv("PMI_SIZE") || getenv("PMIX_RANK"))
? true
: false;
}

} // namespace
} // namespace c10d

#endif // USE_C10D_XCCL
25 changes: 25 additions & 0 deletions torch/csrc/distributed/c10d/Utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,31 @@ size_t computeLengthsAndOffsets(
return offset;
}

inline std::string reduce_op_to_string(c10d::ReduceOp op) {
switch (op) {
case c10d::ReduceOp::SUM:
return "SUM";
case c10d::ReduceOp::PRODUCT:
return "PRODUCT";
case c10d::ReduceOp::MIN:
return "MIN";
case c10d::ReduceOp::MAX:
return "MAX";
case c10d::ReduceOp::BAND:
return "BAND";
case c10d::ReduceOp::BOR:
return "BOR";
case c10d::ReduceOp::BXOR:
return "BXOR";
case c10d::ReduceOp::AVG:
return "AVG";
case c10d::ReduceOp::PREMUL_SUM:
return "PREMUL_SUM";
default:
return "UNKNOWN";
}
}

using RankType = uint32_t;
using SizeType = uint64_t;

Expand Down

0 comments on commit 6c648cd

Please sign in to comment.