mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 17:37:39 +08:00
227 lines
9.5 KiB
Python
227 lines
9.5 KiB
Python
"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent
|
|
directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I
|
|
body + LQ projection branch injected before each MMDiT patch block.
|
|
"""
|
|
|
|
from typing import List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from .model import PixDiT_T2I
|
|
from .modules import precompute_freqs_cis_2d
|
|
|
|
|
|
class SigmaAwareGatePerTokenPerDim(nn.Module):
|
|
"""gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
|
|
|
|
Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1.
|
|
"""
|
|
|
|
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device)
|
|
self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device))
|
|
|
|
def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
|
content_logit = self.content_proj(torch.cat([x, lq], dim=-1))
|
|
# log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
|
|
log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32)
|
|
sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1)
|
|
gate = torch.sigmoid(content_logit + sigma_offset)
|
|
return x + (gate * lq).to(x.dtype)
|
|
|
|
|
|
class ResBlock(nn.Module):
|
|
"""Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip."""
|
|
|
|
def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.block = nn.Sequential(
|
|
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
|
|
nn.SiLU(),
|
|
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
|
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
|
|
nn.SiLU(),
|
|
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x + self.block(x)
|
|
|
|
|
|
class LQProjection2D(nn.Module):
|
|
"""LQ latent -> per-block patch-aligned features for controlnet-style injection."""
|
|
|
|
def __init__(
|
|
self,
|
|
latent_channels: int,
|
|
hidden_dim: int = 512,
|
|
out_dim: int = 1536,
|
|
patch_size: int = 16,
|
|
sr_scale: int = 4,
|
|
latent_spatial_down_factor: int = 8,
|
|
num_res_blocks: int = 4,
|
|
num_outputs: int = 7,
|
|
interval: int = 2,
|
|
dtype=None, device=None, operations=None,
|
|
):
|
|
super().__init__()
|
|
self.latent_channels = latent_channels
|
|
self.hidden_dim = hidden_dim
|
|
self.out_dim = out_dim
|
|
self.patch_size = patch_size
|
|
self.sr_scale = sr_scale
|
|
self.latent_spatial_down_factor = latent_spatial_down_factor
|
|
self.num_outputs = num_outputs
|
|
self.interval = interval
|
|
|
|
z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size
|
|
self.z_to_patch_ratio = z_to_patch_ratio
|
|
if z_to_patch_ratio >= 1:
|
|
self.latent_fold_factor = 0
|
|
latent_proj_in_ch = latent_channels
|
|
else:
|
|
fold_factor = int(1 / z_to_patch_ratio)
|
|
assert fold_factor * z_to_patch_ratio == 1.0
|
|
self.latent_fold_factor = fold_factor
|
|
latent_proj_in_ch = latent_channels * fold_factor * fold_factor
|
|
|
|
layers = [
|
|
operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
|
|
nn.SiLU(),
|
|
operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
|
|
]
|
|
for _ in range(num_res_blocks):
|
|
layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations))
|
|
self.latent_proj = nn.Sequential(*layers)
|
|
|
|
self.output_heads = nn.ModuleList(
|
|
[operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)]
|
|
)
|
|
self.gate_modules = nn.ModuleList(
|
|
[SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations)
|
|
for _ in range(num_outputs)]
|
|
)
|
|
|
|
def is_gate_active(self, block_idx: int) -> bool:
|
|
return block_idx % self.interval == 0
|
|
|
|
def output_index(self, block_idx: int) -> int:
|
|
return block_idx // self.interval
|
|
|
|
def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor:
|
|
return self.gate_modules[out_idx](x, lq_feature, sigma)
|
|
|
|
def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor:
|
|
B, z_dim = lq_latent.shape[:2]
|
|
if self.z_to_patch_ratio >= 1:
|
|
if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW:
|
|
z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest")
|
|
else:
|
|
z_aligned = lq_latent
|
|
else:
|
|
f = self.latent_fold_factor
|
|
zH_expected, zW_expected = pH * f, pW * f
|
|
if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected:
|
|
lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest")
|
|
z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4)
|
|
z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW)
|
|
return self.latent_proj(z_aligned)
|
|
|
|
def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]:
|
|
feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW)
|
|
B, C, H, W = feat.shape
|
|
tokens = feat.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)
|
|
return [head(tokens) for head in self.output_heads]
|
|
|
|
|
|
class PidNet(PixDiT_T2I):
|
|
"""PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block)."""
|
|
|
|
def __init__(
|
|
self,
|
|
lq_latent_channels: int = 16,
|
|
lq_hidden_dim: int = 512,
|
|
lq_num_res_blocks: int = 4,
|
|
lq_interval: int = 2,
|
|
sr_scale: int = 4,
|
|
latent_spatial_down_factor: int = 8,
|
|
rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64.
|
|
rope_ref_w: int = 1024,
|
|
image_model=None,
|
|
dtype=None, device=None, operations=None,
|
|
**pixdit_kwargs,
|
|
):
|
|
super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs)
|
|
|
|
self.rope_ref_grid_h = rope_ref_h // self.patch_size
|
|
self.rope_ref_grid_w = rope_ref_w // self.patch_size
|
|
|
|
# Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware.
|
|
def _pit_rope_fn(head_dim, h, w, device=None, dtype=torch.float32, **rope_opts):
|
|
return precompute_freqs_cis_2d(head_dim, h, w, ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, device=device, dtype=dtype, **rope_opts)
|
|
for blk in self.pixel_blocks:
|
|
blk._rope_fn = _pit_rope_fn
|
|
|
|
num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval
|
|
self.lq_proj = LQProjection2D(
|
|
latent_channels=lq_latent_channels,
|
|
hidden_dim=lq_hidden_dim,
|
|
out_dim=self.hidden_size,
|
|
patch_size=self.patch_size,
|
|
sr_scale=sr_scale,
|
|
latent_spatial_down_factor=latent_spatial_down_factor,
|
|
num_res_blocks=lq_num_res_blocks,
|
|
num_outputs=num_lq_outputs,
|
|
interval=lq_interval,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
|
|
return precompute_freqs_cis_2d(
|
|
self.hidden_size // self.num_groups,
|
|
height, width,
|
|
ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w,
|
|
device=device, dtype=dtype, **rope_opts,
|
|
)
|
|
|
|
def _pre_patch_block(self, s, i, pid_lq_features, pid_degrade_sigma, **kwargs):
|
|
if not self.lq_proj.is_gate_active(i):
|
|
return s
|
|
out_idx = self.lq_proj.output_index(i)
|
|
if out_idx >= len(pid_lq_features):
|
|
return s
|
|
return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx)
|
|
|
|
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs):
|
|
if lq_latent is None:
|
|
raise ValueError("PidNet requires lq_latent — attach via PiDConditioning")
|
|
expected_c = self.lq_proj.latent_channels
|
|
if lq_latent.shape[1] != expected_c:
|
|
raise ValueError(
|
|
f"Input latent has {lq_latent.shape[1]} channels, this model variant expects {expected_c}. "
|
|
f"Flux1/SD3 = 16 channels, Flux2 = 128 channels."
|
|
)
|
|
B = x.shape[0]
|
|
Hs = x.shape[2] // self.patch_size
|
|
Ws = x.shape[3] // self.patch_size
|
|
|
|
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
|
|
if degrade_sigma.numel() == 1 and B > 1:
|
|
degrade_sigma = degrade_sigma.expand(B).contiguous()
|
|
|
|
lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws)
|
|
|
|
return super()._forward(
|
|
x, timesteps,
|
|
context=context, attention_mask=attention_mask,
|
|
transformer_options=transformer_options,
|
|
pid_lq_features=lq_features,
|
|
pid_degrade_sigma=degrade_sigma,
|
|
**kwargs,
|
|
)
|