**Current UX**
- We use a `saved_tensors_hooks` context manager, which should be wrapped around `module.forward`. The context lets us override pack and unpack hooks that are called when saving an activation for backward and using an activation in backward, respectively. See the [tutorial](https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html) for more info.
- We expose two main methods for the user from the context: `wait_for_d2h` and `copy_h2d_async`.
- By default, the D2H copies for offloading are async and use pinned memory. The user must call `wait_for_d2h` to wait on the D2H copies and free the device memory. This should be done after the compute to overlap with has been issued.
- By default, the H2D copies are sync. The user must call `copy_h2d_async` to prefetch the H2D copies as async. This should be done before the compute to overlap with has been issued.
- We show an example of this in `apply_ac` in `parallelize_llama.py` using module hooks.
**Known Problems**
- **!** Conflict with `split_with_sizes_copy`'s H2D copy (specific to FSDP2):
- FSDP2's all-gather copy-out uses `split_with_sizes_copy`, which first issues a mini-H2D copy to send metadata needed for the main copy.
- When the CPU issue order is `copy_h2d_async` for layer `i-1` -> `split_with_sizes_copy` for layer `i` -> layer `i` backward compute, the mini-H2D copy for `split_with_sizes_copy` for layer `i` can get serialized to run _after_ the `copy_h2d_async` for layer `i-1` H2D copies even though they are running in different streams. This prevents the `copy_h2d_async` for layer `i-1` to overlap with layer `i` backward compute.
- For now, this can be worked around with `reshard_after_forward=False`.
- Trick/hack from yifuwang : sleep 1 ms in the `offload_stream` before un-offloading (https://fburl.com/perfdoctor/ms47gqvp) --> allows prioritizing the `split_with_sizes_copy` H2D copy
- The CPU issue order of `copy_h2d_async` for layer `i-1` -> `split_with_sizes_copy` for layer `i` -> layer `i` backward compute comes from running the `copy_h2d_async` for layer `i-1` using a module full pre-backward hook.
- **!** If the user offloads too many activations, the program can become slow and/or freeze. Further, the first few iterations are slow due to `cudaHostAlloc` calls warming up the CPU caching allocator. This might be brittle if other parts of the program (e.g. checkpointing) also use pinned memory. If we do not `gc.collect()` every iteration, the pinned memory does not seem to be freed, so the allocator does not reuse it in subsequent iterations.
- **!** We do not have a good way to apply a predicate to decide which activation tensors to offload. With the pack hook API, we only see the tensor, not any other information like which op constructed the tensor.
**Examples**
- Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, `reshard_after_forward=False`:
- Trace: https://fburl.com/perfdoctor/r1yf0lqf
- Reserved memory: 65.67GiB(69.10%)
- WPS: 5,294 MFU: 31.00%
- Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, `reshard_after_forward=True`:
- Trace: https://fburl.com/perfdoctor/qbhr98az
- Reserved memory: 54.01GiB(56.83%)
- WPS: 4,085 MFU: 23.92% (mainly because H2Ds in backward are not overlapped)
- If we use yifuwang's trick, we can get WPS: 5,073, MFU: 29.70% without changing reserved memory
- Baseline Llama3-8B: DP=8, local batch size 1, sequence length 8192, no AC, `reshard_after_forward=True`:
- Reserved memory: 78.38GiB(82.48%)
- WPS: 6,341 MFU: 37.13%
[ghstack-poisoned]