diff --git a/opacus/utils/tensor_utils.py b/opacus/utils/tensor_utils.py index 27a111d8..79eb7ec8 100644 --- a/opacus/utils/tensor_utils.py +++ b/opacus/utils/tensor_utils.py @@ -322,14 +322,15 @@ def filter_dilated_rows( kernel_rank = len(kernel_size) indices_to_keep = [ - list(range(0, dilated_kernel_size[i], dilation[i])) for i in range(kernel_rank) + torch.arange(0, dilated_kernel_size[i], dilation[i], device=tensor.device) + for i in range(kernel_rank) ] - tensor_np = tensor.numpy() - axis_offset = len(tensor.shape) - kernel_rank for dim in range(kernel_rank): - tensor_np = np.take(tensor_np, indices_to_keep[dim], axis=axis_offset + dim) + tensor = torch.index_select( + tensor, dim=axis_offset + dim, index=indices_to_keep[dim] + ) - return torch.Tensor(tensor_np) + return tensor