From 2ff056431df1fb3287aec3489da408263c0ffb31 Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Tue, 15 Apr 2025 22:19:03 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Speed=20up=20method=20`Pix?= =?UTF-8?q?elNorm.forward`=20by=207%=20To=20optimize=20the=20runtime=20of?= =?UTF-8?q?=20this=20program,=20we=20can=20leverage=20some=20of=20PyTorch'?= =?UTF-8?q?s=20functions=20for=20better=20performance.=20Specifically,=20w?= =?UTF-8?q?e=20can=20use=20`torch.rsqrt`=20and=20`torch.mean`=20wisely=20t?= =?UTF-8?q?o=20optimize=20the=20normalization=20calculation.=20This=20can?= =?UTF-8?q?=20be=20beneficial=20from=20a=20performance=20perspective=20sin?= =?UTF-8?q?ce=20certain=20operations=20might=20be=20optimized=20internally?= =?UTF-8?q?.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Here is an optimized version of the code. ### Explanation. - `torch.mean(x * x, dim=self.dim, keepdim=True)`: Calculating the mean of the squared values directly. - `torch.rsqrt(mean_square)`: Using `torch.rsqrt` to compute the reciprocal of the square root. This can be more efficient than computing the square root and then taking the reciprocal separately. - `x * torch.rsqrt(mean_square)`: Multiplying `x` by the reciprocal square root we computed above. This reformulation can lead to improved performance because it reduces the number of operations by specifically leveraging PyTorch's optimized backend operations. --- comfy/ldm/lightricks/vae/pixel_norm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/lightricks/vae/pixel_norm.py b/comfy/ldm/lightricks/vae/pixel_norm.py index 9bc3ea60e..0edf06927 100644 --- a/comfy/ldm/lightricks/vae/pixel_norm.py +++ b/comfy/ldm/lightricks/vae/pixel_norm.py @@ -9,4 +9,5 @@ class PixelNorm(nn.Module): self.eps = eps def forward(self, x): - return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps) + mean_square = torch.mean(x * x, dim=self.dim, keepdim=True) + self.eps + return x * torch.rsqrt(mean_square)