mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-24 01:12:37 +08:00
spacing
This commit is contained in:
parent
8ed12509f0
commit
fd79579fb7
32
comfy/ops.py
32
comfy/ops.py
@ -438,24 +438,24 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
|
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
running_mean = self.running_mean.to(weight.dtype) if self.running_mean is not None else None
|
running_mean = self.running_mean.to(weight.dtype) if self.running_mean is not None else None
|
||||||
running_var = self.running_var.to(weight.dtype) if self.running_var is not None else None
|
running_var = self.running_var.to(weight.dtype) if self.running_var is not None else None
|
||||||
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, False, self.momentum, self.eps)
|
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, False, self.momentum, self.eps)
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
run_every_op()
|
run_every_op()
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user