Skip to content

Commit

Permalink
add_npu_support for goup_norm (#10496)
Browse files Browse the repository at this point in the history
原来的实现中只考虑了cuda作为底层实现,而没考虑npu,mlu等其他硬件
  • Loading branch information
woaixiaoxiao authored May 20, 2024
1 parent b8c457c commit 57951e4
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion python/oneflow/nn/modules/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def group_norm(
), "The channels of input tensor must equal num_channels"

affine = weight is not None and bias is not None
if input.is_cuda:
if not input.is_cpu:
return flow._C.group_norm(input, weight, bias, affine, num_groups, eps)
else:
origin_shape = input.shape
Expand Down

0 comments on commit 57951e4

Please sign in to comment.