Skip to content

Commit

Permalink
implement align your steps scheduler (#726)
Browse files Browse the repository at this point in the history
Signed-off-by: blob42 <[email protected]>
  • Loading branch information
blob42 authored May 22, 2024
1 parent 62e60ad commit 49c3a08
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 3 deletions.
30 changes: 30 additions & 0 deletions ldm_patched/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from scipy import integrate
import torch
import numpy as np
from torch import nn
import torchsde
from tqdm.auto import trange, tqdm
Expand Down Expand Up @@ -38,6 +39,35 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
return append_zero(sigmas)

# align your steps
def get_sigmas_ays(n, sigma_min, sigma_max, is_sdxl=False, device='cpu'):
# https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
def loglinear_interp(t_steps, num_steps):
"""
Performs log-linear interpolation of a given array of decreasing numbers.
"""
xs = torch.linspace(0, 1, len(t_steps))
ys = torch.log(torch.tensor(t_steps[::-1]))

new_xs = torch.linspace(0, 1, num_steps)
new_ys = np.interp(new_xs, xs, ys)

interped_ys = torch.exp(torch.tensor(new_ys)).numpy()[::-1].copy()
return interped_ys

if is_sdxl:
sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029]
else:
# Default to SD 1.5 sigmas.
sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029]

if n != len(sigmas):
sigmas = np.append(loglinear_interp(sigmas, n), [0.0])
else:
sigmas.append(0.0)

return torch.FloatTensor(sigmas).to(device)


def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule."""
Expand Down
6 changes: 4 additions & 2 deletions ldm_patched/modules/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,14 +662,16 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
samples = sampler.sample(model_wrap, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return model.process_latent_out(samples.to(torch.float32))

SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "ays"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]

def calculate_sigmas_scheduler(model, scheduler_name, steps):
def calculate_sigmas_scheduler(model, scheduler_name, steps, is_sdxl=False):
if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "ays":
sigmas = k_diffusion_sampling.get_sigmas_ays(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max), is_sdxl=is_sdxl)
elif scheduler_name == "normal":
sigmas = normal_scheduler(model, steps)
elif scheduler_name == "simple":
Expand Down
7 changes: 6 additions & 1 deletion modules_forge/forge_alter_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __init__(self, sd_model, sampler_name, scheduler_name):
self.sampler_name = sampler_name
self.scheduler_name = scheduler_name
self.unet = sd_model.forge_objects.unet
self.model = sd_model

sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
super().__init__(sampler_function, sd_model, None)
Expand All @@ -20,7 +21,7 @@ def get_sigmas(self, p, steps):
sigmas = self.unet.model.model_sampling.sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
else:
sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps)
sigmas = calculate_sigmas_scheduler(self.unet.model, self.scheduler_name, steps, is_sdxl=getattr(self.model, "is_sdxl", False))
return sigmas.to(self.unet.load_device)


Expand All @@ -34,9 +35,13 @@ def constructor(m):
samplers_data_alter = [
sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm', scheduler_name='normal'), ['ddpm'], {}),
sd_samplers_common.SamplerData('DDPM Karras', build_constructor(sampler_name='ddpm', scheduler_name='karras'), ['ddpm_karras'], {}),
sd_samplers_common.SamplerData('Euler AYS', build_constructor(sampler_name='euler', scheduler_name='ays'), ['euler_ays'], {}),
sd_samplers_common.SamplerData('Euler A Turbo', build_constructor(sampler_name='euler_ancestral', scheduler_name='turbo'), ['euler_ancestral_turbo'], {}),
sd_samplers_common.SamplerData('Euler A AYS', build_constructor(sampler_name='euler_ancestral', scheduler_name='ays'), ['euler_ancestral_ays'], {}),
sd_samplers_common.SamplerData('DPM++ 2M Turbo', build_constructor(sampler_name='dpmpp_2m', scheduler_name='turbo'), ['dpmpp_2m_turbo'], {}),
sd_samplers_common.SamplerData('DPM++ 2M AYS', build_constructor(sampler_name='dpmpp_2m', scheduler_name='ays'), ['dpmpp_2m_ays'], {}),
sd_samplers_common.SamplerData('DPM++ 2M SDE Turbo', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='turbo'), ['dpmpp_2m_sde_turbo'], {}),
sd_samplers_common.SamplerData('DPM++ 2M SDE AYS', build_constructor(sampler_name='dpmpp_2m_sde', scheduler_name='ays'), ['dpmpp_2m_sde_ays'], {}),
sd_samplers_common.SamplerData('LCM Karras', build_constructor(sampler_name='lcm', scheduler_name='karras'), ['lcm_karras'], {}),
sd_samplers_common.SamplerData('Euler SGMUniform', build_constructor(sampler_name='euler', scheduler_name='sgm_uniform'), ['euler_sgm_uniform'], {}),
sd_samplers_common.SamplerData('Euler A SGMUniform', build_constructor(sampler_name='euler_ancestral', scheduler_name='sgm_uniform'), ['euler_ancestral_sgm_uniform'], {}),
Expand Down

0 comments on commit 49c3a08

Please sign in to comment.