Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do not review] Activation offloading #467

Draft
wants to merge 3 commits into
base: gh/awgu/8/base
Choose a base branch
from
Draft

Commits on Jul 18, 2024

  1. [Do not review] Activation offloading

    [ghstack-poisoned]
    awgu committed Jul 18, 2024
    Configuration menu
    Copy the full SHA
    f64b8fa View commit details
    Browse the repository at this point in the history
  2. Update on "[Do not review] Activation offloading"

    [ghstack-poisoned]
    awgu committed Jul 18, 2024
    Configuration menu
    Copy the full SHA
    d42fc6a View commit details
    Browse the repository at this point in the history
  3. Update on "[Do not review] Activation offloading"

    **Current UX**
    - We use a `saved_tensors_hooks` context manager, which should be wrapped around `module.forward`. The context lets us override pack and unpack hooks that are called when saving an activation for backward and using an activation in backward, respectively. See the [tutorial](https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html) for more info.
    - We expose two main methods for the user from the context: `wait_for_d2h` and `copy_h2d_async`.
        - By default, the D2H copies for offloading are async and use pinned memory. The user must call `wait_for_d2h` to wait on the D2H copies and free the device memory. This should be done after the compute to overlap with has been issued.
        - By default, the H2D copies are sync. The user must call `copy_h2d_async` to prefetch the H2D copies as async. This should be done before the compute to overlap with has been issued.
        - We show an example of this in `apply_ac` in `parallelize_llama.py` using module hooks.
    
    **Known Problems**
    - **!** Conflict with `split_with_sizes_copy`'s H2D copy (specific to FSDP2):
        - FSDP2's all-gather copy-out uses `split_with_sizes_copy`, which first issues a mini-H2D copy to send metadata needed for the main copy.
        - When the CPU issue order is `copy_h2d_async` for layer `i-1` -> `split_with_sizes_copy` for layer `i` -> layer `i` backward compute, the mini-H2D copy for `split_with_sizes_copy` for layer `i` can get serialized to run _after_ the `copy_h2d_async` for layer `i-1` H2D copies even though they are running in different streams. This prevents the `copy_h2d_async` for layer `i-1` to overlap with layer `i` backward compute.
        - For now, this can be worked around with `reshard_after_forward=False`.
        - Trick/hack from yifuwang : sleep 1 ms in the `offload_stream` before un-offloading (https://fburl.com/perfdoctor/ms47gqvp) --> allows prioritizing the `split_with_sizes_copy` H2D copy
        - The CPU issue order of `copy_h2d_async` for layer `i-1` -> `split_with_sizes_copy` for layer `i` -> layer `i` backward compute comes from running the `copy_h2d_async` for layer `i-1` using a module full pre-backward hook.
    - **!** If the user offloads too many activations, the program can become slow and/or freeze. Further, the first few iterations are slow due to `cudaHostAlloc` calls warming up the CPU caching allocator. This might be brittle if other parts of the program (e.g. checkpointing) also use pinned memory. If we do not `gc.collect()` every iteration, the pinned memory does not seem to be freed, so the allocator does not reuse it in subsequent iterations.
    - **!** We do not have a good way to apply a predicate to decide which activation tensors to offload. With the pack hook API, we only see the tensor, not any other information like which op constructed the tensor.
    
    **Examples**
    - Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, `reshard_after_forward=False`:
        - Trace: https://fburl.com/perfdoctor/r1yf0lqf
        - Reserved memory: 65.67GiB(69.10%)
        - WPS: 5,294  MFU: 31.00%
    - Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, `reshard_after_forward=True`:
        - Trace: https://fburl.com/perfdoctor/qbhr98az
        - Reserved memory: 54.01GiB(56.83%)
        - WPS: 4,085  MFU: 23.92% (mainly because H2Ds in backward are not overlapped)
        - If we use yifuwang's trick, we can get WPS: 5,073, MFU: 29.70% without changing reserved memory
    - Baseline Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, `reshard_after_forward=True`:
        - Reserved memory: 78.38GiB(82.48%)
        - WPS: 6,341  MFU: 37.13%
    
    [ghstack-poisoned]
    awgu committed Jul 18, 2024
    Configuration menu
    Copy the full SHA
    9dc418a View commit details
    Browse the repository at this point in the history