mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 12:50:18 +08:00
Upstream InstantX Union ControlNet support for Flux
This commit is contained in:
parent
48ca1a4910
commit
f49bcd4f3c
@ -7,4 +7,11 @@ UNION_CONTROLNET_TYPES = {
|
|||||||
"segment": 5,
|
"segment": 5,
|
||||||
"tile": 6,
|
"tile": 6,
|
||||||
"repaint": 7,
|
"repaint": 7,
|
||||||
|
"canny (InstantX)": 0,
|
||||||
|
"tile (InstantX)": 1,
|
||||||
|
"depth (InstantX)": 2,
|
||||||
|
"blur (InstantX)": 3,
|
||||||
|
"pose (InstantX)": 4,
|
||||||
|
"gray (InstantX)": 5,
|
||||||
|
"lq (InstantX)": 6
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,29 +16,31 @@
|
|||||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import torch
|
|
||||||
from enum import Enum
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import logging
|
from enum import Enum
|
||||||
|
|
||||||
from . import utils
|
import torch
|
||||||
from . import model_management
|
|
||||||
|
from . import latent_formats
|
||||||
from . import model_detection
|
from . import model_detection
|
||||||
|
from . import model_management
|
||||||
from . import model_patcher
|
from . import model_patcher
|
||||||
from . import ops
|
from . import ops
|
||||||
from . import latent_formats
|
from . import utils
|
||||||
|
|
||||||
from .cldm import cldm, mmdit
|
from .cldm import cldm, mmdit
|
||||||
from .t2i_adapter import adapter
|
|
||||||
from .ldm import hydit, flux
|
from .ldm import hydit, flux
|
||||||
from .ldm.cascade import controlnet as cascade_controlnet
|
from .ldm.cascade import controlnet as cascade_controlnet
|
||||||
|
from .ldm.flux.controlnet_instantx import InstantXControlNetFlux
|
||||||
|
from .ldm.flux.controlnet_instantx_format2 import InstantXControlNetFluxFormat2
|
||||||
|
from .ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
|
||||||
|
from .t2i_adapter import adapter
|
||||||
|
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
current_batch_size = tensor.shape[0]
|
current_batch_size = tensor.shape[0]
|
||||||
#print(current_batch_size, target_batch_size)
|
# print(current_batch_size, target_batch_size)
|
||||||
if current_batch_size == 1:
|
if current_batch_size == 1:
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
@ -54,10 +56,12 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
|||||||
else:
|
else:
|
||||||
return torch.cat([tensor] * batched_number, dim=0)
|
return torch.cat([tensor] * batched_number, dim=0)
|
||||||
|
|
||||||
|
|
||||||
class StrengthType(Enum):
|
class StrengthType(Enum):
|
||||||
CONSTANT = 1
|
CONSTANT = 1
|
||||||
LINEAR_UP = 2
|
LINEAR_UP = 2
|
||||||
|
|
||||||
|
|
||||||
class ControlBase:
|
class ControlBase:
|
||||||
def __init__(self, device=None):
|
def __init__(self, device=None):
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
@ -129,7 +133,7 @@ class ControlBase:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
def control_merge(self, control, control_prev, output_dtype):
|
def control_merge(self, control, control_prev, output_dtype):
|
||||||
out = {'input':[], 'middle':[], 'output': []}
|
out = {'input': [], 'middle': [], 'output': []}
|
||||||
|
|
||||||
for key in control:
|
for key in control:
|
||||||
control_output = control[key]
|
control_output = control[key]
|
||||||
@ -140,7 +144,7 @@ class ControlBase:
|
|||||||
if self.global_average_pooling:
|
if self.global_average_pooling:
|
||||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||||
|
|
||||||
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
if x not in applied_to: # memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
||||||
applied_to.add(x)
|
applied_to.add(x)
|
||||||
if self.strength_type == StrengthType.CONSTANT:
|
if self.strength_type == StrengthType.CONSTANT:
|
||||||
x *= self.strength
|
x *= self.strength
|
||||||
@ -166,7 +170,7 @@ class ControlBase:
|
|||||||
if o[i].shape[0] < prev_val.shape[0]:
|
if o[i].shape[0] < prev_val.shape[0]:
|
||||||
o[i] = prev_val + o[i]
|
o[i] = prev_val + o[i]
|
||||||
else:
|
else:
|
||||||
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
|
o[i] = prev_val + o[i] # TODO: change back to inplace add if shared tensors stop being an issue
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def set_extra_arg(self, argument, value=None):
|
def set_extra_arg(self, argument, value=None):
|
||||||
@ -258,10 +262,11 @@ class ControlNet(ControlBase):
|
|||||||
self.model_sampling_current = None
|
self.model_sampling_current = None
|
||||||
super().cleanup()
|
super().cleanup()
|
||||||
|
|
||||||
|
|
||||||
class ControlLoraOps:
|
class ControlLoraOps:
|
||||||
class Linear(torch.nn.Module, ops.CastWeightBiasOp):
|
class Linear(torch.nn.Module, ops.CastWeightBiasOp):
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||||
device=None, dtype=None) -> None:
|
device=None, dtype=None) -> None:
|
||||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
@ -280,18 +285,18 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
class Conv2d(torch.nn.Module, ops.CastWeightBiasOp):
|
class Conv2d(torch.nn.Module, ops.CastWeightBiasOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
dilation=1,
|
dilation=1,
|
||||||
groups=1,
|
groups=1,
|
||||||
bias=True,
|
bias=True,
|
||||||
padding_mode='zeros',
|
padding_mode='zeros',
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None
|
dtype=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
@ -310,7 +315,6 @@ class ControlLoraOps:
|
|||||||
self.up = None
|
self.up = None
|
||||||
self.down = None
|
self.down = None
|
||||||
|
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
weight, bias = ops.cast_bias_weight(self, input)
|
weight, bias = ops.cast_bias_weight(self, input)
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
@ -339,6 +343,7 @@ class ControlLora(ControlNet):
|
|||||||
else:
|
else:
|
||||||
class control_lora_ops(ControlLoraOps, ops.manual_cast):
|
class control_lora_ops(ControlLoraOps, ops.manual_cast):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
dtype = self.manual_cast_dtype
|
dtype = self.manual_cast_dtype
|
||||||
|
|
||||||
controlnet_config["operations"] = control_lora_ops
|
controlnet_config["operations"] = control_lora_ops
|
||||||
@ -377,6 +382,7 @@ class ControlLora(ControlNet):
|
|||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||||
|
|
||||||
|
|
||||||
def controlnet_config(sd):
|
def controlnet_config(sd):
|
||||||
model_config = model_detection.model_config_from_unet(sd, "", True)
|
model_config = model_detection.model_config_from_unet(sd, "", True)
|
||||||
|
|
||||||
@ -394,6 +400,7 @@ def controlnet_config(sd):
|
|||||||
offload_device = model_management.unet_offload_device()
|
offload_device = model_management.unet_offload_device()
|
||||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||||
|
|
||||||
|
|
||||||
def controlnet_load_state_dict(control_model, sd):
|
def controlnet_load_state_dict(control_model, sd):
|
||||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
@ -404,6 +411,7 @@ def controlnet_load_state_dict(control_model, sd):
|
|||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
return control_model
|
return control_model
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_mmdit(sd):
|
def load_controlnet_mmdit(sd):
|
||||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||||
@ -415,7 +423,7 @@ def load_controlnet_mmdit(sd):
|
|||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = latent_formats.SD3()
|
latent_format = latent_formats.SD3()
|
||||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
latent_format.shift_factor = 0 # SD3 controlnet weirdness
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
@ -431,6 +439,62 @@ def load_controlnet_hunyuandit(controlnet_data):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet_flux_instantx(sd, controlnet_class, weight_dtype, full_path):
|
||||||
|
keys_to_keep = [
|
||||||
|
"controlnet_",
|
||||||
|
"single_transformer_blocks",
|
||||||
|
"transformer_blocks"
|
||||||
|
]
|
||||||
|
preserved_keys = {k: v.cpu() for k, v in sd.items() if any(k.startswith(key) for key in keys_to_keep)}
|
||||||
|
|
||||||
|
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
|
|
||||||
|
keys_to_discard = [
|
||||||
|
"double_blocks",
|
||||||
|
"single_blocks"
|
||||||
|
]
|
||||||
|
new_sd = {k: v for k, v in new_sd.items() if not any(k.startswith(discard_key) for discard_key in keys_to_discard)}
|
||||||
|
new_sd.update(preserved_keys)
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"image_model": "flux",
|
||||||
|
"axes_dim": [16, 56, 56],
|
||||||
|
"in_channels": 16,
|
||||||
|
"depth": 5,
|
||||||
|
"depth_single_blocks": 10,
|
||||||
|
"context_in_dim": 4096,
|
||||||
|
"num_heads": 24,
|
||||||
|
"guidance_embed": True,
|
||||||
|
"hidden_size": 3072,
|
||||||
|
"mlp_ratio": 4.0,
|
||||||
|
"theta": 10000,
|
||||||
|
"qkv_bias": True,
|
||||||
|
"vec_in_dim": 768
|
||||||
|
}
|
||||||
|
|
||||||
|
device = model_management.get_torch_device()
|
||||||
|
|
||||||
|
if weight_dtype == "fp8_e4m3fn":
|
||||||
|
dtype = torch.float8_e4m3fn
|
||||||
|
operations = ops.manual_cast
|
||||||
|
elif weight_dtype == "fp8_e5m2":
|
||||||
|
dtype = torch.float8_e5m2
|
||||||
|
operations = ops.manual_cast
|
||||||
|
else:
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
operations = ops.disable_weight_init
|
||||||
|
|
||||||
|
control_model = controlnet_class(operations=operations, device=device, dtype=dtype, **config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
extra_conds = ['y', 'guidance', 'control_type']
|
||||||
|
latent_format = latent_formats.SD3()
|
||||||
|
# TODO check manual cast dtype
|
||||||
|
control = ControlNet(control_model, compression_ratio=1, load_device=device, manual_cast_dtype=torch.bfloat16,
|
||||||
|
extra_conds=extra_conds, latent_format=latent_format, ckpt_name=full_path)
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_flux_xlabs(sd):
|
def load_controlnet_flux_xlabs(sd):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
||||||
control_model = flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
@ -440,9 +504,13 @@ def load_controlnet_flux_xlabs(sd):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
|
||||||
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
if "controlnet_mode_embedder.weight" in controlnet_data:
|
||||||
|
return load_controlnet_flux_instantx(controlnet_data, InstantXControlNetFluxFormat2, weight_dtype, ckpt_path)
|
||||||
|
if "controlnet_mode_embedder.fc.weight" in controlnet_data:
|
||||||
|
return load_controlnet_flux_instantx(controlnet_data, InstantXControlNetFlux, weight_dtype, ckpt_path)
|
||||||
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT
|
||||||
return load_controlnet_hunyuandit(controlnet_data)
|
return load_controlnet_hunyuandit(controlnet_data)
|
||||||
if "lora_controlnet" in controlnet_data:
|
if "lora_controlnet" in controlnet_data:
|
||||||
return ControlLora(controlnet_data)
|
return ControlLora(controlnet_data)
|
||||||
@ -450,7 +518,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
controlnet_config = None
|
controlnet_config = None
|
||||||
supported_inference_dtypes = None
|
supported_inference_dtypes = None
|
||||||
|
|
||||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: # diffusers format
|
||||||
controlnet_config = model_detection.unet_config_from_diffusers_unet(controlnet_data)
|
controlnet_config = model_detection.unet_config_from_diffusers_unet(controlnet_data)
|
||||||
diffusers_keys = utils.unet_to_diffusers(controlnet_config)
|
diffusers_keys = utils.unet_to_diffusers(controlnet_config)
|
||||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||||
@ -490,7 +558,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if k in controlnet_data:
|
if k in controlnet_data:
|
||||||
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
||||||
|
|
||||||
if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
|
if "control_add_embedding.linear_1.bias" in controlnet_data: # Union Controlnet
|
||||||
controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
|
controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
|
||||||
for k in list(controlnet_data.keys()):
|
for k in list(controlnet_data.keys()):
|
||||||
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
|
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
|
||||||
@ -500,7 +568,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if len(leftover_keys) > 0:
|
if len(leftover_keys) > 0:
|
||||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||||
controlnet_data = new_sd
|
controlnet_data = new_sd
|
||||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
elif "controlnet_blocks.0.weight" in controlnet_data: # SD3 diffusers format
|
||||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||||
return load_controlnet_flux_xlabs(controlnet_data)
|
return load_controlnet_flux_xlabs(controlnet_data)
|
||||||
else:
|
else:
|
||||||
@ -558,6 +626,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
|
|
||||||
class WeightsLoader(torch.nn.Module):
|
class WeightsLoader(torch.nn.Module):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
w = WeightsLoader()
|
w = WeightsLoader()
|
||||||
w.control_model = control_model
|
w.control_model = control_model
|
||||||
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
||||||
@ -572,12 +641,13 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
|
|
||||||
global_average_pooling = False
|
global_average_pooling = False
|
||||||
filename = os.path.splitext(ckpt_path)[0]
|
filename = os.path.splitext(ckpt_path)[0]
|
||||||
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): # TODO: smarter way of enabling global_average_pooling
|
||||||
global_average_pooling = True
|
global_average_pooling = True
|
||||||
|
|
||||||
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype, ckpt_name=filename)
|
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype, ckpt_name=filename)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapter(ControlBase):
|
class T2IAdapter(ControlBase):
|
||||||
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
@ -633,13 +703,14 @@ class T2IAdapter(ControlBase):
|
|||||||
self.copy_to(c)
|
self.copy_to(c)
|
||||||
return c
|
return c
|
||||||
|
|
||||||
|
|
||||||
def load_t2i_adapter(t2i_data):
|
def load_t2i_adapter(t2i_data):
|
||||||
compression_ratio = 8
|
compression_ratio = 8
|
||||||
upscale_algorithm = 'nearest-exact'
|
upscale_algorithm = 'nearest-exact'
|
||||||
|
|
||||||
if 'adapter' in t2i_data:
|
if 'adapter' in t2i_data:
|
||||||
t2i_data = t2i_data['adapter']
|
t2i_data = t2i_data['adapter']
|
||||||
if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
|
if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: # diffusers format
|
||||||
prefix_replace = {}
|
prefix_replace = {}
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
for j in range(2):
|
for j in range(2):
|
||||||
@ -663,7 +734,7 @@ def load_t2i_adapter(t2i_data):
|
|||||||
xl = False
|
xl = False
|
||||||
if cin == 256 or cin == 768:
|
if cin == 256 or cin == 768:
|
||||||
xl = True
|
xl = True
|
||||||
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel * 2, channel * 4, channel * 4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
||||||
elif "backbone.0.0.weight" in keys:
|
elif "backbone.0.0.weight" in keys:
|
||||||
model_ad = cascade_controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
model_ad = cascade_controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
|
||||||
compression_ratio = 32
|
compression_ratio = 32
|
||||||
|
|||||||
309
comfy/ldm/flux/controlnet_instantx.py
Normal file
309
comfy/ldm/flux/controlnet_instantx.py
Normal file
@ -0,0 +1,309 @@
|
|||||||
|
# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||||
|
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from diffusers.models.normalization import AdaLayerNormContinuous
|
||||||
|
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||||
|
from diffusers.utils.import_utils import is_torch_version
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
from comfy.ldm.flux.layers import (timestep_embedding)
|
||||||
|
from comfy.ldm.flux.model import Flux
|
||||||
|
|
||||||
|
if is_torch_version(">=", "2.1.0"):
|
||||||
|
LayerNorm = nn.LayerNorm
|
||||||
|
else:
|
||||||
|
# Has optional bias parameter compared to torch layer norm
|
||||||
|
# TODO: replace with torch layernorm once min required torch version >= 2.1
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
if isinstance(dim, numbers.Integral):
|
||||||
|
dim = (dim,)
|
||||||
|
|
||||||
|
self.dim = torch.Size(dim)
|
||||||
|
|
||||||
|
if elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
|
||||||
|
else:
|
||||||
|
self.weight = None
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxUnionControlNetModeEmbedder(nn.Module):
|
||||||
|
def __init__(self, num_mode, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.mode_embber = nn.Embedding(num_mode, out_channels)
|
||||||
|
self.norm = nn.LayerNorm(out_channels, eps=1e-6)
|
||||||
|
self.fc = nn.Linear(out_channels, out_channels)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_emb = self.mode_embber(x)
|
||||||
|
x_emb = self.norm(x_emb)
|
||||||
|
x_emb = self.fc(x_emb)
|
||||||
|
x_emb = x_emb[:, 0]
|
||||||
|
return x_emb
|
||||||
|
|
||||||
|
|
||||||
|
def zero_module(module):
|
||||||
|
for p in module.parameters():
|
||||||
|
nn.init.zeros_(p)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
# YiYi to-do: refactor rope related functions/classes
|
||||||
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxUnionControlNetInputEmbedder(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, num_attention_heads=24, mlp_ratio=4.0, attention_head_dim=128, dtype=None, device=None, operations=None, depth=2):
|
||||||
|
super().__init__()
|
||||||
|
self.x_embedder = nn.Sequential(nn.LayerNorm(in_channels), nn.Linear(in_channels, out_channels))
|
||||||
|
self.norm = AdaLayerNormContinuous(out_channels, out_channels, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.fc = nn.Linear(out_channels, out_channels)
|
||||||
|
self.emb_embedder = nn.Sequential(nn.LayerNorm(out_channels), nn.Linear(out_channels, out_channels))
|
||||||
|
|
||||||
|
""" self.single_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
SingleStreamBlock(
|
||||||
|
out_channels, num_attention_heads, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
for i in range(2)
|
||||||
|
]
|
||||||
|
) """
|
||||||
|
self.single_transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FluxSingleTransformerBlock(
|
||||||
|
dim=out_channels,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
)
|
||||||
|
for i in range(depth)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out = zero_module(nn.Linear(out_channels, out_channels))
|
||||||
|
|
||||||
|
def forward(self, x, mode_emb):
|
||||||
|
mode_token = self.emb_embedder(mode_emb)[:, None]
|
||||||
|
x_emb = self.fc(self.norm(self.x_embedder(x), mode_emb))
|
||||||
|
hidden_states = torch.cat([mode_token, x_emb], dim=1)
|
||||||
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||||
|
hidden_states = block(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
temb=mode_emb,
|
||||||
|
)
|
||||||
|
hidden_states = self.out(hidden_states)
|
||||||
|
res = hidden_states[:, 1:]
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class InstantXControlNetFlux(Flux):
|
||||||
|
def __init__(self, image_model=None, dtype=None, device=None, operations=None, joint_attention_dim=4096, **kwargs):
|
||||||
|
kwargs["depth"] = 0
|
||||||
|
kwargs["depth_single_blocks"] = 0
|
||||||
|
depth_single_blocks_controlnet = kwargs.pop("depth_single_blocks_controlnet", 2)
|
||||||
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FluxTransformerBlock(
|
||||||
|
dim=self.hidden_size,
|
||||||
|
num_attention_heads=24,
|
||||||
|
attention_head_dim=128,
|
||||||
|
).to(dtype=dtype)
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.single_transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FluxSingleTransformerBlock(
|
||||||
|
dim=self.hidden_size,
|
||||||
|
num_attention_heads=24,
|
||||||
|
attention_head_dim=128,
|
||||||
|
).to(dtype=dtype)
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.require_vae = True
|
||||||
|
# add ControlNet blocks
|
||||||
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(len(self.transformer_blocks)):
|
||||||
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
self.controlnet_single_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(len(self.single_transformer_blocks)):
|
||||||
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_single_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
# TODO support both union and unimodal
|
||||||
|
self.union = True # num_mode is not None
|
||||||
|
num_mode = 10
|
||||||
|
if self.union:
|
||||||
|
self.controlnet_mode_embedder = zero_module(FluxUnionControlNetModeEmbedder(num_mode, self.hidden_size)).to(device=device, dtype=dtype)
|
||||||
|
self.controlnet_x_embedder = FluxUnionControlNetInputEmbedder(self.in_channels, self.hidden_size, operations=operations, depth=depth_single_blocks_controlnet).to(device=device, dtype=dtype)
|
||||||
|
self.controlnet_mode_token_embedder = nn.Sequential(nn.LayerNorm(self.hidden_size, eps=1e-6), nn.Linear(self.hidden_size, self.hidden_size)).to(device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.controlnet_x_embedder = zero_module(torch.nn.Linear(self.in_channels, self.hidden_size)).to(device=device, dtype=dtype)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||||
|
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||||
|
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||||
|
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||||
|
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def set_hint_latents(self, hint_latents):
|
||||||
|
vae_shift_factor = 0.1159
|
||||||
|
vae_scaling_factor = 0.3611
|
||||||
|
num_channels_latents = self.in_channels // 4
|
||||||
|
hint_latents = (hint_latents - vae_shift_factor) * vae_scaling_factor
|
||||||
|
|
||||||
|
height, width = hint_latents.shape[2:]
|
||||||
|
hint_latents = self._pack_latents(
|
||||||
|
hint_latents,
|
||||||
|
hint_latents.shape[0],
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
)
|
||||||
|
self.hint_latents = hint_latents.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
controlnet_cond: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
y: Tensor,
|
||||||
|
guidance: Tensor = None,
|
||||||
|
controlnet_mode=None
|
||||||
|
) -> Tensor:
|
||||||
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
batch_size = img.shape[0]
|
||||||
|
|
||||||
|
img = self.img_in(img)
|
||||||
|
vec = self.time_in(timestep_embedding(timesteps, 256).to(self.dtype))
|
||||||
|
if self.params.guidance_embed:
|
||||||
|
vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(self.dtype)))
|
||||||
|
vec.add_(self.vector_in(y))
|
||||||
|
|
||||||
|
if self.union:
|
||||||
|
if controlnet_mode is None:
|
||||||
|
raise ValueError('using union-controlnet, but controlnet_mode is not a list or is empty')
|
||||||
|
controlnet_mode = torch.tensor([[controlnet_mode]], device=self.device)
|
||||||
|
emb_controlnet_mode = self.controlnet_mode_embedder(controlnet_mode).to(self.dtype)
|
||||||
|
vec = vec + emb_controlnet_mode
|
||||||
|
img = img + self.controlnet_x_embedder(controlnet_cond, emb_controlnet_mode)
|
||||||
|
else:
|
||||||
|
img = img + self.controlnet_x_embedder(controlnet_cond)
|
||||||
|
|
||||||
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
if self.union:
|
||||||
|
token_controlnet_mode = self.controlnet_mode_token_embedder(emb_controlnet_mode)[:, None]
|
||||||
|
token_controlnet_mode = token_controlnet_mode.expand(txt.size(0), -1, -1)
|
||||||
|
txt = torch.cat([token_controlnet_mode, txt], dim=1)
|
||||||
|
txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1)
|
||||||
|
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids).to(dtype=self.dtype, device=self.device)
|
||||||
|
|
||||||
|
block_res_samples = ()
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
txt, img = block(hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe)
|
||||||
|
block_res_samples = block_res_samples + (img,)
|
||||||
|
|
||||||
|
img = torch.cat([txt, img], dim=1)
|
||||||
|
|
||||||
|
single_block_res_samples = ()
|
||||||
|
for block in self.single_transformer_blocks:
|
||||||
|
img = block(hidden_states=img, temb=vec, image_rotary_emb=pe)
|
||||||
|
single_block_res_samples = single_block_res_samples + (img[:, txt.shape[1]:],)
|
||||||
|
|
||||||
|
controlnet_block_res_samples = ()
|
||||||
|
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||||
|
block_res_sample = controlnet_block(block_res_sample)
|
||||||
|
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||||
|
|
||||||
|
controlnet_single_block_res_samples = ()
|
||||||
|
for single_block_res_sample, single_controlnet_block in zip(single_block_res_samples, self.controlnet_single_blocks):
|
||||||
|
single_block_res_sample = single_controlnet_block(single_block_res_sample)
|
||||||
|
controlnet_single_block_res_samples = controlnet_single_block_res_samples + (single_block_res_sample,)
|
||||||
|
|
||||||
|
n_single_blocks = 38
|
||||||
|
n_double_blocks = 19
|
||||||
|
|
||||||
|
# Expand controlnet_block_res_samples to match n_double_blocks
|
||||||
|
expanded_controlnet_block_res_samples = []
|
||||||
|
interval_control_double = int(np.ceil(n_double_blocks / len(controlnet_block_res_samples)))
|
||||||
|
for i in range(n_double_blocks):
|
||||||
|
index = i // interval_control_double
|
||||||
|
expanded_controlnet_block_res_samples.append(controlnet_block_res_samples[index])
|
||||||
|
|
||||||
|
# Expand controlnet_single_block_res_samples to match n_single_blocks
|
||||||
|
expanded_controlnet_single_block_res_samples = []
|
||||||
|
interval_control_single = int(np.ceil(n_single_blocks / len(controlnet_single_block_res_samples)))
|
||||||
|
for i in range(n_single_blocks):
|
||||||
|
index = i // interval_control_single
|
||||||
|
expanded_controlnet_single_block_res_samples.append(controlnet_single_block_res_samples[index])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input": expanded_controlnet_block_res_samples,
|
||||||
|
"output": expanded_controlnet_single_block_res_samples
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs):
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
patch_size = 2
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
|
height_control_image, width_control_image = hint.shape[2:]
|
||||||
|
num_channels_latents = self.in_channels // 4
|
||||||
|
hint = self._pack_latents(
|
||||||
|
hint,
|
||||||
|
hint.shape[0],
|
||||||
|
num_channels_latents,
|
||||||
|
height_control_image,
|
||||||
|
width_control_image,
|
||||||
|
)
|
||||||
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type)
|
||||||
227
comfy/ldm/flux/controlnet_instantx_format2.py
Normal file
227
comfy/ldm/flux/controlnet_instantx_format2.py
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||||
|
|
||||||
|
import numbers
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||||
|
from diffusers.utils.import_utils import is_torch_version
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from .layers import timestep_embedding
|
||||||
|
from .model import Flux
|
||||||
|
from ..common_dit import pad_to_patch_size
|
||||||
|
|
||||||
|
if is_torch_version(">=", "2.1.0"):
|
||||||
|
LayerNorm = nn.LayerNorm
|
||||||
|
else:
|
||||||
|
# Has optional bias parameter compared to torch layer norm
|
||||||
|
# TODO: replace with torch layernorm once min required torch version >= 2.1
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
if isinstance(dim, numbers.Integral):
|
||||||
|
dim = (dim,)
|
||||||
|
|
||||||
|
self.dim = torch.Size(dim)
|
||||||
|
|
||||||
|
if elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
|
||||||
|
else:
|
||||||
|
self.weight = None
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
def zero_module(module):
|
||||||
|
for p in module.parameters():
|
||||||
|
nn.init.zeros_(p)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
# YiYi to-do: refactor rope related functions/classes
|
||||||
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||||
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||||
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||||
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||||
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||||
|
|
||||||
|
|
||||||
|
class InstantXControlNetFluxFormat2(Flux):
|
||||||
|
def __init__(self, image_model=None, dtype=None, device=None, operations=None, joint_attention_dim=4096, **kwargs):
|
||||||
|
kwargs["depth"] = 0
|
||||||
|
kwargs["depth_single_blocks"] = 0
|
||||||
|
depth_single_blocks_controlnet = kwargs.pop("depth_single_blocks_controlnet", 2)
|
||||||
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FluxTransformerBlock(
|
||||||
|
dim=self.hidden_size,
|
||||||
|
num_attention_heads=24,
|
||||||
|
attention_head_dim=128,
|
||||||
|
).to(dtype=dtype)
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.single_transformer_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FluxSingleTransformerBlock(
|
||||||
|
dim=self.hidden_size,
|
||||||
|
num_attention_heads=24,
|
||||||
|
attention_head_dim=128,
|
||||||
|
).to(dtype=dtype)
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.require_vae = True
|
||||||
|
# add ControlNet blocks
|
||||||
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(len(self.transformer_blocks)):
|
||||||
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
self.controlnet_single_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(len(self.single_transformer_blocks)):
|
||||||
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_single_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
# TODO support both union and unimodal
|
||||||
|
self.union = True # num_mode is not None
|
||||||
|
num_mode = 10
|
||||||
|
if self.union:
|
||||||
|
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.hidden_size)
|
||||||
|
self.controlnet_x_embedder = zero_module(operations.Linear(self.in_channels, self.hidden_size).to(device=device, dtype=dtype))
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||||
|
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||||
|
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||||
|
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||||
|
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
controlnet_cond: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
y: Tensor,
|
||||||
|
guidance: Tensor = None,
|
||||||
|
controlnet_mode=None
|
||||||
|
) -> Tensor:
|
||||||
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
batch_size = img.shape[0]
|
||||||
|
|
||||||
|
img = self.img_in(img)
|
||||||
|
vec = self.time_in(timestep_embedding(timesteps, 256).to(self.dtype))
|
||||||
|
if self.params.guidance_embed:
|
||||||
|
vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(self.dtype)))
|
||||||
|
vec.add_(self.vector_in(y))
|
||||||
|
|
||||||
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
if self.union:
|
||||||
|
if controlnet_mode is None:
|
||||||
|
raise ValueError('using union-controlnet, but controlnet_mode is not a list or is empty')
|
||||||
|
controlnet_mode = torch.tensor(controlnet_mode).to(self.device, dtype=torch.long)
|
||||||
|
controlnet_mode = controlnet_mode.reshape([-1, 1])
|
||||||
|
emb_controlnet_mode = self.controlnet_mode_embedder(controlnet_mode).to(self.dtype)
|
||||||
|
txt = torch.cat([emb_controlnet_mode, txt], dim=1)
|
||||||
|
txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1)
|
||||||
|
|
||||||
|
img = img + self.controlnet_x_embedder(controlnet_cond)
|
||||||
|
|
||||||
|
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
block_res_samples = ()
|
||||||
|
for block in self.transformer_blocks:
|
||||||
|
txt, img = block(hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe)
|
||||||
|
block_res_samples = block_res_samples + (img,)
|
||||||
|
|
||||||
|
img = torch.cat([txt, img], dim=1)
|
||||||
|
|
||||||
|
single_block_res_samples = ()
|
||||||
|
for block in self.single_transformer_blocks:
|
||||||
|
img = block(hidden_states=img, temb=vec, image_rotary_emb=pe)
|
||||||
|
single_block_res_samples = single_block_res_samples + (img[:, txt.shape[1]:],)
|
||||||
|
|
||||||
|
controlnet_block_res_samples = ()
|
||||||
|
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||||
|
block_res_sample = controlnet_block(block_res_sample)
|
||||||
|
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||||
|
|
||||||
|
controlnet_single_block_res_samples = ()
|
||||||
|
for single_block_res_sample, single_controlnet_block in zip(single_block_res_samples, self.controlnet_single_blocks):
|
||||||
|
single_block_res_sample = single_controlnet_block(single_block_res_sample)
|
||||||
|
controlnet_single_block_res_samples = controlnet_single_block_res_samples + (single_block_res_sample,)
|
||||||
|
|
||||||
|
n_single_blocks = 38
|
||||||
|
n_double_blocks = 19
|
||||||
|
|
||||||
|
# Expand controlnet_block_res_samples to match n_double_blocks
|
||||||
|
expanded_controlnet_block_res_samples = []
|
||||||
|
interval_control_double = int(np.ceil(n_double_blocks / len(controlnet_block_res_samples)))
|
||||||
|
for i in range(n_double_blocks):
|
||||||
|
index = i // interval_control_double
|
||||||
|
expanded_controlnet_block_res_samples.append(controlnet_block_res_samples[index])
|
||||||
|
|
||||||
|
# Expand controlnet_single_block_res_samples to match n_single_blocks
|
||||||
|
expanded_controlnet_single_block_res_samples = []
|
||||||
|
interval_control_single = int(np.ceil(n_single_blocks / len(controlnet_single_block_res_samples)))
|
||||||
|
for i in range(n_single_blocks):
|
||||||
|
index = i // interval_control_single
|
||||||
|
expanded_controlnet_single_block_res_samples.append(controlnet_single_block_res_samples[index])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input": expanded_controlnet_block_res_samples,
|
||||||
|
"output": expanded_controlnet_single_block_res_samples
|
||||||
|
}
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs):
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
patch_size = 2
|
||||||
|
x = pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
|
height_control_image, width_control_image = hint.shape[2:]
|
||||||
|
num_channels_latents = self.in_channels // 4
|
||||||
|
hint = self._pack_latents(
|
||||||
|
hint,
|
||||||
|
hint.shape[0],
|
||||||
|
num_channels_latents,
|
||||||
|
height_control_image,
|
||||||
|
width_control_image,
|
||||||
|
)
|
||||||
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type)
|
||||||
@ -40,6 +40,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
params = FluxParams(**kwargs)
|
params = FluxParams(**kwargs)
|
||||||
self.params = params
|
self.params = params
|
||||||
|
|||||||
1
comfy/ldm/flux/weight_dtypes.py
Normal file
1
comfy/ldm/flux/weight_dtypes.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
FLUX_WEIGHT_DTYPES = ["default", "fp8_e4m3fn", "fp8_e5m2"]
|
||||||
@ -385,6 +385,8 @@ KNOWN_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
|||||||
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0-promax.safetensors"),
|
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0-promax.safetensors"),
|
||||||
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0.safetensors"),
|
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0.safetensors"),
|
||||||
HuggingFile("InstantX/FLUX.1-dev-Controlnet-Canny", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-canny.safetensors"),
|
HuggingFile("InstantX/FLUX.1-dev-Controlnet-Canny", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-canny.safetensors"),
|
||||||
|
HuggingFile("InstantX/FLUX.1-dev-Controlnet-Union", "diffusion_pytorch_model.safetensors", save_with_filename="instantx-flux.1-dev-controlnet-union.safetensors"),
|
||||||
|
HuggingFile("Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro", "diffusion_pytorch_model.safetensors", save_with_filename="shakker-labs-flux.1-dev-controlnet-union-pro.safetensors"),
|
||||||
], folder_name="controlnet")
|
], folder_name="controlnet")
|
||||||
|
|
||||||
KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from ..cmd import folder_paths, latent_preview
|
|||||||
from ..component_model.tensor_types import RGBImage
|
from ..component_model.tensor_types import RGBImage
|
||||||
from ..execution_context import current_execution_context
|
from ..execution_context import current_execution_context
|
||||||
from ..images import open_image
|
from ..images import open_image
|
||||||
|
from ..ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
|
||||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
|
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
|
||||||
from ..nodes.common import MAX_RESOLUTION
|
from ..nodes.common import MAX_RESOLUTION
|
||||||
from .. import controlnet
|
from .. import controlnet
|
||||||
@ -756,6 +757,27 @@ class ControlNetLoader:
|
|||||||
controlnet_ = controlnet.load_controlnet(controlnet_path)
|
controlnet_ = controlnet.load_controlnet(controlnet_path)
|
||||||
return (controlnet_,)
|
return (controlnet_,)
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetLoaderWeights:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"control_net_name": (get_filename_list_with_downloadable("controlnet", KNOWN_CONTROLNETS),),
|
||||||
|
"weight_dtype": (FLUX_WEIGHT_DTYPES,)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONTROL_NET",)
|
||||||
|
FUNCTION = "load_controlnet"
|
||||||
|
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
|
def load_controlnet(self, control_net_name, weight_dtype):
|
||||||
|
controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_CONTROLNETS)
|
||||||
|
controlnet_ = controlnet.load_controlnet(controlnet_path, weight_dtype=weight_dtype)
|
||||||
|
return (controlnet_,)
|
||||||
|
|
||||||
class DiffControlNetLoader:
|
class DiffControlNetLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -854,7 +876,7 @@ class UNETLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "unet_name": (get_filename_list_with_downloadable("diffusion_models", KNOWN_UNET_MODELS),),
|
return {"required": { "unet_name": (get_filename_list_with_downloadable("diffusion_models", KNOWN_UNET_MODELS),),
|
||||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
"weight_dtype": (FLUX_WEIGHT_DTYPES,)
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "load_unet"
|
FUNCTION = "load_unet"
|
||||||
@ -1882,6 +1904,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ControlNetApply": ControlNetApply,
|
"ControlNetApply": ControlNetApply,
|
||||||
"ControlNetApplyAdvanced": ControlNetApplyAdvanced,
|
"ControlNetApplyAdvanced": ControlNetApplyAdvanced,
|
||||||
"ControlNetLoader": ControlNetLoader,
|
"ControlNetLoader": ControlNetLoader,
|
||||||
|
"ControlNetLoaderWeights": ControlNetLoaderWeights,
|
||||||
"DiffControlNetLoader": DiffControlNetLoader,
|
"DiffControlNetLoader": DiffControlNetLoader,
|
||||||
"StyleModelLoader": StyleModelLoader,
|
"StyleModelLoader": StyleModelLoader,
|
||||||
"CLIPVisionLoader": CLIPVisionLoader,
|
"CLIPVisionLoader": CLIPVisionLoader,
|
||||||
@ -1914,6 +1937,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"LoraLoader": "Load LoRA",
|
"LoraLoader": "Load LoRA",
|
||||||
"CLIPLoader": "Load CLIP",
|
"CLIPLoader": "Load CLIP",
|
||||||
"ControlNetLoader": "Load ControlNet Model",
|
"ControlNetLoader": "Load ControlNet Model",
|
||||||
|
"ControlNetLoaderWeights": "Load ControlNet Model (Weights)",
|
||||||
"DiffControlNetLoader": "Load ControlNet Model (diff)",
|
"DiffControlNetLoader": "Load ControlNet Model (diff)",
|
||||||
"StyleModelLoader": "Load Style Model",
|
"StyleModelLoader": "Load Style Model",
|
||||||
"CLIPVisionLoader": "Load CLIP Vision",
|
"CLIPVisionLoader": "Load CLIP Vision",
|
||||||
|
|||||||
314
tests/inference/workflows/flux-controlnet-0.json
Normal file
314
tests/inference/workflows/flux-controlnet-0.json
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
{
|
||||||
|
"1": {
|
||||||
|
"inputs": {
|
||||||
|
"noise": [
|
||||||
|
"2",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"guider": [
|
||||||
|
"3",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"sampler": [
|
||||||
|
"6",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"sigmas": [
|
||||||
|
"7",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"9",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SamplerCustomAdvanced",
|
||||||
|
"_meta": {
|
||||||
|
"title": "SamplerCustomAdvanced"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"inputs": {
|
||||||
|
"noise_seed": 868192789249410
|
||||||
|
},
|
||||||
|
"class_type": "RandomNoise",
|
||||||
|
"_meta": {
|
||||||
|
"title": "RandomNoise"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"3": {
|
||||||
|
"inputs": {
|
||||||
|
"model": [
|
||||||
|
"12",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"conditioning": [
|
||||||
|
"17",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "BasicGuider",
|
||||||
|
"_meta": {
|
||||||
|
"title": "BasicGuider"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"inputs": {
|
||||||
|
"guidance": 4,
|
||||||
|
"conditioning": [
|
||||||
|
"13",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "FluxGuidance",
|
||||||
|
"_meta": {
|
||||||
|
"title": "FluxGuidance"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"inputs": {
|
||||||
|
"sampler_name": "euler"
|
||||||
|
},
|
||||||
|
"class_type": "KSamplerSelect",
|
||||||
|
"_meta": {
|
||||||
|
"title": "KSamplerSelect"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"7": {
|
||||||
|
"inputs": {
|
||||||
|
"scheduler": "ddim_uniform",
|
||||||
|
"steps": 10,
|
||||||
|
"denoise": 1,
|
||||||
|
"model": [
|
||||||
|
"12",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "BasicScheduler",
|
||||||
|
"_meta": {
|
||||||
|
"title": "BasicScheduler"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"9": {
|
||||||
|
"inputs": {
|
||||||
|
"width": [
|
||||||
|
"27",
|
||||||
|
4
|
||||||
|
],
|
||||||
|
"height": [
|
||||||
|
"27",
|
||||||
|
5
|
||||||
|
],
|
||||||
|
"batch_size": 1
|
||||||
|
},
|
||||||
|
"class_type": "EmptySD3LatentImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "EmptySD3LatentImage"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"10": {
|
||||||
|
"inputs": {
|
||||||
|
"samples": [
|
||||||
|
"1",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"11",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "VAE Decode"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"11": {
|
||||||
|
"inputs": {
|
||||||
|
"vae_name": "ae.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "VAELoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load VAE"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"12": {
|
||||||
|
"inputs": {
|
||||||
|
"unet_name": "flux1-dev.safetensors",
|
||||||
|
"weight_dtype": "fp8_e4m3fn"
|
||||||
|
},
|
||||||
|
"class_type": "UNETLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load Diffusion Model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"13": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "A photo of a girl.",
|
||||||
|
"clip": [
|
||||||
|
"15",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"15": {
|
||||||
|
"inputs": {
|
||||||
|
"clip_name1": "clip_l.safetensors",
|
||||||
|
"clip_name2": "t5xxl_fp16.safetensors",
|
||||||
|
"type": "flux"
|
||||||
|
},
|
||||||
|
"class_type": "DualCLIPLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "DualCLIPLoader"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "ComfyUI",
|
||||||
|
"images": [
|
||||||
|
"10",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SaveImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Save Image"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"17": {
|
||||||
|
"inputs": {
|
||||||
|
"strength": 0.6,
|
||||||
|
"start_percent": 0,
|
||||||
|
"end_percent": 1,
|
||||||
|
"positive": [
|
||||||
|
"4",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"21",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"control_net": [
|
||||||
|
"24",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"11",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"image": [
|
||||||
|
"25",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ControlNetApplySD3",
|
||||||
|
"_meta": {
|
||||||
|
"title": "ControlNetApply SD3 and HunyuanDiT"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"19": {
|
||||||
|
"inputs": {
|
||||||
|
"value": "https://upload.wikimedia.org/wikipedia/commons/5/5a/Gibson_Girl.png",
|
||||||
|
"name": "",
|
||||||
|
"title": "",
|
||||||
|
"description": "",
|
||||||
|
"__required": true
|
||||||
|
},
|
||||||
|
"class_type": "ImageRequestParameter",
|
||||||
|
"_meta": {
|
||||||
|
"title": "ImageRequestParameter"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"21": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "",
|
||||||
|
"clip": [
|
||||||
|
"15",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"24": {
|
||||||
|
"inputs": {
|
||||||
|
"type": "canny (InstantX)",
|
||||||
|
"control_net": [
|
||||||
|
"28",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SetUnionControlNetType",
|
||||||
|
"_meta": {
|
||||||
|
"title": "SetUnionControlNetType"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"25": {
|
||||||
|
"inputs": {
|
||||||
|
"low_threshold": 0.4,
|
||||||
|
"high_threshold": 0.8,
|
||||||
|
"image": [
|
||||||
|
"31",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "Canny",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Canny"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"26": {
|
||||||
|
"inputs": {
|
||||||
|
"images": [
|
||||||
|
"25",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "PreviewImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Preview Image"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"27": {
|
||||||
|
"inputs": {
|
||||||
|
"image": [
|
||||||
|
"25",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "Image Size to Number",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Image Size to Number"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"28": {
|
||||||
|
"inputs": {
|
||||||
|
"control_net_name": "shakker-labs-flux.1-dev-controlnet-union-pro.safetensors",
|
||||||
|
"weight_dtype": "default"
|
||||||
|
},
|
||||||
|
"class_type": "ControlNetLoaderWeights",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load ControlNet Model (Weights)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"31": {
|
||||||
|
"inputs": {
|
||||||
|
"image_gen_width": 1024,
|
||||||
|
"image_gen_height": 1024,
|
||||||
|
"resize_mode": "Crop and Resize",
|
||||||
|
"hint_image": [
|
||||||
|
"19",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "HintImageEnchance",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Enchance And Resize Hint Images"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user