diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 75d459b59..12a934d71 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -799,13 +799,15 @@ class ZImagePixelSpace(ChromaRadiance): """ pass - class HiDreamO1Pixel(ChromaRadiance): """Pixel-space latent format for HiDream-O1. No VAE — model patches/unpatches raw RGB internally with patch_size=32. """ pass +class PixelDiTPixel(ChromaRadiance): + pass + class CogVideoX(LatentFormat): """Latent format for CogVideoX-2b (THUDM/CogVideoX-2b). diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 0dc8fe789..9ab3c463c 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -211,7 +211,7 @@ class TimestepEmbedder(nn.Module): Embeds scalar timesteps into vector representations. """ - def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None): + def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None, max_period=10000): super().__init__() if output_size is None: output_size = hidden_size @@ -221,9 +221,10 @@ class TimestepEmbedder(nn.Module): operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device), ) self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period def forward(self, t, dtype, **kwargs): - t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype) + t_freq = timestep_embedding(t, self.frequency_embedding_size, max_period=self.max_period).to(dtype) t_emb = self.mlp(t_freq) return t_emb diff --git a/comfy/ldm/pixeldit/model.py b/comfy/ldm/pixeldit/model.py new file mode 100644 index 000000000..b044b9b29 --- /dev/null +++ b/comfy/ldm/pixeldit/model.py @@ -0,0 +1,239 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ldm.common_dit +import comfy.patcher_extension +from comfy.ldm.flux.math import apply_rope, rope +from comfy.ldm.hidream.model import FeedForwardSwiGLU +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder + +from .modules import ( + FinalLayer, + PatchTokenEmbedder, + PiTBlock, + PixelTokenEmbedder, + apply_adaln_, + precompute_freqs_cis_2d, +) + + +class MMDiTJointAttention(nn.Module): + """Joint MMDiT attention with separate Q/K/V/proj for image and text streams. + + RoPE is applied to each stream before concatenation so each stream uses its own + 2D/1D positional encoding. Concat order is [text, image] (text first). + """ + def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None): + super().__init__() + assert dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + + self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + + self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device) + self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device) + + def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}): + B, Nx, _ = x.shape + _, Ny, _ = y.shape + H = self.num_heads + D = self.head_dim + + qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4) + qx, kx, vx = qkv_x.unbind(0) + qx = self.q_norm_x(qx) + kx = self.k_norm_x(kx) + + qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4) + qy, ky, vy = qkv_y.unbind(0) + qy = self.q_norm_y(qy) + ky = self.k_norm_y(ky) + + qx, kx = apply_rope(qx, kx, pos_img[None, None]) + if pos_txt is not None: + qy, ky = apply_rope(qy, ky, pos_txt[None, None]) + + q_joint = torch.cat([qy, qx], dim=2) + k_joint = torch.cat([ky, kx], dim=2) + v_joint = torch.cat([vy, vx], dim=2) + + out_joint = optimized_attention( + q_joint, k_joint, v_joint, H, + mask=attn_mask, skip_reshape=True, skip_output_reshape=True, + transformer_options=transformer_options, + ) + + out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D) + out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D) + + return self.proj_x(out_x), self.proj_y(out_y) + + +class MMDiTBlockT2I(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None): + super().__init__() + self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False, dtype=dtype, device=device, operations=operations) + self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations) + self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations) + self.adaLN_modulation_img = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)) + self.adaLN_modulation_txt = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)) + + def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}): + shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1) + shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1) + + x_norm = apply_adaln_(self.norm_x1(x), shift_msa_x, scale_msa_x) + y_norm = apply_adaln_(self.norm_y1(y), shift_msa_y, scale_msa_y) + attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options) + x = torch.addcmul(x, gate_msa_x, attn_x) + y = torch.addcmul(y, gate_msa_y, attn_y) + + x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln_(self.norm_x2(x), shift_mlp_x, scale_mlp_x))) + y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln_(self.norm_y2(y), shift_mlp_y, scale_mlp_y))) + return x, y + + +class PixDiT_T2I(nn.Module): + """PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint + (also runs at 512px when fed the appropriate latent size and flow_shift). + + Forward: + x: [B, 3, H, W] pixel-space input (no VAE) + timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention) + context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended) + Returns flow-matching velocity [B, 3, H, W]. + """ + def __init__( + self, + in_channels=3, + num_groups=24, + hidden_size=1536, + pixel_hidden_size=16, + pixel_attn_hidden_size=1152, + pixel_num_groups=16, + patch_depth=14, + pixel_depth=2, + patch_size=16, + txt_embed_dim=2304, + txt_max_length=300, + use_text_rope=True, + text_rope_theta=10000.0, + image_model=None, + dtype=None, + device=None, + operations=None, + pixel_mlp_chunks=2, + ): + super().__init__() + self.dtype = dtype + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.patch_depth = patch_depth + self.pixel_depth = pixel_depth + self.patch_size = patch_size + self.pixel_hidden_size = pixel_hidden_size + self.pixel_attn_hidden_size = pixel_attn_hidden_size + self.pixel_num_groups = pixel_num_groups + self.txt_embed_dim = txt_embed_dim + self.txt_max_length = txt_max_length + self.use_text_rope = use_text_rope + self.text_rope_theta = text_rope_theta + + self.pixel_embedder = PixelTokenEmbedder(self.in_channels, self.pixel_hidden_size, dtype=dtype, device=device, operations=operations) + self.s_embedder = PatchTokenEmbedder(self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True, dtype=dtype, device=device, operations=operations) + self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10) + self.y_embedder = PatchTokenEmbedder(self.txt_embed_dim, self.hidden_size, bias=True, use_norm=True, dtype=dtype, device=device, operations=operations) + self.y_pos_embedding = nn.Parameter(torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device)) + + self.patch_blocks = nn.ModuleList([ + MMDiTBlockT2I(self.hidden_size, self.num_groups, + dtype=dtype, device=device, operations=operations) + for _ in range(self.patch_depth) + ]) + self.pixel_blocks = nn.ModuleList([ + PiTBlock( + self.pixel_hidden_size, + self.hidden_size, + patch_size=self.patch_size, + num_heads=self.num_groups, + attn_hidden_size=self.pixel_attn_hidden_size, + attn_num_heads=self.pixel_num_groups, + dtype=dtype, device=device, operations=operations, + mlp_chunks=pixel_mlp_chunks, + ) + for _ in range(self.pixel_depth) + ]) + + self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels, dtype=dtype, device=device, operations=operations) + + def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts): + return precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width, device=device, dtype=dtype, **rope_opts) + + def _fetch_text_pos(self, length, device, dtype): + return rope(torch.arange(length, dtype=torch.float32, device=device).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0).to(dtype=dtype) + + def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options), + ).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs) + + def _pre_patch_block(self, s, i, **kwargs): + """Hook for subclasses to inject per-block state into the patch stream (e.g. PiD's LQ gate).""" + return s + + def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs): + H_orig, W_orig = x.shape[2], x.shape[3] + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + B, _, H, W = x.shape + Hs = H // self.patch_size + Ws = W // self.patch_size + L = Hs * Ws + + pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {})) + x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + + t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size) + + if context is None or context.dim() != 3: + raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]") + Ltxt = min(context.shape[1], self.txt_max_length) + y = context[:, :Ltxt, :] + y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size) + y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb) # y_pos_embedding is a raw nn.Parameter + + condition = F.silu(t_emb) + pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None + + s = self.s_embedder(x_patches) + for i, blk in enumerate(self.patch_blocks): + s = self._pre_patch_block(s, i, **kwargs) + s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options) + s = F.silu(t_emb + s) + + s_cond = s.view(B * L, self.hidden_size) + x_pixels = self.pixel_embedder(x, patch_size=self.patch_size) + for blk in self.pixel_blocks: + x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options) + + x_pixels = self.final_layer(x_pixels) + C_out = self.out_channels + P2 = self.patch_size * self.patch_size + x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).reshape(B, C_out * P2, L) + out = F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return out[:, :, :H_orig, :W_orig] diff --git a/comfy/ldm/pixeldit/modules.py b/comfy/ldm/pixeldit/modules.py new file mode 100644 index 000000000..4b1e538c7 --- /dev/null +++ b/comfy/ldm/pixeldit/modules.py @@ -0,0 +1,187 @@ +import torch +import torch.nn as nn + +from comfy.ldm.flux.math import apply_rope, rope +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, get_1d_sincos_pos_embed_from_grid_torch + + +def apply_adaln_(x, shift, scale): + return x.addcmul_(x, scale).add_(shift) + + +def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0, + ref_grid_h=None, ref_grid_w=None, + scale_x=1.0, scale_y=1.0, shift_x=0.0, shift_y=0.0, + device=None, dtype=torch.float32, **kwargs): + """2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim. + + rope_options: + scale_x / scale_y multiply the position range (RoPE extrapolation). + shift_x / shift_y offset the position origin (tiled / regional inference). + With ref_grid_h/w set, also applies NTK-aware per-axis theta scaling + (rope_mode='ntk_aware'): theta_axis = theta * (current/ref)^(dim_axis/(dim_axis-2)). + Returns Flux-format rotation matrices of shape [H*W, dim/2, 2, 2]. + Layout of head-dim pairs: [x_0, y_0, x_1, y_1, ..., x_{dim/4-1}, y_{dim/4-1}]. + """ + dim_axis = dim // 2 + if ref_grid_h is not None and dim_axis > 2: + h_ntk = (height / ref_grid_h) ** (dim_axis / (dim_axis - 2)) + w_ntk = (width / ref_grid_w) ** (dim_axis / (dim_axis - 2)) + else: + h_ntk = w_ntk = 1.0 + + x_lin = torch.linspace(shift_x, scale * scale_x + shift_x, width, device=device) + y_lin = torch.linspace(shift_y, scale * scale_y + shift_y, height, device=device) + y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij") + x_rope = rope(x_grid.reshape(1, -1), dim_axis, theta * w_ntk).squeeze(0) + y_rope = rope(y_grid.reshape(1, -1), dim_axis, theta * h_ntk).squeeze(0) + out = torch.stack([x_rope, y_rope], dim=2).reshape(height * width, dim // 2, 2, 2) + return out.to(dtype=dtype) + + +def get_2d_sincos_pos_embed(embed_dim, height, width, device=None, dtype=torch.float32): + """Standard 2D sin/cos absolute positional embedding (ViT-style). + + first half encodes W-coordinates, second half H. + """ + assert embed_dim % 4 == 0 + grid_h = torch.arange(height, dtype=torch.float32, device=device) + grid_w = torch.arange(width, dtype=torch.float32, device=device) + grid_y, grid_x = torch.meshgrid(grid_h, grid_w, indexing="ij") + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_x.reshape(-1), device=device) + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_y.reshape(-1), device=device) + return torch.cat([emb_w, emb_h], dim=1).to(dtype=dtype) + + +class RotaryAttention(nn.Module): + """Single-stream self-attention with rotary positional encoding (used inside PiTBlock).""" + def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None): + super().__init__() + assert dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) + + def forward(self, x, pos, mask=None, transformer_options={}): + B, N, C = x.shape + H = self.num_heads + D = self.head_dim + qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = apply_rope(self.q_norm(q), self.k_norm(k), pos[None, None]) + x = optimized_attention(q, k, v, H, mask=mask, skip_reshape=True, transformer_options=transformer_options) + return self.proj(x) + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, x): + return self.linear(self.norm(x)) + + +class PatchTokenEmbedder(nn.Module): + """Linear projection used both for patchified-image tokens and text-feature tokens.""" + def __init__(self, in_chans, embed_dim, use_norm=False, bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.proj = operations.Linear(in_chans, embed_dim, bias=bias, dtype=dtype, device=device) + self.norm = operations.RMSNorm(embed_dim, eps=1e-6, dtype=dtype, device=device) if use_norm else nn.Identity() + + def forward(self, x): + return self.norm(self.proj(x)) + + +class PixelTokenEmbedder(nn.Module): + """Pixel-level embedder: lifts each RGB pixel to hidden_size and packs into per-patch sequences.""" + def __init__(self, in_channels, hidden_size_output, dtype=None, device=None, operations=None): + super().__init__() + self.in_channels = in_channels + self.hidden_size_output = hidden_size_output + self.proj = operations.Linear(self.in_channels, self.hidden_size_output, bias=True, dtype=dtype, device=device) + + def forward(self, inputs, patch_size): + B, _, H, W = inputs.shape + Hs, Ws = H // patch_size, W // patch_size + P2 = patch_size * patch_size + x = inputs.permute(0, 2, 3, 1).contiguous() + x = self.proj(x) + pos_full = get_2d_sincos_pos_embed(self.hidden_size_output, H, W, device=x.device, dtype=x.dtype).view(H, W, self.hidden_size_output) + x = x + pos_full.unsqueeze(0) + x = x.view(B, Hs, patch_size, Ws, patch_size, self.hidden_size_output) + return x.permute(0, 1, 3, 2, 4, 5).reshape(B * Hs * Ws, P2, self.hidden_size_output) + + +class PiTBlock(nn.Module): + """Pixel-level transformer block. + + Compresses each patch's P^2 pixel tokens → 1 attention token via a linear, + runs global self-attention across patches with 2D RoPE, then expands back to P^2 tokens. + Conditioning is per-pixel adaLN from the patch-level features. + """ + def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0, + attn_hidden_size=None, attn_num_heads=None, dtype=None, device=None, operations=None, mlp_chunks=1): + super().__init__() + self.pixel_dim = pixel_hidden_size + self.context_dim = patch_hidden_size + self.attn_dim = attn_hidden_size if attn_hidden_size is not None else patch_hidden_size + self.num_heads = attn_num_heads if attn_num_heads is not None else num_heads + assert self.attn_dim % self.num_heads == 0 + + p2 = patch_size * patch_size + self.compress_to_attn = operations.Linear(p2 * self.pixel_dim, self.attn_dim, bias=True, dtype=dtype, device=device) + self.expand_from_attn = operations.Linear(self.attn_dim, p2 * self.pixel_dim, bias=True, dtype=dtype, device=device) + + self.norm1 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device) + self.attn = RotaryAttention(self.attn_dim, num_heads=self.num_heads, qkv_bias=False, dtype=dtype, device=device, operations=operations) + self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device) + self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio), dtype=dtype, device=device, operations=operations) + + self.adaLN_modulation_msa = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device) + self.adaLN_modulation_mlp = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device) + + self._rope_fn = precompute_freqs_cis_2d + self.mlp_chunks = max(1, int(mlp_chunks)) + + def _fetch_pos(self, height, width, device, dtype, **rope_opts): + return self._rope_fn(self.attn_dim // self.num_heads, height, width, device=device, dtype=dtype, **rope_opts) + + def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}): + BL, P2, _ = x.shape + Hs, Ws = image_height // patch_size, image_width // patch_size + L = Hs * Ws + B = BL // L + + # Attention path uses only msa params; compute, use, free before mlp params allocate. + msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim) + shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1) + + x_norm = apply_adaln_(self.norm1(x), shift_msa, scale_msa) + x_flat = x_norm.view(BL, P2 * self.pixel_dim) + + x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim) + pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {})) + attn_out = self.attn(x_comp, pos_comp, mask=mask, transformer_options=transformer_options) + attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim)) + attn_exp = attn_flat.view(BL, P2, self.pixel_dim) + x = torch.addcmul(x, gate_msa, attn_exp) + del msa_params, shift_msa, scale_msa, gate_msa + + mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim) + shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1) + gate_mlp = gate_mlp.contiguous() # detach from mlp_params so the del below frees shift+scale storage before the MLP + mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp) + del mlp_params, shift_mlp, scale_mlp + + # MLP in chunks since the peak memory usage is huge here + chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks + for s in range(0, BL, chunk_size): + e = min(s + chunk_size, BL) + x[s:e].addcmul_(gate_mlp[s:e], self.mlp(mlp_input[s:e])) + return x diff --git a/comfy/ldm/pixeldit/pid.py b/comfy/ldm/pixeldit/pid.py new file mode 100644 index 000000000..0ad4b7ce8 --- /dev/null +++ b/comfy/ldm/pixeldit/pid.py @@ -0,0 +1,226 @@ +"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent +directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I +body + LQ projection branch injected before each MMDiT patch block. +""" + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .model import PixDiT_T2I +from .modules import precompute_freqs_cis_2d + + +class SigmaAwareGatePerTokenPerDim(nn.Module): + """gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq. + + Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1. + """ + + def __init__(self, dim: int, dtype=None, device=None, operations=None): + super().__init__() + self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device) + self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device)) + + def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + content_logit = self.content_proj(torch.cat([x, lq], dim=-1)) + # log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM. + log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32) + sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1) + gate = torch.sigmoid(content_logit + sigma_offset) + return x + (gate * lq).to(x.dtype) + + +class ResBlock(nn.Module): + """Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip.""" + + def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None): + super().__init__() + self.block = nn.Sequential( + operations.GroupNorm(num_groups, channels, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device), + operations.GroupNorm(num_groups, channels, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.block(x) + + +class LQProjection2D(nn.Module): + """LQ latent -> per-block patch-aligned features for controlnet-style injection.""" + + def __init__( + self, + latent_channels: int, + hidden_dim: int = 512, + out_dim: int = 1536, + patch_size: int = 16, + sr_scale: int = 4, + latent_spatial_down_factor: int = 8, + num_res_blocks: int = 4, + num_outputs: int = 7, + interval: int = 2, + dtype=None, device=None, operations=None, + ): + super().__init__() + self.latent_channels = latent_channels + self.hidden_dim = hidden_dim + self.out_dim = out_dim + self.patch_size = patch_size + self.sr_scale = sr_scale + self.latent_spatial_down_factor = latent_spatial_down_factor + self.num_outputs = num_outputs + self.interval = interval + + z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size + self.z_to_patch_ratio = z_to_patch_ratio + if z_to_patch_ratio >= 1: + self.latent_fold_factor = 0 + latent_proj_in_ch = latent_channels + else: + fold_factor = int(1 / z_to_patch_ratio) + assert fold_factor * z_to_patch_ratio == 1.0 + self.latent_fold_factor = fold_factor + latent_proj_in_ch = latent_channels * fold_factor * fold_factor + + layers = [ + operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device), + ] + for _ in range(num_res_blocks): + layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations)) + self.latent_proj = nn.Sequential(*layers) + + self.output_heads = nn.ModuleList( + [operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)] + ) + self.gate_modules = nn.ModuleList( + [SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations) + for _ in range(num_outputs)] + ) + + def is_gate_active(self, block_idx: int) -> bool: + return block_idx % self.interval == 0 + + def output_index(self, block_idx: int) -> int: + return block_idx // self.interval + + def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor: + return self.gate_modules[out_idx](x, lq_feature, sigma) + + def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor: + B, z_dim = lq_latent.shape[:2] + if self.z_to_patch_ratio >= 1: + if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW: + z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest") + else: + z_aligned = lq_latent + else: + f = self.latent_fold_factor + zH_expected, zW_expected = pH * f, pW * f + if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected: + lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest") + z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4) + z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW) + return self.latent_proj(z_aligned) + + def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]: + feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW) + B, C, H, W = feat.shape + tokens = feat.permute(0, 2, 3, 1).contiguous().view(B, H * W, C) + return [head(tokens) for head in self.output_heads] + + +class PidNet(PixDiT_T2I): + """PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block).""" + + def __init__( + self, + lq_latent_channels: int = 16, + lq_hidden_dim: int = 512, + lq_num_res_blocks: int = 4, + lq_interval: int = 2, + sr_scale: int = 4, + latent_spatial_down_factor: int = 8, + rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64. + rope_ref_w: int = 1024, + image_model=None, + dtype=None, device=None, operations=None, + **pixdit_kwargs, + ): + super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs) + + self.rope_ref_grid_h = rope_ref_h // self.patch_size + self.rope_ref_grid_w = rope_ref_w // self.patch_size + + # Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware. + def _pit_rope_fn(head_dim, h, w, device=None, dtype=torch.float32, **rope_opts): + return precompute_freqs_cis_2d(head_dim, h, w, ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, device=device, dtype=dtype, **rope_opts) + for blk in self.pixel_blocks: + blk._rope_fn = _pit_rope_fn + + num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval + self.lq_proj = LQProjection2D( + latent_channels=lq_latent_channels, + hidden_dim=lq_hidden_dim, + out_dim=self.hidden_size, + patch_size=self.patch_size, + sr_scale=sr_scale, + latent_spatial_down_factor=latent_spatial_down_factor, + num_res_blocks=lq_num_res_blocks, + num_outputs=num_lq_outputs, + interval=lq_interval, + dtype=dtype, + device=device, + operations=operations, + ) + + def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts): + return precompute_freqs_cis_2d( + self.hidden_size // self.num_groups, + height, width, + ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, + device=device, dtype=dtype, **rope_opts, + ) + + def _pre_patch_block(self, s, i, pid_lq_features, pid_degrade_sigma, **kwargs): + if not self.lq_proj.is_gate_active(i): + return s + out_idx = self.lq_proj.output_index(i) + if out_idx >= len(pid_lq_features): + return s + return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx) + + def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs): + if lq_latent is None: + raise ValueError("PidNet requires lq_latent — attach via PiDConditioning") + expected_c = self.lq_proj.latent_channels + if lq_latent.shape[1] != expected_c: + raise ValueError( + f"Input latent has {lq_latent.shape[1]} channels, this model variant expects {expected_c}. " + f"Flux1/SD3 = 16 channels, Flux2 = 128 channels." + ) + B = x.shape[0] + Hs = x.shape[2] // self.patch_size + Ws = x.shape[3] // self.patch_size + + degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1) + if degrade_sigma.numel() == 1 and B > 1: + degrade_sigma = degrade_sigma.expand(B).contiguous() + + lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws) + + return super()._forward( + x, timesteps, + context=context, attention_mask=attention_mask, + transformer_options=transformer_options, + pid_lq_features=lq_features, + pid_degrade_sigma=degrade_sigma, + **kwargs, + ) diff --git a/comfy/model_base.py b/comfy/model_base.py index d4ab1499e..e55808633 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -49,6 +49,8 @@ import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model import comfy.ldm.chroma_radiance.model +import comfy.ldm.pixeldit.model +import comfy.ldm.pixeldit.pid import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model @@ -1397,6 +1399,36 @@ class ZImagePixelSpace(Lumina2): BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) self.memory_usage_factor_conds = ("ref_latents",) + +class PixelDiTT2I(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, + unet_model=comfy.ldm.pixeldit.model.PixDiT_T2I) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out["attention_mask"] = comfy.conds.CONDRegular(attention_mask) + return out + + +class PiD(PixelDiTT2I): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + BaseModel.__init__(self, model_config, model_type, device=device, + unet_model=comfy.ldm.pixeldit.pid.PidNet) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + lq_latent = kwargs.get("lq_latent", None) + if lq_latent is not None: + out["lq_latent"] = comfy.conds.CONDRegular(lq_latent) + degrade_sigma = kwargs.get("degrade_sigma", None) + if degrade_sigma is not None: + out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma) + return out + + class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 2b0b98cd8..f0db7d388 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -463,6 +463,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" return dit_config + # PiD (Pixel Diffusion Decoder). Must check BEFORE plain PixelDiT_T2I. + _lq_w_key = '{}lq_proj.latent_proj.0.weight'.format(key_prefix) + if _lq_w_key in state_dict_keys: + in_ch = int(state_dict[_lq_w_key].shape[1]) + _gate_prefix = '{}lq_proj.gate_modules.'.format(key_prefix) + num_gates = len({k[len(_gate_prefix):].split('.')[0] + for k in state_dict_keys if k.startswith(_gate_prefix)}) + dit_config = {"image_model": "pid", + "lq_latent_channels": in_ch, + "latent_spatial_down_factor": 16 if in_ch >= 64 else 8} + if num_gates > 0: + dit_config["lq_interval"] = (14 + num_gates - 1) // num_gates + return dit_config + + if '{}core.pixel_embedder.proj.weight'.format(key_prefix) in state_dict_keys: # PixelDiT T2I + return {"image_model": "pixeldit_t2i"} + if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 dit_config = {} dit_config["image_model"] = "lumina2" diff --git a/comfy/sd.py b/comfy/sd.py index beb782310..30b877b85 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -49,6 +49,7 @@ import comfy.text_encoders.lt import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 +import comfy.text_encoders.pixeldit import comfy.text_encoders.wan import comfy.text_encoders.hidream import comfy.text_encoders.ace @@ -1285,6 +1286,7 @@ class CLIPType(Enum): LONGCAT_IMAGE = 26 COGVIDEOX = 27 LENS = 28 + PIXELDIT = 29 @@ -1528,8 +1530,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.tokenizer = variant.tokenizer tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) elif te_model == TEModel.GEMMA_2_2B: - clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) - clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer + if clip_type == CLIPType.PIXELDIT: + clip_target.clip = comfy.text_encoders.pixeldit.pixeldit_te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer + else: + clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.GEMMA_3_4B: clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b") diff --git a/comfy/supported_models.py b/comfy/supported_models.py index e451892e9..4723caff5 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -30,6 +30,7 @@ import comfy.text_encoders.longcat_image import comfy.text_encoders.ernie import comfy.text_encoders.cogvideo import comfy.text_encoders.hidream_o1 +import comfy.text_encoders.pixeldit from . import supported_models_base from . import latent_formats @@ -1201,6 +1202,72 @@ class ZImagePixelSpace(ZImage): def get_model(self, state_dict, prefix="", device=None): return model_base.ZImagePixelSpace(self, device=device) +class PixelDiTT2I(supported_models_base.BASE): + unet_config = { + "image_model": "pixeldit_t2i", + } + + unet_extra_config = {} + + sampling_settings = { + "shift": 4.0, # 1024px stage 3 default; 2.0 for 512px + } + + latent_format = latent_formats.PixelDiTPixel + memory_usage_factor = 0.18 + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.PixelDiTT2I(self, device=device) + + def process_unet_state_dict(self, state_dict): + # pixel_dim from pixel_embedder.proj.weight = (pixel_dim, in_channels); p2 derived per-weight from total // (6 * pixel_dim). + pixel_dim = next(v for k, v in state_dict.items() if k.endswith("pixel_embedder.proj.weight")).shape[0] + + out = {} + marker = ".adaLN_modulation.0." + for k, v in state_dict.items(): + if k.startswith("_repa_projector") or k.startswith("net_ema."): + continue + if k.startswith("core."): + k = k[len("core."):] + elif k.startswith("net."): + k = k[len("net."):] + if "pixel_blocks." in k and marker in k: + # Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM + p2 = v.shape[0] // (6 * pixel_dim) + trail = v.shape[1:] # () for bias, (in_dim,) for weight + vv = v.view(p2, 6, pixel_dim, *trail) + base, suffix = k.split(marker) + out[f"{base}.adaLN_modulation_msa.{suffix}"] = vv[:, 0:3].reshape(3 * p2 * pixel_dim, *trail).contiguous() + out[f"{base}.adaLN_modulation_mlp.{suffix}"] = vv[:, 3:6].reshape(3 * p2 * pixel_dim, *trail).contiguous() + else: + out[k] = v + return out + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget( + comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer, + comfy.text_encoders.pixeldit.PixelDiTGemma2TE, + ) + +class PiD(PixelDiTT2I): + unet_config = { + "image_model": "pid", + } + + sampling_settings = { + "shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0] + } + + memory_usage_factor = 0.07 + + def get_model(self, state_dict, prefix="", device=None): + return model_base.PiD(self, device=device) + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -2111,6 +2178,8 @@ models = [ CosmosI2VPredict2, ZImagePixelSpace, ZImage, + PiD, + PixelDiTT2I, Lumina2, WAN22_T2V, WAN21_CausalAR_T2V, diff --git a/comfy/text_encoders/pixeldit.py b/comfy/text_encoders/pixeldit.py new file mode 100644 index 000000000..3539711e4 --- /dev/null +++ b/comfy/text_encoders/pixeldit.py @@ -0,0 +1,104 @@ +import torch + +from comfy import sd1_clip +from .lumina2 import Gemma2BTokenizer, LuminaModel +import comfy.text_encoders.llama + + +class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + + super().__init__( + device=device, layer=layer, layer_idx=layer_idx, + textmodel_json_config={}, dtype=dtype, + special_tokens={"start": 2, "pad": 0}, + layer_norm_hidden_state=False, + model_class=comfy.text_encoders.llama.Gemma2_2B, + enable_attention_masks=attention_mask, + return_attention_masks=attention_mask, + model_options=model_options, + ) + + +_PIXELDIT_CHI_PROMPT = ( + 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions ' + "suitable for image generation. Evaluate the level of detail in the user prompt:\n" + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, " + "and spatial relationships to create vivid and concrete scenes.\n" + "- If the prompt is already detailed, refine and enhance the existing details slightly without " + "overcomplicating.\n" + "Here are examples of how to transform or refine prompts:\n" + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, " + "sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n" + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring " + "glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus " + "passing by towering glass skyscrapers.\n" + "Please generate only the enhanced description for the prompt below and avoid including any " + "additional commentary or evaluations:\n" + "User Prompt: " +) + +_PIXELDIT_MAX_LENGTH = 300 +_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"' + + +class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data=None): + if tokenizer_data is None: + tokenizer_data = {} + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, + name="gemma2_2b", tokenizer=Gemma2BTokenizer) + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + if not text.strip(): + return super().tokenize_with_weights("", return_word_ids=return_word_ids, disable_weights=True, min_length=_PIXELDIT_MAX_LENGTH) + + chi_token_count = len(self.gemma2_2b.tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"]) + combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text + max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2 + out = super().tokenize_with_weights(combined, return_word_ids=return_word_ids, + disable_weights=True, min_length=max_length_all) + out["gemma2_2b"] = [out["gemma2_2b"][0][:max_length_all]] + return out + + def untokenize(self, token_weight_pair): + return self.gemma2_2b.untokenize(token_weight_pair) + + def state_dict(self): + return self.gemma2_2b.state_dict() + + +class PixelDiTGemma2TE(LuminaModel): + # PixelDiT's select_index: keep BOS + last 299 embeddings of the padded sequence. + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="gemma2_2b", + clip_model=PixelDiTGemma2_2BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + result = super().encode_token_weights(token_weight_pairs) + cond, pooled = result[0], result[1] + extra = result[2] if len(result) > 2 else None + if cond.shape[1] > _PIXELDIT_MAX_LENGTH: + cond = torch.cat([cond[:, :1], cond[:, -(_PIXELDIT_MAX_LENGTH - 1):]], dim=1) + if extra is not None and "attention_mask" in extra: + am = extra["attention_mask"] + extra["attention_mask"] = torch.cat([am[..., :1], am[..., -(_PIXELDIT_MAX_LENGTH - 1):]], dim=-1) + if extra is not None: + return cond, pooled, extra + return cond, pooled + + +def pixeldit_te(dtype_llama=None, llama_quantization_metadata=None): + class PixelDiTTE_(PixelDiTGemma2TE): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return PixelDiTTE_ diff --git a/comfy_extras/nodes_pid.py b/comfy_extras/nodes_pid.py new file mode 100644 index 000000000..811b9ae8e --- /dev/null +++ b/comfy_extras/nodes_pid.py @@ -0,0 +1,55 @@ +"""PiD (Pixel Diffusion Decoder) node""" + +import torch +from typing_extensions import override + +import node_helpers +import comfy.latent_formats +from comfy_api.latest import ComfyExtension, io + + +class PiDConditioning(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="PiDConditioning", + display_name="PiD Conditioning", + category="advanced/conditioning", + description=( + "Attaches a latent and a degrade_sigma scalar to a CONDITIONING for PiD decoding/upscaling" + ), + inputs=[ + io.Conditioning.Input("positive"), + io.Latent.Input("latent", tooltip="latent (from VAEEncode or a KSampler)."), + io.Combo.Input("latent_format", options=["flux", "sd3"], default="flux", + tooltip="Flux1 and Flux2 latents auto-detected from channel dim, sd3 has to be selected manually."), + io.Float.Input( + "degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01, + tooltip="0 = clean latent. Increase to denoise corrupted latent outputs.", + ), + ], + outputs=[io.Conditioning.Output()], + ) + + @classmethod + def execute(cls, positive, latent, latent_format: str, degrade_sigma: float) -> io.NodeOutput: + samples = latent["samples"] + if latent_format == "flux": + fmt_cls = comfy.latent_formats.Flux2 if samples.shape[1] == 128 else comfy.latent_formats.Flux + else: + fmt_cls = comfy.latent_formats.SD3 + lq_latent = fmt_cls().process_in(samples) + sigma_t = torch.tensor([float(degrade_sigma)], dtype=torch.float32) + return io.NodeOutput(node_helpers.conditioning_set_values( + positive, {"lq_latent": lq_latent, "degrade_sigma": sigma_t}, + )) + + +class PiDExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [PiDConditioning] + + +async def comfy_entrypoint() -> PiDExtension: + return PiDExtension() diff --git a/nodes.py b/nodes.py index 13d3864cd..87d81b5b7 100644 --- a/nodes.py +++ b/nodes.py @@ -969,7 +969,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -979,7 +979,7 @@ class CLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B\nlens: gpt-oss-20b" + DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B\nlens: gpt-oss-20b\n pixeldit: gemma 2 2B elm" def load_clip(self, clip_name, type="stable_diffusion", device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) @@ -2420,6 +2420,7 @@ async def init_builtin_extra_nodes(): "nodes_context_windows.py", "nodes_qwen.py", "nodes_chroma_radiance.py", + "nodes_pid.py", "nodes_model_patch.py", "nodes_easycache.py", "nodes_audio_encoder.py",