diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index dea98e66d..677749f3b 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -15,6 +15,7 @@ from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope import comfy.patcher_extension import comfy.utils +from comfy.ldm.chroma_radiance.layers import NerfEmbedder def invert_slices(slices, length): @@ -868,88 +869,37 @@ def _modulate_shift_scale(x, shift, scale): return x * (1 + scale) + shift -class NerfEmbedder(nn.Module): - """ - Combines input pixel features with 2D DCT-like positional encodings before - projecting to the decoder hidden size. - - Input: [B, P^2, C] - Output: [B, P^2, hidden_size] - """ - - def __init__(self, in_channels: int, hidden_size_input: int, max_freqs: int): - super().__init__() - self.max_freqs = max_freqs - self.hidden_size_input = hidden_size_input - self.embedder = nn.Sequential( - nn.Linear(in_channels + max_freqs ** 2, hidden_size_input) - ) - - @lru_cache(maxsize=4) - def fetch_pos(self, patch_size: int, device, dtype): - """Generates and caches 2D DCT-like positional embeddings.""" - pos_x = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) - pos_y = torch.linspace(0, 1, patch_size, device=device, dtype=dtype) - pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") - - pos_x = pos_x.reshape(-1, 1, 1) - pos_y = pos_y.reshape(-1, 1, 1) - - freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs, dtype=dtype, device=device) - freqs_x = freqs[None, :, None] - freqs_y = freqs[None, None, :] - - coeffs = (1 + freqs_x * freqs_y) ** -1 - dct_x = torch.cos(pos_x * freqs_x * torch.pi) - dct_y = torch.cos(pos_y * freqs_y * torch.pi) - dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2) - - return dct - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - B, P2, C = inputs.shape - original_dtype = inputs.dtype - - with torch.autocast("cuda", enabled=False): - patch_size = int(P2 ** 0.5) - inputs = inputs.float() - dct = self.fetch_pos(patch_size, inputs.device, torch.float32) - dct = dct.expand(B, -1, -1) - inputs = torch.cat([inputs, dct], dim=-1) - inputs = self.embedder.float()(inputs) - - return inputs.to(original_dtype) - - class PixelResBlock(nn.Module): """ Residual block with AdaLN modulation, zero-initialised so it starts as an identity at the beginning of training. """ - def __init__(self, channels: int): + def __init__(self, channels: int, dtype=None, device=None, operations=None): super().__init__() - self.in_ln = nn.LayerNorm(channels, eps=1e-6) + self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device) self.mlp = nn.Sequential( - nn.Linear(channels, channels, bias=True), + operations.Linear(channels, channels, bias=True, dtype=dtype, device=device), nn.SiLU(), - nn.Linear(channels, channels, bias=True), + operations.Linear(channels, channels, bias=True, dtype=dtype, device=device), ) self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(channels, 3 * channels, bias=True), + operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device), ) self._init_weights() def _init_weights(self): for m in self.mlp: - if isinstance(m, nn.Linear): + if hasattr(m, 'weight'): nn.init.kaiming_normal_(m.weight, nonlinearity="linear") - if m.bias is not None: + if hasattr(m, 'bias') and m.bias is not None: nn.init.constant_(m.bias, 0) # Zero-init modulation → identity at init - nn.init.constant_(self.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + if hasattr(self.adaLN_modulation[-1], 'weight'): + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + if hasattr(self.adaLN_modulation[-1], 'bias') and self.adaLN_modulation[-1].bias is not None: + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1) @@ -961,12 +911,14 @@ class PixelResBlock(nn.Module): class DCTFinalLayer(nn.Module): """Zero-initialised output projection (adopted from DiT).""" - def __init__(self, model_channels: int, out_channels: int): + def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None): super().__init__() - self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) - self.linear = nn.Linear(model_channels, out_channels, bias=True) - nn.init.constant_(self.linear.weight, 0) - nn.init.constant_(self.linear.bias, 0) + self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device) + if hasattr(self.linear, 'weight'): + nn.init.constant_(self.linear.weight, 0) + if hasattr(self.linear, 'bias') and self.linear.bias is not None: + nn.init.constant_(self.linear.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(self.norm_final(x)) @@ -992,33 +944,41 @@ class SimpleMLPAdaLN(nn.Module): z_channels: int, num_res_blocks: int, max_freqs: int = 8, + dtype=None, + device=None, + operations=None, ): super().__init__() # Project backbone hidden state → per-patch conditioning - self.cond_embed = nn.Linear(z_channels, model_channels) - nn.init.xavier_uniform_(self.cond_embed.weight) - nn.init.constant_(self.cond_embed.bias, 0) + self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device) + if hasattr(self.cond_embed, 'weight'): + nn.init.xavier_uniform_(self.cond_embed.weight) + if hasattr(self.cond_embed, 'bias') and self.cond_embed.bias is not None: + nn.init.constant_(self.cond_embed.bias, 0) # Input projection with DCT positional encoding self.input_embedder = NerfEmbedder( in_channels=in_channels, hidden_size_input=model_channels, max_freqs=max_freqs, + dtype=dtype, + device=device, + operations=operations, ) # Residual blocks self.res_blocks = nn.ModuleList([ - PixelResBlock(model_channels) for _ in range(num_res_blocks) + PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks) ]) # Output projection - self.final_layer = DCTFinalLayer(model_channels, out_channels) + self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # x: [B*N, 1, P^2*C], c: [B*N, dim] original_dtype = x.dtype - weight_dtype = self.cond_embed.weight.dtype + weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, 'weight') else x.dtype x = self.input_embedder(x) # [B*N, 1, model_channels] y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels] x = x.to(weight_dtype) @@ -1077,6 +1037,9 @@ class NextDiTPixelSpace(NextDiT): z_channels=dim, num_res_blocks=decoder_num_res_blocks, max_freqs=decoder_max_freqs, + dtype=kwargs.get("dtype"), + device=kwargs.get("device"), + operations=kwargs.get("operations"), ) if use_x0: