mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-18 23:55:08 +08:00
Not necessary for this model anymore
This commit is contained in:
parent
c76645f9bc
commit
7250d013df
@ -31,7 +31,7 @@ class ConvBNAct(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype)
|
||||
self.bn = operations.BatchNorm2d(oc, device=device, dtype=dtype)
|
||||
self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype)
|
||||
self.act = nn.ReLU() if use_act else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
31
comfy/ops.py
31
comfy/ops.py
@ -483,34 +483,6 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if self.weight is not None:
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
else:
|
||||
weight = None
|
||||
bias = None
|
||||
offload_stream = None
|
||||
|
||||
x = torch.nn.functional.batch_norm(
|
||||
input,
|
||||
comfy.model_management.cast_to(self.running_mean, dtype=input.dtype, device=input.device),
|
||||
comfy.model_management.cast_to(self.running_var, dtype=input.dtype, device=input.device),
|
||||
weight, bias, self.training or not self.track_running_stats,
|
||||
self.momentum, self.eps
|
||||
)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
@ -625,9 +597,6 @@ class manual_cast(disable_weight_init):
|
||||
class Embedding(disable_weight_init.Embedding):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class BatchNorm2d(disable_weight_init.BatchNorm2d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
|
||||
def fp8_linear(self, input):
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user