From 4d6a5016693fd430006efc84623b418571a89f43 Mon Sep 17 00:00:00 2001 From: lodestone-rock Date: Sat, 28 Feb 2026 17:25:10 +0700 Subject: [PATCH] draft zeta (z-image pixel space) --- .gitignore | 4 + comfy/latent_formats.py | 21 +++ comfy/ldm/lumina/model.py | 317 ++++++++++++++++++++++++++++++++++++++ comfy/model_base.py | 5 + comfy/model_detection.py | 28 ++++ comfy/supported_models.py | 14 ++ 6 files changed, 389 insertions(+) diff --git a/.gitignore b/.gitignore index 2700ad5c2..bbe32a8e8 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,7 @@ web_custom_versions/ openapi.yaml filtered-openapi.yaml uv.lock + +init.sh +/.vscode + diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f59999af6..30fbe43e3 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -776,3 +776,24 @@ class ChromaRadiance(LatentFormat): def process_out(self, latent): return latent + + +class ZImagePixelSpace(LatentFormat): + """Pixel-space latent format for ZImage DCT variant. + No VAE encoding/decoding — the model operates directly on RGB pixels. + """ + latent_channels = 3 + + def __init__(self): + self.latent_rgb_factors = [ + # R G B + [ 1.0, 0.0, 0.0 ], + [ 0.0, 1.0, 0.0 ], + [ 0.0, 0.0, 1.0 ] + ] + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 77d1abc97..90a6632f0 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -1,6 +1,7 @@ # Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py from __future__ import annotations +from functools import lru_cache from typing import List, Optional, Tuple import torch @@ -858,3 +859,319 @@ class NextDiT(nn.Module): img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] return -img + +############################################################################# +# Pixel Space Decoder Components # +############################################################################# + +def _modulate_shift_scale(x, shift, scale): + return x * (1 + scale) + shift + + +class NerfEmbedder(nn.Module): + """ + Combines input pixel features with 2D DCT-like positional encodings before + projecting to the decoder hidden size. + + Input: [B, P^2, C] + Output: [B, P^2, hidden_size] + """ + + def __init__(self, in_channels: int, hidden_size_input: int, max_freqs: int): + super().__init__() + self.max_freqs = max_freqs + self.hidden_size_input = hidden_size_input + self.embedder = nn.Sequential( + nn.Linear(in_channels + max_freqs ** 2, hidden_size_input) + ) + + @lru_cache(maxsize=4) + def fetch_pos(self, patch_size: int, device, dtype): + """Generates and caches 2D DCT-like positional embeddings.""" + pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) + pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") + + pos_x = pos_x.reshape(-1, 1, 1) + pos_y = pos_y.reshape(-1, 1, 1) + + freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device) + freqs_x = freqs[None, :, None] + freqs_y = freqs[None, None, :] + + coeffs = (1 + freqs_x * freqs_y) ** -1 + dct_x = torch.cos(pos_x * freqs_x * torch.pi) + dct_y = torch.cos(pos_y * freqs_y * torch.pi) + dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2) + + return dct + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + B, P2, C = inputs.shape + original_dtype = inputs.dtype + + with torch.autocast("cuda", enabled=False): + patch_size = int(P2 ** 0.5) + inputs = inputs.float() + dct = self.fetch_pos(patch_size, inputs.device, torch.float32) + dct = dct.expand(B, -1, -1) + inputs = torch.cat([inputs, dct], dim=-1) + inputs = self.embedder.float()(inputs) + + return inputs.to(original_dtype) + + +class PixelResBlock(nn.Module): + """ + Residual block with AdaLN modulation, zero-initialised so it starts as + an identity at the beginning of training. + """ + + def __init__(self, channels: int): + super().__init__() + self.in_ln = nn.LayerNorm(channels, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(channels, channels, bias=True), + nn.SiLU(), + nn.Linear(channels, channels, bias=True), + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 3 * channels, bias=True), + ) + self._init_weights() + + def _init_weights(self): + for m in self.mlp: + if isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, nonlinearity="linear") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + # Zero-init modulation → identity at init + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1) + h = _modulate_shift_scale(self.in_ln(x), shift, scale) + h = self.mlp(h) + return x + gate * h + + +class DCTFinalLayer(nn.Module): + """Zero-initialised output projection (adopted from DiT).""" + + def __init__(self, model_channels: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(model_channels, out_channels, bias=True) + nn.init.constant_(self.linear.weight, 0) + nn.init.constant_(self.linear.bias, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.norm_final(x)) + + +class SimpleMLPAdaLN(nn.Module): + """ + Small MLP decoder head for the pixel-space variant. + + Takes per-patch pixel values and a per-patch conditioning vector from the + transformer backbone and predicts the denoised pixel values. + + x : [B*N, P^2, C] – noisy pixel values per patch position + c : [B*N, dim] – backbone hidden state per patch (conditioning) + → [B*N, P^2, C] + """ + + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + z_channels: int, + num_res_blocks: int, + patch_size: int, + max_freqs: int = 8, + ): + super().__init__() + self.patch_size = patch_size + + # Project backbone hidden state → per-position conditioning + self.cond_embed = nn.Linear(z_channels, patch_size ** 2 * model_channels) + nn.init.xavier_uniform_(self.cond_embed.weight) + nn.init.constant_(self.cond_embed.bias, 0) + + # Input projection with DCT positional encoding + self.input_embedder = NerfEmbedder( + in_channels=in_channels, + hidden_size_input=model_channels, + max_freqs=max_freqs, + ) + + # Residual blocks + self.res_blocks = nn.ModuleList([ + PixelResBlock(model_channels) for _ in range(num_res_blocks) + ]) + + # Output projection + self.final_layer = DCTFinalLayer(model_channels, out_channels) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + # x: [B*N, P^2, C], c: [B*N, dim] + x = self.input_embedder(x) # [B*N, P^2, model_channels] + y = self.cond_embed(c).reshape(c.shape[0], self.patch_size ** 2, -1) # [B*N, P^2, model_channels] + for block in self.res_blocks: + x = block(x, y) + return self.final_layer(x) # [B*N, P^2, C] + + +############################################################################# +# NextDiT – Pixel Space # +############################################################################# + +class NextDiTPixelSpace(NextDiT): + """ + Pixel-space variant of NextDiT. + + Identical transformer backbone to NextDiT, but the output head is replaced + with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values + per patch rather than a single affine projection. + + Key differences vs NextDiT: + • ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead. + • ``_forward`` stores the raw patchified pixel values before the backbone + embedding and feeds them to ``dec_net`` together with the per-patch + backbone hidden states. + • Supports optional x0 prediction via ``use_x0``. + """ + + def __init__( + self, + # decoder-specific + decoder_hidden_size: int = 3840, + decoder_num_res_blocks: int = 4, + decoder_max_freqs: int = 8, + use_x0: bool = False, + # all NextDiT args forwarded unchanged + **kwargs, + ): + super().__init__(**kwargs) + + # Remove the latent-space final layer – not used in pixel space + del self.final_layer + + patch_size = kwargs.get("patch_size", 2) + in_channels = kwargs.get("in_channels", 4) + dim = kwargs.get("dim", 4096) + + self.dec_net = SimpleMLPAdaLN( + in_channels=in_channels, + model_channels=decoder_hidden_size, + out_channels=in_channels, + z_channels=dim, + num_res_blocks=decoder_num_res_blocks, + patch_size=patch_size, + max_freqs=decoder_max_freqs, + ) + + if use_x0: + self.register_buffer("__x0__", torch.tensor([])) + + # ------------------------------------------------------------------ + # Override patchify_and_embed to also return the raw pixel patches + # ------------------------------------------------------------------ + def patchify_and_embed(self, x, cap_feats, cap_mask, t, num_tokens, transformer_options={}): + # Run the parent implementation unchanged; we capture pixel values + # separately in _forward before calling this. + return super().patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options) + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): + t = 1.0 - timesteps + cap_feats = context + cap_mask = attention_mask + bs, c, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + + t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D) + adaln_input = t + + cap_feats = self.cap_embedder(cap_feats) + + if self.clip_text_pooled_proj is not None: + pooled = kwargs.get("clip_text_pooled", None) + if pooled is not None: + pooled = self.clip_text_pooled_proj(pooled) + else: + pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype) + adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) + + # ---- capture raw pixel patches before backbone embedding ---- + pH = pW = self.patch_size + B, C, H, W = x.shape + # [B, N, P*P*C] (same layout as what x_embedder receives) + pixel_patches = ( + x.view(B, C, H // pH, pH, W // pW, pW) + .permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C] + .flatten(3) # [B, Ht, Wt, pH*pW*C] + .flatten(1, 2) # [B, N, pH*pW*C] + ) + # reshape to [B*N, P^2, C] for the decoder + N = pixel_patches.shape[1] + pixel_values = pixel_patches.reshape(B * N, pH * pW, C) + + patches = transformer_options.get("patches", {}) + x_is_tensor = isinstance(x, torch.Tensor) + img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed( + x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options + ) + freqs_cis = freqs_cis.to(img.device) + + for i, layer in enumerate(self.layers): + img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + if "img" in out: + img[:, cap_size[0]:] = out["img"] + if "txt" in out: + img[:, :cap_size[0]] = out["txt"] + + # ---- pixel-space decoder ---- + # img: [B, txt_len+N, dim] → extract image tokens → [B, N, dim] + img_hidden = img[:, cap_size[0]:, :] # [B, N, dim] + # per-patch conditioning: [B*N, dim] + decoder_cond = img_hidden.reshape(B * N, self.dim) + + # decode: [B*N, P^2, C] + output = self.dec_net(pixel_values, decoder_cond) + + # reshape back: [B*N, P^2, C] → [B, N, P^2*C] + output = output.reshape(B, N, -1) + + # unpatchify expects [B, txt_len+N, P^2*C] with cap tokens prepended + # re-prepend a zero placeholder for the cap positions so unpatchify works + cap_placeholder = torch.zeros( + B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype + ) + img_out = torch.cat([cap_placeholder, output], dim=1) + + img_out = self.unpatchify(img_out, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] + + return -img_out + + def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + # _forward returns -x0 (negated decoder output, matching the latent-space convention). + # Reference x0→v conversion: v = (noisy - out) / t, where out = -x0 + # → v = (noisy - (-x0)) / t = (noisy + x0) / t + # Since neg_x0 = -x0: v = (x - neg_x0) / t + neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) + + return (x - neg_x0) / timesteps.view(-1, 1, 1, 1) diff --git a/comfy/model_base.py b/comfy/model_base.py index a1c690b9b..1842b3b08 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1214,6 +1214,11 @@ class Lumina2(BaseModel): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) self.memory_usage_factor_conds = ("ref_latents",) + +class ZImagePixelSpace(Lumina2): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) + def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) attention_mask = kwargs.get("attention_mask", None) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 3faa950ca..2f3e185f7 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -464,6 +464,34 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if sig_weight is not None: dit_config["siglip_feat_dim"] = sig_weight.shape[0] + dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix) + if dec_cond_key in state_dict_keys: # pixel-space variant + dit_config["image_model"] = "zimage_pixel" + w = state_dict[dec_cond_key] # [patch_size^2 * decoder_hidden_size, dim] + dit_config["decoder_hidden_size"] = w.shape[0] // (dit_config["patch_size"] ** 2) + dit_config["decoder_num_res_blocks"] = count_blocks( + state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.' + ) + dit_config["decoder_max_freqs"] = 8 # fixed in NerfEmbedder + dit_config["in_channels"] = w.shape[1] // dit_config["dim"] if False else \ + state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1] // (dit_config["patch_size"] ** 2) + if '{}__x0__'.format(key_prefix) in state_dict_keys: + dit_config["use_x0"] = True + + dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix) + if dec_cond_key in state_dict_keys: # pixel-space variant + dit_config["image_model"] = "zimage_pixel" + w = state_dict[dec_cond_key] # [patch_size^2 * decoder_hidden_size, dim] + dit_config["decoder_hidden_size"] = w.shape[0] // (dit_config["patch_size"] ** 2) + dit_config["decoder_num_res_blocks"] = count_blocks( + state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.' + ) + dit_config["decoder_max_freqs"] = 8 # fixed in NerfEmbedder + dit_config["in_channels"] = w.shape[1] // dit_config["dim"] if False else \ + state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1] // (dit_config["patch_size"] ** 2) + if '{}__x0__'.format(key_prefix) in state_dict_keys: + dit_config["use_x0"] = True + return dit_config if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 4f63e8327..5259e143f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1118,6 +1118,20 @@ class ZImage(Lumina2): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect)) +class ZImagePixelSpace(ZImage): + unet_config = { + "image_model": "zimage_pixel", + } + + # Pixel-space model: no spatial compression, operates on raw RGB patches. + latent_format = latent_formats.ZImagePixelSpace + + # Much lower memory than latent-space models (no VAE, small patches). + memory_usage_factor = 0.05 # TODO: figure out the optimal value for this. + + def get_model(self, state_dict, prefix="", device=None): + return model_base.ZImagePixelSpace(self, device=device) + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1",