-
Notifications
You must be signed in to change notification settings - Fork 416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Config Continous Integration (CCI) #1717
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1717
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 240ef5e with merge base 3fddc56 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1717 +/- ##
==========================================
- Coverage 70.67% 67.48% -3.20%
==========================================
Files 299 305 +6
Lines 15251 15683 +432
==========================================
- Hits 10778 10583 -195
- Misses 4473 5100 +627
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
@@ -83,7 +83,7 @@ dtype: bf16 | |||
|
|||
# Activations Memory | |||
enable_activation_checkpointing: True | |||
enable_activation_offloading: True | |||
enable_activation_offloading: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the only recipe where it was enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we only support it in lora and single device, currently. Also, if AC is true, then offloading can be set to True for almost free. I think its a good default, but we can make a broader change after offloading comes to all recipes
self.validate_checkpointer(cfg.checkpointer) | ||
self.validate_tokenizer(cfg.tokenizer) | ||
if "fused" in cfg.optimizer: | ||
cfg.optimizer.pop("fused") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recipe will error out with device meta and fused, so we validate here instead.
|
||
class TestRecipeConfigs: | ||
def validate_tokenizer(self, tokenizer_cfg): | ||
with pytest.raises(OSError, match="No such file or directory"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually maybe this test is stupid because it doesn't validate anything after a given OSError being raised. I should instead mock the load function and just make sure it instantiates OK.
return_value="boo", | ||
): | ||
if checkpointer_class.__name__ == "FullModelHFCheckpointer": | ||
with pytest.raises(OSError, match="No such file or directory"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
|
||
with torch.device("meta"): | ||
if "lora" in cfg.model._component_: | ||
with patch.object( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to do this because register_state_dict_hook attempts to copy parameters, and you can't do anything state-dependent on meta tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for this @SalmanMohammadi ! Let me see if i am getting this right: You are basically loading the config and instantiating the recipe.
I understand having to validade the checkpointer independently, since you patch load_checkpoint, but i am not sure why you have to do it with the tokenizer, for example, if recipe.setup() will call it. Am I misunderstanding it?
Overall, i think that the test is very useful, but its not clear to me if its expensive. It was a bit hard to understand it. I think that with more structure and comments, it should be easier to review.
Please, dont spend a long time making changes. Can you do a 10~20min restructuring/commenting, and then i can come back to this?
@@ -83,7 +83,7 @@ dtype: bf16 | |||
|
|||
# Activations Memory | |||
enable_activation_checkpointing: True | |||
enable_activation_offloading: True | |||
enable_activation_offloading: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we only support it in lora and single device, currently. Also, if AC is true, then offloading can be set to True for almost free. I think its a good default, but we can make a broader change after offloading comes to all recipes
config_file_path = CONFIG_ROOT / config_file_path | ||
module = load_module_from_path("recipe_module", recipe_file_path) | ||
recipe_class = None | ||
for name, obj in inspect.getmembers(module, inspect.isclass): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not familiar with inspect. Would you mind adding a small comment of this code intent here?
|
||
|
||
def load_module_from_path(module_name, path): | ||
spec = importlib.util.spec_from_file_location(module_name, path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mind adding a TLDR docstring? No need for anything super complex, just a bit of context or a short example
cfg.tokenizer = OmegaConf.create( | ||
{ | ||
"_component_": "torchtune.models.llama3.llama3_tokenizer", | ||
"path": TOKENIZER_PATHS["llama3"], | ||
} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i dont follow this part. Why do we have this specifically for llama3?
else: | ||
state_dict = config.instantiate(cfg.model).state_dict() | ||
|
||
load_dataset.return_value = [0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dont know what this is doing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mocks the HF load_dataset
function which we don't want to make an expensive call to. I thought a bit about using the HF API to make sure the dataset exists but I felt that was a bit overkill.
RE computational cost: ============================= slowest 20 durations =============================
487.49s call tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss_qlora[True-bf16]
377.84s call tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss_qlora[False-bf16]
95.78s call tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss_qlora[True-fp32]
68.11s call tests/recipes/test_full_finetune_single_device.py::TestFullFinetuneSingleDeviceRecipe::test_loss[llama2/7B_full_low_memory-llama2-meta-True]
67.96s call tests/recipes/test_knowledge_distillation_single_device.py::TestKDSingleDeviceRecipe::test_loss[qwen2/knowledge_distillation_single_device-llama3-tune-True]
61.68s call tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss[llama2/7B_lora_single_device-llama2-meta-True]
61.40s call tests/recipes/test_eleuther_eval.py::TestEleutherEval::test_torchtune_checkpoint_eval_results[truthfulqa_gen-0.1-1]
51.10s call tests/recipes/test_eleuther_eval.py::TestEleutherEval::test_torchtune_checkpoint_eval_results[truthfulqa_gen-0.1-4]
45.85s call tests/recipes/test_full_finetune_single_device.py::TestFullFinetuneSingleDeviceRecipe::test_loss[llama3/8B_full_single_device-llama3-tune-True]
29.40s call tests/recipes/test_lora_finetune_single_device.py::TestLoRAFinetuneSingleDeviceRecipe::test_loss[llama3/8B_lora_single_device-llama3-tune-True]
21.03s call tests/recipes/test_lora_dpo_single_device.py::TestLoRADPOSingleDeviceRecipe::test_training_state_on_resume[False]
20.77s call tests/recipes/test_lora_dpo_single_device.py::TestLoRADPOSingleDeviceRecipe::test_training_state_on_resume[True]
18.60s call tests/recipes/test_lora_finetune_distributed.py::TestLoRAFinetuneDistributedRecipe::test_training_state_on_resume[llama3/8B_lora-llama3-tune-False]
18.12s call tests/recipes/test_knowledge_distillation_single_device.py::TestKDSingleDeviceRecipe::test_training_state_on_resume
17.10s call tests/recipes/test_lora_finetune_distributed.py::TestLoRAFinetuneDistributedRecipe::test_training_state_on_resume[llama2/7B_lora-llama2-hf-True]
17.10s call tests/recipes/test_lora_finetune_distributed.py::TestLoRAFinetuneDistributedRecipe::test_training_state_on_resume[llama2/7B_lora-llama2-hf-False]
16.38s call tests/recipes/test_ppo_full_finetune_single_device.py::TestPPOFullFinetuneSingleDeviceRecipe::test_training_state_on_resume_with_optimizer_in_bwd
16.16s call tests/recipes/test_ppo_full_finetune_single_device.py::TestPPOFullFinetuneSingleDeviceRecipe::test_training_state_on_resume
15.96s call tests/torchtune/training/test_distributed.py::TestFullyShardState::test_qlora_state_dict
15.81s call tests/recipes/test_eleuther_eval.py::TestEleutherEval::test_torchtune_checkpoint_eval_results[truthfulqa_mc2-0.4-4] So it doesn't look like any of these tests make it to the top 20 longest runs. |
Thanks so much for looking at this @felipemello1. Sorry for the very WIP state, I've polished up a bit now. I'll also move it into a sensible namespace in a bit. If this approach generally looks good I'd appreciate any advice on testing distributed recipes. For DPO/KD/PPO single devices recipes I think I might need a seperate tests but I can leave that as a follow up. |
i will take a closer look soon. Thanks for the changes! However, I found 3 configs that were broken. I think we will need to beef this PR up a bit, or maybe having a second PR addressing other config issues. I will link the config fix here soon |
Very curious to see cases these tests didn't catch |
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
*
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example