diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py new file mode 100644 index 000000000..d90feb778 --- /dev/null +++ b/comfy_extras/nodes_frame_interpolation.py @@ -0,0 +1,183 @@ +import torch +import torch.nn.functional as F +from tqdm import tqdm +from typing_extensions import override + +import comfy.model_patcher +import comfy.utils +import folder_paths +from comfy import model_management +from comfy_extras.rife_model.ifnet import IFNet, detect_rife_config, clear_warp_cache +from comfy_api.latest import ComfyExtension, io + +FrameInterpolationModel = io.Custom("FRAME_INTERPOLATION_MODEL") + + +class FrameInterpolationModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FrameInterpolationModelLoader", + display_name="Load Frame Interpolation Model", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation")), + ], + outputs=[ + FrameInterpolationModel.Output(), + ], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + model_path = folder_paths.get_full_path_or_raise("frame_interpolation", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + + # Strip common prefixes (DataParallel, RIFE model wrapper) + sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""}) + + # Convert blockN.xxx keys to blocks.N.xxx if needed + key_map = {} + for k in sd: + for i in range(5): + prefix = f"block{i}." + if k.startswith(prefix): + key_map[k] = f"blocks.{i}.{k[len(prefix):]}" + if key_map: + new_sd = {} + for k, v in sd.items(): + new_sd[key_map.get(k, k)] = v + sd = new_sd + + # Filter out training-only keys (teacher distillation, timestamp calibration) + sd = {k: v for k, v in sd.items() + if not k.startswith(("teacher.", "caltime."))} + + head_ch, channels = detect_rife_config(sd) + model = IFNet(head_ch=head_ch, channels=channels) + model.load_state_dict(sd) + # RIFE is a small pixel-space model similar to VAE, bf16 produces artifacts due to low mantissa precision + dtype = model_management.vae_dtype(device=model_management.get_torch_device(), + allowed_dtypes=[torch.float16, torch.float32]) + model.eval().to(dtype) + patcher = comfy.model_patcher.ModelPatcher( + model, + load_device=model_management.get_torch_device(), + offload_device=model_management.unet_offload_device(), + ) + return io.NodeOutput(patcher) + + +class FrameInterpolate(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FrameInterpolate", + display_name="Frame Interpolate", + category="image/video", + search_aliases=["rife", "frame interpolation", "slow motion", "interpolate frames"], + inputs=[ + FrameInterpolationModel.Input("model"), + io.Image.Input("images"), + io.Int.Input("multiplier", default=2, min=2, max=16), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, model, images, multiplier) -> io.NodeOutput: + offload_device = model_management.intermediate_device() + + num_frames = images.shape[0] + if num_frames < 2 or multiplier < 2: + return io.NodeOutput(images) + + model_management.load_model_gpu(model) + device = model.load_device + inference_model = model.model + dtype = model.model_dtype() + + # BHWC -> BCHW + frames = images.movedim(-1, 1).to(dtype=dtype, device=offload_device) + _, C, H, W = frames.shape + + # Pad to multiple of 64 + pad_h = (64 - H % 64) % 64 + pad_w = (64 - W % 64) % 64 + if pad_h > 0 or pad_w > 0: + frames = F.pad(frames, (0, pad_w, 0, pad_h), mode="reflect") + + # Count total interpolation passes for progress bar + total_pairs = num_frames - 1 + num_interp = multiplier - 1 + total_steps = total_pairs * num_interp + pbar = comfy.utils.ProgressBar(total_steps) + tqdm_bar = tqdm(total=total_steps, desc="Frame interpolation") + + batch = num_interp + t_values = [t / multiplier for t in range(1, multiplier)] + _, _, pH, pW = frames.shape + + # Pre-allocate output tensor, pin for async GPU->CPU transfer + total_out_frames = total_pairs * multiplier + 1 + result = torch.empty((total_out_frames, C, pH, pW), dtype=dtype, device=offload_device) + use_pin = model_management.pin_memory(result) + result[0] = frames[0] + out_idx = 1 + + # Pre-compute timestep tensor on device + ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1) + ts_full = ts_full.expand(-1, 1, pH, pW) + + try: + for i in range(total_pairs): + img0_single = frames[i:i + 1].to(device) + img1_single = frames[i + 1:i + 2].to(device) + + j = 0 + while j < num_interp: + b = min(batch, num_interp - j) + try: + img0 = img0_single.expand(b, -1, -1, -1) + img1 = img1_single.expand(b, -1, -1, -1) + mids = inference_model(img0, img1, timestep=ts_full[j:j + b]) + result[out_idx:out_idx + b].copy_(mids.to(dtype=dtype), non_blocking=use_pin) + out_idx += b + pbar.update(b) + tqdm_bar.update(b) + j += b + except model_management.OOM_EXCEPTION: + if batch <= 1: + raise + batch = max(1, batch // 2) + model_management.soft_empty_cache() + + result[out_idx].copy_(frames[i + 1]) + out_idx += 1 + finally: + tqdm_bar.close() + clear_warp_cache() + if use_pin: + model_management.synchronize() + model_management.unpin_memory(result) + + # Crop padding and BCHW -> BHWC + if pad_h > 0 or pad_w > 0: + result = result[:, :, :H, :W] + result = result.movedim(1, -1).clamp_(0.0, 1.0).to(dtype=model_management.intermediate_dtype()) + return io.NodeOutput(result) + + +class FrameInterpolationExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + FrameInterpolationModelLoader, + FrameInterpolate, + ] + + +async def comfy_entrypoint() -> FrameInterpolationExtension: + return FrameInterpolationExtension() diff --git a/comfy_extras/rife_model/ifnet.py b/comfy_extras/rife_model/ifnet.py new file mode 100644 index 000000000..6a49d1a9f --- /dev/null +++ b/comfy_extras/rife_model/ifnet.py @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops + +ops = comfy.ops.disable_weight_init + +_warp_grid_cache = {} +_WARP_GRID_CACHE_MAX = 4 + + +def clear_warp_cache(): + _warp_grid_cache.clear() + + +def warp(img, flow): + B, _, H, W = img.shape + dtype = img.dtype + img = img.float() + flow = flow.float() + cache_key = (H, W, flow.device) + if cache_key not in _warp_grid_cache: + if len(_warp_grid_cache) >= _WARP_GRID_CACHE_MAX: + _warp_grid_cache.pop(next(iter(_warp_grid_cache))) + grid_y, grid_x = torch.meshgrid( + torch.linspace(-1.0, 1.0, H, device=flow.device, dtype=torch.float32), + torch.linspace(-1.0, 1.0, W, device=flow.device, dtype=torch.float32), + indexing="ij", + ) + _warp_grid_cache[cache_key] = torch.stack((grid_x, grid_y), dim=0).unsqueeze(0) + grid = _warp_grid_cache[cache_key].expand(B, -1, -1, -1) + flow_norm = torch.cat([ + flow[:, 0:1] / ((W - 1) / 2), + flow[:, 1:2] / ((H - 1) / 2), + ], dim=1) + grid = (grid + flow_norm).permute(0, 2, 3, 1) + return F.grid_sample(img, grid, mode="bilinear", padding_mode="border", align_corners=True).to(dtype) + + +class Head(nn.Module): + def __init__(self, out_ch=4, device=None, dtype=None, operations=ops): + super().__init__() + self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype) + self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + x = self.relu(self.cnn0(x)) + x = self.relu(self.cnn1(x)) + x = self.relu(self.cnn2(x)) + return self.cnn3(x) + + +class ResConv(nn.Module): + def __init__(self, c, device=None, dtype=None, operations=ops): + super().__init__() + self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype)) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(torch.addcmul(x, self.conv(x), self.beta)) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops): + super().__init__() + self.conv0 = nn.Sequential( + nn.Sequential( + operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), + nn.LeakyReLU(0.2, True), + ), + nn.Sequential( + operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), + nn.LeakyReLU(0.2, True), + ), + ) + self.convblock = nn.Sequential( + *(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8)) + ) + self.lastconv = nn.Sequential( + operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), + nn.PixelShuffle(2), + ) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") + if flow is not None: + flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale) + x = torch.cat((x, flow), 1) + feat = self.conv0(x) + feat = self.convblock(feat) + tmp = self.lastconv(feat) + tmp = F.interpolate(tmp, scale_factor=scale, mode="bilinear") + flow = tmp[:, :4] * scale + mask = tmp[:, 4:5] + feat = tmp[:, 5:] + return flow, mask, feat + + +class IFNet(nn.Module): + def __init__(self, head_ch=4, channels=None, device=None, dtype=None, operations=ops): + super().__init__() + if channels is None: + channels = [192, 128, 96, 64, 32] + self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations) + block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4 + self.blocks = nn.ModuleList([ + IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) + for i in range(5) + ]) + self.scale_list = [16, 8, 4, 2, 1] + + def get_dtype(self): + return self.encode.cnn0.weight.dtype + + def forward(self, img0, img1, timestep=0.5): + img0 = img0.clamp(0.0, 1.0) + img1 = img1.clamp(0.0, 1.0) + if not isinstance(timestep, torch.Tensor): + timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), + timestep, device=img0.device, dtype=img0.dtype) + f0 = self.encode(img0) + f1 = self.encode(img1) + flow = mask = feat = None + warped_img0 = img0 + warped_img1 = img1 + for i, block in enumerate(self.blocks): + if flow is None: + flow, mask, feat = block( + torch.cat((img0, img1, f0, f1, timestep), 1), + None, scale=self.scale_list[i], + ) + else: + wf0 = warp(f0, flow[:, :2]) + wf1 = warp(f1, flow[:, 2:4]) + fd, mask, feat = block( + torch.cat((warped_img0, warped_img1, wf0, wf1, timestep, mask, feat), 1), + flow, scale=self.scale_list[i], + ) + flow = flow.add_(fd) + warped_img0 = warp(img0, flow[:, :2]) + warped_img1 = warp(img1, flow[:, 2:4]) + mask = torch.sigmoid(mask) + return torch.lerp(warped_img1, warped_img0, mask) + + +def detect_rife_config(state_dict): + # Determine head output channels from encode.cnn3 (ConvTranspose2d) + # ConvTranspose2d weight shape is (in_ch, out_ch, kH, kW) + head_ch = state_dict["encode.cnn3.weight"].shape[1] + + # Read per-block channel widths from conv0 second layer output channels + # conv0 is Sequential(Sequential(Conv2d, ReLU), Sequential(Conv2d, ReLU)) + # conv0.1.0.weight shape is (c, c//2, 3, 3) + channels = [] + for i in range(5): + key = f"blocks.{i}.conv0.1.0.weight" + if key in state_dict: + channels.append(state_dict[key].shape[0]) + + if len(channels) != 5: + raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}") + + return head_ch, channels diff --git a/folder_paths.py b/folder_paths.py index 9c96540e3..80f4b291a 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions) +folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") diff --git a/nodes.py b/nodes.py index 299b3d758..bb38e07b8 100644 --- a/nodes.py +++ b/nodes.py @@ -2457,7 +2457,8 @@ async def init_builtin_extra_nodes(): "nodes_number_convert.py", "nodes_painter.py", "nodes_curve.py", - "nodes_rtdetr.py" + "nodes_rtdetr.py", + "nodes_frame_interpolation.py", ] import_failed = []