import torch from tqdm import tqdm from typing_extensions import override import comfy.model_patcher import comfy.utils import folder_paths from comfy import model_management from comfy_extras.frame_interpolation_models.ifnet import IFNet, detect_rife_config from comfy_extras.frame_interpolation_models.film_net import FILMNet from comfy_api.latest import ComfyExtension, io FrameInterpolationModel = io.Custom("FRAME_INTERPOLATION_MODEL") class FrameInterpolationModelLoader(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="FrameInterpolationModelLoader", display_name="Load Frame Interpolation Model", category="loaders", inputs=[ io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation"), tooltip="Select a frame interpolation model to load. Models must be placed in the 'frame_interpolation' folder."), ], outputs=[ FrameInterpolationModel.Output(), ], ) @classmethod def execute(cls, model_name) -> io.NodeOutput: model_path = folder_paths.get_full_path_or_raise("frame_interpolation", model_name) sd = comfy.utils.load_torch_file(model_path, safe_load=True) model = cls._detect_and_load(sd) dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32 model.eval().to(dtype) patcher = comfy.model_patcher.ModelPatcher( model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), ) return io.NodeOutput(patcher) @classmethod def _detect_and_load(cls, sd): # Try FILM if "extract.extract_sublevels.convs.0.0.conv.weight" in sd: model = FILMNet() model.load_state_dict(sd) return model # Try RIFE (needs key remapping for raw checkpoints) sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""}) key_map = {} for k in sd: for i in range(5): if k.startswith(f"block{i}."): key_map[k] = f"blocks.{i}.{k[len(f'block{i}.'):]}" if key_map: sd = {key_map.get(k, k): v for k, v in sd.items()} sd = {k: v for k, v in sd.items() if not k.startswith(("teacher.", "caltime."))} try: head_ch, channels = detect_rife_config(sd) except (KeyError, ValueError): raise ValueError("Unrecognized frame interpolation model format") model = IFNet(head_ch=head_ch, channels=channels) model.load_state_dict(sd) return model class FrameInterpolate(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="FrameInterpolate", display_name="Frame Interpolate", category="image/video", search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"], inputs=[ FrameInterpolationModel.Input("model"), io.Image.Input("images"), io.Int.Input("multiplier", default=2, min=2, max=16), io.Boolean.Input("torch_compile", default=False, optional=True, advanced=True, tooltip="Requires triton. Compile model submodules for potential speed increase. Adds warmup on first run, recompiles on resolution change."), ], outputs=[ io.Image.Output(), ], ) @classmethod def execute(cls, model, images, multiplier, torch_compile=False) -> io.NodeOutput: offload_device = model_management.intermediate_device() num_frames = images.shape[0] if num_frames < 2 or multiplier < 2: return io.NodeOutput(images) model_management.load_model_gpu(model) device = model.load_device dtype = model.model_dtype() inference_model = model.model # Free VRAM for inference activations (model weights + ~20x a single frame's worth) H, W = images.shape[1], images.shape[2] activation_mem = H * W * 3 * images.element_size() * 20 model_management.free_memory(activation_mem, device) align = getattr(inference_model, "pad_align", 1) # Prepare a single padded frame on device for determining output dimensions def prepare_frame(idx): frame = images[idx:idx + 1].movedim(-1, 1).to(dtype=dtype, device=device) if align > 1: from comfy.ldm.common_dit import pad_to_patch_size frame = pad_to_patch_size(frame, (align, align), padding_mode="reflect") return frame if torch_compile: for name, child in inference_model.named_children(): if isinstance(child, (torch.nn.ModuleList, torch.nn.ModuleDict)): continue if not hasattr(child, "_compiled"): compiled = torch.compile(child) compiled._compiled = True setattr(inference_model, name, compiled) # Count total interpolation passes for progress bar total_pairs = num_frames - 1 num_interp = multiplier - 1 total_steps = total_pairs * num_interp pbar = comfy.utils.ProgressBar(total_steps) tqdm_bar = tqdm(total=total_steps, desc="Frame interpolation") batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit) t_values = [t / multiplier for t in range(1, multiplier)] out_dtype = model_management.intermediate_dtype() total_out_frames = total_pairs * multiplier + 1 result = torch.empty((total_out_frames, 3, H, W), dtype=out_dtype, device=offload_device) result[0] = images[0].movedim(-1, 0).to(out_dtype) out_idx = 1 # Pre-compute timestep tensor on device (padded dimensions needed) sample = prepare_frame(0) pH, pW = sample.shape[2], sample.shape[3] ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1) ts_full = ts_full.expand(-1, 1, pH, pW) del sample multi_fn = getattr(inference_model, "forward_multi_timestep", None) feat_cache = {} prev_frame = None try: for i in range(total_pairs): img0_single = prev_frame if prev_frame is not None else prepare_frame(i) img1_single = prepare_frame(i + 1) prev_frame = img1_single # Cache features: img1 of pair N becomes img0 of pair N+1 feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single) feat_cache["img1"] = inference_model.extract_features(img1_single) feat_cache["next"] = feat_cache["img1"] used_multi = False if multi_fn is not None: # Models with timestep-independent flow can compute it once for all timesteps try: mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache) result[out_idx:out_idx + num_interp] = mids[:, :, :H, :W].to(out_dtype) out_idx += num_interp pbar.update(num_interp) tqdm_bar.update(num_interp) used_multi = True except model_management.OOM_EXCEPTION: model_management.soft_empty_cache() multi_fn = None # fall through to single-timestep path if not used_multi: j = 0 while j < num_interp: b = min(batch, num_interp - j) try: img0 = img0_single.expand(b, -1, -1, -1) img1 = img1_single.expand(b, -1, -1, -1) mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache) result[out_idx:out_idx + b] = mids[:, :, :H, :W].to(out_dtype) out_idx += b pbar.update(b) tqdm_bar.update(b) j += b except model_management.OOM_EXCEPTION: if batch <= 1: raise batch = max(1, batch // 2) model_management.soft_empty_cache() result[out_idx] = images[i + 1].movedim(-1, 0).to(out_dtype) out_idx += 1 finally: tqdm_bar.close() # BCHW -> BHWC result = result.movedim(1, -1).clamp_(0.0, 1.0) return io.NodeOutput(result) class FrameInterpolationExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ FrameInterpolationModelLoader, FrameInterpolate, ] async def comfy_entrypoint() -> FrameInterpolationExtension: return FrameInterpolationExtension()