This commit is contained in:
Yousef Rafat 2026-03-04 00:58:07 +02:00
parent 8ed12509f0
commit fd79579fb7

View File

@ -438,24 +438,24 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
def reset_parameters(self):
return None
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
running_mean = self.running_mean.to(weight.dtype) if self.running_mean is not None else None
running_var = self.running_var.to(weight.dtype) if self.running_var is not None else None
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, False, self.momentum, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
running_mean = self.running_mean.to(weight.dtype) if self.running_mean is not None else None
running_var = self.running_var.to(weight.dtype) if self.running_var is not None else None
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, False, self.momentum, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):