Skip to content

Commit

Permalink
MPS support (#1706)
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Sep 28, 2024
1 parent 7c4c629 commit 3fddc56
Show file tree
Hide file tree
Showing 14 changed files with 54 additions and 41 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ pip install torchtune

```bash
# Install PyTorch, torchvision, torchao nightlies
pip install --pre --upgrade torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu121
pip install --pre --upgrade torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
pip install --pre --upgrade torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu
```

Expand Down
8 changes: 1 addition & 7 deletions recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,7 @@ def __init__(self, cfg: DictConfig) -> None:
raise ValueError(
"fp16 precision is not supported in this recipe. Please use fp32 or bf16."
)
# For CUDA devices, check if the HW supports bf16 if bf16 is specified.
if (
self._dtype == torch.bfloat16
and self._device != torch.device("cpu")
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError("Full bf16 training is not supported on this hardware.")

# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
Expand Down
8 changes: 1 addition & 7 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,7 @@ def __init__(self, cfg: DictConfig) -> None:
raise ValueError(
"fp16 precision is not supported in this recipe. Please use fp32 or bf16."
)
# For CUDA devices, check if the HW supports bf16 if bf16 is specified.
if (
self._dtype == torch.bfloat16
and self._device != torch.device("cpu")
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError("Full bf16 training is not supported on this hardware.")

# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
Expand Down
24 changes: 10 additions & 14 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,7 @@ def __init__(self, cfg: DictConfig) -> None:
raise ValueError(
"fp16 precision is not supported in this recipe. Please use fp32 or bf16."
)
# For CUDA devices, check if the HW supports bf16 if bf16 is specified.
if (
self._dtype == torch.bfloat16
and self._device != torch.device("cpu")
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError("Full bf16 training is not supported on this hardware.")

# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
Expand Down Expand Up @@ -542,13 +536,15 @@ def _setup_data(
batch_size=batch_size,
# dropping last avoids shape issues with compile + flex attention
drop_last=cfg_dataset.get("drop_last", True),
collate_fn=partial(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else padded_collate_packed,
collate_fn=(
partial(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else padded_collate_packed
),
)

log.info("Dataset and Sampler are initialized.")
Expand Down
3 changes: 2 additions & 1 deletion tests/recipes/dev/test_generate_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@

from tests.common import TUNE_PATH
from tests.recipes.utils import MODEL_TEST_CONFIGS, write_hf_ckpt_config
from tests.test_utils import CKPT_MODEL_PATHS, TOKENIZER_PATHS
from tests.test_utils import CKPT_MODEL_PATHS, mps_ignored_test, TOKENIZER_PATHS


class TestGenerateV2:
"""Recipe test suite for the generate_v2 recipe."""

@pytest.mark.integration_test
@mps_ignored_test()
def test_llama2_generate_results(self, caplog, monkeypatch, tmpdir):
ckpt = "llama2_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
Expand Down
2 changes: 2 additions & 0 deletions tests/recipes/test_ppo_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
CKPT_MODEL_PATHS,
gen_log_file_name,
get_loss_values_from_metric_logger,
mps_ignored_test,
)


Expand Down Expand Up @@ -52,6 +53,7 @@ def _get_test_config_overrides(self):
] + dummy_text_completion_alpaca_dataset_config()

@pytest.mark.integration_test
@mps_ignored_test()
def test_loss(self, tmpdir, monkeypatch):

reward_ckpt = "llama2_reward_hf"
Expand Down
8 changes: 8 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,3 +362,11 @@ def assert_dialogue_equal(actual, expected):
for i in range(len(actual)):
assert actual[i].role == expected[i].role
assert actual[i].text_content == expected[i].text_content


def mps_ignored_test() -> bool:
return pytest.mark.skipif(
torch.backends.mps.is_available() and torch.backends.mps.is_built(),
reason="Test skipped due to torch being compiled with MPS"
"see https://github.com/pytorch/torchtune/issues/1707 for more information",
)
6 changes: 5 additions & 1 deletion tests/torchtune/generation/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

import torch
from tests.test_utils import fixed_init_model
from tests.test_utils import fixed_init_model, mps_ignored_test

from torchtune.generation._generation import (
generate,
Expand Down Expand Up @@ -331,6 +331,7 @@ def test_reproducibility_batched(self, request, model1, model2, prompt1, prompt2
@pytest.mark.parametrize(
"prompt", ["prompt_tokens_batched", "prompt_tokens_batched_left_padded"]
)
@mps_ignored_test()
def test_stop_tokens_batched(self, request, model, prompt, expected_tokens_batched):
"""
Test to check if the `generate` function produces the right output when stop tokens are
Expand Down Expand Up @@ -362,6 +363,7 @@ def test_stop_tokens_batched(self, request, model, prompt, expected_tokens_batch
"model",
["generation_model_no_kv_cache", "generation_model_kv_cache"],
)
@mps_ignored_test()
def test_stop_tokens(self, request, model, prompt_tokens, expected_tokens):
"""
Test to check if the `generate` function produces the right output when stop tokens are
Expand Down Expand Up @@ -392,6 +394,7 @@ def test_stop_tokens(self, request, model, prompt_tokens, expected_tokens):
"model",
["generation_model_no_kv_cache", "generation_model_kv_cache_batched"],
)
@mps_ignored_test()
def test_stop_tokens_batched_uneven_stopping(
self, request, model, prompt_tokens_batched
):
Expand Down Expand Up @@ -430,6 +433,7 @@ def test_stop_tokens_batched_uneven_stopping(
"model",
["generation_model_no_kv_cache", "generation_model_kv_cache_batched"],
)
@mps_ignored_test()
def test_stop_tokens_batched_uneven_stopping_left_padded(
self, request, model, prompt_tokens_batched_left_padded
):
Expand Down
5 changes: 4 additions & 1 deletion tests/torchtune/models/llama3_1/test_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import torch

from tests.test_utils import assert_expected
from tests.test_utils import assert_expected, mps_ignored_test
from torch import tensor

from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
Expand Down Expand Up @@ -68,6 +68,7 @@ def test_cache_equality(self, input, rope) -> None:
assert_expected(cache.sum(), self.EXPECTED_FREQS_CIS_SUM, atol=1e-4)
assert_expected(cache.max(), self.EXPECTED_FREQS_CIS_MAX)

@mps_ignored_test()
def test_forward(self, input, rope) -> None:
x_out = rope(input)

Expand All @@ -79,6 +80,7 @@ def test_forward(self, input, rope) -> None:
# check shapes
assert_expected(x_out.shape, input.shape)

@mps_ignored_test()
def test_forward_with_curr_pos(self, input, rope) -> None:
(
_,
Expand All @@ -99,6 +101,7 @@ def test_forward_with_curr_pos(self, input, rope) -> None:
# check shapes
assert_expected(x_out.shape, input.shape)

@mps_ignored_test()
def test_forward_with_2d_pos_ids(self, input, rope) -> None:
"""
Use input_pos to indicate positions of each token relative to its sequence
Expand Down
6 changes: 5 additions & 1 deletion tests/torchtune/modules/test_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
import torch

from tests.test_utils import assert_expected
from tests.test_utils import assert_expected, mps_ignored_test
from torch import tensor
from torchtune.models.phi3 import Phi3RotaryPositionalEmbeddings

Expand Down Expand Up @@ -56,6 +56,7 @@ def rope(
_, _, head_dim, _, max_seq_len = input_params
return RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)

@mps_ignored_test()
def test_forward(self, input: tensor, rope: RotaryPositionalEmbeddings) -> None:
x_out = rope(input)

Expand All @@ -67,6 +68,7 @@ def test_forward(self, input: tensor, rope: RotaryPositionalEmbeddings) -> None:
# check shapes
assert_expected(x_out.shape, input.shape)

@mps_ignored_test()
def test_forward_with_curr_pos(
self, input: tensor, rope: RotaryPositionalEmbeddings
) -> None:
Expand All @@ -89,6 +91,7 @@ def test_forward_with_curr_pos(
# check shapes
assert_expected(x_out.shape, input.shape)

@mps_ignored_test()
def test_forward_with_packed_pos(
self, input: tensor, rope: RotaryPositionalEmbeddings
) -> None:
Expand Down Expand Up @@ -162,6 +165,7 @@ def rope_phi3(
_, _, head_dim, _, max_seq_len = input_params
return Phi3RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)

@mps_ignored_test()
def test_forward(
self, input: tensor, rope_phi3: Phi3RotaryPositionalEmbeddings
) -> None:
Expand Down
5 changes: 4 additions & 1 deletion tests/torchtune/modules/test_transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest

import torch
from tests.test_utils import assert_expected
from tests.test_utils import assert_expected, mps_ignored_test

from torch import nn

Expand Down Expand Up @@ -98,6 +98,7 @@ def transformer_layer(
transformer_layer.eval()
return transformer_layer

@mps_ignored_test()
def test_forward(
self, input: torch.Tensor, transformer_layer: TransformerSelfAttentionLayer
) -> None:
Expand Down Expand Up @@ -182,6 +183,7 @@ def transformer_layer(
transformer_layer.eval()
return transformer_layer

@mps_ignored_test()
def test_forward(
self,
input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
Expand Down Expand Up @@ -317,6 +319,7 @@ def decoder_with_kv_cache_enabled(
decoder.setup_caches(batch_size=4, dtype=torch.float32)
return decoder

@mps_ignored_test()
def test_forward(
self,
input: torch.Tensor,
Expand Down
9 changes: 5 additions & 4 deletions torchtune/generation/_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,6 @@ def generate(
"""
prompt = prompt.view(1, -1) if prompt.ndim == 1 else prompt

stop_tokens = (
torch.tensor(stop_tokens, device=prompt.device) if stop_tokens else None
)

if custom_generate_next_token is None:
custom_generate_next_token = generate_next_token

Expand Down Expand Up @@ -325,6 +321,11 @@ def generate(

# keeps track at a high level if we've already hit a stop token in a sequence so we can early stop
stop_token_reached = torch.zeros(bsz, dtype=torch.bool, device=prompt.device)
stop_tokens = (
torch.tensor(stop_tokens, device=prompt.device, dtype=tokens.dtype)
if stop_tokens
else None
)

# everything in stop_token_mask starts as 1s, and we'll set them to 0 for sequences
# that already hit a stop token
Expand Down
4 changes: 2 additions & 2 deletions torchtune/modules/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def update(
k_out = self.k_cache
v_out = self.v_cache

k_out.index_copy_(2, cache_pos, k_val)
v_out.index_copy_(2, cache_pos, v_val)
k_out[:, :, cache_pos] = k_val
v_out[:, :, cache_pos] = v_val

return k_out, v_out
5 changes: 4 additions & 1 deletion torchtune/training/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,20 @@ def verify_bf16_support() -> bool:
- CUDA version >= 11
- CUDA compute capability >= 8
- NCCL is available and version >= 2.10
- MPS is available and torch was built with MPS
Returns:
bool: True if bf16 is available, False otherwise.
"""
return (
cuda_support = (
torch.cuda.is_available()
and torch.cuda.is_bf16_supported()
and torch.distributed.is_nccl_available()
and torch.cuda.nccl.version() >= (2, 10)
)
mps_support = torch.backends.mps.is_available() and torch.backends.mps.is_built()
return cuda_support or mps_support


def get_dtype(
Expand Down

0 comments on commit 3fddc56

Please sign in to comment.