mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-14 16:57:05 +08:00
Basic implementation of z image fun control union 2.0 (#11304)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
The inpaint part is currently missing and will be implemented later. I think they messed up this model pretty bad. They added some control_noise_refiner blocks but don't actually use them. There is a typo in their code so instead of doing control_noise_refiner -> control_layers it runs the whole control_layers twice. Unfortunately they trained with this typo so the model works but is kind of slow and would probably perform a lot better if they corrected their code and trained it again.
This commit is contained in:
parent
c5a47a1692
commit
da2bfb5b0a
@ -41,6 +41,11 @@ class ZImage_Control(torch.nn.Module):
|
||||
ffn_dim_multiplier: float = (8.0 / 3.0),
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = True,
|
||||
n_control_layers=6,
|
||||
control_in_dim=16,
|
||||
additional_in_dim=0,
|
||||
broken=False,
|
||||
refiner_control=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@ -49,10 +54,11 @@ class ZImage_Control(torch.nn.Module):
|
||||
super().__init__()
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.additional_in_dim = 0
|
||||
self.control_in_dim = 16
|
||||
self.broken = broken
|
||||
self.additional_in_dim = additional_in_dim
|
||||
self.control_in_dim = control_in_dim
|
||||
n_refiner_layers = 2
|
||||
self.n_control_layers = 6
|
||||
self.n_control_layers = n_control_layers
|
||||
self.control_layers = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
@ -74,28 +80,49 @@ class ZImage_Control(torch.nn.Module):
|
||||
all_x_embedder = {}
|
||||
patch_size = 2
|
||||
f_patch_size = 1
|
||||
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
|
||||
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
|
||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||
|
||||
self.refiner_control = refiner_control
|
||||
|
||||
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=True,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
if self.refiner_control:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
block_id=layer_id,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=True,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
||||
patch_size = 2
|
||||
@ -105,9 +132,29 @@ class ZImage_Control(torch.nn.Module):
|
||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
|
||||
x_attn_mask = None
|
||||
for layer in self.control_noise_refiner:
|
||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||
if not self.refiner_control:
|
||||
for layer in self.control_noise_refiner:
|
||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||
|
||||
return control_context
|
||||
|
||||
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||
if self.refiner_control:
|
||||
if self.broken:
|
||||
if layer_id == 0:
|
||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
if layer_id > 0:
|
||||
out = None
|
||||
for i in range(1, len(self.control_layers)):
|
||||
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
if out is None:
|
||||
out = o
|
||||
|
||||
return (out, control_context)
|
||||
else:
|
||||
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
else:
|
||||
return (None, control_context)
|
||||
|
||||
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
|
||||
@ -536,6 +536,7 @@ class NextDiT(nn.Module):
|
||||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
device = x[0].device
|
||||
orig_x = x
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||
@ -572,13 +573,21 @@ class NextDiT(nn.Module):
|
||||
|
||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
# refine context
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||
|
||||
padded_img_mask = None
|
||||
for layer in self.noise_refiner:
|
||||
x_input = x
|
||||
for i, layer in enumerate(self.noise_refiner):
|
||||
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||
if "noise_refiner" in patches:
|
||||
for p in patches["noise_refiner"]:
|
||||
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||
if "img" in out:
|
||||
x = out["img"]
|
||||
|
||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||
mask = None
|
||||
@ -622,14 +631,15 @@ class NextDiT(nn.Module):
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
|
||||
img_input = img
|
||||
for i, layer in enumerate(self.layers):
|
||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
if "img" in out:
|
||||
img[:, cap_size[0]:] = out["img"]
|
||||
if "txt" in out:
|
||||
|
||||
@ -454,6 +454,9 @@ class ModelPatcher:
|
||||
def set_model_post_input_patch(self, patch):
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def set_model_noise_refiner_patch(self, patch):
|
||||
self.set_model_patch(patch, "noise_refiner")
|
||||
|
||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||
rope_options["scale_x"] = scale_x
|
||||
|
||||
@ -243,7 +243,13 @@ class ModelPatchLoader:
|
||||
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
|
||||
sd = z_image_convert(sd)
|
||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||
config = {}
|
||||
if 'control_layers.14.adaLN_modulation.0.weight' in sd:
|
||||
config['n_control_layers'] = 15
|
||||
config['additional_in_dim'] = 17
|
||||
config['refiner_control'] = True
|
||||
config['broken'] = True
|
||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
|
||||
|
||||
model.load_state_dict(sd)
|
||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
@ -297,56 +303,86 @@ class DiffSynthCnetPatch:
|
||||
return [self.model_patch]
|
||||
|
||||
class ZImageControlPatch:
|
||||
def __init__(self, model_patch, vae, image, strength):
|
||||
def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=None):
|
||||
self.model_patch = model_patch
|
||||
self.vae = vae
|
||||
self.image = image
|
||||
self.inpaint_image = inpaint_image
|
||||
self.mask = mask
|
||||
self.strength = strength
|
||||
self.encoded_image = self.encode_latent_cond(image)
|
||||
self.encoded_image_size = (image.shape[1], image.shape[2])
|
||||
self.temp_data = None
|
||||
|
||||
def encode_latent_cond(self, image):
|
||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
|
||||
return latent_image
|
||||
def encode_latent_cond(self, control_image, inpaint_image=None):
|
||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
|
||||
if self.model_patch.model.additional_in_dim > 0:
|
||||
if self.mask is None:
|
||||
mask_ = torch.zeros_like(latent_image)[:, :1]
|
||||
else:
|
||||
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
|
||||
if inpaint_image is None:
|
||||
inpaint_image = torch.ones_like(control_image) * 0.5
|
||||
|
||||
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
|
||||
|
||||
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
|
||||
else:
|
||||
return latent_image
|
||||
|
||||
def __call__(self, kwargs):
|
||||
x = kwargs.get("x")
|
||||
img = kwargs.get("img")
|
||||
img_input = kwargs.get("img_input")
|
||||
txt = kwargs.get("txt")
|
||||
pe = kwargs.get("pe")
|
||||
vec = kwargs.get("vec")
|
||||
block_index = kwargs.get("block_index")
|
||||
block_type = kwargs.get("block_type", "")
|
||||
spacial_compression = self.vae.spacial_compression_encode()
|
||||
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
||||
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
||||
inpaint_scaled = None
|
||||
if self.inpaint_image is not None:
|
||||
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
|
||||
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1), inpaint_scaled)
|
||||
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
||||
comfy.model_management.load_models_gpu(loaded_models)
|
||||
|
||||
cnet_index = (block_index // 5)
|
||||
cnet_index_float = (block_index / 5)
|
||||
cnet_blocks = self.model_patch.model.n_control_layers
|
||||
div = round(30 / cnet_blocks)
|
||||
|
||||
cnet_index = (block_index // div)
|
||||
cnet_index_float = (block_index / div)
|
||||
|
||||
kwargs.pop("img") # we do ops in place
|
||||
kwargs.pop("txt")
|
||||
|
||||
cnet_blocks = self.model_patch.model.n_control_layers
|
||||
if cnet_index_float > (cnet_blocks - 1):
|
||||
self.temp_data = None
|
||||
return kwargs
|
||||
|
||||
if self.temp_data is None or self.temp_data[0] > cnet_index:
|
||||
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||
if block_type == "noise_refiner":
|
||||
self.temp_data = (-3, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||
else:
|
||||
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||
|
||||
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
|
||||
if block_type == "noise_refiner":
|
||||
next_layer = self.temp_data[0] + 1
|
||||
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||
self.temp_data = (next_layer, self.model_patch.model.forward_noise_refiner_block(block_index, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||
if self.temp_data[1][0] is not None:
|
||||
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||
else:
|
||||
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
|
||||
next_layer = self.temp_data[0] + 1
|
||||
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||
|
||||
if cnet_index_float == self.temp_data[0]:
|
||||
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||
if cnet_blocks == self.temp_data[0] + 1:
|
||||
self.temp_data = None
|
||||
if cnet_index_float == self.temp_data[0]:
|
||||
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||
if cnet_blocks == self.temp_data[0] + 1:
|
||||
self.temp_data = None
|
||||
|
||||
return kwargs
|
||||
|
||||
@ -386,7 +422,9 @@ class QwenImageDiffsynthControlnet:
|
||||
mask = 1.0 - mask
|
||||
|
||||
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
||||
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
|
||||
patch = ZImageControlPatch(model_patch, vae, image, strength, mask=mask)
|
||||
model_patched.set_model_noise_refiner_patch(patch)
|
||||
model_patched.set_model_double_block_patch(patch)
|
||||
else:
|
||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||
return (model_patched,)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user