Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial flux support, refactoring weight control #169

Merged
merged 11 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 61 additions & 23 deletions adv_control/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import comfy.model_management
import comfy.model_detection
import comfy.controlnet as comfy_cn
from comfy.controlnet import ControlBase, ControlNet, ControlLora, T2IAdapter
from comfy.controlnet import ControlBase, ControlNet, ControlLora, T2IAdapter, StrengthType
from comfy.model_patcher import ModelPatcher

from .control_sparsectrl import SparseModelPatcher, SparseControlNet, SparseCtrlMotionWrapper, SparseSettings, SparseConst
Expand All @@ -21,13 +21,23 @@


class ControlNetAdvanced(ControlNet, AdvancedControlBase):
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, compression_ratio=compression_ratio, latent_format=latent_format, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet())
self.is_flux = False
self.x_noisy_shape = None

def get_universal_weights(self) -> ControlWeights:
raw_weights = [(self.weights.base_multiplier ** float(12 - i)) for i in range(13)]
return self.weights.copy_with_new_weights(raw_weights)
def cn_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
if key == "middle":
return 1.0
c_len = len(control[key])
raw_weights = [(self.weights.base_multiplier ** float((c_len) - i)) for i in range(c_len+1)]
raw_weights = raw_weights[:-1]
if key == "input":
raw_weights.reverse()
return raw_weights[idx]
return self.weights.copy_with_new_weights(new_weight_func=cn_weights_func)

def get_control_advanced(self, x_noisy, t, cond, batched_number):
# perform special version of get_control that supports sliding context and masks
Expand All @@ -49,7 +59,6 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype

output_dtype = x_noisy.dtype
# make cond_hint appropriate dimensions
# TODO: change this to not require cond_hint upscaling every step when self.sub_idxs are present
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
Expand All @@ -64,9 +73,9 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
actual_cond_hint_orig = self.cond_hint_original
if self.cond_hint_original.size(0) < self.full_latent_length:
actual_cond_hint_orig = extend_to_batch_size(tensor=actual_cond_hint_orig, batch_size=self.full_latent_length)
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center")
self.cond_hint = comfy.utils.common_upscale(actual_cond_hint_orig[self.sub_idxs], x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
else:
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, 'nearest-exact', "center")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
Expand All @@ -81,25 +90,44 @@ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)

context = cond.get('crossattn_controlnet', cond['c_crossattn'])
y = cond.get('y', None)
if y is not None:
y = y.to(dtype)
extra = self.extra_args.copy()
for c in self.extra_conds:
temp = cond.get(c, None)
if temp is not None:
extra[c] = temp.to(dtype)

timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
self.x_noisy_shape = x_noisy.shape
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
return self.control_merge(control, control_prev, output_dtype=None)

control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(control, control_prev, output_dtype)
def pre_run_advanced(self, *args, **kwargs):
self.is_flux = "Flux" in str(type(self.control_model).__name__)
return super().pre_run_advanced(*args, **kwargs)

def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int, flux_shape=None):
if self.is_flux:
flux_shape = self.x_noisy_shape
return super().apply_advanced_strengths_and_masks(x, batched_number, flux_shape)

def copy(self):
c = ControlNetAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
self.copy_to_advanced(c)
return c

def cleanup_advanced(self):
self.x_noisy_shape = None
return super().cleanup_advanced()

@staticmethod
def from_vanilla(v: ControlNet, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlNetAdvanced':
to_return = ControlNetAdvanced(control_model=v.control_model, timestep_keyframes=timestep_keyframe,
global_average_pooling=v.global_average_pooling, compression_ratio=v.compression_ratio, latent_format=v.latent_format, device=v.device, load_device=v.load_device, manual_cast_dtype=v.manual_cast_dtype)
global_average_pooling=v.global_average_pooling, compression_ratio=v.compression_ratio, latent_format=v.latent_format, device=v.device, load_device=v.load_device,
manual_cast_dtype=v.manual_cast_dtype)
v.copy_to(to_return)
return to_return

Expand All @@ -121,18 +149,28 @@ def control_merge_inject(self, control: dict[str, list[Tensor]], control_prev, o
return AdvancedControlBase.control_merge_inject(self, control, control_prev, output_dtype)

def get_universal_weights(self) -> ControlWeights:
raw_weights = [(self.weights.base_multiplier ** float(7 - i)) for i in range(8)]
raw_weights = [raw_weights[-8], raw_weights[-3], raw_weights[-2], raw_weights[-1]]
raw_weights = get_properly_arranged_t2i_weights(raw_weights)
raw_weights.reverse() # need to reverse to match recent ComfyUI changes
return self.weights.copy_with_new_weights(raw_weights)
def t2i_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
if key == "middle":
return 1.0
c_len = 8 #len(control[key])
raw_weights = [(self.weights.base_multiplier ** float((c_len-1) - i)) for i in range(c_len)]
raw_weights = [raw_weights[-c_len], raw_weights[-3], raw_weights[-2], raw_weights[-1]]
raw_weights = get_properly_arranged_t2i_weights(raw_weights)
if key == "input":
raw_weights.reverse()
return raw_weights[idx]
return self.weights.copy_with_new_weights(new_weight_func=t2i_weights_func)

def get_calc_pow(self, idx: int, control: dict[str, list[Tensor]], key: str) -> int:
if key == "middle":
return 0
# match how T2IAdapterAdvanced deals with universal weights
indeces = [7 - i for i in range(8)]
indeces = [indeces[-8], indeces[-3], indeces[-2], indeces[-1]]
c_len = 8 #len(control[key])
indeces = [(c_len-1) - i for i in range(c_len)]
indeces = [indeces[-c_len], indeces[-3], indeces[-2], indeces[-1]]
indeces = get_properly_arranged_t2i_weights(indeces)
indeces.reverse() # need to reverse to match recent ComfyUI changes
if key == "input":
indeces.reverse() # need to reverse to match recent ComfyUI changes
return indeces[idx]

def get_control_advanced(self, x_noisy, t, cond, batched_number):
Expand Down Expand Up @@ -381,11 +419,11 @@ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(control, control_prev, output_dtype)

def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int):
def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int, *args, **kwargs):
# apply mults to indexes with and without a direct condhint
x[self.local_sparse_idxs] *= self.sparse_settings.sparse_hint_mult * self.weights.extras.get(SparseConst.HINT_MULT, 1.0)
x[self.local_sparse_idxs_inverse] *= self.sparse_settings.sparse_nonhint_mult * self.weights.extras.get(SparseConst.NONHINT_MULT, 1.0)
return super().apply_advanced_strengths_and_masks(x, batched_number)
return super().apply_advanced_strengths_and_masks(x, batched_number, *args, **kwargs)

def pre_run_advanced(self, model, percent_to_timestep_function):
super().pre_run_advanced(model, percent_to_timestep_function)
Expand Down
13 changes: 10 additions & 3 deletions adv_control/control_plusplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,16 @@ def __init__(self, control_model: ControlNetPlusPlus, timestep_keyframes: Timest
self.single_control_type: str = None

def get_universal_weights(self) -> ControlWeights:
# TODO: match actual layer count of model
raw_weights = [(self.weights.base_multiplier ** float(12 - i)) for i in range(13)]
return self.weights.copy_with_new_weights(raw_weights)
def cn_weights_func(idx: int, control: dict[str, list[Tensor]], key: str):
if key == "middle":
return 1.0
c_len = len(control[key])
raw_weights = [(self.weights.base_multiplier ** float((c_len) - i)) for i in range(c_len+1)]
raw_weights = raw_weights[:-1]
if key == "input":
raw_weights.reverse()
return raw_weights[idx]
return self.weights.copy_with_new_weights(new_weight_func=cn_weights_func)

def verify_control_type(self, model_name: str, pp_group: PlusPlusInputGroup=None):
if pp_group is not None:
Expand Down
Loading