From f49bcd4f3cddfb16f516d8d67d56f00164b26261 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Mon, 26 Aug 2024 16:54:29 -0700 Subject: [PATCH] Upstream InstantX Union ControlNet support for Flux --- comfy/cldm/control_types.py | 7 + comfy/controlnet.py | 143 ++++++-- comfy/ldm/flux/controlnet_instantx.py | 309 +++++++++++++++++ comfy/ldm/flux/controlnet_instantx_format2.py | 227 +++++++++++++ comfy/ldm/flux/model.py | 1 + comfy/ldm/flux/weight_dtypes.py | 1 + comfy/model_downloader.py | 2 + comfy/nodes/base_nodes.py | 26 +- .../workflows/flux-controlnet-0.json | 314 ++++++++++++++++++ 9 files changed, 993 insertions(+), 37 deletions(-) create mode 100644 comfy/ldm/flux/controlnet_instantx.py create mode 100644 comfy/ldm/flux/controlnet_instantx_format2.py create mode 100644 comfy/ldm/flux/weight_dtypes.py create mode 100644 tests/inference/workflows/flux-controlnet-0.json diff --git a/comfy/cldm/control_types.py b/comfy/cldm/control_types.py index 4128631a3..334aa3770 100644 --- a/comfy/cldm/control_types.py +++ b/comfy/cldm/control_types.py @@ -7,4 +7,11 @@ UNION_CONTROLNET_TYPES = { "segment": 5, "tile": 6, "repaint": 7, + "canny (InstantX)": 0, + "tile (InstantX)": 1, + "depth (InstantX)": 2, + "blur (InstantX)": 3, + "pose (InstantX)": 4, + "gray (InstantX)": 5, + "lq (InstantX)": 6 } diff --git a/comfy/controlnet.py b/comfy/controlnet.py index d874c8cd8..e56ad97ef 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -16,29 +16,31 @@ along with this program. If not, see . """ - -import torch -from enum import Enum +import logging import math import os -import logging +from enum import Enum -from . import utils -from . import model_management +import torch + +from . import latent_formats from . import model_detection +from . import model_management from . import model_patcher from . import ops -from . import latent_formats - +from . import utils from .cldm import cldm, mmdit -from .t2i_adapter import adapter from .ldm import hydit, flux from .ldm.cascade import controlnet as cascade_controlnet +from .ldm.flux.controlnet_instantx import InstantXControlNetFlux +from .ldm.flux.controlnet_instantx_format2 import InstantXControlNetFluxFormat2 +from .ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES +from .t2i_adapter import adapter def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] - #print(current_batch_size, target_batch_size) + # print(current_batch_size, target_batch_size) if current_batch_size == 1: return tensor @@ -54,10 +56,12 @@ def broadcast_image_to(tensor, target_batch_size, batched_number): else: return torch.cat([tensor] * batched_number, dim=0) + class StrengthType(Enum): CONSTANT = 1 LINEAR_UP = 2 + class ControlBase: def __init__(self, device=None): self.cond_hint_original = None @@ -129,7 +133,7 @@ class ControlBase: return 0 def control_merge(self, control, control_prev, output_dtype): - out = {'input':[], 'middle':[], 'output': []} + out = {'input': [], 'middle': [], 'output': []} for key in control: control_output = control[key] @@ -140,7 +144,7 @@ class ControlBase: if self.global_average_pooling: x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) - if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once + if x not in applied_to: # memory saving strategy, allow shared tensors and only apply strength to shared tensors once applied_to.add(x) if self.strength_type == StrengthType.CONSTANT: x *= self.strength @@ -166,7 +170,7 @@ class ControlBase: if o[i].shape[0] < prev_val.shape[0]: o[i] = prev_val + o[i] else: - o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue + o[i] = prev_val + o[i] # TODO: change back to inplace add if shared tensors stop being an issue return out def set_extra_arg(self, argument, value=None): @@ -258,10 +262,11 @@ class ControlNet(ControlBase): self.model_sampling_current = None super().cleanup() + class ControlLoraOps: class Linear(torch.nn.Module, ops.CastWeightBiasOp): def __init__(self, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None) -> None: + device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.in_features = in_features @@ -280,18 +285,18 @@ class ControlLoraOps: class Conv2d(torch.nn.Module, ops.CastWeightBiasOp): def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - padding_mode='zeros', - device=None, - dtype=None + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros', + device=None, + dtype=None ): super().__init__() self.in_channels = in_channels @@ -310,7 +315,6 @@ class ControlLoraOps: self.up = None self.down = None - def forward(self, input): weight, bias = ops.cast_bias_weight(self, input) if self.up is not None: @@ -339,6 +343,7 @@ class ControlLora(ControlNet): else: class control_lora_ops(ControlLoraOps, ops.manual_cast): pass + dtype = self.manual_cast_dtype controlnet_config["operations"] = control_lora_ops @@ -377,6 +382,7 @@ class ControlLora(ControlNet): def inference_memory_requirements(self, dtype): return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) + def controlnet_config(sd): model_config = model_detection.model_config_from_unet(sd, "", True) @@ -394,6 +400,7 @@ def controlnet_config(sd): offload_device = model_management.unet_offload_device() return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device + def controlnet_load_state_dict(control_model, sd): missing, unexpected = control_model.load_state_dict(sd, strict=False) @@ -404,6 +411,7 @@ def controlnet_load_state_dict(control_model, sd): logging.debug("unexpected controlnet keys: {}".format(unexpected)) return control_model + def load_controlnet_mmdit(sd): new_sd = model_detection.convert_diffusers_mmdit(sd, "") model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd) @@ -415,7 +423,7 @@ def load_controlnet_mmdit(sd): control_model = controlnet_load_state_dict(control_model, new_sd) latent_format = latent_formats.SD3() - latent_format.shift_factor = 0 #SD3 controlnet weirdness + latent_format.shift_factor = 0 # SD3 controlnet weirdness control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control @@ -431,6 +439,62 @@ def load_controlnet_hunyuandit(controlnet_data): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT) return control + +def load_controlnet_flux_instantx(sd, controlnet_class, weight_dtype, full_path): + keys_to_keep = [ + "controlnet_", + "single_transformer_blocks", + "transformer_blocks" + ] + preserved_keys = {k: v.cpu() for k, v in sd.items() if any(k.startswith(key) for key in keys_to_keep)} + + new_sd = model_detection.convert_diffusers_mmdit(sd, "") + + keys_to_discard = [ + "double_blocks", + "single_blocks" + ] + new_sd = {k: v for k, v in new_sd.items() if not any(k.startswith(discard_key) for discard_key in keys_to_discard)} + new_sd.update(preserved_keys) + + config = { + "image_model": "flux", + "axes_dim": [16, 56, 56], + "in_channels": 16, + "depth": 5, + "depth_single_blocks": 10, + "context_in_dim": 4096, + "num_heads": 24, + "guidance_embed": True, + "hidden_size": 3072, + "mlp_ratio": 4.0, + "theta": 10000, + "qkv_bias": True, + "vec_in_dim": 768 + } + + device = model_management.get_torch_device() + + if weight_dtype == "fp8_e4m3fn": + dtype = torch.float8_e4m3fn + operations = ops.manual_cast + elif weight_dtype == "fp8_e5m2": + dtype = torch.float8_e5m2 + operations = ops.manual_cast + else: + dtype = torch.bfloat16 + operations = ops.disable_weight_init + + control_model = controlnet_class(operations=operations, device=device, dtype=dtype, **config) + control_model = controlnet_load_state_dict(control_model, new_sd) + extra_conds = ['y', 'guidance', 'control_type'] + latent_format = latent_formats.SD3() + # TODO check manual cast dtype + control = ControlNet(control_model, compression_ratio=1, load_device=device, manual_cast_dtype=torch.bfloat16, + extra_conds=extra_conds, latent_format=latent_format, ckpt_name=full_path) + return control + + def load_controlnet_flux_xlabs(sd): model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd) control_model = flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) @@ -440,9 +504,13 @@ def load_controlnet_flux_xlabs(sd): return control -def load_controlnet(ckpt_path, model=None): +def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]): controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) - if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT + if "controlnet_mode_embedder.weight" in controlnet_data: + return load_controlnet_flux_instantx(controlnet_data, InstantXControlNetFluxFormat2, weight_dtype, ckpt_path) + if "controlnet_mode_embedder.fc.weight" in controlnet_data: + return load_controlnet_flux_instantx(controlnet_data, InstantXControlNetFlux, weight_dtype, ckpt_path) + if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT return load_controlnet_hunyuandit(controlnet_data) if "lora_controlnet" in controlnet_data: return ControlLora(controlnet_data) @@ -450,7 +518,7 @@ def load_controlnet(ckpt_path, model=None): controlnet_config = None supported_inference_dtypes = None - if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format + if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: # diffusers format controlnet_config = model_detection.unet_config_from_diffusers_unet(controlnet_data) diffusers_keys = utils.unet_to_diffusers(controlnet_config) diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight" @@ -490,7 +558,7 @@ def load_controlnet(ckpt_path, model=None): if k in controlnet_data: new_sd[diffusers_keys[k]] = controlnet_data.pop(k) - if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet + if "control_add_embedding.linear_1.bias" in controlnet_data: # Union Controlnet controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0] for k in list(controlnet_data.keys()): new_k = k.replace('.attn.in_proj_', '.attn.in_proj.') @@ -500,7 +568,7 @@ def load_controlnet(ckpt_path, model=None): if len(leftover_keys) > 0: logging.warning("leftover keys: {}".format(leftover_keys)) controlnet_data = new_sd - elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format + elif "controlnet_blocks.0.weight" in controlnet_data: # SD3 diffusers format if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data: return load_controlnet_flux_xlabs(controlnet_data) else: @@ -558,6 +626,7 @@ def load_controlnet(ckpt_path, model=None): class WeightsLoader(torch.nn.Module): pass + w = WeightsLoader() w.control_model = control_model missing, unexpected = w.load_state_dict(controlnet_data, strict=False) @@ -572,12 +641,13 @@ def load_controlnet(ckpt_path, model=None): global_average_pooling = False filename = os.path.splitext(ckpt_path)[0] - if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling + if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): # TODO: smarter way of enabling global_average_pooling global_average_pooling = True control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype, ckpt_name=filename) return control + class T2IAdapter(ControlBase): def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None): super().__init__(device) @@ -633,13 +703,14 @@ class T2IAdapter(ControlBase): self.copy_to(c) return c + def load_t2i_adapter(t2i_data): compression_ratio = 8 upscale_algorithm = 'nearest-exact' if 'adapter' in t2i_data: t2i_data = t2i_data['adapter'] - if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format + if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: # diffusers format prefix_replace = {} for i in range(4): for j in range(2): @@ -663,7 +734,7 @@ def load_t2i_adapter(t2i_data): xl = False if cin == 256 or cin == 768: xl = True - model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) + model_ad = adapter.Adapter(cin=cin, channels=[channel, channel * 2, channel * 4, channel * 4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl) elif "backbone.0.0.weight" in keys: model_ad = cascade_controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63]) compression_ratio = 32 diff --git a/comfy/ldm/flux/controlnet_instantx.py b/comfy/ldm/flux/controlnet_instantx.py new file mode 100644 index 000000000..5fd1f8026 --- /dev/null +++ b/comfy/ldm/flux/controlnet_instantx.py @@ -0,0 +1,309 @@ +# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py + +import numbers + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.models.normalization import AdaLayerNormContinuous +from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock +from diffusers.utils.import_utils import is_torch_version +from einops import rearrange, repeat +from torch import Tensor, nn + +import comfy.ldm.common_dit +from comfy.ldm.flux.layers import (timestep_embedding) +from comfy.ldm.flux.model import Flux + +if is_torch_version(">=", "2.1.0"): + LayerNorm = nn.LayerNorm +else: + # Has optional bias parameter compared to torch layer norm + # TODO: replace with torch layernorm once min required torch version >= 2.1 + class LayerNorm(nn.Module): + def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + self.bias = nn.Parameter(torch.zeros(dim)) if bias else None + else: + self.weight = None + self.bias = None + + def forward(self, input): + return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps) + + +class FluxUnionControlNetModeEmbedder(nn.Module): + def __init__(self, num_mode, out_channels): + super().__init__() + self.mode_embber = nn.Embedding(num_mode, out_channels) + self.norm = nn.LayerNorm(out_channels, eps=1e-6) + self.fc = nn.Linear(out_channels, out_channels) + + def forward(self, x): + x_emb = self.mode_embber(x) + x_emb = self.norm(x_emb) + x_emb = self.fc(x_emb) + x_emb = x_emb[:, 0] + return x_emb + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +# YiYi to-do: refactor rope related functions/classes +def apply_rope(xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +class FluxUnionControlNetInputEmbedder(nn.Module): + def __init__(self, in_channels, out_channels, num_attention_heads=24, mlp_ratio=4.0, attention_head_dim=128, dtype=None, device=None, operations=None, depth=2): + super().__init__() + self.x_embedder = nn.Sequential(nn.LayerNorm(in_channels), nn.Linear(in_channels, out_channels)) + self.norm = AdaLayerNormContinuous(out_channels, out_channels, elementwise_affine=False, eps=1e-6) + self.fc = nn.Linear(out_channels, out_channels) + self.emb_embedder = nn.Sequential(nn.LayerNorm(out_channels), nn.Linear(out_channels, out_channels)) + + """ self.single_blocks = nn.ModuleList( + [ + SingleStreamBlock( + out_channels, num_attention_heads, dtype=dtype, device=device, operations=operations + ) + for i in range(2) + ] + ) """ + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=out_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for i in range(depth) + ] + ) + + self.out = zero_module(nn.Linear(out_channels, out_channels)) + + def forward(self, x, mode_emb): + mode_token = self.emb_embedder(mode_emb)[:, None] + x_emb = self.fc(self.norm(self.x_embedder(x), mode_emb)) + hidden_states = torch.cat([mode_token, x_emb], dim=1) + for index_block, block in enumerate(self.single_transformer_blocks): + hidden_states = block( + hidden_states=hidden_states, + temb=mode_emb, + ) + hidden_states = self.out(hidden_states) + res = hidden_states[:, 1:] + return res + + +class InstantXControlNetFlux(Flux): + def __init__(self, image_model=None, dtype=None, device=None, operations=None, joint_attention_dim=4096, **kwargs): + kwargs["depth"] = 0 + kwargs["depth_single_blocks"] = 0 + depth_single_blocks_controlnet = kwargs.pop("depth_single_blocks_controlnet", 2) + super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.hidden_size, + num_attention_heads=24, + attention_head_dim=128, + ).to(dtype=dtype) + for i in range(5) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.hidden_size, + num_attention_heads=24, + attention_head_dim=128, + ).to(dtype=dtype) + for i in range(10) + ] + ) + + self.require_vae = True + # add ControlNet blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + + self.controlnet_single_blocks = nn.ModuleList([]) + for _ in range(len(self.single_transformer_blocks)): + controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) + controlnet_block = zero_module(controlnet_block) + self.controlnet_single_blocks.append(controlnet_block) + + # TODO support both union and unimodal + self.union = True # num_mode is not None + num_mode = 10 + if self.union: + self.controlnet_mode_embedder = zero_module(FluxUnionControlNetModeEmbedder(num_mode, self.hidden_size)).to(device=device, dtype=dtype) + self.controlnet_x_embedder = FluxUnionControlNetInputEmbedder(self.in_channels, self.hidden_size, operations=operations, depth=depth_single_blocks_controlnet).to(device=device, dtype=dtype) + self.controlnet_mode_token_embedder = nn.Sequential(nn.LayerNorm(self.hidden_size, eps=1e-6), nn.Linear(self.hidden_size, self.hidden_size)).to(device=device, dtype=dtype) + else: + self.controlnet_x_embedder = zero_module(torch.nn.Linear(self.in_channels, self.hidden_size)).to(device=device, dtype=dtype) + self.gradient_checkpointing = False + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + def set_hint_latents(self, hint_latents): + vae_shift_factor = 0.1159 + vae_scaling_factor = 0.3611 + num_channels_latents = self.in_channels // 4 + hint_latents = (hint_latents - vae_shift_factor) * vae_scaling_factor + + height, width = hint_latents.shape[2:] + hint_latents = self._pack_latents( + hint_latents, + hint_latents.shape[0], + num_channels_latents, + height, + width, + ) + self.hint_latents = hint_latents.to(device=self.device, dtype=self.dtype) + + def forward_orig( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor = None, + controlnet_mode=None + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + batch_size = img.shape[0] + + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256).to(self.dtype)) + if self.params.guidance_embed: + vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(self.dtype))) + vec.add_(self.vector_in(y)) + + if self.union: + if controlnet_mode is None: + raise ValueError('using union-controlnet, but controlnet_mode is not a list or is empty') + controlnet_mode = torch.tensor([[controlnet_mode]], device=self.device) + emb_controlnet_mode = self.controlnet_mode_embedder(controlnet_mode).to(self.dtype) + vec = vec + emb_controlnet_mode + img = img + self.controlnet_x_embedder(controlnet_cond, emb_controlnet_mode) + else: + img = img + self.controlnet_x_embedder(controlnet_cond) + + txt = self.txt_in(txt) + + if self.union: + token_controlnet_mode = self.controlnet_mode_token_embedder(emb_controlnet_mode)[:, None] + token_controlnet_mode = token_controlnet_mode.expand(txt.size(0), -1, -1) + txt = torch.cat([token_controlnet_mode, txt], dim=1) + txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1) + + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids).to(dtype=self.dtype, device=self.device) + + block_res_samples = () + for block in self.transformer_blocks: + txt, img = block(hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe) + block_res_samples = block_res_samples + (img,) + + img = torch.cat([txt, img], dim=1) + + single_block_res_samples = () + for block in self.single_transformer_blocks: + img = block(hidden_states=img, temb=vec, image_rotary_emb=pe) + single_block_res_samples = single_block_res_samples + (img[:, txt.shape[1]:],) + + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + controlnet_single_block_res_samples = () + for single_block_res_sample, single_controlnet_block in zip(single_block_res_samples, self.controlnet_single_blocks): + single_block_res_sample = single_controlnet_block(single_block_res_sample) + controlnet_single_block_res_samples = controlnet_single_block_res_samples + (single_block_res_sample,) + + n_single_blocks = 38 + n_double_blocks = 19 + + # Expand controlnet_block_res_samples to match n_double_blocks + expanded_controlnet_block_res_samples = [] + interval_control_double = int(np.ceil(n_double_blocks / len(controlnet_block_res_samples))) + for i in range(n_double_blocks): + index = i // interval_control_double + expanded_controlnet_block_res_samples.append(controlnet_block_res_samples[index]) + + # Expand controlnet_single_block_res_samples to match n_single_blocks + expanded_controlnet_single_block_res_samples = [] + interval_control_single = int(np.ceil(n_single_blocks / len(controlnet_single_block_res_samples))) + for i in range(n_single_blocks): + index = i // interval_control_single + expanded_controlnet_single_block_res_samples.append(controlnet_single_block_res_samples[index]) + + return { + "input": expanded_controlnet_block_res_samples, + "output": expanded_controlnet_single_block_res_samples + } + + def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs): + bs, c, h, w = x.shape + patch_size = 2 + x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) + + height_control_image, width_control_image = hint.shape[2:] + num_channels_latents = self.in_channels // 4 + hint = self._pack_latents( + hint, + hint.shape[0], + num_channels_latents, + height_control_image, + width_control_image, + ) + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type) diff --git a/comfy/ldm/flux/controlnet_instantx_format2.py b/comfy/ldm/flux/controlnet_instantx_format2.py new file mode 100644 index 000000000..0a5abe920 --- /dev/null +++ b/comfy/ldm/flux/controlnet_instantx_format2.py @@ -0,0 +1,227 @@ +# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py + +import numbers + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock +from diffusers.utils.import_utils import is_torch_version +from einops import rearrange, repeat +from torch import Tensor, nn + +from .layers import timestep_embedding +from .model import Flux +from ..common_dit import pad_to_patch_size + +if is_torch_version(">=", "2.1.0"): + LayerNorm = nn.LayerNorm +else: + # Has optional bias parameter compared to torch layer norm + # TODO: replace with torch layernorm once min required torch version >= 2.1 + class LayerNorm(nn.Module): + def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + self.bias = nn.Parameter(torch.zeros(dim)) if bias else None + else: + self.weight = None + self.bias = None + + def forward(self, input): + return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +# YiYi to-do: refactor rope related functions/classes +def apply_rope(xq, xk, freqs_cis): + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) + + +class InstantXControlNetFluxFormat2(Flux): + def __init__(self, image_model=None, dtype=None, device=None, operations=None, joint_attention_dim=4096, **kwargs): + kwargs["depth"] = 0 + kwargs["depth_single_blocks"] = 0 + depth_single_blocks_controlnet = kwargs.pop("depth_single_blocks_controlnet", 2) + super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) + + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.hidden_size, + num_attention_heads=24, + attention_head_dim=128, + ).to(dtype=dtype) + for i in range(5) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.hidden_size, + num_attention_heads=24, + attention_head_dim=128, + ).to(dtype=dtype) + for i in range(10) + ] + ) + + self.require_vae = True + # add ControlNet blocks + self.controlnet_blocks = nn.ModuleList([]) + for _ in range(len(self.transformer_blocks)): + controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + + self.controlnet_single_blocks = nn.ModuleList([]) + for _ in range(len(self.single_transformer_blocks)): + controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device) + controlnet_block = zero_module(controlnet_block) + self.controlnet_single_blocks.append(controlnet_block) + + # TODO support both union and unimodal + self.union = True # num_mode is not None + num_mode = 10 + if self.union: + self.controlnet_mode_embedder = nn.Embedding(num_mode, self.hidden_size) + self.controlnet_x_embedder = zero_module(operations.Linear(self.in_channels, self.hidden_size).to(device=device, dtype=dtype)) + self.gradient_checkpointing = False + + @staticmethod + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + def forward_orig( + self, + img: Tensor, + img_ids: Tensor, + controlnet_cond: Tensor, + txt: Tensor, + txt_ids: Tensor, + timesteps: Tensor, + y: Tensor, + guidance: Tensor = None, + controlnet_mode=None + ) -> Tensor: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + batch_size = img.shape[0] + + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256).to(self.dtype)) + if self.params.guidance_embed: + vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(self.dtype))) + vec.add_(self.vector_in(y)) + + txt = self.txt_in(txt) + + if self.union: + if controlnet_mode is None: + raise ValueError('using union-controlnet, but controlnet_mode is not a list or is empty') + controlnet_mode = torch.tensor(controlnet_mode).to(self.device, dtype=torch.long) + controlnet_mode = controlnet_mode.reshape([-1, 1]) + emb_controlnet_mode = self.controlnet_mode_embedder(controlnet_mode).to(self.dtype) + txt = torch.cat([emb_controlnet_mode, txt], dim=1) + txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1) + + img = img + self.controlnet_x_embedder(controlnet_cond) + + txt_ids = txt_ids.expand(img_ids.size(0), -1, -1) + ids = torch.cat((txt_ids, img_ids), dim=1) + pe = self.pe_embedder(ids) + + block_res_samples = () + for block in self.transformer_blocks: + txt, img = block(hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe) + block_res_samples = block_res_samples + (img,) + + img = torch.cat([txt, img], dim=1) + + single_block_res_samples = () + for block in self.single_transformer_blocks: + img = block(hidden_states=img, temb=vec, image_rotary_emb=pe) + single_block_res_samples = single_block_res_samples + (img[:, txt.shape[1]:],) + + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + controlnet_single_block_res_samples = () + for single_block_res_sample, single_controlnet_block in zip(single_block_res_samples, self.controlnet_single_blocks): + single_block_res_sample = single_controlnet_block(single_block_res_sample) + controlnet_single_block_res_samples = controlnet_single_block_res_samples + (single_block_res_sample,) + + n_single_blocks = 38 + n_double_blocks = 19 + + # Expand controlnet_block_res_samples to match n_double_blocks + expanded_controlnet_block_res_samples = [] + interval_control_double = int(np.ceil(n_double_blocks / len(controlnet_block_res_samples))) + for i in range(n_double_blocks): + index = i // interval_control_double + expanded_controlnet_block_res_samples.append(controlnet_block_res_samples[index]) + + # Expand controlnet_single_block_res_samples to match n_single_blocks + expanded_controlnet_single_block_res_samples = [] + interval_control_single = int(np.ceil(n_single_blocks / len(controlnet_single_block_res_samples))) + for i in range(n_single_blocks): + index = i // interval_control_single + expanded_controlnet_single_block_res_samples.append(controlnet_single_block_res_samples[index]) + + return { + "input": expanded_controlnet_block_res_samples, + "output": expanded_controlnet_single_block_res_samples + } + + def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs): + bs, c, h, w = x.shape + patch_size = 2 + x = pad_to_patch_size(x, (patch_size, patch_size)) + + height_control_image, width_control_image = hint.shape[2:] + num_channels_latents = self.in_channels // 4 + hint = self._pack_latents( + hint, + hint.shape[0], + num_channels_latents, + height_control_image, + width_control_image, + ) + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + + h_len = ((h + (patch_size // 2)) // patch_size) + w_len = ((w + (patch_size // 2)) // patch_size) + img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) + + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index bdc4a977c..cc74fccd6 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -40,6 +40,7 @@ class Flux(nn.Module): def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): super().__init__() + self.device = device self.dtype = dtype params = FluxParams(**kwargs) self.params = params diff --git a/comfy/ldm/flux/weight_dtypes.py b/comfy/ldm/flux/weight_dtypes.py new file mode 100644 index 000000000..0d76f751f --- /dev/null +++ b/comfy/ldm/flux/weight_dtypes.py @@ -0,0 +1 @@ +FLUX_WEIGHT_DTYPES = ["default", "fp8_e4m3fn", "fp8_e5m2"] diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 56c4ae708..bf57afc6c 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -385,6 +385,8 @@ KNOWN_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0-promax.safetensors"), HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0.safetensors"), HuggingFile("InstantX/FLUX.1-dev-Controlnet-Canny", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-canny.safetensors"), + HuggingFile("InstantX/FLUX.1-dev-Controlnet-Union", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-union.safetensors"), + HuggingFile("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", save_with_filename="shakker-labs-flux.1-dev-controlnet-union-pro.safetensors"), ], folder_name="controlnet") KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([ diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 2099576c7..b4368dd97 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -27,6 +27,7 @@ from ..cmd import folder_paths, latent_preview from ..component_model.tensor_types import RGBImage from ..execution_context import current_execution_context from ..images import open_image +from ..ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS from ..nodes.common import MAX_RESOLUTION from .. import controlnet @@ -756,6 +757,27 @@ class ControlNetLoader: controlnet_ = controlnet.load_controlnet(controlnet_path) return (controlnet_,) + +class ControlNetLoaderWeights: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "control_net_name": (get_filename_list_with_downloadable("controlnet", KNOWN_CONTROLNETS),), + "weight_dtype": (FLUX_WEIGHT_DTYPES,) + } + } + + RETURN_TYPES = ("CONTROL_NET",) + FUNCTION = "load_controlnet" + + CATEGORY = "loaders" + + def load_controlnet(self, control_net_name, weight_dtype): + controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_CONTROLNETS) + controlnet_ = controlnet.load_controlnet(controlnet_path, weight_dtype=weight_dtype) + return (controlnet_,) + class DiffControlNetLoader: @classmethod def INPUT_TYPES(s): @@ -854,7 +876,7 @@ class UNETLoader: @classmethod def INPUT_TYPES(s): return {"required": { "unet_name": (get_filename_list_with_downloadable("diffusion_models", KNOWN_UNET_MODELS),), - "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],) + "weight_dtype": (FLUX_WEIGHT_DTYPES,) }} RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" @@ -1882,6 +1904,7 @@ NODE_CLASS_MAPPINGS = { "ControlNetApply": ControlNetApply, "ControlNetApplyAdvanced": ControlNetApplyAdvanced, "ControlNetLoader": ControlNetLoader, + "ControlNetLoaderWeights": ControlNetLoaderWeights, "DiffControlNetLoader": DiffControlNetLoader, "StyleModelLoader": StyleModelLoader, "CLIPVisionLoader": CLIPVisionLoader, @@ -1914,6 +1937,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LoraLoader": "Load LoRA", "CLIPLoader": "Load CLIP", "ControlNetLoader": "Load ControlNet Model", + "ControlNetLoaderWeights": "Load ControlNet Model (Weights)", "DiffControlNetLoader": "Load ControlNet Model (diff)", "StyleModelLoader": "Load Style Model", "CLIPVisionLoader": "Load CLIP Vision", diff --git a/tests/inference/workflows/flux-controlnet-0.json b/tests/inference/workflows/flux-controlnet-0.json new file mode 100644 index 000000000..1e1347957 --- /dev/null +++ b/tests/inference/workflows/flux-controlnet-0.json @@ -0,0 +1,314 @@ +{ + "1": { + "inputs": { + "noise": [ + "2", + 0 + ], + "guider": [ + "3", + 0 + ], + "sampler": [ + "6", + 0 + ], + "sigmas": [ + "7", + 0 + ], + "latent_image": [ + "9", + 0 + ] + }, + "class_type": "SamplerCustomAdvanced", + "_meta": { + "title": "SamplerCustomAdvanced" + } + }, + "2": { + "inputs": { + "noise_seed": 868192789249410 + }, + "class_type": "RandomNoise", + "_meta": { + "title": "RandomNoise" + } + }, + "3": { + "inputs": { + "model": [ + "12", + 0 + ], + "conditioning": [ + "17", + 0 + ] + }, + "class_type": "BasicGuider", + "_meta": { + "title": "BasicGuider" + } + }, + "4": { + "inputs": { + "guidance": 4, + "conditioning": [ + "13", + 0 + ] + }, + "class_type": "FluxGuidance", + "_meta": { + "title": "FluxGuidance" + } + }, + "6": { + "inputs": { + "sampler_name": "euler" + }, + "class_type": "KSamplerSelect", + "_meta": { + "title": "KSamplerSelect" + } + }, + "7": { + "inputs": { + "scheduler": "ddim_uniform", + "steps": 10, + "denoise": 1, + "model": [ + "12", + 0 + ] + }, + "class_type": "BasicScheduler", + "_meta": { + "title": "BasicScheduler" + } + }, + "9": { + "inputs": { + "width": [ + "27", + 4 + ], + "height": [ + "27", + 5 + ], + "batch_size": 1 + }, + "class_type": "EmptySD3LatentImage", + "_meta": { + "title": "EmptySD3LatentImage" + } + }, + "10": { + "inputs": { + "samples": [ + "1", + 0 + ], + "vae": [ + "11", + 0 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "11": { + "inputs": { + "vae_name": "ae.safetensors" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "12": { + "inputs": { + "unet_name": "flux1-dev.safetensors", + "weight_dtype": "fp8_e4m3fn" + }, + "class_type": "UNETLoader", + "_meta": { + "title": "Load Diffusion Model" + } + }, + "13": { + "inputs": { + "text": "A photo of a girl.", + "clip": [ + "15", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "15": { + "inputs": { + "clip_name1": "clip_l.safetensors", + "clip_name2": "t5xxl_fp16.safetensors", + "type": "flux" + }, + "class_type": "DualCLIPLoader", + "_meta": { + "title": "DualCLIPLoader" + } + }, + "16": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "10", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "17": { + "inputs": { + "strength": 0.6, + "start_percent": 0, + "end_percent": 1, + "positive": [ + "4", + 0 + ], + "negative": [ + "21", + 0 + ], + "control_net": [ + "24", + 0 + ], + "vae": [ + "11", + 0 + ], + "image": [ + "25", + 0 + ] + }, + "class_type": "ControlNetApplySD3", + "_meta": { + "title": "ControlNetApply SD3 and HunyuanDiT" + } + }, + "19": { + "inputs": { + "value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png", + "name": "", + "title": "", + "description": "", + "__required": true + }, + "class_type": "ImageRequestParameter", + "_meta": { + "title": "ImageRequestParameter" + } + }, + "21": { + "inputs": { + "text": "", + "clip": [ + "15", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "24": { + "inputs": { + "type": "canny (InstantX)", + "control_net": [ + "28", + 0 + ] + }, + "class_type": "SetUnionControlNetType", + "_meta": { + "title": "SetUnionControlNetType" + } + }, + "25": { + "inputs": { + "low_threshold": 0.4, + "high_threshold": 0.8, + "image": [ + "31", + 0 + ] + }, + "class_type": "Canny", + "_meta": { + "title": "Canny" + } + }, + "26": { + "inputs": { + "images": [ + "25", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image" + } + }, + "27": { + "inputs": { + "image": [ + "25", + 0 + ] + }, + "class_type": "Image Size to Number", + "_meta": { + "title": "Image Size to Number" + } + }, + "28": { + "inputs": { + "control_net_name": "shakker-labs-flux.1-dev-controlnet-union-pro.safetensors", + "weight_dtype": "default" + }, + "class_type": "ControlNetLoaderWeights", + "_meta": { + "title": "Load ControlNet Model (Weights)" + } + }, + "31": { + "inputs": { + "image_gen_width": 1024, + "image_gen_height": 1024, + "resize_mode": "Crop and Resize", + "hint_image": [ + "19", + 0 + ] + }, + "class_type": "HintImageEnchance", + "_meta": { + "title": "Enchance And Resize Hint Images" + } + } +} \ No newline at end of file