Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Config Continous Integration (CCI) #1717

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

SalmanMohammadi
Copy link
Collaborator

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?
*

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Sep 30, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1717

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 240ef5e with merge base 3fddc56 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 30, 2024
@codecov-commenter
Copy link

codecov-commenter commented Sep 30, 2024

Codecov Report

Attention: Patch coverage is 38.35616% with 45 lines in your changes missing coverage. Please review.

Project coverage is 67.48%. Comparing base (6bc143f) to head (e761303).
Report is 17 commits behind head on main.

Files with missing lines Patch % Lines
tests/test_config.py 38.35% 45 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1717      +/-   ##
==========================================
- Coverage   70.67%   67.48%   -3.20%     
==========================================
  Files         299      305       +6     
  Lines       15251    15683     +432     
==========================================
- Hits        10778    10583     -195     
- Misses       4473     5100     +627     
Flag Coverage Δ
67.48% <38.35%> (-3.20%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -83,7 +83,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: True
enable_activation_offloading: False
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only recipe where it was enabled.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we only support it in lora and single device, currently. Also, if AC is true, then offloading can be set to True for almost free. I think its a good default, but we can make a broader change after offloading comes to all recipes

self.validate_checkpointer(cfg.checkpointer)
self.validate_tokenizer(cfg.tokenizer)
if "fused" in cfg.optimizer:
cfg.optimizer.pop("fused")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recipe will error out with device meta and fused, so we validate here instead.


class TestRecipeConfigs:
def validate_tokenizer(self, tokenizer_cfg):
with pytest.raises(OSError, match="No such file or directory"):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually maybe this test is stupid because it doesn't validate anything after a given OSError being raised. I should instead mock the load function and just make sure it instantiates OK.

return_value="boo",
):
if checkpointer_class.__name__ == "FullModelHFCheckpointer":
with pytest.raises(OSError, match="No such file or directory"):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above


with torch.device("meta"):
if "lora" in cfg.model._component_:
with patch.object(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to do this because register_state_dict_hook attempts to copy parameters, and you can't do anything state-dependent on meta tensors.

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for this @SalmanMohammadi ! Let me see if i am getting this right: You are basically loading the config and instantiating the recipe.

I understand having to validade the checkpointer independently, since you patch load_checkpoint, but i am not sure why you have to do it with the tokenizer, for example, if recipe.setup() will call it. Am I misunderstanding it?

Overall, i think that the test is very useful, but its not clear to me if its expensive. It was a bit hard to understand it. I think that with more structure and comments, it should be easier to review.

Please, dont spend a long time making changes. Can you do a 10~20min restructuring/commenting, and then i can come back to this?

@@ -83,7 +83,7 @@ dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: True
enable_activation_offloading: False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we only support it in lora and single device, currently. Also, if AC is true, then offloading can be set to True for almost free. I think its a good default, but we can make a broader change after offloading comes to all recipes

config_file_path = CONFIG_ROOT / config_file_path
module = load_module_from_path("recipe_module", recipe_file_path)
recipe_class = None
for name, obj in inspect.getmembers(module, inspect.isclass):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not familiar with inspect. Would you mind adding a small comment of this code intent here?



def load_module_from_path(module_name, path):
spec = importlib.util.spec_from_file_location(module_name, path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mind adding a TLDR docstring? No need for anything super complex, just a bit of context or a short example

Comment on lines +110 to +115
cfg.tokenizer = OmegaConf.create(
{
"_component_": "torchtune.models.llama3.llama3_tokenizer",
"path": TOKENIZER_PATHS["llama3"],
}
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont follow this part. Why do we have this specifically for llama3?

else:
state_dict = config.instantiate(cfg.model).state_dict()

load_dataset.return_value = [0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont know what this is doing

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mocks the HF load_dataset function which we don't want to make an expensive call to. I thought a bit about using the HF API to make sure the dataset exists but I felt that was a bit overkill.

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Oct 1, 2024

RE computational cost:

============================= slowest 20 durations =============================
487.49s call     tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss_qlora[True-bf16]
377.84s call     tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss_qlora[False-bf16]
95.78s call     tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss_qlora[True-fp32]
68.11s call     tests/recipes/test_full_finetune_single_device.py::TestFullFinetuneSingleDeviceRecipe::test_loss[llama2/7B_full_low_memory-llama2-meta-True]
67.96s call     tests/recipes/test_knowledge_distillation_single_device.py::TestKDSingleDeviceRecipe::test_loss[qwen2/knowledge_distillation_single_device-llama3-tune-True]
61.68s call     tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss[llama2/7B_lora_single_device-llama2-meta-True]
61.40s call     tests/recipes/test_eleuther_eval.py::TestEleutherEval::test_torchtune_checkpoint_eval_results[truthfulqa_gen-0.1-1]
51.10s call     tests/recipes/test_eleuther_eval.py::TestEleutherEval::test_torchtune_checkpoint_eval_results[truthfulqa_gen-0.1-4]
45.85s call     tests/recipes/test_full_finetune_single_device.py::TestFullFinetuneSingleDeviceRecipe::test_loss[llama3/8B_full_single_device-llama3-tune-True]
29.40s call     tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss[llama3/8B_lora_single_device-llama3-tune-True]
21.03s call     tests/recipes/test_lora_dpo_single_device.py::TestLoRADPOSingleDeviceRecipe::test_training_state_on_resume[False]
20.77s call     tests/recipes/test_lora_dpo_single_device.py::TestLoRADPOSingleDeviceRecipe::test_training_state_on_resume[True]
18.60s call     tests/recipes/test_lora_finetune_distributed.py::TestLoRAFinetuneDistributedRecipe::test_training_state_on_resume[llama3/8B_lora-llama3-tune-False]
18.12s call     tests/recipes/test_knowledge_distillation_single_device.py::TestKDSingleDeviceRecipe::test_training_state_on_resume
17.10s call     tests/recipes/test_lora_finetune_distributed.py::TestLoRAFinetuneDistributedRecipe::test_training_state_on_resume[llama2/7B_lora-llama2-hf-True]
17.10s call     tests/recipes/test_lora_finetune_distributed.py::TestLoRAFinetuneDistributedRecipe::test_training_state_on_resume[llama2/7B_lora-llama2-hf-False]
16.38s call     tests/recipes/test_ppo_full_finetune_single_device.py::TestPPOFullFinetuneSingleDeviceRecipe::test_training_state_on_resume_with_optimizer_in_bwd
16.16s call     tests/recipes/test_ppo_full_finetune_single_device.py::TestPPOFullFinetuneSingleDeviceRecipe::test_training_state_on_resume
15.96s call     tests/torchtune/training/test_distributed.py::TestFullyShardState::test_qlora_state_dict
15.81s call     tests/recipes/test_eleuther_eval.py::TestEleutherEval::test_torchtune_checkpoint_eval_results[truthfulqa_mc2-0.4-4]

So it doesn't look like any of these tests make it to the top 20 longest runs.

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Oct 1, 2024

Thanks so much for looking at this @felipemello1. Sorry for the very WIP state, I've polished up a bit now. I'll also move it into a sensible namespace in a bit.

If this approach generally looks good I'd appreciate any advice on testing distributed recipes. For DPO/KD/PPO single devices recipes I think I might need a seperate tests but I can leave that as a follow up.

@felipemello1
Copy link
Contributor

i will take a closer look soon. Thanks for the changes! However, I found 3 configs that were broken. I think we will need to beef this PR up a bit, or maybe having a second PR addressing other config issues. I will link the config fix here soon

@SalmanMohammadi
Copy link
Collaborator Author

i will take a closer look soon. Thanks for the changes! However, I found 3 configs that were broken. I think we will need to beef this PR up a bit, or maybe having a second PR addressing other config issues. I will link the config fix here soon

Very curious to see cases these tests didn't catch

@felipemello1 felipemello1 mentioned this pull request Oct 23, 2024
13 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants