From d56887ac52d4acf72dd03cefa5de89364ff3c125 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Thu, 30 Apr 2026 15:49:44 +0200 Subject: [PATCH] Add Optical Flow Loader. --- comfy_extras/nodes_void.py | 80 ++++++++++++++++++++++++++++++++- comfy_extras/void_noise_warp.py | 33 +++++++------- folder_paths.py | 2 + 3 files changed, 95 insertions(+), 20 deletions(-) diff --git a/comfy_extras/nodes_void.py b/comfy_extras/nodes_void.py index aeffb3ee2..e7a8f3757 100644 --- a/comfy_extras/nodes_void.py +++ b/comfy_extras/nodes_void.py @@ -4,15 +4,21 @@ import torch import comfy import comfy.model_management +import comfy.model_patcher import comfy.samplers import comfy.utils +import folder_paths import node_helpers import nodes from comfy.utils import model_trange as trange from comfy_api.latest import ComfyExtension, io +from torchvision.models.optical_flow import raft_large from typing_extensions import override -from comfy_extras.void_noise_warp import get_noise_from_video + +from comfy_extras.void_noise_warp import RaftOpticalFlow, get_noise_from_video + +OpticalFlow = io.Custom("OPTICAL_FLOW") TEMPORAL_COMPRESSION = 4 PATCH_SIZE_T = 2 @@ -38,6 +44,67 @@ def _valid_void_length(length: int) -> int: return (target_latent_t - 1) * TEMPORAL_COMPRESSION + 1 +class OpticalFlowLoader(io.ComfyNode): + """Load an optical flow model from ``models/optical_flow/``. + + Only torchvision's RAFT-large format is recognized today (the model used + by VOIDWarpedNoise). The checkpoint must be placed under + ``models/optical_flow/`` — ComfyUI never downloads optical-flow weights + at runtime. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="OpticalFlowLoader", + display_name="Load Optical Flow Model", + category="loaders", + inputs=[ + io.Combo.Input( + "model_name", + options=folder_paths.get_filename_list("optical_flow"), + tooltip=( + "Optical flow model to load. Files must be placed in the " + "'optical_flow' folder. Today only torchvision's " + "raft_large.pth is supported." + ), + ), + ], + outputs=[ + OpticalFlow.Output(), + ], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + + model_path = folder_paths.get_full_path_or_raise("optical_flow", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + + has_raft_keys = ( + any(k.startswith("feature_encoder.") for k in sd) + and any(k.startswith("context_encoder.") for k in sd) + and any(k.startswith("update_block.") for k in sd) + ) + if not has_raft_keys: + raise ValueError( + "Unrecognized optical flow model format: expected a torchvision " + "RAFT-large state dict with 'feature_encoder.', 'context_encoder.' " + "and 'update_block.' prefixes." + ) + + model = raft_large(weights=None, progress=False) + model.load_state_dict(sd) + model.eval().to(torch.float32) + + patcher = comfy.model_patcher.ModelPatcher( + model, + load_device=comfy.model_management.get_torch_device(), + offload_device=comfy.model_management.unet_offload_device(), + ) + return io.NodeOutput(patcher) + + class VOIDQuadmaskPreprocess(io.ComfyNode): """Preprocess a quadmask video for VOID inpainting. @@ -222,6 +289,10 @@ class VOIDWarpedNoise(io.ComfyNode): node_id="VOIDWarpedNoise", category="latent/video", inputs=[ + OpticalFlow.Input( + "optical_flow", + tooltip="Optical flow model from OpticalFlowLoader (RAFT-large).", + ), io.Image.Input("video", tooltip="Pass 1 output video frames [T, H, W, 3]"), io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8), io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8), @@ -236,7 +307,7 @@ class VOIDWarpedNoise(io.ComfyNode): ) @classmethod - def execute(cls, video, width, height, length, batch_size) -> io.NodeOutput: + def execute(cls, optical_flow, video, width, height, length, batch_size) -> io.NodeOutput: adjusted_length = _valid_void_length(length) if adjusted_length != length: @@ -257,6 +328,9 @@ class VOIDWarpedNoise(io.ComfyNode): # rest of the ComfyUI pipeline. device = comfy.model_management.get_torch_device() + comfy.model_management.load_model_gpu(optical_flow) + raft = RaftOpticalFlow(optical_flow.model, device=device) + vid = video[:length].to(device) vid = comfy.utils.common_upscale( vid.movedim(-1, 1), width, height, "bilinear", "center" @@ -269,6 +343,7 @@ class VOIDWarpedNoise(io.ComfyNode): warped = get_noise_from_video( vid_uint8, + raft, noise_channels=16, resize_frames=FRAME, resize_flow=FLOW, @@ -395,6 +470,7 @@ class VOIDExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ + OpticalFlowLoader, VOIDQuadmaskPreprocess, VOIDInpaintConditioning, VOIDWarpedNoise, diff --git a/comfy_extras/void_noise_warp.py b/comfy_extras/void_noise_warp.py index 4f7ff470f..fcc9a5f8b 100644 --- a/comfy_extras/void_noise_warp.py +++ b/comfy_extras/void_noise_warp.py @@ -9,8 +9,10 @@ Adapted from RyannDaGreat/CommonSource (MIT License, Ryan Burgert): Only the code paths that ``comfy_extras/nodes_void.py::VOIDWarpedNoise`` actually uses (torch THWC uint8 input, no background removal, no visualization, no disk I/O, default warp/noise params) have been inlined. External ``rp`` utilities -have been replaced with equivalents from torch.nn.functional / einops / -torchvision. +have been replaced with equivalents from torch.nn.functional / einops. The +RAFT optical-flow model itself is loaded offline via ``OpticalFlowLoader`` in +``nodes_void.py`` and passed into ``get_noise_from_video`` by the caller; this +module never downloads weights at runtime. """ import logging @@ -19,7 +21,6 @@ from typing import Optional import torch import torch.nn.functional as F from einops import rearrange -from torchvision.models.optical_flow import raft_large import comfy.model_management @@ -345,14 +346,20 @@ class NoiseWarper: # --------------------------------------------------------------------------- class RaftOpticalFlow: - """Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow.""" + """RAFT-large wrapper around a pre-loaded torchvision model. - def __init__(self, device=None): + ``model`` must be the ``torchvision.models.optical_flow.raft_large`` module + with its weights already populated; this class is load-agnostic so the + caller owns downloading/offload concerns (see ``OpticalFlowLoader`` in + ``nodes_void.py``). ``__call__`` returns a ``(2, H, W)`` flow. + """ + + def __init__(self, model, device=None): if device is None: device = comfy.model_management.get_torch_device() device = torch.device(device) if not isinstance(device, torch.device) else device - model = raft_large(weights="DEFAULT", progress=False).to(device) + model = model.to(device) model.eval() self.device = device self.model = model @@ -384,22 +391,13 @@ class RaftOpticalFlow: return flow -_raft_cache: dict = {} - - -def _get_raft_model(device): - key = str(device) - if key not in _raft_cache: - _raft_cache[key] = RaftOpticalFlow(device=device) - return _raft_cache[key] - - # --------------------------------------------------------------------------- # Narrow entry point used by VOIDWarpedNoise # --------------------------------------------------------------------------- def get_noise_from_video( video_frames: torch.Tensor, + raft: RaftOpticalFlow, *, noise_channels: int = 16, resize_frames: float = 0.5, @@ -411,6 +409,7 @@ def get_noise_from_video( Args: video_frames: ``(T, H, W, 3)`` uint8 torch tensor. + raft: Pre-loaded RAFT optical-flow wrapper (see ``RaftOpticalFlow``). noise_channels: Channels in the output noise. resize_frames: Pre-RAFT frame scale factor. resize_flow: Post-flow up-scale factor applied to the optical flow; @@ -465,8 +464,6 @@ def get_noise_from_video( internal_h, internal_w, downscale_factor, ) - raft = _get_raft_model(device) - with torch.no_grad(): warper = NoiseWarper( c=noise_channels, h=internal_h, w=internal_w, device=device, diff --git a/folder_paths.py b/folder_paths.py index 80f4b291a..322193aae 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -54,6 +54,8 @@ folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_enc folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions) +folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input")