From 7b72f322a58aaa7c358c5704bd5274a936fb0246 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 25 May 2026 16:24:54 +0300 Subject: [PATCH] Support PixelDiT and PiD --- comfy/latent_formats.py | 4 +- comfy/ldm/modules/diffusionmodules/mmdit.py | 5 +- comfy/ldm/pixeldit/__init__.py | 0 comfy/ldm/pixeldit/model.py | 270 ++++++++++++++++++ comfy/ldm/pixeldit/modules.py | 199 ++++++++++++++ comfy/ldm/pixeldit/pid.py | 286 ++++++++++++++++++++ comfy/model_base.py | 37 +++ comfy/model_detection.py | 17 ++ comfy/sd.py | 10 +- comfy/supported_models.py | 68 +++++ comfy/text_encoders/pixeldit.py | 124 +++++++++ comfy_cuda_memory_history20260525_155146.pt | Bin 0 -> 15684590 bytes comfy_extras/nodes_pid.py | 60 ++++ nodes.py | 3 +- 14 files changed, 1077 insertions(+), 6 deletions(-) create mode 100644 comfy/ldm/pixeldit/__init__.py create mode 100644 comfy/ldm/pixeldit/model.py create mode 100644 comfy/ldm/pixeldit/modules.py create mode 100644 comfy/ldm/pixeldit/pid.py create mode 100644 comfy/text_encoders/pixeldit.py create mode 100644 comfy_cuda_memory_history20260525_155146.pt create mode 100644 comfy_extras/nodes_pid.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index d527eec4a..70cb433b0 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -792,13 +792,15 @@ class ZImagePixelSpace(ChromaRadiance): """ pass - class HiDreamO1Pixel(ChromaRadiance): """Pixel-space latent format for HiDream-O1. No VAE — model patches/unpatches raw RGB internally with patch_size=32. """ pass +class PixelDiTPixel(ChromaRadiance): + pass + class CogVideoX(LatentFormat): """Latent format for CogVideoX-2b (THUDM/CogVideoX-2b). diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 0dc8fe789..9ab3c463c 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -211,7 +211,7 @@ class TimestepEmbedder(nn.Module): Embeds scalar timesteps into vector representations. """ - def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None): + def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None, max_period=10000): super().__init__() if output_size is None: output_size = hidden_size @@ -221,9 +221,10 @@ class TimestepEmbedder(nn.Module): operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device), ) self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period def forward(self, t, dtype, **kwargs): - t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype) + t_freq = timestep_embedding(t, self.frequency_embedding_size, max_period=self.max_period).to(dtype) t_emb = self.mlp(t_freq) return t_emb diff --git a/comfy/ldm/pixeldit/__init__.py b/comfy/ldm/pixeldit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/pixeldit/model.py b/comfy/ldm/pixeldit/model.py new file mode 100644 index 000000000..3b35b1a96 --- /dev/null +++ b/comfy/ldm/pixeldit/model.py @@ -0,0 +1,270 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.patcher_extension +from comfy.ldm.flux.math import apply_rope, rope +from comfy.ldm.hidream.model import FeedForwardSwiGLU +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder + +from .modules import ( + FinalLayer, + PatchTokenEmbedder, + PiTBlock, + PixelTokenEmbedder, + apply_adaln, + precompute_freqs_cis_2d, +) + + +class MMDiTJointAttention(nn.Module): + """Joint MMDiT attention with separate Q/K/V/proj for image and text streams. + + RoPE is applied to each stream before concatenation so each stream uses its own + 2D/1D positional encoding. Concat order is [text, image] (text first). + """ + def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None): + super().__init__() + assert dim % num_heads == 0 + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + + self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + + self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device) + self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device) + + def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}): + B, Nx, _ = x.shape + _, Ny, _ = y.shape + H = self.num_heads + D = self.head_dim + + qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4) + qx, kx, vx = qkv_x.unbind(0) + qx = self.q_norm_x(qx) + kx = self.k_norm_x(kx) + + qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4) + qy, ky, vy = qkv_y.unbind(0) + qy = self.q_norm_y(qy) + ky = self.k_norm_y(ky) + + qx, kx = apply_rope(qx, kx, pos_img[None, None]) + if pos_txt is not None: + qy, ky = apply_rope(qy, ky, pos_txt[None, None]) + + q_joint = torch.cat([qy, qx], dim=2) + k_joint = torch.cat([ky, kx], dim=2) + v_joint = torch.cat([vy, vx], dim=2) + + out_joint = optimized_attention( + q_joint, k_joint, v_joint, H, + mask=attn_mask, skip_reshape=True, skip_output_reshape=True, + transformer_options=transformer_options, + ) + + out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D) + out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D) + + return self.proj_x(out_x), self.proj_y(out_y) + + +class MMDiTBlockT2I(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None): + super().__init__() + self.hidden_size = hidden_size + self.groups = groups + self.head_dim = hidden_size // groups + + self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False, + dtype=dtype, device=device, operations=operations) + self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, + dtype=dtype, device=device, operations=operations) + self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, + dtype=dtype, device=device, operations=operations) + self.adaLN_modulation_img = nn.Sequential( + operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device), + ) + self.adaLN_modulation_txt = nn.Sequential( + operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device), + ) + + def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}): + shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1) + shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1) + + x_norm = apply_adaln(self.norm_x1(x), shift_msa_x, scale_msa_x) + y_norm = apply_adaln(self.norm_y1(y), shift_msa_y, scale_msa_y) + attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options) + x = torch.addcmul(x, gate_msa_x, attn_x) + y = torch.addcmul(y, gate_msa_y, attn_y) + + x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln(self.norm_x2(x), shift_mlp_x, scale_mlp_x))) + y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln(self.norm_y2(y), shift_mlp_y, scale_mlp_y))) + return x, y + + +class PixDiT_T2I(nn.Module): + """PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint + (also runs at 512px when fed the appropriate latent size and flow_shift). + + Forward: + x: [B, 3, H, W] pixel-space input (no VAE) + timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention) + context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended) + Returns flow-matching velocity [B, 3, H, W]. + """ + def __init__( + self, + in_channels=3, + num_groups=24, + hidden_size=1536, + pixel_hidden_size=16, + pixel_attn_hidden_size=1152, + pixel_num_groups=16, + patch_depth=14, + pixel_depth=2, + patch_size=16, + txt_embed_dim=2304, + txt_max_length=300, + use_text_rope=True, + text_rope_theta=10000.0, + use_pixel_abs_pos=True, + image_model=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + self.in_channels = int(in_channels) + self.out_channels = int(in_channels) + self.hidden_size = int(hidden_size) + self.num_groups = int(num_groups) + self.patch_depth = int(patch_depth) + self.pixel_depth = int(pixel_depth) + self.patch_size = int(patch_size) + self.pixel_hidden_size = int(pixel_hidden_size) + self.pixel_attn_hidden_size = int(pixel_attn_hidden_size) + self.pixel_num_groups = int(pixel_num_groups) + self.txt_embed_dim = int(txt_embed_dim) + self.txt_max_length = int(txt_max_length) + self.use_text_rope = bool(use_text_rope) + self.text_rope_theta = float(text_rope_theta) + self.use_pixel_abs_pos = bool(use_pixel_abs_pos) + + self.pixel_embedder = PixelTokenEmbedder( + self.in_channels, self.pixel_hidden_size, use_pixel_abs_pos=self.use_pixel_abs_pos, + dtype=dtype, device=device, operations=operations, + ) + self.s_embedder = PatchTokenEmbedder( + self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True, + dtype=dtype, device=device, operations=operations, + ) + self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10) + self.y_embedder = PatchTokenEmbedder( + self.txt_embed_dim, self.hidden_size, bias=True, norm_layer=True, + dtype=dtype, device=device, operations=operations, + ) + self.y_pos_embedding = nn.Parameter( + torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device) + ) + + self.patch_blocks = nn.ModuleList([ + MMDiTBlockT2I(self.hidden_size, self.num_groups, + dtype=dtype, device=device, operations=operations) + for _ in range(self.patch_depth) + ]) + self.pixel_blocks = nn.ModuleList([ + PiTBlock( + self.pixel_hidden_size, + self.hidden_size, + patch_size=self.patch_size, + num_heads=self.num_groups, + mlp_ratio=4.0, + attn_hidden_size=self.pixel_attn_hidden_size, + attn_num_heads=self.pixel_num_groups, + dtype=dtype, device=device, operations=operations, + ) + for _ in range(self.pixel_depth) + ]) + + self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels, + dtype=dtype, device=device, operations=operations) + + self._patch_pos_cache = {} + self._text_pos_cache = {} + + def _fetch_patch_pos(self, height, width, device, dtype): + key = (height, width) + pos = self._patch_pos_cache.get(key) + if pos is None: + pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width) + self._patch_pos_cache[key] = pos + return pos.to(device=device, dtype=dtype) + + def _fetch_text_pos(self, length, device, dtype): + pos = self._text_pos_cache.get(length) + if pos is None: + pos = rope(torch.arange(length, dtype=torch.float32).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0) + self._text_pos_cache[length] = pos + return pos.to(device=device, dtype=dtype) + + def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options), + ).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs) + + def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs): + B, _, H, W = x.shape + Hs = H // self.patch_size + Ws = W // self.patch_size + L = Hs * Ws + + pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype) + x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + + t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size) + + if context is None or context.dim() != 3: + raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]") + Ltxt = min(context.shape[1], self.txt_max_length) + y = context[:, :Ltxt, :] + y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size) + y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb.dtype) + + condition = F.silu(t_emb) + pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None + + s = self.s_embedder(x_patches) + for blk in self.patch_blocks: + s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options) + s = F.silu(t_emb + s) + + s_cond = s.view(B * L, self.hidden_size) + x_pixels = self.pixel_embedder(x, img_height=H, img_width=W, patch_size=self.patch_size) + for blk in self.pixel_blocks: + x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options) + + x_pixels = self.final_layer(x_pixels) + C_out = self.out_channels + P2 = self.patch_size * self.patch_size + x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).contiguous() + x_pixels = x_pixels.view(B, C_out * P2, L) + return F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size) diff --git a/comfy/ldm/pixeldit/modules.py b/comfy/ldm/pixeldit/modules.py new file mode 100644 index 000000000..144735353 --- /dev/null +++ b/comfy/ldm/pixeldit/modules.py @@ -0,0 +1,199 @@ +import torch +import torch.nn as nn + +from comfy.ldm.flux.math import apply_rope +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.diffusionmodules.mmdit import Mlp + + +def apply_adaln(x, shift, scale): + return torch.addcmul(x + shift, x, scale) + + +def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0, device=None, dtype=torch.float32): + """2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim. + + Returns Flux-format rotation matrices of shape [H*W, dim/2, 2, 2]. + Layout of head-dim pairs: [x_0, y_0, x_1, y_1, ..., x_{dim/4-1}, y_{dim/4-1}]. + """ + x_pos = torch.linspace(0, scale, width, device=device) + y_pos = torch.linspace(0, scale, height, device=device) + y_grid, x_grid = torch.meshgrid(y_pos, x_pos, indexing="ij") + x_pos = x_grid.reshape(-1) + y_pos = y_grid.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4, device=device, dtype=torch.float32)[: (dim // 4)] / dim)) + x_freqs = torch.outer(x_pos, freqs) + y_freqs = torch.outer(y_pos, freqs) + freqs_interleaved = torch.stack([x_freqs, y_freqs], dim=-1).reshape(height * width, -1) + cos = torch.cos(freqs_interleaved) + sin = torch.sin(freqs_interleaved) + out = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*cos.shape, 2, 2) + return out.to(dtype=dtype) + + +def get_2d_sincos_pos_embed(embed_dim, height, width, device=None, dtype=torch.float32): + """Torch port of MAE's 2D sin/cos absolute positional embedding for the pixel embedder. + + first half encodes W-coordinates, second half H. + """ + assert embed_dim % 4 == 0 + grid_h = torch.arange(height, dtype=torch.float32, device=device) + grid_w = torch.arange(width, dtype=torch.float32, device=device) + grid_y, grid_x = torch.meshgrid(grid_h, grid_w, indexing="ij") + grid_y = grid_y.reshape(-1) + grid_x = grid_x.reshape(-1) + omega = torch.arange(embed_dim // 4, dtype=torch.float32, device=device) / (embed_dim / 4.0) + omega = 1.0 / (10000.0 ** omega) + out_w = torch.outer(grid_x, omega) + out_h = torch.outer(grid_y, omega) + emb_w = torch.cat([torch.sin(out_w), torch.cos(out_w)], dim=1) + emb_h = torch.cat([torch.sin(out_h), torch.cos(out_h)], dim=1) + return torch.cat([emb_w, emb_h], dim=1).to(dtype=dtype) + + +class RotaryAttention(nn.Module): + """Single-stream self-attention with rotary positional encoding (used inside PiTBlock).""" + def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None): + super().__init__() + assert dim % num_heads == 0 + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) + + def forward(self, x, pos, mask=None, transformer_options={}): + B, N, C = x.shape + H = self.num_heads + D = self.head_dim + qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q = self.q_norm(q) + k = self.k_norm(k) + q, k = apply_rope(q, k, pos[None, None]) + x = optimized_attention(q, k, v, H, mask=mask, skip_reshape=True, transformer_options=transformer_options) + return self.proj(x) + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, x): + return self.linear(self.norm(x)) + + +class PatchTokenEmbedder(nn.Module): + """Linear projection used both for patchified-image tokens and text-feature tokens.""" + def __init__(self, in_chans, embed_dim, norm_layer=None, bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.in_chans = in_chans + self.embed_dim = embed_dim + self.proj = operations.Linear(in_chans, embed_dim, bias=bias, dtype=dtype, device=device) + if norm_layer is not None: + self.norm = operations.RMSNorm(embed_dim, eps=1e-6, dtype=dtype, device=device) + else: + self.norm = nn.Identity() + + def forward(self, x): + return self.norm(self.proj(x)) + + +class PixelTokenEmbedder(nn.Module): + """Pixel-level embedder: lifts each RGB pixel to hidden_size and packs into per-patch sequences.""" + def __init__(self, in_channels, hidden_size_output, use_pixel_abs_pos=True, + dtype=None, device=None, operations=None): + super().__init__() + self.in_channels = in_channels + self.hidden_size_output = hidden_size_output + self.use_pixel_abs_pos = bool(use_pixel_abs_pos) + self.proj = operations.Linear(self.in_channels, self.hidden_size_output, bias=True, dtype=dtype, device=device) + self._pos_cache = {} + + def _fetch_pixel_pos(self, height, width, device, dtype): + key = (height, width) + pe = self._pos_cache.get(key) + if pe is None: + pe = get_2d_sincos_pos_embed(self.hidden_size_output, height, width) + self._pos_cache[key] = pe + return pe.to(device=device, dtype=dtype) + + def forward(self, inputs, img_height, img_width, patch_size): + B, C, H, W = inputs.shape + assert H == img_height and W == img_width + assert (H % patch_size == 0) and (W % patch_size == 0) + Hs, Ws = H // patch_size, W // patch_size + P2 = patch_size * patch_size + x = inputs.permute(0, 2, 3, 1).contiguous() + x = self.proj(x) + if self.use_pixel_abs_pos: + pos_full = self._fetch_pixel_pos(H, W, x.device, x.dtype) + pos_full = pos_full.view(H, W, self.hidden_size_output) + x = x + pos_full.unsqueeze(0) + x = x.view(B, Hs, patch_size, Ws, patch_size, self.hidden_size_output) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous() + return x.view(B * Hs * Ws, P2, self.hidden_size_output) + + +class PiTBlock(nn.Module): + """Pixel-level transformer block. + + Compresses each patch's P^2 pixel tokens → 1 attention token via a linear, + runs global self-attention across patches with 2D RoPE, then expands back to P^2 tokens. + Conditioning is per-pixel adaLN from the patch-level features. + """ + def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0, + attn_hidden_size=None, attn_num_heads=None, rope_fn=None, + dtype=None, device=None, operations=None): + super().__init__() + self.pixel_dim = pixel_hidden_size + self.context_dim = patch_hidden_size + self.patch_size = patch_size + self.attn_dim = attn_hidden_size if attn_hidden_size is not None else patch_hidden_size + self.num_heads = attn_num_heads if attn_num_heads is not None else num_heads + assert self.attn_dim % self.num_heads == 0 + p2 = patch_size * patch_size + self.compress_to_attn = operations.Linear(p2 * self.pixel_dim, self.attn_dim, bias=True, dtype=dtype, device=device) + self.expand_from_attn = operations.Linear(self.attn_dim, p2 * self.pixel_dim, bias=True, dtype=dtype, device=device) + self.norm1 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device) + self.attn = RotaryAttention(self.attn_dim, num_heads=self.num_heads, qkv_bias=False, + dtype=dtype, device=device, operations=operations) + self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device) + self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio), + dtype=dtype, device=device, operations=operations) + self.adaLN_modulation = nn.Sequential( + operations.Linear(self.context_dim, 6 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device), + ) + self._pos_cache = {} + self._rope_fn = rope_fn if rope_fn is not None else precompute_freqs_cis_2d + + def _fetch_pos(self, height, width, device, dtype): + key = (height, width) + pos = self._pos_cache.get(key) + if pos is None: + pos = self._rope_fn(self.attn_dim // self.num_heads, height, width) + self._pos_cache[key] = pos + return pos.to(device=device, dtype=dtype) + + def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}): + BL, P2, _ = x.shape + Hs, Ws = image_height // patch_size, image_width // patch_size + L = Hs * Ws + B = BL // L + cond_params = self.adaLN_modulation(s_cond).view(BL, P2, 6 * self.pixel_dim) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = cond_params.chunk(6, dim=-1) + x_norm = apply_adaln(self.norm1(x), shift_msa, scale_msa) + x_flat = x_norm.view(BL, P2 * self.pixel_dim) + x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim) + pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype) + attn_out = self.attn(x_comp, pos_comp, mask=mask, transformer_options=transformer_options) + attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim)) + attn_exp = attn_flat.view(BL, P2, self.pixel_dim) + x = torch.addcmul(x, gate_msa, attn_exp) + mlp_out = self.mlp(apply_adaln(self.norm2(x), shift_mlp, scale_mlp)) + x = torch.addcmul(x, gate_mlp, mlp_out) + return x diff --git a/comfy/ldm/pixeldit/pid.py b/comfy/ldm/pixeldit/pid.py new file mode 100644 index 000000000..ceb601647 --- /dev/null +++ b/comfy/ldm/pixeldit/pid.py @@ -0,0 +1,286 @@ +"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent +directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I +body + LQ projection branch injected before each MMDiT patch block. +""" + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.flux.math import rope + +from .model import PixDiT_T2I + + +def precompute_freqs_cis_2d_ntk(dim: int, height: int, width: int, + ref_grid_h: int, ref_grid_w: int, + theta: float = 10000.0, scale: float = 16.0, + device=None, dtype=torch.float32): + """NTK-aware 2D RoPE (rope_mode='ntk_aware' in upstream PiD). + + Per-axis theta = theta * (current/ref)^(dim_axis/(dim_axis-2)). Returns + [H*W, dim/2, 2, 2] with x/y axis freqs interleaved at stride 2 (matches + the head-dim layout PiD's Q/K weights expect). + """ + dim_axis = dim // 2 + h_ntk = (height / ref_grid_h) ** (dim_axis / (dim_axis - 2)) if dim_axis > 2 else 1.0 + w_ntk = (width / ref_grid_w) ** (dim_axis / (dim_axis - 2)) if dim_axis > 2 else 1.0 + + x_lin = torch.linspace(0, scale, width, device=device) + y_lin = torch.linspace(0, scale, height, device=device) + y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij") + + x_rope = rope(x_grid.reshape(1, -1), dim_axis, theta * w_ntk).squeeze(0) + y_rope = rope(y_grid.reshape(1, -1), dim_axis, theta * h_ntk).squeeze(0) + + out = torch.stack([x_rope, y_rope], dim=2).reshape(height * width, dim // 2, 2, 2) + return out.to(dtype=dtype) + + +class SigmaAwareGatePerTokenPerDim(nn.Module): + """gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq. + + Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1. + """ + + def __init__(self, dim: int, dtype=None, device=None, operations=None): + super().__init__() + self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device) + self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device)) + + def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + content_logit = self.content_proj(torch.cat([x, lq], dim=-1)) + # log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM. + log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32) + sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1) + gate = torch.sigmoid(content_logit + sigma_offset) + return x + (gate * lq).to(x.dtype) + + +class ResBlock(nn.Module): + """Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip.""" + + def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None): + super().__init__() + self.block = nn.Sequential( + operations.GroupNorm(num_groups, channels, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device), + operations.GroupNorm(num_groups, channels, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.block(x) + + +class LQProjection2D(nn.Module): + """LQ latent -> per-block patch-aligned features for controlnet-style injection.""" + + def __init__( + self, + latent_channels: int, + hidden_dim: int = 512, + out_dim: int = 1536, + patch_size: int = 16, + sr_scale: int = 4, + latent_spatial_down_factor: int = 8, + num_res_blocks: int = 4, + num_outputs: int = 14, + interval: int = 1, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + assert latent_channels > 0 + self.latent_channels = latent_channels + self.hidden_dim = hidden_dim + self.out_dim = out_dim + self.patch_size = patch_size + self.sr_scale = sr_scale + self.latent_spatial_down_factor = latent_spatial_down_factor + self.num_outputs = num_outputs + self.interval = interval + + z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size + self.z_to_patch_ratio = z_to_patch_ratio + if z_to_patch_ratio >= 1: + self.latent_upsample_ratio = int(z_to_patch_ratio) if z_to_patch_ratio > 1 else 1 + self.latent_fold_factor = 0 + latent_proj_in_ch = latent_channels + else: + fold_factor = int(1 / z_to_patch_ratio) + assert fold_factor * z_to_patch_ratio == 1.0 + self.latent_upsample_ratio = 0 + self.latent_fold_factor = fold_factor + latent_proj_in_ch = latent_channels * fold_factor * fold_factor + + layers = [ + operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device), + ] + for _ in range(num_res_blocks): + layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations)) + self.latent_proj = nn.Sequential(*layers) + + self.output_heads = nn.ModuleList( + [operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)] + ) + self.gate_modules = nn.ModuleList( + [SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations) + for _ in range(num_outputs)] + ) + + def is_gate_active(self, block_idx: int) -> bool: + return block_idx % self.interval == 0 if self.interval > 1 else True + + def output_index(self, block_idx: int) -> int: + return block_idx // self.interval if self.interval > 1 else block_idx + + def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor: + return self.gate_modules[out_idx](x, lq_feature, sigma) + + def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor: + B, z_dim = lq_latent.shape[:2] + if self.z_to_patch_ratio >= 1: + if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW: + z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest") + else: + z_aligned = lq_latent + else: + f = self.latent_fold_factor + zH_expected, zW_expected = pH * f, pW * f + if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected: + lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest") + z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4) + z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW) + return self.latent_proj(z_aligned) + + def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]: + feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW) + tokens = feat.flatten(2).transpose(1, 2) + return [head(tokens) for head in self.output_heads] + + +class PidNet(PixDiT_T2I): + """PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block).""" + + def __init__( + self, + lq_latent_channels: int = 16, + lq_hidden_dim: int = 512, + lq_num_res_blocks: int = 4, + lq_interval: int = 1, + sr_scale: int = 4, + latent_spatial_down_factor: int = 8, + rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64. + rope_ref_w: int = 1024, + image_model=None, + dtype=None, device=None, operations=None, + **pixdit_kwargs, + ): + super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs) + + self.rope_ref_grid_h = int(rope_ref_h) // int(self.patch_size) + self.rope_ref_grid_w = int(rope_ref_w) // int(self.patch_size) + + # Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware. + def _pit_rope_fn(head_dim, h, w): + return precompute_freqs_cis_2d_ntk(head_dim, h, w, self.rope_ref_grid_h, self.rope_ref_grid_w) + for blk in self.pixel_blocks: + blk._rope_fn = _pit_rope_fn + blk._pos_cache = {} + + num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval + self.lq_proj = LQProjection2D( + latent_channels=lq_latent_channels, + hidden_dim=lq_hidden_dim, + out_dim=self.hidden_size, + patch_size=self.patch_size, + sr_scale=sr_scale, + latent_spatial_down_factor=latent_spatial_down_factor, + num_res_blocks=lq_num_res_blocks, + num_outputs=num_lq_outputs, + interval=lq_interval, + dtype=dtype, + device=device, + operations=operations, + ) + + def _fetch_patch_pos(self, height, width, device, dtype): + key = (height, width) + pos = self._patch_pos_cache.get(key) + if pos is None: + pos = precompute_freqs_cis_2d_ntk( + self.hidden_size // self.num_groups, + height, width, + self.rope_ref_grid_h, self.rope_ref_grid_w, + ) + self._patch_pos_cache[key] = pos + return pos.to(device=device, dtype=dtype) + + def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, + lq_latent=None, degrade_sigma=None, **kwargs): + B, _, H, W = x.shape + Hs = H // self.patch_size + Ws = W // self.patch_size + L = Hs * Ws + + if context is None or context.dim() != 3: + raise ValueError("PidNet requires context [B, L, D]") + if lq_latent is None: + raise ValueError("PidNet requires lq_latent — attach via PiDConditioning") + + if degrade_sigma is None: + degrade_sigma = torch.zeros(B, device=x.device, dtype=torch.float32) + elif not isinstance(degrade_sigma, torch.Tensor): + degrade_sigma = torch.tensor([float(degrade_sigma)] * B, device=x.device, dtype=torch.float32) + else: + degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1) + if degrade_sigma.numel() == 1 and B > 1: + degrade_sigma = degrade_sigma.expand(B).contiguous() + + lq_latent = lq_latent.to(device=x.device, dtype=x.dtype) + lq_features = self.lq_proj(lq_latent=lq_latent, target_pH=Hs, target_pW=Ws) + + pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype) + x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + + t_emb = self.t_embedder(timesteps.view(-1)).view(B, -1, self.hidden_size) + + Ltxt = min(context.shape[1], self.txt_max_length) + y = context[:, :Ltxt, :] + y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size) + # y_pos_embedding is raw nn.Parameter -> doesn't auto-cast under dynamic VRAM. + y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(device=y_emb.device, dtype=y_emb.dtype) + + condition = F.silu(t_emb) + pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None + + s = self.s_embedder(x_patches) + for i, blk in enumerate(self.patch_blocks): + if self.lq_proj.is_gate_active(i): + out_idx = self.lq_proj.output_index(i) + if out_idx < len(lq_features): + s = self.lq_proj.gate(s, lq_features[out_idx], degrade_sigma, out_idx) + s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, + transformer_options=transformer_options) + s = F.silu(t_emb + s) + + s_cond = s.view(B * L, self.hidden_size) + x_pixels = self.pixel_embedder(x, img_height=H, img_width=W, patch_size=self.patch_size) + for blk in self.pixel_blocks: + x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, + transformer_options=transformer_options) + + x_pixels = self.final_layer(x_pixels) + C_out = self.out_channels + P2 = self.patch_size * self.patch_size + x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).contiguous() + x_pixels = x_pixels.view(B, C_out * P2, L) + return F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size) diff --git a/comfy/model_base.py b/comfy/model_base.py index 0736321b3..2ad705117 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -48,6 +48,8 @@ import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model import comfy.ldm.chroma_radiance.model +import comfy.ldm.pixeldit.model +import comfy.ldm.pixeldit.pid import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model @@ -1296,6 +1298,41 @@ class ZImagePixelSpace(Lumina2): BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) self.memory_usage_factor_conds = ("ref_latents",) + +class PixelDiTT2I(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, + unet_model=comfy.ldm.pixeldit.model.PixDiT_T2I) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out["attention_mask"] = comfy.conds.CONDRegular(attention_mask) + return out + + +class PiD(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, + unet_model=comfy.ldm.pixeldit.pid.PidNet) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out["attention_mask"] = comfy.conds.CONDRegular(attention_mask) + lq_latent = kwargs.get("lq_latent", None) + if lq_latent is not None: + out["lq_latent"] = comfy.conds.CONDRegular(lq_latent) + degrade_sigma = kwargs.get("degrade_sigma", None) + if degrade_sigma is not None: + if not isinstance(degrade_sigma, torch.Tensor): + degrade_sigma = torch.tensor([float(degrade_sigma)], dtype=torch.float32) + out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma) + return out + + class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index bc0b933bc..70f8c41c0 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -424,6 +424,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" return dit_config + # PiD (Pixel Diffusion Decoder). Must check BEFORE plain PixelDiT_T2I. + _lq_w_key = '{}lq_proj.latent_proj.0.weight'.format(key_prefix) + if _lq_w_key in state_dict_keys: + in_ch = int(state_dict[_lq_w_key].shape[1]) + _gate_prefix = '{}lq_proj.gate_modules.'.format(key_prefix) + num_gates = len({k[len(_gate_prefix):].split('.')[0] + for k in state_dict_keys if k.startswith(_gate_prefix)}) + dit_config = {"image_model": "pid", + "lq_latent_channels": in_ch, + "latent_spatial_down_factor": 16 if in_ch >= 64 else 8} + if num_gates > 0: + dit_config["lq_interval"] = (14 + num_gates - 1) // num_gates + return dit_config + + if '{}core.pixel_embedder.proj.weight'.format(key_prefix) in state_dict_keys: # PixelDiT T2I + return {"image_model": "pixeldit_t2i"} + if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 dit_config = {} dit_config["image_model"] = "lumina2" diff --git a/comfy/sd.py b/comfy/sd.py index ab2718892..b4af4b8af 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -49,6 +49,7 @@ import comfy.text_encoders.lt import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 +import comfy.text_encoders.pixeldit import comfy.text_encoders.wan import comfy.text_encoders.hidream import comfy.text_encoders.ace @@ -1228,6 +1229,7 @@ class CLIPType(Enum): FLUX2 = 25 LONGCAT_IMAGE = 26 COGVIDEOX = 27 + PIXELDIT = 28 @@ -1460,8 +1462,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.tokenizer = variant.tokenizer tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) elif te_model == TEModel.GEMMA_2_2B: - clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) - clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer + if clip_type == CLIPType.PIXELDIT: + clip_target.clip = comfy.text_encoders.pixeldit.PixelDiTGemma2TE + clip_target.tokenizer = comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer + else: + clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.GEMMA_3_4B: clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b") diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 8d2e02f68..83162f8f1 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -29,6 +29,7 @@ import comfy.text_encoders.longcat_image import comfy.text_encoders.ernie import comfy.text_encoders.cogvideo import comfy.text_encoders.hidream_o1 +import comfy.text_encoders.pixeldit from . import supported_models_base from . import latent_formats @@ -1135,6 +1136,71 @@ class ZImagePixelSpace(ZImage): def get_model(self, state_dict, prefix="", device=None): return model_base.ZImagePixelSpace(self, device=device) +class PixelDiTT2I(supported_models_base.BASE): + unet_config = { + "image_model": "pixeldit_t2i", + } + + unet_extra_config = {} + + sampling_settings = { + "shift": 4.0, # 1024px stage 3 default; 2.0 for 512px + "multiplier": 1000, + } + + latent_format = latent_formats.PixelDiTPixel + memory_usage_factor = 0.7 + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.PixelDiTT2I(self, device=device) + + def process_unet_state_dict(self, state_dict): + out = {} + for k, v in state_dict.items(): + if k.startswith("_repa_projector"): + continue + if k.startswith("core."): + out[k[len("core."):]] = v + else: + out[k] = v + return out + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget( + comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer, + comfy.text_encoders.pixeldit.PixelDiTGemma2TE, + ) + +class PiD(PixelDiTT2I): + unet_config = { + "image_model": "pid", + } + + sampling_settings = { + "shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0] + "multiplier": 1000, + } + + def get_model(self, state_dict, prefix="", device=None): + return model_base.PiD(self, device=device) + + def process_unet_state_dict(self, state_dict): + out = {} + for k, v in state_dict.items(): + if k.startswith("_repa_projector") or k.startswith("net_ema."): + continue + if k.startswith("core."): + out[k[len("core."):]] = v + elif k.startswith("net."): + out[k[len("net."):]] = v + else: + out[k] = v + return out + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -2044,6 +2110,8 @@ models = [ CosmosI2VPredict2, ZImagePixelSpace, ZImage, + PiD, + PixelDiTT2I, Lumina2, WAN22_T2V, WAN21_CausalAR_T2V, diff --git a/comfy/text_encoders/pixeldit.py b/comfy/text_encoders/pixeldit.py new file mode 100644 index 000000000..8853e3584 --- /dev/null +++ b/comfy/text_encoders/pixeldit.py @@ -0,0 +1,124 @@ +import torch + +from comfy import sd1_clip +from .lumina2 import Gemma2BTokenizer, LuminaModel +import comfy.text_encoders.llama + + +class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel): + """Gemma-2-2b-it text encoder for PixelDiT. + + Uses the FINAL hidden state (layer='last') + """ + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__( + device=device, layer=layer, layer_idx=layer_idx, + textmodel_json_config={}, dtype=dtype, + special_tokens={"start": 2, "pad": 0}, + layer_norm_hidden_state=False, + model_class=comfy.text_encoders.llama.Gemma2_2B, + enable_attention_masks=attention_mask, + return_attention_masks=attention_mask, + model_options=model_options, + ) + + +_PIXELDIT_CHI_PROMPT = ( + 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions ' + "suitable for image generation. Evaluate the level of detail in the user prompt:\n" + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, " + "and spatial relationships to create vivid and concrete scenes.\n" + "- If the prompt is already detailed, refine and enhance the existing details slightly without " + "overcomplicating.\n" + "Here are examples of how to transform or refine prompts:\n" + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, " + "sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n" + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring " + "glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus " + "passing by towering glass skyscrapers.\n" + "Please generate only the enhanced description for the prompt below and avoid including any " + "additional commentary or evaluations:\n" + "User Prompt: " +) + +_PIXELDIT_MAX_LENGTH = 300 +_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"' + + +def _build_padded_tokens(combined_text: str, spiece_tokenizer, pad_id: int, chi_token_count: int): + # Right-pad to chi_token_count + 300 - 2 (matches upstream's max_length_all). + max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2 + ids = spiece_tokenizer(combined_text)["input_ids"] + if len(ids) > max_length_all: + ids = ids[:max_length_all] + elif len(ids) < max_length_all: + ids = ids + [pad_id] * (max_length_all - len(ids)) + return ids + + +class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer): + """Gemma-2-2b-it tokenizer that prepends PixelDiT's chi_prompt. + + Empty text -> BOS + pad to 300. Text already starting with the chi_prompt + preamble is tokenized verbatim (override mirrors QwenImageTokenizer's + `<|im_start|>` detection). Else chi_prompt is prepended. + """ + def __init__(self, embedding_directory=None, tokenizer_data=None): + if tokenizer_data is None: + tokenizer_data = {} + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, + name="gemma2_2b", tokenizer=Gemma2BTokenizer) + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + spiece_tokenizer = self.gemma2_2b.tokenizer + pad_id = self.gemma2_2b.pad_token + + if not (isinstance(text, str) and text.strip()): + ids = spiece_tokenizer("")["input_ids"] + ids = ids + [pad_id] * (_PIXELDIT_MAX_LENGTH - len(ids)) + return {"gemma2_2b": [[(t, 1.0) for t in ids]]} + + chi_token_count = len(spiece_tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"]) + combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text + ids = _build_padded_tokens(combined, spiece_tokenizer, pad_id, chi_token_count) + return {"gemma2_2b": [[(t, 1.0) for t in ids]]} + + def untokenize(self, token_weight_pair): + return self.gemma2_2b.untokenize(token_weight_pair) + + def state_dict(self): + return self.gemma2_2b.state_dict() + + +class PixelDiTGemma2TE(LuminaModel): + """Text encoder wrapper for PixelDiT. + + Overrides `encode_token_weights` to perform PixelDiT's `select_index` step: + encode the full padded sequence (up to ~chi_prompt_tokens + 298), then + return `[BOS_emb] + last_299_embs` as the 300-position conditioning that + matches the diffusion model's learned `y_pos_embedding` positions. + """ + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="gemma2_2b", + clip_model=PixelDiTGemma2_2BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + result = super().encode_token_weights(token_weight_pairs) + cond, pooled = result[0], result[1] + extra = result[2] if len(result) > 2 else None + L = cond.shape[1] + if L > _PIXELDIT_MAX_LENGTH: + head = cond[:, :1] + tail = cond[:, -(_PIXELDIT_MAX_LENGTH - 1):] + cond = torch.cat([head, tail], dim=1) + if extra is not None and "attention_mask" in extra: + am = extra["attention_mask"] + if am.dim() == 1: + am = am.unsqueeze(0) + if am.shape[-1] == L: + head_m = am[..., :1] + tail_m = am[..., -(_PIXELDIT_MAX_LENGTH - 1):] + extra = {**extra, "attention_mask": torch.cat([head_m, tail_m], dim=-1)} + if extra is not None: + return cond, pooled, extra + return cond, pooled diff --git a/comfy_cuda_memory_history20260525_155146.pt b/comfy_cuda_memory_history20260525_155146.pt new file mode 100644 index 0000000000000000000000000000000000000000..2a02983f23fbb89c13a0af6d7bc71e07b4a1ecb8 GIT binary patch literal 15684590 zcmd>{37l6|`TxHnvWbA;E-CJc3a+^;ih{C;A?}02zzp1Z7-q&
7nnoFjc{qOxQcb*47_dMLgFz1~9^?Hrel{?SpeV+5&bIv{Y
zeCGGOxKLQ3ko@PQMVl||(^@xcPF-VL>!PWPHc$SKUbS^|>u1z08edr0yQa3brLMJg
z(K#y>3X6NLyRctdQ(H|#b!+|cb&JLo3Zqt7IH0DXp=m} G&S#wxoHOY+*k(;ndR#tVSt4VGo
zh$OX;Ks)42(qS5rG)dFdBsXMul4AF)oi!<)t4VGoi2Q{8Y?SX3M1ERGpuI$X3L-Gw
z_X&!Kg#_AH5aDT#{X-7bh@?F>T}^VI?zqR!(xh~*Cb^No)7?e_O%RPM(jAV3NhRVQ
z1RleiBd2c(5IoEA?HW}O9$5Arhzx1wdKx*;Wn5Xzk<$^P2}Lf$-P^~hAbhp#Ge=HG
zXM#W$Jtt@w+?ti;51D3U>PP-}C;6>Ef5;SN>c>7Mr}#|(2R+{yG|kCmfGzt@HQuiB
zRgSZdacqL7mP`iN{qA(X8Q26($7V9XXJnk|Hv^lX>BLM1g=qtd2Xo3dWoK&)99s5S
z1Wc!9>xaJ73q00u5ip%8^z*STdbr-Csjtl6yteFnH>V0;E&H~Zr>$&)H+_Cu_H7l8
z6h3(<9ab&f2X_od3ZGWxmqM3h)(-sa*i9wVk-X`1 P3L7ZfE$$m
z={W<7plOkT0s5%&U%oTQA2eN-%>ce9_`-7r7D3bZvl+m}!vA>Az#?e6&cFZ-V88O6
zLH?lWrfdfAu#F!3Y73WlC4+=rqdk&`WHc)N)97eS{PTclLNqCw9RKN%= 2
z@UB~&3SHpzD}Fb8jH*mSWj85KR+{K;@JiD*WnpOdR;!m010=EW_tc41rHGgcrFQ
z_{@KUsz#i4lh!!qqUU>0V8UeA=JyCdJ9)5maCf*Si{(rFN6z@DSdFP4ah@
zN4}dBhk`I^00o*;(4?!};ad&yTU_$pq&UiOlj8KO-8Cvpc9Y^z5dMVyY+vgXgg*l)
z(4E4cLJ?^9{gYAzQaHd%f$kNGaGT@&kR!d1<2^RtO^R=K++#;;RF>=}#i79M9zua0
z5sfSC4rjt+RYl?pQ3^m#PY5Bn%kkqTs6rTs>^ry}i0qqU6gl@Yt}Lf`k<;^{35A#8
z>Fra!kE`sAT8-@UBd3?0LZFDAshSLNgSv`Crq`YPQU2Xow)q!_On-Fp$LV(swiO5g
zO^clx;Pkt5ZEFw$nm%x9fYa~Jx2-`4X!^*h0lw${0^1sdfTn*qH7E^fQ=K8mH)WS<
z8bplj3jj?2a`W$iLXCf6n129Z`daYkb6YgH-lwszZws1^?E7ay7FHwsw!Bl>|3c{!
zHY5ACPKOF#JTzWbEhGDWkPa2TtSU~0R@s(+!^&$19y_|K$~08o^u@_a6Y#sQ42Hh_
zRru96eR1G(S7P-KvJ9U)F$6xf7hV)s*>vN87
T3rP
}d_da!rkooumdI4&DYHvM)Jh@`-aMrK45l;O;N~pEId!ALaW3zWV==34h|EJQ;Af
zGLfSHcu-bR
66oFRmuJ37=;r
zrvF!;!{?c36F%K1a*@-(yB6P9X}D=e4#FmKdPsl}Sq;saR%%t2F4AFtx|-z1kw-{@
zPF$O4QkGava-&1!CTx&$S-;?1;;ngF*6zVsW
zf6U5j2Yz-utCHzRUirB(O%vgd>H{8qFR0-cU-`KSpJyee|9hXq=b4DZf6W*^tuAs=
z(16=eUN_z@d5*hjM*+eja#}}#;8~3qPCivZfPEu=diQ-6juf7tXtTvv;Yi^bn_miT
zl9{jg$mX6WXbbSghlm@~G|}ClcVB-s{Nh8zP53-3G5sMvhtD(7CVbjP
yZH
zUvoH8c!Hum7RQAng=cJjDYVjh{L@vQulP*klT|Xshlm@mG#);lmUpP(7at;S!sl6u
z>7VCw_&gJB!lywZ7dZ`FJ16||@_pWSy=
z;;e8g(6tG&;#CEw+})BOD_(IzSka$F#o;31roT0DkwRF}--Jk_EW
g-WOWA#JNB8N(S>dLCHbGXro8n&o
za)PY%TIOd(_ls`IMZ!)0O5!4gu%ZP*BvH$1C(&ef-(NJ5@+qeI?7mH})^_5+mW!qA
zK6rNM!o*pjprF|aa^m#{uY%u7kd-XTD0k&;FDrUc)Eh1mYWR1(5376#Vy}yGiHdWg
zSA{^LAj2ba$6XUp#mr~-otHQ()bJN2$ccAQT*a
A^