mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Upstream InstantX Union ControlNet support for Flux
This commit is contained in:
parent
48ca1a4910
commit
f49bcd4f3c
@ -7,4 +7,11 @@ UNION_CONTROLNET_TYPES = {
|
||||
"segment": 5,
|
||||
"tile": 6,
|
||||
"repaint": 7,
|
||||
"canny (InstantX)": 0,
|
||||
"tile (InstantX)": 1,
|
||||
"depth (InstantX)": 2,
|
||||
"blur (InstantX)": 3,
|
||||
"pose (InstantX)": 4,
|
||||
"gray (InstantX)": 5,
|
||||
"lq (InstantX)": 6
|
||||
}
|
||||
|
||||
@ -16,29 +16,31 @@
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
from enum import Enum
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from . import utils
|
||||
from . import model_management
|
||||
import torch
|
||||
|
||||
from . import latent_formats
|
||||
from . import model_detection
|
||||
from . import model_management
|
||||
from . import model_patcher
|
||||
from . import ops
|
||||
from . import latent_formats
|
||||
|
||||
from . import utils
|
||||
from .cldm import cldm, mmdit
|
||||
from .t2i_adapter import adapter
|
||||
from .ldm import hydit, flux
|
||||
from .ldm.cascade import controlnet as cascade_controlnet
|
||||
from .ldm.flux.controlnet_instantx import InstantXControlNetFlux
|
||||
from .ldm.flux.controlnet_instantx_format2 import InstantXControlNetFluxFormat2
|
||||
from .ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
|
||||
from .t2i_adapter import adapter
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
#print(current_batch_size, target_batch_size)
|
||||
# print(current_batch_size, target_batch_size)
|
||||
if current_batch_size == 1:
|
||||
return tensor
|
||||
|
||||
@ -54,10 +56,12 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
else:
|
||||
return torch.cat([tensor] * batched_number, dim=0)
|
||||
|
||||
|
||||
class StrengthType(Enum):
|
||||
CONSTANT = 1
|
||||
LINEAR_UP = 2
|
||||
|
||||
|
||||
class ControlBase:
|
||||
def __init__(self, device=None):
|
||||
self.cond_hint_original = None
|
||||
@ -129,7 +133,7 @@ class ControlBase:
|
||||
return 0
|
||||
|
||||
def control_merge(self, control, control_prev, output_dtype):
|
||||
out = {'input':[], 'middle':[], 'output': []}
|
||||
out = {'input': [], 'middle': [], 'output': []}
|
||||
|
||||
for key in control:
|
||||
control_output = control[key]
|
||||
@ -140,7 +144,7 @@ class ControlBase:
|
||||
if self.global_average_pooling:
|
||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||
|
||||
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
||||
if x not in applied_to: # memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
||||
applied_to.add(x)
|
||||
if self.strength_type == StrengthType.CONSTANT:
|
||||
x *= self.strength
|
||||
@ -166,7 +170,7 @@ class ControlBase:
|
||||
if o[i].shape[0] < prev_val.shape[0]:
|
||||
o[i] = prev_val + o[i]
|
||||
else:
|
||||
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
|
||||
o[i] = prev_val + o[i] # TODO: change back to inplace add if shared tensors stop being an issue
|
||||
return out
|
||||
|
||||
def set_extra_arg(self, argument, value=None):
|
||||
@ -258,10 +262,11 @@ class ControlNet(ControlBase):
|
||||
self.model_sampling_current = None
|
||||
super().cleanup()
|
||||
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module, ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
@ -280,18 +285,18 @@ class ControlLoraOps:
|
||||
|
||||
class Conv2d(torch.nn.Module, ops.CastWeightBiasOp):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode='zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode='zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
@ -310,7 +315,6 @@ class ControlLoraOps:
|
||||
self.up = None
|
||||
self.down = None
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias = ops.cast_bias_weight(self, input)
|
||||
if self.up is not None:
|
||||
@ -339,6 +343,7 @@ class ControlLora(ControlNet):
|
||||
else:
|
||||
class control_lora_ops(ControlLoraOps, ops.manual_cast):
|
||||
pass
|
||||
|
||||
dtype = self.manual_cast_dtype
|
||||
|
||||
controlnet_config["operations"] = control_lora_ops
|
||||
@ -377,6 +382,7 @@ class ControlLora(ControlNet):
|
||||
def inference_memory_requirements(self, dtype):
|
||||
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||
|
||||
|
||||
def controlnet_config(sd):
|
||||
model_config = model_detection.model_config_from_unet(sd, "", True)
|
||||
|
||||
@ -394,6 +400,7 @@ def controlnet_config(sd):
|
||||
offload_device = model_management.unet_offload_device()
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||
|
||||
|
||||
def controlnet_load_state_dict(control_model, sd):
|
||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||
|
||||
@ -404,6 +411,7 @@ def controlnet_load_state_dict(control_model, sd):
|
||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||
return control_model
|
||||
|
||||
|
||||
def load_controlnet_mmdit(sd):
|
||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||
@ -415,7 +423,7 @@ def load_controlnet_mmdit(sd):
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
|
||||
latent_format = latent_formats.SD3()
|
||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||
latent_format.shift_factor = 0 # SD3 controlnet weirdness
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
@ -431,6 +439,62 @@ def load_controlnet_hunyuandit(controlnet_data):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_flux_instantx(sd, controlnet_class, weight_dtype, full_path):
|
||||
keys_to_keep = [
|
||||
"controlnet_",
|
||||
"single_transformer_blocks",
|
||||
"transformer_blocks"
|
||||
]
|
||||
preserved_keys = {k: v.cpu() for k, v in sd.items() if any(k.startswith(key) for key in keys_to_keep)}
|
||||
|
||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||
|
||||
keys_to_discard = [
|
||||
"double_blocks",
|
||||
"single_blocks"
|
||||
]
|
||||
new_sd = {k: v for k, v in new_sd.items() if not any(k.startswith(discard_key) for discard_key in keys_to_discard)}
|
||||
new_sd.update(preserved_keys)
|
||||
|
||||
config = {
|
||||
"image_model": "flux",
|
||||
"axes_dim": [16, 56, 56],
|
||||
"in_channels": 16,
|
||||
"depth": 5,
|
||||
"depth_single_blocks": 10,
|
||||
"context_in_dim": 4096,
|
||||
"num_heads": 24,
|
||||
"guidance_embed": True,
|
||||
"hidden_size": 3072,
|
||||
"mlp_ratio": 4.0,
|
||||
"theta": 10000,
|
||||
"qkv_bias": True,
|
||||
"vec_in_dim": 768
|
||||
}
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
|
||||
if weight_dtype == "fp8_e4m3fn":
|
||||
dtype = torch.float8_e4m3fn
|
||||
operations = ops.manual_cast
|
||||
elif weight_dtype == "fp8_e5m2":
|
||||
dtype = torch.float8_e5m2
|
||||
operations = ops.manual_cast
|
||||
else:
|
||||
dtype = torch.bfloat16
|
||||
operations = ops.disable_weight_init
|
||||
|
||||
control_model = controlnet_class(operations=operations, device=device, dtype=dtype, **config)
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
extra_conds = ['y', 'guidance', 'control_type']
|
||||
latent_format = latent_formats.SD3()
|
||||
# TODO check manual cast dtype
|
||||
control = ControlNet(control_model, compression_ratio=1, load_device=device, manual_cast_dtype=torch.bfloat16,
|
||||
extra_conds=extra_conds, latent_format=latent_format, ckpt_name=full_path)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_flux_xlabs(sd):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
||||
control_model = flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
@ -440,9 +504,13 @@ def load_controlnet_flux_xlabs(sd):
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
|
||||
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||
if "controlnet_mode_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, InstantXControlNetFluxFormat2, weight_dtype, ckpt_path)
|
||||
if "controlnet_mode_embedder.fc.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, InstantXControlNetFlux, weight_dtype, ckpt_path)
|
||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT
|
||||
return load_controlnet_hunyuandit(controlnet_data)
|
||||
if "lora_controlnet" in controlnet_data:
|
||||
return ControlLora(controlnet_data)
|
||||
@ -450,7 +518,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_config = None
|
||||
supported_inference_dtypes = None
|
||||
|
||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: # diffusers format
|
||||
controlnet_config = model_detection.unet_config_from_diffusers_unet(controlnet_data)
|
||||
diffusers_keys = utils.unet_to_diffusers(controlnet_config)
|
||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||
@ -490,7 +558,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if k in controlnet_data:
|
||||
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
||||
|
||||
if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
|
||||
if "control_add_embedding.linear_1.bias" in controlnet_data: # Union Controlnet
|
||||
controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
|
||||
for k in list(controlnet_data.keys()):
|
||||
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
|
||||
@ -500,7 +568,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if len(leftover_keys) > 0:
|
||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||
controlnet_data = new_sd
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data: # SD3 diffusers format
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||
return load_controlnet_flux_xlabs(controlnet_data)
|
||||
else:
|
||||
@ -558,6 +626,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
w = WeightsLoader()
|
||||
w.control_model = control_model
|
||||
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
||||
@ -572,12 +641,13 @@ def load_controlnet(ckpt_path, model=None):
|
||||
|
||||
global_average_pooling = False
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): # TODO: smarter way of enabling global_average_pooling
|
||||
global_average_pooling = True
|
||||
|
||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype, ckpt_name=filename)
|
||||
return control
|
||||
|
||||
|
||||
class T2IAdapter(ControlBase):
|
||||
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
||||
super().__init__(device)
|
||||
@ -633,13 +703,14 @@ class T2IAdapter(ControlBase):
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
|
||||
def load_t2i_adapter(t2i_data):
|
||||
compression_ratio = 8
|
||||
upscale_algorithm = 'nearest-exact'
|
||||
|
||||
if 'adapter' in t2i_data:
|
||||
t2i_data = t2i_data['adapter']
|
||||
if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
|
||||
if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: # diffusers format
|
||||
prefix_replace = {}
|
||||
for i in range(4):
|
||||
for j in range(2):
|
||||
@ -663,7 +734,7 @@ def load_t2i_adapter(t2i_data):
|
||||
xl = False
|
||||
if cin == 256 or cin == 768:
|
||||
xl = True
|
||||
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
||||
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel * 2, channel * 4, channel * 4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
||||
elif "backbone.0.0.weight" in keys:
|
||||
model_ad = cascade_controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
||||
compression_ratio = 32
|
||||
|
||||
309
comfy/ldm/flux/controlnet_instantx.py
Normal file
309
comfy/ldm/flux/controlnet_instantx.py
Normal file
@ -0,0 +1,309 @@
|
||||
# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||
|
||||
import numbers
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.normalization import AdaLayerNormContinuous
|
||||
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
from diffusers.utils.import_utils import is_torch_version
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor, nn
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
from comfy.ldm.flux.layers import (timestep_embedding)
|
||||
from comfy.ldm.flux.model import Flux
|
||||
|
||||
if is_torch_version(">=", "2.1.0"):
|
||||
LayerNorm = nn.LayerNorm
|
||||
else:
|
||||
# Has optional bias parameter compared to torch layer norm
|
||||
# TODO: replace with torch layernorm once min required torch version >= 2.1
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
|
||||
if isinstance(dim, numbers.Integral):
|
||||
dim = (dim,)
|
||||
|
||||
self.dim = torch.Size(dim)
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
|
||||
else:
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
class FluxUnionControlNetModeEmbedder(nn.Module):
|
||||
def __init__(self, num_mode, out_channels):
|
||||
super().__init__()
|
||||
self.mode_embber = nn.Embedding(num_mode, out_channels)
|
||||
self.norm = nn.LayerNorm(out_channels, eps=1e-6)
|
||||
self.fc = nn.Linear(out_channels, out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x_emb = self.mode_embber(x)
|
||||
x_emb = self.norm(x_emb)
|
||||
x_emb = self.fc(x_emb)
|
||||
x_emb = x_emb[:, 0]
|
||||
return x_emb
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
|
||||
|
||||
# YiYi to-do: refactor rope related functions/classes
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
class FluxUnionControlNetInputEmbedder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_attention_heads=24, mlp_ratio=4.0, attention_head_dim=128, dtype=None, device=None, operations=None, depth=2):
|
||||
super().__init__()
|
||||
self.x_embedder = nn.Sequential(nn.LayerNorm(in_channels), nn.Linear(in_channels, out_channels))
|
||||
self.norm = AdaLayerNormContinuous(out_channels, out_channels, elementwise_affine=False, eps=1e-6)
|
||||
self.fc = nn.Linear(out_channels, out_channels)
|
||||
self.emb_embedder = nn.Sequential(nn.LayerNorm(out_channels), nn.Linear(out_channels, out_channels))
|
||||
|
||||
""" self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(
|
||||
out_channels, num_attention_heads, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
) """
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxSingleTransformerBlock(
|
||||
dim=out_channels,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.out = zero_module(nn.Linear(out_channels, out_channels))
|
||||
|
||||
def forward(self, x, mode_emb):
|
||||
mode_token = self.emb_embedder(mode_emb)[:, None]
|
||||
x_emb = self.fc(self.norm(self.x_embedder(x), mode_emb))
|
||||
hidden_states = torch.cat([mode_token, x_emb], dim=1)
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=mode_emb,
|
||||
)
|
||||
hidden_states = self.out(hidden_states)
|
||||
res = hidden_states[:, 1:]
|
||||
return res
|
||||
|
||||
|
||||
class InstantXControlNetFlux(Flux):
|
||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, joint_attention_dim=4096, **kwargs):
|
||||
kwargs["depth"] = 0
|
||||
kwargs["depth_single_blocks"] = 0
|
||||
depth_single_blocks_controlnet = kwargs.pop("depth_single_blocks_controlnet", 2)
|
||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxTransformerBlock(
|
||||
dim=self.hidden_size,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
).to(dtype=dtype)
|
||||
for i in range(5)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxSingleTransformerBlock(
|
||||
dim=self.hidden_size,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
).to(dtype=dtype)
|
||||
for i in range(10)
|
||||
]
|
||||
)
|
||||
|
||||
self.require_vae = True
|
||||
# add ControlNet blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.single_transformer_blocks)):
|
||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_single_blocks.append(controlnet_block)
|
||||
|
||||
# TODO support both union and unimodal
|
||||
self.union = True # num_mode is not None
|
||||
num_mode = 10
|
||||
if self.union:
|
||||
self.controlnet_mode_embedder = zero_module(FluxUnionControlNetModeEmbedder(num_mode, self.hidden_size)).to(device=device, dtype=dtype)
|
||||
self.controlnet_x_embedder = FluxUnionControlNetInputEmbedder(self.in_channels, self.hidden_size, operations=operations, depth=depth_single_blocks_controlnet).to(device=device, dtype=dtype)
|
||||
self.controlnet_mode_token_embedder = nn.Sequential(nn.LayerNorm(self.hidden_size, eps=1e-6), nn.Linear(self.hidden_size, self.hidden_size)).to(device=device, dtype=dtype)
|
||||
else:
|
||||
self.controlnet_x_embedder = zero_module(torch.nn.Linear(self.in_channels, self.hidden_size)).to(device=device, dtype=dtype)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
def set_hint_latents(self, hint_latents):
|
||||
vae_shift_factor = 0.1159
|
||||
vae_scaling_factor = 0.3611
|
||||
num_channels_latents = self.in_channels // 4
|
||||
hint_latents = (hint_latents - vae_shift_factor) * vae_scaling_factor
|
||||
|
||||
height, width = hint_latents.shape[2:]
|
||||
hint_latents = self._pack_latents(
|
||||
hint_latents,
|
||||
hint_latents.shape[0],
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
)
|
||||
self.hint_latents = hint_latents.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
controlnet_cond: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
controlnet_mode=None
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
batch_size = img.shape[0]
|
||||
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256).to(self.dtype))
|
||||
if self.params.guidance_embed:
|
||||
vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(self.dtype)))
|
||||
vec.add_(self.vector_in(y))
|
||||
|
||||
if self.union:
|
||||
if controlnet_mode is None:
|
||||
raise ValueError('using union-controlnet, but controlnet_mode is not a list or is empty')
|
||||
controlnet_mode = torch.tensor([[controlnet_mode]], device=self.device)
|
||||
emb_controlnet_mode = self.controlnet_mode_embedder(controlnet_mode).to(self.dtype)
|
||||
vec = vec + emb_controlnet_mode
|
||||
img = img + self.controlnet_x_embedder(controlnet_cond, emb_controlnet_mode)
|
||||
else:
|
||||
img = img + self.controlnet_x_embedder(controlnet_cond)
|
||||
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
if self.union:
|
||||
token_controlnet_mode = self.controlnet_mode_token_embedder(emb_controlnet_mode)[:, None]
|
||||
token_controlnet_mode = token_controlnet_mode.expand(txt.size(0), -1, -1)
|
||||
txt = torch.cat([token_controlnet_mode, txt], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids).to(dtype=self.dtype, device=self.device)
|
||||
|
||||
block_res_samples = ()
|
||||
for block in self.transformer_blocks:
|
||||
txt, img = block(hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe)
|
||||
block_res_samples = block_res_samples + (img,)
|
||||
|
||||
img = torch.cat([txt, img], dim=1)
|
||||
|
||||
single_block_res_samples = ()
|
||||
for block in self.single_transformer_blocks:
|
||||
img = block(hidden_states=img, temb=vec, image_rotary_emb=pe)
|
||||
single_block_res_samples = single_block_res_samples + (img[:, txt.shape[1]:],)
|
||||
|
||||
controlnet_block_res_samples = ()
|
||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||
|
||||
controlnet_single_block_res_samples = ()
|
||||
for single_block_res_sample, single_controlnet_block in zip(single_block_res_samples, self.controlnet_single_blocks):
|
||||
single_block_res_sample = single_controlnet_block(single_block_res_sample)
|
||||
controlnet_single_block_res_samples = controlnet_single_block_res_samples + (single_block_res_sample,)
|
||||
|
||||
n_single_blocks = 38
|
||||
n_double_blocks = 19
|
||||
|
||||
# Expand controlnet_block_res_samples to match n_double_blocks
|
||||
expanded_controlnet_block_res_samples = []
|
||||
interval_control_double = int(np.ceil(n_double_blocks / len(controlnet_block_res_samples)))
|
||||
for i in range(n_double_blocks):
|
||||
index = i // interval_control_double
|
||||
expanded_controlnet_block_res_samples.append(controlnet_block_res_samples[index])
|
||||
|
||||
# Expand controlnet_single_block_res_samples to match n_single_blocks
|
||||
expanded_controlnet_single_block_res_samples = []
|
||||
interval_control_single = int(np.ceil(n_single_blocks / len(controlnet_single_block_res_samples)))
|
||||
for i in range(n_single_blocks):
|
||||
index = i // interval_control_single
|
||||
expanded_controlnet_single_block_res_samples.append(controlnet_single_block_res_samples[index])
|
||||
|
||||
return {
|
||||
"input": expanded_controlnet_block_res_samples,
|
||||
"output": expanded_controlnet_single_block_res_samples
|
||||
}
|
||||
|
||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
height_control_image, width_control_image = hint.shape[2:]
|
||||
num_channels_latents = self.in_channels // 4
|
||||
hint = self._pack_latents(
|
||||
hint,
|
||||
hint.shape[0],
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type)
|
||||
227
comfy/ldm/flux/controlnet_instantx_format2.py
Normal file
227
comfy/ldm/flux/controlnet_instantx_format2.py
Normal file
@ -0,0 +1,227 @@
|
||||
# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||
|
||||
import numbers
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
from diffusers.utils.import_utils import is_torch_version
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .layers import timestep_embedding
|
||||
from .model import Flux
|
||||
from ..common_dit import pad_to_patch_size
|
||||
|
||||
if is_torch_version(">=", "2.1.0"):
|
||||
LayerNorm = nn.LayerNorm
|
||||
else:
|
||||
# Has optional bias parameter compared to torch layer norm
|
||||
# TODO: replace with torch layernorm once min required torch version >= 2.1
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
|
||||
if isinstance(dim, numbers.Integral):
|
||||
dim = (dim,)
|
||||
|
||||
self.dim = torch.Size(dim)
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
|
||||
else:
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
|
||||
|
||||
# YiYi to-do: refactor rope related functions/classes
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
class InstantXControlNetFluxFormat2(Flux):
|
||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, joint_attention_dim=4096, **kwargs):
|
||||
kwargs["depth"] = 0
|
||||
kwargs["depth_single_blocks"] = 0
|
||||
depth_single_blocks_controlnet = kwargs.pop("depth_single_blocks_controlnet", 2)
|
||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxTransformerBlock(
|
||||
dim=self.hidden_size,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
).to(dtype=dtype)
|
||||
for i in range(5)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxSingleTransformerBlock(
|
||||
dim=self.hidden_size,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
).to(dtype=dtype)
|
||||
for i in range(10)
|
||||
]
|
||||
)
|
||||
|
||||
self.require_vae = True
|
||||
# add ControlNet blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.single_transformer_blocks)):
|
||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_single_blocks.append(controlnet_block)
|
||||
|
||||
# TODO support both union and unimodal
|
||||
self.union = True # num_mode is not None
|
||||
num_mode = 10
|
||||
if self.union:
|
||||
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.hidden_size)
|
||||
self.controlnet_x_embedder = zero_module(operations.Linear(self.in_channels, self.hidden_size).to(device=device, dtype=dtype))
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
controlnet_cond: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
controlnet_mode=None
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
batch_size = img.shape[0]
|
||||
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256).to(self.dtype))
|
||||
if self.params.guidance_embed:
|
||||
vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(self.dtype)))
|
||||
vec.add_(self.vector_in(y))
|
||||
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
if self.union:
|
||||
if controlnet_mode is None:
|
||||
raise ValueError('using union-controlnet, but controlnet_mode is not a list or is empty')
|
||||
controlnet_mode = torch.tensor(controlnet_mode).to(self.device, dtype=torch.long)
|
||||
controlnet_mode = controlnet_mode.reshape([-1, 1])
|
||||
emb_controlnet_mode = self.controlnet_mode_embedder(controlnet_mode).to(self.dtype)
|
||||
txt = torch.cat([emb_controlnet_mode, txt], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1)
|
||||
|
||||
img = img + self.controlnet_x_embedder(controlnet_cond)
|
||||
|
||||
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
block_res_samples = ()
|
||||
for block in self.transformer_blocks:
|
||||
txt, img = block(hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe)
|
||||
block_res_samples = block_res_samples + (img,)
|
||||
|
||||
img = torch.cat([txt, img], dim=1)
|
||||
|
||||
single_block_res_samples = ()
|
||||
for block in self.single_transformer_blocks:
|
||||
img = block(hidden_states=img, temb=vec, image_rotary_emb=pe)
|
||||
single_block_res_samples = single_block_res_samples + (img[:, txt.shape[1]:],)
|
||||
|
||||
controlnet_block_res_samples = ()
|
||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||
|
||||
controlnet_single_block_res_samples = ()
|
||||
for single_block_res_sample, single_controlnet_block in zip(single_block_res_samples, self.controlnet_single_blocks):
|
||||
single_block_res_sample = single_controlnet_block(single_block_res_sample)
|
||||
controlnet_single_block_res_samples = controlnet_single_block_res_samples + (single_block_res_sample,)
|
||||
|
||||
n_single_blocks = 38
|
||||
n_double_blocks = 19
|
||||
|
||||
# Expand controlnet_block_res_samples to match n_double_blocks
|
||||
expanded_controlnet_block_res_samples = []
|
||||
interval_control_double = int(np.ceil(n_double_blocks / len(controlnet_block_res_samples)))
|
||||
for i in range(n_double_blocks):
|
||||
index = i // interval_control_double
|
||||
expanded_controlnet_block_res_samples.append(controlnet_block_res_samples[index])
|
||||
|
||||
# Expand controlnet_single_block_res_samples to match n_single_blocks
|
||||
expanded_controlnet_single_block_res_samples = []
|
||||
interval_control_single = int(np.ceil(n_single_blocks / len(controlnet_single_block_res_samples)))
|
||||
for i in range(n_single_blocks):
|
||||
index = i // interval_control_single
|
||||
expanded_controlnet_single_block_res_samples.append(controlnet_single_block_res_samples[index])
|
||||
|
||||
return {
|
||||
"input": expanded_controlnet_block_res_samples,
|
||||
"output": expanded_controlnet_single_block_res_samples
|
||||
}
|
||||
|
||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
height_control_image, width_control_image = hint.shape[2:]
|
||||
num_channels_latents = self.in_channels // 4
|
||||
hint = self._pack_latents(
|
||||
hint,
|
||||
hint.shape[0],
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type)
|
||||
@ -40,6 +40,7 @@ class Flux(nn.Module):
|
||||
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
self.params = params
|
||||
|
||||
1
comfy/ldm/flux/weight_dtypes.py
Normal file
1
comfy/ldm/flux/weight_dtypes.py
Normal file
@ -0,0 +1 @@
|
||||
FLUX_WEIGHT_DTYPES = ["default", "fp8_e4m3fn", "fp8_e5m2"]
|
||||
@ -385,6 +385,8 @@ KNOWN_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
||||
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0-promax.safetensors"),
|
||||
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0.safetensors"),
|
||||
HuggingFile("InstantX/FLUX.1-dev-Controlnet-Canny", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-canny.safetensors"),
|
||||
HuggingFile("InstantX/FLUX.1-dev-Controlnet-Union", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-union.safetensors"),
|
||||
HuggingFile("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", save_with_filename="shakker-labs-flux.1-dev-controlnet-union-pro.safetensors"),
|
||||
], folder_name="controlnet")
|
||||
|
||||
KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
||||
|
||||
@ -27,6 +27,7 @@ from ..cmd import folder_paths, latent_preview
|
||||
from ..component_model.tensor_types import RGBImage
|
||||
from ..execution_context import current_execution_context
|
||||
from ..images import open_image
|
||||
from ..ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
|
||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
|
||||
from ..nodes.common import MAX_RESOLUTION
|
||||
from .. import controlnet
|
||||
@ -756,6 +757,27 @@ class ControlNetLoader:
|
||||
controlnet_ = controlnet.load_controlnet(controlnet_path)
|
||||
return (controlnet_,)
|
||||
|
||||
|
||||
class ControlNetLoaderWeights:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"control_net_name": (get_filename_list_with_downloadable("controlnet", KNOWN_CONTROLNETS),),
|
||||
"weight_dtype": (FLUX_WEIGHT_DTYPES,)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("CONTROL_NET",)
|
||||
FUNCTION = "load_controlnet"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_controlnet(self, control_net_name, weight_dtype):
|
||||
controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_CONTROLNETS)
|
||||
controlnet_ = controlnet.load_controlnet(controlnet_path, weight_dtype=weight_dtype)
|
||||
return (controlnet_,)
|
||||
|
||||
class DiffControlNetLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -854,7 +876,7 @@ class UNETLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "unet_name": (get_filename_list_with_downloadable("diffusion_models", KNOWN_UNET_MODELS),),
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
||||
"weight_dtype": (FLUX_WEIGHT_DTYPES,)
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_unet"
|
||||
@ -1882,6 +1904,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"ControlNetApply": ControlNetApply,
|
||||
"ControlNetApplyAdvanced": ControlNetApplyAdvanced,
|
||||
"ControlNetLoader": ControlNetLoader,
|
||||
"ControlNetLoaderWeights": ControlNetLoaderWeights,
|
||||
"DiffControlNetLoader": DiffControlNetLoader,
|
||||
"StyleModelLoader": StyleModelLoader,
|
||||
"CLIPVisionLoader": CLIPVisionLoader,
|
||||
@ -1914,6 +1937,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LoraLoader": "Load LoRA",
|
||||
"CLIPLoader": "Load CLIP",
|
||||
"ControlNetLoader": "Load ControlNet Model",
|
||||
"ControlNetLoaderWeights": "Load ControlNet Model (Weights)",
|
||||
"DiffControlNetLoader": "Load ControlNet Model (diff)",
|
||||
"StyleModelLoader": "Load Style Model",
|
||||
"CLIPVisionLoader": "Load CLIP Vision",
|
||||
|
||||
314
tests/inference/workflows/flux-controlnet-0.json
Normal file
314
tests/inference/workflows/flux-controlnet-0.json
Normal file
@ -0,0 +1,314 @@
|
||||
{
|
||||
"1": {
|
||||
"inputs": {
|
||||
"noise": [
|
||||
"2",
|
||||
0
|
||||
],
|
||||
"guider": [
|
||||
"3",
|
||||
0
|
||||
],
|
||||
"sampler": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"sigmas": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"9",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SamplerCustomAdvanced",
|
||||
"_meta": {
|
||||
"title": "SamplerCustomAdvanced"
|
||||
}
|
||||
},
|
||||
"2": {
|
||||
"inputs": {
|
||||
"noise_seed": 868192789249410
|
||||
},
|
||||
"class_type": "RandomNoise",
|
||||
"_meta": {
|
||||
"title": "RandomNoise"
|
||||
}
|
||||
},
|
||||
"3": {
|
||||
"inputs": {
|
||||
"model": [
|
||||
"12",
|
||||
0
|
||||
],
|
||||
"conditioning": [
|
||||
"17",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "BasicGuider",
|
||||
"_meta": {
|
||||
"title": "BasicGuider"
|
||||
}
|
||||
},
|
||||
"4": {
|
||||
"inputs": {
|
||||
"guidance": 4,
|
||||
"conditioning": [
|
||||
"13",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "FluxGuidance",
|
||||
"_meta": {
|
||||
"title": "FluxGuidance"
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"inputs": {
|
||||
"sampler_name": "euler"
|
||||
},
|
||||
"class_type": "KSamplerSelect",
|
||||
"_meta": {
|
||||
"title": "KSamplerSelect"
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"inputs": {
|
||||
"scheduler": "ddim_uniform",
|
||||
"steps": 10,
|
||||
"denoise": 1,
|
||||
"model": [
|
||||
"12",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "BasicScheduler",
|
||||
"_meta": {
|
||||
"title": "BasicScheduler"
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"width": [
|
||||
"27",
|
||||
4
|
||||
],
|
||||
"height": [
|
||||
"27",
|
||||
5
|
||||
],
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptySD3LatentImage",
|
||||
"_meta": {
|
||||
"title": "EmptySD3LatentImage"
|
||||
}
|
||||
},
|
||||
"10": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"1",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"11",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"11": {
|
||||
"inputs": {
|
||||
"vae_name": "ae.safetensors"
|
||||
},
|
||||
"class_type": "VAELoader",
|
||||
"_meta": {
|
||||
"title": "Load VAE"
|
||||
}
|
||||
},
|
||||
"12": {
|
||||
"inputs": {
|
||||
"unet_name": "flux1-dev.safetensors",
|
||||
"weight_dtype": "fp8_e4m3fn"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "Load Diffusion Model"
|
||||
}
|
||||
},
|
||||
"13": {
|
||||
"inputs": {
|
||||
"text": "A photo of a girl.",
|
||||
"clip": [
|
||||
"15",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"15": {
|
||||
"inputs": {
|
||||
"clip_name1": "clip_l.safetensors",
|
||||
"clip_name2": "t5xxl_fp16.safetensors",
|
||||
"type": "flux"
|
||||
},
|
||||
"class_type": "DualCLIPLoader",
|
||||
"_meta": {
|
||||
"title": "DualCLIPLoader"
|
||||
}
|
||||
},
|
||||
"16": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"10",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
},
|
||||
"17": {
|
||||
"inputs": {
|
||||
"strength": 0.6,
|
||||
"start_percent": 0,
|
||||
"end_percent": 1,
|
||||
"positive": [
|
||||
"4",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"21",
|
||||
0
|
||||
],
|
||||
"control_net": [
|
||||
"24",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"11",
|
||||
0
|
||||
],
|
||||
"image": [
|
||||
"25",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ControlNetApplySD3",
|
||||
"_meta": {
|
||||
"title": "ControlNetApply SD3 and HunyuanDiT"
|
||||
}
|
||||
},
|
||||
"19": {
|
||||
"inputs": {
|
||||
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
|
||||
"name": "",
|
||||
"title": "",
|
||||
"description": "",
|
||||
"__required": true
|
||||
},
|
||||
"class_type": "ImageRequestParameter",
|
||||
"_meta": {
|
||||
"title": "ImageRequestParameter"
|
||||
}
|
||||
},
|
||||
"21": {
|
||||
"inputs": {
|
||||
"text": "",
|
||||
"clip": [
|
||||
"15",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"24": {
|
||||
"inputs": {
|
||||
"type": "canny (InstantX)",
|
||||
"control_net": [
|
||||
"28",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SetUnionControlNetType",
|
||||
"_meta": {
|
||||
"title": "SetUnionControlNetType"
|
||||
}
|
||||
},
|
||||
"25": {
|
||||
"inputs": {
|
||||
"low_threshold": 0.4,
|
||||
"high_threshold": 0.8,
|
||||
"image": [
|
||||
"31",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "Canny",
|
||||
"_meta": {
|
||||
"title": "Canny"
|
||||
}
|
||||
},
|
||||
"26": {
|
||||
"inputs": {
|
||||
"images": [
|
||||
"25",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "PreviewImage",
|
||||
"_meta": {
|
||||
"title": "Preview Image"
|
||||
}
|
||||
},
|
||||
"27": {
|
||||
"inputs": {
|
||||
"image": [
|
||||
"25",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "Image Size to Number",
|
||||
"_meta": {
|
||||
"title": "Image Size to Number"
|
||||
}
|
||||
},
|
||||
"28": {
|
||||
"inputs": {
|
||||
"control_net_name": "shakker-labs-flux.1-dev-controlnet-union-pro.safetensors",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "ControlNetLoaderWeights",
|
||||
"_meta": {
|
||||
"title": "Load ControlNet Model (Weights)"
|
||||
}
|
||||
},
|
||||
"31": {
|
||||
"inputs": {
|
||||
"image_gen_width": 1024,
|
||||
"image_gen_height": 1024,
|
||||
"resize_mode": "Crop and Resize",
|
||||
"hint_image": [
|
||||
"19",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "HintImageEnchance",
|
||||
"_meta": {
|
||||
"title": "Enchance And Resize Hint Images"
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user