Add native RaftOpticalFlow code.

This commit is contained in:
Talmaj Marinc 2026-04-27 11:13:33 +02:00
parent 12809a0c8d
commit 60eed34bca
2 changed files with 464 additions and 29 deletions

View File

@ -12,6 +12,8 @@ from comfy.utils import model_trange as trange
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
from typing_extensions import override from typing_extensions import override
from comfy_extras.void_noise_warp import get_noise_from_video
TEMPORAL_COMPRESSION = 4 TEMPORAL_COMPRESSION = 4
PATCH_SIZE_T = 2 PATCH_SIZE_T = 2
@ -212,8 +214,6 @@ class VOIDWarpedNoise(io.ComfyNode):
Takes the Pass 1 output video and produces temporally-correlated noise Takes the Pass 1 output video and produces temporally-correlated noise
by warping Gaussian noise along optical flow vectors. This noise is used by warping Gaussian noise along optical flow vectors. This noise is used
as the initial latent for Pass 2, resulting in better temporal consistency. as the initial latent for Pass 2, resulting in better temporal consistency.
Requires: pip install rp (auto-installs Go-with-the-Flow dependencies)
""" """
@classmethod @classmethod
@ -237,15 +237,6 @@ class VOIDWarpedNoise(io.ComfyNode):
@classmethod @classmethod
def execute(cls, video, width, height, length, batch_size) -> io.NodeOutput: def execute(cls, video, width, height, length, batch_size) -> io.NodeOutput:
try:
import rp
rp.r._pip_import_autoyes = True
rp.git_import('CommonSource')
import rp.git.CommonSource.noise_warp as nw
except ImportError:
raise RuntimeError(
"VOIDWarpedNoise requires the 'rp' package. Install with: pip install rp"
)
adjusted_length = _valid_void_length(length) adjusted_length = _valid_void_length(length)
if adjusted_length != length: if adjusted_length != length:
@ -260,36 +251,31 @@ class VOIDWarpedNoise(io.ComfyNode):
latent_h = height // 8 latent_h = height // 8
latent_w = width // 8 latent_w = width // 8
# rp.get_noise_from_video expects uint8 numpy frames; everything # RAFT + noise warp is real compute, not an "intermediate" buffer, so
# downstream of the warp stays on torch. # we want the actual torch device (CUDA/MPS). The final latent is
vid_uint8 = (video[:length].clamp(0, 1) * 255).to(torch.uint8).cpu().numpy() # moved back to intermediate_device() before returning to match the
# rest of the ComfyUI pipeline.
device = comfy.model_management.get_torch_device()
frames = [vid_uint8[i] for i in range(vid_uint8.shape[0])] vid = video[:length].to(device)
frames = rp.resize_images_to_hold(frames, height=height, width=width) vid = comfy.utils.common_upscale(
frames = rp.crop_images(frames, height=height, width=width, origin='center') vid.movedim(-1, 1), width, height, "bilinear", "center"
frames = rp.as_numpy_array(frames) ).movedim(1, -1)
vid_uint8 = (vid.clamp(0, 1) * 255).to(torch.uint8)
FRAME = 2**-1 FRAME = 2**-1
FLOW = 2**3 FLOW = 2**3
LATENT_SCALE = 8 LATENT_SCALE = 8
warp_output = nw.get_noise_from_video( warped = get_noise_from_video(
frames, vid_uint8,
remove_background=False,
visualize=False,
save_files=False,
noise_channels=16, noise_channels=16,
output_folder=None,
resize_frames=FRAME, resize_frames=FRAME,
resize_flow=FLOW, resize_flow=FLOW,
downscale_factor=round(FRAME * FLOW) * LATENT_SCALE, downscale_factor=round(FRAME * FLOW) * LATENT_SCALE,
device=device,
) )
# (T, H, W, C) → torch on intermediate device for torchified resize.
warped = torch.from_numpy(warp_output.numpy_noises).float()
device = comfy.model_management.intermediate_device()
warped = warped.to(device)
if warped.shape[0] != latent_t: if warped.shape[0] != latent_t:
indices = torch.linspace(0, warped.shape[0] - 1, latent_t, indices = torch.linspace(0, warped.shape[0] - 1, latent_t,
device=device).long() device=device).long()
@ -309,6 +295,7 @@ class VOIDWarpedNoise(io.ComfyNode):
if batch_size > 1: if batch_size > 1:
warped_tensor = warped_tensor.repeat(batch_size, 1, 1, 1, 1) warped_tensor = warped_tensor.repeat(batch_size, 1, 1, 1, 1)
warped_tensor = warped_tensor.to(comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": warped_tensor}) return io.NodeOutput({"samples": warped_tensor})

View File

@ -0,0 +1,448 @@
"""
Optical-flow-warped noise for VOID Pass 2 refinement.
Adapted from RyannDaGreat/CommonSource (MIT License, Ryan Burgert):
https://github.com/RyannDaGreat/CommonSource
- noise_warp.py (NoiseWarper / warp_xyωc / regaussianize / get_noise_from_video)
- raft.py (RaftOpticalFlow)
Only the code paths that ``comfy_extras/nodes_void.py::VOIDWarpedNoise`` actually
uses (torch THWC uint8 input, no background removal, no visualization, no disk
I/O, default warp/noise params) have been inlined. External ``rp`` utilities
have been replaced with equivalents from torch.nn.functional / einops /
torchvision.
"""
import logging
from typing import Optional
import torch
import torch.nn.functional as F
from einops import rearrange
import comfy.model_management
# ---------------------------------------------------------------------------
# Low-level torch image helpers (drop-in replacements for rp.torch_* primitives)
# ---------------------------------------------------------------------------
def _torch_resize_chw(image, size, interp, copy=True):
"""Resize a CHW tensor.
``size`` is either a scalar factor or a (h, w) tuple. ``interp`` is one
of ``"bilinear"``, ``"nearest"``, ``"area"``. When ``copy`` is False and
the requested size matches the input, returns the input tensor as is
(faster but callers must not mutate the result).
"""
assert image.ndim == 3, image.shape
_, in_h, in_w = image.shape
if isinstance(size, (int, float)) and not isinstance(size, bool):
new_h = max(1, int(in_h * size))
new_w = max(1, int(in_w * size))
else:
new_h, new_w = size
if (new_h, new_w) == (in_h, in_w):
return image.clone() if copy else image
kwargs = {}
if interp in ("bilinear", "bicubic"):
kwargs["align_corners"] = False
out = F.interpolate(image[None], size=(new_h, new_w), mode=interp, **kwargs)[0]
return out
def _torch_remap_relative(image, dx, dy, interp="bilinear"):
"""Relative remap of a CHW image via ``F.grid_sample``.
Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)``
for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0.
"""
assert image.ndim == 3
assert dx.shape == dy.shape
_, h, w = image.shape
x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype)
y_abs = dy + torch.arange(h, device=dy.device, dtype=dy.dtype)[:, None]
x_norm = (x_abs / (w - 1)) * 2 - 1
y_norm = (y_abs / (h - 1)) * 2 - 1
grid = torch.stack([x_norm, y_norm], dim=-1)[None].to(image.dtype)
out = F.grid_sample(
image[None], grid, mode=interp, align_corners=True, padding_mode="zeros"
)[0]
return out
def _torch_scatter_add_relative(image, dx, dy):
"""Scatter-add a CHW image using relative floor-rounded (dx, dy) offsets.
Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True,
interp='floor')``. Out-of-bounds targets are dropped.
"""
assert image.ndim == 3
in_c, in_h, in_w = image.shape
assert dx.shape == dy.shape == (in_h, in_w)
x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long)
y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None]
valid = ((y >= 0) & (y < in_h) & (x >= 0) & (x < in_w)).reshape(-1)
indices = (y * in_w + x).reshape(-1)[valid]
flat_image = rearrange(image, "c h w -> (h w) c")[valid]
out = torch.zeros((in_h * in_w, in_c), dtype=image.dtype, device=image.device)
out.index_add_(0, indices, flat_image)
return rearrange(out, "(h w) c -> c h w", h=in_h, w=in_w)
# ---------------------------------------------------------------------------
# Noise warping primitives (ported from noise_warp.py)
# ---------------------------------------------------------------------------
def unique_pixels(image):
"""Find unique pixel values in a CHW tensor.
Returns ``(unique_colors [U, C], counts [U], index_matrix [H, W])`` where
``index_matrix[i, j]`` is the index of the unique color at that pixel.
"""
_, h, w = image.shape
flat = rearrange(image, "c h w -> (h w) c")
unique_colors, inverse_indices, counts = torch.unique(
flat, dim=0, return_inverse=True, return_counts=True, sorted=False,
)
index_matrix = rearrange(inverse_indices, "(h w) -> h w", h=h, w=w)
return unique_colors, counts, index_matrix
def sum_indexed_values(image, index_matrix):
"""For each unique index, sum the CHW image values at its pixels."""
_, h, w = image.shape
u = int(index_matrix.max().item()) + 1
flat = rearrange(image, "c h w -> (h w) c")
out = torch.zeros((u, flat.shape[1]), dtype=flat.dtype, device=flat.device)
out.index_add_(0, index_matrix.view(-1), flat)
return out
def indexed_to_image(index_matrix, unique_colors):
"""Build a CHW image from an index matrix and a (U, C) color table."""
h, w = index_matrix.shape
flat = unique_colors[index_matrix.view(-1)]
return rearrange(flat, "(h w) c -> c h w", h=h, w=w)
def regaussianize(noise):
"""Variance-preserving re-sampling of a CHW noise tensor.
Wherever the noise contains groups of identical pixel values (e.g. after
a nearest-neighbor warp that duplicated source pixels), adds zero-mean
foreign noise within each group and scales by ``1/sqrt(count)`` so the
output is unit-variance gaussian again.
"""
_, hs, ws = noise.shape
_, counts, index_matrix = unique_pixels(noise[:1])
foreign_noise = torch.randn_like(noise)
summed = sum_indexed_values(foreign_noise, index_matrix)
meaned = indexed_to_image(index_matrix, summed / rearrange(counts, "u -> u 1"))
zeroed_foreign = foreign_noise - meaned
counts_image = indexed_to_image(index_matrix, rearrange(counts, "u -> u 1"))
output = noise / counts_image ** 0.5 + zeroed_foreign
return output, counts_image
def xy_meshgrid_like_image(image):
"""Return a (2, H, W) tensor of (x, y) pixel coordinates matching ``image``."""
_, h, w = image.shape
y, x = torch.meshgrid(
torch.arange(h, device=image.device, dtype=image.dtype),
torch.arange(w, device=image.device, dtype=image.dtype),
indexing="ij",
)
return torch.stack([x, y])
def noise_to_state(noise):
"""Pack a (C, H, W) noise tensor into a state tensor (3+C, H, W) = [dx, dy, ω, noise]."""
zeros = torch.zeros_like(noise[:1])
ones = torch.ones_like(noise[:1])
return torch.cat([zeros, zeros, ones, noise])
def state_to_noise(state):
"""Unpack the noise channels from a state tensor."""
return state[3:]
def warp_state(state, flow):
"""Warp a noise-warper state tensor along the given optical flow.
``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels).
``flow`` has shape ``(2, h, w)`` (= dx, dy).
"""
assert flow.device == state.device
assert flow.ndim == 3 and flow.shape[0] == 2
assert state.ndim == 3
xyoc, h, w = state.shape
assert flow.shape == (2, h, w)
device = state.device
x_ch, y_ch = 0, 1
xy = 2 # state[:xy] = [dx, dy]
xyw = 3 # state[:xyw] = [dx, dy, ω]
w_ch = 2 # state[w_ch] = ω
c = xyoc - xyw
oc = xyoc - xy
assert c > 0, "state has no noise channels"
assert (state[w_ch] > 0).all(), "all weights must be > 0"
grid = xy_meshgrid_like_image(state)
init = torch.empty_like(state)
init[:xy] = 0
init[w_ch] = 1
init[-c:] = 0
# --- Expansion branch: nearest-neighbor remap with negated flow ---
pre_expand = torch.empty_like(state)
pre_expand[:xy] = _torch_remap_relative(state[:xy], -flow[0], -flow[1], "nearest")
pre_expand[-oc:] = _torch_remap_relative(state[-oc:], -flow[0], -flow[1], "nearest")
pre_expand[w_ch][pre_expand[w_ch] == 0] = 1
# --- Shrink branch: scatter-add state into new positions ---
pre_shrink = state.clone()
pre_shrink[:xy] += flow
pos = (grid + pre_shrink[:xy]).round()
in_bounds = (pos[x_ch] >= 0) & (pos[x_ch] < w) & (pos[y_ch] >= 0) & (pos[y_ch] < h)
pre_shrink = torch.where(~in_bounds[None], init, pre_shrink)
scat_xy = pre_shrink[:xy].round()
pre_shrink[:xy] -= scat_xy
pre_shrink[:xy] = 0 # xy_mode='none' in upstream
def scat(tensor):
return _torch_scatter_add_relative(tensor, scat_xy[0], scat_xy[1])
# rp.torch_scatter_add_image on a bool tensor errors on modern torch;
# scatter-sum a float ones tensor and threshold to get the mask instead.
shrink_mask = scat(torch.ones(1, h, w, dtype=state.dtype, device=device)) > 0
# Drop expansion samples at positions that will be filled by shrink.
pre_expand = torch.where(shrink_mask, init, pre_expand)
# Regaussianize both branches together so duplicated-source groups are
# counted globally, then split back apart.
concat = torch.cat([pre_shrink, pre_expand], dim=2) # along width
concat[-c:], counts_image = regaussianize(concat[-c:])
concat[w_ch] = concat[w_ch] / counts_image[0]
concat[w_ch] = concat[w_ch].nan_to_num()
pre_shrink, expand = torch.chunk(concat, chunks=2, dim=2)
shrink = torch.empty_like(pre_shrink)
shrink[w_ch] = scat(pre_shrink[w_ch][None])[0]
shrink[:xy] = scat(pre_shrink[:xy] * pre_shrink[w_ch][None]) / shrink[w_ch][None]
shrink[-c:] = scat(pre_shrink[-c:] * pre_shrink[w_ch][None]) / scat(
pre_shrink[w_ch][None] ** 2
).sqrt()
output = torch.where(shrink_mask, shrink, expand)
output[w_ch] = output[w_ch] / output[w_ch].mean()
output[w_ch] += 1e-5
output[w_ch] **= 0.9999
return output
class NoiseWarper:
"""Maintain a warpable noise state and emit gaussian noise per frame.
Simplified from RyannDaGreat/CommonSource/noise_warp.py::NoiseWarper:
``scale_factor``, ``post_noise_alpha``, ``progressive_noise_alpha``, and
``warp_kwargs`` are all dropped since VOIDWarpedNoise always uses defaults.
"""
def __init__(self, c, h, w, device, dtype=torch.float32):
assert c > 0 and h > 0 and w > 0
self.c = c
self.h = h
self.w = w
self.device = device
self.dtype = dtype
noise = torch.randn(c, h, w, dtype=dtype, device=device)
self._state = noise_to_state(noise)
@property
def noise(self):
# With scale_factor=1 the "downsample to respect weights" step is a
# size-preserving no-op; the weight-variance correction math still
# runs to stay faithful to upstream.
n = state_to_noise(self._state)
weights = self._state[2:3]
return n * weights / (weights ** 2).sqrt()
def __call__(self, dx, dy):
assert dx.shape == dy.shape
flow = torch.stack([dx, dy]).to(self.device, self.dtype)
_, oflowh, ofloww = flow.shape
flow = _torch_resize_chw(flow, (self.h, self.w), "bilinear", copy=True)
flowh, floww = flow.shape[-2:]
# Upstream scales flow[0] by flowh/oflowh and flow[1] by floww/ofloww
# (channel-order appears swapped but harmless when H and W are scaled
# by the same factor, which is always the case for our callers).
flow[0] *= flowh / oflowh
flow[1] *= floww / ofloww
self._state = warp_state(self._state, flow)
return self
# ---------------------------------------------------------------------------
# RAFT optical flow wrapper (ported from raft.py)
# ---------------------------------------------------------------------------
class RaftOpticalFlow:
"""Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow."""
def __init__(self, device=None):
from torchvision.models.optical_flow import raft_large
if device is None:
device = comfy.model_management.get_torch_device()
device = torch.device(device) if not isinstance(device, torch.device) else device
model = raft_large(weights="DEFAULT", progress=False).to(device)
model.eval()
self.device = device
self.model = model
def _preprocess(self, image_chw):
image = image_chw.to(self.device, torch.float32)
_, h, w = image.shape
new_h = (h // 8) * 8
new_w = (w // 8) * 8
image = _torch_resize_chw(image, (new_h, new_w), "bilinear", copy=False)
image = image * 2 - 1
return image[None]
def __call__(self, from_image, to_image):
"""``from_image``, ``to_image``: CHW float tensors in [0, 1]."""
assert from_image.shape == to_image.shape
_, h, w = from_image.shape
with torch.no_grad():
img1 = self._preprocess(from_image)
img2 = self._preprocess(to_image)
list_of_flows = self.model(img1, img2)
flow = list_of_flows[-1][0] # (2, new_h, new_w)
if flow.shape[-2:] != (h, w):
flow = _torch_resize_chw(flow, (h, w), "bilinear", copy=False)
return flow
_raft_cache: dict = {}
def _get_raft_model(device):
key = str(device)
if key not in _raft_cache:
_raft_cache[key] = RaftOpticalFlow(device=device)
return _raft_cache[key]
# ---------------------------------------------------------------------------
# Narrow entry point used by VOIDWarpedNoise
# ---------------------------------------------------------------------------
def get_noise_from_video(
video_frames: torch.Tensor,
*,
noise_channels: int = 16,
resize_frames: float = 0.5,
resize_flow: int = 8,
downscale_factor: int = 32,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""Produce optical-flow-warped gaussian noise from a video.
Args:
video_frames: ``(T, H, W, 3)`` uint8 torch tensor.
noise_channels: Channels in the output noise.
resize_frames: Pre-RAFT frame scale factor.
resize_flow: Post-flow up-scale factor applied to the optical flow;
the internal noise state is allocated at
``(resize_flow * resize_frames * H, resize_flow * resize_frames * W)``.
downscale_factor: Area-pool factor applied to the noise before return;
should evenly divide the internal noise resolution.
device: Target device. Defaults to ``comfy.model_management.get_torch_device()``.
Returns:
``(T, H', W', noise_channels)`` float32 noise tensor on ``device``.
"""
assert isinstance(resize_flow, int) and resize_flow >= 1, resize_flow
assert video_frames.ndim == 4 and video_frames.shape[-1] == 3, video_frames.shape
assert video_frames.dtype == torch.uint8, video_frames.dtype
if device is None:
device = comfy.model_management.get_torch_device()
device = torch.device(device) if not isinstance(device, torch.device) else device
if device.type == "cpu":
logging.warning(
"VOIDWarpedNoise: running get_noise_from_video on CPU; this will be "
"slow (minutes for ~45 frames). Use CUDA for interactive use."
)
T = video_frames.shape[0]
frames = video_frames.to(device).permute(0, 3, 1, 2).to(torch.float32) / 255.0
if resize_frames != 1.0:
new_h = max(1, int(frames.shape[2] * resize_frames))
new_w = max(1, int(frames.shape[3] * resize_frames))
frames = F.interpolate(frames, size=(new_h, new_w), mode="area")
_, _, H, W = frames.shape
internal_h = resize_flow * H
internal_w = resize_flow * W
if internal_h % downscale_factor or internal_w % downscale_factor:
logging.warning(
"VOIDWarpedNoise: internal noise size %dx%d is not divisible by "
"downscale_factor %d; output noise may have artifacts.",
internal_h, internal_w, downscale_factor,
)
raft = _get_raft_model(device)
with torch.no_grad():
warper = NoiseWarper(
c=noise_channels, h=internal_h, w=internal_w, device=device,
)
down_h = warper.h // downscale_factor
down_w = warper.w // downscale_factor
output = torch.empty(
(T, down_h, down_w, noise_channels), dtype=torch.float32, device=device,
)
def downscale(noise_chw):
# Area-pool to 1/downscale_factor then multiply by downscale_factor
# to adjust std (sqrt of pool area == downscale_factor for a
# square pool).
down = _torch_resize_chw(noise_chw, 1.0 / downscale_factor, "area", copy=False)
return down * downscale_factor
output[0] = downscale(warper.noise).permute(1, 2, 0)
prev = frames[0]
for i in range(1, T):
curr = frames[i]
flow = raft(prev, curr).to(device)
warper(flow[0], flow[1])
output[i] = downscale(warper.noise).permute(1, 2, 0)
prev = curr
return output