Skip to content

Commit

Permalink
accept comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han committed Oct 17, 2024
1 parent 56a5e7f commit a062f9f
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 28 deletions.
4 changes: 2 additions & 2 deletions torch/csrc/distributed/c10d/ProcessGroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
NCCL = 2,
UCC = 3,
MPI = 4,
CUSTOM = 5,
XCCL = 6,
XCCL = 5,
CUSTOM = 6,
};

static std::string backendTypeToString(const BackendType& type) {
Expand Down
19 changes: 0 additions & 19 deletions torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,25 +253,6 @@ c10::intrusive_ptr<Work> ProcessGroupXCCL::allreduce(
auto tensor = tensors.back();
checkXPUTensor(tensor);

RECORD_PARAM_COMMS_DATA(
// static_cast<int>(
// this->getSequenceNumberForGroup() + 1), // seq + 1 to match
// collective
1,
std::make_tuple(pg_uid_, pg_desc_), // PG name tuple
tensors, // inputTensors
tensors, // outputTensors
rank_, // rank
"allreduce", // collective name
tensor.numel(), // inNelems
tensor.numel(), // outNelems
tensor.scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
0, // globalRankStart
1, // globalRankStride
this->getSize()); // worldSize

return collective(
tensor,
tensor,
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ class TORCH_API ProcessGroupXCCL : public Backend {
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;

void setSequenceNumberForGroup() override {}

protected:
std::unordered_map<std::string, at::xpu::XPUStream> xcclStreamsMap_;
std::unordered_map<std::string, at::xpu::XPUEvent> xcclEventsMap_;
Expand All @@ -151,7 +153,6 @@ class TORCH_API ProcessGroupXCCL : public Backend {
ccl::shared_ptr_class<ccl::kvs> kvs;

ccl::shared_ptr_class<ccl::kvs> get_kvs(int rank, c10d::Store& store) {
// todo: why do we need the mutex here?
std::lock_guard<std::mutex> lock(kvs_mutex);
if (kvs)
return kvs;
Expand Down
8 changes: 2 additions & 6 deletions torch/distributed/distributed_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,13 +1675,9 @@ def _new_process_group_helper(
"created, please use a different group name"
)

if device_id is not None and (
device_id.index is None
or (device_id.type != "cuda" and device_id.type != "xpu")
):
if device_id is not None and device_id.index is None:
raise ValueError(
"init_process_group device_id parameter must be a cuda device with an "
"id, e.g. cuda:0, xpu, not just cuda or xpu or cpu"
"init_process_group device_id parameter must be a device with an index"
)

# Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value
Expand Down

0 comments on commit a062f9f

Please sign in to comment.