From 0a5a986bcf0323e1246d98e0b87e2ebb3dfbc246 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 4 Mar 2026 01:10:31 +0200 Subject: [PATCH] coderabbit suggestion for batchnorm2d --- comfy/ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index f4e8d9104..62e1bb834 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -444,9 +444,9 @@ class disable_weight_init: def forward_comfy_cast_weights(self, input): weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - running_mean = self.running_mean.to(weight.dtype) if self.running_mean is not None else None - running_var = self.running_var.to(weight.dtype) if self.running_var is not None else None - x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, False, self.momentum, self.eps) + running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None + running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None + x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps) uncast_bias_weight(self, weight, bias, offload_stream) return x