From 0702ef3c0ca8f6e81619a85c54bf145148aa760a Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 1 Nov 2024 13:33:29 -0700 Subject: [PATCH] Use existing older API, add test case --- .../training/test_activation_offloading.py | 99 ++++++++++++++++++- torchtune/training/_activation_offloading.py | 65 ++++++------ 2 files changed, 125 insertions(+), 39 deletions(-) diff --git a/tests/torchtune/training/test_activation_offloading.py b/tests/torchtune/training/test_activation_offloading.py index 5d4c968e96..8cb3c38bed 100644 --- a/tests/torchtune/training/test_activation_offloading.py +++ b/tests/torchtune/training/test_activation_offloading.py @@ -10,6 +10,8 @@ from torch import nn from torchtune.training import OffloadActivations +NUM_GPU_CYCLES_IN_ONE_SEC = 2000000000 # 2e9 is ~1s worth of GPU cycles + @gpu_test(gpu_count=1) @pytest.mark.parametrize("use_streams", [True, False]) @@ -46,7 +48,8 @@ def test_offloading_is_same_as_without(use_streams) -> None: def test_offloading_works_with_view_outputs() -> None: """ This test is quite contrived but tests against a very obscure situation where - any of the outputs of a backward node are a view of the unpacked tensor. + any of the outputs of a backward node are a view of the unpacked tensor. (See + the first line item under Note: [Track views of the unpacked]). We want to ensure that if an unpacked tensor may be used later that we do not free it too early. @@ -98,7 +101,7 @@ def forward(ctx, activation): @staticmethod def backward(ctx, viewed_activation): - torch.cuda._sleep(2000000000) # 2e9 is ~1s worth of GPU cycles + torch.cuda._sleep(NUM_GPU_CYCLES_IN_ONE_SEC) return viewed_activation == 1 class InspectEarlierActivation(torch.autograd.Function): @@ -129,3 +132,95 @@ def fwd(t): # delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd ctx.fwd_stash = {} loss_c.backward() + + +def test_offloading_works_with_view_ac_cached_buffers() -> None: + """ + Similar to test_offloading_works_with_view_outputs, but for when AC stashes + a view of the unpacked tensor. See the second line item under Note: [Track + views of the unpacked]. + + For details on how the following custom autograd function was contrived, + please see the image attached to the PR description in #1936. The visual + is more helpful than me trying to write a blob of text here. + """ + + class A(torch.autograd.Function): + @staticmethod + def forward(ctx, ones): + ctx.save_for_backward(ones * 5) # corruptedly saving 5s + return ones + + @staticmethod + def backward(ctx, activation_is_ones): + fives = ctx.saved_tensors[0] + assert torch.all(activation_is_ones) + return activation_is_ones + + class B(torch.autograd.Function): + @staticmethod + def forward(ctx, ones): + ctx.save_for_backward(ones.clone()) + return ones.clone() # important, a view of 1s will be saved in C + + @staticmethod + def backward(ctx, activation_is_ones): + saved_tensor = ctx.saved_tensors[0] + return activation_is_ones.clone() + + class C(torch.autograd.Function): + @staticmethod + def forward(ctx, ones): + ctx.save_for_backward(ones.t().t()) + return ones.clone() + + @staticmethod + def backward(ctx, grad): + saved_tensor = ctx.saved_tensors[0] + return saved_tensor == 1 + + class D(torch.autograd.Function): + @staticmethod + def forward(ctx, ones): + ctx.save_for_backward(torch.rand_like(ones)) + return torch.rand_like(ones) + + @staticmethod + def backward(ctx, grad): + saved_tensor = ctx.saved_tensors[0] + torch.cuda._sleep(NUM_GPU_CYCLES_IN_ONE_SEC) + return torch.rand_like(grad) + + class E(torch.autograd.Function): + @staticmethod + def forward(ctx, ones): + ctx.save_for_backward(torch.rand_like(ones)) + return torch.rand_like(ones) + + @staticmethod + def backward(ctx, grad): + # It doesn't matter what E saves, but it needs to save something + # just to trigger AC recompute to fill in this tensor. + saved_tensor = ctx.saved_tensors[0] + return torch.rand_like(grad) + + def checkpointed_region(b): + c = C.apply(b) + d = D.apply(c) + return E.apply(d) + + def fwd(t): + a = A.apply(t) + b = B.apply(a) + e = torch.utils.checkpoint.checkpoint( + checkpointed_region, b, use_reentrant=False + ) + return e.sum() + + tensor = torch.ones(256, 1024, device="cuda", requires_grad=True) + ctx = OffloadActivations(use_streams=True) + with ctx: + loss = fwd(tensor) + # delete the fwd stash to avoid our peek-in-fwd-stash heuristic in the bwd + ctx.fwd_stash = {} + loss.backward() diff --git a/torchtune/training/_activation_offloading.py b/torchtune/training/_activation_offloading.py index e2b0f705a6..d81a814c8d 100644 --- a/torchtune/training/_activation_offloading.py +++ b/torchtune/training/_activation_offloading.py @@ -283,49 +283,40 @@ def wait_and_del_remaining_references() -> None: self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor # Note: [Track views of the unpacked] - # Why do we get the use count of the unpacked tensor here? We want an initial - # count to compare to later, during the post-hook of the backward node, when we - # need to decide whether we're allowed to free the tensor yet. In what obscure - # cases must we delay freeing the tensor (and thus call record_stream)? - # 1. Any of the outputs of the backward node is a view of the unpacked tensor. - # 2. In the case that this unpacked tensor will be used in a checkpointed - # region, if one of the recomputed saved tensors ends up as a view of the - # unpacked tensor. - # 3. The user abuses the system somehow and manually relies on the unpacked - # tensor to exist after the backward node has executed. - # - # Side note: the use_count() API is new, so this check is only valid for - # torch versions after the API has been introduced. - if torch.__version__ > "2.6.0.dev20241101": - storage_refcount = ( - maybe_gpu_tensor.untyped_storage().use_count() - ) + # Why do we get the use count of the unpacked tensor here? We want an + # initial count to compare to later, during the post-hook of the + # backward node, when we need to decide whether we're allowed to free + # the tensor yet. In what obscure cases must we delay freeing the + # tensor (and thus call record_stream)? + # 1. Any of the outputs of the backward node is a view of the unpacked + # tensor. + # 2. In the case that this unpacked tensor will be used in a + # checkpointed region, if one of the recomputed saved tensors ends + # up as a view of the unpacked tensor. + # 3. The user abuses the system somehow and manually relies on the + # unpacked tensor to exist after the backward node has executed. + storage_refcount = torch._C._storage_Use_Count( + maybe_gpu_tensor.untyped_storage()._cdata + ) def hook(outputs, inputs): # create events for the current node inputs/outputs if they were streamed in if brought_back_from_cpu: # See Note: [Track views of the unpacked] - # IF any of the outputs is a view of the tensor, OR if a view of the tensor has been saved - # as a part of checkpoint's recompute process, OR the user has abusedly incurred a reference - # on the unpacked tensor, THEN the tensor might be used later and we cannot presume to - # delete it after only the current node is done! So we use our frenemy, record_stream, to - # ensure the Tensor stays unmessed with until it's done getting used in the compute stream - # (s0 here). Note that the con here is we introduce non-deterministic (thus higher) memory - # usage, but this case should not happen often. + # IF any of the outputs is a view of the tensor, OR if a view of + # the tensor has been saved as a part of checkpoint's recompute + # process, OR the user has abusedly incurred a reference on the + # unpacked tensor, THEN the tensor might be used later and we + # cannot presume to delete it after only the current node is + # done! So we use our frenemy, record_stream, to ensure the + # Tensor stays unmessed with until it's done getting used in the + # compute stream (s0 here). Note that the con here is we introduce + # non-deterministic (thus higher) memory usage, but this case + # should not happen often. unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id] - - storage_might_be_used_later = ( - unpacked_tensor.untyped_storage().use_count() - > storage_refcount - if torch.__version__ > "2.6.0.dev20241101" - else any( - o.untyped_storage() is unpacked_tensor.untyped_storage() - for o in outputs - if o is not None - ) - ) - - if storage_might_be_used_later: + if torch._C._storage_Use_Count( + unpacked_tensor.untyped_storage()._cdata + ): unpacked_tensor.record_stream(self.s0) del self.bwd_tensor_stash[unpack_tensor_id] else: