cascade: remove dead weight init code (#13026)

This weight init process is fully shadowed be the weight load and
doesnt work in dynamic_vram were the weight allocation is deferred.
This commit is contained in:
rattus 2026-03-17 17:59:10 -07:00 committed by GitHub
parent 68d542cc06
commit cad24ce262
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -136,16 +136,7 @@ class ResBlock(nn.Module):
ops.Linear(c_hidden, c),
)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
# Init weights
def _basic_init(module):
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
def _norm(self, x, norm):
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)