From e36ab6c2061807861dcdb90269254ed3603609bc Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 30 Apr 2024 15:50:03 -0700 Subject: [PATCH] use original copy_ Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/dtypes/nf4tensor.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 9ae3f87ec..06c985800 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -288,18 +288,18 @@ def mm_default(func, *args, **kwargs): aten.copy_.default, ] ) -def nf4_copy_(aten_op, args, kwargs=None): - assert len(args) == 2 and (kwargs is None or len(kwargs) == 0), "only support aten.copy_.default with 2 args" - original: NF4Tensor = args[0] - copy_in: torch.Tensor = args[1] +def copy_(func, *args, **kwargs): + assert len(args[0]) == 2 and len(kwargs) == 0, "only support aten.copy_.default with 2 args" + original: NF4Tensor = args[0][0] + copy_in: torch.Tensor = args[0][1] + + # Base Case if same_metadata(original, copy_in): - attrs, _ = original.__tensor_flatten__() - for attr in attrs: - inner_tensor_orig = getattr(original, attr) - inner_tensor_copy_in = getattr(copy_in, attr) - aten_op(inner_tensor_orig, inner_tensor_copy_in, **kwargs) - return original + original_tensors = original.__tensor_flatten__()[0] + for tensor_name in original_tensors: + getattr(original, tensor_name).copy_(getattr(copy_in, tensor_name)) + return # Convert Non NF4Tensor into NF4 for copy in if not isinstance(copy_in, NF4Tensor):