Merge pull request #3 from silveroxides/zeta-x0-dino-class-patch

Zeta x0 dino class patch
This commit is contained in:
Lodestone 2026-03-01 17:54:44 +07:00 committed by GitHub
commit f36c967c82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,6 +15,7 @@ from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension
import comfy.utils
from comfy.ldm.chroma_radiance.layers import NerfEmbedder
def invert_slices(slices, length):
@ -868,88 +869,24 @@ def _modulate_shift_scale(x, shift, scale):
return x * (1 + scale) + shift
class NerfEmbedder(nn.Module):
"""
Combines input pixel features with 2D DCT-like positional encodings before
projecting to the decoder hidden size.
Input: [B, P^2, C]
Output: [B, P^2, hidden_size]
"""
def __init__(self, in_channels: int, hidden_size_input: int, max_freqs: int):
super().__init__()
self.max_freqs = max_freqs
self.hidden_size_input = hidden_size_input
self.embedder = nn.Sequential(
nn.Linear(in_channels + max_freqs ** 2, hidden_size_input)
)
@lru_cache(maxsize=4)
def fetch_pos(self, patch_size: int, device, dtype):
"""Generates and caches 2D DCT-like positional embeddings."""
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
pos_x = pos_x.reshape(-1, 1, 1)
pos_y = pos_y.reshape(-1, 1, 1)
freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
freqs_x = freqs[None, :, None]
freqs_y = freqs[None, None, :]
coeffs = (1 + freqs_x * freqs_y) ** -1
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
return dct
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
B, P2, C = inputs.shape
original_dtype = inputs.dtype
with torch.autocast("cuda", enabled=False):
patch_size = int(P2 ** 0.5)
inputs = inputs.float()
dct = self.fetch_pos(patch_size, inputs.device, torch.float32)
dct = dct.expand(B, -1, -1)
inputs = torch.cat([inputs, dct], dim=-1)
inputs = self.embedder.float()(inputs)
return inputs.to(original_dtype)
class PixelResBlock(nn.Module):
"""
Residual block with AdaLN modulation, zero-initialised so it starts as
an identity at the beginning of training.
"""
def __init__(self, channels: int):
def __init__(self, channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.in_ln = nn.LayerNorm(channels, eps=1e-6)
self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device)
self.mlp = nn.Sequential(
nn.Linear(channels, channels, bias=True),
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
nn.SiLU(),
nn.Linear(channels, channels, bias=True),
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(channels, 3 * channels, bias=True),
operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device),
)
self._init_weights()
def _init_weights(self):
for m in self.mlp:
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="linear")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
# Zero-init modulation → identity at init
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
@ -961,12 +898,10 @@ class PixelResBlock(nn.Module):
class DCTFinalLayer(nn.Module):
"""Zero-initialised output projection (adopted from DiT)."""
def __init__(self, model_channels: int, out_channels: int):
def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(model_channels, out_channels, bias=True)
nn.init.constant_(self.linear.weight, 0)
nn.init.constant_(self.linear.bias, 0)
self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.norm_final(x))
@ -992,33 +927,38 @@ class SimpleMLPAdaLN(nn.Module):
z_channels: int,
num_res_blocks: int,
max_freqs: int = 8,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
# Project backbone hidden state → per-patch conditioning
self.cond_embed = nn.Linear(z_channels, model_channels)
nn.init.xavier_uniform_(self.cond_embed.weight)
nn.init.constant_(self.cond_embed.bias, 0)
self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device)
# Input projection with DCT positional encoding
self.input_embedder = NerfEmbedder(
in_channels=in_channels,
hidden_size_input=model_channels,
max_freqs=max_freqs,
dtype=dtype,
device=device,
operations=operations,
)
# Residual blocks
self.res_blocks = nn.ModuleList([
PixelResBlock(model_channels) for _ in range(num_res_blocks)
PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks)
])
# Output projection
self.final_layer = DCTFinalLayer(model_channels, out_channels)
self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# x: [B*N, 1, P^2*C], c: [B*N, dim]
original_dtype = x.dtype
weight_dtype = self.cond_embed.weight.dtype
weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype)
x = self.input_embedder(x) # [B*N, 1, model_channels]
y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels]
x = x.to(weight_dtype)
@ -1077,6 +1017,9 @@ class NextDiTPixelSpace(NextDiT):
z_channels=dim,
num_res_blocks=decoder_num_res_blocks,
max_freqs=decoder_max_freqs,
dtype=kwargs.get("dtype"),
device=kwargs.get("device"),
operations=kwargs.get("operations"),
)
if use_x0: