Skip to content

Commit

Permalink
A more encompassing fix for offloading + ac
Browse files Browse the repository at this point in the history
  • Loading branch information
janeyx99 committed Oct 31, 2024
1 parent 83e6b51 commit 63e3a08
Showing 1 changed file with 38 additions and 35 deletions.
73 changes: 38 additions & 35 deletions torchtune/training/_activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,47 +282,50 @@ def wait_and_del_remaining_references() -> None:
# Stash the tensor to keep memory alive until compute stream is complete
self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor

# Note: [AC interaction]
# Now, in the case that this unpacked tensor will be used in a
# checkpointed region, there is a chance that one of the recomputed
# saved tensors will be a view of the unpacked tensor. We need to
# track this case so that we call record_stream on the unpacked
# tensor when this happens instead of freeing too early.
recomputed_tensors_that_are_views = []

# This hook will be called when the recompute logic of the AC
# caches a recomputed saved tensor.
def gather_views_hook(recomputed_tensor):
if (
recomputed_tensor.untyped_storage()
is maybe_gpu_tensor.untyped_storage()
):
recomputed_tensors_that_are_views.append(
recomputed_tensor.data_ptr()
)

hook_handle = (
torch.utils.checkpoint._register_checkpoint_saved_tensor_hook(
gather_views_hook
# Note: [Track views of the unpacked]
# Why do we get the use count of the unpacked tensor here? We want an initial
# count to compare to later, during the post-hook of the backward node, when we
# need to decide whether we're allowed to free the tensor yet. In what obscure
# cases must we delay freeing the tensor (and thus call record_stream)?
# 1. Any of the outputs of the backward node is a view of the unpacked tensor.
# 2. In the case that this unpacked tensor will be used in a checkpointed
# region, if one of the recomputed saved tensors ends up as a view of the
# unpacked tensor.
# 3. The user abuses the system somehow and manually relies on the unpacked
# tensor to exist after the backward node has executed.
#
# Side note: the use_count() API is new, so this check is only valid for
# torch versions after the API has been introduced.
if torch.__version__ > "2.6.0.dev20241101":
storage_refcount = (
maybe_gpu_tensor.untyped_storage().use_count()
)
)

def hook(outputs, inputs):
# create events for the current node inputs/outputs if they were streamed in
if brought_back_from_cpu:
# if any of the outputs is a view of the tensor, OR if a view of the tensor has been saved
# as a part of checkpoint's recompute process, meaning the tensor might be used later,
# we cannot presume to delete it after only the current node is done! So we use our frenemy,
# record_stream, to ensure the Tensor stays unmessed with until it's done getting used
# in the compute stream (s0 here). Note that the con here is we introduce non-deterministic
# memory usage, but this case should not happen often.
# See Note: [Track views of the unpacked]
# IF any of the outputs is a view of the tensor, OR if a view of the tensor has been saved
# as a part of checkpoint's recompute process, OR the user has abusedly incurred a reference
# on the unpacked tensor, THEN the tensor might be used later and we cannot presume to
# delete it after only the current node is done! So we use our frenemy, record_stream, to
# ensure the Tensor stays unmessed with until it's done getting used in the compute stream
# (s0 here). Note that the con here is we introduce non-deterministic (thus higher) memory
# usage, but this case should not happen often.
unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
hook_handle.remove()
if any(
o.untyped_storage() is unpacked_tensor.untyped_storage()
for o in outputs
if o is not None
) or (len(recomputed_tensors_that_are_views) > 0):

storage_might_be_used_later = (
unpacked_tensor.untyped_storage().use_count()
> storage_refcount
if torch.__version__ > "2.6.0.dev20241101"
else any(
o.untyped_storage() is unpacked_tensor.untyped_storage()
for o in outputs
if o is not None
)
)

if storage_might_be_used_later:
unpacked_tensor.record_stream(self.s0)
del self.bwd_tensor_stash[unpack_tensor_id]
else:
Expand Down

0 comments on commit 63e3a08

Please sign in to comment.