diff --git a/recipes/configs/llama2/7B_qat_full.yaml b/recipes/configs/llama2/7B_qat_full.yaml index d1a408aca..fc2699a2f 100644 --- a/recipes/configs/llama2/7B_qat_full.yaml +++ b/recipes/configs/llama2/7B_qat_full.yaml @@ -66,7 +66,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -memory_efficient_fsdp_wrap: False +enable_activation_offloading: False # True reduces memory # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama3/8B_qat_full.yaml b/recipes/configs/llama3/8B_qat_full.yaml index 07461e824..4878095d9 100644 --- a/recipes/configs/llama3/8B_qat_full.yaml +++ b/recipes/configs/llama3/8B_qat_full.yaml @@ -44,7 +44,6 @@ resume_from_checkpoint: False # Fine-tuning arguments batch_size: 2 epochs: 3 -compile: False # QAT arguments quantizer: @@ -59,13 +58,15 @@ loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 +compile: False # Training env device: cuda # Memory management enable_activation_checkpointing: True -memory_efficient_fsdp_wrap: True +enable_activation_offloading: False # True reduces memory +custom_sharded_layers: ['tok_embeddings', 'output'] # Reduced precision dtype: bf16 @@ -74,6 +75,6 @@ dtype: bf16 metric_logger: _component_: torchtune.training.metric_logging.DiskLogger log_dir: ${output_dir} -output_dir: /tmp/alpaca-llama3-finetune +output_dir: /tmp/full-llama3-finetune log_every_n_steps: 1 log_peak_memory_stats: True diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index 4126f95bd..2f424b488 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os import sys import time @@ -21,11 +20,13 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_packed, padded_collate_sft +from torchtune.config._utils import _get_component_from_path +from torchtune.data import padded_collate_packed from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY from torchtune.training.activations import apply_selective_activation_checkpointing +from torchtune.training.lr_schedulers import get_lr from tqdm import tqdm @@ -50,18 +51,30 @@ class QATRecipeDistributed(FTRecipeInterface): to improved quantized accuracy. This can be specified through ``fake_quant_after_n_steps``. - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states - is supported via the ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). DDP is currently not supported. Training on CPU is not supported. - - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep activations in memory and instead recompute them during the backward pass. This is especially helpful for larger batch sizes when you're memory constrained. But these savings in memory come at the cost of training performance. In most cases training can slow-down quite a bit as a result of this activation recomputation. + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5 or later and will + be enabled by default if an acceptable torch version is found. Activation offloading can be + used in conjunction with activation checkpointing. + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In most cases this should halve the memory footprint of full precision (fp32) training, without @@ -93,6 +106,10 @@ class QATRecipeDistributed(FTRecipeInterface): - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config has example commands for how to kick-off training. @@ -102,6 +119,9 @@ class QATRecipeDistributed(FTRecipeInterface): Raises: ValueError: If ``dtype`` is set to fp16. RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``left_pad_sequence`` is set as the data collator. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. """ def __init__(self, cfg: DictConfig) -> None: @@ -141,12 +161,46 @@ def __init__(self, cfg: DictConfig) -> None: # Training cfg self._resume_from_checkpoint = cfg.resume_from_checkpoint self._gradient_accumulation_steps = cfg.gradient_accumulation_steps - self._fsdp_sharding_strategy = torch.distributed.fsdp.ShardingStrategy[ - cfg.get("fsdp_sharding_strategy", "FULL_SHARD") - ] + self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) + self._clip_grad_norm = cfg.get("clip_grad_norm", None) self._fake_quant_after_n_steps = cfg.get("fake_quant_after_n_steps", None) self._quantizer_mode = None + # Optimizer in backward is not compatible with gradient accumulation or gradient clipping + if self._optimizer_in_bwd: + if self._clip_grad_norm is not None: + raise RuntimeError( + "Gradient clipping is not supported with optimizer in bwd." + "Please set clip_grad_norm=None, or optimizer_in_bwd=False." + ) + if self._gradient_accumulation_steps > 1: + raise RuntimeError( + "Gradient accumulation is not supported with optimizer in bwd." + "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False." + ) + + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif self._enable_activation_checkpointing: + log.info( + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further." + ) + # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests self.seed = training.set_seed(seed=cfg.seed) @@ -223,10 +277,11 @@ def setup(self, cfg: DictConfig) -> None: checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) - self._model_compile = cfg.get("compile", False) + self._compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, - enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), @@ -239,6 +294,7 @@ def setup(self, cfg: DictConfig) -> None: self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, + optimizer_in_bwd=self._optimizer_in_bwd, opt_state_dict=( checkpoint_dict[training.OPT_KEY] if self._resume_from_checkpoint @@ -248,30 +304,25 @@ def setup(self, cfg: DictConfig) -> None: # initialize loss self._loss_fn = config.instantiate(cfg.loss) - backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + + if self._compile: + training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": # set num_output_chunks for model self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) - if self._model_compile: - log.info("Compiling loss with torch.compile...") - # For CEWithChunkedOutputLoss, if we compile the entire class - # we lose the benefits from the chunked loss. - # Therefore, we only compile the cross entropy function + upcasting - self._loss_fn.compute_cross_entropy = torch.compile( - self._loss_fn.compute_cross_entropy, backend=backend - ) - else: - if self._model_compile: - log.info("Compiling loss with torch.compile...") - self._loss_fn = torch.compile(self._loss_fn, backend=backend) - log.info("Loss is initialized.") + + if self._is_rank_zero: + log.info("Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized + collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") self._sampler, self._dataloader = self._setup_data( cfg_dataset=cfg.dataset, shuffle=cfg.shuffle, batch_size=cfg.batch_size, + collate_fn=collate_name, ) # Finally update the recipe state which can only be correctly set after all of the @@ -371,6 +422,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, fsdp_cpu_offload: bool, reshard_after_forward: bool, model_state_dict: Dict[str, Any], @@ -396,6 +448,9 @@ def _setup_model( with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(cfg_model) + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + # We currently have two versions of activation checkpointing in this recipe # for testing and BC purposes. ``enable_activation_checkpointing`` controls # the older version of AC and this behavior is unchanged @@ -451,7 +506,17 @@ def _setup_model( # This method will convert the full model state dict into a sharded state # dict and load into the model training.load_from_full_model_state_dict( - model, model_state_dict, self._device, self._is_rank_zero, strict=True + model, + model_state_dict, + self._device, + self._is_rank_zero, + strict=True, + cpu_offload=fsdp_cpu_offload, + ) + + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading ) # Ensure no params and buffers are on meta device @@ -470,25 +535,64 @@ def _setup_model( return model def _setup_optimizer( - self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None - ) -> Optimizer: - optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) - if opt_state_dict: - training.load_from_full_optimizer_state_dict( - optimizer, - opt_state_dict, - self._device, + self, + cfg_optimizer: DictConfig, + optimizer_in_bwd: bool = False, + opt_state_dict: Optional[Dict[str, Any]] = None, + ) -> Optional[Optimizer]: + if optimizer_in_bwd: + # Maintain a dict of optims for every parameter. + optim_dict = { + param: config.instantiate(cfg_optimizer, [param]) + for param in self._model.parameters() + } + + # Register optimizer step hooks on the model to run optimizer in backward. + training.register_optim_in_bwd_hooks( + model=self._model, optim_dict=optim_dict ) + # Create a wrapper for checkpoint save/load of optimizer states when running in backward. + self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( + model=self._model, optim_dict=optim_dict + ) + # Load optimizer states for each param. If optimizer states are being restored in an optimizer in + # backward run, these need to have been saved with the same setting. Cannot restore from runs that + # did not use optimizer in backward. + if opt_state_dict is not None: + for param in opt_state_dict.keys(): + try: + training.load_from_full_optimizer_state_dict( + self._optim_ckpt_wrapper.state_dict()[param], + opt_state_dict[param], + self._device, + ) + except BaseException as e: + raise RuntimeError( + "Failed loading in-backward optimizer checkpoints." + "Please make sure run being restored from was using in-backward optimizer." + ) from e + if self._is_rank_zero: + log.info("In-backward optimizers are set up.") + return None + else: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + training.load_from_full_optimizer_state_dict( + optimizer, + opt_state_dict, + self._device, + ) - if self._is_rank_zero: - log.info("Optimizer is initialized.") - return optimizer + if self._is_rank_zero: + log.info("Optimizer is initialized.") + return optimizer def _setup_data( self, cfg_dataset: DictConfig, shuffle: bool, batch_size: int, + collate_fn: str, ) -> Tuple[DistributedSampler, DataLoader]: """ All data related setup happens here. Currently this recipe only supports the @@ -499,15 +603,20 @@ def _setup_data( if isinstance(cfg_dataset, ListConfig): datasets = [ - config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + config.instantiate(single_cfg_dataset, self._tokenizer) for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) packed = False else: - ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) + # Instantiate collate_fn + if "left_pad_sequence" in collate_fn: + raise RuntimeError("left_pad_sequence collator is only for inference.") + collate_fn = _get_component_from_path(collate_fn) + sampler = DistributedSampler( ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 ) @@ -519,14 +628,12 @@ def _setup_data( drop_last=True, collate_fn=( partial( - padded_collate_sft, + collate_fn, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) if not packed - else partial( - padded_collate_packed, - ) + else padded_collate_packed ), ) @@ -553,25 +660,54 @@ def save_checkpoint( checkpoint_dict = {} intermediate_checkpoint = epoch + 1 < self.total_epochs + + if self._is_rank_zero: + log.info( + "Saving checkpoint. This may take some time. Retrieving full model state dict..." + ) + start = time.perf_counter() + # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 cpu_state_dict = training.get_full_model_state_dict( self._model, self._is_rank_zero, + device=self._device, ) - if intermediate_checkpoint: - opt_state_dict = training.get_full_optimizer_state_dict( - self._optimizer, - self._is_rank_zero, + if self._is_rank_zero: + log.info( + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" ) + + if intermediate_checkpoint: + start = time.perf_counter() + if self._is_rank_zero: + log.info("Getting optimizer state dict...") + if not self._optimizer_in_bwd: + opt_state_dict = training.get_full_optimizer_state_dict( + self._optimizer, + self._is_rank_zero, + device=self._device, + ) + else: + opt_state_dict = {} + for param, opt in self._optim_ckpt_wrapper.optim_map.items(): + opt_state_dict[param] = training.get_full_optimizer_state_dict( + opt, self._is_rank_zero, device=self._device + ) + if self._is_rank_zero: + log.info( + f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs" + ) else: opt_state_dict = None # Now that we have the model and opt state dict, create the actual checkpoint dict # to be sent to the checkpointer and ultimately written to file - if self._is_rank_zero: + if self._is_rank_zero: + start = time.perf_counter() checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) # if training is in-progress, checkpoint the optimizer state and recipe state @@ -592,6 +728,9 @@ def save_checkpoint( epoch=epoch, intermediate_checkpoint=intermediate_checkpoint, ) + log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") + + torch.distributed.barrier() def train(self) -> None: """ @@ -599,10 +738,15 @@ def train(self) -> None: """ # clean up before training begins training.cleanup_before_training() + world_size, rank = training.get_world_size_and_rank() # zero out the gradients before starting training - self._optimizer.zero_grad() + if not self._optimizer_in_bwd: + self._optimizer.zero_grad() + else: + for opt in self._optim_ckpt_wrapper.optim_map.values(): + opt.zero_grad() # Initialize tokens count and running loss (for grad accumulation) t0 = time.perf_counter() @@ -612,7 +756,6 @@ def train(self) -> None: self._profiler.start() # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): - # Update the sampler to ensure data is correctly shuffled across epochs # in case shuffle is True self._sampler.set_epoch(curr_epoch) @@ -635,13 +778,6 @@ def train(self) -> None: ): torch.cuda.memory._record_memory_history() - # Both are shape [b, s] - tokens, labels = batch["tokens"], batch["labels"] - # Get the attention mask and position ids from the dataset if they - # exist. Currently, only sample packing in PackedDataset returns these - mask = batch.get("mask", None) # shape [b, s, s] - input_pos = batch.get("input_pos", None) # shape [b, s] - # Optionally wait N steps before enabling fake quant if self._fake_quant_after_n_steps is not None: if self.global_step == 0: @@ -663,20 +799,20 @@ def train(self) -> None: ) self._model.apply(enable_fq) - tokens = tokens.to(self._device) + utils.batch_to_device(batch, self._device) # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step - - utils.batch_to_device(batch, self._device) - current_num_tokens = ( batch["labels"] != self._loss_fn.ignore_index ).sum() num_tokens += current_num_tokens + + # Shape [b, s], needed for the loss not the model labels = batch.pop("labels") - logits = self._model(**batch) + with self.activations_handling_ctx: + logits = self._model(**batch) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] @@ -689,25 +825,40 @@ def train(self) -> None: logits = logits.reshape(-1, logits.size(-1)) # Compute loss + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients current_loss = self._loss_fn(logits, labels) * current_num_tokens # free logits otherwise it peaks backward memory del logits running_loss += current_loss - current_loss.backward() - # Step with optimizer - if (idx + 1) % self._gradient_accumulation_steps == 0: - # Get total number of tokens across all ranks to normalize gradients + # For optimizer in backward, we need to normalize before calling backward + # This case and gradient accumulation are mutually exclusive + if self._optimizer_in_bwd: torch.distributed.all_reduce(num_tokens) - # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) - # Manually scale the gradients from unnormalized loss by total # of tokens - training.scale_grads(self._model, 1 / num_tokens) + current_loss = current_loss / num_tokens + + current_loss.backward() - self._optimizer.step() - self._optimizer.zero_grad(set_to_none=True) + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + if not self._optimizer_in_bwd: + # Get total number of tokens across all ranks to normalize gradients + torch.distributed.all_reduce(num_tokens) + # This will ensure that the logged loss matches what we're optimizing + torch.distributed.all_reduce(running_loss) + # Manually scale the gradients from unnormalized loss by total # of tokens + training.scale_grads(self._model, 1 / num_tokens) + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) # Update the number of steps when the weights are updated self.global_step += 1 @@ -726,15 +877,22 @@ def train(self) -> None: time_per_step = time.perf_counter() - t0 log_dict = { "loss": loss_to_log, - "lr": self._optimizer.param_groups[0]["lr"], - "tokens_per_second_per_gpu": ( - num_tokens / time_per_step * world_size + "lr": get_lr( + ( + self._optimizer + if not self._optimizer_in_bwd + else self._optim_ckpt_wrapper + ), ), + "tokens_per_second_per_gpu": num_tokens + / (time_per_step * world_size), } if self._log_peak_memory_stats: log_dict.update( training.get_memory_stats(device=self._device) ) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) self._metric_logger.log_dict( log_dict, step=self.global_step, @@ -784,7 +942,7 @@ def recipe_main(cfg: DictConfig) -> None: """ if not training.is_distributed(): raise RuntimeError( - "Distributed QAT recipe should be run via a distributed launcher." + "Distributed finetune recipe should be run via a distributed launcher." "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" ) init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")