mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 01:17:24 +08:00
Cleanup
This commit is contained in:
parent
1bb3bea2d3
commit
238f8aa9fa
@ -13,7 +13,8 @@ from .modules import (
|
||||
PatchTokenEmbedder,
|
||||
PiTBlock,
|
||||
PixelTokenEmbedder,
|
||||
apply_adaln,
|
||||
_cache_set,
|
||||
apply_adaln_,
|
||||
precompute_freqs_cis_2d,
|
||||
)
|
||||
|
||||
@ -107,14 +108,14 @@ class MMDiTBlockT2I(nn.Module):
|
||||
shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1)
|
||||
shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1)
|
||||
|
||||
x_norm = apply_adaln(self.norm_x1(x), shift_msa_x, scale_msa_x)
|
||||
y_norm = apply_adaln(self.norm_y1(y), shift_msa_y, scale_msa_y)
|
||||
x_norm = apply_adaln_(self.norm_x1(x), shift_msa_x, scale_msa_x)
|
||||
y_norm = apply_adaln_(self.norm_y1(y), shift_msa_y, scale_msa_y)
|
||||
attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options)
|
||||
x = torch.addcmul(x, gate_msa_x, attn_x)
|
||||
y = torch.addcmul(y, gate_msa_y, attn_y)
|
||||
|
||||
x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
|
||||
y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
|
||||
x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln_(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
|
||||
y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln_(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
|
||||
return x, y
|
||||
|
||||
|
||||
@ -216,14 +217,14 @@ class PixDiT_T2I(nn.Module):
|
||||
pos = self._patch_pos_cache.get(key)
|
||||
if pos is None:
|
||||
pos = precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width)
|
||||
self._patch_pos_cache[key] = pos
|
||||
_cache_set(self._patch_pos_cache, key, pos)
|
||||
return pos.to(device=device, dtype=dtype)
|
||||
|
||||
def _fetch_text_pos(self, length, device, dtype):
|
||||
pos = self._text_pos_cache.get(length)
|
||||
if pos is None:
|
||||
pos = rope(torch.arange(length, dtype=torch.float32).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0)
|
||||
self._text_pos_cache[length] = pos
|
||||
_cache_set(self._text_pos_cache, length, pos)
|
||||
return pos.to(device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||
@ -233,6 +234,10 @@ class PixDiT_T2I(nn.Module):
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
|
||||
|
||||
def _pre_patch_block(self, s, i, **kwargs):
|
||||
"""Hook for subclasses to inject per-block state into the patch stream (e.g. PiD's LQ gate)."""
|
||||
return s
|
||||
|
||||
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||
B, _, H, W = x.shape
|
||||
Hs = H // self.patch_size
|
||||
@ -249,13 +254,14 @@ class PixDiT_T2I(nn.Module):
|
||||
Ltxt = min(context.shape[1], self.txt_max_length)
|
||||
y = context[:, :Ltxt, :]
|
||||
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
|
||||
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb.dtype)
|
||||
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb) # y_pos_embedding is a raw nn.Parameter
|
||||
|
||||
condition = F.silu(t_emb)
|
||||
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
|
||||
|
||||
s = self.s_embedder(x_patches)
|
||||
for blk in self.patch_blocks:
|
||||
for i, blk in enumerate(self.patch_blocks):
|
||||
s = self._pre_patch_block(s, i, **kwargs)
|
||||
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options)
|
||||
s = F.silu(t_emb + s)
|
||||
|
||||
|
||||
@ -6,10 +6,20 @@ from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp
|
||||
|
||||
|
||||
def apply_adaln(x, shift, scale):
|
||||
def apply_adaln_(x, shift, scale):
|
||||
return x.addcmul_(x, scale).add_(shift)
|
||||
|
||||
|
||||
_POS_CACHE_MAX = 16
|
||||
|
||||
|
||||
def _cache_set(cache, key, value):
|
||||
"""Set with a soft LRU cap — evicts the oldest entry if at capacity."""
|
||||
if len(cache) >= _POS_CACHE_MAX:
|
||||
del cache[next(iter(cache))]
|
||||
cache[key] = value
|
||||
|
||||
|
||||
def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0, device=None, dtype=torch.float32):
|
||||
"""2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim.
|
||||
|
||||
@ -119,7 +129,7 @@ class PixelTokenEmbedder(nn.Module):
|
||||
pe = self._pos_cache.get(key)
|
||||
if pe is None:
|
||||
pe = get_2d_sincos_pos_embed(self.hidden_size_output, height, width)
|
||||
self._pos_cache[key] = pe
|
||||
_cache_set(self._pos_cache, key, pe)
|
||||
return pe.to(device=device, dtype=dtype)
|
||||
|
||||
def forward(self, inputs, img_height, img_width, patch_size):
|
||||
@ -176,7 +186,7 @@ class PiTBlock(nn.Module):
|
||||
pos = self._pos_cache.get(key)
|
||||
if pos is None:
|
||||
pos = self._rope_fn(self.attn_dim // self.num_heads, height, width)
|
||||
self._pos_cache[key] = pos
|
||||
_cache_set(self._pos_cache, key, pos)
|
||||
return pos.to(device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}):
|
||||
@ -188,7 +198,7 @@ class PiTBlock(nn.Module):
|
||||
# Attention path uses only msa params; compute, use, free before mlp params allocate.
|
||||
msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim)
|
||||
shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1)
|
||||
x_norm = apply_adaln(self.norm1(x), shift_msa, scale_msa)
|
||||
x_norm = apply_adaln_(self.norm1(x), shift_msa, scale_msa)
|
||||
x_flat = x_norm.view(BL, P2 * self.pixel_dim)
|
||||
x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
|
||||
pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype)
|
||||
@ -201,7 +211,7 @@ class PiTBlock(nn.Module):
|
||||
mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim)
|
||||
shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1)
|
||||
gate_mlp = gate_mlp.contiguous()
|
||||
mlp_input = apply_adaln(self.norm2(x), shift_mlp, scale_mlp)
|
||||
mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp)
|
||||
del mlp_params, shift_mlp, scale_mlp
|
||||
chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks
|
||||
for s in range(0, BL, chunk_size):
|
||||
|
||||
@ -12,6 +12,7 @@ import torch.nn.functional as F
|
||||
from comfy.ldm.flux.math import rope
|
||||
|
||||
from .model import PixDiT_T2I
|
||||
from .modules import _cache_set
|
||||
|
||||
|
||||
def precompute_freqs_cis_2d_ntk(dim: int, height: int, width: int,
|
||||
@ -221,66 +222,36 @@ class PidNet(PixDiT_T2I):
|
||||
height, width,
|
||||
self.rope_ref_grid_h, self.rope_ref_grid_w,
|
||||
)
|
||||
self._patch_pos_cache[key] = pos
|
||||
_cache_set(self._patch_pos_cache, key, pos)
|
||||
return pos.to(device=device, dtype=dtype)
|
||||
|
||||
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={},
|
||||
lq_latent=None, degrade_sigma=None, **kwargs):
|
||||
B, _, H, W = x.shape
|
||||
Hs = H // self.patch_size
|
||||
Ws = W // self.patch_size
|
||||
L = Hs * Ws
|
||||
def _pre_patch_block(self, s, i, pid_lq_features=None, pid_degrade_sigma=None, **kwargs):
|
||||
if pid_lq_features is None or not self.lq_proj.is_gate_active(i):
|
||||
return s
|
||||
out_idx = self.lq_proj.output_index(i)
|
||||
if out_idx >= len(pid_lq_features):
|
||||
return s
|
||||
return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx)
|
||||
|
||||
if context is None or context.dim() != 3:
|
||||
raise ValueError("PidNet requires context [B, L, D]")
|
||||
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs):
|
||||
if lq_latent is None:
|
||||
raise ValueError("PidNet requires lq_latent — attach via PiDConditioning")
|
||||
B = x.shape[0]
|
||||
Hs = x.shape[2] // self.patch_size
|
||||
Ws = x.shape[3] // self.patch_size
|
||||
|
||||
if degrade_sigma is None:
|
||||
degrade_sigma = torch.zeros(B, device=x.device, dtype=torch.float32)
|
||||
elif not isinstance(degrade_sigma, torch.Tensor):
|
||||
degrade_sigma = torch.tensor([float(degrade_sigma)] * B, device=x.device, dtype=torch.float32)
|
||||
else:
|
||||
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
|
||||
if degrade_sigma.numel() == 1 and B > 1:
|
||||
degrade_sigma = degrade_sigma.expand(B).contiguous()
|
||||
degrade_sigma = torch.as_tensor(degrade_sigma if degrade_sigma is not None else 0.0, device=x.device, dtype=torch.float32).reshape(-1)
|
||||
if degrade_sigma.numel() == 1 and B > 1:
|
||||
degrade_sigma = degrade_sigma.expand(B).contiguous()
|
||||
|
||||
lq_latent = lq_latent.to(device=x.device, dtype=x.dtype)
|
||||
lq_features = self.lq_proj(lq_latent=lq_latent, target_pH=Hs, target_pW=Ws)
|
||||
|
||||
pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype)
|
||||
x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
|
||||
t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size)
|
||||
|
||||
Ltxt = min(context.shape[1], self.txt_max_length)
|
||||
y = context[:, :Ltxt, :]
|
||||
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
|
||||
# y_pos_embedding is raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
|
||||
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(device=y_emb.device, dtype=y_emb.dtype)
|
||||
|
||||
condition = F.silu(t_emb)
|
||||
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
|
||||
|
||||
s = self.s_embedder(x_patches)
|
||||
for i, blk in enumerate(self.patch_blocks):
|
||||
if self.lq_proj.is_gate_active(i):
|
||||
out_idx = self.lq_proj.output_index(i)
|
||||
if out_idx < len(lq_features):
|
||||
s = self.lq_proj.gate(s, lq_features[out_idx], degrade_sigma, out_idx)
|
||||
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None,
|
||||
transformer_options=transformer_options)
|
||||
s = F.silu(t_emb + s)
|
||||
|
||||
s_cond = s.view(B * L, self.hidden_size)
|
||||
x_pixels = self.pixel_embedder(x, img_height=H, img_width=W, patch_size=self.patch_size)
|
||||
for blk in self.pixel_blocks:
|
||||
x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None,
|
||||
transformer_options=transformer_options)
|
||||
|
||||
x_pixels = self.final_layer(x_pixels)
|
||||
C_out = self.out_channels
|
||||
P2 = self.patch_size * self.patch_size
|
||||
x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).contiguous()
|
||||
x_pixels = x_pixels.view(B, C_out * P2, L)
|
||||
return F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return super()._forward(
|
||||
x, timesteps,
|
||||
context=context, attention_mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
pid_lq_features=lq_features,
|
||||
pid_degrade_sigma=degrade_sigma,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -93,10 +93,8 @@ class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
class PixelDiTGemma2TE(LuminaModel):
|
||||
"""Text encoder wrapper for PixelDiT.
|
||||
|
||||
Overrides `encode_token_weights` to perform PixelDiT's `select_index` step:
|
||||
encode the full padded sequence (up to ~chi_prompt_tokens + 298), then
|
||||
return `[BOS_emb] + last_299_embs` as the 300-position conditioning that
|
||||
matches the diffusion model's learned `y_pos_embedding` positions.
|
||||
Encodes the full padded sequence, then returns BOS + last 299 embeddings
|
||||
(PixelDiT's `select_index` step) to match the trained y_pos_embedding length.
|
||||
"""
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="gemma2_2b",
|
||||
|
||||
@ -7,10 +7,11 @@ import comfy.latent_formats
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
# Since this can be used only as upscaler with VAE, can't depend on latent format detection from any model
|
||||
_LATENT_FORMAT_CLASSES = {
|
||||
"flux": comfy.latent_formats.Flux,
|
||||
"sd3": comfy.latent_formats.SD3,
|
||||
"flux1": comfy.latent_formats.Flux,
|
||||
"flux2": comfy.latent_formats.Flux2,
|
||||
"sd3": comfy.latent_formats.SD3,
|
||||
}
|
||||
|
||||
|
||||
@ -22,9 +23,9 @@ class PiDConditioning(io.ComfyNode):
|
||||
display_name="PiD Conditioning",
|
||||
category="advanced/conditioning",
|
||||
description=(
|
||||
"Attaches an LDM latent (Flux/SD3/Flux2/Z-Image) and a degrade_sigma scalar "
|
||||
"Attaches an LDM latent (Flux1/Flux2/SD3) and a degrade_sigma scalar "
|
||||
"to a CONDITIONING for PiD decoding. Latent is renormalized into PiD space "
|
||||
"via the chosen latent_format. Z-Image uses 'flux'."
|
||||
"via the chosen latent_format. Z-Image uses 'flux1'."
|
||||
),
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
@ -32,7 +33,7 @@ class PiDConditioning(io.ComfyNode):
|
||||
io.Combo.Input(
|
||||
"latent_format",
|
||||
options=list(_LATENT_FORMAT_CLASSES.keys()),
|
||||
default="flux",
|
||||
default="flux1",
|
||||
),
|
||||
io.Float.Input(
|
||||
"degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user