From a859152817d625115c017146ca0b0d0c3604c448 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 2 Apr 2026 19:12:21 +0300 Subject: [PATCH 1/6] initial RIFE support --- comfy_extras/nodes_frame_interpolation.py | 183 ++++++++++++++++++++++ comfy_extras/rife_model/ifnet.py | 168 ++++++++++++++++++++ folder_paths.py | 2 + nodes.py | 3 +- 4 files changed, 355 insertions(+), 1 deletion(-) create mode 100644 comfy_extras/nodes_frame_interpolation.py create mode 100644 comfy_extras/rife_model/ifnet.py diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py new file mode 100644 index 000000000..d90feb778 --- /dev/null +++ b/comfy_extras/nodes_frame_interpolation.py @@ -0,0 +1,183 @@ +import torch +import torch.nn.functional as F +from tqdm import tqdm +from typing_extensions import override + +import comfy.model_patcher +import comfy.utils +import folder_paths +from comfy import model_management +from comfy_extras.rife_model.ifnet import IFNet, detect_rife_config, clear_warp_cache +from comfy_api.latest import ComfyExtension, io + +FrameInterpolationModel = io.Custom("FRAME_INTERPOLATION_MODEL") + + +class FrameInterpolationModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FrameInterpolationModelLoader", + display_name="Load Frame Interpolation Model", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation")), + ], + outputs=[ + FrameInterpolationModel.Output(), + ], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + model_path = folder_paths.get_full_path_or_raise("frame_interpolation", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + + # Strip common prefixes (DataParallel, RIFE model wrapper) + sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""}) + + # Convert blockN.xxx keys to blocks.N.xxx if needed + key_map = {} + for k in sd: + for i in range(5): + prefix = f"block{i}." + if k.startswith(prefix): + key_map[k] = f"blocks.{i}.{k[len(prefix):]}" + if key_map: + new_sd = {} + for k, v in sd.items(): + new_sd[key_map.get(k, k)] = v + sd = new_sd + + # Filter out training-only keys (teacher distillation, timestamp calibration) + sd = {k: v for k, v in sd.items() + if not k.startswith(("teacher.", "caltime."))} + + head_ch, channels = detect_rife_config(sd) + model = IFNet(head_ch=head_ch, channels=channels) + model.load_state_dict(sd) + # RIFE is a small pixel-space model similar to VAE, bf16 produces artifacts due to low mantissa precision + dtype = model_management.vae_dtype(device=model_management.get_torch_device(), + allowed_dtypes=[torch.float16, torch.float32]) + model.eval().to(dtype) + patcher = comfy.model_patcher.ModelPatcher( + model, + load_device=model_management.get_torch_device(), + offload_device=model_management.unet_offload_device(), + ) + return io.NodeOutput(patcher) + + +class FrameInterpolate(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FrameInterpolate", + display_name="Frame Interpolate", + category="image/video", + search_aliases=["rife", "frame interpolation", "slow motion", "interpolate frames"], + inputs=[ + FrameInterpolationModel.Input("model"), + io.Image.Input("images"), + io.Int.Input("multiplier", default=2, min=2, max=16), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, model, images, multiplier) -> io.NodeOutput: + offload_device = model_management.intermediate_device() + + num_frames = images.shape[0] + if num_frames < 2 or multiplier < 2: + return io.NodeOutput(images) + + model_management.load_model_gpu(model) + device = model.load_device + inference_model = model.model + dtype = model.model_dtype() + + # BHWC -> BCHW + frames = images.movedim(-1, 1).to(dtype=dtype, device=offload_device) + _, C, H, W = frames.shape + + # Pad to multiple of 64 + pad_h = (64 - H % 64) % 64 + pad_w = (64 - W % 64) % 64 + if pad_h > 0 or pad_w > 0: + frames = F.pad(frames, (0, pad_w, 0, pad_h), mode="reflect") + + # Count total interpolation passes for progress bar + total_pairs = num_frames - 1 + num_interp = multiplier - 1 + total_steps = total_pairs * num_interp + pbar = comfy.utils.ProgressBar(total_steps) + tqdm_bar = tqdm(total=total_steps, desc="Frame interpolation") + + batch = num_interp + t_values = [t / multiplier for t in range(1, multiplier)] + _, _, pH, pW = frames.shape + + # Pre-allocate output tensor, pin for async GPU->CPU transfer + total_out_frames = total_pairs * multiplier + 1 + result = torch.empty((total_out_frames, C, pH, pW), dtype=dtype, device=offload_device) + use_pin = model_management.pin_memory(result) + result[0] = frames[0] + out_idx = 1 + + # Pre-compute timestep tensor on device + ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1) + ts_full = ts_full.expand(-1, 1, pH, pW) + + try: + for i in range(total_pairs): + img0_single = frames[i:i + 1].to(device) + img1_single = frames[i + 1:i + 2].to(device) + + j = 0 + while j < num_interp: + b = min(batch, num_interp - j) + try: + img0 = img0_single.expand(b, -1, -1, -1) + img1 = img1_single.expand(b, -1, -1, -1) + mids = inference_model(img0, img1, timestep=ts_full[j:j + b]) + result[out_idx:out_idx + b].copy_(mids.to(dtype=dtype), non_blocking=use_pin) + out_idx += b + pbar.update(b) + tqdm_bar.update(b) + j += b + except model_management.OOM_EXCEPTION: + if batch <= 1: + raise + batch = max(1, batch // 2) + model_management.soft_empty_cache() + + result[out_idx].copy_(frames[i + 1]) + out_idx += 1 + finally: + tqdm_bar.close() + clear_warp_cache() + if use_pin: + model_management.synchronize() + model_management.unpin_memory(result) + + # Crop padding and BCHW -> BHWC + if pad_h > 0 or pad_w > 0: + result = result[:, :, :H, :W] + result = result.movedim(1, -1).clamp_(0.0, 1.0).to(dtype=model_management.intermediate_dtype()) + return io.NodeOutput(result) + + +class FrameInterpolationExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + FrameInterpolationModelLoader, + FrameInterpolate, + ] + + +async def comfy_entrypoint() -> FrameInterpolationExtension: + return FrameInterpolationExtension() diff --git a/comfy_extras/rife_model/ifnet.py b/comfy_extras/rife_model/ifnet.py new file mode 100644 index 000000000..6a49d1a9f --- /dev/null +++ b/comfy_extras/rife_model/ifnet.py @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops + +ops = comfy.ops.disable_weight_init + +_warp_grid_cache = {} +_WARP_GRID_CACHE_MAX = 4 + + +def clear_warp_cache(): + _warp_grid_cache.clear() + + +def warp(img, flow): + B, _, H, W = img.shape + dtype = img.dtype + img = img.float() + flow = flow.float() + cache_key = (H, W, flow.device) + if cache_key not in _warp_grid_cache: + if len(_warp_grid_cache) >= _WARP_GRID_CACHE_MAX: + _warp_grid_cache.pop(next(iter(_warp_grid_cache))) + grid_y, grid_x = torch.meshgrid( + torch.linspace(-1.0, 1.0, H, device=flow.device, dtype=torch.float32), + torch.linspace(-1.0, 1.0, W, device=flow.device, dtype=torch.float32), + indexing="ij", + ) + _warp_grid_cache[cache_key] = torch.stack((grid_x, grid_y), dim=0).unsqueeze(0) + grid = _warp_grid_cache[cache_key].expand(B, -1, -1, -1) + flow_norm = torch.cat([ + flow[:, 0:1] / ((W - 1) / 2), + flow[:, 1:2] / ((H - 1) / 2), + ], dim=1) + grid = (grid + flow_norm).permute(0, 2, 3, 1) + return F.grid_sample(img, grid, mode="bilinear", padding_mode="border", align_corners=True).to(dtype) + + +class Head(nn.Module): + def __init__(self, out_ch=4, device=None, dtype=None, operations=ops): + super().__init__() + self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype) + self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + x = self.relu(self.cnn0(x)) + x = self.relu(self.cnn1(x)) + x = self.relu(self.cnn2(x)) + return self.cnn3(x) + + +class ResConv(nn.Module): + def __init__(self, c, device=None, dtype=None, operations=ops): + super().__init__() + self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype)) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(torch.addcmul(x, self.conv(x), self.beta)) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops): + super().__init__() + self.conv0 = nn.Sequential( + nn.Sequential( + operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), + nn.LeakyReLU(0.2, True), + ), + nn.Sequential( + operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), + nn.LeakyReLU(0.2, True), + ), + ) + self.convblock = nn.Sequential( + *(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8)) + ) + self.lastconv = nn.Sequential( + operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), + nn.PixelShuffle(2), + ) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") + if flow is not None: + flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale) + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + feat = self.convblock(feat) + tmp = self.lastconv(feat) + tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") + flow = tmp[:, :4] * scale + mask = tmp[:, 4:5] + feat = tmp[:, 5:] + return flow, mask, feat + + +class IFNet(nn.Module): + def __init__(self, head_ch=4, channels=None, device=None, dtype=None, operations=ops): + super().__init__() + if channels is None: + channels = [192, 128, 96, 64, 32] + self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations) + block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4 + self.blocks = nn.ModuleList([ + IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) + for i in range(5) + ]) + self.scale_list = [16, 8, 4, 2, 1] + + def get_dtype(self): + return self.encode.cnn0.weight.dtype + + def forward(self, img0, img1, timestep=0.5): + img0 = img0.clamp(0.0, 1.0) + img1 = img1.clamp(0.0, 1.0) + if not isinstance(timestep, torch.Tensor): + timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), + timestep, device=img0.device, dtype=img0.dtype) + f0 = self.encode(img0) + f1 = self.encode(img1) + flow = mask = feat = None + warped_img0 = img0 + warped_img1 = img1 + for i, block in enumerate(self.blocks): + if flow is None: + flow, mask, feat = block( + torch.cat((img0, img1, f0, f1, timestep), 1), + None, scale=self.scale_list[i], + ) + else: + wf0 = warp(f0, flow[:, :2]) + wf1 = warp(f1, flow[:, 2:4]) + fd, mask, feat = block( + torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask, feat), 1), + flow, scale=self.scale_list[i], + ) + flow = flow.add_(fd) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + mask = torch.sigmoid(mask) + return torch.lerp(warped_img1, warped_img0, mask) + + +def detect_rife_config(state_dict): + # Determine head output channels from encode.cnn3 (ConvTranspose2d) + # ConvTranspose2d weight shape is (in_ch, out_ch, kH, kW) + head_ch = state_dict["encode.cnn3.weight"].shape[1] + + # Read per-block channel widths from conv0 second layer output channels + # conv0 is Sequential(Sequential(Conv2d, ReLU), Sequential(Conv2d, ReLU)) + # conv0.1.0.weight shape is (c, c//2, 3, 3) + channels = [] + for i in range(5): + key = f"blocks.{i}.conv0.1.0.weight" + if key in state_dict: + channels.append(state_dict[key].shape[0]) + + if len(channels) != 5: + raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}") + + return head_ch, channels diff --git a/folder_paths.py b/folder_paths.py index 9c96540e3..80f4b291a 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions) +folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") diff --git a/nodes.py b/nodes.py index 299b3d758..bb38e07b8 100644 --- a/nodes.py +++ b/nodes.py @@ -2457,7 +2457,8 @@ async def init_builtin_extra_nodes(): "nodes_number_convert.py", "nodes_painter.py", "nodes_curve.py", - "nodes_rtdetr.py" + "nodes_rtdetr.py", + "nodes_frame_interpolation.py", ] import_failed = [] From 257c5312d908f4e8c930a7f7f173ad0b6ad1082c Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 4 Apr 2026 16:09:34 +0300 Subject: [PATCH 2/6] Also support FILM --- .../frame_interpolation_models/film_net.py | 251 ++++++++++++++++++ .../frame_interpolation_models/ifnet.py | 128 +++++++++ comfy_extras/nodes_frame_interpolation.py | 143 ++++++---- comfy_extras/rife_model/ifnet.py | 168 ------------ 4 files changed, 464 insertions(+), 226 deletions(-) create mode 100644 comfy_extras/frame_interpolation_models/film_net.py create mode 100644 comfy_extras/frame_interpolation_models/ifnet.py delete mode 100644 comfy_extras/rife_model/ifnet.py diff --git a/comfy_extras/frame_interpolation_models/film_net.py b/comfy_extras/frame_interpolation_models/film_net.py new file mode 100644 index 000000000..552b78b8c --- /dev/null +++ b/comfy_extras/frame_interpolation_models/film_net.py @@ -0,0 +1,251 @@ +"""FILM: Frame Interpolation for Large Motion (ECCV 2022).""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops + +ops = comfy.ops.disable_weight_init + + +class FilmConv2d(nn.Module): + """Conv2d with optional LeakyReLU and FILM-style padding.""" + + def __init__(self, in_channels, out_channels, size, activation=True, device=None, dtype=None, operations=ops): + super().__init__() + self.even_pad = not size % 2 + self.conv = operations.Conv2d(in_channels, out_channels, kernel_size=size, padding=size // 2 if size % 2 else 0, device=device, dtype=dtype) + self.activation = nn.LeakyReLU(0.2) if activation else None + + def forward(self, x): + if self.even_pad: + x = F.pad(x, (0, 1, 0, 1)) + x = self.conv(x) + if self.activation is not None: + x = self.activation(x) + return x + + +def _warp_core(image, flow, grid_x, grid_y): + dtype = image.dtype + H, W = flow.shape[2], flow.shape[3] + dx = flow[:, 0].float() / (W * 0.5) + dy = flow[:, 1].float() / (H * 0.5) + grid = torch.stack([grid_x[None, None, :] + dx, grid_y[None, :, None] + dy], dim=3) + return F.grid_sample(image.float(), grid, mode="bilinear", padding_mode="border", align_corners=False).to(dtype) + + +def build_image_pyramid(image, pyramid_levels): + pyramid = [image] + for _ in range(1, pyramid_levels): + image = F.avg_pool2d(image, 2, 2) + pyramid.append(image) + return pyramid + + +def flow_pyramid_synthesis(residual_pyramid): + flow = residual_pyramid[-1] + flow_pyramid = [flow] + for residual_flow in residual_pyramid[:-1][::-1]: + flow = F.interpolate(flow, size=residual_flow.shape[2:4], mode="bilinear", scale_factor=None).mul_(2).add_(residual_flow) + flow_pyramid.append(flow) + flow_pyramid.reverse() + return flow_pyramid + + +def multiply_pyramid(pyramid, scalar): + return [image * scalar[:, None, None, None] for image in pyramid] + + +def pyramid_warp(feature_pyramid, flow_pyramid, warp_fn): + return [warp_fn(features, flow) for features, flow in zip(feature_pyramid, flow_pyramid)] + + +def concatenate_pyramids(pyramid1, pyramid2): + return [torch.cat([f1, f2], dim=1) for f1, f2 in zip(pyramid1, pyramid2)] + + +class SubTreeExtractor(nn.Module): + def __init__(self, in_channels=3, channels=64, n_layers=4, device=None, dtype=None, operations=ops): + super().__init__() + convs = [] + for i in range(n_layers): + out_ch = channels << i + convs.append(nn.Sequential( + FilmConv2d(in_channels, out_ch, 3, device=device, dtype=dtype, operations=operations), + FilmConv2d(out_ch, out_ch, 3, device=device, dtype=dtype, operations=operations))) + in_channels = out_ch + self.convs = nn.ModuleList(convs) + + def forward(self, image, n): + head = image + pyramid = [] + for i, layer in enumerate(self.convs): + head = layer(head) + pyramid.append(head) + if i < n - 1: + head = F.avg_pool2d(head, 2, 2) + return pyramid + + +class FeatureExtractor(nn.Module): + def __init__(self, in_channels=3, channels=64, sub_levels=4, device=None, dtype=None, operations=ops): + super().__init__() + self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels, device=device, dtype=dtype, operations=operations) + self.sub_levels = sub_levels + + def forward(self, image_pyramid): + sub_pyramids = [self.extract_sublevels(image_pyramid[i], min(len(image_pyramid) - i, self.sub_levels)) + for i in range(len(image_pyramid))] + feature_pyramid = [] + for i in range(len(image_pyramid)): + features = sub_pyramids[i][0] + for j in range(1, self.sub_levels): + if j <= i: + features = torch.cat([features, sub_pyramids[i - j][j]], dim=1) + feature_pyramid.append(features) + return feature_pyramid + + +class FlowEstimator(nn.Module): + def __init__(self, in_channels, num_convs, num_filters, device=None, dtype=None, operations=ops): + super().__init__() + self._convs = nn.ModuleList() + for _ in range(num_convs): + self._convs.append(FilmConv2d(in_channels, num_filters, 3, device=device, dtype=dtype, operations=operations)) + in_channels = num_filters + self._convs.append(FilmConv2d(in_channels, num_filters // 2, 1, device=device, dtype=dtype, operations=operations)) + self._convs.append(FilmConv2d(num_filters // 2, 2, 1, activation=False, device=device, dtype=dtype, operations=operations)) + + def forward(self, features_a, features_b): + net = torch.cat([features_a, features_b], dim=1) + for conv in self._convs: + net = conv(net) + return net + + +class PyramidFlowEstimator(nn.Module): + def __init__(self, filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops): + super().__init__() + in_channels = filters << 1 + predictors = [] + for i in range(len(flow_convs)): + predictors.append(FlowEstimator(in_channels, flow_convs[i], flow_filters[i], device=device, dtype=dtype, operations=operations)) + in_channels += filters << (i + 2) + self._predictor = predictors[-1] + self._predictors = nn.ModuleList(predictors[:-1][::-1]) + + def forward(self, feature_pyramid_a, feature_pyramid_b, warp_fn): + levels = len(feature_pyramid_a) + v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1]) + residuals = [v] + # Coarse-to-fine: shared predictor for deep levels, then specialized predictors for fine levels + steps = [(i, self._predictor) for i in range(levels - 2, len(self._predictors) - 1, -1)] + steps += [(len(self._predictors) - 1 - k, p) for k, p in enumerate(self._predictors)] + for i, predictor in steps: + v = F.interpolate(v, size=feature_pyramid_a[i].shape[2:4], mode="bilinear").mul_(2) + v_residual = predictor(feature_pyramid_a[i], warp_fn(feature_pyramid_b[i], v)) + residuals.append(v_residual) + v = v.add_(v_residual) + residuals.reverse() + return residuals + + +def _get_fusion_channels(level, filters): + # Per direction: multi-scale features + RGB image (3ch) + flow (2ch), doubled for both directions + return (sum(filters << i for i in range(level)) + 3 + 2) * 2 + + +class Fusion(nn.Module): + def __init__(self, n_layers=4, specialized_layers=3, filters=64, device=None, dtype=None, operations=ops): + super().__init__() + self.output_conv = operations.Conv2d(filters, 3, kernel_size=1, device=device, dtype=dtype) + self.convs = nn.ModuleList() + in_channels = _get_fusion_channels(n_layers, filters) + increase = 0 + for i in range(n_layers)[::-1]: + num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers) + self.convs.append(nn.ModuleList([ + FilmConv2d(in_channels, num_filters, 2, activation=False, device=device, dtype=dtype, operations=operations), + FilmConv2d(in_channels + (increase or num_filters), num_filters, 3, device=device, dtype=dtype, operations=operations), + FilmConv2d(num_filters, num_filters, 3, device=device, dtype=dtype, operations=operations)])) + in_channels = num_filters + increase = _get_fusion_channels(i, filters) - num_filters // 2 + + def forward(self, pyramid): + net = pyramid[-1] + for k, layers in enumerate(self.convs): + i = len(self.convs) - 1 - k + net = layers[0](F.interpolate(net, size=pyramid[i].shape[2:4], mode="nearest")) + net = layers[2](layers[1](torch.cat([pyramid[i], net], dim=1))) + return self.output_conv(net) + + +class FILMNet(nn.Module): + def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, sub_levels=4, + filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops): + super().__init__() + self.pyramid_levels = pyramid_levels + self.fusion_pyramid_levels = fusion_pyramid_levels + self.extract = FeatureExtractor(3, filters, sub_levels, device=device, dtype=dtype, operations=operations) + self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters, device=device, dtype=dtype, operations=operations) + self.fuse = Fusion(sub_levels, specialized_levels, filters, device=device, dtype=dtype, operations=operations) + self._warp_grids = {} + + def get_dtype(self): + return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype + + def _build_warp_grids(self, H, W, device): + """Pre-compute warp grids for all pyramid levels.""" + if (H, W) in self._warp_grids: + return + self._warp_grids = {} # clear old resolution grids to prevent memory leaks + for _ in range(self.pyramid_levels): + self._warp_grids[(H, W)] = ( + torch.linspace(-(1 - 1 / W), 1 - 1 / W, W, dtype=torch.float32, device=device), + torch.linspace(-(1 - 1 / H), 1 - 1 / H, H, dtype=torch.float32, device=device), + ) + H, W = H // 2, W // 2 + + def warp(self, image, flow): + grid_x, grid_y = self._warp_grids[(flow.shape[2], flow.shape[3])] + return _warp_core(image, flow, grid_x, grid_y) + + def extract_features(self, img): + """Extract image and feature pyramids for a single frame. Can be cached across pairs.""" + image_pyramid = build_image_pyramid(img, self.pyramid_levels) + feature_pyramid = self.extract(image_pyramid) + return image_pyramid, feature_pyramid + + def forward(self, img0, img1, timestep=0.5, cache=None): + # FILM uses a scalar timestep per batch element (spatially-varying timesteps not supported) + t = timestep.mean(dim=(1, 2, 3)).item() if isinstance(timestep, torch.Tensor) else timestep + return self.forward_multi_timestep(img0, img1, [t], cache=cache) + + def forward_multi_timestep(self, img0, img1, timesteps, cache=None): + """Compute flow once, synthesize at multiple timesteps. Expects batch=1 inputs.""" + self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device) + + image_pyr0, feat_pyr0 = cache["img0"] if cache and "img0" in cache else self.extract_features(img0) + image_pyr1, feat_pyr1 = cache["img1"] if cache and "img1" in cache else self.extract_features(img1) + + fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels] + bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels] + + fpl = self.fusion_pyramid_levels + p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]), + concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])] + + results = [] + dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype) + for idx in range(len(timesteps)): + batch_dt = dt_tensors[idx:idx + 1] + bwd_scaled = multiply_pyramid(bwd_flow, batch_dt) + fwd_scaled = multiply_pyramid(fwd_flow, 1 - batch_dt) + fwd_warped = pyramid_warp(p2w[0], bwd_scaled, self.warp) + bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp) + aligned = [torch.cat([fw, bw, bf, ff], dim=1) + for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)] + results.append(self.fuse(aligned)) + return torch.cat(results, dim=0) diff --git a/comfy_extras/frame_interpolation_models/ifnet.py b/comfy_extras/frame_interpolation_models/ifnet.py new file mode 100644 index 000000000..03cb34c50 --- /dev/null +++ b/comfy_extras/frame_interpolation_models/ifnet.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops + +ops = comfy.ops.disable_weight_init + + +def _warp(img, flow, warp_grids): + B, _, H, W = img.shape + base_grid, flow_div = warp_grids[(H, W)] + flow_norm = torch.cat([flow[:, 0:1] / flow_div[0], flow[:, 1:2] / flow_div[1]], 1).float() + grid = (base_grid.expand(B, -1, -1, -1) + flow_norm).permute(0, 2, 3, 1) + return F.grid_sample(img.float(), grid, mode="bilinear", padding_mode="border", align_corners=True).to(img.dtype) + + +class Head(nn.Module): + def __init__(self, out_ch=4, device=None, dtype=None, operations=ops): + super().__init__() + self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype) + self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + x = self.relu(self.cnn0(x)) + x = self.relu(self.cnn1(x)) + x = self.relu(self.cnn2(x)) + return self.cnn3(x) + + +class ResConv(nn.Module): + def __init__(self, c, device=None, dtype=None, operations=ops): + super().__init__() + self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype)) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(torch.addcmul(x, self.conv(x), self.beta)) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops): + super().__init__() + self.conv0 = nn.Sequential( + nn.Sequential(operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)), + nn.Sequential(operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True))) + self.convblock = nn.Sequential(*(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8))) + self.lastconv = nn.Sequential(operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), nn.PixelShuffle(2)) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") + if flow is not None: + flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale) + x = torch.cat((x, flow), 1) + feat = self.convblock(self.conv0(x)) + tmp = F.interpolate(self.lastconv(feat), scale_factor=scale, mode="bilinear") + return tmp[:, :4] * scale, tmp[:, 4:5], tmp[:, 5:] + + +class IFNet(nn.Module): + def __init__(self, head_ch=4, channels=(192, 128, 96, 64, 32), device=None, dtype=None, operations=ops): + super().__init__() + self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations) + block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4 + self.blocks = nn.ModuleList([IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) for i in range(5)]) + self.scale_list = [16, 8, 4, 2, 1] + self.pad_align = 64 + self._warp_grids = {} + + def get_dtype(self): + return self.encode.cnn0.weight.dtype + + def _build_warp_grids(self, H, W, device): + if (H, W) in self._warp_grids: + return + self._warp_grids = {} # clear old resolution grids to prevent memory leaks + grid_y, grid_x = torch.meshgrid( + torch.linspace(-1.0, 1.0, H, device=device, dtype=torch.float32), + torch.linspace(-1.0, 1.0, W, device=device, dtype=torch.float32), indexing="ij") + self._warp_grids[(H, W)] = ( + torch.stack((grid_x, grid_y), dim=0).unsqueeze(0), + torch.tensor([(W - 1.0) / 2.0, (H - 1.0) / 2.0], dtype=torch.float32, device=device)) + + def warp(self, img, flow): + return _warp(img, flow, self._warp_grids) + + def extract_features(self, img): + """Extract head features for a single frame. Can be cached across pairs.""" + return self.encode(img) + + def forward(self, img0, img1, timestep=0.5, cache=None): + if not isinstance(timestep, torch.Tensor): + timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), timestep, device=img0.device, dtype=img0.dtype) + + self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device) + + B = img0.shape[0] + f0 = cache["img0"].expand(B, -1, -1, -1) if cache and "img0" in cache else self.encode(img0) + f1 = cache["img1"].expand(B, -1, -1, -1) if cache and "img1" in cache else self.encode(img1) + flow = mask = feat = None + warped_img0, warped_img1 = img0, img1 + for i, block in enumerate(self.blocks): + if flow is None: + flow, mask, feat = block(torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) + else: + fd, mask, feat = block( + torch.cat((warped_img0, warped_img1, self.warp(f0, flow[:, :2]), self.warp(f1, flow[:, 2:4]), timestep, mask, feat), 1), + flow, scale=self.scale_list[i]) + flow = flow.add_(fd) + warped_img0 = self.warp(img0, flow[:, :2]) + warped_img1 = self.warp(img1, flow[:, 2:4]) + return torch.lerp(warped_img1, warped_img0, torch.sigmoid(mask)) + + +def detect_rife_config(state_dict): + head_ch = state_dict["encode.cnn3.weight"].shape[1] # ConvTranspose2d: (in_ch, out_ch, kH, kW) + channels = [] + for i in range(5): + key = f"blocks.{i}.conv0.1.0.weight" + if key in state_dict: + channels.append(state_dict[key].shape[0]) + if len(channels) != 5: + raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}") + return head_ch, channels diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index d90feb778..723e9c85a 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -1,5 +1,4 @@ import torch -import torch.nn.functional as F from tqdm import tqdm from typing_extensions import override @@ -7,7 +6,8 @@ import comfy.model_patcher import comfy.utils import folder_paths from comfy import model_management -from comfy_extras.rife_model.ifnet import IFNet, detect_rife_config, clear_warp_cache +from comfy_extras.frame_interpolation_models.ifnet import IFNet, detect_rife_config +from comfy_extras.frame_interpolation_models.film_net import FILMNet from comfy_api.latest import ComfyExtension, io FrameInterpolationModel = io.Custom("FRAME_INTERPOLATION_MODEL") @@ -33,32 +33,8 @@ class FrameInterpolationModelLoader(io.ComfyNode): model_path = folder_paths.get_full_path_or_raise("frame_interpolation", model_name) sd = comfy.utils.load_torch_file(model_path, safe_load=True) - # Strip common prefixes (DataParallel, RIFE model wrapper) - sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""}) - - # Convert blockN.xxx keys to blocks.N.xxx if needed - key_map = {} - for k in sd: - for i in range(5): - prefix = f"block{i}." - if k.startswith(prefix): - key_map[k] = f"blocks.{i}.{k[len(prefix):]}" - if key_map: - new_sd = {} - for k, v in sd.items(): - new_sd[key_map.get(k, k)] = v - sd = new_sd - - # Filter out training-only keys (teacher distillation, timestamp calibration) - sd = {k: v for k, v in sd.items() - if not k.startswith(("teacher.", "caltime."))} - - head_ch, channels = detect_rife_config(sd) - model = IFNet(head_ch=head_ch, channels=channels) - model.load_state_dict(sd) - # RIFE is a small pixel-space model similar to VAE, bf16 produces artifacts due to low mantissa precision - dtype = model_management.vae_dtype(device=model_management.get_torch_device(), - allowed_dtypes=[torch.float16, torch.float32]) + model = cls._detect_and_load(sd) + dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32 model.eval().to(dtype) patcher = comfy.model_patcher.ModelPatcher( model, @@ -67,6 +43,33 @@ class FrameInterpolationModelLoader(io.ComfyNode): ) return io.NodeOutput(patcher) + @classmethod + def _detect_and_load(cls, sd): + # Try FILM + if "extract.extract_sublevels.convs.0.0.conv.weight" in sd: + model = FILMNet() + model.load_state_dict(sd) + return model + + # Try RIFE (needs key remapping for raw checkpoints) + sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""}) + key_map = {} + for k in sd: + for i in range(5): + if k.startswith(f"block{i}."): + key_map[k] = f"blocks.{i}.{k[len(f'block{i}.'):]}" + if key_map: + sd = {key_map.get(k, k): v for k, v in sd.items()} + sd = {k: v for k, v in sd.items() if not k.startswith(("teacher.", "caltime."))} + + try: + head_ch, channels = detect_rife_config(sd) + except (KeyError, ValueError): + raise ValueError("Unrecognized frame interpolation model format") + model = IFNet(head_ch=head_ch, channels=channels) + model.load_state_dict(sd) + return model + class FrameInterpolate(io.ComfyNode): @classmethod @@ -75,11 +78,13 @@ class FrameInterpolate(io.ComfyNode): node_id="FrameInterpolate", display_name="Frame Interpolate", category="image/video", - search_aliases=["rife", "frame interpolation", "slow motion", "interpolate frames"], + search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"], inputs=[ FrameInterpolationModel.Input("model"), io.Image.Input("images"), io.Int.Input("multiplier", default=2, min=2, max=16), + io.Boolean.Input("torch_compile", default=False, optional=True, advanced=True, + tooltip="Requires triton. Compile model submodules for potential speed increase. Adds warmup on first run, recompiles on resolution change."), ], outputs=[ io.Image.Output(), @@ -87,7 +92,7 @@ class FrameInterpolate(io.ComfyNode): ) @classmethod - def execute(cls, model, images, multiplier) -> io.NodeOutput: + def execute(cls, model, images, multiplier, torch_compile=False) -> io.NodeOutput: offload_device = model_management.intermediate_device() num_frames = images.shape[0] @@ -96,18 +101,27 @@ class FrameInterpolate(io.ComfyNode): model_management.load_model_gpu(model) device = model.load_device - inference_model = model.model dtype = model.model_dtype() + inference_model = model.model # BHWC -> BCHW frames = images.movedim(-1, 1).to(dtype=dtype, device=offload_device) _, C, H, W = frames.shape - # Pad to multiple of 64 - pad_h = (64 - H % 64) % 64 - pad_w = (64 - W % 64) % 64 - if pad_h > 0 or pad_w > 0: - frames = F.pad(frames, (0, pad_w, 0, pad_h), mode="reflect") + # Pad to model's required alignment (RIFE needs 64, FILM handles any size) + align = getattr(inference_model, "pad_align", 1) + if align > 1: + from comfy.ldm.common_dit import pad_to_patch_size + frames = pad_to_patch_size(frames, (align, align), padding_mode="reflect") + + if torch_compile: + for name, child in inference_model.named_children(): + if isinstance(child, (torch.nn.ModuleList, torch.nn.ModuleDict)): + continue + if not hasattr(child, "_compiled"): + compiled = torch.compile(child) + compiled._compiled = True + setattr(inference_model, name, compiled) # Count total interpolation passes for progress bar total_pairs = num_frames - 1 @@ -116,7 +130,7 @@ class FrameInterpolate(io.ComfyNode): pbar = comfy.utils.ProgressBar(total_steps) tqdm_bar = tqdm(total=total_steps, desc="Frame interpolation") - batch = num_interp + batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit) t_values = [t / multiplier for t in range(1, multiplier)] _, _, pH, pW = frames.shape @@ -131,42 +145,55 @@ class FrameInterpolate(io.ComfyNode): ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1) ts_full = ts_full.expand(-1, 1, pH, pW) + multi_fn = getattr(inference_model, "forward_multi_timestep", None) + feat_cache = {} + try: for i in range(total_pairs): img0_single = frames[i:i + 1].to(device) img1_single = frames[i + 1:i + 2].to(device) - j = 0 - while j < num_interp: - b = min(batch, num_interp - j) - try: - img0 = img0_single.expand(b, -1, -1, -1) - img1 = img1_single.expand(b, -1, -1, -1) - mids = inference_model(img0, img1, timestep=ts_full[j:j + b]) - result[out_idx:out_idx + b].copy_(mids.to(dtype=dtype), non_blocking=use_pin) - out_idx += b - pbar.update(b) - tqdm_bar.update(b) - j += b - except model_management.OOM_EXCEPTION: - if batch <= 1: - raise - batch = max(1, batch // 2) - model_management.soft_empty_cache() + # Cache features: img1 of pair N becomes img0 of pair N+1 + feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single) + feat_cache["img1"] = inference_model.extract_features(img1_single) + feat_cache["next"] = feat_cache["img1"] + + if multi_fn is not None: + # Models with timestep-independent flow can compute it once for all timesteps + mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache) + result[out_idx:out_idx + num_interp].copy_(mids.to(dtype=dtype), non_blocking=use_pin) + out_idx += num_interp + pbar.update(num_interp) + tqdm_bar.update(num_interp) + else: + j = 0 + while j < num_interp: + b = min(batch, num_interp - j) + try: + img0 = img0_single.expand(b, -1, -1, -1) + img1 = img1_single.expand(b, -1, -1, -1) + mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache) + result[out_idx:out_idx + b].copy_(mids.to(dtype=dtype), non_blocking=use_pin) + out_idx += b + pbar.update(b) + tqdm_bar.update(b) + j += b + except model_management.OOM_EXCEPTION: + if batch <= 1: + raise + batch = max(1, batch // 2) + model_management.soft_empty_cache() result[out_idx].copy_(frames[i + 1]) out_idx += 1 finally: tqdm_bar.close() - clear_warp_cache() if use_pin: model_management.synchronize() model_management.unpin_memory(result) # Crop padding and BCHW -> BHWC - if pad_h > 0 or pad_w > 0: - result = result[:, :, :H, :W] - result = result.movedim(1, -1).clamp_(0.0, 1.0).to(dtype=model_management.intermediate_dtype()) + result = result[:, :, :H, :W].movedim(1, -1).clamp_(0.0, 1.0).to(dtype=model_management.intermediate_dtype()) return io.NodeOutput(result) diff --git a/comfy_extras/rife_model/ifnet.py b/comfy_extras/rife_model/ifnet.py deleted file mode 100644 index 6a49d1a9f..000000000 --- a/comfy_extras/rife_model/ifnet.py +++ /dev/null @@ -1,168 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -import comfy.ops - -ops = comfy.ops.disable_weight_init - -_warp_grid_cache = {} -_WARP_GRID_CACHE_MAX = 4 - - -def clear_warp_cache(): - _warp_grid_cache.clear() - - -def warp(img, flow): - B, _, H, W = img.shape - dtype = img.dtype - img = img.float() - flow = flow.float() - cache_key = (H, W, flow.device) - if cache_key not in _warp_grid_cache: - if len(_warp_grid_cache) >= _WARP_GRID_CACHE_MAX: - _warp_grid_cache.pop(next(iter(_warp_grid_cache))) - grid_y, grid_x = torch.meshgrid( - torch.linspace(-1.0, 1.0, H, device=flow.device, dtype=torch.float32), - torch.linspace(-1.0, 1.0, W, device=flow.device, dtype=torch.float32), - indexing="ij", - ) - _warp_grid_cache[cache_key] = torch.stack((grid_x, grid_y), dim=0).unsqueeze(0) - grid = _warp_grid_cache[cache_key].expand(B, -1, -1, -1) - flow_norm = torch.cat([ - flow[:, 0:1] / ((W - 1) / 2), - flow[:, 1:2] / ((H - 1) / 2), - ], dim=1) - grid = (grid + flow_norm).permute(0, 2, 3, 1) - return F.grid_sample(img, grid, mode="bilinear", padding_mode="border", align_corners=True).to(dtype) - - -class Head(nn.Module): - def __init__(self, out_ch=4, device=None, dtype=None, operations=ops): - super().__init__() - self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype) - self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) - self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) - self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype) - self.relu = nn.LeakyReLU(0.2, True) - - def forward(self, x): - x = self.relu(self.cnn0(x)) - x = self.relu(self.cnn1(x)) - x = self.relu(self.cnn2(x)) - return self.cnn3(x) - - -class ResConv(nn.Module): - def __init__(self, c, device=None, dtype=None, operations=ops): - super().__init__() - self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype) - self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype)) - self.relu = nn.LeakyReLU(0.2, True) - - def forward(self, x): - return self.relu(torch.addcmul(x, self.conv(x), self.beta)) - - -class IFBlock(nn.Module): - def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops): - super().__init__() - self.conv0 = nn.Sequential( - nn.Sequential( - operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), - nn.LeakyReLU(0.2, True), - ), - nn.Sequential( - operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), - nn.LeakyReLU(0.2, True), - ), - ) - self.convblock = nn.Sequential( - *(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8)) - ) - self.lastconv = nn.Sequential( - operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), - nn.PixelShuffle(2), - ) - - def forward(self, x, flow=None, scale=1): - x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") - if flow is not None: - flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale) - x = torch.cat((x, flow), 1) - feat = self.conv0(x) - feat = self.convblock(feat) - tmp = self.lastconv(feat) - tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") - flow = tmp[:, :4] * scale - mask = tmp[:, 4:5] - feat = tmp[:, 5:] - return flow, mask, feat - - -class IFNet(nn.Module): - def __init__(self, head_ch=4, channels=None, device=None, dtype=None, operations=ops): - super().__init__() - if channels is None: - channels = [192, 128, 96, 64, 32] - self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations) - block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4 - self.blocks = nn.ModuleList([ - IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) - for i in range(5) - ]) - self.scale_list = [16, 8, 4, 2, 1] - - def get_dtype(self): - return self.encode.cnn0.weight.dtype - - def forward(self, img0, img1, timestep=0.5): - img0 = img0.clamp(0.0, 1.0) - img1 = img1.clamp(0.0, 1.0) - if not isinstance(timestep, torch.Tensor): - timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), - timestep, device=img0.device, dtype=img0.dtype) - f0 = self.encode(img0) - f1 = self.encode(img1) - flow = mask = feat = None - warped_img0 = img0 - warped_img1 = img1 - for i, block in enumerate(self.blocks): - if flow is None: - flow, mask, feat = block( - torch.cat((img0, img1, f0, f1, timestep), 1), - None, scale=self.scale_list[i], - ) - else: - wf0 = warp(f0, flow[:, :2]) - wf1 = warp(f1, flow[:, 2:4]) - fd, mask, feat = block( - torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask, feat), 1), - flow, scale=self.scale_list[i], - ) - flow = flow.add_(fd) - warped_img0 = warp(img0, flow[:, :2]) - warped_img1 = warp(img1, flow[:, 2:4]) - mask = torch.sigmoid(mask) - return torch.lerp(warped_img1, warped_img0, mask) - - -def detect_rife_config(state_dict): - # Determine head output channels from encode.cnn3 (ConvTranspose2d) - # ConvTranspose2d weight shape is (in_ch, out_ch, kH, kW) - head_ch = state_dict["encode.cnn3.weight"].shape[1] - - # Read per-block channel widths from conv0 second layer output channels - # conv0 is Sequential(Sequential(Conv2d, ReLU), Sequential(Conv2d, ReLU)) - # conv0.1.0.weight shape is (c, c//2, 3, 3) - channels = [] - for i in range(5): - key = f"blocks.{i}.conv0.1.0.weight" - if key in state_dict: - channels.append(state_dict[key].shape[0]) - - if len(channels) != 5: - raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}") - - return head_ch, channels From 3cbd1d5f714e33f32b5a53af6073f11458580431 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 4 Apr 2026 17:28:37 +0300 Subject: [PATCH 3/6] Better RAM usage, reduce FILM VRAM peak --- .../frame_interpolation_models/film_net.py | 7 ++ comfy_extras/nodes_frame_interpolation.py | 66 +++++++++++-------- 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/comfy_extras/frame_interpolation_models/film_net.py b/comfy_extras/frame_interpolation_models/film_net.py index 552b78b8c..cf4f6e1e1 100644 --- a/comfy_extras/frame_interpolation_models/film_net.py +++ b/comfy_extras/frame_interpolation_models/film_net.py @@ -105,6 +105,9 @@ class FeatureExtractor(nn.Module): if j <= i: features = torch.cat([features, sub_pyramids[i - j][j]], dim=1) feature_pyramid.append(features) + # Free sub-pyramids no longer needed by future levels + if i >= self.sub_levels - 1: + sub_pyramids[i - self.sub_levels + 1] = None return feature_pyramid @@ -233,9 +236,11 @@ class FILMNet(nn.Module): fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels] bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels] + # Build warp targets and free full pyramids (only first fpl levels needed from here) fpl = self.fusion_pyramid_levels p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]), concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])] + del image_pyr0, image_pyr1, feat_pyr0, feat_pyr1 results = [] dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype) @@ -247,5 +252,7 @@ class FILMNet(nn.Module): bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp) aligned = [torch.cat([fw, bw, bf, ff], dim=1) for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)] + del fwd_warped, bwd_warped, bwd_scaled, fwd_scaled results.append(self.fuse(aligned)) + del aligned return torch.cat(results, dim=0) diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index 723e9c85a..f0e1cf61f 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -104,15 +104,19 @@ class FrameInterpolate(io.ComfyNode): dtype = model.model_dtype() inference_model = model.model - # BHWC -> BCHW - frames = images.movedim(-1, 1).to(dtype=dtype, device=offload_device) - _, C, H, W = frames.shape - - # Pad to model's required alignment (RIFE needs 64, FILM handles any size) + # Free VRAM for inference activations (model weights + ~20x a single frame's worth) + H, W = images.shape[1], images.shape[2] + activation_mem = H * W * 3 * images.element_size() * 20 + model_management.free_memory(activation_mem, device) align = getattr(inference_model, "pad_align", 1) - if align > 1: - from comfy.ldm.common_dit import pad_to_patch_size - frames = pad_to_patch_size(frames, (align, align), padding_mode="reflect") + + # Prepare a single padded frame on device for determining output dimensions + def prepare_frame(idx): + frame = images[idx:idx + 1].movedim(-1, 1).to(dtype=dtype, device=device) + if align > 1: + from comfy.ldm.common_dit import pad_to_patch_size + frame = pad_to_patch_size(frame, (align, align), padding_mode="reflect") + return frame if torch_compile: for name, child in inference_model.named_children(): @@ -132,26 +136,29 @@ class FrameInterpolate(io.ComfyNode): batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit) t_values = [t / multiplier for t in range(1, multiplier)] - _, _, pH, pW = frames.shape - # Pre-allocate output tensor, pin for async GPU->CPU transfer + out_dtype = model_management.intermediate_dtype() total_out_frames = total_pairs * multiplier + 1 - result = torch.empty((total_out_frames, C, pH, pW), dtype=dtype, device=offload_device) - use_pin = model_management.pin_memory(result) - result[0] = frames[0] + result = torch.empty((total_out_frames, 3, H, W), dtype=out_dtype, device=offload_device) + result[0] = images[0].movedim(-1, 0).to(out_dtype) out_idx = 1 - # Pre-compute timestep tensor on device + # Pre-compute timestep tensor on device (padded dimensions needed) + sample = prepare_frame(0) + pH, pW = sample.shape[2], sample.shape[3] ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1) ts_full = ts_full.expand(-1, 1, pH, pW) + del sample multi_fn = getattr(inference_model, "forward_multi_timestep", None) feat_cache = {} + prev_frame = None try: for i in range(total_pairs): - img0_single = frames[i:i + 1].to(device) - img1_single = frames[i + 1:i + 2].to(device) + img0_single = prev_frame if prev_frame is not None else prepare_frame(i) + img1_single = prepare_frame(i + 1) + prev_frame = img1_single # Cache features: img1 of pair N becomes img0 of pair N+1 feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single) @@ -160,11 +167,17 @@ class FrameInterpolate(io.ComfyNode): if multi_fn is not None: # Models with timestep-independent flow can compute it once for all timesteps - mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache) - result[out_idx:out_idx + num_interp].copy_(mids.to(dtype=dtype), non_blocking=use_pin) - out_idx += num_interp - pbar.update(num_interp) - tqdm_bar.update(num_interp) + try: + mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache) + result[out_idx:out_idx + num_interp] = mids[:, :, :H, :W].to(out_dtype) + out_idx += num_interp + pbar.update(num_interp) + tqdm_bar.update(num_interp) + except model_management.OOM_EXCEPTION: + # Fall back to single-timestep calls + model_management.soft_empty_cache() + multi_fn = None + continue else: j = 0 while j < num_interp: @@ -173,7 +186,7 @@ class FrameInterpolate(io.ComfyNode): img0 = img0_single.expand(b, -1, -1, -1) img1 = img1_single.expand(b, -1, -1, -1) mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache) - result[out_idx:out_idx + b].copy_(mids.to(dtype=dtype), non_blocking=use_pin) + result[out_idx:out_idx + b] = mids[:, :, :H, :W].to(out_dtype) out_idx += b pbar.update(b) tqdm_bar.update(b) @@ -184,16 +197,13 @@ class FrameInterpolate(io.ComfyNode): batch = max(1, batch // 2) model_management.soft_empty_cache() - result[out_idx].copy_(frames[i + 1]) + result[out_idx] = images[i + 1].movedim(-1, 0).to(out_dtype) out_idx += 1 finally: tqdm_bar.close() - if use_pin: - model_management.synchronize() - model_management.unpin_memory(result) - # Crop padding and BCHW -> BHWC - result = result[:, :, :H, :W].movedim(1, -1).clamp_(0.0, 1.0).to(dtype=model_management.intermediate_dtype()) + # BCHW -> BHWC + result = result.movedim(1, -1).clamp_(0.0, 1.0) return io.NodeOutput(result) From 390798718c5c4d9b0a7e73b9127c339dd5fa7d54 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 9 Apr 2026 01:07:19 +0300 Subject: [PATCH 4/6] Add model folder placeholder --- comfy_extras/nodes_frame_interpolation.py | 3 ++- models/frame_interpolation/put_frame_interpolation_models_here | 0 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 models/frame_interpolation/put_frame_interpolation_models_here diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index f0e1cf61f..995df6ed1 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -21,7 +21,8 @@ class FrameInterpolationModelLoader(io.ComfyNode): display_name="Load Frame Interpolation Model", category="loaders", inputs=[ - io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation")), + io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation"), + tooltip="Select a frame interpolation model to load. Models must be placed in the 'frame_interpolation' folder."), ], outputs=[ FrameInterpolationModel.Output(), diff --git a/models/frame_interpolation/put_frame_interpolation_models_here b/models/frame_interpolation/put_frame_interpolation_models_here new file mode 100644 index 000000000..e69de29bb From 2637aad79628450654f3634572ba0c8d67271d40 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Thu, 9 Apr 2026 01:20:18 +0300 Subject: [PATCH 5/6] Fix oom fallback frame loss --- comfy_extras/nodes_frame_interpolation.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index 995df6ed1..4d5f5a08a 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -166,6 +166,7 @@ class FrameInterpolate(io.ComfyNode): feat_cache["img1"] = inference_model.extract_features(img1_single) feat_cache["next"] = feat_cache["img1"] + used_multi = False if multi_fn is not None: # Models with timestep-independent flow can compute it once for all timesteps try: @@ -174,12 +175,12 @@ class FrameInterpolate(io.ComfyNode): out_idx += num_interp pbar.update(num_interp) tqdm_bar.update(num_interp) + used_multi = True except model_management.OOM_EXCEPTION: - # Fall back to single-timestep calls model_management.soft_empty_cache() - multi_fn = None - continue - else: + multi_fn = None # fall through to single-timestep path + + if not used_multi: j = 0 while j < num_interp: b = min(batch, num_interp - j) From fe32b376c35fb4588ac766c0636eb7fa9a212d40 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 13 Apr 2026 23:08:15 +0300 Subject: [PATCH 6/6] Remove torch.compile for now --- comfy_extras/nodes_frame_interpolation.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index 4d5f5a08a..34d6dea11 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -84,8 +84,6 @@ class FrameInterpolate(io.ComfyNode): FrameInterpolationModel.Input("model"), io.Image.Input("images"), io.Int.Input("multiplier", default=2, min=2, max=16), - io.Boolean.Input("torch_compile", default=False, optional=True, advanced=True, - tooltip="Requires triton. Compile model submodules for potential speed increase. Adds warmup on first run, recompiles on resolution change."), ], outputs=[ io.Image.Output(), @@ -93,7 +91,7 @@ class FrameInterpolate(io.ComfyNode): ) @classmethod - def execute(cls, model, images, multiplier, torch_compile=False) -> io.NodeOutput: + def execute(cls, model, images, multiplier) -> io.NodeOutput: offload_device = model_management.intermediate_device() num_frames = images.shape[0] @@ -119,15 +117,6 @@ class FrameInterpolate(io.ComfyNode): frame = pad_to_patch_size(frame, (align, align), padding_mode="reflect") return frame - if torch_compile: - for name, child in inference_model.named_children(): - if isinstance(child, (torch.nn.ModuleList, torch.nn.ModuleDict)): - continue - if not hasattr(child, "_compiled"): - compiled = torch.compile(child) - compiled._compiled = True - setattr(inference_model, name, compiled) - # Count total interpolation passes for progress bar total_pairs = num_frames - 1 num_interp = multiplier - 1