diff --git a/comfy/ops.py b/comfy/ops.py index f4e8d9104..62e1bb834 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -444,9 +444,9 @@ class disable_weight_init: def forward_comfy_cast_weights(self, input): weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - running_mean = self.running_mean.to(weight.dtype) if self.running_mean is not None else None - running_var = self.running_var.to(weight.dtype) if self.running_var is not None else None - x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, False, self.momentum, self.eps) + running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None + running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None + x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps) uncast_bias_weight(self, weight, bias, offload_stream) return x