Upstream InstantX Union ControlNet support for Flux

This commit is contained in:
doctorpangloss 2024-08-26 16:54:29 -07:00
parent 48ca1a4910
commit f49bcd4f3c
9 changed files with 993 additions and 37 deletions

View File

@ -7,4 +7,11 @@ UNION_CONTROLNET_TYPES = {
"segment": 5,
"tile": 6,
"repaint": 7,
"canny (InstantX)": 0,
"tile (InstantX)": 1,
"depth (InstantX)": 2,
"blur (InstantX)": 3,
"pose (InstantX)": 4,
"gray (InstantX)": 5,
"lq (InstantX)": 6
}

View File

@ -16,29 +16,31 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch
from enum import Enum
import logging
import math
import os
import logging
from enum import Enum
from . import utils
from . import model_management
import torch
from . import latent_formats
from . import model_detection
from . import model_management
from . import model_patcher
from . import ops
from . import latent_formats
from . import utils
from .cldm import cldm, mmdit
from .t2i_adapter import adapter
from .ldm import hydit, flux
from .ldm.cascade import controlnet as cascade_controlnet
from .ldm.flux.controlnet_instantx import InstantXControlNetFlux
from .ldm.flux.controlnet_instantx_format2 import InstantXControlNetFluxFormat2
from .ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
from .t2i_adapter import adapter
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size)
# print(current_batch_size, target_batch_size)
if current_batch_size == 1:
return tensor
@ -54,10 +56,12 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
else:
return torch.cat([tensor] * batched_number, dim=0)
class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
class ControlBase:
def __init__(self, device=None):
self.cond_hint_original = None
@ -129,7 +133,7 @@ class ControlBase:
return 0
def control_merge(self, control, control_prev, output_dtype):
out = {'input':[], 'middle':[], 'output': []}
out = {'input': [], 'middle': [], 'output': []}
for key in control:
control_output = control[key]
@ -140,7 +144,7 @@ class ControlBase:
if self.global_average_pooling:
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
if x not in applied_to: # memory saving strategy, allow shared tensors and only apply strength to shared tensors once
applied_to.add(x)
if self.strength_type == StrengthType.CONSTANT:
x *= self.strength
@ -166,7 +170,7 @@ class ControlBase:
if o[i].shape[0] < prev_val.shape[0]:
o[i] = prev_val + o[i]
else:
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
o[i] = prev_val + o[i] # TODO: change back to inplace add if shared tensors stop being an issue
return out
def set_extra_arg(self, argument, value=None):
@ -258,10 +262,11 @@ class ControlNet(ControlBase):
self.model_sampling_current = None
super().cleanup()
class ControlLoraOps:
class Linear(torch.nn.Module, ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
@ -280,18 +285,18 @@ class ControlLoraOps:
class Conv2d(torch.nn.Module, ops.CastWeightBiasOp):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros',
device=None,
dtype=None
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros',
device=None,
dtype=None
):
super().__init__()
self.in_channels = in_channels
@ -310,7 +315,6 @@ class ControlLoraOps:
self.up = None
self.down = None
def forward(self, input):
weight, bias = ops.cast_bias_weight(self, input)
if self.up is not None:
@ -339,6 +343,7 @@ class ControlLora(ControlNet):
else:
class control_lora_ops(ControlLoraOps, ops.manual_cast):
pass
dtype = self.manual_cast_dtype
controlnet_config["operations"] = control_lora_ops
@ -377,6 +382,7 @@ class ControlLora(ControlNet):
def inference_memory_requirements(self, dtype):
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def controlnet_config(sd):
model_config = model_detection.model_config_from_unet(sd, "", True)
@ -394,6 +400,7 @@ def controlnet_config(sd):
offload_device = model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)
@ -404,6 +411,7 @@ def controlnet_load_state_dict(control_model, sd):
logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model
def load_controlnet_mmdit(sd):
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
@ -415,7 +423,7 @@ def load_controlnet_mmdit(sd):
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness
latent_format.shift_factor = 0 # SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
@ -431,6 +439,62 @@ def load_controlnet_hunyuandit(controlnet_data):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control
def load_controlnet_flux_instantx(sd, controlnet_class, weight_dtype, full_path):
keys_to_keep = [
"controlnet_",
"single_transformer_blocks",
"transformer_blocks"
]
preserved_keys = {k: v.cpu() for k, v in sd.items() if any(k.startswith(key) for key in keys_to_keep)}
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
keys_to_discard = [
"double_blocks",
"single_blocks"
]
new_sd = {k: v for k, v in new_sd.items() if not any(k.startswith(discard_key) for discard_key in keys_to_discard)}
new_sd.update(preserved_keys)
config = {
"image_model": "flux",
"axes_dim": [16, 56, 56],
"in_channels": 16,
"depth": 5,
"depth_single_blocks": 10,
"context_in_dim": 4096,
"num_heads": 24,
"guidance_embed": True,
"hidden_size": 3072,
"mlp_ratio": 4.0,
"theta": 10000,
"qkv_bias": True,
"vec_in_dim": 768
}
device = model_management.get_torch_device()
if weight_dtype == "fp8_e4m3fn":
dtype = torch.float8_e4m3fn
operations = ops.manual_cast
elif weight_dtype == "fp8_e5m2":
dtype = torch.float8_e5m2
operations = ops.manual_cast
else:
dtype = torch.bfloat16
operations = ops.disable_weight_init
control_model = controlnet_class(operations=operations, device=device, dtype=dtype, **config)
control_model = controlnet_load_state_dict(control_model, new_sd)
extra_conds = ['y', 'guidance', 'control_type']
latent_format = latent_formats.SD3()
# TODO check manual cast dtype
control = ControlNet(control_model, compression_ratio=1, load_device=device, manual_cast_dtype=torch.bfloat16,
extra_conds=extra_conds, latent_format=latent_format, ckpt_name=full_path)
return control
def load_controlnet_flux_xlabs(sd):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
control_model = flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
@ -440,9 +504,13 @@ def load_controlnet_flux_xlabs(sd):
return control
def load_controlnet(ckpt_path, model=None):
def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
if "controlnet_mode_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, InstantXControlNetFluxFormat2, weight_dtype, ckpt_path)
if "controlnet_mode_embedder.fc.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, InstantXControlNetFlux, weight_dtype, ckpt_path)
if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data)
if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)
@ -450,7 +518,7 @@ def load_controlnet(ckpt_path, model=None):
controlnet_config = None
supported_inference_dtypes = None
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: # diffusers format
controlnet_config = model_detection.unet_config_from_diffusers_unet(controlnet_data)
diffusers_keys = utils.unet_to_diffusers(controlnet_config)
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
@ -490,7 +558,7 @@ def load_controlnet(ckpt_path, model=None):
if k in controlnet_data:
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
if "control_add_embedding.linear_1.bias" in controlnet_data: # Union Controlnet
controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
for k in list(controlnet_data.keys()):
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
@ -500,7 +568,7 @@ def load_controlnet(ckpt_path, model=None):
if len(leftover_keys) > 0:
logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
elif "controlnet_blocks.0.weight" in controlnet_data: # SD3 diffusers format
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs(controlnet_data)
else:
@ -558,6 +626,7 @@ def load_controlnet(ckpt_path, model=None):
class WeightsLoader(torch.nn.Module):
pass
w = WeightsLoader()
w.control_model = control_model
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
@ -572,12 +641,13 @@ def load_controlnet(ckpt_path, model=None):
global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): # TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype, ckpt_name=filename)
return control
class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
super().__init__(device)
@ -633,13 +703,14 @@ class T2IAdapter(ControlBase):
self.copy_to(c)
return c
def load_t2i_adapter(t2i_data):
compression_ratio = 8
upscale_algorithm = 'nearest-exact'
if 'adapter' in t2i_data:
t2i_data = t2i_data['adapter']
if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: # diffusers format
prefix_replace = {}
for i in range(4):
for j in range(2):
@ -663,7 +734,7 @@ def load_t2i_adapter(t2i_data):
xl = False
if cin == 256 or cin == 768:
xl = True
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel * 2, channel * 4, channel * 4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
elif "backbone.0.0.weight" in keys:
model_ad = cascade_controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
compression_ratio = 32

View File

@ -0,0 +1,309 @@
# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
import numbers
import numpy as np
import torch
import torch.nn.functional as F
from diffusers.models.normalization import AdaLayerNormContinuous
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from diffusers.utils.import_utils import is_torch_version
from einops import rearrange, repeat
from torch import Tensor, nn
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import (timestep_embedding)
from comfy.ldm.flux.model import Flux
if is_torch_version(">=", "2.1.0"):
LayerNorm = nn.LayerNorm
else:
# Has optional bias parameter compared to torch layer norm
# TODO: replace with torch layernorm once min required torch version >= 2.1
class LayerNorm(nn.Module):
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
else:
self.weight = None
self.bias = None
def forward(self, input):
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
class FluxUnionControlNetModeEmbedder(nn.Module):
def __init__(self, num_mode, out_channels):
super().__init__()
self.mode_embber = nn.Embedding(num_mode, out_channels)
self.norm = nn.LayerNorm(out_channels, eps=1e-6)
self.fc = nn.Linear(out_channels, out_channels)
def forward(self, x):
x_emb = self.mode_embber(x)
x_emb = self.norm(x_emb)
x_emb = self.fc(x_emb)
x_emb = x_emb[:, 0]
return x_emb
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
# YiYi to-do: refactor rope related functions/classes
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
class FluxUnionControlNetInputEmbedder(nn.Module):
def __init__(self, in_channels, out_channels, num_attention_heads=24, mlp_ratio=4.0, attention_head_dim=128, dtype=None, device=None, operations=None, depth=2):
super().__init__()
self.x_embedder = nn.Sequential(nn.LayerNorm(in_channels), nn.Linear(in_channels, out_channels))
self.norm = AdaLayerNormContinuous(out_channels, out_channels, elementwise_affine=False, eps=1e-6)
self.fc = nn.Linear(out_channels, out_channels)
self.emb_embedder = nn.Sequential(nn.LayerNorm(out_channels), nn.Linear(out_channels, out_channels))
""" self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
out_channels, num_attention_heads, dtype=dtype, device=device, operations=operations
)
for i in range(2)
]
) """
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=out_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(depth)
]
)
self.out = zero_module(nn.Linear(out_channels, out_channels))
def forward(self, x, mode_emb):
mode_token = self.emb_embedder(mode_emb)[:, None]
x_emb = self.fc(self.norm(self.x_embedder(x), mode_emb))
hidden_states = torch.cat([mode_token, x_emb], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
hidden_states = block(
hidden_states=hidden_states,
temb=mode_emb,
)
hidden_states = self.out(hidden_states)
res = hidden_states[:, 1:]
return res
class InstantXControlNetFlux(Flux):
def __init__(self, image_model=None, dtype=None, device=None, operations=None, joint_attention_dim=4096, **kwargs):
kwargs["depth"] = 0
kwargs["depth_single_blocks"] = 0
depth_single_blocks_controlnet = kwargs.pop("depth_single_blocks_controlnet", 2)
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.hidden_size,
num_attention_heads=24,
attention_head_dim=128,
).to(dtype=dtype)
for i in range(5)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.hidden_size,
num_attention_heads=24,
attention_head_dim=128,
).to(dtype=dtype)
for i in range(10)
]
)
self.require_vae = True
# add ControlNet blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(len(self.single_transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
controlnet_block = zero_module(controlnet_block)
self.controlnet_single_blocks.append(controlnet_block)
# TODO support both union and unimodal
self.union = True # num_mode is not None
num_mode = 10
if self.union:
self.controlnet_mode_embedder = zero_module(FluxUnionControlNetModeEmbedder(num_mode, self.hidden_size)).to(device=device, dtype=dtype)
self.controlnet_x_embedder = FluxUnionControlNetInputEmbedder(self.in_channels, self.hidden_size, operations=operations, depth=depth_single_blocks_controlnet).to(device=device, dtype=dtype)
self.controlnet_mode_token_embedder = nn.Sequential(nn.LayerNorm(self.hidden_size, eps=1e-6), nn.Linear(self.hidden_size, self.hidden_size)).to(device=device, dtype=dtype)
else:
self.controlnet_x_embedder = zero_module(torch.nn.Linear(self.in_channels, self.hidden_size)).to(device=device, dtype=dtype)
self.gradient_checkpointing = False
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def set_hint_latents(self, hint_latents):
vae_shift_factor = 0.1159
vae_scaling_factor = 0.3611
num_channels_latents = self.in_channels // 4
hint_latents = (hint_latents - vae_shift_factor) * vae_scaling_factor
height, width = hint_latents.shape[2:]
hint_latents = self._pack_latents(
hint_latents,
hint_latents.shape[0],
num_channels_latents,
height,
width,
)
self.hint_latents = hint_latents.to(device=self.device, dtype=self.dtype)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
controlnet_mode=None
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
batch_size = img.shape[0]
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(self.dtype))
if self.params.guidance_embed:
vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(self.dtype)))
vec.add_(self.vector_in(y))
if self.union:
if controlnet_mode is None:
raise ValueError('using union-controlnet, but controlnet_mode is not a list or is empty')
controlnet_mode = torch.tensor([[controlnet_mode]], device=self.device)
emb_controlnet_mode = self.controlnet_mode_embedder(controlnet_mode).to(self.dtype)
vec = vec + emb_controlnet_mode
img = img + self.controlnet_x_embedder(controlnet_cond, emb_controlnet_mode)
else:
img = img + self.controlnet_x_embedder(controlnet_cond)
txt = self.txt_in(txt)
if self.union:
token_controlnet_mode = self.controlnet_mode_token_embedder(emb_controlnet_mode)[:, None]
token_controlnet_mode = token_controlnet_mode.expand(txt.size(0), -1, -1)
txt = torch.cat([token_controlnet_mode, txt], dim=1)
txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids).to(dtype=self.dtype, device=self.device)
block_res_samples = ()
for block in self.transformer_blocks:
txt, img = block(hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe)
block_res_samples = block_res_samples + (img,)
img = torch.cat([txt, img], dim=1)
single_block_res_samples = ()
for block in self.single_transformer_blocks:
img = block(hidden_states=img, temb=vec, image_rotary_emb=pe)
single_block_res_samples = single_block_res_samples + (img[:, txt.shape[1]:],)
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
controlnet_single_block_res_samples = ()
for single_block_res_sample, single_controlnet_block in zip(single_block_res_samples, self.controlnet_single_blocks):
single_block_res_sample = single_controlnet_block(single_block_res_sample)
controlnet_single_block_res_samples = controlnet_single_block_res_samples + (single_block_res_sample,)
n_single_blocks = 38
n_double_blocks = 19
# Expand controlnet_block_res_samples to match n_double_blocks
expanded_controlnet_block_res_samples = []
interval_control_double = int(np.ceil(n_double_blocks / len(controlnet_block_res_samples)))
for i in range(n_double_blocks):
index = i // interval_control_double
expanded_controlnet_block_res_samples.append(controlnet_block_res_samples[index])
# Expand controlnet_single_block_res_samples to match n_single_blocks
expanded_controlnet_single_block_res_samples = []
interval_control_single = int(np.ceil(n_single_blocks / len(controlnet_single_block_res_samples)))
for i in range(n_single_blocks):
index = i // interval_control_single
expanded_controlnet_single_block_res_samples.append(controlnet_single_block_res_samples[index])
return {
"input": expanded_controlnet_block_res_samples,
"output": expanded_controlnet_single_block_res_samples
}
def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
height_control_image, width_control_image = hint.shape[2:]
num_channels_latents = self.in_channels // 4
hint = self._pack_latents(
hint,
hint.shape[0],
num_channels_latents,
height_control_image,
width_control_image,
)
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type)

View File

@ -0,0 +1,227 @@
# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
import numbers
import numpy as np
import torch
import torch.nn.functional as F
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from diffusers.utils.import_utils import is_torch_version
from einops import rearrange, repeat
from torch import Tensor, nn
from .layers import timestep_embedding
from .model import Flux
from ..common_dit import pad_to_patch_size
if is_torch_version(">=", "2.1.0"):
LayerNorm = nn.LayerNorm
else:
# Has optional bias parameter compared to torch layer norm
# TODO: replace with torch layernorm once min required torch version >= 2.1
class LayerNorm(nn.Module):
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
else:
self.weight = None
self.bias = None
def forward(self, input):
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
# YiYi to-do: refactor rope related functions/classes
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
class InstantXControlNetFluxFormat2(Flux):
def __init__(self, image_model=None, dtype=None, device=None, operations=None, joint_attention_dim=4096, **kwargs):
kwargs["depth"] = 0
kwargs["depth_single_blocks"] = 0
depth_single_blocks_controlnet = kwargs.pop("depth_single_blocks_controlnet", 2)
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.hidden_size,
num_attention_heads=24,
attention_head_dim=128,
).to(dtype=dtype)
for i in range(5)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.hidden_size,
num_attention_heads=24,
attention_head_dim=128,
).to(dtype=dtype)
for i in range(10)
]
)
self.require_vae = True
# add ControlNet blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(len(self.single_transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
controlnet_block = zero_module(controlnet_block)
self.controlnet_single_blocks.append(controlnet_block)
# TODO support both union and unimodal
self.union = True # num_mode is not None
num_mode = 10
if self.union:
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.hidden_size)
self.controlnet_x_embedder = zero_module(operations.Linear(self.in_channels, self.hidden_size).to(device=device, dtype=dtype))
self.gradient_checkpointing = False
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
controlnet_mode=None
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
batch_size = img.shape[0]
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(self.dtype))
if self.params.guidance_embed:
vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(self.dtype)))
vec.add_(self.vector_in(y))
txt = self.txt_in(txt)
if self.union:
if controlnet_mode is None:
raise ValueError('using union-controlnet, but controlnet_mode is not a list or is empty')
controlnet_mode = torch.tensor(controlnet_mode).to(self.device, dtype=torch.long)
controlnet_mode = controlnet_mode.reshape([-1, 1])
emb_controlnet_mode = self.controlnet_mode_embedder(controlnet_mode).to(self.dtype)
txt = torch.cat([emb_controlnet_mode, txt], dim=1)
txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1)
img = img + self.controlnet_x_embedder(controlnet_cond)
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
block_res_samples = ()
for block in self.transformer_blocks:
txt, img = block(hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe)
block_res_samples = block_res_samples + (img,)
img = torch.cat([txt, img], dim=1)
single_block_res_samples = ()
for block in self.single_transformer_blocks:
img = block(hidden_states=img, temb=vec, image_rotary_emb=pe)
single_block_res_samples = single_block_res_samples + (img[:, txt.shape[1]:],)
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
controlnet_single_block_res_samples = ()
for single_block_res_sample, single_controlnet_block in zip(single_block_res_samples, self.controlnet_single_blocks):
single_block_res_sample = single_controlnet_block(single_block_res_sample)
controlnet_single_block_res_samples = controlnet_single_block_res_samples + (single_block_res_sample,)
n_single_blocks = 38
n_double_blocks = 19
# Expand controlnet_block_res_samples to match n_double_blocks
expanded_controlnet_block_res_samples = []
interval_control_double = int(np.ceil(n_double_blocks / len(controlnet_block_res_samples)))
for i in range(n_double_blocks):
index = i // interval_control_double
expanded_controlnet_block_res_samples.append(controlnet_block_res_samples[index])
# Expand controlnet_single_block_res_samples to match n_single_blocks
expanded_controlnet_single_block_res_samples = []
interval_control_single = int(np.ceil(n_single_blocks / len(controlnet_single_block_res_samples)))
for i in range(n_single_blocks):
index = i // interval_control_single
expanded_controlnet_single_block_res_samples.append(controlnet_single_block_res_samples[index])
return {
"input": expanded_controlnet_block_res_samples,
"output": expanded_controlnet_single_block_res_samples
}
def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = pad_to_patch_size(x, (patch_size, patch_size))
height_control_image, width_control_image = hint.shape[2:]
num_channels_latents = self.in_channels // 4
hint = self._pack_latents(
hint,
hint.shape[0],
num_channels_latents,
height_control_image,
width_control_image,
)
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type)

View File

@ -40,6 +40,7 @@ class Flux(nn.Module):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.device = device
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params

View File

@ -0,0 +1 @@
FLUX_WEIGHT_DTYPES = ["default", "fp8_e4m3fn", "fp8_e5m2"]

View File

@ -385,6 +385,8 @@ KNOWN_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0-promax.safetensors"),
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0.safetensors"),
HuggingFile("InstantX/FLUX.1-dev-Controlnet-Canny", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-canny.safetensors"),
HuggingFile("InstantX/FLUX.1-dev-Controlnet-Union", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-union.safetensors"),
HuggingFile("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", save_with_filename="shakker-labs-flux.1-dev-controlnet-union-pro.safetensors"),
], folder_name="controlnet")
KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([

View File

@ -27,6 +27,7 @@ from ..cmd import folder_paths, latent_preview
from ..component_model.tensor_types import RGBImage
from ..execution_context import current_execution_context
from ..images import open_image
from ..ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
from ..nodes.common import MAX_RESOLUTION
from .. import controlnet
@ -756,6 +757,27 @@ class ControlNetLoader:
controlnet_ = controlnet.load_controlnet(controlnet_path)
return (controlnet_,)
class ControlNetLoaderWeights:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"control_net_name": (get_filename_list_with_downloadable("controlnet", KNOWN_CONTROLNETS),),
"weight_dtype": (FLUX_WEIGHT_DTYPES,)
}
}
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "load_controlnet"
CATEGORY = "loaders"
def load_controlnet(self, control_net_name, weight_dtype):
controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_CONTROLNETS)
controlnet_ = controlnet.load_controlnet(controlnet_path, weight_dtype=weight_dtype)
return (controlnet_,)
class DiffControlNetLoader:
@classmethod
def INPUT_TYPES(s):
@ -854,7 +876,7 @@ class UNETLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "unet_name": (get_filename_list_with_downloadable("diffusion_models", KNOWN_UNET_MODELS),),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
"weight_dtype": (FLUX_WEIGHT_DTYPES,)
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
@ -1882,6 +1904,7 @@ NODE_CLASS_MAPPINGS = {
"ControlNetApply": ControlNetApply,
"ControlNetApplyAdvanced": ControlNetApplyAdvanced,
"ControlNetLoader": ControlNetLoader,
"ControlNetLoaderWeights": ControlNetLoaderWeights,
"DiffControlNetLoader": DiffControlNetLoader,
"StyleModelLoader": StyleModelLoader,
"CLIPVisionLoader": CLIPVisionLoader,
@ -1914,6 +1937,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LoraLoader": "Load LoRA",
"CLIPLoader": "Load CLIP",
"ControlNetLoader": "Load ControlNet Model",
"ControlNetLoaderWeights": "Load ControlNet Model (Weights)",
"DiffControlNetLoader": "Load ControlNet Model (diff)",
"StyleModelLoader": "Load Style Model",
"CLIPVisionLoader": "Load CLIP Vision",

View File

@ -0,0 +1,314 @@
{
"1": {
"inputs": {
"noise": [
"2",
0
],
"guider": [
"3",
0
],
"sampler": [
"6",
0
],
"sigmas": [
"7",
0
],
"latent_image": [
"9",
0
]
},
"class_type": "SamplerCustomAdvanced",
"_meta": {
"title": "SamplerCustomAdvanced"
}
},
"2": {
"inputs": {
"noise_seed": 868192789249410
},
"class_type": "RandomNoise",
"_meta": {
"title": "RandomNoise"
}
},
"3": {
"inputs": {
"model": [
"12",
0
],
"conditioning": [
"17",
0
]
},
"class_type": "BasicGuider",
"_meta": {
"title": "BasicGuider"
}
},
"4": {
"inputs": {
"guidance": 4,
"conditioning": [
"13",
0
]
},
"class_type": "FluxGuidance",
"_meta": {
"title": "FluxGuidance"
}
},
"6": {
"inputs": {
"sampler_name": "euler"
},
"class_type": "KSamplerSelect",
"_meta": {
"title": "KSamplerSelect"
}
},
"7": {
"inputs": {
"scheduler": "ddim_uniform",
"steps": 10,
"denoise": 1,
"model": [
"12",
0
]
},
"class_type": "BasicScheduler",
"_meta": {
"title": "BasicScheduler"
}
},
"9": {
"inputs": {
"width": [
"27",
4
],
"height": [
"27",
5
],
"batch_size": 1
},
"class_type": "EmptySD3LatentImage",
"_meta": {
"title": "EmptySD3LatentImage"
}
},
"10": {
"inputs": {
"samples": [
"1",
0
],
"vae": [
"11",
0
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"11": {
"inputs": {
"vae_name": "ae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"12": {
"inputs": {
"unet_name": "flux1-dev.safetensors",
"weight_dtype": "fp8_e4m3fn"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"13": {
"inputs": {
"text": "A photo of a girl.",
"clip": [
"15",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"15": {
"inputs": {
"clip_name1": "clip_l.safetensors",
"clip_name2": "t5xxl_fp16.safetensors",
"type": "flux"
},
"class_type": "DualCLIPLoader",
"_meta": {
"title": "DualCLIPLoader"
}
},
"16": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"10",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"17": {
"inputs": {
"strength": 0.6,
"start_percent": 0,
"end_percent": 1,
"positive": [
"4",
0
],
"negative": [
"21",
0
],
"control_net": [
"24",
0
],
"vae": [
"11",
0
],
"image": [
"25",
0
]
},
"class_type": "ControlNetApplySD3",
"_meta": {
"title": "ControlNetApply SD3 and HunyuanDiT"
}
},
"19": {
"inputs": {
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
"name": "",
"title": "",
"description": "",
"__required": true
},
"class_type": "ImageRequestParameter",
"_meta": {
"title": "ImageRequestParameter"
}
},
"21": {
"inputs": {
"text": "",
"clip": [
"15",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"24": {
"inputs": {
"type": "canny (InstantX)",
"control_net": [
"28",
0
]
},
"class_type": "SetUnionControlNetType",
"_meta": {
"title": "SetUnionControlNetType"
}
},
"25": {
"inputs": {
"low_threshold": 0.4,
"high_threshold": 0.8,
"image": [
"31",
0
]
},
"class_type": "Canny",
"_meta": {
"title": "Canny"
}
},
"26": {
"inputs": {
"images": [
"25",
0
]
},
"class_type": "PreviewImage",
"_meta": {
"title": "Preview Image"
}
},
"27": {
"inputs": {
"image": [
"25",
0
]
},
"class_type": "Image Size to Number",
"_meta": {
"title": "Image Size to Number"
}
},
"28": {
"inputs": {
"control_net_name": "shakker-labs-flux.1-dev-controlnet-union-pro.safetensors",
"weight_dtype": "default"
},
"class_type": "ControlNetLoaderWeights",
"_meta": {
"title": "Load ControlNet Model (Weights)"
}
},
"31": {
"inputs": {
"image_gen_width": 1024,
"image_gen_height": 1024,
"resize_mode": "Crop and Resize",
"hint_image": [
"19",
0
]
},
"class_type": "HintImageEnchance",
"_meta": {
"title": "Enchance And Resize Hint Images"
}
}
}