Skip to content

Commit

Permalink
use original copy_
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Apr 30, 2024
1 parent 925602c commit e36ab6c
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,18 +288,18 @@ def mm_default(func, *args, **kwargs):
aten.copy_.default,
]
)
def nf4_copy_(aten_op, args, kwargs=None):
assert len(args) == 2 and (kwargs is None or len(kwargs) == 0), "only support aten.copy_.default with 2 args"
original: NF4Tensor = args[0]
copy_in: torch.Tensor = args[1]
def copy_(func, *args, **kwargs):
assert len(args[0]) == 2 and len(kwargs) == 0, "only support aten.copy_.default with 2 args"
original: NF4Tensor = args[0][0]
copy_in: torch.Tensor = args[0][1]

# Base Case

if same_metadata(original, copy_in):
attrs, _ = original.__tensor_flatten__()
for attr in attrs:
inner_tensor_orig = getattr(original, attr)
inner_tensor_copy_in = getattr(copy_in, attr)
aten_op(inner_tensor_orig, inner_tensor_copy_in, **kwargs)
return original
original_tensors = original.__tensor_flatten__()[0]
for tensor_name in original_tensors:
getattr(original, tensor_name).copy_(getattr(copy_in, tensor_name))
return

# Convert Non NF4Tensor into NF4 for copy in
if not isinstance(copy_in, NF4Tensor):
Expand Down

0 comments on commit e36ab6c

Please sign in to comment.