mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 09:27:24 +08:00
Support PixelDiT and PiD
This commit is contained in:
parent
0155ddcbe3
commit
7b72f322a5
@ -792,13 +792,15 @@ class ZImagePixelSpace(ChromaRadiance):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class HiDreamO1Pixel(ChromaRadiance):
|
class HiDreamO1Pixel(ChromaRadiance):
|
||||||
"""Pixel-space latent format for HiDream-O1.
|
"""Pixel-space latent format for HiDream-O1.
|
||||||
No VAE — model patches/unpatches raw RGB internally with patch_size=32.
|
No VAE — model patches/unpatches raw RGB internally with patch_size=32.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class PixelDiTPixel(ChromaRadiance):
|
||||||
|
pass
|
||||||
|
|
||||||
class CogVideoX(LatentFormat):
|
class CogVideoX(LatentFormat):
|
||||||
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
|
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
|
||||||
|
|
||||||
|
|||||||
@ -211,7 +211,7 @@ class TimestepEmbedder(nn.Module):
|
|||||||
Embeds scalar timesteps into vector representations.
|
Embeds scalar timesteps into vector representations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None, max_period=10000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if output_size is None:
|
if output_size is None:
|
||||||
output_size = hidden_size
|
output_size = hidden_size
|
||||||
@ -221,9 +221,10 @@ class TimestepEmbedder(nn.Module):
|
|||||||
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
self.frequency_embedding_size = frequency_embedding_size
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
self.max_period = max_period
|
||||||
|
|
||||||
def forward(self, t, dtype, **kwargs):
|
def forward(self, t, dtype, **kwargs):
|
||||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
t_freq = timestep_embedding(t, self.frequency_embedding_size, max_period=self.max_period).to(dtype)
|
||||||
t_emb = self.mlp(t_freq)
|
t_emb = self.mlp(t_freq)
|
||||||
return t_emb
|
return t_emb
|
||||||
|
|
||||||
|
|||||||
0
comfy/ldm/pixeldit/__init__.py
Normal file
0
comfy/ldm/pixeldit/__init__.py
Normal file
270
comfy/ldm/pixeldit/model.py
Normal file
270
comfy/ldm/pixeldit/model.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import comfy.patcher_extension
|
||||||
|
from comfy.ldm.flux.math import apply_rope, rope
|
||||||
|
from comfy.ldm.hidream.model import FeedForwardSwiGLU
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||||
|
|
||||||
|
from .modules import (
|
||||||
|
FinalLayer,
|
||||||
|
PatchTokenEmbedder,
|
||||||
|
PiTBlock,
|
||||||
|
PixelTokenEmbedder,
|
||||||
|
apply_adaln,
|
||||||
|
precompute_freqs_cis_2d,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MMDiTJointAttention(nn.Module):
|
||||||
|
"""Joint MMDiT attention with separate Q/K/V/proj for image and text streams.
|
||||||
|
|
||||||
|
RoPE is applied to each stream before concatenation so each stream uses its own
|
||||||
|
2D/1D positional encoding. Concat order is [text, image] (text first).
|
||||||
|
"""
|
||||||
|
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
|
||||||
|
self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||||
|
self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||||
|
self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
|
||||||
|
B, Nx, _ = x.shape
|
||||||
|
_, Ny, _ = y.shape
|
||||||
|
H = self.num_heads
|
||||||
|
D = self.head_dim
|
||||||
|
|
||||||
|
qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||||
|
qx, kx, vx = qkv_x.unbind(0)
|
||||||
|
qx = self.q_norm_x(qx)
|
||||||
|
kx = self.k_norm_x(kx)
|
||||||
|
|
||||||
|
qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||||
|
qy, ky, vy = qkv_y.unbind(0)
|
||||||
|
qy = self.q_norm_y(qy)
|
||||||
|
ky = self.k_norm_y(ky)
|
||||||
|
|
||||||
|
qx, kx = apply_rope(qx, kx, pos_img[None, None])
|
||||||
|
if pos_txt is not None:
|
||||||
|
qy, ky = apply_rope(qy, ky, pos_txt[None, None])
|
||||||
|
|
||||||
|
q_joint = torch.cat([qy, qx], dim=2)
|
||||||
|
k_joint = torch.cat([ky, kx], dim=2)
|
||||||
|
v_joint = torch.cat([vy, vx], dim=2)
|
||||||
|
|
||||||
|
out_joint = optimized_attention(
|
||||||
|
q_joint, k_joint, v_joint, H,
|
||||||
|
mask=attn_mask, skip_reshape=True, skip_output_reshape=True,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D)
|
||||||
|
out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D)
|
||||||
|
|
||||||
|
return self.proj_x(out_x), self.proj_y(out_y)
|
||||||
|
|
||||||
|
|
||||||
|
class MMDiTBlockT2I(nn.Module):
|
||||||
|
def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.groups = groups
|
||||||
|
self.head_dim = hidden_size // groups
|
||||||
|
|
||||||
|
self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
|
self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
self.adaLN_modulation_img = nn.Sequential(
|
||||||
|
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
self.adaLN_modulation_txt = nn.Sequential(
|
||||||
|
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
|
||||||
|
shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1)
|
||||||
|
shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1)
|
||||||
|
|
||||||
|
x_norm = apply_adaln(self.norm_x1(x), shift_msa_x, scale_msa_x)
|
||||||
|
y_norm = apply_adaln(self.norm_y1(y), shift_msa_y, scale_msa_y)
|
||||||
|
attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options)
|
||||||
|
x = torch.addcmul(x, gate_msa_x, attn_x)
|
||||||
|
y = torch.addcmul(y, gate_msa_y, attn_y)
|
||||||
|
|
||||||
|
x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
|
||||||
|
y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class PixDiT_T2I(nn.Module):
|
||||||
|
"""PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint
|
||||||
|
(also runs at 512px when fed the appropriate latent size and flow_shift).
|
||||||
|
|
||||||
|
Forward:
|
||||||
|
x: [B, 3, H, W] pixel-space input (no VAE)
|
||||||
|
timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention)
|
||||||
|
context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended)
|
||||||
|
Returns flow-matching velocity [B, 3, H, W].
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels=3,
|
||||||
|
num_groups=24,
|
||||||
|
hidden_size=1536,
|
||||||
|
pixel_hidden_size=16,
|
||||||
|
pixel_attn_hidden_size=1152,
|
||||||
|
pixel_num_groups=16,
|
||||||
|
patch_depth=14,
|
||||||
|
pixel_depth=2,
|
||||||
|
patch_size=16,
|
||||||
|
txt_embed_dim=2304,
|
||||||
|
txt_max_length=300,
|
||||||
|
use_text_rope=True,
|
||||||
|
text_rope_theta=10000.0,
|
||||||
|
use_pixel_abs_pos=True,
|
||||||
|
image_model=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.in_channels = int(in_channels)
|
||||||
|
self.out_channels = int(in_channels)
|
||||||
|
self.hidden_size = int(hidden_size)
|
||||||
|
self.num_groups = int(num_groups)
|
||||||
|
self.patch_depth = int(patch_depth)
|
||||||
|
self.pixel_depth = int(pixel_depth)
|
||||||
|
self.patch_size = int(patch_size)
|
||||||
|
self.pixel_hidden_size = int(pixel_hidden_size)
|
||||||
|
self.pixel_attn_hidden_size = int(pixel_attn_hidden_size)
|
||||||
|
self.pixel_num_groups = int(pixel_num_groups)
|
||||||
|
self.txt_embed_dim = int(txt_embed_dim)
|
||||||
|
self.txt_max_length = int(txt_max_length)
|
||||||
|
self.use_text_rope = bool(use_text_rope)
|
||||||
|
self.text_rope_theta = float(text_rope_theta)
|
||||||
|
self.use_pixel_abs_pos = bool(use_pixel_abs_pos)
|
||||||
|
|
||||||
|
self.pixel_embedder = PixelTokenEmbedder(
|
||||||
|
self.in_channels, self.pixel_hidden_size, use_pixel_abs_pos=self.use_pixel_abs_pos,
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
)
|
||||||
|
self.s_embedder = PatchTokenEmbedder(
|
||||||
|
self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True,
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
)
|
||||||
|
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10)
|
||||||
|
self.y_embedder = PatchTokenEmbedder(
|
||||||
|
self.txt_embed_dim, self.hidden_size, bias=True, norm_layer=True,
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
)
|
||||||
|
self.y_pos_embedding = nn.Parameter(
|
||||||
|
torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.patch_blocks = nn.ModuleList([
|
||||||
|
MMDiTBlockT2I(self.hidden_size, self.num_groups,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
for _ in range(self.patch_depth)
|
||||||
|
])
|
||||||
|
self.pixel_blocks = nn.ModuleList([
|
||||||
|
PiTBlock(
|
||||||
|
self.pixel_hidden_size,
|
||||||
|
self.hidden_size,
|
||||||
|
patch_size=self.patch_size,
|
||||||
|
num_heads=self.num_groups,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
attn_hidden_size=self.pixel_attn_hidden_size,
|
||||||
|
attn_num_heads=self.pixel_num_groups,
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
)
|
||||||
|
for _ in range(self.pixel_depth)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self._patch_pos_cache = {}
|
||||||
|
self._text_pos_cache = {}
|
||||||
|
|
||||||
|
def _fetch_patch_pos(self, height, width, device, dtype):
|
||||||
|
key = (height, width)
|
||||||
|
pos = self._patch_pos_cache.get(key)
|
||||||
|
if pos is None:
|
||||||
|
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width)
|
||||||
|
self._patch_pos_cache[key] = pos
|
||||||
|
return pos.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def _fetch_text_pos(self, length, device, dtype):
|
||||||
|
pos = self._text_pos_cache.get(length)
|
||||||
|
if pos is None:
|
||||||
|
pos = rope(torch.arange(length, dtype=torch.float32).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0)
|
||||||
|
self._text_pos_cache[length] = pos
|
||||||
|
return pos.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||||
|
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||||
|
B, _, H, W = x.shape
|
||||||
|
Hs = H // self.patch_size
|
||||||
|
Ws = W // self.patch_size
|
||||||
|
L = Hs * Ws
|
||||||
|
|
||||||
|
pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype)
|
||||||
|
x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||||
|
|
||||||
|
t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size)
|
||||||
|
|
||||||
|
if context is None or context.dim() != 3:
|
||||||
|
raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]")
|
||||||
|
Ltxt = min(context.shape[1], self.txt_max_length)
|
||||||
|
y = context[:, :Ltxt, :]
|
||||||
|
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
|
||||||
|
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb.dtype)
|
||||||
|
|
||||||
|
condition = F.silu(t_emb)
|
||||||
|
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
|
||||||
|
|
||||||
|
s = self.s_embedder(x_patches)
|
||||||
|
for blk in self.patch_blocks:
|
||||||
|
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options)
|
||||||
|
s = F.silu(t_emb + s)
|
||||||
|
|
||||||
|
s_cond = s.view(B * L, self.hidden_size)
|
||||||
|
x_pixels = self.pixel_embedder(x, img_height=H, img_width=W, patch_size=self.patch_size)
|
||||||
|
for blk in self.pixel_blocks:
|
||||||
|
x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
x_pixels = self.final_layer(x_pixels)
|
||||||
|
C_out = self.out_channels
|
||||||
|
P2 = self.patch_size * self.patch_size
|
||||||
|
x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).contiguous()
|
||||||
|
x_pixels = x_pixels.view(B, C_out * P2, L)
|
||||||
|
return F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||||
199
comfy/ldm/pixeldit/modules.py
Normal file
199
comfy/ldm/pixeldit/modules.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.ldm.flux.math import apply_rope
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp
|
||||||
|
|
||||||
|
|
||||||
|
def apply_adaln(x, shift, scale):
|
||||||
|
return torch.addcmul(x + shift, x, scale)
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0, device=None, dtype=torch.float32):
|
||||||
|
"""2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim.
|
||||||
|
|
||||||
|
Returns Flux-format rotation matrices of shape [H*W, dim/2, 2, 2].
|
||||||
|
Layout of head-dim pairs: [x_0, y_0, x_1, y_1, ..., x_{dim/4-1}, y_{dim/4-1}].
|
||||||
|
"""
|
||||||
|
x_pos = torch.linspace(0, scale, width, device=device)
|
||||||
|
y_pos = torch.linspace(0, scale, height, device=device)
|
||||||
|
y_grid, x_grid = torch.meshgrid(y_pos, x_pos, indexing="ij")
|
||||||
|
x_pos = x_grid.reshape(-1)
|
||||||
|
y_pos = y_grid.reshape(-1)
|
||||||
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4, device=device, dtype=torch.float32)[: (dim // 4)] / dim))
|
||||||
|
x_freqs = torch.outer(x_pos, freqs)
|
||||||
|
y_freqs = torch.outer(y_pos, freqs)
|
||||||
|
freqs_interleaved = torch.stack([x_freqs, y_freqs], dim=-1).reshape(height * width, -1)
|
||||||
|
cos = torch.cos(freqs_interleaved)
|
||||||
|
sin = torch.sin(freqs_interleaved)
|
||||||
|
out = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*cos.shape, 2, 2)
|
||||||
|
return out.to(dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed(embed_dim, height, width, device=None, dtype=torch.float32):
|
||||||
|
"""Torch port of MAE's 2D sin/cos absolute positional embedding for the pixel embedder.
|
||||||
|
|
||||||
|
first half encodes W-coordinates, second half H.
|
||||||
|
"""
|
||||||
|
assert embed_dim % 4 == 0
|
||||||
|
grid_h = torch.arange(height, dtype=torch.float32, device=device)
|
||||||
|
grid_w = torch.arange(width, dtype=torch.float32, device=device)
|
||||||
|
grid_y, grid_x = torch.meshgrid(grid_h, grid_w, indexing="ij")
|
||||||
|
grid_y = grid_y.reshape(-1)
|
||||||
|
grid_x = grid_x.reshape(-1)
|
||||||
|
omega = torch.arange(embed_dim // 4, dtype=torch.float32, device=device) / (embed_dim / 4.0)
|
||||||
|
omega = 1.0 / (10000.0 ** omega)
|
||||||
|
out_w = torch.outer(grid_x, omega)
|
||||||
|
out_h = torch.outer(grid_y, omega)
|
||||||
|
emb_w = torch.cat([torch.sin(out_w), torch.cos(out_w)], dim=1)
|
||||||
|
emb_h = torch.cat([torch.sin(out_h), torch.cos(out_h)], dim=1)
|
||||||
|
return torch.cat([emb_w, emb_h], dim=1).to(dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryAttention(nn.Module):
|
||||||
|
"""Single-stream self-attention with rotary positional encoding (used inside PiTBlock)."""
|
||||||
|
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||||
|
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x, pos, mask=None, transformer_options={}):
|
||||||
|
B, N, C = x.shape
|
||||||
|
H = self.num_heads
|
||||||
|
D = self.head_dim
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv.unbind(0)
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
q, k = apply_rope(q, k, pos[None, None])
|
||||||
|
x = optimized_attention(q, k, v, H, mask=mask, skip_reshape=True, transformer_options=transformer_options)
|
||||||
|
return self.proj(x)
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(self.norm(x))
|
||||||
|
|
||||||
|
|
||||||
|
class PatchTokenEmbedder(nn.Module):
|
||||||
|
"""Linear projection used both for patchified-image tokens and text-feature tokens."""
|
||||||
|
def __init__(self, in_chans, embed_dim, norm_layer=None, bias=True, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.in_chans = in_chans
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.proj = operations.Linear(in_chans, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||||
|
if norm_layer is not None:
|
||||||
|
self.norm = operations.RMSNorm(embed_dim, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.norm(self.proj(x))
|
||||||
|
|
||||||
|
|
||||||
|
class PixelTokenEmbedder(nn.Module):
|
||||||
|
"""Pixel-level embedder: lifts each RGB pixel to hidden_size and packs into per-patch sequences."""
|
||||||
|
def __init__(self, in_channels, hidden_size_output, use_pixel_abs_pos=True,
|
||||||
|
dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.hidden_size_output = hidden_size_output
|
||||||
|
self.use_pixel_abs_pos = bool(use_pixel_abs_pos)
|
||||||
|
self.proj = operations.Linear(self.in_channels, self.hidden_size_output, bias=True, dtype=dtype, device=device)
|
||||||
|
self._pos_cache = {}
|
||||||
|
|
||||||
|
def _fetch_pixel_pos(self, height, width, device, dtype):
|
||||||
|
key = (height, width)
|
||||||
|
pe = self._pos_cache.get(key)
|
||||||
|
if pe is None:
|
||||||
|
pe = get_2d_sincos_pos_embed(self.hidden_size_output, height, width)
|
||||||
|
self._pos_cache[key] = pe
|
||||||
|
return pe.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, inputs, img_height, img_width, patch_size):
|
||||||
|
B, C, H, W = inputs.shape
|
||||||
|
assert H == img_height and W == img_width
|
||||||
|
assert (H % patch_size == 0) and (W % patch_size == 0)
|
||||||
|
Hs, Ws = H // patch_size, W // patch_size
|
||||||
|
P2 = patch_size * patch_size
|
||||||
|
x = inputs.permute(0, 2, 3, 1).contiguous()
|
||||||
|
x = self.proj(x)
|
||||||
|
if self.use_pixel_abs_pos:
|
||||||
|
pos_full = self._fetch_pixel_pos(H, W, x.device, x.dtype)
|
||||||
|
pos_full = pos_full.view(H, W, self.hidden_size_output)
|
||||||
|
x = x + pos_full.unsqueeze(0)
|
||||||
|
x = x.view(B, Hs, patch_size, Ws, patch_size, self.hidden_size_output)
|
||||||
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
||||||
|
return x.view(B * Hs * Ws, P2, self.hidden_size_output)
|
||||||
|
|
||||||
|
|
||||||
|
class PiTBlock(nn.Module):
|
||||||
|
"""Pixel-level transformer block.
|
||||||
|
|
||||||
|
Compresses each patch's P^2 pixel tokens → 1 attention token via a linear,
|
||||||
|
runs global self-attention across patches with 2D RoPE, then expands back to P^2 tokens.
|
||||||
|
Conditioning is per-pixel adaLN from the patch-level features.
|
||||||
|
"""
|
||||||
|
def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0,
|
||||||
|
attn_hidden_size=None, attn_num_heads=None, rope_fn=None,
|
||||||
|
dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.pixel_dim = pixel_hidden_size
|
||||||
|
self.context_dim = patch_hidden_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.attn_dim = attn_hidden_size if attn_hidden_size is not None else patch_hidden_size
|
||||||
|
self.num_heads = attn_num_heads if attn_num_heads is not None else num_heads
|
||||||
|
assert self.attn_dim % self.num_heads == 0
|
||||||
|
p2 = patch_size * patch_size
|
||||||
|
self.compress_to_attn = operations.Linear(p2 * self.pixel_dim, self.attn_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
self.expand_from_attn = operations.Linear(self.attn_dim, p2 * self.pixel_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
self.norm1 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.attn = RotaryAttention(self.attn_dim, num_heads=self.num_heads, qkv_bias=False,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio),
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
operations.Linear(self.context_dim, 6 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
self._pos_cache = {}
|
||||||
|
self._rope_fn = rope_fn if rope_fn is not None else precompute_freqs_cis_2d
|
||||||
|
|
||||||
|
def _fetch_pos(self, height, width, device, dtype):
|
||||||
|
key = (height, width)
|
||||||
|
pos = self._pos_cache.get(key)
|
||||||
|
if pos is None:
|
||||||
|
pos = self._rope_fn(self.attn_dim // self.num_heads, height, width)
|
||||||
|
self._pos_cache[key] = pos
|
||||||
|
return pos.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}):
|
||||||
|
BL, P2, _ = x.shape
|
||||||
|
Hs, Ws = image_height // patch_size, image_width // patch_size
|
||||||
|
L = Hs * Ws
|
||||||
|
B = BL // L
|
||||||
|
cond_params = self.adaLN_modulation(s_cond).view(BL, P2, 6 * self.pixel_dim)
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = cond_params.chunk(6, dim=-1)
|
||||||
|
x_norm = apply_adaln(self.norm1(x), shift_msa, scale_msa)
|
||||||
|
x_flat = x_norm.view(BL, P2 * self.pixel_dim)
|
||||||
|
x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
|
||||||
|
pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype)
|
||||||
|
attn_out = self.attn(x_comp, pos_comp, mask=mask, transformer_options=transformer_options)
|
||||||
|
attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim))
|
||||||
|
attn_exp = attn_flat.view(BL, P2, self.pixel_dim)
|
||||||
|
x = torch.addcmul(x, gate_msa, attn_exp)
|
||||||
|
mlp_out = self.mlp(apply_adaln(self.norm2(x), shift_mlp, scale_mlp))
|
||||||
|
x = torch.addcmul(x, gate_mlp, mlp_out)
|
||||||
|
return x
|
||||||
286
comfy/ldm/pixeldit/pid.py
Normal file
286
comfy/ldm/pixeldit/pid.py
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent
|
||||||
|
directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I
|
||||||
|
body + LQ projection branch injected before each MMDiT patch block.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from comfy.ldm.flux.math import rope
|
||||||
|
|
||||||
|
from .model import PixDiT_T2I
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_freqs_cis_2d_ntk(dim: int, height: int, width: int,
|
||||||
|
ref_grid_h: int, ref_grid_w: int,
|
||||||
|
theta: float = 10000.0, scale: float = 16.0,
|
||||||
|
device=None, dtype=torch.float32):
|
||||||
|
"""NTK-aware 2D RoPE (rope_mode='ntk_aware' in upstream PiD).
|
||||||
|
|
||||||
|
Per-axis theta = theta * (current/ref)^(dim_axis/(dim_axis-2)). Returns
|
||||||
|
[H*W, dim/2, 2, 2] with x/y axis freqs interleaved at stride 2 (matches
|
||||||
|
the head-dim layout PiD's Q/K weights expect).
|
||||||
|
"""
|
||||||
|
dim_axis = dim // 2
|
||||||
|
h_ntk = (height / ref_grid_h) ** (dim_axis / (dim_axis - 2)) if dim_axis > 2 else 1.0
|
||||||
|
w_ntk = (width / ref_grid_w) ** (dim_axis / (dim_axis - 2)) if dim_axis > 2 else 1.0
|
||||||
|
|
||||||
|
x_lin = torch.linspace(0, scale, width, device=device)
|
||||||
|
y_lin = torch.linspace(0, scale, height, device=device)
|
||||||
|
y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij")
|
||||||
|
|
||||||
|
x_rope = rope(x_grid.reshape(1, -1), dim_axis, theta * w_ntk).squeeze(0)
|
||||||
|
y_rope = rope(y_grid.reshape(1, -1), dim_axis, theta * h_ntk).squeeze(0)
|
||||||
|
|
||||||
|
out = torch.stack([x_rope, y_rope], dim=2).reshape(height * width, dim // 2, 2, 2)
|
||||||
|
return out.to(dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class SigmaAwareGatePerTokenPerDim(nn.Module):
|
||||||
|
"""gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
|
||||||
|
|
||||||
|
Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device)
|
||||||
|
self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
||||||
|
content_logit = self.content_proj(torch.cat([x, lq], dim=-1))
|
||||||
|
# log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
|
||||||
|
log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32)
|
||||||
|
sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1)
|
||||||
|
gate = torch.sigmoid(content_logit + sigma_offset)
|
||||||
|
return x + (gate * lq).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
"""Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip."""
|
||||||
|
|
||||||
|
def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.block = nn.Sequential(
|
||||||
|
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||||
|
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x + self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class LQProjection2D(nn.Module):
|
||||||
|
"""LQ latent -> per-block patch-aligned features for controlnet-style injection."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
latent_channels: int,
|
||||||
|
hidden_dim: int = 512,
|
||||||
|
out_dim: int = 1536,
|
||||||
|
patch_size: int = 16,
|
||||||
|
sr_scale: int = 4,
|
||||||
|
latent_spatial_down_factor: int = 8,
|
||||||
|
num_res_blocks: int = 4,
|
||||||
|
num_outputs: int = 14,
|
||||||
|
interval: int = 1,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert latent_channels > 0
|
||||||
|
self.latent_channels = latent_channels
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.sr_scale = sr_scale
|
||||||
|
self.latent_spatial_down_factor = latent_spatial_down_factor
|
||||||
|
self.num_outputs = num_outputs
|
||||||
|
self.interval = interval
|
||||||
|
|
||||||
|
z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size
|
||||||
|
self.z_to_patch_ratio = z_to_patch_ratio
|
||||||
|
if z_to_patch_ratio >= 1:
|
||||||
|
self.latent_upsample_ratio = int(z_to_patch_ratio) if z_to_patch_ratio > 1 else 1
|
||||||
|
self.latent_fold_factor = 0
|
||||||
|
latent_proj_in_ch = latent_channels
|
||||||
|
else:
|
||||||
|
fold_factor = int(1 / z_to_patch_ratio)
|
||||||
|
assert fold_factor * z_to_patch_ratio == 1.0
|
||||||
|
self.latent_upsample_ratio = 0
|
||||||
|
self.latent_fold_factor = fold_factor
|
||||||
|
latent_proj_in_ch = latent_channels * fold_factor * fold_factor
|
||||||
|
|
||||||
|
layers = [
|
||||||
|
operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||||
|
]
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations))
|
||||||
|
self.latent_proj = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
self.output_heads = nn.ModuleList(
|
||||||
|
[operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)]
|
||||||
|
)
|
||||||
|
self.gate_modules = nn.ModuleList(
|
||||||
|
[SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
for _ in range(num_outputs)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_gate_active(self, block_idx: int) -> bool:
|
||||||
|
return block_idx % self.interval == 0 if self.interval > 1 else True
|
||||||
|
|
||||||
|
def output_index(self, block_idx: int) -> int:
|
||||||
|
return block_idx // self.interval if self.interval > 1 else block_idx
|
||||||
|
|
||||||
|
def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor:
|
||||||
|
return self.gate_modules[out_idx](x, lq_feature, sigma)
|
||||||
|
|
||||||
|
def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor:
|
||||||
|
B, z_dim = lq_latent.shape[:2]
|
||||||
|
if self.z_to_patch_ratio >= 1:
|
||||||
|
if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW:
|
||||||
|
z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest")
|
||||||
|
else:
|
||||||
|
z_aligned = lq_latent
|
||||||
|
else:
|
||||||
|
f = self.latent_fold_factor
|
||||||
|
zH_expected, zW_expected = pH * f, pW * f
|
||||||
|
if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected:
|
||||||
|
lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest")
|
||||||
|
z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4)
|
||||||
|
z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW)
|
||||||
|
return self.latent_proj(z_aligned)
|
||||||
|
|
||||||
|
def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]:
|
||||||
|
feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW)
|
||||||
|
tokens = feat.flatten(2).transpose(1, 2)
|
||||||
|
return [head(tokens) for head in self.output_heads]
|
||||||
|
|
||||||
|
|
||||||
|
class PidNet(PixDiT_T2I):
|
||||||
|
"""PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lq_latent_channels: int = 16,
|
||||||
|
lq_hidden_dim: int = 512,
|
||||||
|
lq_num_res_blocks: int = 4,
|
||||||
|
lq_interval: int = 1,
|
||||||
|
sr_scale: int = 4,
|
||||||
|
latent_spatial_down_factor: int = 8,
|
||||||
|
rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64.
|
||||||
|
rope_ref_w: int = 1024,
|
||||||
|
image_model=None,
|
||||||
|
dtype=None, device=None, operations=None,
|
||||||
|
**pixdit_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs)
|
||||||
|
|
||||||
|
self.rope_ref_grid_h = int(rope_ref_h) // int(self.patch_size)
|
||||||
|
self.rope_ref_grid_w = int(rope_ref_w) // int(self.patch_size)
|
||||||
|
|
||||||
|
# Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware.
|
||||||
|
def _pit_rope_fn(head_dim, h, w):
|
||||||
|
return precompute_freqs_cis_2d_ntk(head_dim, h, w, self.rope_ref_grid_h, self.rope_ref_grid_w)
|
||||||
|
for blk in self.pixel_blocks:
|
||||||
|
blk._rope_fn = _pit_rope_fn
|
||||||
|
blk._pos_cache = {}
|
||||||
|
|
||||||
|
num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval
|
||||||
|
self.lq_proj = LQProjection2D(
|
||||||
|
latent_channels=lq_latent_channels,
|
||||||
|
hidden_dim=lq_hidden_dim,
|
||||||
|
out_dim=self.hidden_size,
|
||||||
|
patch_size=self.patch_size,
|
||||||
|
sr_scale=sr_scale,
|
||||||
|
latent_spatial_down_factor=latent_spatial_down_factor,
|
||||||
|
num_res_blocks=lq_num_res_blocks,
|
||||||
|
num_outputs=num_lq_outputs,
|
||||||
|
interval=lq_interval,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fetch_patch_pos(self, height, width, device, dtype):
|
||||||
|
key = (height, width)
|
||||||
|
pos = self._patch_pos_cache.get(key)
|
||||||
|
if pos is None:
|
||||||
|
pos = precompute_freqs_cis_2d_ntk(
|
||||||
|
self.hidden_size // self.num_groups,
|
||||||
|
height, width,
|
||||||
|
self.rope_ref_grid_h, self.rope_ref_grid_w,
|
||||||
|
)
|
||||||
|
self._patch_pos_cache[key] = pos
|
||||||
|
return pos.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={},
|
||||||
|
lq_latent=None, degrade_sigma=None, **kwargs):
|
||||||
|
B, _, H, W = x.shape
|
||||||
|
Hs = H // self.patch_size
|
||||||
|
Ws = W // self.patch_size
|
||||||
|
L = Hs * Ws
|
||||||
|
|
||||||
|
if context is None or context.dim() != 3:
|
||||||
|
raise ValueError("PidNet requires context [B, L, D]")
|
||||||
|
if lq_latent is None:
|
||||||
|
raise ValueError("PidNet requires lq_latent — attach via PiDConditioning")
|
||||||
|
|
||||||
|
if degrade_sigma is None:
|
||||||
|
degrade_sigma = torch.zeros(B, device=x.device, dtype=torch.float32)
|
||||||
|
elif not isinstance(degrade_sigma, torch.Tensor):
|
||||||
|
degrade_sigma = torch.tensor([float(degrade_sigma)] * B, device=x.device, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
|
||||||
|
if degrade_sigma.numel() == 1 and B > 1:
|
||||||
|
degrade_sigma = degrade_sigma.expand(B).contiguous()
|
||||||
|
|
||||||
|
lq_latent = lq_latent.to(device=x.device, dtype=x.dtype)
|
||||||
|
lq_features = self.lq_proj(lq_latent=lq_latent, target_pH=Hs, target_pW=Ws)
|
||||||
|
|
||||||
|
pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype)
|
||||||
|
x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||||
|
|
||||||
|
t_emb = self.t_embedder(timesteps.view(-1)).view(B, -1, self.hidden_size)
|
||||||
|
|
||||||
|
Ltxt = min(context.shape[1], self.txt_max_length)
|
||||||
|
y = context[:, :Ltxt, :]
|
||||||
|
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
|
||||||
|
# y_pos_embedding is raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
|
||||||
|
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(device=y_emb.device, dtype=y_emb.dtype)
|
||||||
|
|
||||||
|
condition = F.silu(t_emb)
|
||||||
|
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
|
||||||
|
|
||||||
|
s = self.s_embedder(x_patches)
|
||||||
|
for i, blk in enumerate(self.patch_blocks):
|
||||||
|
if self.lq_proj.is_gate_active(i):
|
||||||
|
out_idx = self.lq_proj.output_index(i)
|
||||||
|
if out_idx < len(lq_features):
|
||||||
|
s = self.lq_proj.gate(s, lq_features[out_idx], degrade_sigma, out_idx)
|
||||||
|
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None,
|
||||||
|
transformer_options=transformer_options)
|
||||||
|
s = F.silu(t_emb + s)
|
||||||
|
|
||||||
|
s_cond = s.view(B * L, self.hidden_size)
|
||||||
|
x_pixels = self.pixel_embedder(x, img_height=H, img_width=W, patch_size=self.patch_size)
|
||||||
|
for blk in self.pixel_blocks:
|
||||||
|
x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None,
|
||||||
|
transformer_options=transformer_options)
|
||||||
|
|
||||||
|
x_pixels = self.final_layer(x_pixels)
|
||||||
|
C_out = self.out_channels
|
||||||
|
P2 = self.patch_size * self.patch_size
|
||||||
|
x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).contiguous()
|
||||||
|
x_pixels = x_pixels.view(B, C_out * P2, L)
|
||||||
|
return F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||||
@ -48,6 +48,8 @@ import comfy.ldm.hunyuan3d.model
|
|||||||
import comfy.ldm.hidream.model
|
import comfy.ldm.hidream.model
|
||||||
import comfy.ldm.chroma.model
|
import comfy.ldm.chroma.model
|
||||||
import comfy.ldm.chroma_radiance.model
|
import comfy.ldm.chroma_radiance.model
|
||||||
|
import comfy.ldm.pixeldit.model
|
||||||
|
import comfy.ldm.pixeldit.pid
|
||||||
import comfy.ldm.ace.model
|
import comfy.ldm.ace.model
|
||||||
import comfy.ldm.omnigen.omnigen2
|
import comfy.ldm.omnigen.omnigen2
|
||||||
import comfy.ldm.qwen_image.model
|
import comfy.ldm.qwen_image.model
|
||||||
@ -1296,6 +1298,41 @@ class ZImagePixelSpace(Lumina2):
|
|||||||
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
||||||
self.memory_usage_factor_conds = ("ref_latents",)
|
self.memory_usage_factor_conds = ("ref_latents",)
|
||||||
|
|
||||||
|
|
||||||
|
class PixelDiTT2I(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device,
|
||||||
|
unet_model=comfy.ldm.pixeldit.model.PixDiT_T2I)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PiD(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device,
|
||||||
|
unet_model=comfy.ldm.pixeldit.pid.PidNet)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
lq_latent = kwargs.get("lq_latent", None)
|
||||||
|
if lq_latent is not None:
|
||||||
|
out["lq_latent"] = comfy.conds.CONDRegular(lq_latent)
|
||||||
|
degrade_sigma = kwargs.get("degrade_sigma", None)
|
||||||
|
if degrade_sigma is not None:
|
||||||
|
if not isinstance(degrade_sigma, torch.Tensor):
|
||||||
|
degrade_sigma = torch.tensor([float(degrade_sigma)], dtype=torch.float32)
|
||||||
|
out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class WAN21(BaseModel):
|
class WAN21(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
|
|||||||
@ -424,6 +424,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
# PiD (Pixel Diffusion Decoder). Must check BEFORE plain PixelDiT_T2I.
|
||||||
|
_lq_w_key = '{}lq_proj.latent_proj.0.weight'.format(key_prefix)
|
||||||
|
if _lq_w_key in state_dict_keys:
|
||||||
|
in_ch = int(state_dict[_lq_w_key].shape[1])
|
||||||
|
_gate_prefix = '{}lq_proj.gate_modules.'.format(key_prefix)
|
||||||
|
num_gates = len({k[len(_gate_prefix):].split('.')[0]
|
||||||
|
for k in state_dict_keys if k.startswith(_gate_prefix)})
|
||||||
|
dit_config = {"image_model": "pid",
|
||||||
|
"lq_latent_channels": in_ch,
|
||||||
|
"latent_spatial_down_factor": 16 if in_ch >= 64 else 8}
|
||||||
|
if num_gates > 0:
|
||||||
|
dit_config["lq_interval"] = (14 + num_gates - 1) // num_gates
|
||||||
|
return dit_config
|
||||||
|
|
||||||
|
if '{}core.pixel_embedder.proj.weight'.format(key_prefix) in state_dict_keys: # PixelDiT T2I
|
||||||
|
return {"image_model": "pixeldit_t2i"}
|
||||||
|
|
||||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "lumina2"
|
dit_config["image_model"] = "lumina2"
|
||||||
|
|||||||
10
comfy/sd.py
10
comfy/sd.py
@ -49,6 +49,7 @@ import comfy.text_encoders.lt
|
|||||||
import comfy.text_encoders.hunyuan_video
|
import comfy.text_encoders.hunyuan_video
|
||||||
import comfy.text_encoders.cosmos
|
import comfy.text_encoders.cosmos
|
||||||
import comfy.text_encoders.lumina2
|
import comfy.text_encoders.lumina2
|
||||||
|
import comfy.text_encoders.pixeldit
|
||||||
import comfy.text_encoders.wan
|
import comfy.text_encoders.wan
|
||||||
import comfy.text_encoders.hidream
|
import comfy.text_encoders.hidream
|
||||||
import comfy.text_encoders.ace
|
import comfy.text_encoders.ace
|
||||||
@ -1228,6 +1229,7 @@ class CLIPType(Enum):
|
|||||||
FLUX2 = 25
|
FLUX2 = 25
|
||||||
LONGCAT_IMAGE = 26
|
LONGCAT_IMAGE = 26
|
||||||
COGVIDEOX = 27
|
COGVIDEOX = 27
|
||||||
|
PIXELDIT = 28
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1460,8 +1462,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.tokenizer = variant.tokenizer
|
clip_target.tokenizer = variant.tokenizer
|
||||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||||
elif te_model == TEModel.GEMMA_2_2B:
|
elif te_model == TEModel.GEMMA_2_2B:
|
||||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
if clip_type == CLIPType.PIXELDIT:
|
||||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
clip_target.clip = comfy.text_encoders.pixeldit.PixelDiTGemma2TE
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer
|
||||||
|
else:
|
||||||
|
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif te_model == TEModel.GEMMA_3_4B:
|
elif te_model == TEModel.GEMMA_3_4B:
|
||||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
|
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
|
||||||
|
|||||||
@ -29,6 +29,7 @@ import comfy.text_encoders.longcat_image
|
|||||||
import comfy.text_encoders.ernie
|
import comfy.text_encoders.ernie
|
||||||
import comfy.text_encoders.cogvideo
|
import comfy.text_encoders.cogvideo
|
||||||
import comfy.text_encoders.hidream_o1
|
import comfy.text_encoders.hidream_o1
|
||||||
|
import comfy.text_encoders.pixeldit
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -1135,6 +1136,71 @@ class ZImagePixelSpace(ZImage):
|
|||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
return model_base.ZImagePixelSpace(self, device=device)
|
return model_base.ZImagePixelSpace(self, device=device)
|
||||||
|
|
||||||
|
class PixelDiTT2I(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "pixeldit_t2i",
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 4.0, # 1024px stage 3 default; 2.0 for 512px
|
||||||
|
"multiplier": 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.PixelDiTPixel
|
||||||
|
memory_usage_factor = 0.7
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
return model_base.PixelDiTT2I(self, device=device)
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
out = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if k.startswith("_repa_projector"):
|
||||||
|
continue
|
||||||
|
if k.startswith("core."):
|
||||||
|
out[k[len("core."):]] = v
|
||||||
|
else:
|
||||||
|
out[k] = v
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return supported_models_base.ClipTarget(
|
||||||
|
comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer,
|
||||||
|
comfy.text_encoders.pixeldit.PixelDiTGemma2TE,
|
||||||
|
)
|
||||||
|
|
||||||
|
class PiD(PixelDiTT2I):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "pid",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0]
|
||||||
|
"multiplier": 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
return model_base.PiD(self, device=device)
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
out = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if k.startswith("_repa_projector") or k.startswith("net_ema."):
|
||||||
|
continue
|
||||||
|
if k.startswith("core."):
|
||||||
|
out[k[len("core."):]] = v
|
||||||
|
elif k.startswith("net."):
|
||||||
|
out[k[len("net."):]] = v
|
||||||
|
else:
|
||||||
|
out[k] = v
|
||||||
|
return out
|
||||||
|
|
||||||
class WAN21_T2V(supported_models_base.BASE):
|
class WAN21_T2V(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
@ -2044,6 +2110,8 @@ models = [
|
|||||||
CosmosI2VPredict2,
|
CosmosI2VPredict2,
|
||||||
ZImagePixelSpace,
|
ZImagePixelSpace,
|
||||||
ZImage,
|
ZImage,
|
||||||
|
PiD,
|
||||||
|
PixelDiTT2I,
|
||||||
Lumina2,
|
Lumina2,
|
||||||
WAN22_T2V,
|
WAN22_T2V,
|
||||||
WAN21_CausalAR_T2V,
|
WAN21_CausalAR_T2V,
|
||||||
|
|||||||
124
comfy/text_encoders/pixeldit.py
Normal file
124
comfy/text_encoders/pixeldit.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy import sd1_clip
|
||||||
|
from .lumina2 import Gemma2BTokenizer, LuminaModel
|
||||||
|
import comfy.text_encoders.llama
|
||||||
|
|
||||||
|
|
||||||
|
class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel):
|
||||||
|
"""Gemma-2-2b-it text encoder for PixelDiT.
|
||||||
|
|
||||||
|
Uses the FINAL hidden state (layer='last')
|
||||||
|
"""
|
||||||
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
super().__init__(
|
||||||
|
device=device, layer=layer, layer_idx=layer_idx,
|
||||||
|
textmodel_json_config={}, dtype=dtype,
|
||||||
|
special_tokens={"start": 2, "pad": 0},
|
||||||
|
layer_norm_hidden_state=False,
|
||||||
|
model_class=comfy.text_encoders.llama.Gemma2_2B,
|
||||||
|
enable_attention_masks=attention_mask,
|
||||||
|
return_attention_masks=attention_mask,
|
||||||
|
model_options=model_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_PIXELDIT_CHI_PROMPT = (
|
||||||
|
'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions '
|
||||||
|
"suitable for image generation. Evaluate the level of detail in the user prompt:\n"
|
||||||
|
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, "
|
||||||
|
"and spatial relationships to create vivid and concrete scenes.\n"
|
||||||
|
"- If the prompt is already detailed, refine and enhance the existing details slightly without "
|
||||||
|
"overcomplicating.\n"
|
||||||
|
"Here are examples of how to transform or refine prompts:\n"
|
||||||
|
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, "
|
||||||
|
"sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n"
|
||||||
|
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring "
|
||||||
|
"glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus "
|
||||||
|
"passing by towering glass skyscrapers.\n"
|
||||||
|
"Please generate only the enhanced description for the prompt below and avoid including any "
|
||||||
|
"additional commentary or evaluations:\n"
|
||||||
|
"User Prompt: "
|
||||||
|
)
|
||||||
|
|
||||||
|
_PIXELDIT_MAX_LENGTH = 300
|
||||||
|
_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"'
|
||||||
|
|
||||||
|
|
||||||
|
def _build_padded_tokens(combined_text: str, spiece_tokenizer, pad_id: int, chi_token_count: int):
|
||||||
|
# Right-pad to chi_token_count + 300 - 2 (matches upstream's max_length_all).
|
||||||
|
max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2
|
||||||
|
ids = spiece_tokenizer(combined_text)["input_ids"]
|
||||||
|
if len(ids) > max_length_all:
|
||||||
|
ids = ids[:max_length_all]
|
||||||
|
elif len(ids) < max_length_all:
|
||||||
|
ids = ids + [pad_id] * (max_length_all - len(ids))
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
"""Gemma-2-2b-it tokenizer that prepends PixelDiT's chi_prompt.
|
||||||
|
|
||||||
|
Empty text -> BOS + pad to 300. Text already starting with the chi_prompt
|
||||||
|
preamble is tokenized verbatim (override mirrors QwenImageTokenizer's
|
||||||
|
`<|im_start|>` detection). Else chi_prompt is prepended.
|
||||||
|
"""
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
|
if tokenizer_data is None:
|
||||||
|
tokenizer_data = {}
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
|
||||||
|
name="gemma2_2b", tokenizer=Gemma2BTokenizer)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||||
|
spiece_tokenizer = self.gemma2_2b.tokenizer
|
||||||
|
pad_id = self.gemma2_2b.pad_token
|
||||||
|
|
||||||
|
if not (isinstance(text, str) and text.strip()):
|
||||||
|
ids = spiece_tokenizer("")["input_ids"]
|
||||||
|
ids = ids + [pad_id] * (_PIXELDIT_MAX_LENGTH - len(ids))
|
||||||
|
return {"gemma2_2b": [[(t, 1.0) for t in ids]]}
|
||||||
|
|
||||||
|
chi_token_count = len(spiece_tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"])
|
||||||
|
combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text
|
||||||
|
ids = _build_padded_tokens(combined, spiece_tokenizer, pad_id, chi_token_count)
|
||||||
|
return {"gemma2_2b": [[(t, 1.0) for t in ids]]}
|
||||||
|
|
||||||
|
def untokenize(self, token_weight_pair):
|
||||||
|
return self.gemma2_2b.untokenize(token_weight_pair)
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return self.gemma2_2b.state_dict()
|
||||||
|
|
||||||
|
|
||||||
|
class PixelDiTGemma2TE(LuminaModel):
|
||||||
|
"""Text encoder wrapper for PixelDiT.
|
||||||
|
|
||||||
|
Overrides `encode_token_weights` to perform PixelDiT's `select_index` step:
|
||||||
|
encode the full padded sequence (up to ~chi_prompt_tokens + 298), then
|
||||||
|
return `[BOS_emb] + last_299_embs` as the 300-position conditioning that
|
||||||
|
matches the diffusion model's learned `y_pos_embedding` positions.
|
||||||
|
"""
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, name="gemma2_2b",
|
||||||
|
clip_model=PixelDiTGemma2_2BModel, model_options=model_options)
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
result = super().encode_token_weights(token_weight_pairs)
|
||||||
|
cond, pooled = result[0], result[1]
|
||||||
|
extra = result[2] if len(result) > 2 else None
|
||||||
|
L = cond.shape[1]
|
||||||
|
if L > _PIXELDIT_MAX_LENGTH:
|
||||||
|
head = cond[:, :1]
|
||||||
|
tail = cond[:, -(_PIXELDIT_MAX_LENGTH - 1):]
|
||||||
|
cond = torch.cat([head, tail], dim=1)
|
||||||
|
if extra is not None and "attention_mask" in extra:
|
||||||
|
am = extra["attention_mask"]
|
||||||
|
if am.dim() == 1:
|
||||||
|
am = am.unsqueeze(0)
|
||||||
|
if am.shape[-1] == L:
|
||||||
|
head_m = am[..., :1]
|
||||||
|
tail_m = am[..., -(_PIXELDIT_MAX_LENGTH - 1):]
|
||||||
|
extra = {**extra, "attention_mask": torch.cat([head_m, tail_m], dim=-1)}
|
||||||
|
if extra is not None:
|
||||||
|
return cond, pooled, extra
|
||||||
|
return cond, pooled
|
||||||
BIN
comfy_cuda_memory_history20260525_155146.pt
Normal file
BIN
comfy_cuda_memory_history20260525_155146.pt
Normal file
Binary file not shown.
60
comfy_extras/nodes_pid.py
Normal file
60
comfy_extras/nodes_pid.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
"""PiD (Pixel Diffusion Decoder) node"""
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
import node_helpers
|
||||||
|
import comfy.latent_formats
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
_LATENT_FORMAT_CLASSES = {
|
||||||
|
"flux": comfy.latent_formats.Flux,
|
||||||
|
"sd3": comfy.latent_formats.SD3,
|
||||||
|
"flux2": comfy.latent_formats.Flux2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PiDConditioning(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="PiDConditioning",
|
||||||
|
display_name="PiD Conditioning",
|
||||||
|
category="advanced/conditioning",
|
||||||
|
description=(
|
||||||
|
"Attaches an LDM latent (Flux/SD3/Flux2/Z-Image) and a degrade_sigma scalar "
|
||||||
|
"to a CONDITIONING for PiD decoding. Latent is renormalized into PiD space "
|
||||||
|
"via the chosen latent_format. Z-Image uses 'flux'."
|
||||||
|
),
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Latent.Input("latent", tooltip="LDM latent (from VAEEncode or a KSampler)."),
|
||||||
|
io.Combo.Input(
|
||||||
|
"latent_format",
|
||||||
|
options=list(_LATENT_FORMAT_CLASSES.keys()),
|
||||||
|
default="flux",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
"degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01,
|
||||||
|
tooltip="0 = clean latent. Increase to denoise corrupted LDM outputs.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Conditioning.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, latent, latent_format: str, degrade_sigma: float) -> io.NodeOutput:
|
||||||
|
lq_latent = _LATENT_FORMAT_CLASSES[latent_format]().process_in(latent["samples"])
|
||||||
|
return io.NodeOutput(node_helpers.conditioning_set_values(
|
||||||
|
positive, {"lq_latent": lq_latent, "degrade_sigma": float(degrade_sigma)},
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
class PiDExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [PiDConditioning]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> PiDExtension:
|
||||||
|
return PiDExtension()
|
||||||
3
nodes.py
3
nodes.py
@ -958,7 +958,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "pixeldit"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -2405,6 +2405,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_context_windows.py",
|
"nodes_context_windows.py",
|
||||||
"nodes_qwen.py",
|
"nodes_qwen.py",
|
||||||
"nodes_chroma_radiance.py",
|
"nodes_chroma_radiance.py",
|
||||||
|
"nodes_pid.py",
|
||||||
"nodes_model_patch.py",
|
"nodes_model_patch.py",
|
||||||
"nodes_easycache.py",
|
"nodes_easycache.py",
|
||||||
"nodes_audio_encoder.py",
|
"nodes_audio_encoder.py",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user