mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-12 09:17:51 +08:00
Merge pull request #3 from silveroxides/zeta-x0-dino-class-patch
Zeta x0 dino class patch
This commit is contained in:
commit
f36c967c82
@ -15,6 +15,7 @@ from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
from comfy.ldm.chroma_radiance.layers import NerfEmbedder
|
||||
|
||||
|
||||
def invert_slices(slices, length):
|
||||
@ -868,88 +869,24 @@ def _modulate_shift_scale(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
"""
|
||||
Combines input pixel features with 2D DCT-like positional encodings before
|
||||
projecting to the decoder hidden size.
|
||||
|
||||
Input: [B, P^2, C]
|
||||
Output: [B, P^2, hidden_size]
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, hidden_size_input: int, max_freqs: int):
|
||||
super().__init__()
|
||||
self.max_freqs = max_freqs
|
||||
self.hidden_size_input = hidden_size_input
|
||||
self.embedder = nn.Sequential(
|
||||
nn.Linear(in_channels + max_freqs ** 2, hidden_size_input)
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=4)
|
||||
def fetch_pos(self, patch_size: int, device, dtype):
|
||||
"""Generates and caches 2D DCT-like positional embeddings."""
|
||||
pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype)
|
||||
pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij")
|
||||
|
||||
pos_x = pos_x.reshape(-1, 1, 1)
|
||||
pos_y = pos_y.reshape(-1, 1, 1)
|
||||
|
||||
freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device)
|
||||
freqs_x = freqs[None, :, None]
|
||||
freqs_y = freqs[None, None, :]
|
||||
|
||||
coeffs = (1 + freqs_x * freqs_y) ** -1
|
||||
dct_x = torch.cos(pos_x * freqs_x * torch.pi)
|
||||
dct_y = torch.cos(pos_y * freqs_y * torch.pi)
|
||||
dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2)
|
||||
|
||||
return dct
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
B, P2, C = inputs.shape
|
||||
original_dtype = inputs.dtype
|
||||
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
patch_size = int(P2 ** 0.5)
|
||||
inputs = inputs.float()
|
||||
dct = self.fetch_pos(patch_size, inputs.device, torch.float32)
|
||||
dct = dct.expand(B, -1, -1)
|
||||
inputs = torch.cat([inputs, dct], dim=-1)
|
||||
inputs = self.embedder.float()(inputs)
|
||||
|
||||
return inputs.to(original_dtype)
|
||||
|
||||
|
||||
class PixelResBlock(nn.Module):
|
||||
"""
|
||||
Residual block with AdaLN modulation, zero-initialised so it starts as
|
||||
an identity at the beginning of training.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int):
|
||||
def __init__(self, channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_ln = nn.LayerNorm(channels, eps=1e-6)
|
||||
self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(channels, channels, bias=True),
|
||||
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
nn.Linear(channels, channels, bias=True),
|
||||
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(channels, 3 * channels, bias=True),
|
||||
operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.mlp:
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="linear")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
# Zero-init modulation → identity at init
|
||||
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
|
||||
@ -961,12 +898,10 @@ class PixelResBlock(nn.Module):
|
||||
class DCTFinalLayer(nn.Module):
|
||||
"""Zero-initialised output projection (adopted from DiT)."""
|
||||
|
||||
def __init__(self, model_channels: int, out_channels: int):
|
||||
def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(model_channels, out_channels, bias=True)
|
||||
nn.init.constant_(self.linear.weight, 0)
|
||||
nn.init.constant_(self.linear.bias, 0)
|
||||
self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(self.norm_final(x))
|
||||
@ -992,33 +927,38 @@ class SimpleMLPAdaLN(nn.Module):
|
||||
z_channels: int,
|
||||
num_res_blocks: int,
|
||||
max_freqs: int = 8,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
# Project backbone hidden state → per-patch conditioning
|
||||
self.cond_embed = nn.Linear(z_channels, model_channels)
|
||||
nn.init.xavier_uniform_(self.cond_embed.weight)
|
||||
nn.init.constant_(self.cond_embed.bias, 0)
|
||||
self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device)
|
||||
|
||||
# Input projection with DCT positional encoding
|
||||
self.input_embedder = NerfEmbedder(
|
||||
in_channels=in_channels,
|
||||
hidden_size_input=model_channels,
|
||||
max_freqs=max_freqs,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
# Residual blocks
|
||||
self.res_blocks = nn.ModuleList([
|
||||
PixelResBlock(model_channels) for _ in range(num_res_blocks)
|
||||
PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks)
|
||||
])
|
||||
|
||||
# Output projection
|
||||
self.final_layer = DCTFinalLayer(model_channels, out_channels)
|
||||
self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
# x: [B*N, 1, P^2*C], c: [B*N, dim]
|
||||
original_dtype = x.dtype
|
||||
weight_dtype = self.cond_embed.weight.dtype
|
||||
weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype)
|
||||
x = self.input_embedder(x) # [B*N, 1, model_channels]
|
||||
y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels]
|
||||
x = x.to(weight_dtype)
|
||||
@ -1077,6 +1017,9 @@ class NextDiTPixelSpace(NextDiT):
|
||||
z_channels=dim,
|
||||
num_res_blocks=decoder_num_res_blocks,
|
||||
max_freqs=decoder_max_freqs,
|
||||
dtype=kwargs.get("dtype"),
|
||||
device=kwargs.get("device"),
|
||||
operations=kwargs.get("operations"),
|
||||
)
|
||||
|
||||
if use_x0:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user