Skip to content

Commit

Permalink
use torchao copy_
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed May 1, 2024
1 parent cb3abb3 commit e68804a
Showing 1 changed file with 0 additions and 23 deletions.
23 changes: 0 additions & 23 deletions torchtune/utils/_register_nf4_dispatch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,3 @@ def clone(func, *args, **kwargs):
in precision.
"""
return to_nf4(args[0][0].get_original_weight())


@nf4_tensor_impl([torch.ops.aten.copy_.default])
def inplace_copy(func, *args, **kwargs):
"""
Performs an inplace copy of an incoming tensor into the tensor
being copied into. The inplace tensor is given by args[0][1] and the
tensor being copied into is given by args[0][0]. The copy is performed
by copying over all attributes. This method would have to be updated
if additional attributes are added to NF4Tensor.
"""
dest_tensor = args[0][0] # tensor we are inplace copying into
ref_tensor = to_nf4(
args[0][1].to(dest_tensor.device)
) # TODO check if nf4 tensor takes in device arg
dest_tensor.block_size = ref_tensor.block_size
dest_tensor.n_blocks = ref_tensor.n_blocks
dest_tensor.scaler_block_size = ref_tensor.scaler_block_size
dest_tensor.quantized_scalers = ref_tensor.quantized_scalers
dest_tensor.quantization_factor = ref_tensor.quantization_factor
dest_tensor.scaler_mean = ref_tensor.scaler_mean
dest_tensor.quantized_data = ref_tensor.quantized_data
dest_tensor.nf4 = ref_tensor.nf4

0 comments on commit e68804a

Please sign in to comment.