Skip to content

Commit

Permalink
Llama3-70b: Full Finetune w/CPU offload + fused optimizer (pytorch#993)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohan-varma authored Jun 1, 2024
1 parent 135cf2e commit eac2dc5
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 8 deletions.
110 changes: 110 additions & 0 deletions recipes/configs/llama3/70B_full.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Config for multi-device full finetuning in full_finetune_distributed.py
# using a Llama3 70B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3-70B-Instruct --output-dir /tmp/Meta-Llama-3-70B-Instruct --hf-token <HF_TOKEN> --ignore-patterns "original/consolidated*"
#
# To launch on 8 devices, run the following command from root:
# tune run --nproc_per_node 8 full_finetune_distributed --config llama3/70B_full
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nproc_per_node 8 full_finetune_distributed --config llama3/70B_full checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config is only tested on an 8xA100 machine.


# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3-70B-Instruct/original/tokenizer.model

# Dataset
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: True
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.llama3.llama3_70b

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct
checkpoint_files: [
model-00001-of-00030.safetensors,
model-00002-of-00030.safetensors,
model-00003-of-00030.safetensors,
model-00004-of-00030.safetensors,
model-00005-of-00030.safetensors,
model-00006-of-00030.safetensors,
model-00007-of-00030.safetensors,
model-00008-of-00030.safetensors,
model-00009-of-00030.safetensors,
model-00010-of-00030.safetensors,
model-00011-of-00030.safetensors,
model-00012-of-00030.safetensors,
model-00013-of-00030.safetensors,
model-00014-of-00030.safetensors,
model-00015-of-00030.safetensors,
model-00016-of-00030.safetensors,
model-00017-of-00030.safetensors,
model-00018-of-00030.safetensors,
model-00019-of-00030.safetensors,
model-00020-of-00030.safetensors,
model-00021-of-00030.safetensors,
model-00022-of-00030.safetensors,
model-00023-of-00030.safetensors,
model-00024-of-00030.safetensors,
model-00025-of-00030.safetensors,
model-00026-of-00030.safetensors,
model-00027-of-00030.safetensors,
model-00028-of-00030.safetensors,
model-00029-of-00030.safetensors,
model-00030-of-00030.safetensors,
]
recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3-70b
model_type: LLAMA3
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 2
epochs: 3

optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
foreach: False
# Note: highly recommended to use fused=True optimizer flag
# with CPU offload for faster optimizer step.
fused: True

loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1


# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True
memory_efficient_fsdp_wrap: True
fsdp_cpu_offload: True

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/alpaca-llama3-finetune
log_every_n_steps: 1
log_peak_memory_stats: False
17 changes: 17 additions & 0 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch import nn
from torch.distributed import init_process_group
from torch.distributed.fsdp import (
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
Expand Down Expand Up @@ -103,6 +104,15 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

if (
cfg.get("fsdp_cpu_offload", False)
and cfg.get("fused", False)
and not utils.torch_version_ge("2.4.0")
):
raise RuntimeError(
"Using fused optimizer on CPU is only supported in PyTorch nightly."
)

# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
Expand Down Expand Up @@ -186,6 +196,7 @@ def setup(self, cfg: DictConfig) -> None:
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False),
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
model_state_dict=ckpt_dict[utils.MODEL_KEY],
ac_mode=cfg.get("ac_mode", None),
ac_option=cfg.get("ac_option", None),
Expand Down Expand Up @@ -234,6 +245,7 @@ def _setup_model(
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
memory_efficient_fsdp_wrap: bool,
fsdp_cpu_offload: bool,
model_state_dict: Dict[str, Any],
ac_mode: Optional[str] = None,
ac_option: Optional[int] = None,
Expand Down Expand Up @@ -296,6 +308,7 @@ def _setup_model(
memory_efficient_fsdp_wrap=memory_efficient_fsdp_wrap,
modules_to_wrap={modules.TransformerDecoderLayer},
),
cpu_offload=CPUOffload(offload_params=fsdp_cpu_offload),
sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD,
device_id=self._device,
# this recipe does not currently support mixed precision training
Expand Down Expand Up @@ -563,6 +576,10 @@ def recipe_main(cfg: DictConfig) -> None:
)

init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
utils.set_torch_num_threads()

config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg)

Expand Down
2 changes: 1 addition & 1 deletion tests/recipes/test_lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
gen_log_file_name,
get_loss_values_from_metric_logger,
TOKENIZER_PATHS,
torch_version_ge,
)
from torchtune import config
from torchtune.utils import torch_version_ge


class TestLoRAFinetuneSingleDeviceRecipe:
Expand Down
7 changes: 0 additions & 7 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@
}


def torch_version_ge(version: str) -> bool:
"""
Check if torch version is greater than or equal to the given version
"""
return version in torch.__version__ or torch.__version__ >= version


# Inherit from SentencePieceTokenizer class to reuse its tokenize_messages method
class DummyTokenizer(SentencePieceTokenizer):
def __init__(self):
Expand Down
1 change: 1 addition & 0 deletions torchtune/_recipe_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class Recipe:
Config(name="llama2/7B_full", file_path="llama2/7B_full.yaml"),
Config(name="llama2/13B_full", file_path="llama2/13B_full.yaml"),
Config(name="llama3/8B_full", file_path="llama3/8B_full.yaml"),
Config(name="llama3/70B_full", file_path="llama3/70B_full.yaml"),
Config(name="mistral/7B_full", file_path="mistral/7B_full.yaml"),
Config(name="gemma/2B_full", file_path="gemma/2B_full.yaml"),
Config(name="gemma/7B_full", file_path="gemma/7B_full.yaml"),
Expand Down
3 changes: 3 additions & 0 deletions torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
is_distributed,
lora_fsdp_wrap_policy,
prepare_model_for_fsdp_with_meta_device,
set_torch_num_threads,
validate_no_params_on_meta_device,
)
from ._generation import generate
from ._profiler import profiler
from ._version import torch_version_ge
from .argparse import TuneRecipeArgumentParser
from .collate import padded_collate, padded_collate_dpo
from .constants import ( # noqa
Expand Down Expand Up @@ -79,6 +81,7 @@
"set_seed",
"validate_expected_param_dtype",
"TuneRecipeArgumentParser",
"torch_version_ge",
"OptimizerInBackwardWrapper",
"create_optim_in_bwd_wrapper",
"register_optim_in_bwd_hooks",
Expand Down
16 changes: 16 additions & 0 deletions torchtune/utils/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,22 @@ def init_distributed(**kwargs: Dict) -> bool: # noqa: DOC106, DOC109
return False


def set_torch_num_threads() -> None:
"""
Sets the number of threads used by torch to utilize all physical CPU
cores for intra-op parallelism. Currently, this function sets num_threads
to be the number of physical CPU cores divided by the number of GPUs as we
use one process per GPU, and this avoids CPU oversubscription. Note that this is
currently a rough approximation, and doesn't take into account environments where
things like CPU affinity is set.
"""
num_threads = os.cpu_count() // (
torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
)
torch.set_num_threads(num_threads)
_log.info(f"Set intra op parallelism no. of threads to {num_threads}")


def get_world_size_and_rank() -> Tuple[int, int]:
"""Function that gets the current world size (aka total number
of ranks) and rank number of the current trainer.
Expand Down
13 changes: 13 additions & 0 deletions torchtune/utils/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch


def torch_version_ge(version: str) -> bool:
"""
Check if torch version is greater than or equal to the given version
"""
return version in torch.__version__ or torch.__version__ >= version

0 comments on commit eac2dc5

Please sign in to comment.