From da2bfb5b0af26c7a1c44ec951dbd0fffe413c793 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 12 Dec 2025 22:39:11 -0800 Subject: [PATCH] Basic implementation of z image fun control union 2.0 (#11304) The inpaint part is currently missing and will be implemented later. I think they messed up this model pretty bad. They added some control_noise_refiner blocks but don't actually use them. There is a typo in their code so instead of doing control_noise_refiner -> control_layers it runs the whole control_layers twice. Unfortunately they trained with this typo so the model works but is kind of slow and would probably perform a lot better if they corrected their code and trained it again. --- comfy/ldm/lumina/controlnet.py | 95 +++++++++++++++++++++++-------- comfy/ldm/lumina/model.py | 16 +++++- comfy/model_patcher.py | 3 + comfy_extras/nodes_model_patch.py | 72 +++++++++++++++++------ 4 files changed, 142 insertions(+), 44 deletions(-) diff --git a/comfy/ldm/lumina/controlnet.py b/comfy/ldm/lumina/controlnet.py index fd7ce3b5c..8e2de7977 100644 --- a/comfy/ldm/lumina/controlnet.py +++ b/comfy/ldm/lumina/controlnet.py @@ -41,6 +41,11 @@ class ZImage_Control(torch.nn.Module): ffn_dim_multiplier: float = (8.0 / 3.0), norm_eps: float = 1e-5, qk_norm: bool = True, + n_control_layers=6, + control_in_dim=16, + additional_in_dim=0, + broken=False, + refiner_control=False, dtype=None, device=None, operations=None, @@ -49,10 +54,11 @@ class ZImage_Control(torch.nn.Module): super().__init__() operation_settings = {"operations": operations, "device": device, "dtype": dtype} - self.additional_in_dim = 0 - self.control_in_dim = 16 + self.broken = broken + self.additional_in_dim = additional_in_dim + self.control_in_dim = control_in_dim n_refiner_layers = 2 - self.n_control_layers = 6 + self.n_control_layers = n_control_layers self.control_layers = nn.ModuleList( [ ZImageControlTransformerBlock( @@ -74,28 +80,49 @@ class ZImage_Control(torch.nn.Module): all_x_embedder = {} patch_size = 2 f_patch_size = 1 - x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype) + x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype) all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + self.refiner_control = refiner_control + self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) - self.control_noise_refiner = nn.ModuleList( - [ - JointTransformerBlock( - layer_id, - dim, - n_heads, - n_kv_heads, - multiple_of, - ffn_dim_multiplier, - norm_eps, - qk_norm, - modulation=True, - z_image_modulation=True, - operation_settings=operation_settings, - ) - for layer_id in range(n_refiner_layers) - ] - ) + if self.refiner_control: + self.control_noise_refiner = nn.ModuleList( + [ + ZImageControlTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + block_id=layer_id, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + else: + self.control_noise_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=True, + z_image_modulation=True, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input): patch_size = 2 @@ -105,9 +132,29 @@ class ZImage_Control(torch.nn.Module): control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) x_attn_mask = None - for layer in self.control_noise_refiner: - control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input) + if not self.refiner_control: + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input) + return control_context + def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): + if self.refiner_control: + if self.broken: + if layer_id == 0: + return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) + if layer_id > 0: + out = None + for i in range(1, len(self.control_layers)): + o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) + if out is None: + out = o + + return (out, control_context) + else: + return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) + else: + return (None, control_context) + def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index c47df49ca..96cb37fa6 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -536,6 +536,7 @@ class NextDiT(nn.Module): bsz = len(x) pH = pW = self.patch_size device = x[0].device + orig_x = x if self.pad_tokens_multiple is not None: pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple @@ -572,13 +573,21 @@ class NextDiT(nn.Module): freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) + patches = transformer_options.get("patches", {}) + # refine context for layer in self.context_refiner: cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) padded_img_mask = None - for layer in self.noise_refiner: + x_input = x + for i, layer in enumerate(self.noise_refiner): x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) + if "noise_refiner" in patches: + for p in patches["noise_refiner"]: + out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) + if "img" in out: + x = out["img"] padded_full_embed = torch.cat((cap_feats, x), dim=1) mask = None @@ -622,14 +631,15 @@ class NextDiT(nn.Module): patches = transformer_options.get("patches", {}) x_is_tensor = isinstance(x, torch.Tensor) - img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) + img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options) freqs_cis = freqs_cis.to(img.device) + img_input = img for i, layer in enumerate(self.layers): img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) if "double_block" in patches: for p in patches["double_block"]: - out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) if "img" in out: img[:, cap_size[0]:] = out["img"] if "txt" in out: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a486c2723..93d26c690 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -454,6 +454,9 @@ class ModelPatcher: def set_model_post_input_patch(self, patch): self.set_model_patch(patch, "post_input") + def set_model_noise_refiner_patch(self, patch): + self.set_model_patch(patch, "noise_refiner") + def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs): rope_options = self.model_options["transformer_options"].get("rope_options", {}) rope_options["scale_x"] = scale_x diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index c61810dbf..ec0e790dc 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -243,7 +243,13 @@ class ModelPatchLoader: model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet sd = z_image_convert(sd) - model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + config = {} + if 'control_layers.14.adaLN_modulation.0.weight' in sd: + config['n_control_layers'] = 15 + config['additional_in_dim'] = 17 + config['refiner_control'] = True + config['broken'] = True + model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config) model.load_state_dict(sd) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) @@ -297,56 +303,86 @@ class DiffSynthCnetPatch: return [self.model_patch] class ZImageControlPatch: - def __init__(self, model_patch, vae, image, strength): + def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=None): self.model_patch = model_patch self.vae = vae self.image = image + self.inpaint_image = inpaint_image + self.mask = mask self.strength = strength self.encoded_image = self.encode_latent_cond(image) self.encoded_image_size = (image.shape[1], image.shape[2]) self.temp_data = None - def encode_latent_cond(self, image): - latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image)) - return latent_image + def encode_latent_cond(self, control_image, inpaint_image=None): + latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image)) + if self.model_patch.model.additional_in_dim > 0: + if self.mask is None: + mask_ = torch.zeros_like(latent_image)[:, :1] + else: + mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none") + if inpaint_image is None: + inpaint_image = torch.ones_like(control_image) * 0.5 + + inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image)) + + return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1) + else: + return latent_image def __call__(self, kwargs): x = kwargs.get("x") img = kwargs.get("img") + img_input = kwargs.get("img_input") txt = kwargs.get("txt") pe = kwargs.get("pe") vec = kwargs.get("vec") block_index = kwargs.get("block_index") + block_type = kwargs.get("block_type", "") spacial_compression = self.vae.spacial_compression_encode() if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression): image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") + inpaint_scaled = None + if self.inpaint_image is not None: + inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1) loaded_models = comfy.model_management.loaded_models(only_currently_used=True) - self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1)) + self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1), inpaint_scaled) self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1]) comfy.model_management.load_models_gpu(loaded_models) - cnet_index = (block_index // 5) - cnet_index_float = (block_index / 5) + cnet_blocks = self.model_patch.model.n_control_layers + div = round(30 / cnet_blocks) + + cnet_index = (block_index // div) + cnet_index_float = (block_index / div) kwargs.pop("img") # we do ops in place kwargs.pop("txt") - cnet_blocks = self.model_patch.model.n_control_layers if cnet_index_float > (cnet_blocks - 1): self.temp_data = None return kwargs if self.temp_data is None or self.temp_data[0] > cnet_index: - self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec))) + if block_type == "noise_refiner": + self.temp_data = (-3, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec))) + else: + self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec))) - while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks: + if block_type == "noise_refiner": next_layer = self.temp_data[0] + 1 - self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec)) + self.temp_data = (next_layer, self.model_patch.model.forward_noise_refiner_block(block_index, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec)) + if self.temp_data[1][0] is not None: + img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength) + else: + while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks: + next_layer = self.temp_data[0] + 1 + self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec)) - if cnet_index_float == self.temp_data[0]: - img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength) - if cnet_blocks == self.temp_data[0] + 1: - self.temp_data = None + if cnet_index_float == self.temp_data[0]: + img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength) + if cnet_blocks == self.temp_data[0] + 1: + self.temp_data = None return kwargs @@ -386,7 +422,9 @@ class QwenImageDiffsynthControlnet: mask = 1.0 - mask if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control): - model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength)) + patch = ZImageControlPatch(model_patch, vae, image, strength, mask=mask) + model_patched.set_model_noise_refiner_patch(patch) + model_patched.set_model_double_block_patch(patch) else: model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) return (model_patched,)