Also support FILM

This commit is contained in:
kijai 2026-04-04 16:09:34 +03:00
parent a859152817
commit 257c5312d9
4 changed files with 464 additions and 226 deletions

View File

@ -0,0 +1,251 @@
"""FILM: Frame Interpolation for Large Motion (ECCV 2022)."""
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
ops = comfy.ops.disable_weight_init
class FilmConv2d(nn.Module):
"""Conv2d with optional LeakyReLU and FILM-style padding."""
def __init__(self, in_channels, out_channels, size, activation=True, device=None, dtype=None, operations=ops):
super().__init__()
self.even_pad = not size % 2
self.conv = operations.Conv2d(in_channels, out_channels, kernel_size=size, padding=size // 2 if size % 2 else 0, device=device, dtype=dtype)
self.activation = nn.LeakyReLU(0.2) if activation else None
def forward(self, x):
if self.even_pad:
x = F.pad(x, (0, 1, 0, 1))
x = self.conv(x)
if self.activation is not None:
x = self.activation(x)
return x
def _warp_core(image, flow, grid_x, grid_y):
dtype = image.dtype
H, W = flow.shape[2], flow.shape[3]
dx = flow[:, 0].float() / (W * 0.5)
dy = flow[:, 1].float() / (H * 0.5)
grid = torch.stack([grid_x[None, None, :] + dx, grid_y[None, :, None] + dy], dim=3)
return F.grid_sample(image.float(), grid, mode="bilinear", padding_mode="border", align_corners=False).to(dtype)
def build_image_pyramid(image, pyramid_levels):
pyramid = [image]
for _ in range(1, pyramid_levels):
image = F.avg_pool2d(image, 2, 2)
pyramid.append(image)
return pyramid
def flow_pyramid_synthesis(residual_pyramid):
flow = residual_pyramid[-1]
flow_pyramid = [flow]
for residual_flow in residual_pyramid[:-1][::-1]:
flow = F.interpolate(flow, size=residual_flow.shape[2:4], mode="bilinear", scale_factor=None).mul_(2).add_(residual_flow)
flow_pyramid.append(flow)
flow_pyramid.reverse()
return flow_pyramid
def multiply_pyramid(pyramid, scalar):
return [image * scalar[:, None, None, None] for image in pyramid]
def pyramid_warp(feature_pyramid, flow_pyramid, warp_fn):
return [warp_fn(features, flow) for features, flow in zip(feature_pyramid, flow_pyramid)]
def concatenate_pyramids(pyramid1, pyramid2):
return [torch.cat([f1, f2], dim=1) for f1, f2 in zip(pyramid1, pyramid2)]
class SubTreeExtractor(nn.Module):
def __init__(self, in_channels=3, channels=64, n_layers=4, device=None, dtype=None, operations=ops):
super().__init__()
convs = []
for i in range(n_layers):
out_ch = channels << i
convs.append(nn.Sequential(
FilmConv2d(in_channels, out_ch, 3, device=device, dtype=dtype, operations=operations),
FilmConv2d(out_ch, out_ch, 3, device=device, dtype=dtype, operations=operations)))
in_channels = out_ch
self.convs = nn.ModuleList(convs)
def forward(self, image, n):
head = image
pyramid = []
for i, layer in enumerate(self.convs):
head = layer(head)
pyramid.append(head)
if i < n - 1:
head = F.avg_pool2d(head, 2, 2)
return pyramid
class FeatureExtractor(nn.Module):
def __init__(self, in_channels=3, channels=64, sub_levels=4, device=None, dtype=None, operations=ops):
super().__init__()
self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels, device=device, dtype=dtype, operations=operations)
self.sub_levels = sub_levels
def forward(self, image_pyramid):
sub_pyramids = [self.extract_sublevels(image_pyramid[i], min(len(image_pyramid) - i, self.sub_levels))
for i in range(len(image_pyramid))]
feature_pyramid = []
for i in range(len(image_pyramid)):
features = sub_pyramids[i][0]
for j in range(1, self.sub_levels):
if j <= i:
features = torch.cat([features, sub_pyramids[i - j][j]], dim=1)
feature_pyramid.append(features)
return feature_pyramid
class FlowEstimator(nn.Module):
def __init__(self, in_channels, num_convs, num_filters, device=None, dtype=None, operations=ops):
super().__init__()
self._convs = nn.ModuleList()
for _ in range(num_convs):
self._convs.append(FilmConv2d(in_channels, num_filters, 3, device=device, dtype=dtype, operations=operations))
in_channels = num_filters
self._convs.append(FilmConv2d(in_channels, num_filters // 2, 1, device=device, dtype=dtype, operations=operations))
self._convs.append(FilmConv2d(num_filters // 2, 2, 1, activation=False, device=device, dtype=dtype, operations=operations))
def forward(self, features_a, features_b):
net = torch.cat([features_a, features_b], dim=1)
for conv in self._convs:
net = conv(net)
return net
class PyramidFlowEstimator(nn.Module):
def __init__(self, filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops):
super().__init__()
in_channels = filters << 1
predictors = []
for i in range(len(flow_convs)):
predictors.append(FlowEstimator(in_channels, flow_convs[i], flow_filters[i], device=device, dtype=dtype, operations=operations))
in_channels += filters << (i + 2)
self._predictor = predictors[-1]
self._predictors = nn.ModuleList(predictors[:-1][::-1])
def forward(self, feature_pyramid_a, feature_pyramid_b, warp_fn):
levels = len(feature_pyramid_a)
v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1])
residuals = [v]
# Coarse-to-fine: shared predictor for deep levels, then specialized predictors for fine levels
steps = [(i, self._predictor) for i in range(levels - 2, len(self._predictors) - 1, -1)]
steps += [(len(self._predictors) - 1 - k, p) for k, p in enumerate(self._predictors)]
for i, predictor in steps:
v = F.interpolate(v, size=feature_pyramid_a[i].shape[2:4], mode="bilinear").mul_(2)
v_residual = predictor(feature_pyramid_a[i], warp_fn(feature_pyramid_b[i], v))
residuals.append(v_residual)
v = v.add_(v_residual)
residuals.reverse()
return residuals
def _get_fusion_channels(level, filters):
# Per direction: multi-scale features + RGB image (3ch) + flow (2ch), doubled for both directions
return (sum(filters << i for i in range(level)) + 3 + 2) * 2
class Fusion(nn.Module):
def __init__(self, n_layers=4, specialized_layers=3, filters=64, device=None, dtype=None, operations=ops):
super().__init__()
self.output_conv = operations.Conv2d(filters, 3, kernel_size=1, device=device, dtype=dtype)
self.convs = nn.ModuleList()
in_channels = _get_fusion_channels(n_layers, filters)
increase = 0
for i in range(n_layers)[::-1]:
num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers)
self.convs.append(nn.ModuleList([
FilmConv2d(in_channels, num_filters, 2, activation=False, device=device, dtype=dtype, operations=operations),
FilmConv2d(in_channels + (increase or num_filters), num_filters, 3, device=device, dtype=dtype, operations=operations),
FilmConv2d(num_filters, num_filters, 3, device=device, dtype=dtype, operations=operations)]))
in_channels = num_filters
increase = _get_fusion_channels(i, filters) - num_filters // 2
def forward(self, pyramid):
net = pyramid[-1]
for k, layers in enumerate(self.convs):
i = len(self.convs) - 1 - k
net = layers[0](F.interpolate(net, size=pyramid[i].shape[2:4], mode="nearest"))
net = layers[2](layers[1](torch.cat([pyramid[i], net], dim=1)))
return self.output_conv(net)
class FILMNet(nn.Module):
def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, sub_levels=4,
filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops):
super().__init__()
self.pyramid_levels = pyramid_levels
self.fusion_pyramid_levels = fusion_pyramid_levels
self.extract = FeatureExtractor(3, filters, sub_levels, device=device, dtype=dtype, operations=operations)
self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters, device=device, dtype=dtype, operations=operations)
self.fuse = Fusion(sub_levels, specialized_levels, filters, device=device, dtype=dtype, operations=operations)
self._warp_grids = {}
def get_dtype(self):
return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype
def _build_warp_grids(self, H, W, device):
"""Pre-compute warp grids for all pyramid levels."""
if (H, W) in self._warp_grids:
return
self._warp_grids = {} # clear old resolution grids to prevent memory leaks
for _ in range(self.pyramid_levels):
self._warp_grids[(H, W)] = (
torch.linspace(-(1 - 1 / W), 1 - 1 / W, W, dtype=torch.float32, device=device),
torch.linspace(-(1 - 1 / H), 1 - 1 / H, H, dtype=torch.float32, device=device),
)
H, W = H // 2, W // 2
def warp(self, image, flow):
grid_x, grid_y = self._warp_grids[(flow.shape[2], flow.shape[3])]
return _warp_core(image, flow, grid_x, grid_y)
def extract_features(self, img):
"""Extract image and feature pyramids for a single frame. Can be cached across pairs."""
image_pyramid = build_image_pyramid(img, self.pyramid_levels)
feature_pyramid = self.extract(image_pyramid)
return image_pyramid, feature_pyramid
def forward(self, img0, img1, timestep=0.5, cache=None):
# FILM uses a scalar timestep per batch element (spatially-varying timesteps not supported)
t = timestep.mean(dim=(1, 2, 3)).item() if isinstance(timestep, torch.Tensor) else timestep
return self.forward_multi_timestep(img0, img1, [t], cache=cache)
def forward_multi_timestep(self, img0, img1, timesteps, cache=None):
"""Compute flow once, synthesize at multiple timesteps. Expects batch=1 inputs."""
self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device)
image_pyr0, feat_pyr0 = cache["img0"] if cache and "img0" in cache else self.extract_features(img0)
image_pyr1, feat_pyr1 = cache["img1"] if cache and "img1" in cache else self.extract_features(img1)
fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels]
bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels]
fpl = self.fusion_pyramid_levels
p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]),
concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])]
results = []
dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype)
for idx in range(len(timesteps)):
batch_dt = dt_tensors[idx:idx + 1]
bwd_scaled = multiply_pyramid(bwd_flow, batch_dt)
fwd_scaled = multiply_pyramid(fwd_flow, 1 - batch_dt)
fwd_warped = pyramid_warp(p2w[0], bwd_scaled, self.warp)
bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp)
aligned = [torch.cat([fw, bw, bf, ff], dim=1)
for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)]
results.append(self.fuse(aligned))
return torch.cat(results, dim=0)

View File

@ -0,0 +1,128 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
ops = comfy.ops.disable_weight_init
def _warp(img, flow, warp_grids):
B, _, H, W = img.shape
base_grid, flow_div = warp_grids[(H, W)]
flow_norm = torch.cat([flow[:, 0:1] / flow_div[0], flow[:, 1:2] / flow_div[1]], 1).float()
grid = (base_grid.expand(B, -1, -1, -1) + flow_norm).permute(0, 2, 3, 1)
return F.grid_sample(img.float(), grid, mode="bilinear", padding_mode="border", align_corners=True).to(img.dtype)
class Head(nn.Module):
def __init__(self, out_ch=4, device=None, dtype=None, operations=ops):
super().__init__()
self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype)
self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype)
self.relu = nn.LeakyReLU(0.2, True)
def forward(self, x):
x = self.relu(self.cnn0(x))
x = self.relu(self.cnn1(x))
x = self.relu(self.cnn2(x))
return self.cnn3(x)
class ResConv(nn.Module):
def __init__(self, c, device=None, dtype=None, operations=ops):
super().__init__()
self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype)
self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype))
self.relu = nn.LeakyReLU(0.2, True)
def forward(self, x):
return self.relu(torch.addcmul(x, self.conv(x), self.beta))
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops):
super().__init__()
self.conv0 = nn.Sequential(
nn.Sequential(operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)),
nn.Sequential(operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)))
self.convblock = nn.Sequential(*(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8)))
self.lastconv = nn.Sequential(operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), nn.PixelShuffle(2))
def forward(self, x, flow=None, scale=1):
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear")
if flow is not None:
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale)
x = torch.cat((x, flow), 1)
feat = self.convblock(self.conv0(x))
tmp = F.interpolate(self.lastconv(feat), scale_factor=scale, mode="bilinear")
return tmp[:, :4] * scale, tmp[:, 4:5], tmp[:, 5:]
class IFNet(nn.Module):
def __init__(self, head_ch=4, channels=(192, 128, 96, 64, 32), device=None, dtype=None, operations=ops):
super().__init__()
self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations)
block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4
self.blocks = nn.ModuleList([IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) for i in range(5)])
self.scale_list = [16, 8, 4, 2, 1]
self.pad_align = 64
self._warp_grids = {}
def get_dtype(self):
return self.encode.cnn0.weight.dtype
def _build_warp_grids(self, H, W, device):
if (H, W) in self._warp_grids:
return
self._warp_grids = {} # clear old resolution grids to prevent memory leaks
grid_y, grid_x = torch.meshgrid(
torch.linspace(-1.0, 1.0, H, device=device, dtype=torch.float32),
torch.linspace(-1.0, 1.0, W, device=device, dtype=torch.float32), indexing="ij")
self._warp_grids[(H, W)] = (
torch.stack((grid_x, grid_y), dim=0).unsqueeze(0),
torch.tensor([(W - 1.0) / 2.0, (H - 1.0) / 2.0], dtype=torch.float32, device=device))
def warp(self, img, flow):
return _warp(img, flow, self._warp_grids)
def extract_features(self, img):
"""Extract head features for a single frame. Can be cached across pairs."""
return self.encode(img)
def forward(self, img0, img1, timestep=0.5, cache=None):
if not isinstance(timestep, torch.Tensor):
timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), timestep, device=img0.device, dtype=img0.dtype)
self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device)
B = img0.shape[0]
f0 = cache["img0"].expand(B, -1, -1, -1) if cache and "img0" in cache else self.encode(img0)
f1 = cache["img1"].expand(B, -1, -1, -1) if cache and "img1" in cache else self.encode(img1)
flow = mask = feat = None
warped_img0, warped_img1 = img0, img1
for i, block in enumerate(self.blocks):
if flow is None:
flow, mask, feat = block(torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i])
else:
fd, mask, feat = block(
torch.cat((warped_img0, warped_img1, self.warp(f0, flow[:, :2]), self.warp(f1, flow[:, 2:4]), timestep, mask, feat), 1),
flow, scale=self.scale_list[i])
flow = flow.add_(fd)
warped_img0 = self.warp(img0, flow[:, :2])
warped_img1 = self.warp(img1, flow[:, 2:4])
return torch.lerp(warped_img1, warped_img0, torch.sigmoid(mask))
def detect_rife_config(state_dict):
head_ch = state_dict["encode.cnn3.weight"].shape[1] # ConvTranspose2d: (in_ch, out_ch, kH, kW)
channels = []
for i in range(5):
key = f"blocks.{i}.conv0.1.0.weight"
if key in state_dict:
channels.append(state_dict[key].shape[0])
if len(channels) != 5:
raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}")
return head_ch, channels

View File

@ -1,5 +1,4 @@
import torch
import torch.nn.functional as F
from tqdm import tqdm
from typing_extensions import override
@ -7,7 +6,8 @@ import comfy.model_patcher
import comfy.utils
import folder_paths
from comfy import model_management
from comfy_extras.rife_model.ifnet import IFNet, detect_rife_config, clear_warp_cache
from comfy_extras.frame_interpolation_models.ifnet import IFNet, detect_rife_config
from comfy_extras.frame_interpolation_models.film_net import FILMNet
from comfy_api.latest import ComfyExtension, io
FrameInterpolationModel = io.Custom("FRAME_INTERPOLATION_MODEL")
@ -33,32 +33,8 @@ class FrameInterpolationModelLoader(io.ComfyNode):
model_path = folder_paths.get_full_path_or_raise("frame_interpolation", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
# Strip common prefixes (DataParallel, RIFE model wrapper)
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""})
# Convert blockN.xxx keys to blocks.N.xxx if needed
key_map = {}
for k in sd:
for i in range(5):
prefix = f"block{i}."
if k.startswith(prefix):
key_map[k] = f"blocks.{i}.{k[len(prefix):]}"
if key_map:
new_sd = {}
for k, v in sd.items():
new_sd[key_map.get(k, k)] = v
sd = new_sd
# Filter out training-only keys (teacher distillation, timestamp calibration)
sd = {k: v for k, v in sd.items()
if not k.startswith(("teacher.", "caltime."))}
head_ch, channels = detect_rife_config(sd)
model = IFNet(head_ch=head_ch, channels=channels)
model.load_state_dict(sd)
# RIFE is a small pixel-space model similar to VAE, bf16 produces artifacts due to low mantissa precision
dtype = model_management.vae_dtype(device=model_management.get_torch_device(),
allowed_dtypes=[torch.float16, torch.float32])
model = cls._detect_and_load(sd)
dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32
model.eval().to(dtype)
patcher = comfy.model_patcher.ModelPatcher(
model,
@ -67,6 +43,33 @@ class FrameInterpolationModelLoader(io.ComfyNode):
)
return io.NodeOutput(patcher)
@classmethod
def _detect_and_load(cls, sd):
# Try FILM
if "extract.extract_sublevels.convs.0.0.conv.weight" in sd:
model = FILMNet()
model.load_state_dict(sd)
return model
# Try RIFE (needs key remapping for raw checkpoints)
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""})
key_map = {}
for k in sd:
for i in range(5):
if k.startswith(f"block{i}."):
key_map[k] = f"blocks.{i}.{k[len(f'block{i}.'):]}"
if key_map:
sd = {key_map.get(k, k): v for k, v in sd.items()}
sd = {k: v for k, v in sd.items() if not k.startswith(("teacher.", "caltime."))}
try:
head_ch, channels = detect_rife_config(sd)
except (KeyError, ValueError):
raise ValueError("Unrecognized frame interpolation model format")
model = IFNet(head_ch=head_ch, channels=channels)
model.load_state_dict(sd)
return model
class FrameInterpolate(io.ComfyNode):
@classmethod
@ -75,11 +78,13 @@ class FrameInterpolate(io.ComfyNode):
node_id="FrameInterpolate",
display_name="Frame Interpolate",
category="image/video",
search_aliases=["rife", "frame interpolation", "slow motion", "interpolate frames"],
search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"],
inputs=[
FrameInterpolationModel.Input("model"),
io.Image.Input("images"),
io.Int.Input("multiplier", default=2, min=2, max=16),
io.Boolean.Input("torch_compile", default=False, optional=True, advanced=True,
tooltip="Requires triton. Compile model submodules for potential speed increase. Adds warmup on first run, recompiles on resolution change."),
],
outputs=[
io.Image.Output(),
@ -87,7 +92,7 @@ class FrameInterpolate(io.ComfyNode):
)
@classmethod
def execute(cls, model, images, multiplier) -> io.NodeOutput:
def execute(cls, model, images, multiplier, torch_compile=False) -> io.NodeOutput:
offload_device = model_management.intermediate_device()
num_frames = images.shape[0]
@ -96,18 +101,27 @@ class FrameInterpolate(io.ComfyNode):
model_management.load_model_gpu(model)
device = model.load_device
inference_model = model.model
dtype = model.model_dtype()
inference_model = model.model
# BHWC -> BCHW
frames = images.movedim(-1, 1).to(dtype=dtype, device=offload_device)
_, C, H, W = frames.shape
# Pad to multiple of 64
pad_h = (64 - H % 64) % 64
pad_w = (64 - W % 64) % 64
if pad_h > 0 or pad_w > 0:
frames = F.pad(frames, (0, pad_w, 0, pad_h), mode="reflect")
# Pad to model's required alignment (RIFE needs 64, FILM handles any size)
align = getattr(inference_model, "pad_align", 1)
if align > 1:
from comfy.ldm.common_dit import pad_to_patch_size
frames = pad_to_patch_size(frames, (align, align), padding_mode="reflect")
if torch_compile:
for name, child in inference_model.named_children():
if isinstance(child, (torch.nn.ModuleList, torch.nn.ModuleDict)):
continue
if not hasattr(child, "_compiled"):
compiled = torch.compile(child)
compiled._compiled = True
setattr(inference_model, name, compiled)
# Count total interpolation passes for progress bar
total_pairs = num_frames - 1
@ -116,7 +130,7 @@ class FrameInterpolate(io.ComfyNode):
pbar = comfy.utils.ProgressBar(total_steps)
tqdm_bar = tqdm(total=total_steps, desc="Frame interpolation")
batch = num_interp
batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit)
t_values = [t / multiplier for t in range(1, multiplier)]
_, _, pH, pW = frames.shape
@ -131,42 +145,55 @@ class FrameInterpolate(io.ComfyNode):
ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1)
ts_full = ts_full.expand(-1, 1, pH, pW)
multi_fn = getattr(inference_model, "forward_multi_timestep", None)
feat_cache = {}
try:
for i in range(total_pairs):
img0_single = frames[i:i + 1].to(device)
img1_single = frames[i + 1:i + 2].to(device)
j = 0
while j < num_interp:
b = min(batch, num_interp - j)
try:
img0 = img0_single.expand(b, -1, -1, -1)
img1 = img1_single.expand(b, -1, -1, -1)
mids = inference_model(img0, img1, timestep=ts_full[j:j + b])
result[out_idx:out_idx + b].copy_(mids.to(dtype=dtype), non_blocking=use_pin)
out_idx += b
pbar.update(b)
tqdm_bar.update(b)
j += b
except model_management.OOM_EXCEPTION:
if batch <= 1:
raise
batch = max(1, batch // 2)
model_management.soft_empty_cache()
# Cache features: img1 of pair N becomes img0 of pair N+1
feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single)
feat_cache["img1"] = inference_model.extract_features(img1_single)
feat_cache["next"] = feat_cache["img1"]
if multi_fn is not None:
# Models with timestep-independent flow can compute it once for all timesteps
mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache)
result[out_idx:out_idx + num_interp].copy_(mids.to(dtype=dtype), non_blocking=use_pin)
out_idx += num_interp
pbar.update(num_interp)
tqdm_bar.update(num_interp)
else:
j = 0
while j < num_interp:
b = min(batch, num_interp - j)
try:
img0 = img0_single.expand(b, -1, -1, -1)
img1 = img1_single.expand(b, -1, -1, -1)
mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache)
result[out_idx:out_idx + b].copy_(mids.to(dtype=dtype), non_blocking=use_pin)
out_idx += b
pbar.update(b)
tqdm_bar.update(b)
j += b
except model_management.OOM_EXCEPTION:
if batch <= 1:
raise
batch = max(1, batch // 2)
model_management.soft_empty_cache()
result[out_idx].copy_(frames[i + 1])
out_idx += 1
finally:
tqdm_bar.close()
clear_warp_cache()
if use_pin:
model_management.synchronize()
model_management.unpin_memory(result)
# Crop padding and BCHW -> BHWC
if pad_h > 0 or pad_w > 0:
result = result[:, :, :H, :W]
result = result.movedim(1, -1).clamp_(0.0, 1.0).to(dtype=model_management.intermediate_dtype())
result = result[:, :, :H, :W].movedim(1, -1).clamp_(0.0, 1.0).to(dtype=model_management.intermediate_dtype())
return io.NodeOutput(result)

View File

@ -1,168 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
ops = comfy.ops.disable_weight_init
_warp_grid_cache = {}
_WARP_GRID_CACHE_MAX = 4
def clear_warp_cache():
_warp_grid_cache.clear()
def warp(img, flow):
B, _, H, W = img.shape
dtype = img.dtype
img = img.float()
flow = flow.float()
cache_key = (H, W, flow.device)
if cache_key not in _warp_grid_cache:
if len(_warp_grid_cache) >= _WARP_GRID_CACHE_MAX:
_warp_grid_cache.pop(next(iter(_warp_grid_cache)))
grid_y, grid_x = torch.meshgrid(
torch.linspace(-1.0, 1.0, H, device=flow.device, dtype=torch.float32),
torch.linspace(-1.0, 1.0, W, device=flow.device, dtype=torch.float32),
indexing="ij",
)
_warp_grid_cache[cache_key] = torch.stack((grid_x, grid_y), dim=0).unsqueeze(0)
grid = _warp_grid_cache[cache_key].expand(B, -1, -1, -1)
flow_norm = torch.cat([
flow[:, 0:1] / ((W - 1) / 2),
flow[:, 1:2] / ((H - 1) / 2),
], dim=1)
grid = (grid + flow_norm).permute(0, 2, 3, 1)
return F.grid_sample(img, grid, mode="bilinear", padding_mode="border", align_corners=True).to(dtype)
class Head(nn.Module):
def __init__(self, out_ch=4, device=None, dtype=None, operations=ops):
super().__init__()
self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype)
self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype)
self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype)
self.relu = nn.LeakyReLU(0.2, True)
def forward(self, x):
x = self.relu(self.cnn0(x))
x = self.relu(self.cnn1(x))
x = self.relu(self.cnn2(x))
return self.cnn3(x)
class ResConv(nn.Module):
def __init__(self, c, device=None, dtype=None, operations=ops):
super().__init__()
self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype)
self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype))
self.relu = nn.LeakyReLU(0.2, True)
def forward(self, x):
return self.relu(torch.addcmul(x, self.conv(x), self.beta))
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops):
super().__init__()
self.conv0 = nn.Sequential(
nn.Sequential(
operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype),
nn.LeakyReLU(0.2, True),
),
nn.Sequential(
operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype),
nn.LeakyReLU(0.2, True),
),
)
self.convblock = nn.Sequential(
*(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8))
)
self.lastconv = nn.Sequential(
operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype),
nn.PixelShuffle(2),
)
def forward(self, x, flow=None, scale=1):
x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear")
if flow is not None:
flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale)
x = torch.cat((x, flow), 1)
feat = self.conv0(x)
feat = self.convblock(feat)
tmp = self.lastconv(feat)
tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear")
flow = tmp[:, :4] * scale
mask = tmp[:, 4:5]
feat = tmp[:, 5:]
return flow, mask, feat
class IFNet(nn.Module):
def __init__(self, head_ch=4, channels=None, device=None, dtype=None, operations=ops):
super().__init__()
if channels is None:
channels = [192, 128, 96, 64, 32]
self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations)
block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4
self.blocks = nn.ModuleList([
IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations)
for i in range(5)
])
self.scale_list = [16, 8, 4, 2, 1]
def get_dtype(self):
return self.encode.cnn0.weight.dtype
def forward(self, img0, img1, timestep=0.5):
img0 = img0.clamp(0.0, 1.0)
img1 = img1.clamp(0.0, 1.0)
if not isinstance(timestep, torch.Tensor):
timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]),
timestep, device=img0.device, dtype=img0.dtype)
f0 = self.encode(img0)
f1 = self.encode(img1)
flow = mask = feat = None
warped_img0 = img0
warped_img1 = img1
for i, block in enumerate(self.blocks):
if flow is None:
flow, mask, feat = block(
torch.cat((img0, img1, f0, f1, timestep), 1),
None, scale=self.scale_list[i],
)
else:
wf0 = warp(f0, flow[:, :2])
wf1 = warp(f1, flow[:, 2:4])
fd, mask, feat = block(
torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask, feat), 1),
flow, scale=self.scale_list[i],
)
flow = flow.add_(fd)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
mask = torch.sigmoid(mask)
return torch.lerp(warped_img1, warped_img0, mask)
def detect_rife_config(state_dict):
# Determine head output channels from encode.cnn3 (ConvTranspose2d)
# ConvTranspose2d weight shape is (in_ch, out_ch, kH, kW)
head_ch = state_dict["encode.cnn3.weight"].shape[1]
# Read per-block channel widths from conv0 second layer output channels
# conv0 is Sequential(Sequential(Conv2d, ReLU), Sequential(Conv2d, ReLU))
# conv0.1.0.weight shape is (c, c//2, 3, 3)
channels = []
for i in range(5):
key = f"blocks.{i}.conv0.1.0.weight"
if key in state_dict:
channels.append(state_dict[key].shape[0])
if len(channels) != 5:
raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}")
return head_ch, channels