diff --git a/torchtune/training/_activation_offloading.py b/torchtune/training/_activation_offloading.py index a92f18f57..cdd77aba2 100644 --- a/torchtune/training/_activation_offloading.py +++ b/torchtune/training/_activation_offloading.py @@ -292,16 +292,19 @@ def wait_and_del_remaining_references() -> None: # This hook will be called when the recompute logic of the AC # caches a recomputed saved tensor. - storage_id = id(maybe_gpu_tensor.untyped_storage()) - def gather_views_hook(recomputed_tensor): - if id(recomputed_tensor.untyped_storage()) == storage_id: + if ( + recomputed_tensor.untyped_storage() + is maybe_gpu_tensor.untyped_storage() + ): recomputed_tensors_that_are_views.append( recomputed_tensor.data_ptr() ) - torch.utils.checkpoint._register_checkpoint_saved_tensor_hook( - gather_views_hook + hook_handle = ( + torch.utils.checkpoint._register_checkpoint_saved_tensor_hook( + gather_views_hook + ) ) def hook(outputs, inputs): @@ -314,6 +317,7 @@ def hook(outputs, inputs): # in the compute stream (s0 here). Note that the con here is we introduce non-deterministic # memory usage, but this case should not happen often. unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id] + hook_handle.remove() if any( o.untyped_storage() is unpacked_tensor.untyped_storage() for o in outputs