Skip to content

Commit

Permalink
LoRA Builders for MM (#1661)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbontrager authored Sep 24, 2024
1 parent 30b8519 commit 18efc81
Show file tree
Hide file tree
Showing 7 changed files with 949 additions and 106 deletions.
56 changes: 34 additions & 22 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate_packed, padded_collate_sft
from torchtune.config._utils import _get_component_from_path
from torchtune.data import padded_collate_packed
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
DoRALinear,
Expand Down Expand Up @@ -94,6 +95,10 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
- Logging. Terminal, Disk, WandB and TensorBoard are all supported.
- Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default,
``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set
``clip_grad_norm='inf'``.
For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
has example commands for how to kick-off training.
Expand All @@ -104,6 +109,7 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
ValueError: If ``dtype`` is set to fp16.
ValueError: If world_size is 1
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``left_pad_sequence`` is set as the data collator.
"""

def __init__(self, cfg: DictConfig) -> None:
Expand Down Expand Up @@ -136,6 +142,7 @@ def __init__(self, cfg: DictConfig) -> None:
self.total_epochs = cfg.epochs
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0
self._clip_grad_norm = cfg.get("clip_grad_norm", None)

self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False)
self._resume_from_checkpoint = cfg.resume_from_checkpoint
Expand Down Expand Up @@ -257,10 +264,12 @@ def setup(self, cfg: DictConfig) -> None:

# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after all of these are setup
collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
self._sampler, self._dataloader = self._setup_data(
cfg_dataset=cfg.dataset,
shuffle=cfg.shuffle,
batch_size=cfg.batch_size,
collate_fn=collate_name,
)

# Finally update the recipe state which can only be correctly set after all of the
Expand Down Expand Up @@ -535,6 +544,7 @@ def _setup_data(
cfg_dataset: DictConfig,
shuffle: bool,
batch_size: int,
collate_fn: str,
) -> Tuple[DistributedSampler, DataLoader]:
"""
All data related setup happens here. Currently this recipe only supports the
Expand All @@ -545,15 +555,20 @@ def _setup_data(

if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
config.instantiate(single_cfg_dataset, self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
packed = False
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)
ds = config.instantiate(cfg_dataset, self._tokenizer)
packed = cfg_dataset.get("packed", False)

# Instantiate collate_fn
if "left_pad_sequence" in collate_fn:
raise RuntimeError("left_pad_sequence collator is only for inference.")
collate_fn = _get_component_from_path(collate_fn)

sampler = DistributedSampler(
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
)
Expand All @@ -565,14 +580,12 @@ def _setup_data(
# dropping last avoids shape issues with compile + flex attention
drop_last=cfg_dataset.get("drop_last", True),
collate_fn=partial(
padded_collate_sft,
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
),
else padded_collate_packed,
)

if self._is_rank_zero:
Expand Down Expand Up @@ -714,21 +727,13 @@ def train(self) -> None:
):
torch.cuda.memory._record_memory_history()

# Both are shape [b, s]
tokens, labels = batch["tokens"], batch["labels"]
# Get the attention mask and position ids from the dataset if they
# exist. Currently, only sample packing in PackedDataset returns these
mask = batch.get("mask", None) # shape [b, s, s]
input_pos = batch.get("input_pos", None) # shape [b, s]

tokens = tokens.to(self._device)
num_tokens += tokens.numel()
labels = labels.to(self._device)
mask = mask.to(self._device) if mask is not None else None
input_pos = (
input_pos.to(self._device) if input_pos is not None else None
)
logits = self._model(tokens, mask=mask, input_pos=input_pos)
utils.batch_to_device(batch, self._device)
num_tokens += batch["tokens"].numel()

# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")

logits = self._model(**batch)

# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
Expand All @@ -752,6 +757,11 @@ def train(self) -> None:

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
if self._clip_grad_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
Expand Down Expand Up @@ -780,6 +790,8 @@ def train(self) -> None:
log_dict.update(
training.get_memory_stats(device=self._device)
)
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
log_dict,
step=self.global_step,
Expand Down
45 changes: 23 additions & 22 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate_packed, padded_collate_sft
from torchtune.config._utils import _get_component_from_path
from torchtune.data import padded_collate_packed
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
get_adapter_params,
Expand Down Expand Up @@ -72,7 +73,7 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface):
most cases this should halve the memory footprint of full precision (fp32) training, without
loss in model quality (will depend on the model, training data and other settings). For
GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
precision are currently not supported.g
precision are currently not supported.
- Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
controlled using the ``gradient_accumulation_steps`` flag.
Expand Down Expand Up @@ -119,6 +120,7 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface):
ValueError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
RuntimeError: If ``left_pad_sequence`` is set as the data collator
"""

Expand Down Expand Up @@ -282,10 +284,12 @@ def setup(self, cfg: DictConfig) -> None:

# Dataloader depends on the tokenizer and loss_fn and should be
# setup after all of these are setup
collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
self._sampler, self._dataloader = self._setup_data(
cfg_dataset=cfg.dataset,
shuffle=cfg.shuffle,
batch_size=cfg.batch_size,
collate_fn=collate_name,
)

# Finally update the recipe state which can only be correctly set after all of the
Expand Down Expand Up @@ -502,6 +506,7 @@ def _setup_data(
cfg_dataset: DictConfig,
shuffle: bool,
batch_size: int,
collate_fn: str,
) -> Tuple[DistributedSampler, DataLoader]:
"""
All data related setup happens here. Currently this recipe only supports
Expand All @@ -519,6 +524,11 @@ def _setup_data(
ds = config.instantiate(cfg_dataset, self._tokenizer)
packed = cfg_dataset.get("packed", False)

# Instantiate collate_fn
if "left_pad_sequence" in collate_fn:
raise RuntimeError("left_pad_sequence collator is only for inference.")
collate_fn = _get_component_from_path(collate_fn)

sampler = DistributedSampler(
ds,
num_replicas=1,
Expand All @@ -532,17 +542,13 @@ def _setup_data(
batch_size=batch_size,
# dropping last avoids shape issues with compile + flex attention
drop_last=cfg_dataset.get("drop_last", True),
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else partial(
padded_collate_packed,
)
),
collate_fn=partial(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else padded_collate_packed,
)

log.info("Dataset and Sampler are initialized.")
Expand Down Expand Up @@ -623,17 +629,12 @@ def save_checkpoint(self, epoch: int) -> None:
)

def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# Both are shape [b, s]
tokens, labels = batch["tokens"], batch["labels"]

# Get the attention mask and position ids from the dataset if they
# exist. Currently, only sample packing in PackedDataset returns these
mask = batch.get("mask", None) # shape [b, s, s]
input_pos = batch.get("input_pos", None) # shape [b, s]
# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")

# run model
with self.activations_handling_ctx:
logits = self._model(tokens, mask=mask, input_pos=input_pos)
logits = self._model(**batch)

# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
Expand Down Expand Up @@ -692,7 +693,7 @@ def train(self) -> None:
):
torch.cuda.memory._record_memory_history()

batch = {k: v.to(self._device) for k, v in batch.items()}
utils.batch_to_device(batch, self._device)
num_tokens += batch["tokens"].numel()

loss = self._loss_step(batch)
Expand Down
Loading

0 comments on commit 18efc81

Please sign in to comment.