Add native RaftOpticalFlow code.

This commit is contained in:
Talmaj Marinc 2026-04-27 11:13:33 +02:00
parent 12809a0c8d
commit 60eed34bca
2 changed files with 464 additions and 29 deletions

View File

@ -12,6 +12,8 @@ from comfy.utils import model_trange as trange
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
from comfy_extras.void_noise_warp import get_noise_from_video
TEMPORAL_COMPRESSION = 4
PATCH_SIZE_T = 2
@ -212,8 +214,6 @@ class VOIDWarpedNoise(io.ComfyNode):
Takes the Pass 1 output video and produces temporally-correlated noise
by warping Gaussian noise along optical flow vectors. This noise is used
as the initial latent for Pass 2, resulting in better temporal consistency.
Requires: pip install rp (auto-installs Go-with-the-Flow dependencies)
"""
@classmethod
@ -237,15 +237,6 @@ class VOIDWarpedNoise(io.ComfyNode):
@classmethod
def execute(cls, video, width, height, length, batch_size) -> io.NodeOutput:
try:
import rp
rp.r._pip_import_autoyes = True
rp.git_import('CommonSource')
import rp.git.CommonSource.noise_warp as nw
except ImportError:
raise RuntimeError(
"VOIDWarpedNoise requires the 'rp' package. Install with: pip install rp"
)
adjusted_length = _valid_void_length(length)
if adjusted_length != length:
@ -260,36 +251,31 @@ class VOIDWarpedNoise(io.ComfyNode):
latent_h = height // 8
latent_w = width // 8
# rp.get_noise_from_video expects uint8 numpy frames; everything
# downstream of the warp stays on torch.
vid_uint8 = (video[:length].clamp(0, 1) * 255).to(torch.uint8).cpu().numpy()
# RAFT + noise warp is real compute, not an "intermediate" buffer, so
# we want the actual torch device (CUDA/MPS). The final latent is
# moved back to intermediate_device() before returning to match the
# rest of the ComfyUI pipeline.
device = comfy.model_management.get_torch_device()
frames = [vid_uint8[i] for i in range(vid_uint8.shape[0])]
frames = rp.resize_images_to_hold(frames, height=height, width=width)
frames = rp.crop_images(frames, height=height, width=width, origin='center')
frames = rp.as_numpy_array(frames)
vid = video[:length].to(device)
vid = comfy.utils.common_upscale(
vid.movedim(-1, 1), width, height, "bilinear", "center"
).movedim(1, -1)
vid_uint8 = (vid.clamp(0, 1) * 255).to(torch.uint8)
FRAME = 2**-1
FLOW = 2**3
LATENT_SCALE = 8
warp_output = nw.get_noise_from_video(
frames,
remove_background=False,
visualize=False,
save_files=False,
warped = get_noise_from_video(
vid_uint8,
noise_channels=16,
output_folder=None,
resize_frames=FRAME,
resize_flow=FLOW,
downscale_factor=round(FRAME * FLOW) * LATENT_SCALE,
device=device,
)
# (T, H, W, C) → torch on intermediate device for torchified resize.
warped = torch.from_numpy(warp_output.numpy_noises).float()
device = comfy.model_management.intermediate_device()
warped = warped.to(device)
if warped.shape[0] != latent_t:
indices = torch.linspace(0, warped.shape[0] - 1, latent_t,
device=device).long()
@ -309,6 +295,7 @@ class VOIDWarpedNoise(io.ComfyNode):
if batch_size > 1:
warped_tensor = warped_tensor.repeat(batch_size, 1, 1, 1, 1)
warped_tensor = warped_tensor.to(comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": warped_tensor})

View File

@ -0,0 +1,448 @@
"""
Optical-flow-warped noise for VOID Pass 2 refinement.
Adapted from RyannDaGreat/CommonSource (MIT License, Ryan Burgert):
https://github.com/RyannDaGreat/CommonSource
- noise_warp.py (NoiseWarper / warp_xyωc / regaussianize / get_noise_from_video)
- raft.py (RaftOpticalFlow)
Only the code paths that ``comfy_extras/nodes_void.py::VOIDWarpedNoise`` actually
uses (torch THWC uint8 input, no background removal, no visualization, no disk
I/O, default warp/noise params) have been inlined. External ``rp`` utilities
have been replaced with equivalents from torch.nn.functional / einops /
torchvision.
"""
import logging
from typing import Optional
import torch
import torch.nn.functional as F
from einops import rearrange
import comfy.model_management
# ---------------------------------------------------------------------------
# Low-level torch image helpers (drop-in replacements for rp.torch_* primitives)
# ---------------------------------------------------------------------------
def _torch_resize_chw(image, size, interp, copy=True):
"""Resize a CHW tensor.
``size`` is either a scalar factor or a (h, w) tuple. ``interp`` is one
of ``"bilinear"``, ``"nearest"``, ``"area"``. When ``copy`` is False and
the requested size matches the input, returns the input tensor as is
(faster but callers must not mutate the result).
"""
assert image.ndim == 3, image.shape
_, in_h, in_w = image.shape
if isinstance(size, (int, float)) and not isinstance(size, bool):
new_h = max(1, int(in_h * size))
new_w = max(1, int(in_w * size))
else:
new_h, new_w = size
if (new_h, new_w) == (in_h, in_w):
return image.clone() if copy else image
kwargs = {}
if interp in ("bilinear", "bicubic"):
kwargs["align_corners"] = False
out = F.interpolate(image[None], size=(new_h, new_w), mode=interp, **kwargs)[0]
return out
def _torch_remap_relative(image, dx, dy, interp="bilinear"):
"""Relative remap of a CHW image via ``F.grid_sample``.
Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)``
for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0.
"""
assert image.ndim == 3
assert dx.shape == dy.shape
_, h, w = image.shape
x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype)
y_abs = dy + torch.arange(h, device=dy.device, dtype=dy.dtype)[:, None]
x_norm = (x_abs / (w - 1)) * 2 - 1
y_norm = (y_abs / (h - 1)) * 2 - 1
grid = torch.stack([x_norm, y_norm], dim=-1)[None].to(image.dtype)
out = F.grid_sample(
image[None], grid, mode=interp, align_corners=True, padding_mode="zeros"
)[0]
return out
def _torch_scatter_add_relative(image, dx, dy):
"""Scatter-add a CHW image using relative floor-rounded (dx, dy) offsets.
Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True,
interp='floor')``. Out-of-bounds targets are dropped.
"""
assert image.ndim == 3
in_c, in_h, in_w = image.shape
assert dx.shape == dy.shape == (in_h, in_w)
x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long)
y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None]
valid = ((y >= 0) & (y < in_h) & (x >= 0) & (x < in_w)).reshape(-1)
indices = (y * in_w + x).reshape(-1)[valid]
flat_image = rearrange(image, "c h w -> (h w) c")[valid]
out = torch.zeros((in_h * in_w, in_c), dtype=image.dtype, device=image.device)
out.index_add_(0, indices, flat_image)
return rearrange(out, "(h w) c -> c h w", h=in_h, w=in_w)
# ---------------------------------------------------------------------------
# Noise warping primitives (ported from noise_warp.py)
# ---------------------------------------------------------------------------
def unique_pixels(image):
"""Find unique pixel values in a CHW tensor.
Returns ``(unique_colors [U, C], counts [U], index_matrix [H, W])`` where
``index_matrix[i, j]`` is the index of the unique color at that pixel.
"""
_, h, w = image.shape
flat = rearrange(image, "c h w -> (h w) c")
unique_colors, inverse_indices, counts = torch.unique(
flat, dim=0, return_inverse=True, return_counts=True, sorted=False,
)
index_matrix = rearrange(inverse_indices, "(h w) -> h w", h=h, w=w)
return unique_colors, counts, index_matrix
def sum_indexed_values(image, index_matrix):
"""For each unique index, sum the CHW image values at its pixels."""
_, h, w = image.shape
u = int(index_matrix.max().item()) + 1
flat = rearrange(image, "c h w -> (h w) c")
out = torch.zeros((u, flat.shape[1]), dtype=flat.dtype, device=flat.device)
out.index_add_(0, index_matrix.view(-1), flat)
return out
def indexed_to_image(index_matrix, unique_colors):
"""Build a CHW image from an index matrix and a (U, C) color table."""
h, w = index_matrix.shape
flat = unique_colors[index_matrix.view(-1)]
return rearrange(flat, "(h w) c -> c h w", h=h, w=w)
def regaussianize(noise):
"""Variance-preserving re-sampling of a CHW noise tensor.
Wherever the noise contains groups of identical pixel values (e.g. after
a nearest-neighbor warp that duplicated source pixels), adds zero-mean
foreign noise within each group and scales by ``1/sqrt(count)`` so the
output is unit-variance gaussian again.
"""
_, hs, ws = noise.shape
_, counts, index_matrix = unique_pixels(noise[:1])
foreign_noise = torch.randn_like(noise)
summed = sum_indexed_values(foreign_noise, index_matrix)
meaned = indexed_to_image(index_matrix, summed / rearrange(counts, "u -> u 1"))
zeroed_foreign = foreign_noise - meaned
counts_image = indexed_to_image(index_matrix, rearrange(counts, "u -> u 1"))
output = noise / counts_image ** 0.5 + zeroed_foreign
return output, counts_image
def xy_meshgrid_like_image(image):
"""Return a (2, H, W) tensor of (x, y) pixel coordinates matching ``image``."""
_, h, w = image.shape
y, x = torch.meshgrid(
torch.arange(h, device=image.device, dtype=image.dtype),
torch.arange(w, device=image.device, dtype=image.dtype),
indexing="ij",
)
return torch.stack([x, y])
def noise_to_state(noise):
"""Pack a (C, H, W) noise tensor into a state tensor (3+C, H, W) = [dx, dy, ω, noise]."""
zeros = torch.zeros_like(noise[:1])
ones = torch.ones_like(noise[:1])
return torch.cat([zeros, zeros, ones, noise])
def state_to_noise(state):
"""Unpack the noise channels from a state tensor."""
return state[3:]
def warp_state(state, flow):
"""Warp a noise-warper state tensor along the given optical flow.
``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels).
``flow`` has shape ``(2, h, w)`` (= dx, dy).
"""
assert flow.device == state.device
assert flow.ndim == 3 and flow.shape[0] == 2
assert state.ndim == 3
xyoc, h, w = state.shape
assert flow.shape == (2, h, w)
device = state.device
x_ch, y_ch = 0, 1
xy = 2 # state[:xy] = [dx, dy]
xyw = 3 # state[:xyw] = [dx, dy, ω]
w_ch = 2 # state[w_ch] = ω
c = xyoc - xyw
oc = xyoc - xy
assert c > 0, "state has no noise channels"
assert (state[w_ch] > 0).all(), "all weights must be > 0"
grid = xy_meshgrid_like_image(state)
init = torch.empty_like(state)
init[:xy] = 0
init[w_ch] = 1
init[-c:] = 0
# --- Expansion branch: nearest-neighbor remap with negated flow ---
pre_expand = torch.empty_like(state)
pre_expand[:xy] = _torch_remap_relative(state[:xy], -flow[0], -flow[1], "nearest")
pre_expand[-oc:] = _torch_remap_relative(state[-oc:], -flow[0], -flow[1], "nearest")
pre_expand[w_ch][pre_expand[w_ch] == 0] = 1
# --- Shrink branch: scatter-add state into new positions ---
pre_shrink = state.clone()
pre_shrink[:xy] += flow
pos = (grid + pre_shrink[:xy]).round()
in_bounds = (pos[x_ch] >= 0) & (pos[x_ch] < w) & (pos[y_ch] >= 0) & (pos[y_ch] < h)
pre_shrink = torch.where(~in_bounds[None], init, pre_shrink)
scat_xy = pre_shrink[:xy].round()
pre_shrink[:xy] -= scat_xy
pre_shrink[:xy] = 0 # xy_mode='none' in upstream
def scat(tensor):
return _torch_scatter_add_relative(tensor, scat_xy[0], scat_xy[1])
# rp.torch_scatter_add_image on a bool tensor errors on modern torch;
# scatter-sum a float ones tensor and threshold to get the mask instead.
shrink_mask = scat(torch.ones(1, h, w, dtype=state.dtype, device=device)) > 0
# Drop expansion samples at positions that will be filled by shrink.
pre_expand = torch.where(shrink_mask, init, pre_expand)
# Regaussianize both branches together so duplicated-source groups are
# counted globally, then split back apart.
concat = torch.cat([pre_shrink, pre_expand], dim=2) # along width
concat[-c:], counts_image = regaussianize(concat[-c:])
concat[w_ch] = concat[w_ch] / counts_image[0]
concat[w_ch] = concat[w_ch].nan_to_num()
pre_shrink, expand = torch.chunk(concat, chunks=2, dim=2)
shrink = torch.empty_like(pre_shrink)
shrink[w_ch] = scat(pre_shrink[w_ch][None])[0]
shrink[:xy] = scat(pre_shrink[:xy] * pre_shrink[w_ch][None]) / shrink[w_ch][None]
shrink[-c:] = scat(pre_shrink[-c:] * pre_shrink[w_ch][None]) / scat(
pre_shrink[w_ch][None] ** 2
).sqrt()
output = torch.where(shrink_mask, shrink, expand)
output[w_ch] = output[w_ch] / output[w_ch].mean()
output[w_ch] += 1e-5
output[w_ch] **= 0.9999
return output
class NoiseWarper:
"""Maintain a warpable noise state and emit gaussian noise per frame.
Simplified from RyannDaGreat/CommonSource/noise_warp.py::NoiseWarper:
``scale_factor``, ``post_noise_alpha``, ``progressive_noise_alpha``, and
``warp_kwargs`` are all dropped since VOIDWarpedNoise always uses defaults.
"""
def __init__(self, c, h, w, device, dtype=torch.float32):
assert c > 0 and h > 0 and w > 0
self.c = c
self.h = h
self.w = w
self.device = device
self.dtype = dtype
noise = torch.randn(c, h, w, dtype=dtype, device=device)
self._state = noise_to_state(noise)
@property
def noise(self):
# With scale_factor=1 the "downsample to respect weights" step is a
# size-preserving no-op; the weight-variance correction math still
# runs to stay faithful to upstream.
n = state_to_noise(self._state)
weights = self._state[2:3]
return n * weights / (weights ** 2).sqrt()
def __call__(self, dx, dy):
assert dx.shape == dy.shape
flow = torch.stack([dx, dy]).to(self.device, self.dtype)
_, oflowh, ofloww = flow.shape
flow = _torch_resize_chw(flow, (self.h, self.w), "bilinear", copy=True)
flowh, floww = flow.shape[-2:]
# Upstream scales flow[0] by flowh/oflowh and flow[1] by floww/ofloww
# (channel-order appears swapped but harmless when H and W are scaled
# by the same factor, which is always the case for our callers).
flow[0] *= flowh / oflowh
flow[1] *= floww / ofloww
self._state = warp_state(self._state, flow)
return self
# ---------------------------------------------------------------------------
# RAFT optical flow wrapper (ported from raft.py)
# ---------------------------------------------------------------------------
class RaftOpticalFlow:
"""Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow."""
def __init__(self, device=None):
from torchvision.models.optical_flow import raft_large
if device is None:
device = comfy.model_management.get_torch_device()
device = torch.device(device) if not isinstance(device, torch.device) else device
model = raft_large(weights="DEFAULT", progress=False).to(device)
model.eval()
self.device = device
self.model = model
def _preprocess(self, image_chw):
image = image_chw.to(self.device, torch.float32)
_, h, w = image.shape
new_h = (h // 8) * 8
new_w = (w // 8) * 8
image = _torch_resize_chw(image, (new_h, new_w), "bilinear", copy=False)
image = image * 2 - 1
return image[None]
def __call__(self, from_image, to_image):
"""``from_image``, ``to_image``: CHW float tensors in [0, 1]."""
assert from_image.shape == to_image.shape
_, h, w = from_image.shape
with torch.no_grad():
img1 = self._preprocess(from_image)
img2 = self._preprocess(to_image)
list_of_flows = self.model(img1, img2)
flow = list_of_flows[-1][0] # (2, new_h, new_w)
if flow.shape[-2:] != (h, w):
flow = _torch_resize_chw(flow, (h, w), "bilinear", copy=False)
return flow
_raft_cache: dict = {}
def _get_raft_model(device):
key = str(device)
if key not in _raft_cache:
_raft_cache[key] = RaftOpticalFlow(device=device)
return _raft_cache[key]
# ---------------------------------------------------------------------------
# Narrow entry point used by VOIDWarpedNoise
# ---------------------------------------------------------------------------
def get_noise_from_video(
video_frames: torch.Tensor,
*,
noise_channels: int = 16,
resize_frames: float = 0.5,
resize_flow: int = 8,
downscale_factor: int = 32,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""Produce optical-flow-warped gaussian noise from a video.
Args:
video_frames: ``(T, H, W, 3)`` uint8 torch tensor.
noise_channels: Channels in the output noise.
resize_frames: Pre-RAFT frame scale factor.
resize_flow: Post-flow up-scale factor applied to the optical flow;
the internal noise state is allocated at
``(resize_flow * resize_frames * H, resize_flow * resize_frames * W)``.
downscale_factor: Area-pool factor applied to the noise before return;
should evenly divide the internal noise resolution.
device: Target device. Defaults to ``comfy.model_management.get_torch_device()``.
Returns:
``(T, H', W', noise_channels)`` float32 noise tensor on ``device``.
"""
assert isinstance(resize_flow, int) and resize_flow >= 1, resize_flow
assert video_frames.ndim == 4 and video_frames.shape[-1] == 3, video_frames.shape
assert video_frames.dtype == torch.uint8, video_frames.dtype
if device is None:
device = comfy.model_management.get_torch_device()
device = torch.device(device) if not isinstance(device, torch.device) else device
if device.type == "cpu":
logging.warning(
"VOIDWarpedNoise: running get_noise_from_video on CPU; this will be "
"slow (minutes for ~45 frames). Use CUDA for interactive use."
)
T = video_frames.shape[0]
frames = video_frames.to(device).permute(0, 3, 1, 2).to(torch.float32) / 255.0
if resize_frames != 1.0:
new_h = max(1, int(frames.shape[2] * resize_frames))
new_w = max(1, int(frames.shape[3] * resize_frames))
frames = F.interpolate(frames, size=(new_h, new_w), mode="area")
_, _, H, W = frames.shape
internal_h = resize_flow * H
internal_w = resize_flow * W
if internal_h % downscale_factor or internal_w % downscale_factor:
logging.warning(
"VOIDWarpedNoise: internal noise size %dx%d is not divisible by "
"downscale_factor %d; output noise may have artifacts.",
internal_h, internal_w, downscale_factor,
)
raft = _get_raft_model(device)
with torch.no_grad():
warper = NoiseWarper(
c=noise_channels, h=internal_h, w=internal_w, device=device,
)
down_h = warper.h // downscale_factor
down_w = warper.w // downscale_factor
output = torch.empty(
(T, down_h, down_w, noise_channels), dtype=torch.float32, device=device,
)
def downscale(noise_chw):
# Area-pool to 1/downscale_factor then multiply by downscale_factor
# to adjust std (sqrt of pool area == downscale_factor for a
# square pool).
down = _torch_resize_chw(noise_chw, 1.0 / downscale_factor, "area", copy=False)
return down * downscale_factor
output[0] = downscale(warper.noise).permute(1, 2, 0)
prev = frames[0]
for i in range(1, T):
curr = frames[i]
flow = raft(prev, curr).to(device)
warper(flow[0], flow[1])
output[i] = downscale(warper.noise).permute(1, 2, 0)
prev = curr
return output