mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 05:23:37 +08:00
spacing
This commit is contained in:
parent
8ed12509f0
commit
fd79579fb7
32
comfy/ops.py
32
comfy/ops.py
@ -438,24 +438,24 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
running_mean = self.running_mean.to(weight.dtype) if self.running_mean is not None else None
|
||||
running_var = self.running_var.to(weight.dtype) if self.running_var is not None else None
|
||||
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, False, self.momentum, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
running_mean = self.running_mean.to(weight.dtype) if self.running_mean is not None else None
|
||||
running_var = self.running_var.to(weight.dtype) if self.running_var is not None else None
|
||||
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, False, self.momentum, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user