diff --git a/CMakeLists.txt b/CMakeLists.txt index 98593c2de9717..60fc8aae14173 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -275,6 +275,8 @@ option(USE_NATIVE_ARCH "Use -march=native" OFF) cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF) cmake_dependent_option(USE_NCCL "Use NCCL" ON "USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF) +cmake_dependent_option(USE_XCCL "Use XCCL" ON + "USE_XPU;UNIX;NOT APPLE" OFF) cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF) cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF) cmake_dependent_option(USE_SYSTEM_NCCL "Use system-wide NCCL" OFF "USE_NCCL" @@ -353,6 +355,8 @@ cmake_dependent_option(USE_C10D_GLOO "USE C10D GLOO" ON "USE_DISTRIBUTED;USE_GLOO" OFF) cmake_dependent_option(USE_C10D_NCCL "USE C10D NCCL" ON "USE_DISTRIBUTED;USE_NCCL" OFF) +cmake_dependent_option(USE_C10D_XCCL "USE C10D XCCL" ON + "USE_DISTRIBUTED;USE_XCCL" OFF) cmake_dependent_option(USE_C10D_MPI "USE C10D MPI" ON "USE_DISTRIBUTED;USE_MPI" OFF) cmake_dependent_option( diff --git a/build_variables.bzl b/build_variables.bzl index 8417c1f53a72c..d11bba1ae1f37 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -703,6 +703,10 @@ libtorch_cuda_sources = libtorch_cuda_core_sources + libtorch_cuda_distributed_s "torch/csrc/cuda/nccl.cpp", ] +libtorch_xpu_distributed_extra_sources = [ + "torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp", +] + torch_cpp_srcs = [ "torch/csrc/api/src/cuda.cpp", # this just forwards stuff, no real CUDA "torch/csrc/api/src/data/datasets/mnist.cpp", diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index f25286d5a6fe4..25bd7f700f68a 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1013,6 +1013,13 @@ elseif(USE_CUDA) endif() if(USE_XPU) + # if SYCL runtime and oneCCL runtime are both system installed + # then building flag USE_XPU=ON , USE_XCCL=ON and USE_C10D_XCCL=ON; + # XCCL backend will be build in libtorch_xpu; + # manually set `USE_XCCL=OFF` disable XCCL backend building. + if(USE_XCCL) + append_filelist("libtorch_xpu_distributed_extra_sources" Caffe2_XPU_SRCS) + endif() add_library(torch_xpu ${Caffe2_XPU_SRCS}) torch_compile_options(torch_xpu) # see cmake/public/utils.cmake target_compile_definitions(torch_xpu PRIVATE USE_XPU) @@ -1078,6 +1085,10 @@ if(USE_XPU) include_directories(SYSTEM ${ATen_XPU_INCLUDE_DIRS}) endif() + if(USE_XCCL) + target_link_libraries(torch_xpu PRIVATE torch::xccl) + target_compile_definitions(torch_xpu PRIVATE USE_XCCL) + endif() endif() if(NOT MSVC AND USE_XNNPACK) @@ -1363,6 +1374,9 @@ if(USE_DISTRIBUTED) target_compile_definitions(torch_cuda PUBLIC USE_C10D_NCCL) endif() endif() + if(USE_XPU AND USE_C10D_XCCL) + target_compile_definitions(torch_xpu PUBLIC USE_C10D_XCCL) + endif() if(USE_MPI AND USE_C10D_MPI) if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") set_source_files_properties( diff --git a/caffe2/core/macros.h.in b/caffe2/core/macros.h.in index 2929f105b31fa..e5398a83cad94 100644 --- a/caffe2/core/macros.h.in +++ b/caffe2/core/macros.h.in @@ -45,6 +45,7 @@ {"USE_CUDNN", "${USE_CUDNN}"}, \ {"CUDNN_VERSION", "${CUDNN_VERSION}"}, \ {"USE_NCCL", "${USE_NCCL}"}, \ + {"USE_XCCL", "${USE_XCCL}"}, \ {"USE_MPI", "${USE_MPI}"}, \ {"USE_GFLAGS", "${USE_GFLAGS}"}, \ {"USE_GLOG", "${USE_GLOG}"}, \ diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 1dc33efec7b87..f90846e89c754 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1151,6 +1151,24 @@ if(USE_CUDA) include_directories(SYSTEM ${CUB_INCLUDE_DIRS}) endif() +# ---[ XCCL +if(USE_XCCL) + if(NOT USE_XPU) + message(WARNING + "Not using XPU, so disabling USE_XCCL. Suppress this warning with " + "-DUSE_XCCL=OFF.") + caffe2_update_option(USE_XCCL OFF) + elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Linux") + message(WARNING "USE_XCCL is currently only supported under Linux.") + caffe2_update_option(USE_XCCL OFF) + else() + include(${CMAKE_CURRENT_LIST_DIR}/External/xccl.cmake) + if(NOT XCCL_FOUND) + caffe2_update_option(USE_XCCL OFF) + endif() + endif() +endif() + if(USE_DISTRIBUTED AND USE_TENSORPIPE) if(MSVC) message(WARNING "Tensorpipe cannot be used on Windows.") diff --git a/cmake/External/xccl.cmake b/cmake/External/xccl.cmake new file mode 100644 index 0000000000000..acb7cee87593e --- /dev/null +++ b/cmake/External/xccl.cmake @@ -0,0 +1,15 @@ +if(NOT __XCCL_INCLUDED) + set(__XCCL_INCLUDED TRUE) + + # XCCL_ROOT, XCCL_LIBRARY_DIR, XCCL_INCLUDE_DIR are handled by FindXCCL.cmake. + find_package(XCCL REQUIRED) + if(XCCL_FOUND) + add_library(torch::xccl INTERFACE IMPORTED) + set_property( + TARGET torch::xccl PROPERTY INTERFACE_INCLUDE_DIRECTORIES + ${XCCL_INCLUDE_DIR}) + set_property( + TARGET torch::xccl PROPERTY INTERFACE_LINK_LIBRARIES + ${XCCL_LIBRARY}) + endif() +endif() diff --git a/cmake/Modules/FindXCCL.cmake b/cmake/Modules/FindXCCL.cmake new file mode 100644 index 0000000000000..18f7ac642d54e --- /dev/null +++ b/cmake/Modules/FindXCCL.cmake @@ -0,0 +1,69 @@ +# This will define the following variables: +# XCCL_FOUND : True if the system has the XCCL library. +# XCCL_INCLUDE_DIR : Include directories needed to use XCCL. +# XCCL_LIBRARY_DIR :The path to the XCCL library. +# XCCL_LIBRARY : XCCL library fullname. + +include(FindPackageHandleStandardArgs) + +set(XCCL_ROOT "/opt/intel/oneapi/ccl/latest") +if (NOT EXISTS "${XCCL_ROOT}") + message(STATUS "Default OneCCL not found, using current environment OneAPI") + set(XCCL_ROOT $ENV{ONEAPI_ROOT}/ccl/latest) +endif() + +string(COMPARE EQUAL "${XCCL_ROOT}" "" nocclfound) +if(nocclfound) + set(XCCL_FOUND False) + set(XCCL_REASON_FAILURE "OneCCL library not found!!") + set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}") + return() +endif() + +# Find include path from binary. +find_file( + XCCL_INCLUDE_DIR + NAMES include + HINTS ${XCCL_ROOT} + NO_DEFAULT_PATH +) + +# Find include/oneapi path from include path. +find_file( + XCCL_INCLUDE_ONEAPI_DIR + NAMES oneapi + HINTS ${XCCL_ROOT}/include/ + NO_DEFAULT_PATH +) + +list(APPEND XCCL_INCLUDE_DIR ${XCCL_INCLUDE_ONEAPI_DIR}) + +# Find library directory from binary. +find_file( + XCCL_LIBRARY_DIR + NAMES lib + HINTS ${XCCL_ROOT} + NO_DEFAULT_PATH +) + +# Find XCCL library fullname. +find_library( + XCCL_LIBRARY + NAMES ccl + HINTS ${XCCL_LIBRARY_DIR} + NO_DEFAULT_PATH +) + +if((NOT XCCL_INCLUDE_DIR) OR (NOT XCCL_LIBRARY_DIR) OR (NOT XCCL_LIBRARY)) + set(XCCL_FOUND False) + set(XCCL_REASON_FAILURE "OneCCL library not found!!") + set(XCCL_NOT_FOUND_MESSAGE "${XCCL_REASON_FAILURE}") + return() +endif() + +find_package_handle_standard_args( + XCCL + FOUND_VAR XCCL_FOUND + REQUIRED_VARS XCCL_INCLUDE_DIR XCCL_LIBRARY_DIR XCCL_LIBRARY + REASON_FAILURE_MESSAGE "${XCCL_REASON_FAILURE}" +) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index d51c451589c2c..229ff112ab318 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -153,6 +153,12 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}") endif() message(STATUS " USE_ITT : ${USE_ITT}") + message(STATUS " USE_XCCL : ${USE_XCCL}") + if(${USE_XCCL}) + message(STATUS " USE_C10D_XCCL : ${USE_C10D_XCCL}") + message(STATUS " XCCL include path : ${XCCL_INCLUDE_DIR}") + message(STATUS " XCCL library : ${XCCL_LIBRARY}") + endif() message(STATUS " USE_NCCL : ${USE_NCCL}") if(${USE_NCCL}) message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}") diff --git a/setup.py b/setup.py index f674d607dcdde..e9f5d2a579432 100644 --- a/setup.py +++ b/setup.py @@ -645,6 +645,10 @@ def run(self): report("-- Building NCCL library") else: report("-- Not using NCCL") + if cmake_cache_vars["USE_XCCL"]: + report("-- Building XCCL library") + else: + report("-- Not using XCCL") if cmake_cache_vars["USE_DISTRIBUTED"]: if IS_WINDOWS: report("-- Building without distributed package") diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 53f2753a48506..d3cb65f7befb1 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -31,6 +31,7 @@ from torch.testing._internal.common_distributed import ( MultiProcessTestCase, skip_if_lt_x_gpu, + get_device_count, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -60,14 +61,15 @@ torch.backends.cuda.matmul.allow_tf32 = False -def gpus_for_rank(world_size): +def gpus_for_rank(world_size, backend): """Multigpu tests are designed to simulate the multi nodes with multi GPUs on each node. Nccl backend requires equal #GPUs in each process. On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - visible_devices = list(range(torch.cuda.device_count())) - gpus_per_process = torch.cuda.device_count() // world_size + device_count = get_device_count(backend) + visible_devices = list(range(device_count)) + gpus_per_process = device_count // world_size gpus_for_rank = [] for rank in range(world_size): gpus_for_rank.append( @@ -828,7 +830,7 @@ def update_parameters(model): def _gpu_model_with_ddp_comm_hook( self, process_group, hook=None, gradient_as_bucket_view=False, state=None ): - device_id = gpus_for_rank(self.world_size)[self.rank][0] + device_id = gpus_for_rank(self.world_size, process_group.name())[self.rank][0] gpu_model = DistributedDataParallel( ModuleForDdpCommHook().to(device_id), device_ids=[device_id], @@ -845,7 +847,7 @@ def _gpu_model_with_ddp_comm_hook( def _gpu_model_with_builtin_ddp_comm_hook( self, process_group, hook=None, gradient_as_bucket_view=False ): - device_id = gpus_for_rank(self.world_size)[self.rank][0] + device_id = gpus_for_rank(self.world_size, process_group.name())[self.rank][0] gpu_model = DistributedDataParallel( ModuleForDdpCommHook().to(device_id), device_ids=[device_id], @@ -1831,6 +1833,9 @@ def test_init_process_group_for_all_backends(self): elif backend == dist.Backend.UCC: if not dist.is_ucc_available(): continue + elif backend == dist.Backend.XCCL: + if not dist.is_xccl_available(): + continue # Multi-threaded PG is defined as a pure python class. # Its pg.name() does not going through Pybind, so its backend name # is still "threaded" instead of "custom". diff --git a/test/distributed/test_c10d_xccl.py b/test/distributed/test_c10d_xccl.py new file mode 100644 index 0000000000000..704cdd414e554 --- /dev/null +++ b/test/distributed/test_c10d_xccl.py @@ -0,0 +1,303 @@ +# Owner(s): ["oncall: distributed"] + +import math +import os +import sys +import time +from datetime import timedelta +from unittest import mock + +import torch +import torch.distributed as c10d + + +if not c10d.is_available() or not c10d.is_xccl_available(): + print("c10d XCCL not available, skipping tests", file=sys.stderr) + sys.exit(0) + +import test_c10d_common + +import torch.distributed as dist +import torch.testing._internal.common_utils as common +from torch.testing._internal.common_distributed import ( + init_multigpu_helper, + MultiProcessTestCase, + requires_xccl, +) +from torch.testing._internal.common_utils import ( + retry_on_connect_failures, + run_tests, + skip_but_pass_in_sandcastle_if, + TEST_XPU, + TestCase, +) + + +def simple_reduce_tests(rank, world_size): + tests = [ + ( + c10d.ReduceOp.SUM, + torch.tensor([rank + 1.0]), + torch.tensor([float(world_size * (world_size + 1) / 2)]), + ), + ( + c10d.ReduceOp.PRODUCT, + torch.tensor([rank + 1.0]), + torch.tensor([float(math.factorial(world_size))]), + ), + ( + c10d.ReduceOp.MIN, + torch.tensor([rank + 1.0]), + torch.tensor([1.0]), + ), + ( + c10d.ReduceOp.MAX, + torch.tensor([rank + 1.0]), + torch.tensor([world_size]), + ), + ] + + return tests + + +TEST_MULTIXPU = torch.xpu.device_count() > 1 + + +class RendezvousEnvTest(TestCase): + @retry_on_connect_failures + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test") + def test_common_errors(self): + vars = { + "WORLD_SIZE": "1", + "RANK": "0", + "MASTER_ADDR": "127.0.0.1", + "MASTER_PORT": str(common.find_free_port()), + } + + class Env: + def __init__(self, vars): + self.env_patcher = mock.patch.dict(os.environ, vars, clear=True) + + def __enter__(self): + self.env_patcher.start() + + def __exit__(self, type, value, traceback): + self.env_patcher.stop() + + def without(d, key): + d = d.copy() + d.pop(key) + return d + + def withouts(d, keys): + d = d.copy() + for key in keys: + d.pop(key) + return d + + with Env(without(vars, "WORLD_SIZE")): + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + with self.assertRaisesRegex(ValueError, "WORLD_SIZE expected"): + gen = c10d.rendezvous("env://") + next(gen) + c10d.init_process_group(backend="xccl", world_size=1) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(without(vars, "RANK")): + self.assertEqual(None, os.environ.get("RANK")) + with self.assertRaisesRegex(ValueError, "RANK expected"): + gen = c10d.rendezvous("env://") + next(gen) + c10d.init_process_group(backend="xccl", rank=0) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): + self.assertEqual(None, os.environ.get("RANK")) + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + c10d.init_process_group(backend="xccl", rank=0, world_size=1) + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(vars): + c10d.init_process_group(backend="xccl") + self.assertEqual(c10d.get_rank(), 0) + self.assertEqual(c10d.get_world_size(), 1) + c10d.destroy_process_group() + + with Env(without(vars, "MASTER_ADDR")): + self.assertEqual(None, os.environ.get("MASTER_ADDR")) + with self.assertRaisesRegex(ValueError, "MASTER_ADDR expected"): + gen = c10d.rendezvous("env://") + next(gen) + + with Env(without(vars, "MASTER_PORT")): + self.assertEqual(None, os.environ.get("MASTER_PORT")) + with self.assertRaisesRegex(ValueError, "MASTER_PORT expected"): + gen = c10d.rendezvous("env://") + next(gen) + + with Env(without(vars, "WORLD_SIZE")): + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + gen = c10d.rendezvous(f"env://?world_size={1}") + _, _, size = next(gen) + self.assertEqual(size, 1) + + with Env(without(vars, "RANK")): + self.assertEqual(None, os.environ.get("RANK")) + gen = c10d.rendezvous(f"env://?rank={0}") + _, rank, _ = next(gen) + self.assertEqual(rank, 0) + + with Env(withouts(vars, ["RANK", "WORLD_SIZE"])): + self.assertEqual(None, os.environ.get("RANK")) + self.assertEqual(None, os.environ.get("WORLD_SIZE")) + gen = c10d.rendezvous(f"env://?rank={0}&world_size={1}") + _, rank, size = next(gen) + self.assertEqual(rank, 0) + self.assertEqual(size, 1) + + +class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase): + @requires_xccl() + @retry_on_connect_failures + @skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test") + def test_default_store_timeout_nccl(self): + self._test_default_store_timeout("xccl") + + +class ProcessGroupXCCLTest(MultiProcessTestCase): + def _create_process_group_xccl( + self, timeout=timedelta(seconds=600), device_id=None + ): + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + "xccl", + world_size=self.world_size, + rank=self.rank, + store=store, + timeout=timeout, + device_id=device_id, + ) + pg = c10d.distributed_c10d._get_default_group() + return pg + + def setUp(self): + super().setUp() + self._spawn_processes() + + def tearDown(self): + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def world_size(self): + return 2 + + @property + def rank_to_GPU(self): + # return rank to GPU map + return init_multigpu_helper(self.world_size, "xccl") + + @requires_xccl() + @skip_but_pass_in_sandcastle_if( + torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" + ) + def test_close_multi_pg_unordered(self): + pg = self._create_process_group_xccl() + device = self.rank_to_GPU[self.rank][0] + t = torch.rand(10, 10, device=device) + # First allreduce to initialize default PG's communicator. + pg.allreduce(t).wait() + new_pg1 = c10d.new_group([0, 1]) + new_pg2 = c10d.new_group([0, 1]) + if self.rank == 0 or self.rank == 1: + t1 = torch.rand(10, 10, device=device) + t2 = torch.rand(10, 10, device=device) + new_pg1.allreduce(t1).wait() + new_pg2.allreduce(t2).wait() + if self.rank == 0: + dist.destroy_process_group(new_pg2) + # force destruction of pg2 first + del new_pg2 + dist.destroy_process_group(new_pg1) + del new_pg1 + if self.rank == 1: + c10d.destroy_process_group(new_pg1) + # force destruction of pg1 first + del new_pg1 + dist.destroy_process_group(new_pg2) + del new_pg2 + dist.destroy_process_group() + + @requires_xccl() + @skip_but_pass_in_sandcastle_if( + torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs" + ) + def test_file_store_check(self): + # self.file_name is created using "delete=False" + # e.g., self.file_name = tempfile.NamedTemporaryFile(delete=False).name + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="xccl", rank=self.rank, world_size=self.world_size, store=store + ) + pg = dist.distributed_c10d._get_default_group() + self.assertEqual(pg.rank(), self.rank) + self.assertEqual(pg.size(), self.world_size) + # give enough time for check() to be executed multiple times + time.sleep(2) + dist.destroy_process_group() + + @requires_xccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIXPU, "XCCL test requires 2+ GPUs") + def test_set_process_group_desc(self): + device = torch.device(f"xpu:{self.rank}") + pg_default = self._create_process_group_xccl(device_id=device) + self.assertEqual(pg_default.group_desc, "default_pg") + pg_1 = c10d.new_group([0, 1], group_desc="test_purpose") + self.assertEqual(pg_1.group_desc, "test_purpose") + pg_2 = c10d.new_group([0, 1]) + self.assertEqual(pg_2.group_desc, "undefined") + + def _test_allreduce_basics(self, fn): + pg = self._create_process_group_xccl() + device = torch.device("xpu:" + str(self.rank)) + # Single input tests + tests = simple_reduce_tests(self.rank, self.world_size) + for op, input, expected in tests: + opts = c10d.AllreduceOptions() + opts.reduceOp = op + tensor = fn(input.to(device)) + fut = pg.allreduce([tensor], opts).get_future() + fut.wait() + result = fut.value() + self.assertEqual(expected, result[0], exact_dtype=False) + + x = fn(torch.tensor([self.rank + 1.0], device=device)) + fut = pg.allreduce(x).get_future() + fut.wait() + result = fut.value() + self.assertEqual( + torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]), + result[0], + ) + + @requires_xccl() + def test_allreduce_basics(self): + self._test_allreduce_basics(lambda t: t.clone()) + + +if __name__ == "__main__": + assert ( + not torch.xpu._initialized + ), "test_distributed must not have initialized XPU context on main process" + + run_tests() diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index c74b45431c947..b8dfb8b706ba1 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -282,6 +282,9 @@ if(USE_DISTRIBUTED) if(USE_NCCL) list(APPEND TORCH_PYTHON_LINK_LIBRARIES __caffe2_nccl) endif() + if(USE_XCCL) + list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::xccl) + endif() # Same for MPI. if(USE_MPI) list(APPEND TORCH_PYTHON_LINK_LIBRARIES MPI::MPI_CXX) @@ -351,6 +354,10 @@ if(BUILD_LIBTORCHLESS) target_compile_definitions(torch_python PRIVATE USE_C10D_NCCL) endif() + if(USE_XPU AND USE_C10D_XCCL) + target_compile_definitions(torch_python PRIVATE USE_C10D_XCCL) + endif() + if(USE_DISTRIBUTED) target_compile_definitions(torch_python PRIVATE USE_DISTRIBUTED) endif() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index f1cbf47ea0f3f..f89f0b50c8582 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -300,6 +300,7 @@ class ProcessGroup: UNDEFINED = ... GLOO = ... NCCL = ... + XCCL = ... UCC = ... MPI = ... CUSTOM = ... @@ -688,3 +689,11 @@ class ProcessGroupCudaP2P(Backend): storage_offset: Optional[int] = 0, ) -> torch.Tensor: ... def _shutdown(self) -> None: ... + +class ProcessGroupXCCL(Backend): + def __init__( + self, + store: Store, + rank: int, + size: int, + ): ... diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index ae822ad397504..699c54236f641 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -79,6 +79,7 @@ namespace { } IMPL_SEND(CPU) +IMPL_SEND(XPU) IMPL_SEND(CUDA) IMPL_SEND(PrivateUse1) @@ -94,6 +95,7 @@ IMPL_SEND(PrivateUse1) } IMPL_RECV(CPU) +IMPL_RECV(XPU) IMPL_RECV(CUDA) IMPL_RECV(PrivateUse1) @@ -108,6 +110,7 @@ IMPL_RECV(PrivateUse1) } IMPL_RECV_ANY_SOURCE(CPU) +IMPL_RECV_ANY_SOURCE(XPU) IMPL_RECV_ANY_SOURCE(CUDA) IMPL_RECV_ANY_SOURCE(PrivateUse1) @@ -131,6 +134,7 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1) } IMPL_REDUCE(CPU) +IMPL_REDUCE(XPU) IMPL_REDUCE(CUDA) IMPL_REDUCE(PrivateUse1) @@ -156,6 +160,7 @@ IMPL_REDUCE(PrivateUse1) } IMPL_BROADCAST(CPU) +IMPL_BROADCAST(XPU) IMPL_BROADCAST(CUDA) IMPL_BROADCAST(PrivateUse1) @@ -181,6 +186,7 @@ IMPL_BROADCAST(PrivateUse1) IMPL_ALLREDUCE(CPU) IMPL_ALLREDUCE(CUDA) +IMPL_ALLREDUCE(XPU) IMPL_ALLREDUCE(PrivateUse1) #define IMPL_ALLREDUCE_COALESCED(DEV) \ @@ -198,6 +204,7 @@ IMPL_ALLREDUCE(PrivateUse1) } IMPL_ALLREDUCE_COALESCED(CPU) +IMPL_ALLREDUCE_COALESCED(XPU) IMPL_ALLREDUCE_COALESCED(CUDA) IMPL_ALLREDUCE_COALESCED(PrivateUse1) @@ -222,6 +229,7 @@ IMPL_ALLREDUCE_COALESCED(PrivateUse1) // NOLINTBEGIN(cppcoreguidelines-pro-type-const-cast) IMPL_ALLGATHER(CPU) +IMPL_ALLGATHER(XPU) IMPL_ALLGATHER(CUDA) IMPL_ALLGATHER(PrivateUse1) @@ -242,6 +250,7 @@ IMPL_ALLGATHER(PrivateUse1) } IMPL__ALLGATHER_BASE(CPU) +IMPL__ALLGATHER_BASE(XPU) IMPL__ALLGATHER_BASE(CUDA) IMPL__ALLGATHER_BASE(PrivateUse1) @@ -258,6 +267,7 @@ IMPL__ALLGATHER_BASE(PrivateUse1) } IMPL_ALLGATHER_COALESCED(CPU) +IMPL_ALLGATHER_COALESCED(XPU) IMPL_ALLGATHER_COALESCED(CUDA) IMPL_ALLGATHER_COALESCED(PrivateUse1) @@ -273,6 +283,7 @@ IMPL_ALLGATHER_COALESCED(PrivateUse1) } IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU) +IMPL_ALLGATHER_INTO_TENSOR_COALESCED(XPU) IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CUDA) IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1) @@ -296,6 +307,7 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1) } IMPL_REDUCE_SCATTER(CPU) +IMPL_REDUCE_SCATTER(XPU) IMPL_REDUCE_SCATTER(CUDA) IMPL_REDUCE_SCATTER(PrivateUse1) @@ -320,6 +332,7 @@ IMPL_REDUCE_SCATTER(PrivateUse1) } IMPL__REDUCE_SCATTER_BASE(CPU) +IMPL__REDUCE_SCATTER_BASE(XPU) IMPL__REDUCE_SCATTER_BASE(CUDA) IMPL__REDUCE_SCATTER_BASE(PrivateUse1) @@ -341,6 +354,7 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1) } IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU) +IMPL_REDUCE_SCATTER_TENSOR_COALESCED(XPU) IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CUDA) IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1) @@ -360,6 +374,7 @@ IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1) } IMPL_GATHER(CPU) +IMPL_GATHER(XPU) IMPL_GATHER(CUDA) IMPL_GATHER(PrivateUse1) @@ -382,6 +397,7 @@ IMPL_GATHER(PrivateUse1) } IMPL_SCATTER(CPU) +IMPL_SCATTER(XPU) IMPL_SCATTER(CUDA) IMPL_SCATTER(PrivateUse1) @@ -403,6 +419,7 @@ IMPL_SCATTER(PrivateUse1) } IMPL_ALLTOALL(CPU) +IMPL_ALLTOALL(XPU) IMPL_ALLTOALL(CUDA) IMPL_ALLTOALL(PrivateUse1) @@ -424,6 +441,7 @@ IMPL_ALLTOALL(PrivateUse1) } IMPL_ALLTOALL_BASE(CPU) +IMPL_ALLTOALL_BASE(XPU) IMPL_ALLTOALL_BASE(CUDA) IMPL_ALLTOALL_BASE(PrivateUse1) @@ -439,6 +457,7 @@ IMPL_ALLTOALL_BASE(PrivateUse1) } IMPL_BARRIER(CPU) +IMPL_BARRIER(XPU) IMPL_BARRIER(CUDA) IMPL_BARRIER(PrivateUse1) // NOLINTEND(cppcoreguidelines-pro-type-const-cast) @@ -491,6 +510,7 @@ namespace { #define REGISTER_C10D_OP(FUNC) \ REGISTER_C10D_OP1(FUNC, CPU) \ REGISTER_C10D_OP1(FUNC, CUDA) \ + REGISTER_C10D_OP1(FUNC, XPU) \ REGISTER_C10D_OP1(FUNC, PrivateUse1) // Now we start to register ops with the three device keys diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 92b655f016eff..b3eac70e871bf 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -51,7 +51,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { NCCL = 2, UCC = 3, MPI = 4, - CUSTOM = 5, + XCCL = 5, + CUSTOM = 6, }; static std::string backendTypeToString(const BackendType& type) { @@ -60,6 +61,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return "gloo"; case BackendType::NCCL: return "nccl"; + case BackendType::XCCL: + return "xccl"; case BackendType::UCC: return "ucc"; case BackendType::MPI: @@ -80,6 +83,8 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return BackendType::GLOO; } else if (backend == "nccl") { return BackendType::NCCL; + } else if (backend == "xccl") { + return BackendType::XCCL; } else if (backend == "ucc") { return BackendType::UCC; } else if (backend == "mpi") { @@ -126,6 +131,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { return backendType_; }; + inline bool backendSupportsSequenceNumbers(BackendType backendType) { + if (backendType == BackendType::GLOO || backendType == BackendType::NCCL || + backendType == BackendType::XCCL || backendType == BackendType::UCC) + return true; + return false; + } + virtual void startCoalescing(c10::DeviceType deviceType) { // only nccl has implemented startCoalescing so only execute for nccl // backends @@ -503,9 +515,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { virtual void setSequenceNumberForGroup() { auto backendType = getBackendType(); // TODO: HACK for backend name to get sequence number for that backend. - if (backendType == ProcessGroup::BackendType::GLOO || - backendType == ProcessGroup::BackendType::NCCL || - backendType == ProcessGroup::BackendType::UCC) { + if (backendSupportsSequenceNumbers(backendType)) { getDefaultBackend()->setSequenceNumberForGroup(); } else { TORCH_CHECK( @@ -524,9 +534,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { auto backendType = getBackendType(); // TODO: HACK for backend name to get sequence number for that backend. - if (backendType == ProcessGroup::BackendType::GLOO || - backendType == ProcessGroup::BackendType::NCCL || - backendType == ProcessGroup::BackendType::UCC) { + if (backendSupportsSequenceNumbers(backendType)) { return getDefaultBackend()->getSequenceNumberForGroup(); } else { TORCH_CHECK( diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp new file mode 100644 index 0000000000000..76d265ca5de28 --- /dev/null +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.cpp @@ -0,0 +1,305 @@ +#ifdef USE_C10D_XCCL + +#include +#include +#include + +namespace c10d { + +namespace { +const std::map xcclOps = { + {ReduceOp::MIN, ccl::reduction::min}, + {ReduceOp::MAX, ccl::reduction::max}, + {ReduceOp::SUM, ccl::reduction::sum}, + {ReduceOp::PRODUCT, ccl::reduction::prod}, +}; + +const std::map xcclDatatypes = { + {at::kByte, ccl::datatype::uint8}, + {at::kChar, ccl::datatype::int8}, + {at::kInt, ccl::datatype::int32}, + {at::kLong, ccl::datatype::int64}, + {at::kHalf, ccl::datatype::float16}, + {at::kFloat, ccl::datatype::float32}, + {at::kDouble, ccl::datatype::float64}, + {at::kBFloat16, ccl::datatype::bfloat16}, + {at::kBool, ccl::datatype::uint8}, +}; + +void checkXPUTensor(at::Tensor& tensor) { + if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) { + C10_THROW_ERROR( + ValueError, "Tensors must be XPU and dense and non-complex"); + if (!tensor.is_contiguous(tensor.suggest_memory_format())) { + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + } + } +} + +ccl::datatype getXcclDataType( + at::ScalarType type, + bool is_reduction_op = false) { + TORCH_CHECK( + !isFloat8Type(type) && is_reduction_op, + "Float8 dtypes are not currenlty supported for XCCL reductions"); + auto it = xcclDatatypes.find(type); + TORCH_CHECK_WITH( + TypeError, + it != xcclDatatypes.end(), + "Input tensor data type is not supported for XCCL process group: ", + type); + return it->second; +} + +ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { + try { + if (input.scalar_type() == at::kBool && reduceOp == ReduceOp::SUM) { + // Map sum to max for bool tensors to avoid overflow issues with sum. + return ccl::reduction::max; + } + return xcclOps.at(reduceOp); + } catch (const std::out_of_range&) { + C10_THROW_ERROR( + ValueError, + "Cannot use ReduceOp." + reduceOpToString(reduceOp) + " with XCCL"); + } +} + +void syncStream( + at::Device& device, + at::xpu::XPUEvent& xcclEvent, + at::xpu::XPUStream& xcclStream) { + xcclEvent.record(at::xpu::getCurrentXPUStream(device.index())); + xcclEvent.block(xcclStream); +} +} // namespace + +constexpr int64_t kSynchronizeBusyWaitMillis = 10; + +ProcessGroupXCCL::WorkXCCL::WorkXCCL( + at::Device& device, + int rank, + OpType opType, + uint64_t seq, + const char* profilingTitle, + const std::optional>& inputs) + : Work(rank, opType, profilingTitle, inputs), + device_(device), + workStartTime_(std::chrono::steady_clock::now()), + seq_(seq) { + xcclEndEvent_ = std::make_shared(); +} + +ProcessGroupXCCL::WorkXCCL::WorkXCCL(const WorkXCCL& w) + : Work(w.rank_, w.opType_), + device_(w.device_), + xcclEndEvent_(w.xcclEndEvent_), + blockingWait_(w.blockingWait_), + workStartTime_(w.workStartTime_), + seq_(w.seq_) {} + +ProcessGroupXCCL::WorkXCCL::~WorkXCCL() = default; + +bool ProcessGroupXCCL::WorkXCCL::isCompleted() { + if (xcclEndEvent_ && xcclEndEvent_->query()) { + return true; + } + return false; +} + +void ProcessGroupXCCL::WorkXCCL::synchronize() { + synchronizeInternal(kNoTimeout); +} + +void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( + std::chrono::milliseconds timeout) { + auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); + xcclEndEvent_->block(currentStream); + if (blockingWait_) { + while (!isCompleted()) { + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - workStartTime_); + if (timeElapsed >= timeout) { + std::string exceptionMsg = c10::str( + "Work ran time out after ", timeElapsed.count(), " milliseconds."); + TORCH_CHECK(false, exceptionMsg) + } + std::this_thread::sleep_for( + std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); + } + } +} + +bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { + synchronizeInternal(timeout); + return true; +} + +ProcessGroupXCCL::ProcessGroupXCCL( + const c10::intrusive_ptr& store, + int rank, + int size) + : Backend(rank, size), store_(store) { + blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); + init(); +} + +ProcessGroupXCCL::~ProcessGroupXCCL() = default; + +c10::intrusive_ptr ProcessGroupXCCL::initWork( + at::Device& device, + int rank, + OpType opType, + const char* profilingTitle, + const std::vector& inputs, + const std::vector& outputs) { + auto r = c10::make_intrusive( + device, + rank, + opType, + seqCollective_, + profilingTitle, + std::optional>(inputs)); + return r; +} + +std::shared_ptr ProcessGroupXCCL::getXCCLComm( + const std::string& deviceKey, + at::Device& device) { + TORCH_CHECK_WITH( + DistBackendError, + !deviceKey.empty(), + "Not able to create/get " + "XCCL Communicator since the devices are empty "); + { + // todo: why do we need mutex here? + std::lock_guard lock(mutex_); + if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { + return devXCCLCommMap_[deviceKey]; + } + } + + int numRanks, rank; + numRanks = getSize(); + rank = getRank(); + + c10::impl::VirtualGuardImpl impl(device.type()); + c10::Stream stream = + impl.getStreamFromGlobalPool(device, /*isHighPriority=*/false); + sycl::queue& q = c10::xpu::XPUStream(stream).queue(); + + auto ctx = ccl::create_context(q.get_context()); + ccl::vector_class> devs_rank; + devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); + + auto xccl_kvs = get_kvs(rank_, *store_); + auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs); + std::shared_ptr XCCLComm = + std::make_shared(std::move(comms[0])); + + std::lock_guard lock(mutex_); + devXCCLCommMap_.emplace(deviceKey, XCCLComm); + xcclStreamsMap_.emplace(deviceKey, std::move(stream)); + xcclEventsMap_.emplace(deviceKey, at::xpu::XPUEvent()); + + return XCCLComm; +} + +template +c10::intrusive_ptr ProcessGroupXCCL::collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle) { + seqCollective_++; + + auto device = inputs[0].device(); + const auto key = std::to_string(device.index()); + auto comm = getXCCLComm(key, device); + + auto stream = xcclStreamsMap_.at(key); + syncStream(device, xcclEventsMap_[key], stream); + + c10::intrusive_ptr work; + work = initWork(device, rank_, opType, profilingTitle); + work->outputs_ = std::make_shared>(outputs); + + at::xpu::OptionalXPUGuard gpuGuard(device); + pre(stream, work); + for (const auto i : c10::irange(inputs.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + inputs[i].storage().data_ptr(), stream); + fn(inputs[i], outputs[i], *comm, stream); + } + post(stream, work); + + work->xcclEndEvent_->record(stream); + std::vector streams = {stream.unwrap()}; + c10::MultiStreamGuard streamGuard(streams); + std::vector devices{device}; + work->future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get()), devices); + work->future_->markCompleted(at::IValue(*work->outputs_)); + work->blockingWait_ = blockingWait_; + + return work; +} + +c10::intrusive_ptr ProcessGroupXCCL::allreduce( + std::vector& tensors, + const AllreduceOptions& opts) { + TORCH_CHECK( + tensors.size() == 1, "Expecting one tensor only but got multiple"); + auto tensor = tensors.back(); + checkXPUTensor(tensor); + + RECORD_PARAM_COMMS_DATA( + // static_cast( + // this->getSequenceNumberForGroup() + 1), // seq + 1 to match + // collective + 1, + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + 0, // globalRankStart + 1, // globalRankStride + this->getSize()); // worldSize + + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + auto ccl_stream = ccl::create_stream(stream.queue()); + ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl_stream); + return; + }, + OpType::ALLREDUCE, + "xccl:all_reduce"); +} + +} // namespace c10d + +#endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp new file mode 100644 index 0000000000000..f9761c652dc1a --- /dev/null +++ b/torch/csrc/distributed/c10d/ProcessGroupXCCL.hpp @@ -0,0 +1,187 @@ +#pragma once + +#ifdef USE_C10D_XCCL +// We will define those flags in XCCL backend file instead of passing to gcc +// compiler. +#define CCL_ENABLE_ZE +#define CCL_ENABLE_SYCL + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +namespace c10d { + +static std::vector TORCH_XCCL_BLOCKING_WAIT = { + "TORCH_XCCL_BLOCKING_WAIT", + "XCCL_BLOCKING_WAIT"}; + +using xcclComm_t = ccl::communicator; +constexpr const char* XCCL_BACKEND_NAME = "xccl"; + +class TORCH_API ProcessGroupXCCL : public Backend { + public: + class WorkXCCL : public Work { + public: + WorkXCCL( + at::Device& device, + int rank, + OpType opType, + uint64_t seq, + const char* profilingTitle = nullptr, + const std::optional>& inputs = std::nullopt); + WorkXCCL(const WorkXCCL& w); + ~WorkXCCL() override; + + bool isCompleted() override; + + void abort() override { + TORCH_CHECK(false, "ProcessGroupXCCL::WorkXCCL::abort not implemented"); + } + + void synchronize() override; + + bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; + + c10::intrusive_ptr getFuture() override { + return future_; + } + + uint64_t getSequencenumber() const override { + return seq_; + } + + std::vector result() override { + return *outputs_; + } + + protected: + at::Device device_; + std::shared_ptr xcclEndEvent_; + bool blockingWait_ = false; + std::chrono::time_point workStartTime_; + uint64_t seq_; + + private: + void synchronizeInternal(std::chrono::milliseconds timeout); + std::shared_ptr> outputs_; + c10::intrusive_ptr future_; + friend class ProcessGroupXCCL; + }; + + ProcessGroupXCCL(const c10::intrusive_ptr& store, int rank, int size); + + C10_DEPRECATED ProcessGroupXCCL( + const c10::intrusive_ptr& store, + int rank, + int size, + const std::string& groupName) + : ProcessGroupXCCL(store, rank, size) {} + + ~ProcessGroupXCCL() override; + + const std::string getBackendName() const override { + return std::string(XCCL_BACKEND_NAME); + } + + std::shared_ptr getXCCLComm( + const std::string& deviceKey, + at::Device& device); + + virtual c10::intrusive_ptr initWork( + at::Device& device, + int rank, + OpType opType, + const char* profilingTitle = nullptr, + const std::vector& inputs = {}, + const std::vector& outputs = {}); + + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + OpType opType, + const char* profilingTitle = nullptr) { + auto inputs = std::vector{input}; + auto outputs = std::vector{output}; + return collective( + inputs, + outputs, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + opType); + } + + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle = nullptr); + + c10::intrusive_ptr allreduce( + std::vector& tensors, + const AllreduceOptions& opts = AllreduceOptions()) override; + + void setSequenceNumberForGroup() override {} + uint64_t getSequenceNumberForGroup() override { + return seqCollective_; + } + + protected: + std::unordered_map xcclStreamsMap_; + std::unordered_map xcclEventsMap_; + std::unordered_map> devXCCLCommMap_; + c10::intrusive_ptr store_; + std::mutex mutex_; + bool blockingWait_ = false; + uint64_t seqCollective_{0}; + + private: + std::mutex kvs_mutex; + ccl::shared_ptr_class kvs; + + ccl::shared_ptr_class get_kvs(int rank, c10d::Store& store) { + std::lock_guard lock(kvs_mutex); + if (kvs) + return kvs; + std::string storeKey = "xccl_kvs"; + // Rank 0 broadcast the bootstrap network information to other ranks + if (rank == 0) { + kvs = ccl::create_main_kvs(); + ccl::kvs::address_type main_addr = kvs->get_address(); + auto ccl_kvs_addr = + std::vector(main_addr.begin(), main_addr.end()); + store.set(storeKey, ccl_kvs_addr); + } else { + auto ccl_kvs_addr = store.get(storeKey); + if (ccl_kvs_addr.size() != ccl::kvs::address_max_size) { + throw std::runtime_error("Unexpected ccl kvs addr from the store\n"); + } + ccl::kvs::address_type main_addr; + std::copy_n( + ccl_kvs_addr.begin(), ccl::kvs::address_max_size, main_addr.begin()); + kvs = ccl::create_kvs(main_addr); + } + return kvs; + } +}; +} // namespace c10d + +#endif // USE_C10D_XCCL diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index ea4a4653bc35f..e27ec363ba1cc 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -557,6 +557,31 @@ size_t computeLengthsAndOffsets( return offset; } +inline std::string reduceOpToString(c10d::ReduceOp op) { + switch (op) { + case c10d::ReduceOp::SUM: + return "SUM"; + case c10d::ReduceOp::PRODUCT: + return "PRODUCT"; + case c10d::ReduceOp::MIN: + return "MIN"; + case c10d::ReduceOp::MAX: + return "MAX"; + case c10d::ReduceOp::BAND: + return "BAND"; + case c10d::ReduceOp::BOR: + return "BOR"; + case c10d::ReduceOp::BXOR: + return "BXOR"; + case c10d::ReduceOp::AVG: + return "AVG"; + case c10d::ReduceOp::PREMUL_SUM: + return "PREMUL_SUM"; + default: + return "UNKNOWN"; + } +} + using RankType = uint32_t; using SizeType = uint64_t; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index a1ef541ce0025..0f7792e64e5fa 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -37,6 +37,10 @@ #include #endif +#ifdef USE_C10D_XCCL +#include +#endif + #include #include #include @@ -2233,6 +2237,7 @@ The hook must have the following signature: .value("UNDEFINED", ::c10d::ProcessGroup::BackendType::UNDEFINED) .value("GLOO", ::c10d::ProcessGroup::BackendType::GLOO) .value("NCCL", ::c10d::ProcessGroup::BackendType::NCCL) + .value("XCCL", ::c10d::ProcessGroup::BackendType::XCCL) .value("UCC", ::c10d::ProcessGroup::BackendType::UCC) .value("MPI", ::c10d::ProcessGroup::BackendType::MPI) .value("CUSTOM", ::c10d::ProcessGroup::BackendType::CUSTOM) @@ -2881,6 +2886,23 @@ Example:: py::call_guard()); #endif +#ifdef USE_C10D_XCCL + auto processGroupXCCL = + intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupXCCL>( + module, "ProcessGroupXCCL", backend) + .def( + py::init([](const c10::intrusive_ptr<::c10d::Store>& store, + int rank, + int size) { + return c10::make_intrusive<::c10d::ProcessGroupXCCL>( + store, rank, size); + }), + py::arg("store"), + py::arg("rank"), + py::arg("size"), + py::call_guard()); +#endif + py::enum_<::c10d::OpType>(module, "OpType") .value("BROADCAST", ::c10d::OpType::BROADCAST) .value("ALLREDUCE", ::c10d::OpType::ALLREDUCE) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index ea44a06df2d3e..fc4ca55dbd023 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -87,6 +87,7 @@ "is_nccl_available", "is_torchelastic_launched", "is_ucc_available", + "is_xccl_available", "isend", "monitored_barrier", "new_group", @@ -130,6 +131,7 @@ _NCCL_AVAILABLE = True _GLOO_AVAILABLE = True _UCC_AVAILABLE = True +_XCCL_AVAILABLE = True _pickler = pickle.Pickler _unpickler = pickle.Unpickler @@ -193,6 +195,14 @@ def _export_c_types() -> None: except ImportError: _UCC_AVAILABLE = False +try: + from torch._C._distributed_c10d import ProcessGroupXCCL + + ProcessGroupXCCL.__module__ = "torch.distributed.distributed_c10d" + __all__ += ["ProcessGroupXCCL"] +except ImportError: + _XCCL_AVAILABLE = False + logger = logging.getLogger(__name__) PG_WRAPPER_STORE_PREFIX = "pg_wrapper" @@ -222,7 +232,7 @@ class Backend(str): """ An enum-like class for backends. - Available backends: GLOO, NCCL, UCC, MPI, and other registered backends. + Available backends: GLOO, NCCL, UCC, MPI, XCCL, and other registered backends. The values of this class are lowercase strings, e.g., ``"gloo"``. They can be accessed as attributes, e.g., ``Backend.NCCL``. @@ -242,21 +252,24 @@ class Backend(str): NCCL = "nccl" UCC = "ucc" MPI = "mpi" + XCCL = "xccl" _BackendPlugin = namedtuple("_BackendPlugin", ["creator_fn", "extended_api"]) _plugins: Dict[str, _BackendPlugin] = {} - backend_list = [UNDEFINED, GLOO, NCCL, UCC, MPI] + backend_list = [UNDEFINED, GLOO, NCCL, XCCL, UCC, MPI] default_device_backend_map: Dict[str, str] = { "cpu": GLOO, "cuda": NCCL, + "xpu": XCCL, } backend_capability: Dict[str, List[str]] = { GLOO: ["cpu", "cuda"], NCCL: ["cuda"], + XCCL: ["xpu"], UCC: ["cpu", "cuda"], MPI: ["cpu", "cuda"], } @@ -265,6 +278,7 @@ class Backend(str): UNDEFINED: ProcessGroup.BackendType.UNDEFINED, GLOO: ProcessGroup.BackendType.GLOO, NCCL: ProcessGroup.BackendType.NCCL, + XCCL: ProcessGroup.BackendType.XCCL, UCC: ProcessGroup.BackendType.UCC, MPI: ProcessGroup.BackendType.MPI, } @@ -1099,6 +1113,11 @@ def is_ucc_available() -> bool: return _UCC_AVAILABLE +def is_xccl_available() -> bool: + """Check if the XCCL backend is available.""" + return _XCCL_AVAILABLE + + def is_backend_available(backend: str) -> bool: """ Check backend availability. @@ -1351,6 +1370,10 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> backends.add(backend) # type: ignore[arg-type] elif is_gloo_available() and isinstance(backend, ProcessGroupGloo): backends.add(backend) # type: ignore[arg-type] + if torch.device("xpu") in devices and is_xccl_available(): + backend = group._get_backend(torch.device("xpu")) + if isinstance(backend, ProcessGroupXCCL): + backends.add(backend) # type: ignore[arg-type] if len(backends) == 0: warnings.warn("Set timeout is now only supported for either nccl or gloo.") for backend in backends: @@ -1386,7 +1409,7 @@ def init_process_group( Args: backend (str or Backend, optional): The backend to use. Depending on - build-time configurations, valid values include ``mpi``, ``gloo``, + build-time configurations, valid values include ``mpi``, ``gloo``, ``xccl``, ``nccl``, and ``ucc``. If the backend is not provided, then both a ``gloo`` and ``nccl`` backend will be created, see notes below for how multiple backends are managed. This field can be given as a lowercase string @@ -1652,10 +1675,9 @@ def _new_process_group_helper( "created, please use a different group name" ) - if device_id is not None and (device_id.index is None or device_id.type != "cuda"): + if device_id is not None and device_id.index is None: raise ValueError( - "init_process_group device_id parameter must be a cuda device with an " - "id, e.g. cuda:0, not just cuda or cpu" + "init_process_group device_id parameter must be a device with an index" ) # Note: _new_process_group_helper is only called from init_process_group, which always provides a timeout value @@ -1790,6 +1812,17 @@ def _new_process_group_helper( backend_prefix_store, group_rank, group_size, timeout=timeout ) backend_type = ProcessGroup.BackendType.UCC + elif backend_str == Backend.XCCL: + if not is_xccl_available(): + raise RuntimeError("Distributed package doesn't have XCCL built in") + if backend_options is not None: + assert isinstance( + backend_options, ProcessGroupXCCL.Options + ), "Expected backend_options argument to be of type ProcessGroupXCCL.Options" + backend_class = ProcessGroupXCCL( + backend_prefix_store, group_rank, group_size + ) + backend_type = ProcessGroup.BackendType.XCCL else: assert ( backend_str.upper() in Backend._plugins diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index fb2a5c034b3e7..3e1664690b713 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -93,8 +93,9 @@ class DistTestCases: # Sets showing that something is implemented backend_feature = {} - backend_feature["gpu"] = {"nccl", "gloo", "ucc"} + backend_feature["gpu"] = {"nccl", "gloo", "ucc", "xccl"} backend_feature["cuda"] = {"nccl", "gloo", "ucc"} + backend_feature["cuda"] = {"xccl"} backend_feature["ddp"] = {"nccl", "gloo", "ucc"} backend_feature["subgroup"] = {"nccl", "gloo", "ucc"} backend_feature["plugin"] = set() @@ -180,7 +181,8 @@ def skip_if_lt_x_gpu(x): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): - if torch.cuda.is_available() and torch.cuda.device_count() >= x: + if (torch.cuda.is_available() and torch.cuda.device_count() >= x) or \ + (torch.xpu.is_available() and torch.xpu.device_count() >= x): return func(*args, **kwargs) sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code) @@ -320,6 +322,12 @@ def requires_nccl(): "c10d was not compiled with the NCCL backend", ) +def requires_xccl(): + return skip_but_pass_in_sandcastle_if( + not c10d.is_xccl_available(), + "c10d was not compiled with the XCCL backend", + ) + def requires_ucc(): return skip_but_pass_in_sandcastle_if( not c10d.is_ucc_available(), @@ -455,6 +463,15 @@ def compute_sum(fn, world_size: int): ] ] +# Returns the number of GPUs, currently only for CUDA and XPU. +def get_device_count(backend: str): + assert c10d.is_backend_available(backend) + if backend in backend_feature.get("cuda", set()): + return torch.cuda.device_count() + elif backend in backend_feature.get("xpu", set()): + return torch.xpu.device_count() + else: + raise ValueError(f"Unsupported backend: {backend}") # HELPER FOR MULTIGPU TESTS def init_multigpu_helper(world_size: int, backend: str): @@ -463,7 +480,7 @@ def init_multigpu_helper(world_size: int, backend: str): On a single node, all visible GPUs are evenly divided to subsets, each process only uses a subset. """ - nGPUs = torch.cuda.device_count() + nGPUs = get_device_count(backend) visible_devices = range(nGPUs) # If rank is less than or equal to number of available GPU's