From b33cd610703213dbe73baa6aaa3fdc2c61a84adc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 28 Aug 2024 18:56:33 -0400 Subject: [PATCH] InstantX canny controlnet. --- comfy/controlnet.py | 21 +++++- .../{controlnet_xlabs.py => controlnet.py} | 67 ++++++++++++------- comfy/utils.py | 2 + 3 files changed, 63 insertions(+), 27 deletions(-) rename comfy/ldm/flux/{controlnet_xlabs.py => controlnet.py} (62%) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index d4479589e8b..0c8cd30c404 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -34,7 +34,7 @@ import comfy.ldm.cascade.controlnet import comfy.cldm.mmdit import comfy.ldm.hydit.controlnet -import comfy.ldm.flux.controlnet_xlabs +import comfy.ldm.flux.controlnet def broadcast_image_to(tensor, target_batch_size, batched_number): @@ -433,12 +433,25 @@ def load_controlnet_hunyuandit(controlnet_data): def load_controlnet_flux_xlabs(sd): model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd) - control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, sd) extra_conds = ['y', 'guidance'] control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control +def load_controlnet_flux_instantx(sd): + new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") + model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd) + for k in sd: + new_sd[k] = sd[k] + + control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_model = controlnet_load_state_dict(control_model, new_sd) + + latent_format = comfy.latent_formats.Flux() + extra_conds = ['y', 'guidance'] + control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) + return control def load_controlnet(ckpt_path, model=None): controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) @@ -504,8 +517,10 @@ def load_controlnet(ckpt_path, model=None): elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data: return load_controlnet_flux_xlabs(controlnet_data) - else: + elif "pos_embed_input.proj.weight" in controlnet_data: return load_controlnet_mmdit(controlnet_data) + elif "controlnet_x_embedder.weight" in controlnet_data: + return load_controlnet_flux_instantx(controlnet_data) pth_key = 'control_model.zero_convs.0.0.weight' pth = False diff --git a/comfy/ldm/flux/controlnet_xlabs.py b/comfy/ldm/flux/controlnet.py similarity index 62% rename from comfy/ldm/flux/controlnet_xlabs.py rename to comfy/ldm/flux/controlnet.py index 5d700f16c9f..0e160b07529 100644 --- a/comfy/ldm/flux/controlnet_xlabs.py +++ b/comfy/ldm/flux/controlnet.py @@ -1,6 +1,7 @@ #Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py import torch +import math from torch import Tensor, nn from einops import rearrange, repeat @@ -13,34 +14,38 @@ class ControlNetFlux(Flux): - def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs): + def __init__(self, latent_input=False, image_model=None, dtype=None, device=None, operations=None, **kwargs): super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) + self.main_model_double = 19 + self.main_model_single = 38 # add ControlNet blocks self.controlnet_blocks = nn.ModuleList([]) for _ in range(self.params.depth): controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) # controlnet_block = zero_module(controlnet_block) self.controlnet_blocks.append(controlnet_block) - self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) self.gradient_checkpointing = False - self.input_hint_block = nn.Sequential( - operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), - nn.SiLU(), - operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device) - ) + self.latent_input = latent_input + self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) + if not self.latent_input: + self.input_hint_block = nn.Sequential( + operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device) + ) def forward_orig( self, @@ -58,8 +63,10 @@ def forward_orig( # running on sequences img img = self.img_in(img) - controlnet_cond = self.input_hint_block(controlnet_cond) - controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + if not self.latent_input: + controlnet_cond = self.input_hint_block(controlnet_cond) + controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) + controlnet_cond = self.pos_embed_input(controlnet_cond) img = img + controlnet_cond vec = self.time_in(timestep_embedding(timesteps, 256)) @@ -82,13 +89,25 @@ def forward_orig( block_res_sample = controlnet_block(block_res_sample) controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) - return {"input": (controlnet_block_res_samples * 10)[:19]} + + repeat = math.ceil(self.main_model_double / len(controlnet_block_res_samples)) + if self.latent_input: + out_input = () + for x in controlnet_block_res_samples: + out_input += (x,) * repeat + else: + out_input = (controlnet_block_res_samples * repeat) + return {"input": out_input[:self.main_model_double]} def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs): - hint = hint * 2.0 - 1.0 + patch_size = 2 + if self.latent_input: + hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size)) + hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + else: + hint = hint * 2.0 - 1.0 bs, c, h, w = x.shape - patch_size = 2 x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) diff --git a/comfy/utils.py b/comfy/utils.py index d0d410d9756..1bc35df7a4c 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -528,6 +528,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): ("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"), ("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift), ("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift), + ("pos_embed_input.bias", "controlnet_x_embedder.bias"), + ("pos_embed_input.weight", "controlnet_x_embedder.weight"), } for k in MAP_BASIC: