mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-23 00:42:30 +08:00
259 lines
12 KiB
Python
259 lines
12 KiB
Python
"""FILM: Frame Interpolation for Large Motion (ECCV 2022)."""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import comfy.ops
|
|
|
|
ops = comfy.ops.disable_weight_init
|
|
|
|
|
|
class FilmConv2d(nn.Module):
|
|
"""Conv2d with optional LeakyReLU and FILM-style padding."""
|
|
|
|
def __init__(self, in_channels, out_channels, size, activation=True, device=None, dtype=None, operations=ops):
|
|
super().__init__()
|
|
self.even_pad = not size % 2
|
|
self.conv = operations.Conv2d(in_channels, out_channels, kernel_size=size, padding=size // 2 if size % 2 else 0, device=device, dtype=dtype)
|
|
self.activation = nn.LeakyReLU(0.2) if activation else None
|
|
|
|
def forward(self, x):
|
|
if self.even_pad:
|
|
x = F.pad(x, (0, 1, 0, 1))
|
|
x = self.conv(x)
|
|
if self.activation is not None:
|
|
x = self.activation(x)
|
|
return x
|
|
|
|
|
|
def _warp_core(image, flow, grid_x, grid_y):
|
|
dtype = image.dtype
|
|
H, W = flow.shape[2], flow.shape[3]
|
|
dx = flow[:, 0].float() / (W * 0.5)
|
|
dy = flow[:, 1].float() / (H * 0.5)
|
|
grid = torch.stack([grid_x[None, None, :] + dx, grid_y[None, :, None] + dy], dim=3)
|
|
return F.grid_sample(image.float(), grid, mode="bilinear", padding_mode="border", align_corners=False).to(dtype)
|
|
|
|
|
|
def build_image_pyramid(image, pyramid_levels):
|
|
pyramid = [image]
|
|
for _ in range(1, pyramid_levels):
|
|
image = F.avg_pool2d(image, 2, 2)
|
|
pyramid.append(image)
|
|
return pyramid
|
|
|
|
|
|
def flow_pyramid_synthesis(residual_pyramid):
|
|
flow = residual_pyramid[-1]
|
|
flow_pyramid = [flow]
|
|
for residual_flow in residual_pyramid[:-1][::-1]:
|
|
flow = F.interpolate(flow, size=residual_flow.shape[2:4], mode="bilinear", scale_factor=None).mul_(2).add_(residual_flow)
|
|
flow_pyramid.append(flow)
|
|
flow_pyramid.reverse()
|
|
return flow_pyramid
|
|
|
|
|
|
def multiply_pyramid(pyramid, scalar):
|
|
return [image * scalar[:, None, None, None] for image in pyramid]
|
|
|
|
|
|
def pyramid_warp(feature_pyramid, flow_pyramid, warp_fn):
|
|
return [warp_fn(features, flow) for features, flow in zip(feature_pyramid, flow_pyramid)]
|
|
|
|
|
|
def concatenate_pyramids(pyramid1, pyramid2):
|
|
return [torch.cat([f1, f2], dim=1) for f1, f2 in zip(pyramid1, pyramid2)]
|
|
|
|
|
|
class SubTreeExtractor(nn.Module):
|
|
def __init__(self, in_channels=3, channels=64, n_layers=4, device=None, dtype=None, operations=ops):
|
|
super().__init__()
|
|
convs = []
|
|
for i in range(n_layers):
|
|
out_ch = channels << i
|
|
convs.append(nn.Sequential(
|
|
FilmConv2d(in_channels, out_ch, 3, device=device, dtype=dtype, operations=operations),
|
|
FilmConv2d(out_ch, out_ch, 3, device=device, dtype=dtype, operations=operations)))
|
|
in_channels = out_ch
|
|
self.convs = nn.ModuleList(convs)
|
|
|
|
def forward(self, image, n):
|
|
head = image
|
|
pyramid = []
|
|
for i, layer in enumerate(self.convs):
|
|
head = layer(head)
|
|
pyramid.append(head)
|
|
if i < n - 1:
|
|
head = F.avg_pool2d(head, 2, 2)
|
|
return pyramid
|
|
|
|
|
|
class FeatureExtractor(nn.Module):
|
|
def __init__(self, in_channels=3, channels=64, sub_levels=4, device=None, dtype=None, operations=ops):
|
|
super().__init__()
|
|
self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels, device=device, dtype=dtype, operations=operations)
|
|
self.sub_levels = sub_levels
|
|
|
|
def forward(self, image_pyramid):
|
|
sub_pyramids = [self.extract_sublevels(image_pyramid[i], min(len(image_pyramid) - i, self.sub_levels))
|
|
for i in range(len(image_pyramid))]
|
|
feature_pyramid = []
|
|
for i in range(len(image_pyramid)):
|
|
features = sub_pyramids[i][0]
|
|
for j in range(1, self.sub_levels):
|
|
if j <= i:
|
|
features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
|
|
feature_pyramid.append(features)
|
|
# Free sub-pyramids no longer needed by future levels
|
|
if i >= self.sub_levels - 1:
|
|
sub_pyramids[i - self.sub_levels + 1] = None
|
|
return feature_pyramid
|
|
|
|
|
|
class FlowEstimator(nn.Module):
|
|
def __init__(self, in_channels, num_convs, num_filters, device=None, dtype=None, operations=ops):
|
|
super().__init__()
|
|
self._convs = nn.ModuleList()
|
|
for _ in range(num_convs):
|
|
self._convs.append(FilmConv2d(in_channels, num_filters, 3, device=device, dtype=dtype, operations=operations))
|
|
in_channels = num_filters
|
|
self._convs.append(FilmConv2d(in_channels, num_filters // 2, 1, device=device, dtype=dtype, operations=operations))
|
|
self._convs.append(FilmConv2d(num_filters // 2, 2, 1, activation=False, device=device, dtype=dtype, operations=operations))
|
|
|
|
def forward(self, features_a, features_b):
|
|
net = torch.cat([features_a, features_b], dim=1)
|
|
for conv in self._convs:
|
|
net = conv(net)
|
|
return net
|
|
|
|
|
|
class PyramidFlowEstimator(nn.Module):
|
|
def __init__(self, filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops):
|
|
super().__init__()
|
|
in_channels = filters << 1
|
|
predictors = []
|
|
for i in range(len(flow_convs)):
|
|
predictors.append(FlowEstimator(in_channels, flow_convs[i], flow_filters[i], device=device, dtype=dtype, operations=operations))
|
|
in_channels += filters << (i + 2)
|
|
self._predictor = predictors[-1]
|
|
self._predictors = nn.ModuleList(predictors[:-1][::-1])
|
|
|
|
def forward(self, feature_pyramid_a, feature_pyramid_b, warp_fn):
|
|
levels = len(feature_pyramid_a)
|
|
v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1])
|
|
residuals = [v]
|
|
# Coarse-to-fine: shared predictor for deep levels, then specialized predictors for fine levels
|
|
steps = [(i, self._predictor) for i in range(levels - 2, len(self._predictors) - 1, -1)]
|
|
steps += [(len(self._predictors) - 1 - k, p) for k, p in enumerate(self._predictors)]
|
|
for i, predictor in steps:
|
|
v = F.interpolate(v, size=feature_pyramid_a[i].shape[2:4], mode="bilinear").mul_(2)
|
|
v_residual = predictor(feature_pyramid_a[i], warp_fn(feature_pyramid_b[i], v))
|
|
residuals.append(v_residual)
|
|
v = v.add_(v_residual)
|
|
residuals.reverse()
|
|
return residuals
|
|
|
|
|
|
def _get_fusion_channels(level, filters):
|
|
# Per direction: multi-scale features + RGB image (3ch) + flow (2ch), doubled for both directions
|
|
return (sum(filters << i for i in range(level)) + 3 + 2) * 2
|
|
|
|
|
|
class Fusion(nn.Module):
|
|
def __init__(self, n_layers=4, specialized_layers=3, filters=64, device=None, dtype=None, operations=ops):
|
|
super().__init__()
|
|
self.output_conv = operations.Conv2d(filters, 3, kernel_size=1, device=device, dtype=dtype)
|
|
self.convs = nn.ModuleList()
|
|
in_channels = _get_fusion_channels(n_layers, filters)
|
|
increase = 0
|
|
for i in range(n_layers)[::-1]:
|
|
num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers)
|
|
self.convs.append(nn.ModuleList([
|
|
FilmConv2d(in_channels, num_filters, 2, activation=False, device=device, dtype=dtype, operations=operations),
|
|
FilmConv2d(in_channels + (increase or num_filters), num_filters, 3, device=device, dtype=dtype, operations=operations),
|
|
FilmConv2d(num_filters, num_filters, 3, device=device, dtype=dtype, operations=operations)]))
|
|
in_channels = num_filters
|
|
increase = _get_fusion_channels(i, filters) - num_filters // 2
|
|
|
|
def forward(self, pyramid):
|
|
net = pyramid[-1]
|
|
for k, layers in enumerate(self.convs):
|
|
i = len(self.convs) - 1 - k
|
|
net = layers[0](F.interpolate(net, size=pyramid[i].shape[2:4], mode="nearest"))
|
|
net = layers[2](layers[1](torch.cat([pyramid[i], net], dim=1)))
|
|
return self.output_conv(net)
|
|
|
|
|
|
class FILMNet(nn.Module):
|
|
def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, sub_levels=4,
|
|
filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops):
|
|
super().__init__()
|
|
self.pyramid_levels = pyramid_levels
|
|
self.fusion_pyramid_levels = fusion_pyramid_levels
|
|
self.extract = FeatureExtractor(3, filters, sub_levels, device=device, dtype=dtype, operations=operations)
|
|
self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters, device=device, dtype=dtype, operations=operations)
|
|
self.fuse = Fusion(sub_levels, specialized_levels, filters, device=device, dtype=dtype, operations=operations)
|
|
self._warp_grids = {}
|
|
|
|
def get_dtype(self):
|
|
return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype
|
|
|
|
def _build_warp_grids(self, H, W, device):
|
|
"""Pre-compute warp grids for all pyramid levels."""
|
|
if (H, W) in self._warp_grids:
|
|
return
|
|
self._warp_grids = {} # clear old resolution grids to prevent memory leaks
|
|
for _ in range(self.pyramid_levels):
|
|
self._warp_grids[(H, W)] = (
|
|
torch.linspace(-(1 - 1 / W), 1 - 1 / W, W, dtype=torch.float32, device=device),
|
|
torch.linspace(-(1 - 1 / H), 1 - 1 / H, H, dtype=torch.float32, device=device),
|
|
)
|
|
H, W = H // 2, W // 2
|
|
|
|
def warp(self, image, flow):
|
|
grid_x, grid_y = self._warp_grids[(flow.shape[2], flow.shape[3])]
|
|
return _warp_core(image, flow, grid_x, grid_y)
|
|
|
|
def extract_features(self, img):
|
|
"""Extract image and feature pyramids for a single frame. Can be cached across pairs."""
|
|
image_pyramid = build_image_pyramid(img, self.pyramid_levels)
|
|
feature_pyramid = self.extract(image_pyramid)
|
|
return image_pyramid, feature_pyramid
|
|
|
|
def forward(self, img0, img1, timestep=0.5, cache=None):
|
|
# FILM uses a scalar timestep per batch element (spatially-varying timesteps not supported)
|
|
t = timestep.mean(dim=(1, 2, 3)).item() if isinstance(timestep, torch.Tensor) else timestep
|
|
return self.forward_multi_timestep(img0, img1, [t], cache=cache)
|
|
|
|
def forward_multi_timestep(self, img0, img1, timesteps, cache=None):
|
|
"""Compute flow once, synthesize at multiple timesteps. Expects batch=1 inputs."""
|
|
self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device)
|
|
|
|
image_pyr0, feat_pyr0 = cache["img0"] if cache and "img0" in cache else self.extract_features(img0)
|
|
image_pyr1, feat_pyr1 = cache["img1"] if cache and "img1" in cache else self.extract_features(img1)
|
|
|
|
fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels]
|
|
bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels]
|
|
|
|
# Build warp targets and free full pyramids (only first fpl levels needed from here)
|
|
fpl = self.fusion_pyramid_levels
|
|
p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]),
|
|
concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])]
|
|
del image_pyr0, image_pyr1, feat_pyr0, feat_pyr1
|
|
|
|
results = []
|
|
dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype)
|
|
for idx in range(len(timesteps)):
|
|
batch_dt = dt_tensors[idx:idx + 1]
|
|
bwd_scaled = multiply_pyramid(bwd_flow, batch_dt)
|
|
fwd_scaled = multiply_pyramid(fwd_flow, 1 - batch_dt)
|
|
fwd_warped = pyramid_warp(p2w[0], bwd_scaled, self.warp)
|
|
bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp)
|
|
aligned = [torch.cat([fw, bw, bf, ff], dim=1)
|
|
for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)]
|
|
del fwd_warped, bwd_warped, bwd_scaled, fwd_scaled
|
|
results.append(self.fuse(aligned))
|
|
del aligned
|
|
return torch.cat(results, dim=0)
|