Compare commits

...

2 Commits

Author SHA1 Message Date
Aseem Saxena
9b1a7dd83a
Merge 2ff056431d into c4a14df9a3 2026-01-21 00:13:50 +00:00
codeflash-ai[bot]
2ff056431d
️ Speed up method PixelNorm.forward by 7%
To optimize the runtime of this program, we can leverage some of PyTorch's functions for better performance. Specifically, we can use `torch.rsqrt` and `torch.mean` wisely to optimize the normalization calculation. This can be beneficial from a performance perspective since certain operations might be optimized internally.

Here is an optimized version of the code.



### Explanation.
- `torch.mean(x * x, dim=self.dim, keepdim=True)`: Calculating the mean of the squared values directly.
- `torch.rsqrt(mean_square)`: Using `torch.rsqrt` to compute the reciprocal of the square root. This can be more efficient than computing the square root and then taking the reciprocal separately.
- `x * torch.rsqrt(mean_square)`: Multiplying `x` by the reciprocal square root we computed above.

This reformulation can lead to improved performance because it reduces the number of operations by specifically leveraging PyTorch's optimized backend operations.
2025-04-15 22:19:03 +00:00

View File

@ -9,4 +9,5 @@ class PixelNorm(nn.Module):
self.eps = eps
def forward(self, x):
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
mean_square = torch.mean(x * x, dim=self.dim, keepdim=True) + self.eps
return x * torch.rsqrt(mean_square)