mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-14 16:20:50 +08:00
To optimize the runtime of this program, we can leverage some of PyTorch's functions for better performance. Specifically, we can use `torch.rsqrt` and `torch.mean` wisely to optimize the normalization calculation. This can be beneficial from a performance perspective since certain operations might be optimized internally. Here is an optimized version of the code. ### Explanation. - `torch.mean(x * x, dim=self.dim, keepdim=True)`: Calculating the mean of the squared values directly. - `torch.rsqrt(mean_square)`: Using `torch.rsqrt` to compute the reciprocal of the square root. This can be more efficient than computing the square root and then taking the reciprocal separately. - `x * torch.rsqrt(mean_square)`: Multiplying `x` by the reciprocal square root we computed above. This reformulation can lead to improved performance because it reduces the number of operations by specifically leveraging PyTorch's optimized backend operations.
14 lines
343 B
Python
14 lines
343 B
Python
import torch
|
|
from torch import nn
|
|
|
|
|
|
class PixelNorm(nn.Module):
|
|
def __init__(self, dim=1, eps=1e-8):
|
|
super(PixelNorm, self).__init__()
|
|
self.dim = dim
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
mean_square = torch.mean(x * x, dim=self.dim, keepdim=True) + self.eps
|
|
return x * torch.rsqrt(mean_square)
|