mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-12 17:27:26 +08:00
draft zeta (z-image pixel space)
This commit is contained in:
parent
1080bd442a
commit
4d6a501669
4
.gitignore
vendored
4
.gitignore
vendored
@ -24,3 +24,7 @@ web_custom_versions/
|
||||
openapi.yaml
|
||||
filtered-openapi.yaml
|
||||
uv.lock
|
||||
|
||||
init.sh
|
||||
/.vscode
|
||||
|
||||
|
||||
@ -776,3 +776,24 @@ class ChromaRadiance(LatentFormat):
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
|
||||
class ZImagePixelSpace(LatentFormat):
|
||||
"""Pixel-space latent format for ZImage DCT variant.
|
||||
No VAE encoding/decoding — the model operates directly on RGB pixels.
|
||||
"""
|
||||
latent_channels = 3
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
# R G B
|
||||
[ 1.0, 0.0, 0.0 ],
|
||||
[ 0.0, 1.0, 0.0 ],
|
||||
[ 0.0, 0.0, 1.0 ]
|
||||
]
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -858,3 +859,319 @@ class NextDiT(nn.Module):
|
||||
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||
return -img
|
||||
|
||||
|
||||
#############################################################################
|
||||
# Pixel Space Decoder Components #
|
||||
#############################################################################
|
||||
|
||||
def _modulate_shift_scale(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
"""
|
||||
Combines input pixel features with 2D DCT-like positional encodings before
|
||||
projecting to the decoder hidden size.
|
||||
|
||||
Input: [B, P^2, C]
|
||||
Output: [B, P^2, hidden_size]
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, hidden_size_input: int, max_freqs: int):
|
||||
super().__init__()
|
||||
self.max_freqs = max_freqs
|
||||
self.hidden_size_input = hidden_size_input
|
||||
self.embedder = nn.Sequential(
|
||||
nn.Linear(in_channels + max_freqs ** 2, hidden_size_input)
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=4)
|
||||
def fetch_pos(self, patch_size: int, device, dtype):
|
||||
"""Generates and caches 2D DCT-like positional embeddings."""
|
||||
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
|
||||
|
||||
pos_x = pos_x.reshape(-1, 1, 1)
|
||||
pos_y = pos_y.reshape(-1, 1, 1)
|
||||
|
||||
freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
|
||||
freqs_x = freqs[None, :, None]
|
||||
freqs_y = freqs[None, None, :]
|
||||
|
||||
coeffs = (1 + freqs_x * freqs_y) ** -1
|
||||
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
|
||||
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
|
||||
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
|
||||
|
||||
return dct
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
B, P2, C = inputs.shape
|
||||
original_dtype = inputs.dtype
|
||||
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
patch_size = int(P2 ** 0.5)
|
||||
inputs = inputs.float()
|
||||
dct = self.fetch_pos(patch_size, inputs.device, torch.float32)
|
||||
dct = dct.expand(B, -1, -1)
|
||||
inputs = torch.cat([inputs, dct], dim=-1)
|
||||
inputs = self.embedder.float()(inputs)
|
||||
|
||||
return inputs.to(original_dtype)
|
||||
|
||||
|
||||
class PixelResBlock(nn.Module):
|
||||
"""
|
||||
Residual block with AdaLN modulation, zero-initialised so it starts as
|
||||
an identity at the beginning of training.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.in_ln = nn.LayerNorm(channels, eps=1e-6)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(channels, channels, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(channels, channels, bias=True),
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(channels, 3 * channels, bias=True),
|
||||
)
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.mlp:
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="linear")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
# Zero-init modulation → identity at init
|
||||
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
|
||||
h = _modulate_shift_scale(self.in_ln(x), shift, scale)
|
||||
h = self.mlp(h)
|
||||
return x + gate * h
|
||||
|
||||
|
||||
class DCTFinalLayer(nn.Module):
|
||||
"""Zero-initialised output projection (adopted from DiT)."""
|
||||
|
||||
def __init__(self, model_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(model_channels, out_channels, bias=True)
|
||||
nn.init.constant_(self.linear.weight, 0)
|
||||
nn.init.constant_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(self.norm_final(x))
|
||||
|
||||
|
||||
class SimpleMLPAdaLN(nn.Module):
|
||||
"""
|
||||
Small MLP decoder head for the pixel-space variant.
|
||||
|
||||
Takes per-patch pixel values and a per-patch conditioning vector from the
|
||||
transformer backbone and predicts the denoised pixel values.
|
||||
|
||||
x : [B*N, P^2, C] – noisy pixel values per patch position
|
||||
c : [B*N, dim] – backbone hidden state per patch (conditioning)
|
||||
→ [B*N, P^2, C]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
out_channels: int,
|
||||
z_channels: int,
|
||||
num_res_blocks: int,
|
||||
patch_size: int,
|
||||
max_freqs: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
# Project backbone hidden state → per-position conditioning
|
||||
self.cond_embed = nn.Linear(z_channels, patch_size ** 2 * model_channels)
|
||||
nn.init.xavier_uniform_(self.cond_embed.weight)
|
||||
nn.init.constant_(self.cond_embed.bias, 0)
|
||||
|
||||
# Input projection with DCT positional encoding
|
||||
self.input_embedder = NerfEmbedder(
|
||||
in_channels=in_channels,
|
||||
hidden_size_input=model_channels,
|
||||
max_freqs=max_freqs,
|
||||
)
|
||||
|
||||
# Residual blocks
|
||||
self.res_blocks = nn.ModuleList([
|
||||
PixelResBlock(model_channels) for _ in range(num_res_blocks)
|
||||
])
|
||||
|
||||
# Output projection
|
||||
self.final_layer = DCTFinalLayer(model_channels, out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
# x: [B*N, P^2, C], c: [B*N, dim]
|
||||
x = self.input_embedder(x) # [B*N, P^2, model_channels]
|
||||
y = self.cond_embed(c).reshape(c.shape[0], self.patch_size ** 2, -1) # [B*N, P^2, model_channels]
|
||||
for block in self.res_blocks:
|
||||
x = block(x, y)
|
||||
return self.final_layer(x) # [B*N, P^2, C]
|
||||
|
||||
|
||||
#############################################################################
|
||||
# NextDiT – Pixel Space #
|
||||
#############################################################################
|
||||
|
||||
class NextDiTPixelSpace(NextDiT):
|
||||
"""
|
||||
Pixel-space variant of NextDiT.
|
||||
|
||||
Identical transformer backbone to NextDiT, but the output head is replaced
|
||||
with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values
|
||||
per patch rather than a single affine projection.
|
||||
|
||||
Key differences vs NextDiT:
|
||||
• ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead.
|
||||
• ``_forward`` stores the raw patchified pixel values before the backbone
|
||||
embedding and feeds them to ``dec_net`` together with the per-patch
|
||||
backbone hidden states.
|
||||
• Supports optional x0 prediction via ``use_x0``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# decoder-specific
|
||||
decoder_hidden_size: int = 3840,
|
||||
decoder_num_res_blocks: int = 4,
|
||||
decoder_max_freqs: int = 8,
|
||||
use_x0: bool = False,
|
||||
# all NextDiT args forwarded unchanged
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Remove the latent-space final layer – not used in pixel space
|
||||
del self.final_layer
|
||||
|
||||
patch_size = kwargs.get("patch_size", 2)
|
||||
in_channels = kwargs.get("in_channels", 4)
|
||||
dim = kwargs.get("dim", 4096)
|
||||
|
||||
self.dec_net = SimpleMLPAdaLN(
|
||||
in_channels=in_channels,
|
||||
model_channels=decoder_hidden_size,
|
||||
out_channels=in_channels,
|
||||
z_channels=dim,
|
||||
num_res_blocks=decoder_num_res_blocks,
|
||||
patch_size=patch_size,
|
||||
max_freqs=decoder_max_freqs,
|
||||
)
|
||||
|
||||
if use_x0:
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Override patchify_and_embed to also return the raw pixel patches
|
||||
# ------------------------------------------------------------------
|
||||
def patchify_and_embed(self, x, cap_feats, cap_mask, t, num_tokens, transformer_options={}):
|
||||
# Run the parent implementation unchanged; we capture pixel values
|
||||
# separately in _forward before calling this.
|
||||
return super().patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Forward
|
||||
# ------------------------------------------------------------------
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
bs, c, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
|
||||
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
|
||||
adaln_input = t
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
|
||||
if self.clip_text_pooled_proj is not None:
|
||||
pooled = kwargs.get("clip_text_pooled", None)
|
||||
if pooled is not None:
|
||||
pooled = self.clip_text_pooled_proj(pooled)
|
||||
else:
|
||||
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||
|
||||
# ---- capture raw pixel patches before backbone embedding ----
|
||||
pH = pW = self.patch_size
|
||||
B, C, H, W = x.shape
|
||||
# [B, N, P*P*C] (same layout as what x_embedder receives)
|
||||
pixel_patches = (
|
||||
x.view(B, C, H // pH, pH, W // pW, pW)
|
||||
.permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C]
|
||||
.flatten(3) # [B, Ht, Wt, pH*pW*C]
|
||||
.flatten(1, 2) # [B, N, pH*pW*C]
|
||||
)
|
||||
# reshape to [B*N, P^2, C] for the decoder
|
||||
N = pixel_patches.shape[1]
|
||||
pixel_values = pixel_patches.reshape(B * N, pH * pW, C)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(
|
||||
x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options
|
||||
)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
if "img" in out:
|
||||
img[:, cap_size[0]:] = out["img"]
|
||||
if "txt" in out:
|
||||
img[:, :cap_size[0]] = out["txt"]
|
||||
|
||||
# ---- pixel-space decoder ----
|
||||
# img: [B, txt_len+N, dim] → extract image tokens → [B, N, dim]
|
||||
img_hidden = img[:, cap_size[0]:, :] # [B, N, dim]
|
||||
# per-patch conditioning: [B*N, dim]
|
||||
decoder_cond = img_hidden.reshape(B * N, self.dim)
|
||||
|
||||
# decode: [B*N, P^2, C]
|
||||
output = self.dec_net(pixel_values, decoder_cond)
|
||||
|
||||
# reshape back: [B*N, P^2, C] → [B, N, P^2*C]
|
||||
output = output.reshape(B, N, -1)
|
||||
|
||||
# unpatchify expects [B, txt_len+N, P^2*C] with cap tokens prepended
|
||||
# re-prepend a zero placeholder for the cap positions so unpatchify works
|
||||
cap_placeholder = torch.zeros(
|
||||
B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype
|
||||
)
|
||||
img_out = torch.cat([cap_placeholder, output], dim=1)
|
||||
|
||||
img_out = self.unpatchify(img_out, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||
|
||||
return -img_out
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
# _forward returns -x0 (negated decoder output, matching the latent-space convention).
|
||||
# Reference x0→v conversion: v = (noisy - out) / t, where out = -x0
|
||||
# → v = (noisy - (-x0)) / t = (noisy + x0) / t
|
||||
# Since neg_x0 = -x0: v = (x - neg_x0) / t
|
||||
neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
|
||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
return (x - neg_x0) / timesteps.view(-1, 1, 1, 1)
|
||||
|
||||
@ -1214,6 +1214,11 @@ class Lumina2(BaseModel):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
|
||||
class ZImagePixelSpace(Lumina2):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
|
||||
@ -464,6 +464,34 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
if sig_weight is not None:
|
||||
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
|
||||
|
||||
dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix)
|
||||
if dec_cond_key in state_dict_keys: # pixel-space variant
|
||||
dit_config["image_model"] = "zimage_pixel"
|
||||
w = state_dict[dec_cond_key] # [patch_size^2 * decoder_hidden_size, dim]
|
||||
dit_config["decoder_hidden_size"] = w.shape[0] // (dit_config["patch_size"] ** 2)
|
||||
dit_config["decoder_num_res_blocks"] = count_blocks(
|
||||
state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.'
|
||||
)
|
||||
dit_config["decoder_max_freqs"] = 8 # fixed in NerfEmbedder
|
||||
dit_config["in_channels"] = w.shape[1] // dit_config["dim"] if False else \
|
||||
state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1] // (dit_config["patch_size"] ** 2)
|
||||
if '{}__x0__'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["use_x0"] = True
|
||||
|
||||
dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix)
|
||||
if dec_cond_key in state_dict_keys: # pixel-space variant
|
||||
dit_config["image_model"] = "zimage_pixel"
|
||||
w = state_dict[dec_cond_key] # [patch_size^2 * decoder_hidden_size, dim]
|
||||
dit_config["decoder_hidden_size"] = w.shape[0] // (dit_config["patch_size"] ** 2)
|
||||
dit_config["decoder_num_res_blocks"] = count_blocks(
|
||||
state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.'
|
||||
)
|
||||
dit_config["decoder_max_freqs"] = 8 # fixed in NerfEmbedder
|
||||
dit_config["in_channels"] = w.shape[1] // dit_config["dim"] if False else \
|
||||
state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1] // (dit_config["patch_size"] ** 2)
|
||||
if '{}__x0__'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["use_x0"] = True
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
|
||||
@ -1118,6 +1118,20 @@ class ZImage(Lumina2):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
|
||||
|
||||
class ZImagePixelSpace(ZImage):
|
||||
unet_config = {
|
||||
"image_model": "zimage_pixel",
|
||||
}
|
||||
|
||||
# Pixel-space model: no spatial compression, operates on raw RGB patches.
|
||||
latent_format = latent_formats.ZImagePixelSpace
|
||||
|
||||
# Much lower memory than latent-space models (no VAE, small patches).
|
||||
memory_usage_factor = 0.05 # TODO: figure out the optimal value for this.
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.ZImagePixelSpace(self, device=device)
|
||||
|
||||
class WAN21_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user