Add Optical Flow Loader.
Some checks failed
Build package / Build Test (3.12) (push) Has been cancelled
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled

This commit is contained in:
Talmaj Marinc 2026-04-30 15:49:44 +02:00
parent 752355991a
commit d56887ac52
3 changed files with 95 additions and 20 deletions

View File

@ -4,15 +4,21 @@ import torch
import comfy import comfy
import comfy.model_management import comfy.model_management
import comfy.model_patcher
import comfy.samplers import comfy.samplers
import comfy.utils import comfy.utils
import folder_paths
import node_helpers import node_helpers
import nodes import nodes
from comfy.utils import model_trange as trange from comfy.utils import model_trange as trange
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
from torchvision.models.optical_flow import raft_large
from typing_extensions import override from typing_extensions import override
from comfy_extras.void_noise_warp import get_noise_from_video
from comfy_extras.void_noise_warp import RaftOpticalFlow, get_noise_from_video
OpticalFlow = io.Custom("OPTICAL_FLOW")
TEMPORAL_COMPRESSION = 4 TEMPORAL_COMPRESSION = 4
PATCH_SIZE_T = 2 PATCH_SIZE_T = 2
@ -38,6 +44,67 @@ def _valid_void_length(length: int) -> int:
return (target_latent_t - 1) * TEMPORAL_COMPRESSION + 1 return (target_latent_t - 1) * TEMPORAL_COMPRESSION + 1
class OpticalFlowLoader(io.ComfyNode):
"""Load an optical flow model from ``models/optical_flow/``.
Only torchvision's RAFT-large format is recognized today (the model used
by VOIDWarpedNoise). The checkpoint must be placed under
``models/optical_flow/`` ComfyUI never downloads optical-flow weights
at runtime.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="OpticalFlowLoader",
display_name="Load Optical Flow Model",
category="loaders",
inputs=[
io.Combo.Input(
"model_name",
options=folder_paths.get_filename_list("optical_flow"),
tooltip=(
"Optical flow model to load. Files must be placed in the "
"'optical_flow' folder. Today only torchvision's "
"raft_large.pth is supported."
),
),
],
outputs=[
OpticalFlow.Output(),
],
)
@classmethod
def execute(cls, model_name) -> io.NodeOutput:
model_path = folder_paths.get_full_path_or_raise("optical_flow", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
has_raft_keys = (
any(k.startswith("feature_encoder.") for k in sd)
and any(k.startswith("context_encoder.") for k in sd)
and any(k.startswith("update_block.") for k in sd)
)
if not has_raft_keys:
raise ValueError(
"Unrecognized optical flow model format: expected a torchvision "
"RAFT-large state dict with 'feature_encoder.', 'context_encoder.' "
"and 'update_block.' prefixes."
)
model = raft_large(weights=None, progress=False)
model.load_state_dict(sd)
model.eval().to(torch.float32)
patcher = comfy.model_patcher.ModelPatcher(
model,
load_device=comfy.model_management.get_torch_device(),
offload_device=comfy.model_management.unet_offload_device(),
)
return io.NodeOutput(patcher)
class VOIDQuadmaskPreprocess(io.ComfyNode): class VOIDQuadmaskPreprocess(io.ComfyNode):
"""Preprocess a quadmask video for VOID inpainting. """Preprocess a quadmask video for VOID inpainting.
@ -222,6 +289,10 @@ class VOIDWarpedNoise(io.ComfyNode):
node_id="VOIDWarpedNoise", node_id="VOIDWarpedNoise",
category="latent/video", category="latent/video",
inputs=[ inputs=[
OpticalFlow.Input(
"optical_flow",
tooltip="Optical flow model from OpticalFlowLoader (RAFT-large).",
),
io.Image.Input("video", tooltip="Pass 1 output video frames [T, H, W, 3]"), io.Image.Input("video", tooltip="Pass 1 output video frames [T, H, W, 3]"),
io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8), io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8), io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8),
@ -236,7 +307,7 @@ class VOIDWarpedNoise(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, video, width, height, length, batch_size) -> io.NodeOutput: def execute(cls, optical_flow, video, width, height, length, batch_size) -> io.NodeOutput:
adjusted_length = _valid_void_length(length) adjusted_length = _valid_void_length(length)
if adjusted_length != length: if adjusted_length != length:
@ -257,6 +328,9 @@ class VOIDWarpedNoise(io.ComfyNode):
# rest of the ComfyUI pipeline. # rest of the ComfyUI pipeline.
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
comfy.model_management.load_model_gpu(optical_flow)
raft = RaftOpticalFlow(optical_flow.model, device=device)
vid = video[:length].to(device) vid = video[:length].to(device)
vid = comfy.utils.common_upscale( vid = comfy.utils.common_upscale(
vid.movedim(-1, 1), width, height, "bilinear", "center" vid.movedim(-1, 1), width, height, "bilinear", "center"
@ -269,6 +343,7 @@ class VOIDWarpedNoise(io.ComfyNode):
warped = get_noise_from_video( warped = get_noise_from_video(
vid_uint8, vid_uint8,
raft,
noise_channels=16, noise_channels=16,
resize_frames=FRAME, resize_frames=FRAME,
resize_flow=FLOW, resize_flow=FLOW,
@ -395,6 +470,7 @@ class VOIDExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ return [
OpticalFlowLoader,
VOIDQuadmaskPreprocess, VOIDQuadmaskPreprocess,
VOIDInpaintConditioning, VOIDInpaintConditioning,
VOIDWarpedNoise, VOIDWarpedNoise,

View File

@ -9,8 +9,10 @@ Adapted from RyannDaGreat/CommonSource (MIT License, Ryan Burgert):
Only the code paths that ``comfy_extras/nodes_void.py::VOIDWarpedNoise`` actually Only the code paths that ``comfy_extras/nodes_void.py::VOIDWarpedNoise`` actually
uses (torch THWC uint8 input, no background removal, no visualization, no disk uses (torch THWC uint8 input, no background removal, no visualization, no disk
I/O, default warp/noise params) have been inlined. External ``rp`` utilities I/O, default warp/noise params) have been inlined. External ``rp`` utilities
have been replaced with equivalents from torch.nn.functional / einops / have been replaced with equivalents from torch.nn.functional / einops. The
torchvision. RAFT optical-flow model itself is loaded offline via ``OpticalFlowLoader`` in
``nodes_void.py`` and passed into ``get_noise_from_video`` by the caller; this
module never downloads weights at runtime.
""" """
import logging import logging
@ -19,7 +21,6 @@ from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from torchvision.models.optical_flow import raft_large
import comfy.model_management import comfy.model_management
@ -345,14 +346,20 @@ class NoiseWarper:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class RaftOpticalFlow: class RaftOpticalFlow:
"""Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow.""" """RAFT-large wrapper around a pre-loaded torchvision model.
def __init__(self, device=None): ``model`` must be the ``torchvision.models.optical_flow.raft_large`` module
with its weights already populated; this class is load-agnostic so the
caller owns downloading/offload concerns (see ``OpticalFlowLoader`` in
``nodes_void.py``). ``__call__`` returns a ``(2, H, W)`` flow.
"""
def __init__(self, model, device=None):
if device is None: if device is None:
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
device = torch.device(device) if not isinstance(device, torch.device) else device device = torch.device(device) if not isinstance(device, torch.device) else device
model = raft_large(weights="DEFAULT", progress=False).to(device) model = model.to(device)
model.eval() model.eval()
self.device = device self.device = device
self.model = model self.model = model
@ -384,22 +391,13 @@ class RaftOpticalFlow:
return flow return flow
_raft_cache: dict = {}
def _get_raft_model(device):
key = str(device)
if key not in _raft_cache:
_raft_cache[key] = RaftOpticalFlow(device=device)
return _raft_cache[key]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Narrow entry point used by VOIDWarpedNoise # Narrow entry point used by VOIDWarpedNoise
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def get_noise_from_video( def get_noise_from_video(
video_frames: torch.Tensor, video_frames: torch.Tensor,
raft: RaftOpticalFlow,
*, *,
noise_channels: int = 16, noise_channels: int = 16,
resize_frames: float = 0.5, resize_frames: float = 0.5,
@ -411,6 +409,7 @@ def get_noise_from_video(
Args: Args:
video_frames: ``(T, H, W, 3)`` uint8 torch tensor. video_frames: ``(T, H, W, 3)`` uint8 torch tensor.
raft: Pre-loaded RAFT optical-flow wrapper (see ``RaftOpticalFlow``).
noise_channels: Channels in the output noise. noise_channels: Channels in the output noise.
resize_frames: Pre-RAFT frame scale factor. resize_frames: Pre-RAFT frame scale factor.
resize_flow: Post-flow up-scale factor applied to the optical flow; resize_flow: Post-flow up-scale factor applied to the optical flow;
@ -465,8 +464,6 @@ def get_noise_from_video(
internal_h, internal_w, downscale_factor, internal_h, internal_w, downscale_factor,
) )
raft = _get_raft_model(device)
with torch.no_grad(): with torch.no_grad():
warper = NoiseWarper( warper = NoiseWarper(
c=noise_channels, h=internal_h, w=internal_w, device=device, c=noise_channels, h=internal_h, w=internal_w, device=device,

View File

@ -54,6 +54,8 @@ folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_enc
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions) folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
output_directory = os.path.join(base_path, "output") output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp") temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input") input_directory = os.path.join(base_path, "input")