mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-04 22:32:32 +08:00
Add Optical Flow Loader.
Some checks failed
Build package / Build Test (3.12) (push) Has been cancelled
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Some checks failed
Build package / Build Test (3.12) (push) Has been cancelled
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
This commit is contained in:
parent
752355991a
commit
d56887ac52
@ -4,15 +4,21 @@ import torch
|
|||||||
|
|
||||||
import comfy
|
import comfy
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.model_patcher
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import folder_paths
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import nodes
|
import nodes
|
||||||
from comfy.utils import model_trange as trange
|
from comfy.utils import model_trange as trange
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from torchvision.models.optical_flow import raft_large
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_extras.void_noise_warp import get_noise_from_video
|
|
||||||
|
from comfy_extras.void_noise_warp import RaftOpticalFlow, get_noise_from_video
|
||||||
|
|
||||||
|
OpticalFlow = io.Custom("OPTICAL_FLOW")
|
||||||
|
|
||||||
TEMPORAL_COMPRESSION = 4
|
TEMPORAL_COMPRESSION = 4
|
||||||
PATCH_SIZE_T = 2
|
PATCH_SIZE_T = 2
|
||||||
@ -38,6 +44,67 @@ def _valid_void_length(length: int) -> int:
|
|||||||
return (target_latent_t - 1) * TEMPORAL_COMPRESSION + 1
|
return (target_latent_t - 1) * TEMPORAL_COMPRESSION + 1
|
||||||
|
|
||||||
|
|
||||||
|
class OpticalFlowLoader(io.ComfyNode):
|
||||||
|
"""Load an optical flow model from ``models/optical_flow/``.
|
||||||
|
|
||||||
|
Only torchvision's RAFT-large format is recognized today (the model used
|
||||||
|
by VOIDWarpedNoise). The checkpoint must be placed under
|
||||||
|
``models/optical_flow/`` — ComfyUI never downloads optical-flow weights
|
||||||
|
at runtime.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="OpticalFlowLoader",
|
||||||
|
display_name="Load Optical Flow Model",
|
||||||
|
category="loaders",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input(
|
||||||
|
"model_name",
|
||||||
|
options=folder_paths.get_filename_list("optical_flow"),
|
||||||
|
tooltip=(
|
||||||
|
"Optical flow model to load. Files must be placed in the "
|
||||||
|
"'optical_flow' folder. Today only torchvision's "
|
||||||
|
"raft_large.pth is supported."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
OpticalFlow.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model_name) -> io.NodeOutput:
|
||||||
|
|
||||||
|
model_path = folder_paths.get_full_path_or_raise("optical_flow", model_name)
|
||||||
|
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||||
|
|
||||||
|
has_raft_keys = (
|
||||||
|
any(k.startswith("feature_encoder.") for k in sd)
|
||||||
|
and any(k.startswith("context_encoder.") for k in sd)
|
||||||
|
and any(k.startswith("update_block.") for k in sd)
|
||||||
|
)
|
||||||
|
if not has_raft_keys:
|
||||||
|
raise ValueError(
|
||||||
|
"Unrecognized optical flow model format: expected a torchvision "
|
||||||
|
"RAFT-large state dict with 'feature_encoder.', 'context_encoder.' "
|
||||||
|
"and 'update_block.' prefixes."
|
||||||
|
)
|
||||||
|
|
||||||
|
model = raft_large(weights=None, progress=False)
|
||||||
|
model.load_state_dict(sd)
|
||||||
|
model.eval().to(torch.float32)
|
||||||
|
|
||||||
|
patcher = comfy.model_patcher.ModelPatcher(
|
||||||
|
model,
|
||||||
|
load_device=comfy.model_management.get_torch_device(),
|
||||||
|
offload_device=comfy.model_management.unet_offload_device(),
|
||||||
|
)
|
||||||
|
return io.NodeOutput(patcher)
|
||||||
|
|
||||||
|
|
||||||
class VOIDQuadmaskPreprocess(io.ComfyNode):
|
class VOIDQuadmaskPreprocess(io.ComfyNode):
|
||||||
"""Preprocess a quadmask video for VOID inpainting.
|
"""Preprocess a quadmask video for VOID inpainting.
|
||||||
|
|
||||||
@ -222,6 +289,10 @@ class VOIDWarpedNoise(io.ComfyNode):
|
|||||||
node_id="VOIDWarpedNoise",
|
node_id="VOIDWarpedNoise",
|
||||||
category="latent/video",
|
category="latent/video",
|
||||||
inputs=[
|
inputs=[
|
||||||
|
OpticalFlow.Input(
|
||||||
|
"optical_flow",
|
||||||
|
tooltip="Optical flow model from OpticalFlowLoader (RAFT-large).",
|
||||||
|
),
|
||||||
io.Image.Input("video", tooltip="Pass 1 output video frames [T, H, W, 3]"),
|
io.Image.Input("video", tooltip="Pass 1 output video frames [T, H, W, 3]"),
|
||||||
io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
@ -236,7 +307,7 @@ class VOIDWarpedNoise(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, video, width, height, length, batch_size) -> io.NodeOutput:
|
def execute(cls, optical_flow, video, width, height, length, batch_size) -> io.NodeOutput:
|
||||||
|
|
||||||
adjusted_length = _valid_void_length(length)
|
adjusted_length = _valid_void_length(length)
|
||||||
if adjusted_length != length:
|
if adjusted_length != length:
|
||||||
@ -257,6 +328,9 @@ class VOIDWarpedNoise(io.ComfyNode):
|
|||||||
# rest of the ComfyUI pipeline.
|
# rest of the ComfyUI pipeline.
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
|
comfy.model_management.load_model_gpu(optical_flow)
|
||||||
|
raft = RaftOpticalFlow(optical_flow.model, device=device)
|
||||||
|
|
||||||
vid = video[:length].to(device)
|
vid = video[:length].to(device)
|
||||||
vid = comfy.utils.common_upscale(
|
vid = comfy.utils.common_upscale(
|
||||||
vid.movedim(-1, 1), width, height, "bilinear", "center"
|
vid.movedim(-1, 1), width, height, "bilinear", "center"
|
||||||
@ -269,6 +343,7 @@ class VOIDWarpedNoise(io.ComfyNode):
|
|||||||
|
|
||||||
warped = get_noise_from_video(
|
warped = get_noise_from_video(
|
||||||
vid_uint8,
|
vid_uint8,
|
||||||
|
raft,
|
||||||
noise_channels=16,
|
noise_channels=16,
|
||||||
resize_frames=FRAME,
|
resize_frames=FRAME,
|
||||||
resize_flow=FLOW,
|
resize_flow=FLOW,
|
||||||
@ -395,6 +470,7 @@ class VOIDExtension(ComfyExtension):
|
|||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
|
OpticalFlowLoader,
|
||||||
VOIDQuadmaskPreprocess,
|
VOIDQuadmaskPreprocess,
|
||||||
VOIDInpaintConditioning,
|
VOIDInpaintConditioning,
|
||||||
VOIDWarpedNoise,
|
VOIDWarpedNoise,
|
||||||
|
|||||||
@ -9,8 +9,10 @@ Adapted from RyannDaGreat/CommonSource (MIT License, Ryan Burgert):
|
|||||||
Only the code paths that ``comfy_extras/nodes_void.py::VOIDWarpedNoise`` actually
|
Only the code paths that ``comfy_extras/nodes_void.py::VOIDWarpedNoise`` actually
|
||||||
uses (torch THWC uint8 input, no background removal, no visualization, no disk
|
uses (torch THWC uint8 input, no background removal, no visualization, no disk
|
||||||
I/O, default warp/noise params) have been inlined. External ``rp`` utilities
|
I/O, default warp/noise params) have been inlined. External ``rp`` utilities
|
||||||
have been replaced with equivalents from torch.nn.functional / einops /
|
have been replaced with equivalents from torch.nn.functional / einops. The
|
||||||
torchvision.
|
RAFT optical-flow model itself is loaded offline via ``OpticalFlowLoader`` in
|
||||||
|
``nodes_void.py`` and passed into ``get_noise_from_video`` by the caller; this
|
||||||
|
module never downloads weights at runtime.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -19,7 +21,6 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torchvision.models.optical_flow import raft_large
|
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
@ -345,14 +346,20 @@ class NoiseWarper:
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
class RaftOpticalFlow:
|
class RaftOpticalFlow:
|
||||||
"""Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow."""
|
"""RAFT-large wrapper around a pre-loaded torchvision model.
|
||||||
|
|
||||||
def __init__(self, device=None):
|
``model`` must be the ``torchvision.models.optical_flow.raft_large`` module
|
||||||
|
with its weights already populated; this class is load-agnostic so the
|
||||||
|
caller owns downloading/offload concerns (see ``OpticalFlowLoader`` in
|
||||||
|
``nodes_void.py``). ``__call__`` returns a ``(2, H, W)`` flow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model, device=None):
|
||||||
if device is None:
|
if device is None:
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
device = torch.device(device) if not isinstance(device, torch.device) else device
|
device = torch.device(device) if not isinstance(device, torch.device) else device
|
||||||
|
|
||||||
model = raft_large(weights="DEFAULT", progress=False).to(device)
|
model = model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -384,22 +391,13 @@ class RaftOpticalFlow:
|
|||||||
return flow
|
return flow
|
||||||
|
|
||||||
|
|
||||||
_raft_cache: dict = {}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_raft_model(device):
|
|
||||||
key = str(device)
|
|
||||||
if key not in _raft_cache:
|
|
||||||
_raft_cache[key] = RaftOpticalFlow(device=device)
|
|
||||||
return _raft_cache[key]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Narrow entry point used by VOIDWarpedNoise
|
# Narrow entry point used by VOIDWarpedNoise
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def get_noise_from_video(
|
def get_noise_from_video(
|
||||||
video_frames: torch.Tensor,
|
video_frames: torch.Tensor,
|
||||||
|
raft: RaftOpticalFlow,
|
||||||
*,
|
*,
|
||||||
noise_channels: int = 16,
|
noise_channels: int = 16,
|
||||||
resize_frames: float = 0.5,
|
resize_frames: float = 0.5,
|
||||||
@ -411,6 +409,7 @@ def get_noise_from_video(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
video_frames: ``(T, H, W, 3)`` uint8 torch tensor.
|
video_frames: ``(T, H, W, 3)`` uint8 torch tensor.
|
||||||
|
raft: Pre-loaded RAFT optical-flow wrapper (see ``RaftOpticalFlow``).
|
||||||
noise_channels: Channels in the output noise.
|
noise_channels: Channels in the output noise.
|
||||||
resize_frames: Pre-RAFT frame scale factor.
|
resize_frames: Pre-RAFT frame scale factor.
|
||||||
resize_flow: Post-flow up-scale factor applied to the optical flow;
|
resize_flow: Post-flow up-scale factor applied to the optical flow;
|
||||||
@ -465,8 +464,6 @@ def get_noise_from_video(
|
|||||||
internal_h, internal_w, downscale_factor,
|
internal_h, internal_w, downscale_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
raft = _get_raft_model(device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
warper = NoiseWarper(
|
warper = NoiseWarper(
|
||||||
c=noise_channels, h=internal_h, w=internal_w, device=device,
|
c=noise_channels, h=internal_h, w=internal_w, device=device,
|
||||||
|
|||||||
@ -54,6 +54,8 @@ folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_enc
|
|||||||
|
|
||||||
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
|
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
|
||||||
|
|
||||||
|
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
|
||||||
|
|
||||||
output_directory = os.path.join(base_path, "output")
|
output_directory = os.path.join(base_path, "output")
|
||||||
temp_directory = os.path.join(base_path, "temp")
|
temp_directory = os.path.join(base_path, "temp")
|
||||||
input_directory = os.path.join(base_path, "input")
|
input_directory = os.path.join(base_path, "input")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user