diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 677749f3b..2105a1894 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -887,19 +887,6 @@ class PixelResBlock(nn.Module): nn.SiLU(), operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device), ) - self._init_weights() - - def _init_weights(self): - for m in self.mlp: - if hasattr(m, 'weight'): - nn.init.kaiming_normal_(m.weight, nonlinearity="linear") - if hasattr(m, 'bias') and m.bias is not None: - nn.init.constant_(m.bias, 0) - # Zero-init modulation → identity at init - if hasattr(self.adaLN_modulation[-1], 'weight'): - nn.init.constant_(self.adaLN_modulation[-1].weight, 0) - if hasattr(self.adaLN_modulation[-1], 'bias') and self.adaLN_modulation[-1].bias is not None: - nn.init.constant_(self.adaLN_modulation[-1].bias, 0) def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1) @@ -915,10 +902,6 @@ class DCTFinalLayer(nn.Module): super().__init__() self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device) - if hasattr(self.linear, 'weight'): - nn.init.constant_(self.linear.weight, 0) - if hasattr(self.linear, 'bias') and self.linear.bias is not None: - nn.init.constant_(self.linear.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(self.norm_final(x)) @@ -949,13 +932,10 @@ class SimpleMLPAdaLN(nn.Module): operations=None, ): super().__init__() + self.dtype = dtype # Project backbone hidden state → per-patch conditioning self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device) - if hasattr(self.cond_embed, 'weight'): - nn.init.xavier_uniform_(self.cond_embed.weight) - if hasattr(self.cond_embed, 'bias') and self.cond_embed.bias is not None: - nn.init.constant_(self.cond_embed.bias, 0) # Input projection with DCT positional encoding self.input_embedder = NerfEmbedder( @@ -978,7 +958,7 @@ class SimpleMLPAdaLN(nn.Module): def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # x: [B*N, 1, P^2*C], c: [B*N, dim] original_dtype = x.dtype - weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, 'weight') else x.dtype + weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype) x = self.input_embedder(x) # [B*N, 1, model_channels] y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels] x = x.to(weight_dtype)