Skip to content

Commit

Permalink
i
Browse files Browse the repository at this point in the history
Update sd_samplers_common.py

Update sd_hijack.py

i

Update sd_models.py

Update sd_models.py

Update forge_loader.py

Update sd_models.py

i

Update sd_model.py

i

Update sd_models.py

Create sd_model.py

i

i

Update sd_models.py

i

Update sd_models.py

Update sd_models.py

i

i

Update sd_samplers_common.py

i

Update sd_models.py

Update sd_models.py

Update sd_samplers_common.py

Update sd_models.py

Update sd_models.py

Update sd_models.py

Update sd_models.py

Update sd_samplers_common.py

i

Update shared_options.py

Update prompt_parser.py
  • Loading branch information
lllyasviel committed Jan 14, 2024
1 parent 2af39f1 commit 1c7da49
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 175 deletions.
48 changes: 5 additions & 43 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,44 +627,7 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):

for i in range(batch.shape[0]):
sample = decode_first_stage(model, batch[i:i + 1])[0]

if check_for_nans:

try:
devices.test_for_nans(sample, "vae")
except devices.NansException as e:
if shared.opts.auto_vae_precision_bfloat16:
autofix_dtype = torch.bfloat16
autofix_dtype_text = "bfloat16"
autofix_dtype_setting = "Automatically convert VAE to bfloat16"
autofix_dtype_comment = ""
elif shared.opts.auto_vae_precision:
autofix_dtype = torch.float32
autofix_dtype_text = "32-bit float"
autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
else:
raise e

if devices.dtype_vae == autofix_dtype:
raise e

errors.print_error_explanation(
"A tensor with all NaNs was produced in VAE.\n"
f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
)

devices.dtype_vae = autofix_dtype
model.first_stage_model.to(devices.dtype_vae)
batch = batch.to(devices.dtype_vae)

sample = decode_first_stage(model, batch[i:i + 1])[0]

if target_device is not None:
sample = sample.to(target_device)

samples.append(sample)
samples.append(sample.to(target_device))

return samples

Expand Down Expand Up @@ -940,8 +903,7 @@ def rescale_zero_terminal_snr_abar(alphas_cumprod):
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)

with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)

if p.scripts is not None:
ps = scripts.PostSampleArgs(samples_ddim)
Expand Down Expand Up @@ -1255,7 +1217,7 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subs
image = np.array(self.firstpass_image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
image = torch.from_numpy(np.expand_dims(image, axis=0))
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image.to(shared.device, dtype=torch.float32)

if opts.sd_vae_encode_method != 'Full':
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
Expand Down Expand Up @@ -1339,7 +1301,7 @@ def save_intermediate(image, index):
batch_images.append(image)

decoded_samples = torch.from_numpy(np.array(batch_images))
decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
decoded_samples = decoded_samples.to(shared.device, dtype=torch.float32)

if opts.sd_vae_encode_method != 'Full':
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
Expand Down Expand Up @@ -1631,7 +1593,7 @@ def init(self, all_prompts, all_seeds, all_subseeds):
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")

image = torch.from_numpy(batch_images)
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image.to(shared.device, dtype=torch.float32)

if opts.sd_vae_encode_method != 'Full':
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
Expand Down
6 changes: 6 additions & 0 deletions modules/prompt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,12 @@ def __init__(self, x, shape):
def shape(self):
return self["crossattn"].shape

def to(self, *args, **kwargs):
for k in self.keys():
if isinstance(self[k], torch.Tensor):
self[k] = self[k].to(*args, **kwargs)
return self


def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond
Expand Down
52 changes: 4 additions & 48 deletions modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,58 +56,14 @@ def list_optimizers():
optimizers.extend(new_optimizers)


def apply_optimizations(option=None):
def apply_optimizations(*args, **kwargs):
global current_optimizer

undo_optimizations()

if len(optimizers) == 0:
# a script can access the model very early, and optimizations would not be filled by then
current_optimizer = None
return ''

ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th

sgm.modules.diffusionmodules.model.nonlinearity = silu
sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th

if current_optimizer is not None:
current_optimizer.undo()
current_optimizer = None

selection = option or shared.opts.cross_attention_optimization
if selection == "Automatic" and len(optimizers) > 0:
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
else:
matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)

if selection == "None":
matching_optimizer = None
elif selection == "Automatic" and shared.cmd_opts.disable_opt_split_attention:
matching_optimizer = None
elif matching_optimizer is None:
matching_optimizer = optimizers[0]

if matching_optimizer is not None:
print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
matching_optimizer.apply()
print("done.")
current_optimizer = matching_optimizer
return current_optimizer.name
else:
print("Disabling attention optimization")
return ''
current_optimizer = None
return ''


def undo_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward

sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
return


def fix_checkpoint():
Expand Down
5 changes: 0 additions & 5 deletions modules/sd_hijack_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ def forward(self, x):
def hijack_ddpm_edit():
global ddpm_edit_hijack
if not ddpm_edit_hijack:
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)


Expand All @@ -76,9 +74,6 @@ def hijack_ddpm_edit():
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)

first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16
first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)

CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast)
Expand Down
96 changes: 27 additions & 69 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from modules.timer import Timer
import tomesd
import numpy as np
from modules_forge import forge_loader


model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
Expand Down Expand Up @@ -366,10 +368,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")

if devices.fp8:
# prevent model to load state dict in fp8
model.half()

if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title

Expand All @@ -379,13 +377,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.is_sdxl = hasattr(model, 'conditioner')
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
model.is_sd1 = not model.is_sdxl and not model.is_sd2
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in model.state_dict().keys()
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)

if model.is_ssd:
sd_hijack.model_hijack.convert_sdxl_to_ssd(model)

if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict.copy()
Expand All @@ -395,65 +390,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer

del state_dict

if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")

if shared.cmd_opts.no_half:
model.float()
model.alphas_cumprod_original = model.alphas_cumprod
devices.dtype_unet = torch.float32
timer.record("apply float()")
else:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)

# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
model.first_stage_model = None
# with --upcast-sampling, don't convert the depth model weights to float16
if shared.cmd_opts.upcast_sampling and depth_model:
model.depth_model = None

alphas_cumprod = model.alphas_cumprod
model.alphas_cumprod = None
model.half()
model.alphas_cumprod = alphas_cumprod
model.alphas_cumprod_original = alphas_cumprod
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model

devices.dtype_unet = torch.float16
timer.record("apply half()")

for module in model.modules():
if hasattr(module, 'fp16_weight'):
del module.fp16_weight
if hasattr(module, 'fp16_bias'):
del module.fp16_bias

if check_fp8(model):
devices.fp8 = True
first_stage = model.first_stage_model
model.first_stage_model = None
for module in model.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
if shared.opts.cache_fp16_weight:
module.fp16_weight = module.weight.data.clone().cpu().half()
if module.bias is not None:
module.fp16_bias = module.bias.data.clone().cpu().half()
module.to(torch.float8_e4m3fn)
model.first_stage_model = first_stage
timer.record("apply fp8")
else:
devices.fp8 = False

devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16

model.first_stage_model.to(devices.dtype_vae)
timer.record("apply dtype to VAE")

# clean up cache if limit is reached
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False)
Expand Down Expand Up @@ -670,6 +606,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):

timer.record("load config")

if hasattr(sd_config.model.params, 'network_config'):
sd_config.model.params.network_config.target = 'modules_forge.forge_loader.FakeObject'

if hasattr(sd_config.model.params, 'unet_config'):
sd_config.model.params.unet_config.target = 'modules_forge.forge_loader.FakeObject'

if hasattr(sd_config.model.params, 'first_stage_config'):
sd_config.model.params.first_stage_config.target = 'modules_forge.forge_loader.FakeObject'

print(f"Creating model from config: {checkpoint_config}")

sd_model = None
Expand Down Expand Up @@ -700,8 +645,21 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
'': torch.float16,
}

with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
state_dict_for_a1111 = {k: v for k, v in state_dict.items() if not k.startswith('model.diffusion_model.') and not k.startswith('first_stage_model.')}
state_dict_for_forge = {k: v for k, v in state_dict.items()}
del state_dict

unet_patcher, vae_patcher = forge_loader.load_unet_and_vae(state_dict_for_forge)
sd_model.first_stage_model = vae_patcher.first_stage_model
sd_model.model.diffusion_model = unet_patcher.model.diffusion_model
sd_model.unet_patcher = unet_patcher
sd_model.vae_patcher = vae_patcher
timer.record("create unet patcher")
del state_dict_for_forge

with sd_disable_initialization.LoadStateDictOnMeta(state_dict_for_a1111, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict_for_a1111, timer)
del state_dict_for_a1111
timer.record("load weights from state dict")

send_model_to_device(sd_model)
Expand Down
9 changes: 4 additions & 5 deletions modules/sd_models_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from modules import devices, shared, prompt_parser
from modules import torch_utils

import ldm_patched.modules.model_management as model_management


def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
for embedder in self.conditioner.embedders:
Expand Down Expand Up @@ -35,11 +37,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:


def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
sd = self.model.state_dict()
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
x = torch.cat([x] + cond['c_concat'], dim=1)
if self.model.diffusion_model.in_channels == 9:
x = torch.cat([x] + cond['c_concat'], dim=1)

return self.model(x, t, cond)

Expand Down
11 changes: 11 additions & 0 deletions modules/sd_samplers_cfg_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ def apply_blend(current_latent):
uncond = pad_cond(uncond, num_repeats, empty)
self.padded_cond_uncond = True

unet_dtype = self.inner_model.inner_model.unet_patcher.model.model_config.unet_config['dtype']
x_input_dtype = x_in.dtype

x_in = x_in.to(unet_dtype)
sigma_in = sigma_in.to(unet_dtype)
image_cond_in = image_cond_in.to(unet_dtype)
tensor = tensor.to(unet_dtype)
uncond = uncond.to(unet_dtype)

if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
cond_in = catenate_conds([tensor, uncond, uncond])
Expand Down Expand Up @@ -211,6 +220,8 @@ def apply_blend(current_latent):
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be

x_out = x_out.to(x_input_dtype)

denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)

Expand Down
5 changes: 2 additions & 3 deletions modules/sd_samplers_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
else:
if model is None:
model = shared.sd_model
with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
sample = model.unet_patcher.model.model_config.latent_format.process_out(sample)
x_sample = model.vae_patcher.decode(sample).movedim(-1, 1) * 2.0 - 1.0

return x_sample

Expand All @@ -71,7 +71,6 @@ def single_sample_to_image(sample, approximation=None):


def decode_first_stage(model, x):
x = x.to(devices.dtype_vae)
approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
return samples_to_images_tensor(x, approx_index, model)

Expand Down
1 change: 0 additions & 1 deletion modules/sd_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
model.first_stage_model.load_state_dict(vae_dict_1)
model.first_stage_model.to(devices.dtype_vae)


def clear_loaded_vae():
Expand Down
Loading

0 comments on commit 1c7da49

Please sign in to comment.