Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ascend NPU as a backend #1826

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ def setup(self, cfg: DictConfig) -> None:
# should be called before ``_setup_optimizer`` since transforming the optimizer
# state dict requires the model
self._compile = cfg.compile
if cfg.device == "npu" and cfg.compile:
raise ValueError(
"NPU does not support model compilation. Please set `compile: False` in the config."
)
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=self._enable_activation_checkpointing,
Expand Down Expand Up @@ -430,7 +434,7 @@ def _setup_model(

log.info(f"Model is initialized with precision {self._dtype}.")

if self._device.type == "cuda":
if self._device.type != "cpu":
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)

Expand Down Expand Up @@ -728,7 +732,7 @@ def train(self) -> None:
),
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
if self._device.type == "cuda" and self._log_peak_memory_stats:
if self._device.type != "cpu" and self._log_peak_memory_stats:
log_dict.update(
training.get_memory_stats(device=self._device)
)
Expand Down
7 changes: 6 additions & 1 deletion recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY

from tqdm import tqdm

log = utils.get_logger("DEBUG")
Expand Down Expand Up @@ -249,6 +250,10 @@ def setup(self, cfg: DictConfig) -> None:
self._metric_logger.log_config(cfg)

self._compile = cfg.compile
if cfg.device == "npu" and cfg.compile:
raise ValueError(
"NPU does not support model compilation. Please set `compile: False` in the config."
)
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)

# hack to toggle to the low cpu ram version of the reparametrize_as_dtype
Expand Down Expand Up @@ -468,7 +473,7 @@ def _setup_model(

log.info(f"Model is initialized with precision {self._dtype}.")

if self._device.type == "cuda":
if self._device.type != "cpu":
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
return model
Expand Down
26 changes: 23 additions & 3 deletions tests/torchtune/utils/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
import torch
from torchtune.utils._device import (
_get_device_type_from_env,
_setup_cuda_device,
_setup_device,
batch_to_device,
DeviceSupport,
get_device,
get_device_support,
get_torch_device,
)


Expand Down Expand Up @@ -69,7 +72,7 @@ def test_get_gpu_device(self) -> None:
if device_idx > 0:
with pytest.raises(
RuntimeError,
match=f"Device specified is cuda:0 but was assigned cuda:{device_idx}",
match=f"Device specified is cuda:0 but local rank is:{device_idx}",
):
device = get_device("cuda:0")

Expand All @@ -83,7 +86,24 @@ def test_get_gpu_device(self) -> None:

# Test that we fall back to 0 if LOCAL_RANK is not specified
device = torch.device(_get_device_type_from_env())
device = _setup_cuda_device(device)
device = _setup_device(device)
assert device.type == "cuda"
assert device.index == 0
assert device.index == torch.cuda.current_device()

@pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.")
@patch("torch.cuda.is_available", return_value=True)
def test_cuda_available(self, mock_cuda):
# Test if CUDA is available, get_device_support should return DeviceSupport.CUDA
device_support = get_device_support()
assert device_support == DeviceSupport.CUDA
assert device_support.device_type == "cuda"
assert device_support.device_name == "GPU"
assert device_support.communication_backend == "nccl"

@pytest.mark.skipif(not cuda_available, reason="The test requires GPUs to run.")
@patch("torch.cuda.is_available", return_value=True)
def test_get_torch_device_for_cuda(self, mock_cuda):
noemotiovon marked this conversation as resolved.
Show resolved Hide resolved
# Test if get_torch_device returns the correct torch.cuda module
torch_device = get_torch_device()
assert torch_device == torch.cuda
31 changes: 15 additions & 16 deletions torchtune/training/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.optim.lr_scheduler import LRScheduler
from torchtune.utils import get_logger
from torchtune.utils import get_device_support, get_logger, get_torch_device

_log: logging.Logger = get_logger()

Expand Down Expand Up @@ -45,11 +45,11 @@ def set_activation_checkpointing(

def cleanup_before_training() -> None:
"""
Call gc collect, empty CUDA cache, and reset peak memory stats.
Call gc collect, empty device cache, and reset peak memory stats.
"""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
get_torch_device().empty_cache()
get_torch_device().reset_peak_memory_stats()


class OptimizerInBackwardWrapper:
Expand Down Expand Up @@ -260,19 +260,17 @@ def get_memory_stats(device: torch.device, reset_stats: bool = True) -> dict:
Raises:
ValueError: If the passed-in device is not CUDA.
"""
if device.type != "cuda":
raise ValueError(
f"Logging memory stats is only supported on CUDA devices, got {device}"
)
if device.type == "cpu":
raise ValueError("Logging memory stats is not supported on CPU devices")

peak_memory_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / (
torch_device = get_torch_device()
peak_memory_active = torch_device.memory_stats().get("active_bytes.all.peak", 0) / (
1024**3
)
peak_mem_alloc = torch.cuda.max_memory_allocated(device) / (1024**3)
peak_mem_reserved = torch.cuda.max_memory_reserved(device) / (1024**3)

peak_mem_alloc = torch_device.max_memory_allocated(device) / (1024**3)
peak_mem_reserved = torch_device.max_memory_reserved(device) / (1024**3)
if reset_stats:
torch.cuda.reset_peak_memory_stats(device)
torch_device.reset_peak_memory_stats(device)

memory_stats = {
"peak_memory_active": peak_memory_active,
Expand All @@ -292,9 +290,10 @@ def log_memory_stats(stats: Dict[str, float]) -> None:
stats (Dict[str, float]): A dictionary containing the peak memory active, peak memory
allocated, and peak memory reserved stats.
"""
device_support = get_device_support()
_log.info(
"Memory stats after model init:"
f"\n\tGPU peak memory allocation: {stats['peak_memory_alloc']:.2f} GiB"
f"\n\tGPU peak memory reserved: {stats['peak_memory_reserved']:.2f} GiB"
f"\n\tGPU peak memory active: {stats['peak_memory_active']:.2f} GiB"
f"\n\t{device_support.device_name} peak memory allocation: {stats['peak_memory_alloc']:.2f} GiB"
f"\n\t{device_support.device_name} peak memory reserved: {stats['peak_memory_reserved']:.2f} GiB"
f"\n\t{device_support.device_name} peak memory active: {stats['peak_memory_active']:.2f} GiB"
)
6 changes: 5 additions & 1 deletion torchtune/training/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import torch

from torchtune.utils import get_logger
from torchtune.utils._device import is_npu_available

log = get_logger()


PRECISION_STR_TO_DTYPE: Dict[str, torch.dtype] = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
Expand Down Expand Up @@ -50,6 +52,7 @@ def verify_bf16_support() -> bool:
- CUDA compute capability >= 8
- NCCL is available and version >= 2.10
- MPS is available and torch was built with MPS
- NPU is available and supports bf16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a bit redundant. Do we know the exact requirements for bf16 support on NPUs?


Returns:
bool: True if bf16 is available, False otherwise.
Expand All @@ -62,7 +65,8 @@ def verify_bf16_support() -> bool:
and torch.cuda.nccl.version() >= (2, 10)
)
mps_support = torch.backends.mps.is_available() and torch.backends.mps.is_built()
return cuda_support or mps_support
npu_support = is_npu_available and torch.npu.is_bf16_supported()
return cuda_support or mps_support or npu_support


def get_dtype(
Expand Down
11 changes: 10 additions & 1 deletion torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._device import batch_to_device, get_device
from ._device import (
batch_to_device,
DeviceSupport,
get_device,
get_device_support,
get_torch_device,
)
from ._logging import get_logger

from ._version import torch_version_ge
Expand All @@ -14,4 +20,7 @@
"get_device",
"get_logger",
"torch_version_ge",
"get_device_support",
"get_torch_device",
"DeviceSupport",
]
Loading
Loading