diff --git a/recipes/configs/llama3/70B_full.yaml b/recipes/configs/llama3/70B_full.yaml new file mode 100644 index 0000000000..6bb0c0416f --- /dev/null +++ b/recipes/configs/llama3/70B_full.yaml @@ -0,0 +1,110 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Llama3 70B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3-70B-Instruct --output-dir /tmp/Meta-Llama-3-70B-Instruct --hf-token --ignore-patterns "original/consolidated*" +# +# To launch on 8 devices, run the following command from root: +# tune run --nproc_per_node 8 full_finetune_distributed --config llama3/70B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 8 full_finetune_distributed --config llama3/70B_full checkpointer.checkpoint_dir= +# +# This config is only tested on an 8xA100 machine. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3-70B-Instruct/original/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset + train_on_input: True +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama3.llama3_70b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct + checkpoint_files: [ + model-00001-of-00030.safetensors, + model-00002-of-00030.safetensors, + model-00003-of-00030.safetensors, + model-00004-of-00030.safetensors, + model-00005-of-00030.safetensors, + model-00006-of-00030.safetensors, + model-00007-of-00030.safetensors, + model-00008-of-00030.safetensors, + model-00009-of-00030.safetensors, + model-00010-of-00030.safetensors, + model-00011-of-00030.safetensors, + model-00012-of-00030.safetensors, + model-00013-of-00030.safetensors, + model-00014-of-00030.safetensors, + model-00015-of-00030.safetensors, + model-00016-of-00030.safetensors, + model-00017-of-00030.safetensors, + model-00018-of-00030.safetensors, + model-00019-of-00030.safetensors, + model-00020-of-00030.safetensors, + model-00021-of-00030.safetensors, + model-00022-of-00030.safetensors, + model-00023-of-00030.safetensors, + model-00024-of-00030.safetensors, + model-00025-of-00030.safetensors, + model-00026-of-00030.safetensors, + model-00027-of-00030.safetensors, + model-00028-of-00030.safetensors, + model-00029-of-00030.safetensors, + model-00030-of-00030.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3-70b + model_type: LLAMA3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + foreach: False + # Note: highly recommended to use fused=True optimizer flag + # with CPU offload for faster optimizer step. + fused: True + +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +memory_efficient_fsdp_wrap: True +fsdp_cpu_offload: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/alpaca-llama3-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 7d1dd5b42a..1d0717a035 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -17,6 +17,7 @@ from torch import nn from torch.distributed import init_process_group from torch.distributed.fsdp import ( + CPUOffload, FullOptimStateDictConfig, FullStateDictConfig, FullyShardedDataParallel as FSDP, @@ -103,6 +104,15 @@ def __init__(self, cfg: DictConfig) -> None: "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." ) + if ( + cfg.get("fsdp_cpu_offload", False) + and cfg.get("fused", False) + and not utils.torch_version_ge("2.4.0") + ): + raise RuntimeError( + "Using fused optimizer on CPU is only supported in PyTorch nightly." + ) + # logging attributes self._output_dir = cfg.output_dir self._log_every_n_steps = cfg.get("log_every_n_steps", 1) @@ -186,6 +196,7 @@ def setup(self, cfg: DictConfig) -> None: cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False), + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), model_state_dict=ckpt_dict[utils.MODEL_KEY], ac_mode=cfg.get("ac_mode", None), ac_option=cfg.get("ac_option", None), @@ -234,6 +245,7 @@ def _setup_model( cfg_model: DictConfig, enable_activation_checkpointing: bool, memory_efficient_fsdp_wrap: bool, + fsdp_cpu_offload: bool, model_state_dict: Dict[str, Any], ac_mode: Optional[str] = None, ac_option: Optional[int] = None, @@ -296,6 +308,7 @@ def _setup_model( memory_efficient_fsdp_wrap=memory_efficient_fsdp_wrap, modules_to_wrap={modules.TransformerDecoderLayer}, ), + cpu_offload=CPUOffload(offload_params=fsdp_cpu_offload), sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, device_id=self._device, # this recipe does not currently support mixed precision training @@ -563,6 +576,10 @@ def recipe_main(cfg: DictConfig) -> None: ) init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + if cfg.get("fsdp_cpu_offload", False): + # Utilize all available CPU cores for intra-op parallelism. This provides ~2x + # speed up when benchmarking fused AdamW on CPU + utils.set_torch_num_threads() config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg) diff --git a/tests/recipes/test_lora_finetune_single_device.py b/tests/recipes/test_lora_finetune_single_device.py index f0c4bc42c1..fb9fdc21c4 100644 --- a/tests/recipes/test_lora_finetune_single_device.py +++ b/tests/recipes/test_lora_finetune_single_device.py @@ -24,9 +24,9 @@ gen_log_file_name, get_loss_values_from_metric_logger, TOKENIZER_PATHS, - torch_version_ge, ) from torchtune import config +from torchtune.utils import torch_version_ge class TestLoRAFinetuneSingleDeviceRecipe: diff --git a/tests/test_utils.py b/tests/test_utils.py index cae78abed9..b769af0fec 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -38,13 +38,6 @@ } -def torch_version_ge(version: str) -> bool: - """ - Check if torch version is greater than or equal to the given version - """ - return version in torch.__version__ or torch.__version__ >= version - - # Inherit from SentencePieceTokenizer class to reuse its tokenize_messages method class DummyTokenizer(SentencePieceTokenizer): def __init__(self): diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 6fa60736fb..1230d4291f 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -57,6 +57,7 @@ class Recipe: Config(name="llama2/7B_full", file_path="llama2/7B_full.yaml"), Config(name="llama2/13B_full", file_path="llama2/13B_full.yaml"), Config(name="llama3/8B_full", file_path="llama3/8B_full.yaml"), + Config(name="llama3/70B_full", file_path="llama3/70B_full.yaml"), Config(name="mistral/7B_full", file_path="mistral/7B_full.yaml"), Config(name="gemma/2B_full", file_path="gemma/2B_full.yaml"), Config(name="gemma/7B_full", file_path="gemma/7B_full.yaml"), diff --git a/torchtune/utils/__init__.py b/torchtune/utils/__init__.py index f8964ed6ba..41c65e62ca 100644 --- a/torchtune/utils/__init__.py +++ b/torchtune/utils/__init__.py @@ -22,10 +22,12 @@ is_distributed, lora_fsdp_wrap_policy, prepare_model_for_fsdp_with_meta_device, + set_torch_num_threads, validate_no_params_on_meta_device, ) from ._generation import generate from ._profiler import profiler +from ._version import torch_version_ge from .argparse import TuneRecipeArgumentParser from .collate import padded_collate, padded_collate_dpo from .constants import ( # noqa @@ -79,6 +81,7 @@ "set_seed", "validate_expected_param_dtype", "TuneRecipeArgumentParser", + "torch_version_ge", "OptimizerInBackwardWrapper", "create_optim_in_bwd_wrapper", "register_optim_in_bwd_hooks", diff --git a/torchtune/utils/_distributed.py b/torchtune/utils/_distributed.py index 85085fe694..fad42d1560 100644 --- a/torchtune/utils/_distributed.py +++ b/torchtune/utils/_distributed.py @@ -117,6 +117,22 @@ def init_distributed(**kwargs: Dict) -> bool: # noqa: DOC106, DOC109 return False +def set_torch_num_threads() -> None: + """ + Sets the number of threads used by torch to utilize all physical CPU + cores for intra-op parallelism. Currently, this function sets num_threads + to be the number of physical CPU cores divided by the number of GPUs as we + use one process per GPU, and this avoids CPU oversubscription. Note that this is + currently a rough approximation, and doesn't take into account environments where + things like CPU affinity is set. + """ + num_threads = os.cpu_count() // ( + torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + ) + torch.set_num_threads(num_threads) + _log.info(f"Set intra op parallelism no. of threads to {num_threads}") + + def get_world_size_and_rank() -> Tuple[int, int]: """Function that gets the current world size (aka total number of ranks) and rank number of the current trainer. diff --git a/torchtune/utils/_version.py b/torchtune/utils/_version.py new file mode 100644 index 0000000000..530f74f805 --- /dev/null +++ b/torchtune/utils/_version.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch + + +def torch_version_ge(version: str) -> bool: + """ + Check if torch version is greater than or equal to the given version + """ + return version in torch.__version__ or torch.__version__ >= version