mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-15 01:07:03 +08:00
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
The inpaint part is currently missing and will be implemented later. I think they messed up this model pretty bad. They added some control_noise_refiner blocks but don't actually use them. There is a typo in their code so instead of doing control_noise_refiner -> control_layers it runs the whole control_layers twice. Unfortunately they trained with this typo so the model works but is kind of slow and would probably perform a lot better if they corrected their code and trained it again.
161 lines
6.4 KiB
Python
161 lines
6.4 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
from .model import JointTransformerBlock
|
|
|
|
class ZImageControlTransformerBlock(JointTransformerBlock):
|
|
def __init__(
|
|
self,
|
|
layer_id: int,
|
|
dim: int,
|
|
n_heads: int,
|
|
n_kv_heads: int,
|
|
multiple_of: int,
|
|
ffn_dim_multiplier: float,
|
|
norm_eps: float,
|
|
qk_norm: bool,
|
|
modulation=True,
|
|
block_id=0,
|
|
operation_settings=None,
|
|
):
|
|
super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
|
|
self.block_id = block_id
|
|
if block_id == 0:
|
|
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
|
|
|
def forward(self, c, x, **kwargs):
|
|
if self.block_id == 0:
|
|
c = self.before_proj(c) + x
|
|
c = super().forward(c, **kwargs)
|
|
c_skip = self.after_proj(c)
|
|
return c_skip, c
|
|
|
|
class ZImage_Control(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int = 3840,
|
|
n_heads: int = 30,
|
|
n_kv_heads: int = 30,
|
|
multiple_of: int = 256,
|
|
ffn_dim_multiplier: float = (8.0 / 3.0),
|
|
norm_eps: float = 1e-5,
|
|
qk_norm: bool = True,
|
|
n_control_layers=6,
|
|
control_in_dim=16,
|
|
additional_in_dim=0,
|
|
broken=False,
|
|
refiner_control=False,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
**kwargs
|
|
):
|
|
super().__init__()
|
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
|
|
|
self.broken = broken
|
|
self.additional_in_dim = additional_in_dim
|
|
self.control_in_dim = control_in_dim
|
|
n_refiner_layers = 2
|
|
self.n_control_layers = n_control_layers
|
|
self.control_layers = nn.ModuleList(
|
|
[
|
|
ZImageControlTransformerBlock(
|
|
i,
|
|
dim,
|
|
n_heads,
|
|
n_kv_heads,
|
|
multiple_of,
|
|
ffn_dim_multiplier,
|
|
norm_eps,
|
|
qk_norm,
|
|
block_id=i,
|
|
operation_settings=operation_settings,
|
|
)
|
|
for i in range(self.n_control_layers)
|
|
]
|
|
)
|
|
|
|
all_x_embedder = {}
|
|
patch_size = 2
|
|
f_patch_size = 1
|
|
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
|
|
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
|
|
|
self.refiner_control = refiner_control
|
|
|
|
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
|
if self.refiner_control:
|
|
self.control_noise_refiner = nn.ModuleList(
|
|
[
|
|
ZImageControlTransformerBlock(
|
|
layer_id,
|
|
dim,
|
|
n_heads,
|
|
n_kv_heads,
|
|
multiple_of,
|
|
ffn_dim_multiplier,
|
|
norm_eps,
|
|
qk_norm,
|
|
block_id=layer_id,
|
|
operation_settings=operation_settings,
|
|
)
|
|
for layer_id in range(n_refiner_layers)
|
|
]
|
|
)
|
|
else:
|
|
self.control_noise_refiner = nn.ModuleList(
|
|
[
|
|
JointTransformerBlock(
|
|
layer_id,
|
|
dim,
|
|
n_heads,
|
|
n_kv_heads,
|
|
multiple_of,
|
|
ffn_dim_multiplier,
|
|
norm_eps,
|
|
qk_norm,
|
|
modulation=True,
|
|
z_image_modulation=True,
|
|
operation_settings=operation_settings,
|
|
)
|
|
for layer_id in range(n_refiner_layers)
|
|
]
|
|
)
|
|
|
|
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
|
patch_size = 2
|
|
f_patch_size = 1
|
|
pH = pW = patch_size
|
|
B, C, H, W = control_context.shape
|
|
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
|
|
|
x_attn_mask = None
|
|
if not self.refiner_control:
|
|
for layer in self.control_noise_refiner:
|
|
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
|
|
|
return control_context
|
|
|
|
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
|
if self.refiner_control:
|
|
if self.broken:
|
|
if layer_id == 0:
|
|
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
|
if layer_id > 0:
|
|
out = None
|
|
for i in range(1, len(self.control_layers)):
|
|
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
|
if out is None:
|
|
out = o
|
|
|
|
return (out, control_context)
|
|
else:
|
|
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
|
else:
|
|
return (None, control_context)
|
|
|
|
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
|
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|