mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-16 09:10:50 +08:00
228 lines
9.6 KiB
Python
228 lines
9.6 KiB
Python
# Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
|
|
|
import numbers
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
|
from diffusers.utils.import_utils import is_torch_version
|
|
from einops import rearrange, repeat
|
|
from torch import Tensor, nn
|
|
|
|
from .layers import timestep_embedding
|
|
from .model import Flux
|
|
from ..common_dit import pad_to_patch_size
|
|
|
|
if is_torch_version(">=", "2.1.0"):
|
|
LayerNorm = nn.LayerNorm
|
|
else:
|
|
# Has optional bias parameter compared to torch layer norm
|
|
# TODO: replace with torch layernorm once min required torch version >= 2.1
|
|
class LayerNorm(nn.Module):
|
|
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
|
|
super().__init__()
|
|
|
|
self.eps = eps
|
|
|
|
if isinstance(dim, numbers.Integral):
|
|
dim = (dim,)
|
|
|
|
self.dim = torch.Size(dim)
|
|
|
|
if elementwise_affine:
|
|
self.weight = nn.Parameter(torch.ones(dim))
|
|
self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
|
|
else:
|
|
self.weight = None
|
|
self.bias = None
|
|
|
|
def forward(self, input):
|
|
return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
|
|
|
|
|
|
def zero_module(module):
|
|
for p in module.parameters():
|
|
nn.init.zeros_(p)
|
|
return module
|
|
|
|
|
|
# YiYi to-do: refactor rope related functions/classes
|
|
def apply_rope(xq, xk, freqs_cis):
|
|
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
|
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
|
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
|
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
|
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
|
|
|
|
|
class InstantXControlNetFluxFormat2(Flux):
|
|
def __init__(self, image_model=None, dtype=None, device=None, operations=None, joint_attention_dim=4096, **kwargs):
|
|
kwargs["depth"] = 0
|
|
kwargs["depth_single_blocks"] = 0
|
|
depth_single_blocks_controlnet = kwargs.pop("depth_single_blocks_controlnet", 2)
|
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
|
|
|
self.transformer_blocks = nn.ModuleList(
|
|
[
|
|
FluxTransformerBlock(
|
|
dim=self.hidden_size,
|
|
num_attention_heads=24,
|
|
attention_head_dim=128,
|
|
).to(dtype=dtype)
|
|
for i in range(5)
|
|
]
|
|
)
|
|
|
|
self.single_transformer_blocks = nn.ModuleList(
|
|
[
|
|
FluxSingleTransformerBlock(
|
|
dim=self.hidden_size,
|
|
num_attention_heads=24,
|
|
attention_head_dim=128,
|
|
).to(dtype=dtype)
|
|
for i in range(10)
|
|
]
|
|
)
|
|
|
|
self.require_vae = True
|
|
# add ControlNet blocks
|
|
self.controlnet_blocks = nn.ModuleList([])
|
|
for _ in range(len(self.transformer_blocks)):
|
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
|
controlnet_block = zero_module(controlnet_block)
|
|
self.controlnet_blocks.append(controlnet_block)
|
|
|
|
self.controlnet_single_blocks = nn.ModuleList([])
|
|
for _ in range(len(self.single_transformer_blocks)):
|
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
|
controlnet_block = zero_module(controlnet_block)
|
|
self.controlnet_single_blocks.append(controlnet_block)
|
|
|
|
# TODO support both union and unimodal
|
|
self.union = True # num_mode is not None
|
|
num_mode = 10
|
|
if self.union:
|
|
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.hidden_size)
|
|
self.controlnet_x_embedder = zero_module(operations.Linear(self.in_channels, self.hidden_size).to(device=device, dtype=dtype))
|
|
self.gradient_checkpointing = False
|
|
|
|
@staticmethod
|
|
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
|
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
|
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
|
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
|
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
|
|
|
return latents
|
|
|
|
def forward_orig(
|
|
self,
|
|
img: Tensor,
|
|
img_ids: Tensor,
|
|
controlnet_cond: Tensor,
|
|
txt: Tensor,
|
|
txt_ids: Tensor,
|
|
timesteps: Tensor,
|
|
y: Tensor,
|
|
guidance: Tensor = None,
|
|
controlnet_mode=None
|
|
) -> Tensor:
|
|
if img.ndim != 3 or txt.ndim != 3:
|
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
|
|
|
batch_size = img.shape[0]
|
|
|
|
img = self.img_in(img)
|
|
vec = self.time_in(timestep_embedding(timesteps, 256).to(self.dtype))
|
|
if self.params.guidance_embed:
|
|
vec.add_(self.guidance_in(timestep_embedding(guidance, 256).to(self.dtype)))
|
|
vec.add_(self.vector_in(y))
|
|
|
|
txt = self.txt_in(txt)
|
|
|
|
if self.union:
|
|
if controlnet_mode is None:
|
|
raise ValueError('using union-controlnet, but controlnet_mode is not a list or is empty')
|
|
controlnet_mode = torch.tensor(controlnet_mode).to(self.device, dtype=torch.long)
|
|
controlnet_mode = controlnet_mode.reshape([-1, 1])
|
|
emb_controlnet_mode = self.controlnet_mode_embedder(controlnet_mode).to(self.dtype)
|
|
txt = torch.cat([emb_controlnet_mode, txt], dim=1)
|
|
txt_ids = torch.cat([txt_ids[:, :1], txt_ids], dim=1)
|
|
|
|
img = img + self.controlnet_x_embedder(controlnet_cond)
|
|
|
|
txt_ids = txt_ids.expand(img_ids.size(0), -1, -1)
|
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
|
pe = self.pe_embedder(ids)
|
|
|
|
block_res_samples = ()
|
|
for block in self.transformer_blocks:
|
|
txt, img = block(hidden_states=img, encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe)
|
|
block_res_samples = block_res_samples + (img,)
|
|
|
|
img = torch.cat([txt, img], dim=1)
|
|
|
|
single_block_res_samples = ()
|
|
for block in self.single_transformer_blocks:
|
|
img = block(hidden_states=img, temb=vec, image_rotary_emb=pe)
|
|
single_block_res_samples = single_block_res_samples + (img[:, txt.shape[1]:],)
|
|
|
|
controlnet_block_res_samples = ()
|
|
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
|
block_res_sample = controlnet_block(block_res_sample)
|
|
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
|
|
|
controlnet_single_block_res_samples = ()
|
|
for single_block_res_sample, single_controlnet_block in zip(single_block_res_samples, self.controlnet_single_blocks):
|
|
single_block_res_sample = single_controlnet_block(single_block_res_sample)
|
|
controlnet_single_block_res_samples = controlnet_single_block_res_samples + (single_block_res_sample,)
|
|
|
|
n_single_blocks = 38
|
|
n_double_blocks = 19
|
|
|
|
# Expand controlnet_block_res_samples to match n_double_blocks
|
|
expanded_controlnet_block_res_samples = []
|
|
interval_control_double = int(np.ceil(n_double_blocks / len(controlnet_block_res_samples)))
|
|
for i in range(n_double_blocks):
|
|
index = i // interval_control_double
|
|
expanded_controlnet_block_res_samples.append(controlnet_block_res_samples[index])
|
|
|
|
# Expand controlnet_single_block_res_samples to match n_single_blocks
|
|
expanded_controlnet_single_block_res_samples = []
|
|
interval_control_single = int(np.ceil(n_single_blocks / len(controlnet_single_block_res_samples)))
|
|
for i in range(n_single_blocks):
|
|
index = i // interval_control_single
|
|
expanded_controlnet_single_block_res_samples.append(controlnet_single_block_res_samples[index])
|
|
|
|
return {
|
|
"input": expanded_controlnet_block_res_samples,
|
|
"output": expanded_controlnet_single_block_res_samples
|
|
}
|
|
|
|
def forward(self, x, timesteps, context, y, guidance=None, hint=None, control_type=None, **kwargs):
|
|
bs, c, h, w = x.shape
|
|
patch_size = 2
|
|
x = pad_to_patch_size(x, (patch_size, patch_size))
|
|
|
|
height_control_image, width_control_image = hint.shape[2:]
|
|
num_channels_latents = self.in_channels // 4
|
|
hint = self._pack_latents(
|
|
hint,
|
|
hint.shape[0],
|
|
num_channels_latents,
|
|
height_control_image,
|
|
width_control_image,
|
|
)
|
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
|
|
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
|
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
|
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
|
|
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
|
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type)
|