remove arbitrary nn.init instances

This commit is contained in:
silveroxides 2026-03-01 11:39:38 +01:00
parent ed3090319c
commit 3419068c94

View File

@ -887,19 +887,6 @@ class PixelResBlock(nn.Module):
nn.SiLU(),
operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device),
)
self._init_weights()
def _init_weights(self):
for m in self.mlp:
if hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, nonlinearity="linear")
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
# Zero-init modulation → identity at init
if hasattr(self.adaLN_modulation[-1], 'weight'):
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
if hasattr(self.adaLN_modulation[-1], 'bias') and self.adaLN_modulation[-1].bias is not None:
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
@ -915,10 +902,6 @@ class DCTFinalLayer(nn.Module):
super().__init__()
self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device)
if hasattr(self.linear, 'weight'):
nn.init.constant_(self.linear.weight, 0)
if hasattr(self.linear, 'bias') and self.linear.bias is not None:
nn.init.constant_(self.linear.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.norm_final(x))
@ -949,13 +932,10 @@ class SimpleMLPAdaLN(nn.Module):
operations=None,
):
super().__init__()
self.dtype = dtype
# Project backbone hidden state → per-patch conditioning
self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device)
if hasattr(self.cond_embed, 'weight'):
nn.init.xavier_uniform_(self.cond_embed.weight)
if hasattr(self.cond_embed, 'bias') and self.cond_embed.bias is not None:
nn.init.constant_(self.cond_embed.bias, 0)
# Input projection with DCT positional encoding
self.input_embedder = NerfEmbedder(
@ -978,7 +958,7 @@ class SimpleMLPAdaLN(nn.Module):
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# x: [B*N, 1, P^2*C], c: [B*N, dim]
original_dtype = x.dtype
weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, 'weight') else x.dtype
weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype)
x = self.input_embedder(x) # [B*N, 1, model_channels]
y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels]
x = x.to(weight_dtype)