From fd79579fb73a236c1b08e9f3d389e3c6064fa68d Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 4 Mar 2026 00:58:07 +0200 Subject: [PATCH] spacing --- comfy/ops.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 9a658f269..7bc0b40f1 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -438,24 +438,24 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp): - def reset_parameters(self): - return None + class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp): + def reset_parameters(self): + return None - def forward_comfy_cast_weights(self, input): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - running_mean = self.running_mean.to(weight.dtype) if self.running_mean is not None else None - running_var = self.running_var.to(weight.dtype) if self.running_var is not None else None - x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, False, self.momentum, self.eps) - uncast_bias_weight(self, weight, bias, offload_stream) - return x + def forward_comfy_cast_weights(self, input): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + running_mean = self.running_mean.to(weight.dtype) if self.running_mean is not None else None + running_var = self.running_var.to(weight.dtype) if self.running_var is not None else None + x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, False, self.momentum, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x - def forward(self, *args, **kwargs): - run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(*args, **kwargs) - else: - return super().forward(*args, **kwargs) + def forward(self, *args, **kwargs): + run_every_op() + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp): def reset_parameters(self):