From 257c5312d908f4e8c930a7f7f173ad0b6ad1082c Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 4 Apr 2026 16:09:34 +0300 Subject: [PATCH] Also support FILM --- .../frame_interpolation_models/film_net.py | 251 ++++++++++++++++++ .../frame_interpolation_models/ifnet.py | 128 +++++++++ comfy_extras/nodes_frame_interpolation.py | 143 ++++++---- comfy_extras/rife_model/ifnet.py | 168 ------------ 4 files changed, 464 insertions(+), 226 deletions(-) create mode 100644 comfy_extras/frame_interpolation_models/film_net.py create mode 100644 comfy_extras/frame_interpolation_models/ifnet.py delete mode 100644 comfy_extras/rife_model/ifnet.py diff --git a/comfy_extras/frame_interpolation_models/film_net.py b/comfy_extras/frame_interpolation_models/film_net.py new file mode 100644 index 000000000..552b78b8c --- /dev/null +++ b/comfy_extras/frame_interpolation_models/film_net.py @@ -0,0 +1,251 @@ +"""FILM: Frame Interpolation for Large Motion (ECCV 2022).""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops + +ops = comfy.ops.disable_weight_init + + +class FilmConv2d(nn.Module): + """Conv2d with optional LeakyReLU and FILM-style padding.""" + + def __init__(self, in_channels, out_channels, size, activation=True, device=None, dtype=None, operations=ops): + super().__init__() + self.even_pad = not size % 2 + self.conv = operations.Conv2d(in_channels, out_channels, kernel_size=size, padding=size // 2 if size % 2 else 0, device=device, dtype=dtype) + self.activation = nn.LeakyReLU(0.2) if activation else None + + def forward(self, x): + if self.even_pad: + x = F.pad(x, (0, 1, 0, 1)) + x = self.conv(x) + if self.activation is not None: + x = self.activation(x) + return x + + +def _warp_core(image, flow, grid_x, grid_y): + dtype = image.dtype + H, W = flow.shape[2], flow.shape[3] + dx = flow[:, 0].float() / (W * 0.5) + dy = flow[:, 1].float() / (H * 0.5) + grid = torch.stack([grid_x[None, None, :] + dx, grid_y[None, :, None] + dy], dim=3) + return F.grid_sample(image.float(), grid, mode="bilinear", padding_mode="border", align_corners=False).to(dtype) + + +def build_image_pyramid(image, pyramid_levels): + pyramid = [image] + for _ in range(1, pyramid_levels): + image = F.avg_pool2d(image, 2, 2) + pyramid.append(image) + return pyramid + + +def flow_pyramid_synthesis(residual_pyramid): + flow = residual_pyramid[-1] + flow_pyramid = [flow] + for residual_flow in residual_pyramid[:-1][::-1]: + flow = F.interpolate(flow, size=residual_flow.shape[2:4], mode="bilinear", scale_factor=None).mul_(2).add_(residual_flow) + flow_pyramid.append(flow) + flow_pyramid.reverse() + return flow_pyramid + + +def multiply_pyramid(pyramid, scalar): + return [image * scalar[:, None, None, None] for image in pyramid] + + +def pyramid_warp(feature_pyramid, flow_pyramid, warp_fn): + return [warp_fn(features, flow) for features, flow in zip(feature_pyramid, flow_pyramid)] + + +def concatenate_pyramids(pyramid1, pyramid2): + return [torch.cat([f1, f2], dim=1) for f1, f2 in zip(pyramid1, pyramid2)] + + +class SubTreeExtractor(nn.Module): + def __init__(self, in_channels=3, channels=64, n_layers=4, device=None, dtype=None, operations=ops): + super().__init__() + convs = [] + for i in range(n_layers): + out_ch = channels << i + convs.append(nn.Sequential( + FilmConv2d(in_channels, out_ch, 3, device=device, dtype=dtype, operations=operations), + FilmConv2d(out_ch, out_ch, 3, device=device, dtype=dtype, operations=operations))) + in_channels = out_ch + self.convs = nn.ModuleList(convs) + + def forward(self, image, n): + head = image + pyramid = [] + for i, layer in enumerate(self.convs): + head = layer(head) + pyramid.append(head) + if i < n - 1: + head = F.avg_pool2d(head, 2, 2) + return pyramid + + +class FeatureExtractor(nn.Module): + def __init__(self, in_channels=3, channels=64, sub_levels=4, device=None, dtype=None, operations=ops): + super().__init__() + self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels, device=device, dtype=dtype, operations=operations) + self.sub_levels = sub_levels + + def forward(self, image_pyramid): + sub_pyramids = [self.extract_sublevels(image_pyramid[i], min(len(image_pyramid) - i, self.sub_levels)) + for i in range(len(image_pyramid))] + feature_pyramid = [] + for i in range(len(image_pyramid)): + features = sub_pyramids[i][0] + for j in range(1, self.sub_levels): + if j <= i: + features = torch.cat([features, sub_pyramids[i - j][j]], dim=1) + feature_pyramid.append(features) + return feature_pyramid + + +class FlowEstimator(nn.Module): + def __init__(self, in_channels, num_convs, num_filters, device=None, dtype=None, operations=ops): + super().__init__() + self._convs = nn.ModuleList() + for _ in range(num_convs): + self._convs.append(FilmConv2d(in_channels, num_filters, 3, device=device, dtype=dtype, operations=operations)) + in_channels = num_filters + self._convs.append(FilmConv2d(in_channels, num_filters // 2, 1, device=device, dtype=dtype, operations=operations)) + self._convs.append(FilmConv2d(num_filters // 2, 2, 1, activation=False, device=device, dtype=dtype, operations=operations)) + + def forward(self, features_a, features_b): + net = torch.cat([features_a, features_b], dim=1) + for conv in self._convs: + net = conv(net) + return net + + +class PyramidFlowEstimator(nn.Module): + def __init__(self, filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops): + super().__init__() + in_channels = filters << 1 + predictors = [] + for i in range(len(flow_convs)): + predictors.append(FlowEstimator(in_channels, flow_convs[i], flow_filters[i], device=device, dtype=dtype, operations=operations)) + in_channels += filters << (i + 2) + self._predictor = predictors[-1] + self._predictors = nn.ModuleList(predictors[:-1][::-1]) + + def forward(self, feature_pyramid_a, feature_pyramid_b, warp_fn): + levels = len(feature_pyramid_a) + v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1]) + residuals = [v] + # Coarse-to-fine: shared predictor for deep levels, then specialized predictors for fine levels + steps = [(i, self._predictor) for i in range(levels - 2, len(self._predictors) - 1, -1)] + steps += [(len(self._predictors) - 1 - k, p) for k, p in enumerate(self._predictors)] + for i, predictor in steps: + v = F.interpolate(v, size=feature_pyramid_a[i].shape[2:4], mode="bilinear").mul_(2) + v_residual = predictor(feature_pyramid_a[i], warp_fn(feature_pyramid_b[i], v)) + residuals.append(v_residual) + v = v.add_(v_residual) + residuals.reverse() + return residuals + + +def _get_fusion_channels(level, filters): + # Per direction: multi-scale features + RGB image (3ch) + flow (2ch), doubled for both directions + return (sum(filters << i for i in range(level)) + 3 + 2) * 2 + + +class Fusion(nn.Module): + def __init__(self, n_layers=4, specialized_layers=3, filters=64, device=None, dtype=None, operations=ops): + super().__init__() + self.output_conv = operations.Conv2d(filters, 3, kernel_size=1, device=device, dtype=dtype) + self.convs = nn.ModuleList() + in_channels = _get_fusion_channels(n_layers, filters) + increase = 0 + for i in range(n_layers)[::-1]: + num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers) + self.convs.append(nn.ModuleList([ + FilmConv2d(in_channels, num_filters, 2, activation=False, device=device, dtype=dtype, operations=operations), + FilmConv2d(in_channels + (increase or num_filters), num_filters, 3, device=device, dtype=dtype, operations=operations), + FilmConv2d(num_filters, num_filters, 3, device=device, dtype=dtype, operations=operations)])) + in_channels = num_filters + increase = _get_fusion_channels(i, filters) - num_filters // 2 + + def forward(self, pyramid): + net = pyramid[-1] + for k, layers in enumerate(self.convs): + i = len(self.convs) - 1 - k + net = layers[0](F.interpolate(net, size=pyramid[i].shape[2:4], mode="nearest")) + net = layers[2](layers[1](torch.cat([pyramid[i], net], dim=1))) + return self.output_conv(net) + + +class FILMNet(nn.Module): + def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, sub_levels=4, + filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops): + super().__init__() + self.pyramid_levels = pyramid_levels + self.fusion_pyramid_levels = fusion_pyramid_levels + self.extract = FeatureExtractor(3, filters, sub_levels, device=device, dtype=dtype, operations=operations) + self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters, device=device, dtype=dtype, operations=operations) + self.fuse = Fusion(sub_levels, specialized_levels, filters, device=device, dtype=dtype, operations=operations) + self._warp_grids = {} + + def get_dtype(self): + return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype + + def _build_warp_grids(self, H, W, device): + """Pre-compute warp grids for all pyramid levels.""" + if (H, W) in self._warp_grids: + return + self._warp_grids = {} # clear old resolution grids to prevent memory leaks + for _ in range(self.pyramid_levels): + self._warp_grids[(H, W)] = ( + torch.linspace(-(1 - 1 / W), 1 - 1 / W, W, dtype=torch.float32, device=device), + torch.linspace(-(1 - 1 / H), 1 - 1 / H, H, dtype=torch.float32, device=device), + ) + H, W = H // 2, W // 2 + + def warp(self, image, flow): + grid_x, grid_y = self._warp_grids[(flow.shape[2], flow.shape[3])] + return _warp_core(image, flow, grid_x, grid_y) + + def extract_features(self, img): + """Extract image and feature pyramids for a single frame. Can be cached across pairs.""" + image_pyramid = build_image_pyramid(img, self.pyramid_levels) + feature_pyramid = self.extract(image_pyramid) + return image_pyramid, feature_pyramid + + def forward(self, img0, img1, timestep=0.5, cache=None): + # FILM uses a scalar timestep per batch element (spatially-varying timesteps not supported) + t = timestep.mean(dim=(1, 2, 3)).item() if isinstance(timestep, torch.Tensor) else timestep + return self.forward_multi_timestep(img0, img1, [t], cache=cache) + + def forward_multi_timestep(self, img0, img1, timesteps, cache=None): + """Compute flow once, synthesize at multiple timesteps. Expects batch=1 inputs.""" + self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device) + + image_pyr0, feat_pyr0 = cache["img0"] if cache and "img0" in cache else self.extract_features(img0) + image_pyr1, feat_pyr1 = cache["img1"] if cache and "img1" in cache else self.extract_features(img1) + + fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels] + bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels] + + fpl = self.fusion_pyramid_levels + p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]), + concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])] + + results = [] + dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype) + for idx in range(len(timesteps)): + batch_dt = dt_tensors[idx:idx + 1] + bwd_scaled = multiply_pyramid(bwd_flow, batch_dt) + fwd_scaled = multiply_pyramid(fwd_flow, 1 - batch_dt) + fwd_warped = pyramid_warp(p2w[0], bwd_scaled, self.warp) + bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp) + aligned = [torch.cat([fw, bw, bf, ff], dim=1) + for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)] + results.append(self.fuse(aligned)) + return torch.cat(results, dim=0) diff --git a/comfy_extras/frame_interpolation_models/ifnet.py b/comfy_extras/frame_interpolation_models/ifnet.py new file mode 100644 index 000000000..03cb34c50 --- /dev/null +++ b/comfy_extras/frame_interpolation_models/ifnet.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops + +ops = comfy.ops.disable_weight_init + + +def _warp(img, flow, warp_grids): + B, _, H, W = img.shape + base_grid, flow_div = warp_grids[(H, W)] + flow_norm = torch.cat([flow[:, 0:1] / flow_div[0], flow[:, 1:2] / flow_div[1]], 1).float() + grid = (base_grid.expand(B, -1, -1, -1) + flow_norm).permute(0, 2, 3, 1) + return F.grid_sample(img.float(), grid, mode="bilinear", padding_mode="border", align_corners=True).to(img.dtype) + + +class Head(nn.Module): + def __init__(self, out_ch=4, device=None, dtype=None, operations=ops): + super().__init__() + self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype) + self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + x = self.relu(self.cnn0(x)) + x = self.relu(self.cnn1(x)) + x = self.relu(self.cnn2(x)) + return self.cnn3(x) + + +class ResConv(nn.Module): + def __init__(self, c, device=None, dtype=None, operations=ops): + super().__init__() + self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype)) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(torch.addcmul(x, self.conv(x), self.beta)) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops): + super().__init__() + self.conv0 = nn.Sequential( + nn.Sequential(operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)), + nn.Sequential(operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True))) + self.convblock = nn.Sequential(*(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8))) + self.lastconv = nn.Sequential(operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), nn.PixelShuffle(2)) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") + if flow is not None: + flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale) + x = torch.cat((x, flow), 1) + feat = self.convblock(self.conv0(x)) + tmp = F.interpolate(self.lastconv(feat), scale_factor=scale, mode="bilinear") + return tmp[:, :4] * scale, tmp[:, 4:5], tmp[:, 5:] + + +class IFNet(nn.Module): + def __init__(self, head_ch=4, channels=(192, 128, 96, 64, 32), device=None, dtype=None, operations=ops): + super().__init__() + self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations) + block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4 + self.blocks = nn.ModuleList([IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) for i in range(5)]) + self.scale_list = [16, 8, 4, 2, 1] + self.pad_align = 64 + self._warp_grids = {} + + def get_dtype(self): + return self.encode.cnn0.weight.dtype + + def _build_warp_grids(self, H, W, device): + if (H, W) in self._warp_grids: + return + self._warp_grids = {} # clear old resolution grids to prevent memory leaks + grid_y, grid_x = torch.meshgrid( + torch.linspace(-1.0, 1.0, H, device=device, dtype=torch.float32), + torch.linspace(-1.0, 1.0, W, device=device, dtype=torch.float32), indexing="ij") + self._warp_grids[(H, W)] = ( + torch.stack((grid_x, grid_y), dim=0).unsqueeze(0), + torch.tensor([(W - 1.0) / 2.0, (H - 1.0) / 2.0], dtype=torch.float32, device=device)) + + def warp(self, img, flow): + return _warp(img, flow, self._warp_grids) + + def extract_features(self, img): + """Extract head features for a single frame. Can be cached across pairs.""" + return self.encode(img) + + def forward(self, img0, img1, timestep=0.5, cache=None): + if not isinstance(timestep, torch.Tensor): + timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), timestep, device=img0.device, dtype=img0.dtype) + + self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device) + + B = img0.shape[0] + f0 = cache["img0"].expand(B, -1, -1, -1) if cache and "img0" in cache else self.encode(img0) + f1 = cache["img1"].expand(B, -1, -1, -1) if cache and "img1" in cache else self.encode(img1) + flow = mask = feat = None + warped_img0, warped_img1 = img0, img1 + for i, block in enumerate(self.blocks): + if flow is None: + flow, mask, feat = block(torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) + else: + fd, mask, feat = block( + torch.cat((warped_img0, warped_img1, self.warp(f0, flow[:, :2]), self.warp(f1, flow[:, 2:4]), timestep, mask, feat), 1), + flow, scale=self.scale_list[i]) + flow = flow.add_(fd) + warped_img0 = self.warp(img0, flow[:, :2]) + warped_img1 = self.warp(img1, flow[:, 2:4]) + return torch.lerp(warped_img1, warped_img0, torch.sigmoid(mask)) + + +def detect_rife_config(state_dict): + head_ch = state_dict["encode.cnn3.weight"].shape[1] # ConvTranspose2d: (in_ch, out_ch, kH, kW) + channels = [] + for i in range(5): + key = f"blocks.{i}.conv0.1.0.weight" + if key in state_dict: + channels.append(state_dict[key].shape[0]) + if len(channels) != 5: + raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}") + return head_ch, channels diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index d90feb778..723e9c85a 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F from tqdm import tqdm from typing_extensions import override @@ -7,7 +6,8 @@ import comfy.model_patcher import comfy.utils import folder_paths from comfy import model_management -from comfy_extras.rife_model.ifnet import IFNet, detect_rife_config, clear_warp_cache +from comfy_extras.frame_interpolation_models.ifnet import IFNet, detect_rife_config +from comfy_extras.frame_interpolation_models.film_net import FILMNet from comfy_api.latest import ComfyExtension, io FrameInterpolationModel = io.Custom("FRAME_INTERPOLATION_MODEL") @@ -33,32 +33,8 @@ class FrameInterpolationModelLoader(io.ComfyNode): model_path = folder_paths.get_full_path_or_raise("frame_interpolation", model_name) sd = comfy.utils.load_torch_file(model_path, safe_load=True) - # Strip common prefixes (DataParallel, RIFE model wrapper) - sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""}) - - # Convert blockN.xxx keys to blocks.N.xxx if needed - key_map = {} - for k in sd: - for i in range(5): - prefix = f"block{i}." - if k.startswith(prefix): - key_map[k] = f"blocks.{i}.{k[len(prefix):]}" - if key_map: - new_sd = {} - for k, v in sd.items(): - new_sd[key_map.get(k, k)] = v - sd = new_sd - - # Filter out training-only keys (teacher distillation, timestamp calibration) - sd = {k: v for k, v in sd.items() - if not k.startswith(("teacher.", "caltime."))} - - head_ch, channels = detect_rife_config(sd) - model = IFNet(head_ch=head_ch, channels=channels) - model.load_state_dict(sd) - # RIFE is a small pixel-space model similar to VAE, bf16 produces artifacts due to low mantissa precision - dtype = model_management.vae_dtype(device=model_management.get_torch_device(), - allowed_dtypes=[torch.float16, torch.float32]) + model = cls._detect_and_load(sd) + dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32 model.eval().to(dtype) patcher = comfy.model_patcher.ModelPatcher( model, @@ -67,6 +43,33 @@ class FrameInterpolationModelLoader(io.ComfyNode): ) return io.NodeOutput(patcher) + @classmethod + def _detect_and_load(cls, sd): + # Try FILM + if "extract.extract_sublevels.convs.0.0.conv.weight" in sd: + model = FILMNet() + model.load_state_dict(sd) + return model + + # Try RIFE (needs key remapping for raw checkpoints) + sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""}) + key_map = {} + for k in sd: + for i in range(5): + if k.startswith(f"block{i}."): + key_map[k] = f"blocks.{i}.{k[len(f'block{i}.'):]}" + if key_map: + sd = {key_map.get(k, k): v for k, v in sd.items()} + sd = {k: v for k, v in sd.items() if not k.startswith(("teacher.", "caltime."))} + + try: + head_ch, channels = detect_rife_config(sd) + except (KeyError, ValueError): + raise ValueError("Unrecognized frame interpolation model format") + model = IFNet(head_ch=head_ch, channels=channels) + model.load_state_dict(sd) + return model + class FrameInterpolate(io.ComfyNode): @classmethod @@ -75,11 +78,13 @@ class FrameInterpolate(io.ComfyNode): node_id="FrameInterpolate", display_name="Frame Interpolate", category="image/video", - search_aliases=["rife", "frame interpolation", "slow motion", "interpolate frames"], + search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"], inputs=[ FrameInterpolationModel.Input("model"), io.Image.Input("images"), io.Int.Input("multiplier", default=2, min=2, max=16), + io.Boolean.Input("torch_compile", default=False, optional=True, advanced=True, + tooltip="Requires triton. Compile model submodules for potential speed increase. Adds warmup on first run, recompiles on resolution change."), ], outputs=[ io.Image.Output(), @@ -87,7 +92,7 @@ class FrameInterpolate(io.ComfyNode): ) @classmethod - def execute(cls, model, images, multiplier) -> io.NodeOutput: + def execute(cls, model, images, multiplier, torch_compile=False) -> io.NodeOutput: offload_device = model_management.intermediate_device() num_frames = images.shape[0] @@ -96,18 +101,27 @@ class FrameInterpolate(io.ComfyNode): model_management.load_model_gpu(model) device = model.load_device - inference_model = model.model dtype = model.model_dtype() + inference_model = model.model # BHWC -> BCHW frames = images.movedim(-1, 1).to(dtype=dtype, device=offload_device) _, C, H, W = frames.shape - # Pad to multiple of 64 - pad_h = (64 - H % 64) % 64 - pad_w = (64 - W % 64) % 64 - if pad_h > 0 or pad_w > 0: - frames = F.pad(frames, (0, pad_w, 0, pad_h), mode="reflect") + # Pad to model's required alignment (RIFE needs 64, FILM handles any size) + align = getattr(inference_model, "pad_align", 1) + if align > 1: + from comfy.ldm.common_dit import pad_to_patch_size + frames = pad_to_patch_size(frames, (align, align), padding_mode="reflect") + + if torch_compile: + for name, child in inference_model.named_children(): + if isinstance(child, (torch.nn.ModuleList, torch.nn.ModuleDict)): + continue + if not hasattr(child, "_compiled"): + compiled = torch.compile(child) + compiled._compiled = True + setattr(inference_model, name, compiled) # Count total interpolation passes for progress bar total_pairs = num_frames - 1 @@ -116,7 +130,7 @@ class FrameInterpolate(io.ComfyNode): pbar = comfy.utils.ProgressBar(total_steps) tqdm_bar = tqdm(total=total_steps, desc="Frame interpolation") - batch = num_interp + batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit) t_values = [t / multiplier for t in range(1, multiplier)] _, _, pH, pW = frames.shape @@ -131,42 +145,55 @@ class FrameInterpolate(io.ComfyNode): ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1) ts_full = ts_full.expand(-1, 1, pH, pW) + multi_fn = getattr(inference_model, "forward_multi_timestep", None) + feat_cache = {} + try: for i in range(total_pairs): img0_single = frames[i:i + 1].to(device) img1_single = frames[i + 1:i + 2].to(device) - j = 0 - while j < num_interp: - b = min(batch, num_interp - j) - try: - img0 = img0_single.expand(b, -1, -1, -1) - img1 = img1_single.expand(b, -1, -1, -1) - mids = inference_model(img0, img1, timestep=ts_full[j:j + b]) - result[out_idx:out_idx + b].copy_(mids.to(dtype=dtype), non_blocking=use_pin) - out_idx += b - pbar.update(b) - tqdm_bar.update(b) - j += b - except model_management.OOM_EXCEPTION: - if batch <= 1: - raise - batch = max(1, batch // 2) - model_management.soft_empty_cache() + # Cache features: img1 of pair N becomes img0 of pair N+1 + feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single) + feat_cache["img1"] = inference_model.extract_features(img1_single) + feat_cache["next"] = feat_cache["img1"] + + if multi_fn is not None: + # Models with timestep-independent flow can compute it once for all timesteps + mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache) + result[out_idx:out_idx + num_interp].copy_(mids.to(dtype=dtype), non_blocking=use_pin) + out_idx += num_interp + pbar.update(num_interp) + tqdm_bar.update(num_interp) + else: + j = 0 + while j < num_interp: + b = min(batch, num_interp - j) + try: + img0 = img0_single.expand(b, -1, -1, -1) + img1 = img1_single.expand(b, -1, -1, -1) + mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache) + result[out_idx:out_idx + b].copy_(mids.to(dtype=dtype), non_blocking=use_pin) + out_idx += b + pbar.update(b) + tqdm_bar.update(b) + j += b + except model_management.OOM_EXCEPTION: + if batch <= 1: + raise + batch = max(1, batch // 2) + model_management.soft_empty_cache() result[out_idx].copy_(frames[i + 1]) out_idx += 1 finally: tqdm_bar.close() - clear_warp_cache() if use_pin: model_management.synchronize() model_management.unpin_memory(result) # Crop padding and BCHW -> BHWC - if pad_h > 0 or pad_w > 0: - result = result[:, :, :H, :W] - result = result.movedim(1, -1).clamp_(0.0, 1.0).to(dtype=model_management.intermediate_dtype()) + result = result[:, :, :H, :W].movedim(1, -1).clamp_(0.0, 1.0).to(dtype=model_management.intermediate_dtype()) return io.NodeOutput(result) diff --git a/comfy_extras/rife_model/ifnet.py b/comfy_extras/rife_model/ifnet.py deleted file mode 100644 index 6a49d1a9f..000000000 --- a/comfy_extras/rife_model/ifnet.py +++ /dev/null @@ -1,168 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -import comfy.ops - -ops = comfy.ops.disable_weight_init - -_warp_grid_cache = {} -_WARP_GRID_CACHE_MAX = 4 - - -def clear_warp_cache(): - _warp_grid_cache.clear() - - -def warp(img, flow): - B, _, H, W = img.shape - dtype = img.dtype - img = img.float() - flow = flow.float() - cache_key = (H, W, flow.device) - if cache_key not in _warp_grid_cache: - if len(_warp_grid_cache) >= _WARP_GRID_CACHE_MAX: - _warp_grid_cache.pop(next(iter(_warp_grid_cache))) - grid_y, grid_x = torch.meshgrid( - torch.linspace(-1.0, 1.0, H, device=flow.device, dtype=torch.float32), - torch.linspace(-1.0, 1.0, W, device=flow.device, dtype=torch.float32), - indexing="ij", - ) - _warp_grid_cache[cache_key] = torch.stack((grid_x, grid_y), dim=0).unsqueeze(0) - grid = _warp_grid_cache[cache_key].expand(B, -1, -1, -1) - flow_norm = torch.cat([ - flow[:, 0:1] / ((W - 1) / 2), - flow[:, 1:2] / ((H - 1) / 2), - ], dim=1) - grid = (grid + flow_norm).permute(0, 2, 3, 1) - return F.grid_sample(img, grid, mode="bilinear", padding_mode="border", align_corners=True).to(dtype) - - -class Head(nn.Module): - def __init__(self, out_ch=4, device=None, dtype=None, operations=ops): - super().__init__() - self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype) - self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) - self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) - self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype) - self.relu = nn.LeakyReLU(0.2, True) - - def forward(self, x): - x = self.relu(self.cnn0(x)) - x = self.relu(self.cnn1(x)) - x = self.relu(self.cnn2(x)) - return self.cnn3(x) - - -class ResConv(nn.Module): - def __init__(self, c, device=None, dtype=None, operations=ops): - super().__init__() - self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype) - self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype)) - self.relu = nn.LeakyReLU(0.2, True) - - def forward(self, x): - return self.relu(torch.addcmul(x, self.conv(x), self.beta)) - - -class IFBlock(nn.Module): - def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops): - super().__init__() - self.conv0 = nn.Sequential( - nn.Sequential( - operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), - nn.LeakyReLU(0.2, True), - ), - nn.Sequential( - operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), - nn.LeakyReLU(0.2, True), - ), - ) - self.convblock = nn.Sequential( - *(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8)) - ) - self.lastconv = nn.Sequential( - operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), - nn.PixelShuffle(2), - ) - - def forward(self, x, flow=None, scale=1): - x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") - if flow is not None: - flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale) - x = torch.cat((x, flow), 1) - feat = self.conv0(x) - feat = self.convblock(feat) - tmp = self.lastconv(feat) - tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") - flow = tmp[:, :4] * scale - mask = tmp[:, 4:5] - feat = tmp[:, 5:] - return flow, mask, feat - - -class IFNet(nn.Module): - def __init__(self, head_ch=4, channels=None, device=None, dtype=None, operations=ops): - super().__init__() - if channels is None: - channels = [192, 128, 96, 64, 32] - self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations) - block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4 - self.blocks = nn.ModuleList([ - IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) - for i in range(5) - ]) - self.scale_list = [16, 8, 4, 2, 1] - - def get_dtype(self): - return self.encode.cnn0.weight.dtype - - def forward(self, img0, img1, timestep=0.5): - img0 = img0.clamp(0.0, 1.0) - img1 = img1.clamp(0.0, 1.0) - if not isinstance(timestep, torch.Tensor): - timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), - timestep, device=img0.device, dtype=img0.dtype) - f0 = self.encode(img0) - f1 = self.encode(img1) - flow = mask = feat = None - warped_img0 = img0 - warped_img1 = img1 - for i, block in enumerate(self.blocks): - if flow is None: - flow, mask, feat = block( - torch.cat((img0, img1, f0, f1, timestep), 1), - None, scale=self.scale_list[i], - ) - else: - wf0 = warp(f0, flow[:, :2]) - wf1 = warp(f1, flow[:, 2:4]) - fd, mask, feat = block( - torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask, feat), 1), - flow, scale=self.scale_list[i], - ) - flow = flow.add_(fd) - warped_img0 = warp(img0, flow[:, :2]) - warped_img1 = warp(img1, flow[:, 2:4]) - mask = torch.sigmoid(mask) - return torch.lerp(warped_img1, warped_img0, mask) - - -def detect_rife_config(state_dict): - # Determine head output channels from encode.cnn3 (ConvTranspose2d) - # ConvTranspose2d weight shape is (in_ch, out_ch, kH, kW) - head_ch = state_dict["encode.cnn3.weight"].shape[1] - - # Read per-block channel widths from conv0 second layer output channels - # conv0 is Sequential(Sequential(Conv2d, ReLU), Sequential(Conv2d, ReLU)) - # conv0.1.0.weight shape is (c, c//2, 3, 3) - channels = [] - for i in range(5): - key = f"blocks.{i}.conv0.1.0.weight" - if key in state_dict: - channels.append(state_dict[key].shape[0]) - - if len(channels) != 5: - raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}") - - return head_ch, channels