diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 77f6bdc331..c07b4fd953 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -26,7 +26,7 @@ from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY from torchtune.training.activations import apply_selective_activation_checkpointing -from torchtune.utils import get_torch_device +from torchtune.utils import DeviceSupport, get_torch_device from tqdm import tqdm @@ -743,7 +743,8 @@ def recipe_main(cfg: DictConfig) -> None: "Distributed finetune recipe should be run via a distributed launcher." "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" ) - init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + device_support = DeviceSupport.from_type(cfg.device) + init_process_group(backend=device_support.communication_backend) if cfg.get("fsdp_cpu_offload", False): # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index e903ab274a..c2515c33d9 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -33,6 +33,7 @@ ) from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.rlhf.loss import SimPOLoss +from torchtune.utils import DeviceSupport from tqdm import tqdm log = utils.get_logger("DEBUG") @@ -759,7 +760,8 @@ def recipe_main(cfg: DictConfig) -> None: # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU training.set_torch_num_threads() - init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + device_support = DeviceSupport.from_type(cfg.device) + init_process_group(backend=device_support.communication_backend) config.log_config(recipe_name="LoRADPORecipeDistributed", cfg=cfg) diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 44f3345079..2d9e47c116 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -41,7 +41,7 @@ OffloadActivations, PROFILER_KEY, ) -from torchtune.utils import get_torch_device +from torchtune.utils import DeviceSupport, get_torch_device from tqdm import tqdm @@ -897,13 +897,8 @@ def recipe_main(cfg: DictConfig) -> None: # speed up when benchmarking fused AdamW on CPU training.set_torch_num_threads() - if cfg.device == "cpu": - backend = "gloo" - elif cfg.device == "npu": - backend = "hccl" - else: - backend = "nccl" - init_process_group(backend=backend) + device_support = DeviceSupport.from_type(cfg.device) + init_process_group(backend=device_support.communication_backend) config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg) diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index 9302bea9ab..dca7ee547d 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -782,7 +782,9 @@ def recipe_main(cfg: DictConfig) -> None: "Distributed QAT recipe should be run via a distributed launcher." "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" ) - init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + + device_support = DeviceSupport.from_type(cfg.device) + init_process_group(backend=device_support.communication_backend) if cfg.get("fsdp_cpu_offload", False): # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU diff --git a/tests/torchtune/utils/test_device.py b/tests/torchtune/utils/test_device.py index 6aa5ecca59..df54b0e980 100644 --- a/tests/torchtune/utils/test_device.py +++ b/tests/torchtune/utils/test_device.py @@ -16,7 +16,10 @@ _get_device_type_from_env, _setup_device, batch_to_device, + DeviceSupport, get_device, + get_device_support, + get_torch_device, ) @@ -87,3 +90,20 @@ def test_get_gpu_device(self) -> None: assert device.type == "cuda" assert device.index == 0 assert device.index == torch.cuda.current_device() + + @pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.") + @patch("torch.cuda.is_available", return_value=True) + def test_cuda_available(self, mock_cuda): + # Test if CUDA is available, get_device_support should return DeviceSupport.CUDA + device_support = get_device_support() + assert device_support == DeviceSupport.CUDA + assert device_support.device_type == "cuda" + assert device_support.device_name == "GPU" + assert device_support.communication_backend == "nccl" + + @pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.") + @patch("torch.cuda.is_available", return_value=True) + def test_get_torch_device_for_cuda(self, mock_cuda): + # Test if get_torch_device returns the correct torch.cuda module + torch_device = get_torch_device("cuda") + assert torch_device == torch.cuda diff --git a/tests/torchtune/utils/test_device_support.py b/tests/torchtune/utils/test_device_support.py deleted file mode 100644 index 321c9a1265..0000000000 --- a/tests/torchtune/utils/test_device_support.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from unittest.mock import patch - -import pytest - -import torch -from torchtune.utils._device_support import ( - DeviceSupport, - get_device_support, - get_torch_device, -) - - -class TestDevice: - - cuda_available: bool = torch.cuda.is_available() - - @pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.") - @patch("torch.cuda.is_available", return_value=True) - def test_cuda_available(self, mock_cuda): - # Test if CUDA is available, get_device_support should return DeviceSupport.CUDA - device_support = get_device_support() - assert device_support == DeviceSupport.CUDA - assert device_support.device_type == "cuda" - assert device_support.device_name == "GPU" - assert device_support.device_backend == "nccl" - - @pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.") - @patch("torch.cuda.is_available", return_value=True) - def test_get_torch_device_for_cuda(self, mock_cuda): - # Test if get_torch_device returns the correct torch.cuda module - torch_device = get_torch_device("cuda") - assert torch_device == torch.cuda diff --git a/torchtune/training/precision.py b/torchtune/training/precision.py index 2e32e35096..08b311c04b 100644 --- a/torchtune/training/precision.py +++ b/torchtune/training/precision.py @@ -56,8 +56,6 @@ def verify_bf16_support() -> bool: bool: True if bf16 is available, False otherwise. """ - if is_npu_available: - return torch.npu.is_bf16_supported() cuda_support = ( torch.cuda.is_available() and torch.cuda.is_bf16_supported() @@ -65,7 +63,8 @@ def verify_bf16_support() -> bool: and torch.cuda.nccl.version() >= (2, 10) ) mps_support = torch.backends.mps.is_available() and torch.backends.mps.is_built() - return cuda_support or mps_support + npu_support = is_npu_available and torch.npu.is_bf16_supported() + return cuda_support or mps_support or npu_support def get_dtype( diff --git a/torchtune/utils/__init__.py b/torchtune/utils/__init__.py index 4649675308..5b515a388b 100644 --- a/torchtune/utils/__init__.py +++ b/torchtune/utils/__init__.py @@ -4,8 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._device import batch_to_device, get_device -from ._device_support import get_device_support, get_torch_device, is_npu_available +from ._device import ( + batch_to_device, + DeviceSupport, + get_device, + get_device_support, + get_torch_device, + is_npu_available, +) from ._logging import get_logger from ._version import torch_version_ge @@ -18,4 +24,5 @@ "is_npu_available", "get_device_support", "get_torch_device", + "DeviceSupport", ] diff --git a/torchtune/utils/_device.py b/torchtune/utils/_device.py index 6f97eb41a6..e01d235aa3 100644 --- a/torchtune/utils/_device.py +++ b/torchtune/utils/_device.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import os +from enum import Enum from typing import Optional import torch @@ -16,11 +17,18 @@ else: BlockMask = torch.Tensor -from torchtune.utils._device_support import ( - get_device_support, - get_torch_device, - is_npu_available, -) + +def is_torch_npu_available() -> bool: + """Check the availability of NPU""" + try: + import torch_npu # noqa: F401 + + return torch.npu.is_available() + except ImportError: + return False + + +is_npu_available = is_torch_npu_available() def _get_local_rank() -> Optional[int]: @@ -167,3 +175,54 @@ def batch_to_device(batch: dict, device: torch.device) -> None: f"""To use batch_to_device, all elements in the batch must be a dict or Tensor. Got key "{k}" with value of type {type(v)}""" ) + + +class DeviceSupport(Enum): + """ + This is a simple enum for compute devices, + This currently only supports CPU, CUDA, NPU. + """ + + CPU = ("cpu", "CPU", "gloo") + CUDA = ("cuda", "GPU", "nccl") + NPU = ("npu", "NPU", "hccl") + + def __init__(self, device_type: str, device_name: str, communication_backend: str): + self.device_type = device_type + self.device_name = device_name + self.communication_backend = communication_backend + + @staticmethod + def from_type(device_type: str): + for member in DeviceSupport: + if member.device_type == device_type: + return member + raise ValueError(f"Unknown device type: {device_type}.") + + +def get_device_support() -> DeviceSupport: + """function that gets the DeviceSupport with compute devices based on the current machine. + + This currently only supports CPU, CUDA, NPU. + + Returns: + device_support: DeviceSupport + """ + device_type = _get_device_type_from_env() + return DeviceSupport.from_type(device_type) + + +def get_torch_device() -> any: + """Return the corresponding torch attribute based on the device type string. + + Returns: + module: The corresponding torch module, or torch.cuda if not found. + """ + device_type = get_device_support().device_type + try: + return getattr(torch, device_type) + except AttributeError: + print( + f"Device Module '{device_type}' not found in torch, try to load torch.cuda." + ) + return torch.cuda diff --git a/torchtune/utils/_device_support.py b/torchtune/utils/_device_support.py deleted file mode 100644 index 5de80abad2..0000000000 --- a/torchtune/utils/_device_support.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from enum import Enum -from typing import Optional - -import torch - - -def is_torch_npu_available() -> bool: - """Check the availability of NPU""" - try: - import torch_npu # noqa: F401 - - return torch.npu.is_available() - except ImportError: - return False - - -is_npu_available = is_torch_npu_available() - - -class DeviceSupport(Enum): - """ - This is a simple enum for Non-CPU compute devices, - which enables custom backends that implement CUDA-like semantics. - """ - - CUDA = ("cuda", "GPU", "nccl") - NPU = ("npu", "NPU", "hccl") - - def __init__(self, device_type: str, device_name: str, device_backend: str): - self.device_type = device_type - self.device_name = device_name - self.device_backend = device_backend - - @staticmethod - def from_type(device_type: str): - for member in DeviceSupport: - if member.device_type == device_type: - return member - raise ValueError(f"Unknown device type: {device_type}.") - - -def _get_device_support_from_env() -> DeviceSupport: - """function that gets the DeviceSupport with Non-CPU compute devices based on the current machine. - - This currently only supports CUDA, NPU. - - Raises: - RuntimeError: If Non-CPU compute devices is not available. - - Returns: - device_support: DeviceSupport - """ - if is_npu_available: - return DeviceSupport.NPU - elif torch.cuda.is_available(): - return DeviceSupport.CUDA - else: - raise RuntimeError("No available device found.") - - -def get_device_support(device_type: Optional[str] = None) -> DeviceSupport: - """Function that takes an optional device string, verifies it's correct and available given the machine and - distributed settings, and returns a enum:`DeviceSupport`. If device string is not provided, this function will - infer the device based on the environment. - - Args: - device_type (Optional[str]): The name of the device to use, e.g. "cuda" or "npu". - - Example: - >>> device_support = get_device_support("cuda") - >>> device_support - device_support(type='cuda', name='GPU') - - Returns: - device_support: DeviceSupport - """ - if device_type is not None: - device_support = DeviceSupport.from_type(device_type) - else: - device_support = _get_device_support_from_env() - return device_support - - -def get_torch_device(device_type: Optional[str] = None) -> any: - """Return the corresponding torch attribute based on the device type string. - - Args: - device_type(Optional[str]): The device type name, e.g., 'cuda', 'npu'. - - Returns: - module: The corresponding torch module, or None if not found. - """ - if device_type is None: - device_type = get_device_support().device_type - try: - return getattr(torch, device_type) - except AttributeError: - print(f"Device Module '{device_type}' not found in torch.") - return None