From 18efc81dda1c537bb7c25058ff059b4623ccff58 Mon Sep 17 00:00:00 2001 From: Philip Bontrager Date: Tue, 24 Sep 2024 18:16:36 -0400 Subject: [PATCH] LoRA Builders for MM (#1661) --- recipes/lora_finetune_distributed.py | 56 +- recipes/lora_finetune_single_device.py | 45 +- torchtune/models/clip/_component_builders.py | 355 ++++++++++- .../models/flamingo/_component_builders.py | 574 ++++++++++++++++-- torchtune/models/flamingo/_encoder.py | 9 +- torchtune/models/flamingo/_transform.py | 1 + .../models/llama3_1/_component_builders.py | 15 +- 7 files changed, 949 insertions(+), 106 deletions(-) diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 908d50d0db..a1f768a590 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -20,7 +20,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_packed, padded_collate_sft +from torchtune.config._utils import _get_component_from_path +from torchtune.data import padded_collate_packed from torchtune.datasets import ConcatDataset from torchtune.modules.peft import ( DoRALinear, @@ -94,6 +95,10 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config has example commands for how to kick-off training. @@ -104,6 +109,7 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): ValueError: If ``dtype`` is set to fp16. ValueError: If world_size is 1 RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``left_pad_sequence`` is set as the data collator. """ def __init__(self, cfg: DictConfig) -> None: @@ -136,6 +142,7 @@ def __init__(self, cfg: DictConfig) -> None: self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 + self._clip_grad_norm = cfg.get("clip_grad_norm", None) self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) self._resume_from_checkpoint = cfg.resume_from_checkpoint @@ -257,10 +264,12 @@ def setup(self, cfg: DictConfig) -> None: # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after all of these are setup + collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") self._sampler, self._dataloader = self._setup_data( cfg_dataset=cfg.dataset, shuffle=cfg.shuffle, batch_size=cfg.batch_size, + collate_fn=collate_name, ) # Finally update the recipe state which can only be correctly set after all of the @@ -535,6 +544,7 @@ def _setup_data( cfg_dataset: DictConfig, shuffle: bool, batch_size: int, + collate_fn: str, ) -> Tuple[DistributedSampler, DataLoader]: """ All data related setup happens here. Currently this recipe only supports the @@ -545,15 +555,20 @@ def _setup_data( if isinstance(cfg_dataset, ListConfig): datasets = [ - config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + config.instantiate(single_cfg_dataset, self._tokenizer) for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) packed = False else: - ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) + # Instantiate collate_fn + if "left_pad_sequence" in collate_fn: + raise RuntimeError("left_pad_sequence collator is only for inference.") + collate_fn = _get_component_from_path(collate_fn) + sampler = DistributedSampler( ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 ) @@ -565,14 +580,12 @@ def _setup_data( # dropping last avoids shape issues with compile + flex attention drop_last=cfg_dataset.get("drop_last", True), collate_fn=partial( - padded_collate_sft, + collate_fn, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) if not packed - else partial( - padded_collate_packed, - ), + else padded_collate_packed, ) if self._is_rank_zero: @@ -714,21 +727,13 @@ def train(self) -> None: ): torch.cuda.memory._record_memory_history() - # Both are shape [b, s] - tokens, labels = batch["tokens"], batch["labels"] - # Get the attention mask and position ids from the dataset if they - # exist. Currently, only sample packing in PackedDataset returns these - mask = batch.get("mask", None) # shape [b, s, s] - input_pos = batch.get("input_pos", None) # shape [b, s] - - tokens = tokens.to(self._device) - num_tokens += tokens.numel() - labels = labels.to(self._device) - mask = mask.to(self._device) if mask is not None else None - input_pos = ( - input_pos.to(self._device) if input_pos is not None else None - ) - logits = self._model(tokens, mask=mask, input_pos=input_pos) + utils.batch_to_device(batch, self._device) + num_tokens += batch["tokens"].numel() + + # Shape [b, s], needed for the loss not the model + labels = batch.pop("labels") + + logits = self._model(**batch) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] @@ -752,6 +757,11 @@ def train(self) -> None: # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self._lr_scheduler.step() @@ -780,6 +790,8 @@ def train(self) -> None: log_dict.update( training.get_memory_stats(device=self._device) ) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) self._metric_logger.log_dict( log_dict, step=self.global_step, diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index e2180dd078..74f7cfec3d 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -20,7 +20,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_packed, padded_collate_sft +from torchtune.config._utils import _get_component_from_path +from torchtune.data import padded_collate_packed from torchtune.datasets import ConcatDataset from torchtune.modules.peft import ( get_adapter_params, @@ -72,7 +73,7 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): most cases this should halve the memory footprint of full precision (fp32) training, without loss in model quality (will depend on the model, training data and other settings). For GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 - precision are currently not supported.g + precision are currently not supported. - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is controlled using the ``gradient_accumulation_steps`` flag. @@ -119,6 +120,7 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): ValueError: If ``dtype`` is set to fp16. RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``left_pad_sequence`` is set as the data collator """ @@ -282,10 +284,12 @@ def setup(self, cfg: DictConfig) -> None: # Dataloader depends on the tokenizer and loss_fn and should be # setup after all of these are setup + collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") self._sampler, self._dataloader = self._setup_data( cfg_dataset=cfg.dataset, shuffle=cfg.shuffle, batch_size=cfg.batch_size, + collate_fn=collate_name, ) # Finally update the recipe state which can only be correctly set after all of the @@ -502,6 +506,7 @@ def _setup_data( cfg_dataset: DictConfig, shuffle: bool, batch_size: int, + collate_fn: str, ) -> Tuple[DistributedSampler, DataLoader]: """ All data related setup happens here. Currently this recipe only supports @@ -519,6 +524,11 @@ def _setup_data( ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) + # Instantiate collate_fn + if "left_pad_sequence" in collate_fn: + raise RuntimeError("left_pad_sequence collator is only for inference.") + collate_fn = _get_component_from_path(collate_fn) + sampler = DistributedSampler( ds, num_replicas=1, @@ -532,17 +542,13 @@ def _setup_data( batch_size=batch_size, # dropping last avoids shape issues with compile + flex attention drop_last=cfg_dataset.get("drop_last", True), - collate_fn=( - partial( - padded_collate_sft, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ) - if not packed - else partial( - padded_collate_packed, - ) - ), + collate_fn=partial( + collate_fn, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else padded_collate_packed, ) log.info("Dataset and Sampler are initialized.") @@ -623,17 +629,12 @@ def save_checkpoint(self, epoch: int) -> None: ) def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - # Both are shape [b, s] - tokens, labels = batch["tokens"], batch["labels"] - - # Get the attention mask and position ids from the dataset if they - # exist. Currently, only sample packing in PackedDataset returns these - mask = batch.get("mask", None) # shape [b, s, s] - input_pos = batch.get("input_pos", None) # shape [b, s] + # Shape [b, s], needed for the loss not the model + labels = batch.pop("labels") # run model with self.activations_handling_ctx: - logits = self._model(tokens, mask=mask, input_pos=input_pos) + logits = self._model(**batch) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] @@ -692,7 +693,7 @@ def train(self) -> None: ): torch.cuda.memory._record_memory_history() - batch = {k: v.to(self._device) for k, v in batch.items()} + utils.batch_to_device(batch, self._device) num_tokens += batch["tokens"].numel() loss = self._loss_step(batch) diff --git a/torchtune/models/clip/_component_builders.py b/torchtune/models/clip/_component_builders.py index 47894728eb..0940d49359 100644 --- a/torchtune/models/clip/_component_builders.py +++ b/torchtune/models/clip/_component_builders.py @@ -1,3 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial from typing import Callable, List, Optional import torch @@ -12,7 +19,12 @@ TanhGate, FeedForward, Fp32LayerNorm -) +) + +from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook + +from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear + def clip_vision_encoder( tile_size: int, @@ -40,6 +52,11 @@ def clip_vision_encoder( :class:`torchtune.modules.vision_transformer.VisionTransformer`. Args: + tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, + the size of the input image. In this case, the function will consider your image as a single tile. + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches + with shape (40, 40) each. embed_dim (int): The dimensionality of each patch embedding (token). num_layers (int): The number of transformer layers. num_heads (int): The number of attention heads in each transformer layer. @@ -52,11 +69,6 @@ def clip_vision_encoder( return the tokens before they go through the first and fourth layers. output_cls_projection (bool): If True, only the CLS token projection will be outputted, instead of all tokens. Defaults to False. - tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, - the size of the input image. In this case, the function will consider your image as a single tile. - patch_size (int): The size of each patch. Used to divide the tiles into patches. - E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches - with shape (40, 40) each. max_num_tiles (int): The maximum number of tiles that can be processed. This is used to determine the size of the positional embeddings. in_channels (int): The number of image input channels. @@ -140,3 +152,334 @@ def clip_mlp(in_dim: int, out_dim: int, hidden_dim: int, activation: nn.Module, gate_proj = nn.Linear(in_dim, hidden_dim) if not quantize_base else FrozenNF4Linear(in_dim, hidden_dim) down_proj = nn.Linear(hidden_dim, out_dim) if not quantize_base else FrozenNF4Linear(hidden_dim, out_dim) return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation) + + +# ------------------ LoRA CLIP ------------------ + + +def lora_clip_vision_encoder( + lora_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + *, + # clip encoder parameters + tile_size: int, + patch_size: int, + embed_dim: int, + num_layers: int, + num_heads: int, + activation: Callable = nn.SiLU, + cls_output_dim: int = 512, + attn_bias: bool = True, + out_indices: Optional[List[int]] = None, + output_cls_projection: bool = False, + max_num_tiles: int = 4, + in_channels: int = 3, + intermediate_act: torch.nn.Module = torch.nn.SiLU(), + # LoRA parameters + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, +) -> VisionTransformer: + """ + Build a LoRA implementation of the CLIP vision encoder. + + Args: + lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, + the size of the input image. In this case, the function will consider your image as a single tile. + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches + with shape (40, 40) each. + embed_dim (int): The dimensionality of each patch embedding (token). + num_layers (int): The number of transformer layers. + num_heads (int): The number of attention heads in each transformer layer. + activation (Callable): The activation function to use in the MLP layer. + cls_output_dim (int): The dimensionality of the output tensor from the CLS projection module. + attn_bias (bool): Boolean for if to use bias in the attention module. Default True. + out_indices (Optional[List[int]]): The indices of hidden layers to return. + If provided, it will return the intermediate results of the transformer layers + before they go through a next layer. For example, ``out_indices=[0,3]`` will + return the tokens before they go through the first and fourth layers. + output_cls_projection (bool): If True, only the CLS token projection will be outputted, + instead of all tokens. Defaults to False. + max_num_tiles (int): The maximum number of tiles that can be processed. This is used to + determine the size of the positional embeddings. + in_channels (int): The number of image input channels. + intermediate_act (torch.nn.Module): The activation function used in the intermediate layers in the transformer encoder. + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + use_dora (bool): Whether to use DoRA layers instead of LoRA layers. Default is ``False``. + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + + Returns: + VisionTransformer: Instantiation of VisionTransformer model. + """ + assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + + # TODO: add support for quantizing and LoRA for the final output projection + cls_projection = CLSProjection(embed_dim=embed_dim, cls_output_dim=cls_output_dim) if output_cls_projection else None + + # transformer layer + self_attn = lora_clip_attention( + lora_modules=lora_modules, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_heads, + head_dim=embed_dim // num_heads, + attn_dropout=0.0, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + if apply_lora_to_mlp: + mlp = lora_clip_mlp( + in_dim=embed_dim, + hidden_dim=4 * embed_dim, + out_dim=embed_dim, + activation=activation(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + lora_dropout=lora_dropout, + use_dora=use_dora, + ) + else: + mlp = clip_mlp( + in_dim=embed_dim, + hidden_dim=4 * embed_dim, + out_dim=embed_dim, + activation=activation(), + quantize_base=quantize_base, + ) + transformer_layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm= Fp32LayerNorm(embed_dim, eps=1e-5), + mlp_norm= Fp32LayerNorm(embed_dim, eps=1e-5), + sa_scale=None, + mlp_scale=None, + ) + + # position embeddings + if max_num_tiles == 1: + pre_tile_pos_embed = None + post_tile_pos_embed = None + token_pos_embedding = TokenPositionalEmbedding( + embed_dim=embed_dim, + patch_size=patch_size, + tile_size=tile_size) + else: + pre_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim) + post_tile_pos_embed = TilePositionalEmbedding(max_num_tiles=max_num_tiles, embed_dim=embed_dim) + token_pos_embedding = TiledTokenPositionalEmbedding( + max_num_tiles=max_num_tiles, + embed_dim=embed_dim, + patch_size=patch_size, + tile_size=tile_size) + + model = VisionTransformer( + num_layers=num_layers, + layer=transformer_layer, + token_pos_embedding=token_pos_embedding, + pre_tile_pos_embed=pre_tile_pos_embed, + post_tile_pos_embed=post_tile_pos_embed, + cls_projection=cls_projection, + out_indices=out_indices, + tile_size=tile_size, + patch_size=patch_size, + embed_dim=embed_dim, + in_channels=in_channels, + ) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly + # so as to not increase peak memory + model._register_state_dict_hook( + partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True) + ) + + return model + + +def lora_clip_attention( + lora_modules: List[LORA_ATTN_MODULES], + *, + # MultiHeadAttention args + embed_dim: int, + head_dim: int, + num_heads: int, + num_kv_heads: int, + attn_dropout: float = 0.0, + # LoRA args + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, +) -> MultiHeadAttention: + """ + Return an instance of :func:`~torchtune.modules.MultiHeadAttention` with LoRA + applied to a subset of its linear layers + + Args: + lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj", + "output_proj"}``. + embed_dim (int): embedding dimension for self-attention + head_dim (int): dimension of each head in the multihead attention. Usually + computed as ``embed_dim // num_heads``. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + attn_dropout (float): dropout value passed onto scaled_dot_product_attention. + Default: 0.0 + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + use_dora (bool): Whether to use DoRA layers instead of LoRA layers. Default is ``False``. + quantize_base (bool): Whether to quantize base model parameters for linear layers + LoRA is being applied to. Default is ``False``. + + Returns: + MultiHeadAttention: instantiation of self-attention module with LoRA + applied to a subset of Q, K, V, output projections. + + Raises: + ValueError: If lora_modules arg is an empty list + """ + if not lora_modules: + raise ValueError( + f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules" + ) + + adapter_cls = DoRALinear if use_dora else LoRALinear + q_proj = ( + adapter_cls( + embed_dim, + num_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "q_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) + ) + k_proj = ( + adapter_cls( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "k_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + v_proj = ( + adapter_cls( + embed_dim, + num_kv_heads * head_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "v_proj" in lora_modules + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) + ) + output_proj = ( + adapter_cls( + embed_dim, + embed_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + if "output_proj" in lora_modules + else ( + nn.Linear(embed_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, embed_dim, bias=False) + ) + ) + + self_attn = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + q_proj=q_proj, + k_proj=k_proj, + v_proj=v_proj, + output_proj=output_proj, + pos_embeddings=None, + attn_dropout=attn_dropout, + ) + return self_attn + + +def lora_clip_mlp( + *, + in_dim: int, + out_dim: int, + hidden_dim: int, + activation: nn.Module, + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, +) -> FeedForward: + """ + Build the MLP layer with LoRA applied to the gate and down projections. + """ + adapter_cls = DoRALinear if use_dora else LoRALinear + gate_proj = adapter_cls( + in_dim=dim, + out_dim=hidden_dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + down_proj = adapter_cls( + in_dim=hidden_dim, + out_dim=dim, + rank=lora_rank, + alpha=lora_alpha, + dropout=lora_dropout, + quantize_base=quantize_base, + ) + return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation) diff --git a/torchtune/models/flamingo/_component_builders.py b/torchtune/models/flamingo/_component_builders.py index 870c028626..0c71bc3954 100644 --- a/torchtune/models/flamingo/_component_builders.py +++ b/torchtune/models/flamingo/_component_builders.py @@ -1,16 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from enum import Enum from typing import Optional, List from torch import nn -from torchtune.models.llama3._component_builders import llama3_mlp from torchtune.models.llama3._model_utils import scale_hidden_dim_for_mlp -from torchtune.models.clip import clip_vision_encoder, clip_mlp +from torchtune.models.llama3_1._component_builders import llama3_mlp, lora_llama3_mlp, lora_llama3_attention +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE +from torchtune.models.clip._component_builders import clip_vision_encoder, clip_mlp, lora_clip_attention, lora_clip_mlp, lora_clip_vision_encoder from torchtune.models.flamingo._encoder import FlamingoProjectionHead, FlamingoEncoder from torchtune.modules.model_fusion import FusionEmbedding, FusionLayer from torchtune.modules import ( RMSNorm, - RotaryPositionalEmbeddings, TanhGate, TransformerCrossAttentionLayer, MultiHeadAttention, @@ -19,6 +27,9 @@ Fp32LayerNorm ) +from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook + +from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear """ @@ -32,6 +43,7 @@ the building blocks simple. """ + def flamingo_vision_encoder( # clip encoder parameters *, @@ -96,51 +108,12 @@ def flamingo_vision_encoder( ) # Projection head - mlp_ratio = 4 - hidden_dim = int(mlp_ratio * clip_embed_dim) - head_dim = clip_embed_dim // num_heads - num_kv_heads = num_heads - - self_attn = MultiHeadAttention( - embed_dim=clip_embed_dim, - num_heads=num_heads, - num_kv_heads=num_heads, - head_dim=head_dim, - q_proj=nn.Linear(clip_embed_dim, num_heads * head_dim, bias=False), - k_proj=nn.Linear(clip_embed_dim, num_kv_heads * head_dim, bias=False), - v_proj=nn.Linear(clip_embed_dim, num_kv_heads * head_dim, bias=False), - output_proj=nn.Linear(clip_embed_dim, clip_embed_dim, bias=False), - pos_embeddings=None, - attn_dropout=0.0, - is_causal=False, - ) - - mlp = clip_mlp( - in_dim=clip_embed_dim, - hidden_dim=hidden_dim, - out_dim=clip_embed_dim, - activation=nn.GELU(), - ) - - layer = TransformerSelfAttentionLayer( - attn=self_attn, - mlp=mlp, - sa_norm=Fp32LayerNorm(clip_embed_dim, eps=1e-5), - mlp_norm=Fp32LayerNorm(clip_embed_dim, eps=1e-5), - sa_scale=TanhGate(), - mlp_scale=TanhGate(), - ) - - # we concatenate clip embeddings and hidden layers output - # and project it to embed_dim_out, which will be used for the - # cross encoding - num_hidden_inputs = len(clip_hidden_states) if clip_hidden_states is not None else 0 - proj_in = clip_embed_dim * (num_hidden_inputs + 1) - projection_head = FlamingoProjectionHead( - layer=layer, + projection_head = flamingo_projection_head( num_layers=num_layers_projection, - output=nn.Linear(proj_in, decoder_embed_dim), - num_hidden_inputs=num_hidden_inputs + num_heads=num_heads, + decoder_embed_dim=decoder_embed_dim, + clip_embed_dim=clip_embed_dim, + num_hidden_inputs=len(clip_hidden_states or []) ) return FlamingoEncoder(clip=clip, projection_head=projection_head) @@ -194,10 +167,11 @@ def flamingo_decoder( num_kv_heads = num_kv_heads if num_kv_heads else num_heads hidden_dim = intermediate_dim or scale_hidden_dim_for_mlp(embed_dim) layers = [] + + rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) for idx in range(1, num_layers + 1): # Self attention layers for text decoder - rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, @@ -264,3 +238,507 @@ def flamingo_decoder( norm=RMSNorm(embed_dim, eps=1e-05), output=output_proj, ) + +def flamingo_projection_head( + *, + num_layers: int, + num_heads: int, + decoder_embed_dim: int, + clip_embed_dim: int, + num_hidden_inputs: int, +) -> FlamingoProjectionHead: + """ + Build the Flamingo Projection Head that maps the output of the CLIP encoder + to the decoder cross attention input. + + Args: + num_layers (int): number of layers in the projection head. + num_heads (int): number of heads in the projection head. + decoder_embed_dim (int): embedding dimension for the decoder. + clip_embed_dim (int): embedding dimension for the CLIP encoder. + num_hidden_inputs (int): number of hidden inputs to the projection head. + + Returns: + FlamingoProjectionHead: Instantiation of Flamingo projection head. + """ + mlp_ratio = 4 + hidden_dim = int(mlp_ratio * clip_embed_dim) + head_dim = clip_embed_dim // num_heads + num_kv_heads = num_heads + + layers = [] + for _ in range(num_layers): + self_attn = MultiHeadAttention( + embed_dim=clip_embed_dim, + num_heads=num_heads, + num_kv_heads=num_heads, + head_dim=head_dim, + q_proj=nn.Linear(clip_embed_dim, num_heads * head_dim, bias=False), + k_proj=nn.Linear(clip_embed_dim, num_kv_heads * head_dim, bias=False), + v_proj=nn.Linear(clip_embed_dim, num_kv_heads * head_dim, bias=False), + output_proj=nn.Linear(clip_embed_dim, clip_embed_dim, bias=False), + pos_embeddings=None, + attn_dropout=0.0, + is_causal=False, + ) + + mlp = clip_mlp( + in_dim=clip_embed_dim, + hidden_dim=hidden_dim, + out_dim=clip_embed_dim, + activation=nn.GELU(), + ) + + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=Fp32LayerNorm(clip_embed_dim, eps=1e-5), + mlp_norm=Fp32LayerNorm(clip_embed_dim, eps=1e-5), + sa_scale=TanhGate(), + mlp_scale=TanhGate(), + ) + layers.append(layer) + + # we concatenate clip embeddings and hidden layers output + # and project it to embed_dim_out, which will be used for the + # cross encoding + proj_in = clip_embed_dim * (num_hidden_inputs + 1) + return FlamingoProjectionHead( + layers=layers, + output=nn.Linear(proj_in, decoder_embed_dim), + num_hidden_inputs=num_hidden_inputs + ) + +# ------------------ LoRA Flamingo ------------------ + + +class LoRATrainable(Enum): + FULL = "full" + LORA = "lora" + FROZEN = "frozen" + + +def lora_flamingo_vision_encoder( + encoder_lora: bool, + fusion_lora: bool, + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + *, + # clip encoder parameters + patch_size: int, + num_heads: int, + clip_embed_dim: int, + clip_num_layers: int, + clip_hidden_states: Optional[List[int]], + # projection parameters + num_layers_projection: int, + decoder_embed_dim: int, + # image parameters + tile_size: int, + max_num_tiles: int = 4, + in_channels: int = 3, + # LoRA parameters + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, + ) -> FlamingoEncoder: + """ + Build the Flamingo encoder by combining the CLIP image model with an additional + projection head fusion module. This includes: + - Spatial positional encodings + - CLIP model backbone + - Projection head on top of CLIP + - Final projection into token embedding dimension + + Args: + encoder_lora (bool): whether to apply LoRA to the CLIP encoder + fusion_lora (bool): whether to apply LoRA to the projection head + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + patch_size (int): The size of each patch. Used to divide the tiles into patches. + E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches + with shape (40, 40) each. + num_heads (int): The number of attention heads in each transformer layer. + clip_embed_dim (int): The dimensionality of each patch embedding in CLIP. + clip_num_layers (int): The number of transformer layers. + clip_hidden_states (Optional[List[int]]): The indices of CLIP hidden layers to return + to return to the encoder projection head. It will return the intermediate results + of the vision transformer layers which will be concatenated with the CLIP output + and input into the projection head. For example, ``clip_hidden_states=[0,3]`` will + return the embeddings before they go through the first and fourth layers. + num_layers_projection (int): The number of transformer layers in the projection head. + decoder_embed_dim (int): The dimensionality of the final output embeddings for the decoder. + tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise, + the size of the input image. In this case, the function will consider your image as a single tile. + max_num_tiles (int): The maximum number of tiles that can be processed. This is used to + determine the size of the positional embeddings. + in_channels (int): The number of image input channels. + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + use_dora (bool): Whether to use DoRA layers instead of LoRA layers. Default is ``False``. + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + + Returns: + FlamingoEncoder: Instantiation of Flamingo encoder. + """ + lora_options = { + "lora_modules": lora_attn_modules, + "apply_lora_to_mlp": apply_lora_to_mlp, + "apply_lora_to_output": apply_lora_to_output, + "lora_rank": lora_rank, + "lora_alpha": lora_alpha, + "lora_dropout": lora_dropout, + "use_dora": use_dora, + "quantize_base": quantize_base, + } + + # clip encoder + clip_options = { + "tile_size": tile_size, + "patch_size": patch_size, + "embed_dim": clip_embed_dim, + "num_layers": clip_num_layers, + "num_heads": num_heads, + "activation": nn.GELU, + "out_indices": clip_hidden_states, + "max_num_tiles": max_num_tiles, + "in_channels": in_channels, + "attn_bias": False, + "output_cls_projection": False, + } + if encoder_lora: + clip = lora_clip_vision_encoder(**clip_options, **lora_options) + else: + clip = clip_vision_encoder(**clip_options) + + # Projection + projection_options = { + "num_layers": num_layers_projection, + "num_heads": num_heads, + "decoder_embed_dim": decoder_embed_dim, + "clip_embed_dim": clip_embed_dim, + "num_hidden_inputs": len(clip_hidden_states or []), + } + if fusion_lora: + projection_head = lora_flamingo_projection_head(**projection_options, **lora_options) + else: + projection_head = flamingo_projection_head(**projection_options) + + return FlamingoEncoder(clip=clip, projection_head=projection_head) + + +def lora_flamingo_decoder( + decoder_lora: bool, + fusion_lora: bool, + lora_attn_modules: List[LORA_ATTN_MODULES], + apply_lora_to_mlp: bool = False, + apply_lora_to_output: bool = False, + *, + # decoder params + vocab_size: int, + num_layers: int, + fusion_interval: int, + num_special_tokens: int, + num_heads: int, + num_kv_heads: int, + embed_dim: int, + max_seq_len: int, + encoder_max_seq_len: int, + rope_base: int = 500000.0, + intermediate_dim: Optional[int] = None, + # LoRA parameters + lora_rank: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, +) -> TransformerDecoder: + """ + Build the decoder associated with the Llama3 model with additional fused + cross attention layers. This includes: + - Token embeddings + - num_layers number of CausalSelfAttention blocks + - Fused cross attention layers every fusion_interval number of layers + - RMS Norm layer applied to the output of the transformer + - Final projection into token space + + Args: + decoder_lora (bool): whether to apply LoRA to the language decoder + fusion_lora (bool): whether to apply LoRA to the projection head + lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to in each self-attention block. Options are + ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + vocab_size (int): number of tokens in vocabulary. + num_layers (int): number of layers in the transformer decoder. + fusion_interval (int): interval number of layers between fusion layers. + num_special_tokens (int): number of special tokens added for the fusion model. + num_heads (int): number of query heads. For MHA this is also the + number of heads for key and value. + num_kv_heads (int): number of key and value heads. User should ensure + `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, + for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + embed_dim (int): embedding dimension for self-attention. + max_seq_len (int): maximum sequence length the model will be run with, as used + by :func:`~torchtune.modules.KVCache`. + encoder_max_seq_len (int): maximum sequence length the encoder will be run with, as used + by :func:`~torchtune.modules.KVCache`. + intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified, + this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp`. + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + use_dora (bool): Whether to use DoRA layers instead of LoRA layers. Default is ``False``. + quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base + weights within linear layers LoRA is applied to. The final output linear projection is not + supported for quantization currently. + + Returns: + TransformerDecoder: Instantiation of Flamingo decoder. + """ + head_dim = embed_dim // num_heads + num_kv_heads = num_kv_heads if num_kv_heads else num_heads + hidden_dim = intermediate_dim or scale_hidden_dim_for_mlp(embed_dim) + layers = [] + + rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) + for idx in range(1, num_layers + 1): + + # Self attention layers for text decoder + self_attn = lora_llama3_attention( + lora_modules=lora_attn_modules, + pos_embeddings=rope, + head_dim=head_dim, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + max_seq_len=max_seq_len, + attn_dropout=0.0, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + if apply_lora_to_mlp: + mlp = lora_llama3_mlp( + dim=embed_dim, + hidden_dim=hidden_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + lora_dropout=lora_dropout, + use_dora=use_dora, + ) + else: + mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base) + decoder_layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=RMSNorm(dim=embed_dim, eps=1e-5), + mlp_norm=RMSNorm(dim=embed_dim, eps=1e-5), + ) + + # cross attention layers, mixing text and vision, + # placed every `fusion_interval` layers + if idx % fusion_interval == 0: + attn = lora_llama3_attention( + lora_modules=lora_attn_modules, + pos_embeddings=None, + head_dim=head_dim, + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + q_norm=RMSNorm(dim=head_dim, eps=1e-05), + k_norm=RMSNorm(dim=head_dim, eps=1e-05), + max_seq_len=encoder_max_seq_len, + is_causal=False, + attn_dropout=0.0, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + if apply_lora_to_mlp: + mlp = lora_llama3_mlp( + dim=embed_dim, + hidden_dim=hidden_dim, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + lora_dropout=lora_dropout, + use_dora=use_dora, + ) + else: + mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base) + xattn_layer = TransformerCrossAttentionLayer( + attn=attn, + mlp=mlp, + ca_norm=RMSNorm(dim=embed_dim), + mlp_norm=RMSNorm(dim=embed_dim), + ca_scale=TanhGate(), + mlp_scale=TanhGate(), + ) + fusion_layer = FusionLayer(layer=decoder_layer, fusion_layer=xattn_layer) + layers.append(fusion_layer) + else: + layers.append(decoder_layer) + + tok_embeddings = FusionEmbedding(vocab_size, num_special_tokens, embed_dim) + + # TODO: quantize_base is not applied to final output_proj currently. + adapter_cls = DoRALinear if use_dora else LoRALinear + output_proj = ( + adapter_cls(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout) + if apply_lora_to_output + else nn.Linear(embed_dim, vocab_size, bias=False) + ) + + model = TransformerDecoder( + tok_embeddings=tok_embeddings, + layers=layers, + max_seq_len=max_seq_len, + num_heads=num_heads, + head_dim=head_dim, + norm=RMSNorm(embed_dim, eps=1e-05), + output=output_proj, + ) + + if quantize_base: + # For QLoRA, we reparametrize 4-bit tensors to bf16, and offload to CPU on the fly + # so as to not increase peak memory + model._register_state_dict_hook( + partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True) + ) + + return model + + +def lora_flamingo_projection_head( + lora_modules: List[LORA_ATTN_MODULES], + *, + # projection head parameters + num_layers: int, + num_heads: int, + decoder_embed_dim: int, + clip_embed_dim: int, + num_hidden_inputs: int, + # LoRA args + apply_lora_to_mlp: bool, + apply_lora_to_output: bool, + lora_rank: int, + lora_alpha: float, + lora_dropout: float = 0.0, + use_dora: bool = False, + quantize_base: bool = False, +) -> FlamingoProjectionHead: + """ + Build the Flamingo Projection Head with LoRA applied to a subset of the layers. + + Args: + lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers + LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj", + "output_proj"}``. + num_layers (int): number of layers in the projection head. + num_heads (int): number of heads in the projection head. + decoder_embed_dim (int): embedding dimension for the decoder. + clip_embed_dim (int): embedding dimension for the CLIP encoder. + num_hidden_inputs (int): number of hidden inputs to the projection head. + apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer. + Default: False + apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection. + Default: False + lora_rank (int): rank of each low-rank approximation + lora_alpha (float): scaling factor for the low-rank approximation + lora_dropout (float): LoRA dropout probability. Default: 0.0 + use_dora (bool): Whether to use DoRA layers instead of LoRA layers. Default is ``False``. + quantize_base (bool): Whether to quantize base model parameters for linear layers + LoRA is being applied to. Default is ``False``. + + Returns: + FlamingoProjectionHead: Instantiation of Flamingo projection head. + """ + mlp_ratio = 4 + hidden_dim = int(mlp_ratio * clip_embed_dim) + head_dim = clip_embed_dim // num_heads + num_kv_heads = num_heads + + layers = [] + for _ in range(num_layers): + self_attn = lora_clip_attention( + lora_modules=lora_modules, + embed_dim=clip_embed_dim, + num_heads=num_heads, + num_kv_heads=num_heads, + head_dim=head_dim, + attn_dropout=0.0, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + use_dora=use_dora, + quantize_base=quantize_base, + ) + + if apply_lora_to_mlp: + mlp = lora_clip_mlp( + in_dim=clip_embed_dim, + hidden_dim=hidden_dim, + out_dim=clip_embed_dim, + activation=nn.GELU(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + quantize_base=quantize_base, + lora_dropout=lora_dropout, + use_dora=use_dora, + ) + else: + mlp = clip_mlp( + in_dim=clip_embed_dim, + hidden_dim=hidden_dim, + out_dim=clip_embed_dim, + activation=nn.GELU(), + quantize_base=quantize_base + ) + + layer = TransformerSelfAttentionLayer( + attn=self_attn, + mlp=mlp, + sa_norm=Fp32LayerNorm(clip_embed_dim, eps=1e-5), + mlp_norm=Fp32LayerNorm(clip_embed_dim, eps=1e-5), + sa_scale=TanhGate(), + mlp_scale=TanhGate(), + ) + layers.append(layer) + + # we concatenate clip embeddings and hidden layers output + # and project it to embed_dim_out, which will be used for the + # cross encoding + # TODO: quantize_base is not applied to final output_proj currently. + proj_in = clip_embed_dim * (num_hidden_inputs + 1) + adapter_cls = DoRALinear if use_dora else LoRALinear + output_proj = ( + adapter_cls(proj_in, decoder_embed_dim, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout, use_bias=True) + if apply_lora_to_output + else nn.Linear(proj_in, decoder_embed_dim) + ) + return FlamingoProjectionHead( + layers=layers, + output=output_proj, + num_hidden_inputs=num_hidden_inputs, + ) diff --git a/torchtune/models/flamingo/_encoder.py b/torchtune/models/flamingo/_encoder.py index ad20fef92d..8ef13ac5da 100644 --- a/torchtune/models/flamingo/_encoder.py +++ b/torchtune/models/flamingo/_encoder.py @@ -10,7 +10,6 @@ from torch import nn from torchtune.modules.model_fusion import register_fusion_module -from torchtune.modules.transformer import _get_clones class FlamingoProjectionHead(nn.Module): @@ -19,8 +18,7 @@ class FlamingoProjectionHead(nn.Module): For example, nn.Sequential(CLIP(), FlamingoProjectionHead()). Args: - layer (nn.Module): Transformer Decoder layer - num_layers (int): Number of Transformer Decoder layers + layers (nn.Module): Transformer Decoder layers output (nn.Module): Output linear layer. Input dim is (num_hidden + 1) * encoder_dim and output is decoder_dim. num_hidden_inputs (int): Number of expected hidden state inputs @@ -28,13 +26,12 @@ class FlamingoProjectionHead(nn.Module): def __init__( self, - layer: nn.Module, - num_layers: int, + layers: nn.Module, output: nn.Module, num_hidden_inputs: int = 0, ) -> None: super().__init__() - self.layers = _get_clones(layer, num_layers) + self.layers = nn.ModuleList(layers) self.output = output self.num_hidden = num_hidden_inputs diff --git a/torchtune/models/flamingo/_transform.py b/torchtune/models/flamingo/_transform.py index 00ecbac553..e75152a787 100644 --- a/torchtune/models/flamingo/_transform.py +++ b/torchtune/models/flamingo/_transform.py @@ -99,6 +99,7 @@ def __init__( self.stop_tokens = self.tokenizer.stop_tokens self.max_seq_len = max_seq_len + self.image_seq_len = max_num_tiles * (self.xattn_mask.patches_per_tile + 1) self.prompt_template = prompt_template self.pad_id = self.tokenizer.pad_id diff --git a/torchtune/models/llama3_1/_component_builders.py b/torchtune/models/llama3_1/_component_builders.py index afeebd910f..aa33a9fe0a 100644 --- a/torchtune/models/llama3_1/_component_builders.py +++ b/torchtune/models/llama3_1/_component_builders.py @@ -190,6 +190,7 @@ def lora_llama3_1( lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 + use_dora (bool): Whether to use DoRA layers instead of LoRA layers. Default is ``False``. quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base weights within linear layers LoRA is applied to. The final output linear projection is not supported for quantization currently. @@ -205,7 +206,7 @@ def lora_llama3_1( rope = Llama3ScaledRoPE(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) layers = [] for _ in range(num_layers): - self_attn = lora_llama3_1_self_attention( + self_attn = lora_llama3_attention( lora_modules=lora_attn_modules, pos_embeddings=rope, head_dim=head_dim, @@ -271,7 +272,7 @@ def lora_llama3_1( return model -def lora_llama3_1_self_attention( +def lora_llama3_attention( lora_modules: List[LORA_ATTN_MODULES], pos_embeddings: nn.Module, *, @@ -280,7 +281,10 @@ def lora_llama3_1_self_attention( embed_dim: int, num_heads: int, num_kv_heads: int, + q_norm: Optional[nn.Module] = None, + k_norm: Optional[nn.Module] = None, max_seq_len: int, + is_causal: bool = True, attn_dropout: float = 0.0, # LoRA args lora_rank: int, @@ -307,13 +311,17 @@ def lora_llama3_1_self_attention( num_kv_heads (int): number of key and value heads. User should ensure `num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`, for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1. + q_norm (Optional[nn.Module]): normalization applied to query. Default: None + k_norm (Optional[nn.Module]): normalization applied to key. Default: None max_seq_len (int): maximum sequence length the model will be run with, as used by :func:`~torchtune.modules.KVCache` + is_causal (bool): whether to apply causal attention mask. Default: True attn_dropout (float): dropout value passed onto scaled_dot_product_attention. Default: 0.0 lora_rank (int): rank of each low-rank approximation lora_alpha (float): scaling factor for the low-rank approximation lora_dropout (float): LoRA dropout probability. Default: 0.0 + use_dora (bool): Whether to use DoRA layers instead of LoRA layers. Default is ``False``. quantize_base (bool): Whether to quantize base model parameters for linear layers LoRA is being applied to. Default is ``False``. @@ -405,8 +413,11 @@ def lora_llama3_1_self_attention( k_proj=k_proj, v_proj=v_proj, output_proj=output_proj, + q_norm=q_norm, + k_norm=k_norm, pos_embeddings=pos_embeddings, max_seq_len=max_seq_len, + is_causal=is_causal, attn_dropout=attn_dropout, ) return self_attn