mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-27 03:53:27 +08:00
Not necessary for this model anymore
This commit is contained in:
parent
c76645f9bc
commit
7250d013df
@ -31,7 +31,7 @@ class ConvBNAct(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype)
|
self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype)
|
||||||
self.bn = operations.BatchNorm2d(oc, device=device, dtype=dtype)
|
self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype)
|
||||||
self.act = nn.ReLU() if use_act else nn.Identity()
|
self.act = nn.ReLU() if use_act else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|||||||
31
comfy/ops.py
31
comfy/ops.py
@ -483,34 +483,6 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
|
|
||||||
def reset_parameters(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
|
||||||
if self.weight is not None:
|
|
||||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
|
||||||
else:
|
|
||||||
weight = None
|
|
||||||
bias = None
|
|
||||||
offload_stream = None
|
|
||||||
|
|
||||||
x = torch.nn.functional.batch_norm(
|
|
||||||
input,
|
|
||||||
comfy.model_management.cast_to(self.running_mean, dtype=input.dtype, device=input.device),
|
|
||||||
comfy.model_management.cast_to(self.running_var, dtype=input.dtype, device=input.device),
|
|
||||||
weight, bias, self.training or not self.track_running_stats,
|
|
||||||
self.momentum, self.eps
|
|
||||||
)
|
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
run_every_op()
|
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
return super().forward(*args, **kwargs)
|
|
||||||
|
|
||||||
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
@ -625,9 +597,6 @@ class manual_cast(disable_weight_init):
|
|||||||
class Embedding(disable_weight_init.Embedding):
|
class Embedding(disable_weight_init.Embedding):
|
||||||
comfy_cast_weights = True
|
comfy_cast_weights = True
|
||||||
|
|
||||||
class BatchNorm2d(disable_weight_init.BatchNorm2d):
|
|
||||||
comfy_cast_weights = True
|
|
||||||
|
|
||||||
|
|
||||||
def fp8_linear(self, input):
|
def fp8_linear(self, input):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user