From 238f8aa9fa0115e30f40cc7a00627ceef0b1c2e9 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 25 May 2026 21:05:09 +0300 Subject: [PATCH] Cleanup --- comfy/ldm/pixeldit/__init__.py | 0 comfy/ldm/pixeldit/model.py | 24 ++++++---- comfy/ldm/pixeldit/modules.py | 20 ++++++--- comfy/ldm/pixeldit/pid.py | 77 ++++++++++----------------------- comfy/text_encoders/pixeldit.py | 6 +-- comfy_extras/nodes_pid.py | 11 ++--- 6 files changed, 62 insertions(+), 76 deletions(-) delete mode 100644 comfy/ldm/pixeldit/__init__.py diff --git a/comfy/ldm/pixeldit/__init__.py b/comfy/ldm/pixeldit/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/comfy/ldm/pixeldit/model.py b/comfy/ldm/pixeldit/model.py index a76307099..ece992c22 100644 --- a/comfy/ldm/pixeldit/model.py +++ b/comfy/ldm/pixeldit/model.py @@ -13,7 +13,8 @@ from .modules import ( PatchTokenEmbedder, PiTBlock, PixelTokenEmbedder, - apply_adaln, + _cache_set, + apply_adaln_, precompute_freqs_cis_2d, ) @@ -107,14 +108,14 @@ class MMDiTBlockT2I(nn.Module): shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1) shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1) - x_norm = apply_adaln(self.norm_x1(x), shift_msa_x, scale_msa_x) - y_norm = apply_adaln(self.norm_y1(y), shift_msa_y, scale_msa_y) + x_norm = apply_adaln_(self.norm_x1(x), shift_msa_x, scale_msa_x) + y_norm = apply_adaln_(self.norm_y1(y), shift_msa_y, scale_msa_y) attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options) x = torch.addcmul(x, gate_msa_x, attn_x) y = torch.addcmul(y, gate_msa_y, attn_y) - x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln(self.norm_x2(x), shift_mlp_x, scale_mlp_x))) - y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln(self.norm_y2(y), shift_mlp_y, scale_mlp_y))) + x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln_(self.norm_x2(x), shift_mlp_x, scale_mlp_x))) + y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln_(self.norm_y2(y), shift_mlp_y, scale_mlp_y))) return x, y @@ -216,14 +217,14 @@ class PixDiT_T2I(nn.Module): pos = self._patch_pos_cache.get(key) if pos is None: pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width) - self._patch_pos_cache[key] = pos + _cache_set(self._patch_pos_cache, key, pos) return pos.to(device=device, dtype=dtype) def _fetch_text_pos(self, length, device, dtype): pos = self._text_pos_cache.get(length) if pos is None: pos = rope(torch.arange(length, dtype=torch.float32).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0) - self._text_pos_cache[length] = pos + _cache_set(self._text_pos_cache, length, pos) return pos.to(device=device, dtype=dtype) def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs): @@ -233,6 +234,10 @@ class PixDiT_T2I(nn.Module): comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options), ).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs) + def _pre_patch_block(self, s, i, **kwargs): + """Hook for subclasses to inject per-block state into the patch stream (e.g. PiD's LQ gate).""" + return s + def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs): B, _, H, W = x.shape Hs = H // self.patch_size @@ -249,13 +254,14 @@ class PixDiT_T2I(nn.Module): Ltxt = min(context.shape[1], self.txt_max_length) y = context[:, :Ltxt, :] y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size) - y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb.dtype) + y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb) # y_pos_embedding is a raw nn.Parameter condition = F.silu(t_emb) pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None s = self.s_embedder(x_patches) - for blk in self.patch_blocks: + for i, blk in enumerate(self.patch_blocks): + s = self._pre_patch_block(s, i, **kwargs) s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options) s = F.silu(t_emb + s) diff --git a/comfy/ldm/pixeldit/modules.py b/comfy/ldm/pixeldit/modules.py index 2f9dd6174..72067643d 100644 --- a/comfy/ldm/pixeldit/modules.py +++ b/comfy/ldm/pixeldit/modules.py @@ -6,10 +6,20 @@ from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.diffusionmodules.mmdit import Mlp -def apply_adaln(x, shift, scale): +def apply_adaln_(x, shift, scale): return x.addcmul_(x, scale).add_(shift) +_POS_CACHE_MAX = 16 + + +def _cache_set(cache, key, value): + """Set with a soft LRU cap — evicts the oldest entry if at capacity.""" + if len(cache) >= _POS_CACHE_MAX: + del cache[next(iter(cache))] + cache[key] = value + + def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0, device=None, dtype=torch.float32): """2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim. @@ -119,7 +129,7 @@ class PixelTokenEmbedder(nn.Module): pe = self._pos_cache.get(key) if pe is None: pe = get_2d_sincos_pos_embed(self.hidden_size_output, height, width) - self._pos_cache[key] = pe + _cache_set(self._pos_cache, key, pe) return pe.to(device=device, dtype=dtype) def forward(self, inputs, img_height, img_width, patch_size): @@ -176,7 +186,7 @@ class PiTBlock(nn.Module): pos = self._pos_cache.get(key) if pos is None: pos = self._rope_fn(self.attn_dim // self.num_heads, height, width) - self._pos_cache[key] = pos + _cache_set(self._pos_cache, key, pos) return pos.to(device=device, dtype=dtype) def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}): @@ -188,7 +198,7 @@ class PiTBlock(nn.Module): # Attention path uses only msa params; compute, use, free before mlp params allocate. msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim) shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1) - x_norm = apply_adaln(self.norm1(x), shift_msa, scale_msa) + x_norm = apply_adaln_(self.norm1(x), shift_msa, scale_msa) x_flat = x_norm.view(BL, P2 * self.pixel_dim) x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim) pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype) @@ -201,7 +211,7 @@ class PiTBlock(nn.Module): mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim) shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1) gate_mlp = gate_mlp.contiguous() - mlp_input = apply_adaln(self.norm2(x), shift_mlp, scale_mlp) + mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp) del mlp_params, shift_mlp, scale_mlp chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks for s in range(0, BL, chunk_size): diff --git a/comfy/ldm/pixeldit/pid.py b/comfy/ldm/pixeldit/pid.py index f95698d68..7283d9788 100644 --- a/comfy/ldm/pixeldit/pid.py +++ b/comfy/ldm/pixeldit/pid.py @@ -12,6 +12,7 @@ import torch.nn.functional as F from comfy.ldm.flux.math import rope from .model import PixDiT_T2I +from .modules import _cache_set def precompute_freqs_cis_2d_ntk(dim: int, height: int, width: int, @@ -221,66 +222,36 @@ class PidNet(PixDiT_T2I): height, width, self.rope_ref_grid_h, self.rope_ref_grid_w, ) - self._patch_pos_cache[key] = pos + _cache_set(self._patch_pos_cache, key, pos) return pos.to(device=device, dtype=dtype) - def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, - lq_latent=None, degrade_sigma=None, **kwargs): - B, _, H, W = x.shape - Hs = H // self.patch_size - Ws = W // self.patch_size - L = Hs * Ws + def _pre_patch_block(self, s, i, pid_lq_features=None, pid_degrade_sigma=None, **kwargs): + if pid_lq_features is None or not self.lq_proj.is_gate_active(i): + return s + out_idx = self.lq_proj.output_index(i) + if out_idx >= len(pid_lq_features): + return s + return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx) - if context is None or context.dim() != 3: - raise ValueError("PidNet requires context [B, L, D]") + def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs): if lq_latent is None: raise ValueError("PidNet requires lq_latent — attach via PiDConditioning") + B = x.shape[0] + Hs = x.shape[2] // self.patch_size + Ws = x.shape[3] // self.patch_size - if degrade_sigma is None: - degrade_sigma = torch.zeros(B, device=x.device, dtype=torch.float32) - elif not isinstance(degrade_sigma, torch.Tensor): - degrade_sigma = torch.tensor([float(degrade_sigma)] * B, device=x.device, dtype=torch.float32) - else: - degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1) - if degrade_sigma.numel() == 1 and B > 1: - degrade_sigma = degrade_sigma.expand(B).contiguous() + degrade_sigma = torch.as_tensor(degrade_sigma if degrade_sigma is not None else 0.0, device=x.device, dtype=torch.float32).reshape(-1) + if degrade_sigma.numel() == 1 and B > 1: + degrade_sigma = degrade_sigma.expand(B).contiguous() lq_latent = lq_latent.to(device=x.device, dtype=x.dtype) lq_features = self.lq_proj(lq_latent=lq_latent, target_pH=Hs, target_pW=Ws) - pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype) - x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) - - t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size) - - Ltxt = min(context.shape[1], self.txt_max_length) - y = context[:, :Ltxt, :] - y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size) - # y_pos_embedding is raw nn.Parameter -> doesn't auto-cast under dynamic VRAM. - y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(device=y_emb.device, dtype=y_emb.dtype) - - condition = F.silu(t_emb) - pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None - - s = self.s_embedder(x_patches) - for i, blk in enumerate(self.patch_blocks): - if self.lq_proj.is_gate_active(i): - out_idx = self.lq_proj.output_index(i) - if out_idx < len(lq_features): - s = self.lq_proj.gate(s, lq_features[out_idx], degrade_sigma, out_idx) - s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, - transformer_options=transformer_options) - s = F.silu(t_emb + s) - - s_cond = s.view(B * L, self.hidden_size) - x_pixels = self.pixel_embedder(x, img_height=H, img_width=W, patch_size=self.patch_size) - for blk in self.pixel_blocks: - x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, - transformer_options=transformer_options) - - x_pixels = self.final_layer(x_pixels) - C_out = self.out_channels - P2 = self.patch_size * self.patch_size - x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).contiguous() - x_pixels = x_pixels.view(B, C_out * P2, L) - return F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return super()._forward( + x, timesteps, + context=context, attention_mask=attention_mask, + transformer_options=transformer_options, + pid_lq_features=lq_features, + pid_degrade_sigma=degrade_sigma, + **kwargs, + ) diff --git a/comfy/text_encoders/pixeldit.py b/comfy/text_encoders/pixeldit.py index 8853e3584..0dc65114b 100644 --- a/comfy/text_encoders/pixeldit.py +++ b/comfy/text_encoders/pixeldit.py @@ -93,10 +93,8 @@ class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer): class PixelDiTGemma2TE(LuminaModel): """Text encoder wrapper for PixelDiT. - Overrides `encode_token_weights` to perform PixelDiT's `select_index` step: - encode the full padded sequence (up to ~chi_prompt_tokens + 298), then - return `[BOS_emb] + last_299_embs` as the 300-position conditioning that - matches the diffusion model's learned `y_pos_embedding` positions. + Encodes the full padded sequence, then returns BOS + last 299 embeddings + (PixelDiT's `select_index` step) to match the trained y_pos_embedding length. """ def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__(device=device, dtype=dtype, name="gemma2_2b", diff --git a/comfy_extras/nodes_pid.py b/comfy_extras/nodes_pid.py index fdf8b3a9c..3fecd4bfa 100644 --- a/comfy_extras/nodes_pid.py +++ b/comfy_extras/nodes_pid.py @@ -7,10 +7,11 @@ import comfy.latent_formats from comfy_api.latest import ComfyExtension, io +# Since this can be used only as upscaler with VAE, can't depend on latent format detection from any model _LATENT_FORMAT_CLASSES = { - "flux": comfy.latent_formats.Flux, - "sd3": comfy.latent_formats.SD3, + "flux1": comfy.latent_formats.Flux, "flux2": comfy.latent_formats.Flux2, + "sd3": comfy.latent_formats.SD3, } @@ -22,9 +23,9 @@ class PiDConditioning(io.ComfyNode): display_name="PiD Conditioning", category="advanced/conditioning", description=( - "Attaches an LDM latent (Flux/SD3/Flux2/Z-Image) and a degrade_sigma scalar " + "Attaches an LDM latent (Flux1/Flux2/SD3) and a degrade_sigma scalar " "to a CONDITIONING for PiD decoding. Latent is renormalized into PiD space " - "via the chosen latent_format. Z-Image uses 'flux'." + "via the chosen latent_format. Z-Image uses 'flux1'." ), inputs=[ io.Conditioning.Input("positive"), @@ -32,7 +33,7 @@ class PiDConditioning(io.ComfyNode): io.Combo.Input( "latent_format", options=list(_LATENT_FORMAT_CLASSES.keys()), - default="flux", + default="flux1", ), io.Float.Input( "degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01,