mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-12 17:27:26 +08:00
remove arbitrary nn.init instances
This commit is contained in:
parent
ed3090319c
commit
3419068c94
@ -887,19 +887,6 @@ class PixelResBlock(nn.Module):
|
||||
nn.SiLU(),
|
||||
operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.mlp:
|
||||
if hasattr(m, 'weight'):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="linear")
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
# Zero-init modulation → identity at init
|
||||
if hasattr(self.adaLN_modulation[-1], 'weight'):
|
||||
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
||||
if hasattr(self.adaLN_modulation[-1], 'bias') and self.adaLN_modulation[-1].bias is not None:
|
||||
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
|
||||
@ -915,10 +902,6 @@ class DCTFinalLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device)
|
||||
if hasattr(self.linear, 'weight'):
|
||||
nn.init.constant_(self.linear.weight, 0)
|
||||
if hasattr(self.linear, 'bias') and self.linear.bias is not None:
|
||||
nn.init.constant_(self.linear.bias, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(self.norm_final(x))
|
||||
@ -949,13 +932,10 @@ class SimpleMLPAdaLN(nn.Module):
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
|
||||
# Project backbone hidden state → per-patch conditioning
|
||||
self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device)
|
||||
if hasattr(self.cond_embed, 'weight'):
|
||||
nn.init.xavier_uniform_(self.cond_embed.weight)
|
||||
if hasattr(self.cond_embed, 'bias') and self.cond_embed.bias is not None:
|
||||
nn.init.constant_(self.cond_embed.bias, 0)
|
||||
|
||||
# Input projection with DCT positional encoding
|
||||
self.input_embedder = NerfEmbedder(
|
||||
@ -978,7 +958,7 @@ class SimpleMLPAdaLN(nn.Module):
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
||||
# x: [B*N, 1, P^2*C], c: [B*N, dim]
|
||||
original_dtype = x.dtype
|
||||
weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, 'weight') else x.dtype
|
||||
weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype)
|
||||
x = self.input_embedder(x) # [B*N, 1, model_channels]
|
||||
y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels]
|
||||
x = x.to(weight_dtype)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user