mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 05:23:37 +08:00
coderabbit suggestion for batchnorm2d
This commit is contained in:
parent
538cdab031
commit
0a5a986bcf
@ -444,9 +444,9 @@ class disable_weight_init:
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
running_mean = self.running_mean.to(weight.dtype) if self.running_mean is not None else None
|
||||
running_var = self.running_var.to(weight.dtype) if self.running_var is not None else None
|
||||
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, False, self.momentum, self.eps)
|
||||
running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None
|
||||
running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None
|
||||
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user