Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KD distributed recipe #1631

Merged
merged 50 commits into from
Oct 29, 2024
Merged

Conversation

lindawangg
Copy link
Contributor

@lindawangg lindawangg commented Sep 20, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

To enable distributed training for knowledge distillation.

Changelog

What are the changes made in this PR?

  • Builds on top of Add single device KD recipe #1539
  • KD distributed recipe (knowledge_distillation_distributed.py) is similar to lora_finetune_distributed.py.
  • KD config: knowledge_distillation_distributed.yaml

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
tune run --nodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed

(left) single device (right) distributed, can also increase batch size
imageimage
Similar eval results
image

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Sep 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1631

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cf5f01a with merge base d3039da (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 20, 2024
@lindawangg lindawangg marked this pull request as ready for review September 20, 2024 01:55
@felipemello1
Copy link
Contributor

Hey @lindawangg , thanks for the recipe!! We have been a bit busy, but we will get to this PR.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few minor comments, otherwise looks good!

# To launch on a single device, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed
#
# This config works only for distilling on a single device.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: update this

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the peak memory you're seeing is but with distributed you may be able to get away without this (especially for such small models) and get faster training

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to False. Isn't needed for qwen2 and training time also went from 1h to 20 mins.

Comment on lines 56 to 62
@pytest.mark.parametrize(
"reshard_after_forward",
[
True,
False,
],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big deal but you can probably get away without testing both of these cases. We already test it elsewhere and I don't expect it to change in KD vs in other recipes (lmk if you disagree though)

Comment on lines 75 to 78
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}] \
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA3 \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly we should probably just bite the bullet and upload some small Qwen2-formatted checkpoints rather than overriding everything as Llama in these tests. (Btw you don't have to worry about this, I am just writing it down so we can hold ourselves accountable later 😃 )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since llama3_2 is released. I changed to llama3_2 distributed config, which uses the same LLAMA3 model type

Comment on lines 44 to 102
"""
Knowledge distillation recipe for dense transformer-based LLMs such as Llama3. This recipe is optimized
for single GPU training. Training on CPU is not supported.

Features:
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
activations in memory and instead recompute them during the backward pass. This is especially
helpful for larger batch sizes when you're memory constrained. But these savings in memory
come at the cost of training performance. In most cases training can slow-down quite a bit as
a result of this activation recomputation.

- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
most cases this should halve the memory footprint of full precision (fp32) training, without
loss in model quality (will depend on the model, training data and other settings). For
GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
precision are currently not supported.g

- Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
controlled using the ``gradient_accumulation_steps`` flag.

Total Batch Size = batch_size * gradient accumulation steps.

For example: with batch_size=1 and gradient_accumulation_steps=32 we get a total batch size of 32.

Gradient accumulation is especially useful when you are memory constrained. In this case,
accumulating gradients might give you better training speed than enabling activation
checkpointing.

- Lower precision optimizers. This recipe supports lower-precision optimizers from the bitsandbytes
library (https://huggingface.co/docs/bitsandbytes/main/en/index). We've tested the recipe with
8-bit AdamW and Paged AdamW.

- Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
training. Currently we checkpoint both the adapter weights (trainable params only) and the
complete merged weights (adapter weights added back to the base model). For more details
please take a look at our LoRA tutorial
(https://pytorch.org/torchtune/main/tutorials/lora_finetune.html).

Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are
only saved at the end of a given epoch and used in case of resuming training. Resuming
training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is
currently not supported.

For more details on the checkpointer, please take a look at
our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html).

- Logging. Terminal, Disk, WandB and TensorBoard are all supported.

For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
has example commands for how to kick-off training.

Args:
cfg (DictConfig): OmegaConf object parsed from yaml file

Raises:
ValueError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this whole docstring is from the single-device recipe. May wanna make sure it lines up with the features that are in here (e.g. FSDP, and I don't think we really advertise low-precision optimizers in our distributed recipes (though they should probably work))

is_dora = False
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
is_dora = (True,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why a tuple? Also might be useful to run with QLoRA and/or DoRA just as a sanity check that nothing breaks if you haven't already

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that was a mistake. I don't remember why. Changed to is_dora = True and tested works with dora.

@joecummings joecummings mentioned this pull request Oct 15, 2024
34 tasks
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments, but no huge concerns from my side. Looks like CI is red for unrelated reasons, tagging @joecummings who is looking into it

# This config assumes that you've ran the following commands before launching KD:
# First download the student and teacher models
# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth"
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it's also worthwhile to add a config with 70B model size? (Doesn't necessarily have to be in this PR, but it'd be useful to have at least one config that strictly requires distributed)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a follow-up to that, I wonder if we should include model sizes in the config names? I know it makes it a bit longer (and doesn't line up with what you did for the single-device configs), but otherwise we cannot really distinguish between configs for distilling 70B -> 1B vs 8B -> 1B. Similar to the other comment here, this is fine to save for a follow-up though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this config on the 70B model and verified it works, but I think there has to be more tuning. We can add the 70B model in a separate PR and figure out how to change the naming. There wasn't many changes to add the 70B model, just the model target and checkpoint since tokenizer has to be the same right now.


@pytest.mark.integration_test
@gpu_test(gpu_count=2)
def test_training_state_on_resume(self, tmpdir, monkeypatch):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we lose test_loss along the way here? (Just wanna make sure it was deliberate)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I think I read the previous comment wrong. I thought you meant since test_loss is already tested, we didn't need to test it again. But now i realized you meant reshard_after_forward. Let me add test_loss back in

Comment on lines 465 to 470
training.shard_model(
model=model,
shard_conditions=[_is_layer_name],
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, did you try keeping the student model unsharded? I'm wondering what the tradeoff is here for perf vs memory.. if the model is small enough to not change the HW profile we're runnable on by fully replicating across all devices but we get speedups by saving on comms, might be worthwhile.

Copy link
Contributor Author

@lindawangg lindawangg Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could do it for 1B, but i got oom when trying to load 3B student and 70B teacher models. We could set it as an option to shard the student model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you see nice speedups on 1B, or were they pretty minimal? If the latter let's just leave it as is, otherwise we can consider exposing the option as you mentioned

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The speedup is pretty minimal, especially when using 8 gpus. The number of devices influences the speed more.
1B w/o fsdp 5 steps on 8 gpus: 1:09
1B w/o fsdp 5 steps on 4 gpus: 3:57
1B w/ fsdp 5 steps on 8 gpus: 1:14
1B w/ fsdp 5 steps on 4 gpus: 4:59


class_loss, kd_loss = self._loss_step(batch)
loss = (1 - self._kd_ratio) * class_loss + self._kd_ratio * kd_loss
loss = loss / self._gradient_accumulation_steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need to actually do anything for this PR, but just FYI we are likely to be changing how we normalize loss when gradient accumulation is enabled (see #1875)

@codecov-commenter
Copy link

codecov-commenter commented Oct 26, 2024

Codecov Report

Attention: Patch coverage is 5.65111% with 384 lines in your changes missing coverage. Please review.

Project coverage is 68.70%. Comparing base (23c8829) to head (cf5f01a).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
recipes/knowledge_distillation_distributed.py 0.00% 313 Missing ⚠️
...recipes/test_knowledge_distillation_distributed.py 24.46% 71 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1631      +/-   ##
==========================================
- Coverage   70.44%   68.70%   -1.74%     
==========================================
  Files         308      306       -2     
  Lines       16270    16596     +326     
==========================================
- Hits        11462    11403      -59     
- Misses       4808     5193     +385     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

# You get better results using KD if the teacher model has already been fine-tuned on the target dataset:
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora
#
# To launch on a 2 devices, run the following command from root:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
# To launch on a 2 devices, run the following command from root:
# To launch on 2 devices, run the following command from root:

# You get better results using KD if the teacher model has already been fine-tuned on the target dataset:
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora
#
# To launch on a 2 devices, run the following command from root:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same nit here

Suggested change
# To launch on a 2 devices, run the following command from root:
# To launch on 2 devices, run the following command from root:

)

# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
# if cfg is missing profiler key or if `cfg.profiler.enabled = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit (or just remove the backticks, this isn't rendering anywhere anyways)

Suggested change
# if cfg is missing profiler key or if `cfg.profiler.enabled = False
# if cfg is missing profiler key or if `cfg.profiler.enabled = False`

Comment on lines 446 to 463
# For FSDP sharding, we can condition on either the module or its name
# Shard conditions should be callables taking name (relative to model root)
# and the module itself and returning a bool on whether to shard the given module

# Shard transformer decoder layers (or AC-wrapped versions)
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
# But directly using the name is more concise
def _is_layer_name(name: str, module: nn.Module) -> bool:
"""
Return True for layers.i and False for all other module names
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
name_list = name.split(".")
return (
len(name_list) == 2
and name_list[0] == "layers"
and str.isdigit(name_list[1])
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry one more set of changes to merge.. we changed this in #1889 due to a bug in how multimodal models were being handled by this logic. Can you do something similar to this instead?

Comment on lines 565 to 584
# For FSDP sharding, we can condition on either the module or its name
# Shard conditions should be callables taking name (relative to model root)
# and the module itself and returning a bool on whether to shard the given module
fsdp_shard_conditions = []

# Shard transformer decoder layers (or AC-wrapped versions)
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
# But directly using the name is more concise
def _is_layer_fqn(s: str) -> bool:
"""
Return True for layers.i and False for all other module names
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
s_list = s.split(".")
return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1])

fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)]

if custom_sharded_layers:
fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar comment here

self._profiler.step()

self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not need to call self._profiler.stop() somewhere?

# in case shuffle is True
self._sampler.set_epoch(curr_epoch)

pbar = tqdm(total=self._steps_per_epoch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not need to disable on nonzero ranks?

from torchtune import config


class TestKDDistributedDeviceRecipe:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class TestKDDistributedDeviceRecipe:
class TestKDDistributedRecipe:

Comment on lines 94 to 95
print(loss_values)
print(expected_loss_values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune
checkpoint_files: [
hf_model_0001_0.pt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the default .pt? The assumption is that we finetuned the model on a target dataset first?

If so, the first example (Llama3.2) is wrong b/c the default files are safetensors, which are only saved if the checkpointer specifies safe_serialization: True. We should be consistent across these defaults.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the default lora distributed finetune configs for qwen2 and llama3.1 8b. I'm not sure why qwen2/1.5B_lora outputs .pt whereas llama3_1/8B_lora outputs .safetensors.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this!

@ebsmothers ebsmothers merged commit 09c2619 into pytorch:main Oct 29, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants