diff --git a/mmdet/models/layers/ema.py b/mmdet/models/layers/ema.py index bce503c4641..73a0ca67c28 100644 --- a/mmdet/models/layers/ema.py +++ b/mmdet/models/layers/ema.py @@ -63,4 +63,4 @@ def avg_func(self, averaged_param: Tensor, source_param: Tensor, """ momentum = (1 - self.momentum) * math.exp( -float(1 + steps) / self.gamma) + self.momentum - averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum) + averaged_param.lerp_(source_param, momentum)