Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work with new ComfyUI ModelPatcher update (NOT backwards compatible) #460

Merged
merged 3 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 64 additions & 55 deletions animatediff/model_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,53 +220,54 @@ def get_combined_hooked_patches(self, lora_hooks: LoraHookGroup):
combined_patches[key] = current_patches
return combined_patches

def model_patches_to(self, device):
super().model_patches_to(device)

def patch_model(self, device_to=None, patch_weights=True):
def patch_model(self, *args, **kwargs):
was_injected = False
if self.currently_injected:
self.eject_model()
was_injected = True
# first, perform model patching
if patch_weights: # TODO: keep only 'else' portion when don't need to worry about past comfy versions
patched_model = super().patch_model(device_to)
else:
patched_model = super().patch_model(device_to, patch_weights)
# finally, perform motion model injection
self.inject_model()
patched_model = super().patch_model(*args, **kwargs)
# bring injection back to original state
if was_injected and not self.currently_injected:
self.inject_model()
return patched_model

def patch_model_lowvram(self, *args, **kwargs):
def load(self, device_to=None, lowvram_model_memory=0, *args, **kwargs):
self.eject_model()
try:
return super().patch_model_lowvram(*args, **kwargs)
return super().load(device_to=device_to, lowvram_model_memory=lowvram_model_memory, *args, **kwargs)
finally:
# check if any modules have weight_function or bias_function that is not None
# NOTE: this serves no purpose currently, but I have it here for future reasons
for n, m in self.model.named_modules():
if not hasattr(m, "comfy_cast_weights"):
continue
if getattr(m, "weight_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.weight"] = n
if getattr(m, "bias_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.bias"] = n
self.inject_model()
if lowvram_model_memory > 0:
self._patch_lowvram_extras()

def _patch_lowvram_extras(self):
# check if any modules have weight_function or bias_function that is not None
# NOTE: this serves no purpose currently, but I have it here for future reasons
self.model_params_lowvram = False
self.model_params_lowvram_keys.clear()
for n, m in self.model.named_modules():
if not hasattr(m, "comfy_cast_weights"):
continue
if getattr(m, "weight_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.weight"] = n
if getattr(m, "bias_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.bias"] = n

def unpatch_model(self, device_to=None, unpatch_weights=True):
# first, eject motion model from unet
self.eject_model()
# finally, do normal model unpatching
if unpatch_weights: # TODO: keep only 'else' portion when don't need to worry about past comfy versions
if unpatch_weights:
# handle hooked_patches first
self.clean_hooks()
try:
return super().unpatch_model(device_to)
finally:
self.model_params_lowvram = False
self.model_params_lowvram_keys.clear()
else:
try:
return super().unpatch_model(device_to, unpatch_weights)
finally:
self.model_params_lowvram = False
self.model_params_lowvram_keys.clear()
try:
return super().unpatch_model(device_to, unpatch_weights)
finally:
self.model_params_lowvram = False
self.model_params_lowvram_keys.clear()

def partially_load(self, *args, **kwargs):
# partially_load calls patch_model, but we don't want to inject model in the intermediate call;
Expand Down Expand Up @@ -625,31 +626,37 @@ def patch_hooked_replace_weight_to_device(self, model_sd: dict, replace_patches:
else:
comfy.utils.set_attr_param(self.model, key, out_weight)

def patch_model(self, device_to=None, patch_weights=True, *args, **kwargs):
def patch_model(self, device_to=None, *args, **kwargs):
if self.desired_lora_hooks is not None:
self.patches_backup = self.patches.copy()
relevant_patches = self.get_combined_hooked_patches(lora_hooks=self.desired_lora_hooks)
for key in relevant_patches:
self.patches.setdefault(key, [])
self.patches[key].extend(relevant_patches[key])
self.current_lora_hooks = self.desired_lora_hooks
return super().patch_model(device_to, patch_weights, *args, **kwargs)
return super().patch_model(device_to, *args, **kwargs)

def patch_model_lowvram(self, *args, **kwargs):
def load(self, device_to=None, lowvram_model_memory=0, *args, **kwargs):
try:
return super().patch_model_lowvram(*args, **kwargs)
return super().load(device_to=device_to, lowvram_model_memory=lowvram_model_memory, *args, **kwargs)
finally:
# check if any modules have weight_function or bias_function that is not None
# NOTE: this serves no purpose currently, but I have it here for future reasons
for n, m in self.model.named_modules():
if not hasattr(m, "comfy_cast_weights"):
continue
if getattr(m, "weight_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.weight"] = n
if getattr(m, "bias_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.weight"] = n
if lowvram_model_memory > 0:
self._patch_lowvram_extras()

def _patch_lowvram_extras(self):
# check if any modules have weight_function or bias_function that is not None
# NOTE: this serves no purpose currently, but I have it here for future reasons
self.model_params_lowvram = False
self.model_params_lowvram_keys.clear()
for n, m in self.model.named_modules():
if not hasattr(m, "comfy_cast_weights"):
continue
if getattr(m, "weight_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.weight"] = n
if getattr(m, "bias_function", None) is not None:
self.model_params_lowvram = True
self.model_params_lowvram_keys[f"{n}.weight"] = n

def unpatch_model(self, device_to=None, unpatch_weights=True, *args, **kwargs):
try:
Expand Down Expand Up @@ -797,10 +804,14 @@ def __init__(self, *args, **kwargs):
self.was_within_range = False
self.prev_sub_idxs = None
self.prev_batched_number = None

def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, *args, **kwargs):
patched_model = super().patch_model_lowvram(device_to, lowvram_model_memory, force_patch_weights, *args, **kwargs)

def load(self, device_to=None, lowvram_model_memory=0, *args, **kwargs):
to_return = super().load(device_to=device_to, lowvram_model_memory=lowvram_model_memory, *args, **kwargs)
if lowvram_model_memory > 0:
self._patch_lowvram_extras(device_to=device_to)
return to_return

def _patch_lowvram_extras(self, device_to=None):
# figure out the tensors (likely pe's) that should be cast to device besides just the named_modules
remaining_tensors = list(self.model.state_dict().keys())
named_modules = []
Expand All @@ -817,8 +828,6 @@ def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patc
if device_to is not None:
comfy.utils.set_attr(self.model, key, comfy.utils.get_attr(self.model, key).to(device_to))

return patched_model

def pre_run(self, model: ModelPatcherAndInjector):
self.cleanup()
self.model.set_scale(self.scale_multival, self.per_block_list)
Expand Down
81 changes: 20 additions & 61 deletions animatediff/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,7 @@
from comfy.ldm.modules.diffusionmodules import openaimodel
import comfy.model_management
import comfy.samplers
import comfy.sample
SAMPLE_FALLBACK = False
try:
import comfy.sampler_helpers
except ImportError:
SAMPLE_FALLBACK = True
import comfy.sampler_helpers
import comfy.utils
from comfy.controlnet import ControlBase
from comfy.model_base import BaseModel
Expand Down Expand Up @@ -291,10 +286,7 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara
self.orig_diffusion_model_forward = model.model.diffusion_model.forward
self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers
self.orig_get_area_and_mult = comfy.samplers.get_area_and_mult
if SAMPLE_FALLBACK: # for backwards compatibility, for now
self.orig_get_additional_models = comfy.sample.get_additional_models
else:
self.orig_get_additional_models = comfy.sampler_helpers.get_additional_models
self.orig_get_additional_models = comfy.sampler_helpers.get_additional_models
self.orig_apply_model = model.model.apply_model
# Inject Functions
openaimodel.forward_timestep_embed = forward_timestep_embed_factory()
Expand Down Expand Up @@ -324,10 +316,7 @@ def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionPara
del info
comfy.samplers.sampling_function = evolved_sampling_function
comfy.samplers.get_area_and_mult = get_area_and_mult_ADE
if SAMPLE_FALLBACK: # for backwards compatibility, for now
comfy.sample.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models)
else:
comfy.sampler_helpers.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models)
comfy.sampler_helpers.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models)
# create temp_uninjector to help facilitate uninjecting functions
self.temp_uninjector = GroupnormUninjectHelper(self)

Expand All @@ -341,10 +330,7 @@ def restore_functions(self, model: ModelPatcherAndInjector):
model.model.diffusion_model.forward = self.orig_diffusion_model_forward
comfy.samplers.sampling_function = self.orig_sampling_function
comfy.samplers.get_area_and_mult = self.orig_get_area_and_mult
if SAMPLE_FALLBACK: # for backwards compatibility, for now
comfy.sample.get_additional_models = self.orig_get_additional_models
else:
comfy.sampler_helpers.get_additional_models = self.orig_get_additional_models
comfy.sampler_helpers.get_additional_models = self.orig_get_additional_models
model.model.apply_model = self.orig_apply_model
except AttributeError:
logger.error("Encountered AttributeError while attempting to restore functions - likely, an error occured while trying " + \
Expand Down Expand Up @@ -505,17 +491,8 @@ def ad_callback(step, x0, x, total_steps):
if is_custom:
iter_kwargs[IterationOptions.SAMPLER] = None #args[-5]
else:
if SAMPLE_FALLBACK: # backwards compatibility, for now
# in older comfy, model needs to be loaded to get proper model_sampling to be used for sigmas
comfy.model_management.load_model_gpu(model)
iter_model = model.model
else:
iter_model = model
current_device = None
if hasattr(model, "current_device"): # backwards compatibility, for now
current_device = model.current_device
else:
current_device = model.model.device
iter_model = model
current_device = model.model.device
iter_kwargs[IterationOptions.SAMPLER] = comfy.samplers.KSampler(
iter_model, steps=999, #steps=args[-7],
device=current_device, sampler=args[-5],
Expand Down Expand Up @@ -653,35 +630,20 @@ def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond,
model_options["transformer_options"]["ad_params"] = ADGS.create_exposed_params()

if not ADGS.is_using_sliding_context():
cond_pred, uncond_pred = calc_cond_uncond_batch_wrapper(model, [cond, uncond_], x, timestep, model_options)
cond_pred, uncond_pred = calc_conds_batch_wrapper(model, [cond, uncond_], x, timestep, model_options)
else:
cond_pred, uncond_pred = sliding_calc_conds_batch(model, [cond, uncond_], x, timestep, model_options)

if hasattr(comfy.samplers, "cfg_function"):
if ADGS.sample_settings.custom_cfg is not None:
cond_scale = ADGS.sample_settings.custom_cfg.get_cfg_scale(cond_pred)
model_options = ADGS.sample_settings.custom_cfg.get_model_options(model_options)
try:
cached_calc_cond_batch = comfy.samplers.calc_cond_batch
# support hooks and sliding context for PAG/other sampler_post_cfg_function tech that may use calc_cond_batch
comfy.samplers.calc_cond_batch = wrapped_cfg_sliding_calc_cond_batch_factory(cached_calc_cond_batch)
return comfy.samplers.cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options, cond, uncond)
finally:
comfy.samplers.calc_cond_batch = cached_calc_cond_batch
else: # for backwards compatibility, for now
if "sampler_cfg_function" in model_options:
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
cfg_result = x - model_options["sampler_cfg_function"](args)
else:
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale

for fn in model_options.get("sampler_post_cfg_function", []):
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
"sigma": timestep, "model_options": model_options, "input": x}
cfg_result = fn(args)

return cfg_result
if ADGS.sample_settings.custom_cfg is not None:
cond_scale = ADGS.sample_settings.custom_cfg.get_cfg_scale(cond_pred)
model_options = ADGS.sample_settings.custom_cfg.get_model_options(model_options)
try:
cached_calc_cond_batch = comfy.samplers.calc_cond_batch
# support hooks and sliding context for PAG/other sampler_post_cfg_function tech that may use calc_cond_batch
comfy.samplers.calc_cond_batch = wrapped_cfg_sliding_calc_cond_batch_factory(cached_calc_cond_batch)
return comfy.samplers.cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options, cond, uncond)
finally:
comfy.samplers.calc_cond_batch = cached_calc_cond_batch
finally:
ADGS.restore_special_model_features(model)

Expand Down Expand Up @@ -745,7 +707,7 @@ def wrapped_cfg_sliding_calc_cond_batch(model, conds, x_in, timestep, model_opti
# when inside sliding_calc_conds_batch, should return to original calc_cond_batch
comfy.samplers.calc_cond_batch = orig_calc_cond_batch
if not ADGS.is_using_sliding_context():
return calc_cond_uncond_batch_wrapper(model, conds, x_in, timestep, model_options)
return calc_conds_batch_wrapper(model, conds, x_in, timestep, model_options)
else:
return sliding_calc_conds_batch(model, conds, x_in, timestep, model_options)
finally:
Expand Down Expand Up @@ -922,7 +884,7 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list
model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.OFF
#logger.info(f"window: {curr_window_idx} - {model_options['transformer_options'][CONTEXTREF_MACHINE_STATE]}")

sub_conds_out = calc_cond_uncond_batch_wrapper(model, sub_conds, sub_x, sub_timestep, model_options)
sub_conds_out = calc_conds_batch_wrapper(model, sub_conds, sub_x, sub_timestep, model_options)

if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE:
full_length = ADGS.params.full_length
Expand Down Expand Up @@ -1008,7 +970,7 @@ def get_conds_with_c_concat(conds: list[dict], c_concat: comfy.conds.CONDNoiseSh
return new_conds


def calc_cond_uncond_batch_wrapper(model, conds: list[dict], x_in: Tensor, timestep, model_options):
def calc_conds_batch_wrapper(model, conds: list[dict], x_in: Tensor, timestep, model_options):
# check if conds or unconds contain lora_hook or default_cond
contains_lora_hooks = False
has_default_cond = False
Expand All @@ -1028,9 +990,6 @@ def calc_cond_uncond_batch_wrapper(model, conds: list[dict], x_in: Tensor, times
ADGS.hooks_initialize(model, hook_groups=hook_groups)
ADGS.prepare_hooks_current_keyframes(timestep, hook_groups=hook_groups)
return calc_conds_batch_lora_hook(model, conds, x_in, timestep, model_options, has_default_cond)
# keep for backwards compatibility, for now
if not hasattr(comfy.samplers, "calc_cond_batch"):
return comfy.samplers.calc_cond_uncond_batch(model, conds[0], conds[1], x_in, timestep, model_options)
return comfy.samplers.calc_cond_batch(model, conds, x_in, timestep, model_options)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-animatediff-evolved"
description = "Improved AnimateDiff integration for ComfyUI."
version = "1.1.4"
version = "1.2.0"
license = { file = "LICENSE" }
dependencies = []

Expand Down