-
Notifications
You must be signed in to change notification settings - Fork 416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add KD distributed recipe #1631
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1631
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cf5f01a with merge base d3039da (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hey @lindawangg , thanks for the recipe!! We have been a bit busy, but we will get to this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few minor comments, otherwise looks good!
# To launch on a single device, run the following command from root: | ||
# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed | ||
# | ||
# This config works only for distilling on a single device. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: update this
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
enable_activation_checkpointing: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what the peak memory you're seeing is but with distributed you may be able to get away without this (especially for such small models) and get faster training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to False. Isn't needed for qwen2 and training time also went from 1h to 20 mins.
@pytest.mark.parametrize( | ||
"reshard_after_forward", | ||
[ | ||
True, | ||
False, | ||
], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a big deal but you can probably get away without testing both of these cases. We already test it elsewhere and I don't expect it to change in KD vs in other recipes (lmk if you disagree though)
checkpointer.checkpoint_dir='{ckpt_dir}' \ | ||
checkpointer.checkpoint_files=[{ckpt_path}] \ | ||
checkpointer.output_dir={tmpdir} \ | ||
checkpointer.model_type=LLAMA3 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly we should probably just bite the bullet and upload some small Qwen2-formatted checkpoints rather than overriding everything as Llama in these tests. (Btw you don't have to worry about this, I am just writing it down so we can hold ourselves accountable later 😃 )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since llama3_2 is released. I changed to llama3_2 distributed config, which uses the same LLAMA3 model type
""" | ||
Knowledge distillation recipe for dense transformer-based LLMs such as Llama3. This recipe is optimized | ||
for single GPU training. Training on CPU is not supported. | ||
|
||
Features: | ||
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` | ||
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep | ||
activations in memory and instead recompute them during the backward pass. This is especially | ||
helpful for larger batch sizes when you're memory constrained. But these savings in memory | ||
come at the cost of training performance. In most cases training can slow-down quite a bit as | ||
a result of this activation recomputation. | ||
|
||
- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` | ||
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In | ||
most cases this should halve the memory footprint of full precision (fp32) training, without | ||
loss in model quality (will depend on the model, training data and other settings). For | ||
GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 | ||
precision are currently not supported.g | ||
|
||
- Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is | ||
controlled using the ``gradient_accumulation_steps`` flag. | ||
|
||
Total Batch Size = batch_size * gradient accumulation steps. | ||
|
||
For example: with batch_size=1 and gradient_accumulation_steps=32 we get a total batch size of 32. | ||
|
||
Gradient accumulation is especially useful when you are memory constrained. In this case, | ||
accumulating gradients might give you better training speed than enabling activation | ||
checkpointing. | ||
|
||
- Lower precision optimizers. This recipe supports lower-precision optimizers from the bitsandbytes | ||
library (https://huggingface.co/docs/bitsandbytes/main/en/index). We've tested the recipe with | ||
8-bit AdamW and Paged AdamW. | ||
|
||
- Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of | ||
training. Currently we checkpoint both the adapter weights (trainable params only) and the | ||
complete merged weights (adapter weights added back to the base model). For more details | ||
please take a look at our LoRA tutorial | ||
(https://pytorch.org/torchtune/main/tutorials/lora_finetune.html). | ||
|
||
Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are | ||
only saved at the end of a given epoch and used in case of resuming training. Resuming | ||
training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is | ||
currently not supported. | ||
|
||
For more details on the checkpointer, please take a look at | ||
our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html). | ||
|
||
- Logging. Terminal, Disk, WandB and TensorBoard are all supported. | ||
|
||
For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config | ||
has example commands for how to kick-off training. | ||
|
||
Args: | ||
cfg (DictConfig): OmegaConf object parsed from yaml file | ||
|
||
Raises: | ||
ValueError: If ``dtype`` is set to fp16. | ||
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this whole docstring is from the single-device recipe. May wanna make sure it lines up with the features that are in here (e.g. FSDP, and I don't think we really advertise low-precision optimizers in our distributed recipes (though they should probably work))
is_dora = False | ||
for m in model.modules(): | ||
if hasattr(m, "initialize_dora_magnitude"): | ||
is_dora = (True,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why a tuple? Also might be useful to run with QLoRA and/or DoRA just as a sanity check that nothing breaks if you haven't already
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that was a mistake. I don't remember why. Changed to is_dora = True
and tested works with dora.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments, but no huge concerns from my side. Looks like CI is red for unrelated reasons, tagging @joecummings who is looking into it
# This config assumes that you've ran the following commands before launching KD: | ||
# First download the student and teacher models | ||
# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" | ||
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it's also worthwhile to add a config with 70B model size? (Doesn't necessarily have to be in this PR, but it'd be useful to have at least one config that strictly requires distributed)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a follow-up to that, I wonder if we should include model sizes in the config names? I know it makes it a bit longer (and doesn't line up with what you did for the single-device configs), but otherwise we cannot really distinguish between configs for distilling 70B -> 1B vs 8B -> 1B. Similar to the other comment here, this is fine to save for a follow-up though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested this config on the 70B model and verified it works, but I think there has to be more tuning. We can add the 70B model in a separate PR and figure out how to change the naming. There wasn't many changes to add the 70B model, just the model target and checkpoint since tokenizer has to be the same right now.
|
||
@pytest.mark.integration_test | ||
@gpu_test(gpu_count=2) | ||
def test_training_state_on_resume(self, tmpdir, monkeypatch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we lose test_loss
along the way here? (Just wanna make sure it was deliberate)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I think I read the previous comment wrong. I thought you meant since test_loss
is already tested, we didn't need to test it again. But now i realized you meant reshard_after_forward
. Let me add test_loss back in
training.shard_model( | ||
model=model, | ||
shard_conditions=[_is_layer_name], | ||
cpu_offload=fsdp_cpu_offload, | ||
reshard_after_forward=reshard_after_forward, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, did you try keeping the student model unsharded? I'm wondering what the tradeoff is here for perf vs memory.. if the model is small enough to not change the HW profile we're runnable on by fully replicating across all devices but we get speedups by saving on comms, might be worthwhile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could do it for 1B, but i got oom when trying to load 3B student and 70B teacher models. We could set it as an option to shard the student model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you see nice speedups on 1B, or were they pretty minimal? If the latter let's just leave it as is, otherwise we can consider exposing the option as you mentioned
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The speedup is pretty minimal, especially when using 8 gpus. The number of devices influences the speed more.
1B w/o fsdp 5 steps on 8 gpus: 1:09
1B w/o fsdp 5 steps on 4 gpus: 3:57
1B w/ fsdp 5 steps on 8 gpus: 1:14
1B w/ fsdp 5 steps on 4 gpus: 4:59
|
||
class_loss, kd_loss = self._loss_step(batch) | ||
loss = (1 - self._kd_ratio) * class_loss + self._kd_ratio * kd_loss | ||
loss = loss / self._gradient_accumulation_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need to actually do anything for this PR, but just FYI we are likely to be changing how we normalize loss when gradient accumulation is enabled (see #1875)
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1631 +/- ##
==========================================
- Coverage 70.44% 68.70% -1.74%
==========================================
Files 308 306 -2
Lines 16270 16596 +326
==========================================
- Hits 11462 11403 -59
- Misses 4808 5193 +385 ☔ View full report in Codecov by Sentry. |
# You get better results using KD if the teacher model has already been fine-tuned on the target dataset: | ||
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora | ||
# | ||
# To launch on a 2 devices, run the following command from root: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
# To launch on a 2 devices, run the following command from root: | |
# To launch on 2 devices, run the following command from root: |
# You get better results using KD if the teacher model has already been fine-tuned on the target dataset: | ||
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora | ||
# | ||
# To launch on a 2 devices, run the following command from root: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same nit here
# To launch on a 2 devices, run the following command from root: | |
# To launch on 2 devices, run the following command from root: |
) | ||
|
||
# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) | ||
# if cfg is missing profiler key or if `cfg.profiler.enabled = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit (or just remove the backticks, this isn't rendering anywhere anyways)
# if cfg is missing profiler key or if `cfg.profiler.enabled = False | |
# if cfg is missing profiler key or if `cfg.profiler.enabled = False` |
# For FSDP sharding, we can condition on either the module or its name | ||
# Shard conditions should be callables taking name (relative to model root) | ||
# and the module itself and returning a bool on whether to shard the given module | ||
|
||
# Shard transformer decoder layers (or AC-wrapped versions) | ||
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) | ||
# But directly using the name is more concise | ||
def _is_layer_name(name: str, module: nn.Module) -> bool: | ||
""" | ||
Return True for layers.i and False for all other module names | ||
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot | ||
""" | ||
name_list = name.split(".") | ||
return ( | ||
len(name_list) == 2 | ||
and name_list[0] == "layers" | ||
and str.isdigit(name_list[1]) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# For FSDP sharding, we can condition on either the module or its name | ||
# Shard conditions should be callables taking name (relative to model root) | ||
# and the module itself and returning a bool on whether to shard the given module | ||
fsdp_shard_conditions = [] | ||
|
||
# Shard transformer decoder layers (or AC-wrapped versions) | ||
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) | ||
# But directly using the name is more concise | ||
def _is_layer_fqn(s: str) -> bool: | ||
""" | ||
Return True for layers.i and False for all other module names | ||
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot | ||
""" | ||
s_list = s.split(".") | ||
return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1]) | ||
|
||
fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)] | ||
|
||
if custom_sharded_layers: | ||
fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar comment here
self._profiler.step() | ||
|
||
self.epochs_run += 1 | ||
self.save_checkpoint(epoch=curr_epoch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we not need to call self._profiler.stop()
somewhere?
# in case shuffle is True | ||
self._sampler.set_epoch(curr_epoch) | ||
|
||
pbar = tqdm(total=self._steps_per_epoch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we not need to disable on nonzero ranks?
from torchtune import config | ||
|
||
|
||
class TestKDDistributedDeviceRecipe: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class TestKDDistributedDeviceRecipe: | |
class TestKDDistributedRecipe: |
print(loss_values) | ||
print(expected_loss_values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune | ||
checkpoint_files: [ | ||
hf_model_0001_0.pt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the default .pt
? The assumption is that we finetuned the model on a target dataset first?
If so, the first example (Llama3.2) is wrong b/c the default files are safetensors, which are only saved if the checkpointer specifies safe_serialization: True
. We should be consistent across these defaults.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used the default lora distributed finetune configs for qwen2 and llama3.1 8b. I'm not sure why qwen2/1.5B_lora outputs .pt
whereas llama3_1/8B_lora outputs .safetensors
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for working on this!
Context
What is the purpose of this PR? Is it to
To enable distributed training for knowledge distillation.
Changelog
What are the changes made in this PR?
knowledge_distillation_distributed.py
) is similar tolora_finetune_distributed.py
.Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
(left) single device (right) distributed, can also increase batch size
Similar eval results
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example