Not necessary for this model anymore

This commit is contained in:
kijai 2026-03-05 15:20:42 +02:00
parent c76645f9bc
commit 7250d013df
2 changed files with 1 additions and 32 deletions

View File

@ -31,7 +31,7 @@ class ConvBNAct(nn.Module):
super().__init__()
self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype)
self.bn = operations.BatchNorm2d(oc, device=device, dtype=dtype)
self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype)
self.act = nn.ReLU() if use_act else nn.Identity()
def forward(self, x):

View File

@ -483,34 +483,6 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
bias = None
offload_stream = None
x = torch.nn.functional.batch_norm(
input,
comfy.model_management.cast_to(self.running_mean, dtype=input.dtype, device=input.device),
comfy.model_management.cast_to(self.running_var, dtype=input.dtype, device=input.device),
weight, bias, self.training or not self.track_running_stats,
self.momentum, self.eps
)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
def reset_parameters(self):
@ -625,9 +597,6 @@ class manual_cast(disable_weight_init):
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True
class BatchNorm2d(disable_weight_init.BatchNorm2d):
comfy_cast_weights = True
def fp8_linear(self, input):
"""