From cb862e91bd985f7f709a7fd12d791af47aad1fb3 Mon Sep 17 00:00:00 2001 From: willfengg Date: Fri, 10 May 2024 12:51:49 -0700 Subject: [PATCH] rebase qlora Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- recipes/lora_finetune_distributed.py | 2 ++ torchtune/utils/_distributed.py | 1 + 2 files changed, 3 insertions(+) diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 94ffeb5c8e..a8fd86d3a4 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -33,6 +33,7 @@ get_adapter_params, get_merged_lora_ckpt, set_trainable_params, + validate_state_dict_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface @@ -472,6 +473,7 @@ def save_checkpoint( # Now that we have the model and opt state dict, create the actual checkpoint dict # to be sent to the checkpointer and ultimately written to file if self._is_rank_zero: + # Filter out the adapter keys and weights from the model state dict. These will # be saved separately adapter_key_filter = lambda x: x in self.adapter_params diff --git a/torchtune/utils/_distributed.py b/torchtune/utils/_distributed.py index 9500d5932d..bdc995234b 100644 --- a/torchtune/utils/_distributed.py +++ b/torchtune/utils/_distributed.py @@ -27,6 +27,7 @@ _lora_b_init_params, LoRALinear, ) + from torchtune.utils._device import get_device from torchtune.utils.logging import get_logger