Update modules.py

This commit is contained in:
kijai 2026-05-13 21:58:16 +03:00
parent 6617e76b1c
commit 719e6facf9

View File

@ -36,8 +36,8 @@ class ResidualConvBlock(nn.Module):
super().__init__() super().__init__()
hidden_channels = hidden_channels if hidden_channels is not None else channels hidden_channels = hidden_channels if hidden_channels is not None else channels
in_norm_layer = operations.GroupNorm(1, channels) if in_norm == "layer_norm" else nn.Identity() in_norm_layer = operations.GroupNorm(1, channels, dtype=dtype, device=device) if in_norm == "layer_norm" else nn.Identity()
hidden_norm_layer = (operations.GroupNorm(max(hidden_channels // 32, 1), hidden_channels) hidden_norm_layer = (operations.GroupNorm(max(hidden_channels // 32, 1), hidden_channels, dtype=dtype, device=device)
if hidden_norm == "group_norm" else nn.Identity()) if hidden_norm == "group_norm" else nn.Identity())
self.layers = nn.Sequential( self.layers = nn.Sequential(