Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do not review] Activation offloading #467

Draft
wants to merge 3 commits into
base: gh/awgu/8/base
Choose a base branch
from
Draft

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Jul 18, 2024

Stack from ghstack (oldest at bottom):

Current UX

  • We use a saved_tensors_hooks context manager, which should be wrapped around module.forward. The context lets us override pack and unpack hooks that are called when saving an activation for backward and using an activation in backward, respectively. See the tutorial for more info.
  • We expose two main methods for the user from the context: wait_for_d2h and copy_h2d_async.
    • By default, the D2H copies for offloading are async and use pinned memory. The user must call wait_for_d2h to wait on the D2H copies and free the device memory. This should be done after the compute to overlap with has been issued.
    • By default, the H2D copies are sync. The user must call copy_h2d_async to prefetch the H2D copies as async. This should be done before the compute to overlap with has been issued.
    • We show an example of this in apply_ac in parallelize_llama.py using module hooks.
  • Together, this means that by default, no GPU memory is saved and that H2D copies are sync. Only by calling the wait_for_d2h method can we save GPU memory, and only by calling copy_h2d_async methods can we overlap H2D in backward.

Known Problems

  • ! Conflict with split_with_sizes_copy's H2D copy (specific to FSDP2):
    • FSDP2's all-gather copy-out uses split_with_sizes_copy, which first issues a mini-H2D copy to send metadata needed for the main copy.
    • When the CPU issue order is copy_h2d_async for layer i-1 -> split_with_sizes_copy for layer i -> layer i backward compute, the mini-H2D copy for split_with_sizes_copy for layer i can get serialized to run after the copy_h2d_async for layer i-1 H2D copies even though they are running in different streams. This prevents the copy_h2d_async for layer i-1 to overlap with layer i backward compute.
    • For now, this can be worked around with reshard_after_forward=False.
    • Trick/hack from @yifuwang : sleep 1 ms in the offload_stream before un-offloading (https://fburl.com/perfdoctor/ms47gqvp) --> allows prioritizing the split_with_sizes_copy H2D copy
    • The CPU issue order of copy_h2d_async for layer i-1 -> split_with_sizes_copy for layer i -> layer i backward compute comes from running the copy_h2d_async for layer i-1 using a module full pre-backward hook.
  • ! If the user offloads too many activations, the program can become slow and/or freeze. Further, the first few iterations are slow due to cudaHostAlloc calls warming up the CPU caching allocator. This might be brittle if other parts of the program (e.g. checkpointing) also use pinned memory. If we do not gc.collect() every iteration, the pinned memory does not seem to be freed, so the allocator does not reuse it in subsequent iterations.
  • ! We do not have a good way to apply a predicate to decide which activation tensors to offload. With the pack hook API, we only see the tensor, not any other information like which op constructed the tensor.

Examples

  • Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, reshard_after_forward=False:
  • Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, reshard_after_forward=True:
    • Trace: https://fburl.com/perfdoctor/qbhr98az
    • Reserved memory: 54.01GiB(56.83%)
    • WPS: 4,085 MFU: 23.92% (mainly because H2Ds in backward are not overlapped)
    • If we use @yifuwang's trick, we can get WPS: 5,073, MFU: 29.70% without changing reserved memory
  • Baseline Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, reshard_after_forward=True:
    • Reserved memory: 78.38GiB(82.48%)
    • WPS: 6,341 MFU: 37.13%

awgu added a commit that referenced this pull request Jul 18, 2024
ghstack-source-id: 863f4bc18580e02e2c2dd4198374730429a97961
Pull Request resolved: #467
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 18, 2024
awgu added a commit that referenced this pull request Jul 18, 2024
ghstack-source-id: f749e6fc173efba59c859166d6dd9eae8917aab7
Pull Request resolved: #467
**Current UX**
- We use a `saved_tensors_hooks` context manager, which should be wrapped around `module.forward`. The context lets us override pack and unpack hooks that are called when saving an activation for backward and using an activation in backward, respectively. See the [tutorial](https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html) for more info.
- We expose two main methods for the user from the context: `wait_for_d2h` and `copy_h2d_async`.
    - By default, the D2H copies for offloading are async and use pinned memory. The user must call `wait_for_d2h` to wait on the D2H copies and free the device memory. This should be done after the compute to overlap with has been issued.
    - By default, the H2D copies are sync. The user must call `copy_h2d_async` to prefetch the H2D copies as async. This should be done before the compute to overlap with has been issued.
    - We show an example of this in `apply_ac` in `parallelize_llama.py` using module hooks.

**Known Problems**
- **!** Conflict with `split_with_sizes_copy`'s H2D copy (specific to FSDP2):
    - FSDP2's all-gather copy-out uses `split_with_sizes_copy`, which first issues a mini-H2D copy to send metadata needed for the main copy.
    - When the CPU issue order is `copy_h2d_async` for layer `i-1` -> `split_with_sizes_copy` for layer `i` -> layer `i` backward compute, the mini-H2D copy for `split_with_sizes_copy` for layer `i` can get serialized to run _after_ the `copy_h2d_async` for layer `i-1` H2D copies even though they are running in different streams. This prevents the `copy_h2d_async` for layer `i-1` to overlap with layer `i` backward compute.
    - For now, this can be worked around with `reshard_after_forward=False`.
    - Trick/hack from yifuwang : sleep 1 ms in the `offload_stream` before un-offloading (https://fburl.com/perfdoctor/ms47gqvp) --> allows prioritizing the `split_with_sizes_copy` H2D copy
    - The CPU issue order of `copy_h2d_async` for layer `i-1` -> `split_with_sizes_copy` for layer `i` -> layer `i` backward compute comes from running the `copy_h2d_async` for layer `i-1` using a module full pre-backward hook.
- **!** If the user offloads too many activations, the program can become slow and/or freeze. Further, the first few iterations are slow due to `cudaHostAlloc` calls warming up the CPU caching allocator. This might be brittle if other parts of the program (e.g. checkpointing) also use pinned memory. If we do not `gc.collect()` every iteration, the pinned memory does not seem to be freed, so the allocator does not reuse it in subsequent iterations.
- **!** We do not have a good way to apply a predicate to decide which activation tensors to offload. With the pack hook API, we only see the tensor, not any other information like which op constructed the tensor.

**Examples**
- Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, `reshard_after_forward=False`:
    - Trace: https://fburl.com/perfdoctor/r1yf0lqf
    - Reserved memory: 65.67GiB(69.10%)
    - WPS: 5,294  MFU: 31.00%
- Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, `reshard_after_forward=True`:
    - Trace: https://fburl.com/perfdoctor/qbhr98az
    - Reserved memory: 54.01GiB(56.83%)
    - WPS: 4,085  MFU: 23.92% (mainly because H2Ds in backward are not overlapped)
    - If we use yifuwang's trick, we can get WPS: 5,073, MFU: 29.70% without changing reserved memory
- Baseline Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, `reshard_after_forward=True`:
    - Reserved memory: 78.38GiB(82.48%)
    - WPS: 6,341  MFU: 37.13%

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jul 18, 2024
ghstack-source-id: 1f53901f927b56c0ff58b81f853e6969cf348b84
Pull Request resolved: #467
@@ -400,6 +400,11 @@ def loss_fn(pred, labels):
optimizers.step()
lr_schedulers.step()

if job_config.experimental.offload_activations:
# NOTE: We need `gc.collect` to ensure that CPU memory is freed
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs some more investigation. Maybe there is some ref cycle that I am not aware of.

@awgu
Copy link
Contributor Author

awgu commented Jul 18, 2024

For fun:

  • Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, reshard_after_forward=True, using @yifuwang 's hack (0.5 ms sleep), offloading 2 large FFN activations per transformer block
  • Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, reshard_after_forward=False, offloading 2 large FFN activations per transformer block

Note how reshard_after_forward=False with this 2-FFN-activation offloading dominates reshard_after_forward=True without offloading/AC since it has higher WPS but lower memory.

@awgu
Copy link
Contributor Author

awgu commented Jul 19, 2024

@yifuwang made the good point there may be interference for inter-node collectives since they also use PCIe to send/recv data to/from the NIC, competing for the D2H/H2D activation copies. The testing so far was only intra-node due to lack of compute resources.

tianyu-l pushed a commit that referenced this pull request Aug 16, 2024
ghstack-source-id: 1f53901f927b56c0ff58b81f853e6969cf348b84
Pull Request resolved: #467
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants