Skip to content

Commit

Permalink
rebase qlora
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed May 10, 2024
1 parent 1a70498 commit cb862e9
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
2 changes: 2 additions & 0 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
get_adapter_params,
get_merged_lora_ckpt,
set_trainable_params,
validate_state_dict_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface

Expand Down Expand Up @@ -472,6 +473,7 @@ def save_checkpoint(
# Now that we have the model and opt state dict, create the actual checkpoint dict
# to be sent to the checkpointer and ultimately written to file
if self._is_rank_zero:

# Filter out the adapter keys and weights from the model state dict. These will
# be saved separately
adapter_key_filter = lambda x: x in self.adapter_params
Expand Down
1 change: 1 addition & 0 deletions torchtune/utils/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_lora_b_init_params,
LoRALinear,
)

from torchtune.utils._device import get_device
from torchtune.utils.logging import get_logger

Expand Down

0 comments on commit cb862e9

Please sign in to comment.