Support PixelDiT and PiD

This commit is contained in:
kijai 2026-05-25 16:24:54 +03:00
parent 0155ddcbe3
commit 7b72f322a5
14 changed files with 1077 additions and 6 deletions

View File

@ -792,13 +792,15 @@ class ZImagePixelSpace(ChromaRadiance):
""" """
pass pass
class HiDreamO1Pixel(ChromaRadiance): class HiDreamO1Pixel(ChromaRadiance):
"""Pixel-space latent format for HiDream-O1. """Pixel-space latent format for HiDream-O1.
No VAE model patches/unpatches raw RGB internally with patch_size=32. No VAE model patches/unpatches raw RGB internally with patch_size=32.
""" """
pass pass
class PixelDiTPixel(ChromaRadiance):
pass
class CogVideoX(LatentFormat): class CogVideoX(LatentFormat):
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b). """Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).

View File

@ -211,7 +211,7 @@ class TimestepEmbedder(nn.Module):
Embeds scalar timesteps into vector representations. Embeds scalar timesteps into vector representations.
""" """
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None): def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None, max_period=10000):
super().__init__() super().__init__()
if output_size is None: if output_size is None:
output_size = hidden_size output_size = hidden_size
@ -221,9 +221,10 @@ class TimestepEmbedder(nn.Module):
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device), operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
) )
self.frequency_embedding_size = frequency_embedding_size self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
def forward(self, t, dtype, **kwargs): def forward(self, t, dtype, **kwargs):
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype) t_freq = timestep_embedding(t, self.frequency_embedding_size, max_period=self.max_period).to(dtype)
t_emb = self.mlp(t_freq) t_emb = self.mlp(t_freq)
return t_emb return t_emb

View File

270
comfy/ldm/pixeldit/model.py Normal file
View File

@ -0,0 +1,270 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.patcher_extension
from comfy.ldm.flux.math import apply_rope, rope
from comfy.ldm.hidream.model import FeedForwardSwiGLU
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from .modules import (
FinalLayer,
PatchTokenEmbedder,
PiTBlock,
PixelTokenEmbedder,
apply_adaln,
precompute_freqs_cis_2d,
)
class MMDiTJointAttention(nn.Module):
"""Joint MMDiT attention with separate Q/K/V/proj for image and text streams.
RoPE is applied to each stream before concatenation so each stream uses its own
2D/1D positional encoding. Concat order is [text, image] (text first).
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device)
self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
B, Nx, _ = x.shape
_, Ny, _ = y.shape
H = self.num_heads
D = self.head_dim
qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4)
qx, kx, vx = qkv_x.unbind(0)
qx = self.q_norm_x(qx)
kx = self.k_norm_x(kx)
qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4)
qy, ky, vy = qkv_y.unbind(0)
qy = self.q_norm_y(qy)
ky = self.k_norm_y(ky)
qx, kx = apply_rope(qx, kx, pos_img[None, None])
if pos_txt is not None:
qy, ky = apply_rope(qy, ky, pos_txt[None, None])
q_joint = torch.cat([qy, qx], dim=2)
k_joint = torch.cat([ky, kx], dim=2)
v_joint = torch.cat([vy, vx], dim=2)
out_joint = optimized_attention(
q_joint, k_joint, v_joint, H,
mask=attn_mask, skip_reshape=True, skip_output_reshape=True,
transformer_options=transformer_options,
)
out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D)
out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D)
return self.proj_x(out_x), self.proj_y(out_y)
class MMDiTBlockT2I(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None):
super().__init__()
self.hidden_size = hidden_size
self.groups = groups
self.head_dim = hidden_size // groups
self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False,
dtype=dtype, device=device, operations=operations)
self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1,
dtype=dtype, device=device, operations=operations)
self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1,
dtype=dtype, device=device, operations=operations)
self.adaLN_modulation_img = nn.Sequential(
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device),
)
self.adaLN_modulation_txt = nn.Sequential(
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1)
shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1)
x_norm = apply_adaln(self.norm_x1(x), shift_msa_x, scale_msa_x)
y_norm = apply_adaln(self.norm_y1(y), shift_msa_y, scale_msa_y)
attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options)
x = torch.addcmul(x, gate_msa_x, attn_x)
y = torch.addcmul(y, gate_msa_y, attn_y)
x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
return x, y
class PixDiT_T2I(nn.Module):
"""PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint
(also runs at 512px when fed the appropriate latent size and flow_shift).
Forward:
x: [B, 3, H, W] pixel-space input (no VAE)
timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention)
context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended)
Returns flow-matching velocity [B, 3, H, W].
"""
def __init__(
self,
in_channels=3,
num_groups=24,
hidden_size=1536,
pixel_hidden_size=16,
pixel_attn_hidden_size=1152,
pixel_num_groups=16,
patch_depth=14,
pixel_depth=2,
patch_size=16,
txt_embed_dim=2304,
txt_max_length=300,
use_text_rope=True,
text_rope_theta=10000.0,
use_pixel_abs_pos=True,
image_model=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
self.in_channels = int(in_channels)
self.out_channels = int(in_channels)
self.hidden_size = int(hidden_size)
self.num_groups = int(num_groups)
self.patch_depth = int(patch_depth)
self.pixel_depth = int(pixel_depth)
self.patch_size = int(patch_size)
self.pixel_hidden_size = int(pixel_hidden_size)
self.pixel_attn_hidden_size = int(pixel_attn_hidden_size)
self.pixel_num_groups = int(pixel_num_groups)
self.txt_embed_dim = int(txt_embed_dim)
self.txt_max_length = int(txt_max_length)
self.use_text_rope = bool(use_text_rope)
self.text_rope_theta = float(text_rope_theta)
self.use_pixel_abs_pos = bool(use_pixel_abs_pos)
self.pixel_embedder = PixelTokenEmbedder(
self.in_channels, self.pixel_hidden_size, use_pixel_abs_pos=self.use_pixel_abs_pos,
dtype=dtype, device=device, operations=operations,
)
self.s_embedder = PatchTokenEmbedder(
self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True,
dtype=dtype, device=device, operations=operations,
)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10)
self.y_embedder = PatchTokenEmbedder(
self.txt_embed_dim, self.hidden_size, bias=True, norm_layer=True,
dtype=dtype, device=device, operations=operations,
)
self.y_pos_embedding = nn.Parameter(
torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device)
)
self.patch_blocks = nn.ModuleList([
MMDiTBlockT2I(self.hidden_size, self.num_groups,
dtype=dtype, device=device, operations=operations)
for _ in range(self.patch_depth)
])
self.pixel_blocks = nn.ModuleList([
PiTBlock(
self.pixel_hidden_size,
self.hidden_size,
patch_size=self.patch_size,
num_heads=self.num_groups,
mlp_ratio=4.0,
attn_hidden_size=self.pixel_attn_hidden_size,
attn_num_heads=self.pixel_num_groups,
dtype=dtype, device=device, operations=operations,
)
for _ in range(self.pixel_depth)
])
self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels,
dtype=dtype, device=device, operations=operations)
self._patch_pos_cache = {}
self._text_pos_cache = {}
def _fetch_patch_pos(self, height, width, device, dtype):
key = (height, width)
pos = self._patch_pos_cache.get(key)
if pos is None:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width)
self._patch_pos_cache[key] = pos
return pos.to(device=device, dtype=dtype)
def _fetch_text_pos(self, length, device, dtype):
pos = self._text_pos_cache.get(length)
if pos is None:
pos = rope(torch.arange(length, dtype=torch.float32).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0)
self._text_pos_cache[length] = pos
return pos.to(device=device, dtype=dtype)
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
B, _, H, W = x.shape
Hs = H // self.patch_size
Ws = W // self.patch_size
L = Hs * Ws
pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype)
x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size)
if context is None or context.dim() != 3:
raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]")
Ltxt = min(context.shape[1], self.txt_max_length)
y = context[:, :Ltxt, :]
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb.dtype)
condition = F.silu(t_emb)
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
s = self.s_embedder(x_patches)
for blk in self.patch_blocks:
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options)
s = F.silu(t_emb + s)
s_cond = s.view(B * L, self.hidden_size)
x_pixels = self.pixel_embedder(x, img_height=H, img_width=W, patch_size=self.patch_size)
for blk in self.pixel_blocks:
x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options)
x_pixels = self.final_layer(x_pixels)
C_out = self.out_channels
P2 = self.patch_size * self.patch_size
x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).contiguous()
x_pixels = x_pixels.view(B, C_out * P2, L)
return F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)

View File

@ -0,0 +1,199 @@
import torch
import torch.nn as nn
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp
def apply_adaln(x, shift, scale):
return torch.addcmul(x + shift, x, scale)
def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0, device=None, dtype=torch.float32):
"""2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim.
Returns Flux-format rotation matrices of shape [H*W, dim/2, 2, 2].
Layout of head-dim pairs: [x_0, y_0, x_1, y_1, ..., x_{dim/4-1}, y_{dim/4-1}].
"""
x_pos = torch.linspace(0, scale, width, device=device)
y_pos = torch.linspace(0, scale, height, device=device)
y_grid, x_grid = torch.meshgrid(y_pos, x_pos, indexing="ij")
x_pos = x_grid.reshape(-1)
y_pos = y_grid.reshape(-1)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4, device=device, dtype=torch.float32)[: (dim // 4)] / dim))
x_freqs = torch.outer(x_pos, freqs)
y_freqs = torch.outer(y_pos, freqs)
freqs_interleaved = torch.stack([x_freqs, y_freqs], dim=-1).reshape(height * width, -1)
cos = torch.cos(freqs_interleaved)
sin = torch.sin(freqs_interleaved)
out = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*cos.shape, 2, 2)
return out.to(dtype=dtype)
def get_2d_sincos_pos_embed(embed_dim, height, width, device=None, dtype=torch.float32):
"""Torch port of MAE's 2D sin/cos absolute positional embedding for the pixel embedder.
first half encodes W-coordinates, second half H.
"""
assert embed_dim % 4 == 0
grid_h = torch.arange(height, dtype=torch.float32, device=device)
grid_w = torch.arange(width, dtype=torch.float32, device=device)
grid_y, grid_x = torch.meshgrid(grid_h, grid_w, indexing="ij")
grid_y = grid_y.reshape(-1)
grid_x = grid_x.reshape(-1)
omega = torch.arange(embed_dim // 4, dtype=torch.float32, device=device) / (embed_dim / 4.0)
omega = 1.0 / (10000.0 ** omega)
out_w = torch.outer(grid_x, omega)
out_h = torch.outer(grid_y, omega)
emb_w = torch.cat([torch.sin(out_w), torch.cos(out_w)], dim=1)
emb_h = torch.cat([torch.sin(out_h), torch.cos(out_h)], dim=1)
return torch.cat([emb_w, emb_h], dim=1).to(dtype=dtype)
class RotaryAttention(nn.Module):
"""Single-stream self-attention with rotary positional encoding (used inside PiTBlock)."""
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x, pos, mask=None, transformer_options={}):
B, N, C = x.shape
H = self.num_heads
D = self.head_dim
qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rope(q, k, pos[None, None])
x = optimized_attention(q, k, v, H, mask=mask, skip_reshape=True, transformer_options=transformer_options)
return self.proj(x)
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, x):
return self.linear(self.norm(x))
class PatchTokenEmbedder(nn.Module):
"""Linear projection used both for patchified-image tokens and text-feature tokens."""
def __init__(self, in_chans, embed_dim, norm_layer=None, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = operations.Linear(in_chans, embed_dim, bias=bias, dtype=dtype, device=device)
if norm_layer is not None:
self.norm = operations.RMSNorm(embed_dim, eps=1e-6, dtype=dtype, device=device)
else:
self.norm = nn.Identity()
def forward(self, x):
return self.norm(self.proj(x))
class PixelTokenEmbedder(nn.Module):
"""Pixel-level embedder: lifts each RGB pixel to hidden_size and packs into per-patch sequences."""
def __init__(self, in_channels, hidden_size_output, use_pixel_abs_pos=True,
dtype=None, device=None, operations=None):
super().__init__()
self.in_channels = in_channels
self.hidden_size_output = hidden_size_output
self.use_pixel_abs_pos = bool(use_pixel_abs_pos)
self.proj = operations.Linear(self.in_channels, self.hidden_size_output, bias=True, dtype=dtype, device=device)
self._pos_cache = {}
def _fetch_pixel_pos(self, height, width, device, dtype):
key = (height, width)
pe = self._pos_cache.get(key)
if pe is None:
pe = get_2d_sincos_pos_embed(self.hidden_size_output, height, width)
self._pos_cache[key] = pe
return pe.to(device=device, dtype=dtype)
def forward(self, inputs, img_height, img_width, patch_size):
B, C, H, W = inputs.shape
assert H == img_height and W == img_width
assert (H % patch_size == 0) and (W % patch_size == 0)
Hs, Ws = H // patch_size, W // patch_size
P2 = patch_size * patch_size
x = inputs.permute(0, 2, 3, 1).contiguous()
x = self.proj(x)
if self.use_pixel_abs_pos:
pos_full = self._fetch_pixel_pos(H, W, x.device, x.dtype)
pos_full = pos_full.view(H, W, self.hidden_size_output)
x = x + pos_full.unsqueeze(0)
x = x.view(B, Hs, patch_size, Ws, patch_size, self.hidden_size_output)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
return x.view(B * Hs * Ws, P2, self.hidden_size_output)
class PiTBlock(nn.Module):
"""Pixel-level transformer block.
Compresses each patch's P^2 pixel tokens → 1 attention token via a linear,
runs global self-attention across patches with 2D RoPE, then expands back to P^2 tokens.
Conditioning is per-pixel adaLN from the patch-level features.
"""
def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0,
attn_hidden_size=None, attn_num_heads=None, rope_fn=None,
dtype=None, device=None, operations=None):
super().__init__()
self.pixel_dim = pixel_hidden_size
self.context_dim = patch_hidden_size
self.patch_size = patch_size
self.attn_dim = attn_hidden_size if attn_hidden_size is not None else patch_hidden_size
self.num_heads = attn_num_heads if attn_num_heads is not None else num_heads
assert self.attn_dim % self.num_heads == 0
p2 = patch_size * patch_size
self.compress_to_attn = operations.Linear(p2 * self.pixel_dim, self.attn_dim, bias=True, dtype=dtype, device=device)
self.expand_from_attn = operations.Linear(self.attn_dim, p2 * self.pixel_dim, bias=True, dtype=dtype, device=device)
self.norm1 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
self.attn = RotaryAttention(self.attn_dim, num_heads=self.num_heads, qkv_bias=False,
dtype=dtype, device=device, operations=operations)
self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio),
dtype=dtype, device=device, operations=operations)
self.adaLN_modulation = nn.Sequential(
operations.Linear(self.context_dim, 6 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device),
)
self._pos_cache = {}
self._rope_fn = rope_fn if rope_fn is not None else precompute_freqs_cis_2d
def _fetch_pos(self, height, width, device, dtype):
key = (height, width)
pos = self._pos_cache.get(key)
if pos is None:
pos = self._rope_fn(self.attn_dim // self.num_heads, height, width)
self._pos_cache[key] = pos
return pos.to(device=device, dtype=dtype)
def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}):
BL, P2, _ = x.shape
Hs, Ws = image_height // patch_size, image_width // patch_size
L = Hs * Ws
B = BL // L
cond_params = self.adaLN_modulation(s_cond).view(BL, P2, 6 * self.pixel_dim)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = cond_params.chunk(6, dim=-1)
x_norm = apply_adaln(self.norm1(x), shift_msa, scale_msa)
x_flat = x_norm.view(BL, P2 * self.pixel_dim)
x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype)
attn_out = self.attn(x_comp, pos_comp, mask=mask, transformer_options=transformer_options)
attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim))
attn_exp = attn_flat.view(BL, P2, self.pixel_dim)
x = torch.addcmul(x, gate_msa, attn_exp)
mlp_out = self.mlp(apply_adaln(self.norm2(x), shift_mlp, scale_mlp))
x = torch.addcmul(x, gate_mlp, mlp_out)
return x

286
comfy/ldm/pixeldit/pid.py Normal file
View File

@ -0,0 +1,286 @@
"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent
directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I
body + LQ projection branch injected before each MMDiT patch block.
"""
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.flux.math import rope
from .model import PixDiT_T2I
def precompute_freqs_cis_2d_ntk(dim: int, height: int, width: int,
ref_grid_h: int, ref_grid_w: int,
theta: float = 10000.0, scale: float = 16.0,
device=None, dtype=torch.float32):
"""NTK-aware 2D RoPE (rope_mode='ntk_aware' in upstream PiD).
Per-axis theta = theta * (current/ref)^(dim_axis/(dim_axis-2)). Returns
[H*W, dim/2, 2, 2] with x/y axis freqs interleaved at stride 2 (matches
the head-dim layout PiD's Q/K weights expect).
"""
dim_axis = dim // 2
h_ntk = (height / ref_grid_h) ** (dim_axis / (dim_axis - 2)) if dim_axis > 2 else 1.0
w_ntk = (width / ref_grid_w) ** (dim_axis / (dim_axis - 2)) if dim_axis > 2 else 1.0
x_lin = torch.linspace(0, scale, width, device=device)
y_lin = torch.linspace(0, scale, height, device=device)
y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij")
x_rope = rope(x_grid.reshape(1, -1), dim_axis, theta * w_ntk).squeeze(0)
y_rope = rope(y_grid.reshape(1, -1), dim_axis, theta * h_ntk).squeeze(0)
out = torch.stack([x_rope, y_rope], dim=2).reshape(height * width, dim // 2, 2, 2)
return out.to(dtype=dtype)
class SigmaAwareGatePerTokenPerDim(nn.Module):
"""gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1.
"""
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device)
self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device))
def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
content_logit = self.content_proj(torch.cat([x, lq], dim=-1))
# log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32)
sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1)
gate = torch.sigmoid(content_logit + sigma_offset)
return x + (gate * lq).to(x.dtype)
class ResBlock(nn.Module):
"""Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip."""
def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None):
super().__init__()
self.block = nn.Sequential(
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)
class LQProjection2D(nn.Module):
"""LQ latent -> per-block patch-aligned features for controlnet-style injection."""
def __init__(
self,
latent_channels: int,
hidden_dim: int = 512,
out_dim: int = 1536,
patch_size: int = 16,
sr_scale: int = 4,
latent_spatial_down_factor: int = 8,
num_res_blocks: int = 4,
num_outputs: int = 14,
interval: int = 1,
dtype=None,
device=None,
operations=None,
):
super().__init__()
assert latent_channels > 0
self.latent_channels = latent_channels
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.patch_size = patch_size
self.sr_scale = sr_scale
self.latent_spatial_down_factor = latent_spatial_down_factor
self.num_outputs = num_outputs
self.interval = interval
z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size
self.z_to_patch_ratio = z_to_patch_ratio
if z_to_patch_ratio >= 1:
self.latent_upsample_ratio = int(z_to_patch_ratio) if z_to_patch_ratio > 1 else 1
self.latent_fold_factor = 0
latent_proj_in_ch = latent_channels
else:
fold_factor = int(1 / z_to_patch_ratio)
assert fold_factor * z_to_patch_ratio == 1.0
self.latent_upsample_ratio = 0
self.latent_fold_factor = fold_factor
latent_proj_in_ch = latent_channels * fold_factor * fold_factor
layers = [
operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
]
for _ in range(num_res_blocks):
layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations))
self.latent_proj = nn.Sequential(*layers)
self.output_heads = nn.ModuleList(
[operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)]
)
self.gate_modules = nn.ModuleList(
[SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations)
for _ in range(num_outputs)]
)
def is_gate_active(self, block_idx: int) -> bool:
return block_idx % self.interval == 0 if self.interval > 1 else True
def output_index(self, block_idx: int) -> int:
return block_idx // self.interval if self.interval > 1 else block_idx
def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor:
return self.gate_modules[out_idx](x, lq_feature, sigma)
def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor:
B, z_dim = lq_latent.shape[:2]
if self.z_to_patch_ratio >= 1:
if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW:
z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest")
else:
z_aligned = lq_latent
else:
f = self.latent_fold_factor
zH_expected, zW_expected = pH * f, pW * f
if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected:
lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest")
z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4)
z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW)
return self.latent_proj(z_aligned)
def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]:
feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW)
tokens = feat.flatten(2).transpose(1, 2)
return [head(tokens) for head in self.output_heads]
class PidNet(PixDiT_T2I):
"""PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block)."""
def __init__(
self,
lq_latent_channels: int = 16,
lq_hidden_dim: int = 512,
lq_num_res_blocks: int = 4,
lq_interval: int = 1,
sr_scale: int = 4,
latent_spatial_down_factor: int = 8,
rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64.
rope_ref_w: int = 1024,
image_model=None,
dtype=None, device=None, operations=None,
**pixdit_kwargs,
):
super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs)
self.rope_ref_grid_h = int(rope_ref_h) // int(self.patch_size)
self.rope_ref_grid_w = int(rope_ref_w) // int(self.patch_size)
# Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware.
def _pit_rope_fn(head_dim, h, w):
return precompute_freqs_cis_2d_ntk(head_dim, h, w, self.rope_ref_grid_h, self.rope_ref_grid_w)
for blk in self.pixel_blocks:
blk._rope_fn = _pit_rope_fn
blk._pos_cache = {}
num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval
self.lq_proj = LQProjection2D(
latent_channels=lq_latent_channels,
hidden_dim=lq_hidden_dim,
out_dim=self.hidden_size,
patch_size=self.patch_size,
sr_scale=sr_scale,
latent_spatial_down_factor=latent_spatial_down_factor,
num_res_blocks=lq_num_res_blocks,
num_outputs=num_lq_outputs,
interval=lq_interval,
dtype=dtype,
device=device,
operations=operations,
)
def _fetch_patch_pos(self, height, width, device, dtype):
key = (height, width)
pos = self._patch_pos_cache.get(key)
if pos is None:
pos = precompute_freqs_cis_2d_ntk(
self.hidden_size // self.num_groups,
height, width,
self.rope_ref_grid_h, self.rope_ref_grid_w,
)
self._patch_pos_cache[key] = pos
return pos.to(device=device, dtype=dtype)
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={},
lq_latent=None, degrade_sigma=None, **kwargs):
B, _, H, W = x.shape
Hs = H // self.patch_size
Ws = W // self.patch_size
L = Hs * Ws
if context is None or context.dim() != 3:
raise ValueError("PidNet requires context [B, L, D]")
if lq_latent is None:
raise ValueError("PidNet requires lq_latent — attach via PiDConditioning")
if degrade_sigma is None:
degrade_sigma = torch.zeros(B, device=x.device, dtype=torch.float32)
elif not isinstance(degrade_sigma, torch.Tensor):
degrade_sigma = torch.tensor([float(degrade_sigma)] * B, device=x.device, dtype=torch.float32)
else:
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
if degrade_sigma.numel() == 1 and B > 1:
degrade_sigma = degrade_sigma.expand(B).contiguous()
lq_latent = lq_latent.to(device=x.device, dtype=x.dtype)
lq_features = self.lq_proj(lq_latent=lq_latent, target_pH=Hs, target_pW=Ws)
pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype)
x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t_emb = self.t_embedder(timesteps.view(-1)).view(B, -1, self.hidden_size)
Ltxt = min(context.shape[1], self.txt_max_length)
y = context[:, :Ltxt, :]
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
# y_pos_embedding is raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(device=y_emb.device, dtype=y_emb.dtype)
condition = F.silu(t_emb)
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
s = self.s_embedder(x_patches)
for i, blk in enumerate(self.patch_blocks):
if self.lq_proj.is_gate_active(i):
out_idx = self.lq_proj.output_index(i)
if out_idx < len(lq_features):
s = self.lq_proj.gate(s, lq_features[out_idx], degrade_sigma, out_idx)
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None,
transformer_options=transformer_options)
s = F.silu(t_emb + s)
s_cond = s.view(B * L, self.hidden_size)
x_pixels = self.pixel_embedder(x, img_height=H, img_width=W, patch_size=self.patch_size)
for blk in self.pixel_blocks:
x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None,
transformer_options=transformer_options)
x_pixels = self.final_layer(x_pixels)
C_out = self.out_channels
P2 = self.patch_size * self.patch_size
x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).contiguous()
x_pixels = x_pixels.view(B, C_out * P2, L)
return F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)

View File

@ -48,6 +48,8 @@ import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model import comfy.ldm.hidream.model
import comfy.ldm.chroma.model import comfy.ldm.chroma.model
import comfy.ldm.chroma_radiance.model import comfy.ldm.chroma_radiance.model
import comfy.ldm.pixeldit.model
import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2 import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model import comfy.ldm.qwen_image.model
@ -1296,6 +1298,41 @@ class ZImagePixelSpace(Lumina2):
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
self.memory_usage_factor_conds = ("ref_latents",) self.memory_usage_factor_conds = ("ref_latents",)
class PixelDiTT2I(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device,
unet_model=comfy.ldm.pixeldit.model.PixDiT_T2I)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
return out
class PiD(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device,
unet_model=comfy.ldm.pixeldit.pid.PidNet)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
lq_latent = kwargs.get("lq_latent", None)
if lq_latent is not None:
out["lq_latent"] = comfy.conds.CONDRegular(lq_latent)
degrade_sigma = kwargs.get("degrade_sigma", None)
if degrade_sigma is not None:
if not isinstance(degrade_sigma, torch.Tensor):
degrade_sigma = torch.tensor([float(degrade_sigma)], dtype=torch.float32)
out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma)
return out
class WAN21(BaseModel): class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)

View File

@ -424,6 +424,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config return dit_config
# PiD (Pixel Diffusion Decoder). Must check BEFORE plain PixelDiT_T2I.
_lq_w_key = '{}lq_proj.latent_proj.0.weight'.format(key_prefix)
if _lq_w_key in state_dict_keys:
in_ch = int(state_dict[_lq_w_key].shape[1])
_gate_prefix = '{}lq_proj.gate_modules.'.format(key_prefix)
num_gates = len({k[len(_gate_prefix):].split('.')[0]
for k in state_dict_keys if k.startswith(_gate_prefix)})
dit_config = {"image_model": "pid",
"lq_latent_channels": in_ch,
"latent_spatial_down_factor": 16 if in_ch >= 64 else 8}
if num_gates > 0:
dit_config["lq_interval"] = (14 + num_gates - 1) // num_gates
return dit_config
if '{}core.pixel_embedder.proj.weight'.format(key_prefix) in state_dict_keys: # PixelDiT T2I
return {"image_model": "pixeldit_t2i"}
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
dit_config = {} dit_config = {}
dit_config["image_model"] = "lumina2" dit_config["image_model"] = "lumina2"

View File

@ -49,6 +49,7 @@ import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2 import comfy.text_encoders.lumina2
import comfy.text_encoders.pixeldit
import comfy.text_encoders.wan import comfy.text_encoders.wan
import comfy.text_encoders.hidream import comfy.text_encoders.hidream
import comfy.text_encoders.ace import comfy.text_encoders.ace
@ -1228,6 +1229,7 @@ class CLIPType(Enum):
FLUX2 = 25 FLUX2 = 25
LONGCAT_IMAGE = 26 LONGCAT_IMAGE = 26
COGVIDEOX = 27 COGVIDEOX = 27
PIXELDIT = 28
@ -1460,8 +1462,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = variant.tokenizer clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B: elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) if clip_type == CLIPType.PIXELDIT:
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer clip_target.clip = comfy.text_encoders.pixeldit.PixelDiTGemma2TE
clip_target.tokenizer = comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer
else:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B: elif te_model == TEModel.GEMMA_3_4B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b") clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")

View File

@ -29,6 +29,7 @@ import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie import comfy.text_encoders.ernie
import comfy.text_encoders.cogvideo import comfy.text_encoders.cogvideo
import comfy.text_encoders.hidream_o1 import comfy.text_encoders.hidream_o1
import comfy.text_encoders.pixeldit
from . import supported_models_base from . import supported_models_base
from . import latent_formats from . import latent_formats
@ -1135,6 +1136,71 @@ class ZImagePixelSpace(ZImage):
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
return model_base.ZImagePixelSpace(self, device=device) return model_base.ZImagePixelSpace(self, device=device)
class PixelDiTT2I(supported_models_base.BASE):
unet_config = {
"image_model": "pixeldit_t2i",
}
unet_extra_config = {}
sampling_settings = {
"shift": 4.0, # 1024px stage 3 default; 2.0 for 512px
"multiplier": 1000,
}
latent_format = latent_formats.PixelDiTPixel
memory_usage_factor = 0.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
return model_base.PixelDiTT2I(self, device=device)
def process_unet_state_dict(self, state_dict):
out = {}
for k, v in state_dict.items():
if k.startswith("_repa_projector"):
continue
if k.startswith("core."):
out[k[len("core."):]] = v
else:
out[k] = v
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(
comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer,
comfy.text_encoders.pixeldit.PixelDiTGemma2TE,
)
class PiD(PixelDiTT2I):
unet_config = {
"image_model": "pid",
}
sampling_settings = {
"shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0]
"multiplier": 1000,
}
def get_model(self, state_dict, prefix="", device=None):
return model_base.PiD(self, device=device)
def process_unet_state_dict(self, state_dict):
out = {}
for k, v in state_dict.items():
if k.startswith("_repa_projector") or k.startswith("net_ema."):
continue
if k.startswith("core."):
out[k[len("core."):]] = v
elif k.startswith("net."):
out[k[len("net."):]] = v
else:
out[k] = v
return out
class WAN21_T2V(supported_models_base.BASE): class WAN21_T2V(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "wan2.1", "image_model": "wan2.1",
@ -2044,6 +2110,8 @@ models = [
CosmosI2VPredict2, CosmosI2VPredict2,
ZImagePixelSpace, ZImagePixelSpace,
ZImage, ZImage,
PiD,
PixelDiTT2I,
Lumina2, Lumina2,
WAN22_T2V, WAN22_T2V,
WAN21_CausalAR_T2V, WAN21_CausalAR_T2V,

View File

@ -0,0 +1,124 @@
import torch
from comfy import sd1_clip
from .lumina2 import Gemma2BTokenizer, LuminaModel
import comfy.text_encoders.llama
class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel):
"""Gemma-2-2b-it text encoder for PixelDiT.
Uses the FINAL hidden state (layer='last')
"""
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(
device=device, layer=layer, layer_idx=layer_idx,
textmodel_json_config={}, dtype=dtype,
special_tokens={"start": 2, "pad": 0},
layer_norm_hidden_state=False,
model_class=comfy.text_encoders.llama.Gemma2_2B,
enable_attention_masks=attention_mask,
return_attention_masks=attention_mask,
model_options=model_options,
)
_PIXELDIT_CHI_PROMPT = (
'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions '
"suitable for image generation. Evaluate the level of detail in the user prompt:\n"
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, "
"and spatial relationships to create vivid and concrete scenes.\n"
"- If the prompt is already detailed, refine and enhance the existing details slightly without "
"overcomplicating.\n"
"Here are examples of how to transform or refine prompts:\n"
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, "
"sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n"
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring "
"glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus "
"passing by towering glass skyscrapers.\n"
"Please generate only the enhanced description for the prompt below and avoid including any "
"additional commentary or evaluations:\n"
"User Prompt: "
)
_PIXELDIT_MAX_LENGTH = 300
_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"'
def _build_padded_tokens(combined_text: str, spiece_tokenizer, pad_id: int, chi_token_count: int):
# Right-pad to chi_token_count + 300 - 2 (matches upstream's max_length_all).
max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2
ids = spiece_tokenizer(combined_text)["input_ids"]
if len(ids) > max_length_all:
ids = ids[:max_length_all]
elif len(ids) < max_length_all:
ids = ids + [pad_id] * (max_length_all - len(ids))
return ids
class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer):
"""Gemma-2-2b-it tokenizer that prepends PixelDiT's chi_prompt.
Empty text -> BOS + pad to 300. Text already starting with the chi_prompt
preamble is tokenized verbatim (override mirrors QwenImageTokenizer's
`<|im_start|>` detection). Else chi_prompt is prepended.
"""
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
name="gemma2_2b", tokenizer=Gemma2BTokenizer)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
spiece_tokenizer = self.gemma2_2b.tokenizer
pad_id = self.gemma2_2b.pad_token
if not (isinstance(text, str) and text.strip()):
ids = spiece_tokenizer("")["input_ids"]
ids = ids + [pad_id] * (_PIXELDIT_MAX_LENGTH - len(ids))
return {"gemma2_2b": [[(t, 1.0) for t in ids]]}
chi_token_count = len(spiece_tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"])
combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text
ids = _build_padded_tokens(combined, spiece_tokenizer, pad_id, chi_token_count)
return {"gemma2_2b": [[(t, 1.0) for t in ids]]}
def untokenize(self, token_weight_pair):
return self.gemma2_2b.untokenize(token_weight_pair)
def state_dict(self):
return self.gemma2_2b.state_dict()
class PixelDiTGemma2TE(LuminaModel):
"""Text encoder wrapper for PixelDiT.
Overrides `encode_token_weights` to perform PixelDiT's `select_index` step:
encode the full padded sequence (up to ~chi_prompt_tokens + 298), then
return `[BOS_emb] + last_299_embs` as the 300-position conditioning that
matches the diffusion model's learned `y_pos_embedding` positions.
"""
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="gemma2_2b",
clip_model=PixelDiTGemma2_2BModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
result = super().encode_token_weights(token_weight_pairs)
cond, pooled = result[0], result[1]
extra = result[2] if len(result) > 2 else None
L = cond.shape[1]
if L > _PIXELDIT_MAX_LENGTH:
head = cond[:, :1]
tail = cond[:, -(_PIXELDIT_MAX_LENGTH - 1):]
cond = torch.cat([head, tail], dim=1)
if extra is not None and "attention_mask" in extra:
am = extra["attention_mask"]
if am.dim() == 1:
am = am.unsqueeze(0)
if am.shape[-1] == L:
head_m = am[..., :1]
tail_m = am[..., -(_PIXELDIT_MAX_LENGTH - 1):]
extra = {**extra, "attention_mask": torch.cat([head_m, tail_m], dim=-1)}
if extra is not None:
return cond, pooled, extra
return cond, pooled

Binary file not shown.

60
comfy_extras/nodes_pid.py Normal file
View File

@ -0,0 +1,60 @@
"""PiD (Pixel Diffusion Decoder) node"""
from typing_extensions import override
import node_helpers
import comfy.latent_formats
from comfy_api.latest import ComfyExtension, io
_LATENT_FORMAT_CLASSES = {
"flux": comfy.latent_formats.Flux,
"sd3": comfy.latent_formats.SD3,
"flux2": comfy.latent_formats.Flux2,
}
class PiDConditioning(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="PiDConditioning",
display_name="PiD Conditioning",
category="advanced/conditioning",
description=(
"Attaches an LDM latent (Flux/SD3/Flux2/Z-Image) and a degrade_sigma scalar "
"to a CONDITIONING for PiD decoding. Latent is renormalized into PiD space "
"via the chosen latent_format. Z-Image uses 'flux'."
),
inputs=[
io.Conditioning.Input("positive"),
io.Latent.Input("latent", tooltip="LDM latent (from VAEEncode or a KSampler)."),
io.Combo.Input(
"latent_format",
options=list(_LATENT_FORMAT_CLASSES.keys()),
default="flux",
),
io.Float.Input(
"degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01,
tooltip="0 = clean latent. Increase to denoise corrupted LDM outputs.",
),
],
outputs=[io.Conditioning.Output()],
)
@classmethod
def execute(cls, positive, latent, latent_format: str, degrade_sigma: float) -> io.NodeOutput:
lq_latent = _LATENT_FORMAT_CLASSES[latent_format]().process_in(latent["samples"])
return io.NodeOutput(node_helpers.conditioning_set_values(
positive, {"lq_latent": lq_latent, "degrade_sigma": float(degrade_sigma)},
))
class PiDExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [PiDConditioning]
async def comfy_entrypoint() -> PiDExtension:
return PiDExtension()

View File

@ -958,7 +958,7 @@ class CLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ), "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "pixeldit"], ),
}, },
"optional": { "optional": {
"device": (["default", "cpu"], {"advanced": True}), "device": (["default", "cpu"], {"advanced": True}),
@ -2405,6 +2405,7 @@ async def init_builtin_extra_nodes():
"nodes_context_windows.py", "nodes_context_windows.py",
"nodes_qwen.py", "nodes_qwen.py",
"nodes_chroma_radiance.py", "nodes_chroma_radiance.py",
"nodes_pid.py",
"nodes_model_patch.py", "nodes_model_patch.py",
"nodes_easycache.py", "nodes_easycache.py",
"nodes_audio_encoder.py", "nodes_audio_encoder.py",