cascade: remove dead weight init code

This weight init process is fully shadowed be the weight load and
doesnt work in dynamic_vram were the weight allocation is deferred.
This commit is contained in:
Rattus 2026-03-18 08:12:48 +10:00
parent 8b9d039f26
commit b68a96a436

View File

@ -136,16 +136,7 @@ class ResBlock(nn.Module):
ops.Linear(c_hidden, c),
)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
# Init weights
def _basic_init(module):
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
def _norm(self, x, norm):
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)