"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I body + LQ projection branch injected before each MMDiT patch block. """ from typing import List import torch import torch.nn as nn import torch.nn.functional as F from .model import PixDiT_T2I from .modules import precompute_freqs_cis_2d class SigmaAwareGatePerTokenPerDim(nn.Module): """gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq. Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1. """ def __init__(self, dim: int, dtype=None, device=None, operations=None): super().__init__() self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device) self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device)) def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: content_logit = self.content_proj(torch.cat([x, lq], dim=-1)) # log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM. log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32) sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1) gate = torch.sigmoid(content_logit + sigma_offset) return x + (gate * lq).to(x.dtype) class ResBlock(nn.Module): """Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip.""" def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None): super().__init__() self.block = nn.Sequential( operations.GroupNorm(num_groups, channels, dtype=dtype, device=device), nn.SiLU(), operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device), operations.GroupNorm(num_groups, channels, dtype=dtype, device=device), nn.SiLU(), operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.block(x) class LQProjection2D(nn.Module): """LQ latent -> per-block patch-aligned features for controlnet-style injection.""" def __init__( self, latent_channels: int, hidden_dim: int = 512, out_dim: int = 1536, patch_size: int = 16, sr_scale: int = 4, latent_spatial_down_factor: int = 8, num_res_blocks: int = 4, num_outputs: int = 7, interval: int = 2, dtype=None, device=None, operations=None, ): super().__init__() self.latent_channels = latent_channels self.hidden_dim = hidden_dim self.out_dim = out_dim self.patch_size = patch_size self.sr_scale = sr_scale self.latent_spatial_down_factor = latent_spatial_down_factor self.num_outputs = num_outputs self.interval = interval z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size self.z_to_patch_ratio = z_to_patch_ratio if z_to_patch_ratio >= 1: self.latent_fold_factor = 0 latent_proj_in_ch = latent_channels else: fold_factor = int(1 / z_to_patch_ratio) assert fold_factor * z_to_patch_ratio == 1.0 self.latent_fold_factor = fold_factor latent_proj_in_ch = latent_channels * fold_factor * fold_factor layers = [ operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device), nn.SiLU(), operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device), ] for _ in range(num_res_blocks): layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations)) self.latent_proj = nn.Sequential(*layers) self.output_heads = nn.ModuleList( [operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)] ) self.gate_modules = nn.ModuleList( [SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations) for _ in range(num_outputs)] ) def is_gate_active(self, block_idx: int) -> bool: return block_idx % self.interval == 0 def output_index(self, block_idx: int) -> int: return block_idx // self.interval def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor: return self.gate_modules[out_idx](x, lq_feature, sigma) def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor: B, z_dim = lq_latent.shape[:2] if self.z_to_patch_ratio >= 1: if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW: z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest") else: z_aligned = lq_latent else: f = self.latent_fold_factor zH_expected, zW_expected = pH * f, pW * f if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected: lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest") z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4) z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW) return self.latent_proj(z_aligned) def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]: feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW) B, C, H, W = feat.shape tokens = feat.permute(0, 2, 3, 1).contiguous().view(B, H * W, C) return [head(tokens) for head in self.output_heads] class PidNet(PixDiT_T2I): """PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block).""" def __init__( self, lq_latent_channels: int = 16, lq_hidden_dim: int = 512, lq_num_res_blocks: int = 4, lq_interval: int = 2, sr_scale: int = 4, latent_spatial_down_factor: int = 8, rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64. rope_ref_w: int = 1024, image_model=None, dtype=None, device=None, operations=None, **pixdit_kwargs, ): super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs) self.rope_ref_grid_h = rope_ref_h // self.patch_size self.rope_ref_grid_w = rope_ref_w // self.patch_size # Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware. def _pit_rope_fn(head_dim, h, w, device=None, dtype=torch.float32, **rope_opts): return precompute_freqs_cis_2d(head_dim, h, w, ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, device=device, dtype=dtype, **rope_opts) for blk in self.pixel_blocks: blk._rope_fn = _pit_rope_fn num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval self.lq_proj = LQProjection2D( latent_channels=lq_latent_channels, hidden_dim=lq_hidden_dim, out_dim=self.hidden_size, patch_size=self.patch_size, sr_scale=sr_scale, latent_spatial_down_factor=latent_spatial_down_factor, num_res_blocks=lq_num_res_blocks, num_outputs=num_lq_outputs, interval=lq_interval, dtype=dtype, device=device, operations=operations, ) def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts): return precompute_freqs_cis_2d( self.hidden_size // self.num_groups, height, width, ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, device=device, dtype=dtype, **rope_opts, ) def _pre_patch_block(self, s, i, pid_lq_features, pid_degrade_sigma, **kwargs): if not self.lq_proj.is_gate_active(i): return s out_idx = self.lq_proj.output_index(i) if out_idx >= len(pid_lq_features): return s return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx) def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs): if lq_latent is None: raise ValueError("PidNet requires lq_latent — attach via PiDConditioning") expected_c = self.lq_proj.latent_channels if lq_latent.shape[1] != expected_c: raise ValueError( f"Input latent has {lq_latent.shape[1]} channels, this model variant expects {expected_c}. " f"Flux1/SD3 = 16 channels, Flux2 = 128 channels." ) B = x.shape[0] Hs = x.shape[2] // self.patch_size Ws = x.shape[3] // self.patch_size degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1) if degrade_sigma.numel() == 1 and B > 1: degrade_sigma = degrade_sigma.expand(B).contiguous() lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws) return super()._forward( x, timesteps, context=context, attention_mask=attention_mask, transformer_options=transformer_options, pid_lq_features=lq_features, pid_degrade_sigma=degrade_sigma, **kwargs, )