From 7250d013dfac0269b16d30a31b77242a4b7b3672 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 5 Mar 2026 15:20:42 +0200 Subject: [PATCH] Not necessary for this model anymore --- comfy/ldm/rt_detr/rtdetr_v4.py | 2 +- comfy/ops.py | 31 ------------------------------- 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/comfy/ldm/rt_detr/rtdetr_v4.py b/comfy/ldm/rt_detr/rtdetr_v4.py index 3233dbdf6..426d3e6c3 100644 --- a/comfy/ldm/rt_detr/rtdetr_v4.py +++ b/comfy/ldm/rt_detr/rtdetr_v4.py @@ -31,7 +31,7 @@ class ConvBNAct(nn.Module): super().__init__() self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype) - self.bn = operations.BatchNorm2d(oc, device=device, dtype=dtype) + self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype) self.act = nn.ReLU() if use_act else nn.Identity() def forward(self, x): diff --git a/comfy/ops.py b/comfy/ops.py index 130dad74e..3752ed395 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -483,34 +483,6 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp): - def reset_parameters(self): - return None - - def forward_comfy_cast_weights(self, input): - if self.weight is not None: - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - else: - weight = None - bias = None - offload_stream = None - - x = torch.nn.functional.batch_norm( - input, - comfy.model_management.cast_to(self.running_mean, dtype=input.dtype, device=input.device), - comfy.model_management.cast_to(self.running_var, dtype=input.dtype, device=input.device), - weight, bias, self.training or not self.track_running_stats, - self.momentum, self.eps - ) - uncast_bias_weight(self, weight, bias, offload_stream) - return x - - def forward(self, *args, **kwargs): - run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(*args, **kwargs) - else: - return super().forward(*args, **kwargs) class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): def reset_parameters(self): @@ -625,9 +597,6 @@ class manual_cast(disable_weight_init): class Embedding(disable_weight_init.Embedding): comfy_cast_weights = True - class BatchNorm2d(disable_weight_init.BatchNorm2d): - comfy_cast_weights = True - def fp8_linear(self, input): """