This commit is contained in:
kijai 2026-05-25 21:05:09 +03:00
parent 1bb3bea2d3
commit 238f8aa9fa
6 changed files with 62 additions and 76 deletions

View File

@ -13,7 +13,8 @@ from .modules import (
PatchTokenEmbedder,
PiTBlock,
PixelTokenEmbedder,
apply_adaln,
_cache_set,
apply_adaln_,
precompute_freqs_cis_2d,
)
@ -107,14 +108,14 @@ class MMDiTBlockT2I(nn.Module):
shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1)
shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1)
x_norm = apply_adaln(self.norm_x1(x), shift_msa_x, scale_msa_x)
y_norm = apply_adaln(self.norm_y1(y), shift_msa_y, scale_msa_y)
x_norm = apply_adaln_(self.norm_x1(x), shift_msa_x, scale_msa_x)
y_norm = apply_adaln_(self.norm_y1(y), shift_msa_y, scale_msa_y)
attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options)
x = torch.addcmul(x, gate_msa_x, attn_x)
y = torch.addcmul(y, gate_msa_y, attn_y)
x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln_(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln_(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
return x, y
@ -216,14 +217,14 @@ class PixDiT_T2I(nn.Module):
pos = self._patch_pos_cache.get(key)
if pos is None:
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width)
self._patch_pos_cache[key] = pos
_cache_set(self._patch_pos_cache, key, pos)
return pos.to(device=device, dtype=dtype)
def _fetch_text_pos(self, length, device, dtype):
pos = self._text_pos_cache.get(length)
if pos is None:
pos = rope(torch.arange(length, dtype=torch.float32).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0)
self._text_pos_cache[length] = pos
_cache_set(self._text_pos_cache, length, pos)
return pos.to(device=device, dtype=dtype)
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
@ -233,6 +234,10 @@ class PixDiT_T2I(nn.Module):
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
def _pre_patch_block(self, s, i, **kwargs):
"""Hook for subclasses to inject per-block state into the patch stream (e.g. PiD's LQ gate)."""
return s
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
B, _, H, W = x.shape
Hs = H // self.patch_size
@ -249,13 +254,14 @@ class PixDiT_T2I(nn.Module):
Ltxt = min(context.shape[1], self.txt_max_length)
y = context[:, :Ltxt, :]
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb.dtype)
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb) # y_pos_embedding is a raw nn.Parameter
condition = F.silu(t_emb)
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
s = self.s_embedder(x_patches)
for blk in self.patch_blocks:
for i, blk in enumerate(self.patch_blocks):
s = self._pre_patch_block(s, i, **kwargs)
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options)
s = F.silu(t_emb + s)

View File

@ -6,10 +6,20 @@ from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp
def apply_adaln(x, shift, scale):
def apply_adaln_(x, shift, scale):
return x.addcmul_(x, scale).add_(shift)
_POS_CACHE_MAX = 16
def _cache_set(cache, key, value):
"""Set with a soft LRU cap — evicts the oldest entry if at capacity."""
if len(cache) >= _POS_CACHE_MAX:
del cache[next(iter(cache))]
cache[key] = value
def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0, device=None, dtype=torch.float32):
"""2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim.
@ -119,7 +129,7 @@ class PixelTokenEmbedder(nn.Module):
pe = self._pos_cache.get(key)
if pe is None:
pe = get_2d_sincos_pos_embed(self.hidden_size_output, height, width)
self._pos_cache[key] = pe
_cache_set(self._pos_cache, key, pe)
return pe.to(device=device, dtype=dtype)
def forward(self, inputs, img_height, img_width, patch_size):
@ -176,7 +186,7 @@ class PiTBlock(nn.Module):
pos = self._pos_cache.get(key)
if pos is None:
pos = self._rope_fn(self.attn_dim // self.num_heads, height, width)
self._pos_cache[key] = pos
_cache_set(self._pos_cache, key, pos)
return pos.to(device=device, dtype=dtype)
def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}):
@ -188,7 +198,7 @@ class PiTBlock(nn.Module):
# Attention path uses only msa params; compute, use, free before mlp params allocate.
msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim)
shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1)
x_norm = apply_adaln(self.norm1(x), shift_msa, scale_msa)
x_norm = apply_adaln_(self.norm1(x), shift_msa, scale_msa)
x_flat = x_norm.view(BL, P2 * self.pixel_dim)
x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype)
@ -201,7 +211,7 @@ class PiTBlock(nn.Module):
mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim)
shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1)
gate_mlp = gate_mlp.contiguous()
mlp_input = apply_adaln(self.norm2(x), shift_mlp, scale_mlp)
mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp)
del mlp_params, shift_mlp, scale_mlp
chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks
for s in range(0, BL, chunk_size):

View File

@ -12,6 +12,7 @@ import torch.nn.functional as F
from comfy.ldm.flux.math import rope
from .model import PixDiT_T2I
from .modules import _cache_set
def precompute_freqs_cis_2d_ntk(dim: int, height: int, width: int,
@ -221,66 +222,36 @@ class PidNet(PixDiT_T2I):
height, width,
self.rope_ref_grid_h, self.rope_ref_grid_w,
)
self._patch_pos_cache[key] = pos
_cache_set(self._patch_pos_cache, key, pos)
return pos.to(device=device, dtype=dtype)
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={},
lq_latent=None, degrade_sigma=None, **kwargs):
B, _, H, W = x.shape
Hs = H // self.patch_size
Ws = W // self.patch_size
L = Hs * Ws
def _pre_patch_block(self, s, i, pid_lq_features=None, pid_degrade_sigma=None, **kwargs):
if pid_lq_features is None or not self.lq_proj.is_gate_active(i):
return s
out_idx = self.lq_proj.output_index(i)
if out_idx >= len(pid_lq_features):
return s
return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx)
if context is None or context.dim() != 3:
raise ValueError("PidNet requires context [B, L, D]")
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs):
if lq_latent is None:
raise ValueError("PidNet requires lq_latent — attach via PiDConditioning")
B = x.shape[0]
Hs = x.shape[2] // self.patch_size
Ws = x.shape[3] // self.patch_size
if degrade_sigma is None:
degrade_sigma = torch.zeros(B, device=x.device, dtype=torch.float32)
elif not isinstance(degrade_sigma, torch.Tensor):
degrade_sigma = torch.tensor([float(degrade_sigma)] * B, device=x.device, dtype=torch.float32)
else:
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
if degrade_sigma.numel() == 1 and B > 1:
degrade_sigma = degrade_sigma.expand(B).contiguous()
degrade_sigma = torch.as_tensor(degrade_sigma if degrade_sigma is not None else 0.0, device=x.device, dtype=torch.float32).reshape(-1)
if degrade_sigma.numel() == 1 and B > 1:
degrade_sigma = degrade_sigma.expand(B).contiguous()
lq_latent = lq_latent.to(device=x.device, dtype=x.dtype)
lq_features = self.lq_proj(lq_latent=lq_latent, target_pH=Hs, target_pW=Ws)
pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype)
x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size)
Ltxt = min(context.shape[1], self.txt_max_length)
y = context[:, :Ltxt, :]
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
# y_pos_embedding is raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(device=y_emb.device, dtype=y_emb.dtype)
condition = F.silu(t_emb)
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
s = self.s_embedder(x_patches)
for i, blk in enumerate(self.patch_blocks):
if self.lq_proj.is_gate_active(i):
out_idx = self.lq_proj.output_index(i)
if out_idx < len(lq_features):
s = self.lq_proj.gate(s, lq_features[out_idx], degrade_sigma, out_idx)
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None,
transformer_options=transformer_options)
s = F.silu(t_emb + s)
s_cond = s.view(B * L, self.hidden_size)
x_pixels = self.pixel_embedder(x, img_height=H, img_width=W, patch_size=self.patch_size)
for blk in self.pixel_blocks:
x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None,
transformer_options=transformer_options)
x_pixels = self.final_layer(x_pixels)
C_out = self.out_channels
P2 = self.patch_size * self.patch_size
x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).contiguous()
x_pixels = x_pixels.view(B, C_out * P2, L)
return F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return super()._forward(
x, timesteps,
context=context, attention_mask=attention_mask,
transformer_options=transformer_options,
pid_lq_features=lq_features,
pid_degrade_sigma=degrade_sigma,
**kwargs,
)

View File

@ -93,10 +93,8 @@ class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer):
class PixelDiTGemma2TE(LuminaModel):
"""Text encoder wrapper for PixelDiT.
Overrides `encode_token_weights` to perform PixelDiT's `select_index` step:
encode the full padded sequence (up to ~chi_prompt_tokens + 298), then
return `[BOS_emb] + last_299_embs` as the 300-position conditioning that
matches the diffusion model's learned `y_pos_embedding` positions.
Encodes the full padded sequence, then returns BOS + last 299 embeddings
(PixelDiT's `select_index` step) to match the trained y_pos_embedding length.
"""
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="gemma2_2b",

View File

@ -7,10 +7,11 @@ import comfy.latent_formats
from comfy_api.latest import ComfyExtension, io
# Since this can be used only as upscaler with VAE, can't depend on latent format detection from any model
_LATENT_FORMAT_CLASSES = {
"flux": comfy.latent_formats.Flux,
"sd3": comfy.latent_formats.SD3,
"flux1": comfy.latent_formats.Flux,
"flux2": comfy.latent_formats.Flux2,
"sd3": comfy.latent_formats.SD3,
}
@ -22,9 +23,9 @@ class PiDConditioning(io.ComfyNode):
display_name="PiD Conditioning",
category="advanced/conditioning",
description=(
"Attaches an LDM latent (Flux/SD3/Flux2/Z-Image) and a degrade_sigma scalar "
"Attaches an LDM latent (Flux1/Flux2/SD3) and a degrade_sigma scalar "
"to a CONDITIONING for PiD decoding. Latent is renormalized into PiD space "
"via the chosen latent_format. Z-Image uses 'flux'."
"via the chosen latent_format. Z-Image uses 'flux1'."
),
inputs=[
io.Conditioning.Input("positive"),
@ -32,7 +33,7 @@ class PiDConditioning(io.ComfyNode):
io.Combo.Input(
"latent_format",
options=list(_LATENT_FORMAT_CLASSES.keys()),
default="flux",
default="flux1",
),
io.Float.Input(
"degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01,