diff --git a/comfy/ldm/lightricks/vae/pixel_norm.py b/comfy/ldm/lightricks/vae/pixel_norm.py index 9bc3ea60e..0edf06927 100644 --- a/comfy/ldm/lightricks/vae/pixel_norm.py +++ b/comfy/ldm/lightricks/vae/pixel_norm.py @@ -9,4 +9,5 @@ class PixelNorm(nn.Module): self.eps = eps def forward(self, x): - return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps) + mean_square = torch.mean(x * x, dim=self.dim, keepdim=True) + self.eps + return x * torch.rsqrt(mean_square)