Skip to content

Commit

Permalink
Be readable AND don't keep the memory forever
Browse files Browse the repository at this point in the history
  • Loading branch information
janeyx99 committed Oct 28, 2024
1 parent 862d462 commit 83e6b51
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions torchtune/training/_activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,16 +292,19 @@ def wait_and_del_remaining_references() -> None:

# This hook will be called when the recompute logic of the AC
# caches a recomputed saved tensor.
storage_id = id(maybe_gpu_tensor.untyped_storage())

def gather_views_hook(recomputed_tensor):
if id(recomputed_tensor.untyped_storage()) == storage_id:
if (
recomputed_tensor.untyped_storage()
is maybe_gpu_tensor.untyped_storage()
):
recomputed_tensors_that_are_views.append(
recomputed_tensor.data_ptr()
)

torch.utils.checkpoint._register_checkpoint_saved_tensor_hook(
gather_views_hook
hook_handle = (
torch.utils.checkpoint._register_checkpoint_saved_tensor_hook(
gather_views_hook
)
)

def hook(outputs, inputs):
Expand All @@ -314,6 +317,7 @@ def hook(outputs, inputs):
# in the compute stream (s0 here). Note that the con here is we introduce non-deterministic
# memory usage, but this case should not happen often.
unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
hook_handle.remove()
if any(
o.untyped_storage() is unpacked_tensor.untyped_storage()
for o in outputs
Expand Down

0 comments on commit 83e6b51

Please sign in to comment.