coderabbit suggestion for batchnorm2d

This commit is contained in:
Yousef Rafat 2026-03-04 01:10:31 +02:00
parent 538cdab031
commit 0a5a986bcf

View File

@ -444,9 +444,9 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
running_mean = self.running_mean.to(weight.dtype) if self.running_mean is not None else None
running_var = self.running_var.to(weight.dtype) if self.running_var is not None else None
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, False, self.momentum, self.eps)
running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None
running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x