diff --git a/ldm_patched/modules/model_sampling.py b/ldm_patched/modules/model_sampling.py index 00b3ac81b..c62212e62 100644 --- a/ldm_patched/modules/model_sampling.py +++ b/ldm_patched/modules/model_sampling.py @@ -82,7 +82,7 @@ def timestep(self, sigma): w = (low - log_sigma) / (low - high) w = w.clamp(0, 1) t = (1 - w) * low_idx + w * high_idx - return t.view(sigma.shape) + return t.view(sigma.shape).to(sigma.device) def sigma(self, timestep): t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1))