ComfyUI/comfy/ldm/lightricks/latent_upsampler.py
2026-01-05 01:58:59 -05:00

293 lines
10 KiB
Python

from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
def _rational_for_scale(scale: float) -> Tuple[int, int]:
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
if float(scale) not in mapping:
raise ValueError(
f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}"
)
return mapping[float(scale)]
class PixelShuffleND(nn.Module):
def __init__(self, dims, upscale_factors=(2, 2, 2)):
super().__init__()
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
self.dims = dims
self.upscale_factors = upscale_factors
def forward(self, x):
if self.dims == 3:
return rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
p3=self.upscale_factors[2],
)
elif self.dims == 2:
return rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
)
elif self.dims == 1:
return rearrange(
x,
"b (c p1) f h w -> b c (f p1) h w",
p1=self.upscale_factors[0],
)
class BlurDownsample(nn.Module):
"""
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
"""
def __init__(self, dims: int, stride: int):
super().__init__()
assert dims in (2, 3)
assert stride >= 1 and isinstance(stride, int)
self.dims = dims
self.stride = stride
# 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized
k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0])
k2d = k[:, None] @ k[None, :]
k2d = (k2d / k2d.sum()).float() # shape (5,5)
self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.stride == 1:
return x
def _apply_2d(x2d: torch.Tensor) -> torch.Tensor:
# x2d: (B, C, H, W)
B, C, H, W = x2d.shape
weight = self.kernel.expand(C, 1, 5, 5) # depthwise
x2d = F.conv2d(
x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C
)
return x2d
if self.dims == 2:
return _apply_2d(x)
else:
# dims == 3: apply per-frame on H,W
b, c, f, h, w = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = _apply_2d(x)
h2, w2 = x.shape[-2:]
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
return x
class SpatialRationalResampler(nn.Module):
"""
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
downsample by 'den' using fixed blur + stride. Operates on H,W only.
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
"""
def __init__(self, mid_channels: int, scale: float):
super().__init__()
self.scale = float(scale)
self.num, self.den = _rational_for_scale(self.scale)
self.conv = nn.Conv2d(
mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1
)
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
self.blur_down = BlurDownsample(dims=2, stride=self.den)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, f, h, w = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.blur_down(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
return x
class ResBlock(nn.Module):
def __init__(
self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
):
super().__init__()
if mid_channels is None:
mid_channels = channels
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
self.norm1 = nn.GroupNorm(32, mid_channels)
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
self.norm2 = nn.GroupNorm(32, channels)
self.activation = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.activation(x + residual)
return x
class LatentUpsampler(nn.Module):
"""
Model to spatially upsample VAE latents.
Args:
in_channels (`int`): Number of channels in the input latent
mid_channels (`int`): Number of channels in the middle layers
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
dims (`int`): Number of dimensions for convolutions (2 or 3)
spatial_upsample (`bool`): Whether to spatially upsample the latent
temporal_upsample (`bool`): Whether to temporally upsample the latent
"""
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 512,
num_blocks_per_stage: int = 4,
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
spatial_scale: float = 2.0,
rational_resampler: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.num_blocks_per_stage = num_blocks_per_stage
self.dims = dims
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
self.spatial_scale = float(spatial_scale)
self.rational_resampler = rational_resampler
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = nn.GroupNorm(32, mid_channels)
self.initial_activation = nn.SiLU()
self.res_blocks = nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
if spatial_upsample and temporal_upsample:
self.upsampler = nn.Sequential(
nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(3),
)
elif spatial_upsample:
if rational_resampler:
self.upsampler = SpatialRationalResampler(
mid_channels=mid_channels, scale=self.spatial_scale
)
else:
self.upsampler = nn.Sequential(
nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(2),
)
elif temporal_upsample:
self.upsampler = nn.Sequential(
nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(1),
)
else:
raise ValueError(
"Either spatial_upsample or temporal_upsample must be True"
)
self.post_upsample_res_blocks = nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
def forward(self, latent: torch.Tensor) -> torch.Tensor:
b, c, f, h, w = latent.shape
if self.dims == 2:
x = rearrange(latent, "b c f h w -> (b f) c h w")
x = self.initial_conv(x)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
x = self.upsampler(x)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
else:
x = self.initial_conv(latent)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
if self.temporal_upsample:
x = self.upsampler(x)
x = x[:, :, 1:, :, :]
else:
if isinstance(self.upsampler, SpatialRationalResampler):
x = self.upsampler(x)
else:
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.upsampler(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
return x
@classmethod
def from_config(cls, config):
return cls(
in_channels=config.get("in_channels", 4),
mid_channels=config.get("mid_channels", 128),
num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
dims=config.get("dims", 2),
spatial_upsample=config.get("spatial_upsample", True),
temporal_upsample=config.get("temporal_upsample", False),
spatial_scale=config.get("spatial_scale", 2.0),
rational_resampler=config.get("rational_resampler", False),
)
def config(self):
return {
"_class_name": "LatentUpsampler",
"in_channels": self.in_channels,
"mid_channels": self.mid_channels,
"num_blocks_per_stage": self.num_blocks_per_stage,
"dims": self.dims,
"spatial_upsample": self.spatial_upsample,
"temporal_upsample": self.temporal_upsample,
"spatial_scale": self.spatial_scale,
"rational_resampler": self.rational_resampler,
}