From 60eed34bca8b91282e8ce540135884f28e5633b4 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Mon, 27 Apr 2026 11:13:33 +0200 Subject: [PATCH] Add native RaftOpticalFlow code. --- comfy_extras/nodes_void.py | 45 ++-- comfy_extras/void_noise_warp.py | 448 ++++++++++++++++++++++++++++++++ 2 files changed, 464 insertions(+), 29 deletions(-) create mode 100644 comfy_extras/void_noise_warp.py diff --git a/comfy_extras/nodes_void.py b/comfy_extras/nodes_void.py index 301163269..aeffb3ee2 100644 --- a/comfy_extras/nodes_void.py +++ b/comfy_extras/nodes_void.py @@ -12,6 +12,8 @@ from comfy.utils import model_trange as trange from comfy_api.latest import ComfyExtension, io from typing_extensions import override +from comfy_extras.void_noise_warp import get_noise_from_video + TEMPORAL_COMPRESSION = 4 PATCH_SIZE_T = 2 @@ -212,8 +214,6 @@ class VOIDWarpedNoise(io.ComfyNode): Takes the Pass 1 output video and produces temporally-correlated noise by warping Gaussian noise along optical flow vectors. This noise is used as the initial latent for Pass 2, resulting in better temporal consistency. - - Requires: pip install rp (auto-installs Go-with-the-Flow dependencies) """ @classmethod @@ -237,15 +237,6 @@ class VOIDWarpedNoise(io.ComfyNode): @classmethod def execute(cls, video, width, height, length, batch_size) -> io.NodeOutput: - try: - import rp - rp.r._pip_import_autoyes = True - rp.git_import('CommonSource') - import rp.git.CommonSource.noise_warp as nw - except ImportError: - raise RuntimeError( - "VOIDWarpedNoise requires the 'rp' package. Install with: pip install rp" - ) adjusted_length = _valid_void_length(length) if adjusted_length != length: @@ -260,36 +251,31 @@ class VOIDWarpedNoise(io.ComfyNode): latent_h = height // 8 latent_w = width // 8 - # rp.get_noise_from_video expects uint8 numpy frames; everything - # downstream of the warp stays on torch. - vid_uint8 = (video[:length].clamp(0, 1) * 255).to(torch.uint8).cpu().numpy() + # RAFT + noise warp is real compute, not an "intermediate" buffer, so + # we want the actual torch device (CUDA/MPS). The final latent is + # moved back to intermediate_device() before returning to match the + # rest of the ComfyUI pipeline. + device = comfy.model_management.get_torch_device() - frames = [vid_uint8[i] for i in range(vid_uint8.shape[0])] - frames = rp.resize_images_to_hold(frames, height=height, width=width) - frames = rp.crop_images(frames, height=height, width=width, origin='center') - frames = rp.as_numpy_array(frames) + vid = video[:length].to(device) + vid = comfy.utils.common_upscale( + vid.movedim(-1, 1), width, height, "bilinear", "center" + ).movedim(1, -1) + vid_uint8 = (vid.clamp(0, 1) * 255).to(torch.uint8) FRAME = 2**-1 FLOW = 2**3 LATENT_SCALE = 8 - warp_output = nw.get_noise_from_video( - frames, - remove_background=False, - visualize=False, - save_files=False, + warped = get_noise_from_video( + vid_uint8, noise_channels=16, - output_folder=None, resize_frames=FRAME, resize_flow=FLOW, downscale_factor=round(FRAME * FLOW) * LATENT_SCALE, + device=device, ) - # (T, H, W, C) → torch on intermediate device for torchified resize. - warped = torch.from_numpy(warp_output.numpy_noises).float() - device = comfy.model_management.intermediate_device() - warped = warped.to(device) - if warped.shape[0] != latent_t: indices = torch.linspace(0, warped.shape[0] - 1, latent_t, device=device).long() @@ -309,6 +295,7 @@ class VOIDWarpedNoise(io.ComfyNode): if batch_size > 1: warped_tensor = warped_tensor.repeat(batch_size, 1, 1, 1, 1) + warped_tensor = warped_tensor.to(comfy.model_management.intermediate_device()) return io.NodeOutput({"samples": warped_tensor}) diff --git a/comfy_extras/void_noise_warp.py b/comfy_extras/void_noise_warp.py new file mode 100644 index 000000000..358ff388e --- /dev/null +++ b/comfy_extras/void_noise_warp.py @@ -0,0 +1,448 @@ +""" +Optical-flow-warped noise for VOID Pass 2 refinement. + +Adapted from RyannDaGreat/CommonSource (MIT License, Ryan Burgert): + https://github.com/RyannDaGreat/CommonSource + - noise_warp.py (NoiseWarper / warp_xyωc / regaussianize / get_noise_from_video) + - raft.py (RaftOpticalFlow) + +Only the code paths that ``comfy_extras/nodes_void.py::VOIDWarpedNoise`` actually +uses (torch THWC uint8 input, no background removal, no visualization, no disk +I/O, default warp/noise params) have been inlined. External ``rp`` utilities +have been replaced with equivalents from torch.nn.functional / einops / +torchvision. +""" + +import logging +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange + +import comfy.model_management + + +# --------------------------------------------------------------------------- +# Low-level torch image helpers (drop-in replacements for rp.torch_* primitives) +# --------------------------------------------------------------------------- + +def _torch_resize_chw(image, size, interp, copy=True): + """Resize a CHW tensor. + + ``size`` is either a scalar factor or a (h, w) tuple. ``interp`` is one + of ``"bilinear"``, ``"nearest"``, ``"area"``. When ``copy`` is False and + the requested size matches the input, returns the input tensor as is + (faster but callers must not mutate the result). + """ + assert image.ndim == 3, image.shape + _, in_h, in_w = image.shape + if isinstance(size, (int, float)) and not isinstance(size, bool): + new_h = max(1, int(in_h * size)) + new_w = max(1, int(in_w * size)) + else: + new_h, new_w = size + + if (new_h, new_w) == (in_h, in_w): + return image.clone() if copy else image + + kwargs = {} + if interp in ("bilinear", "bicubic"): + kwargs["align_corners"] = False + out = F.interpolate(image[None], size=(new_h, new_w), mode=interp, **kwargs)[0] + return out + + +def _torch_remap_relative(image, dx, dy, interp="bilinear"): + """Relative remap of a CHW image via ``F.grid_sample``. + + Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)`` + for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0. + """ + assert image.ndim == 3 + assert dx.shape == dy.shape + _, h, w = image.shape + + x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype) + y_abs = dy + torch.arange(h, device=dy.device, dtype=dy.dtype)[:, None] + + x_norm = (x_abs / (w - 1)) * 2 - 1 + y_norm = (y_abs / (h - 1)) * 2 - 1 + + grid = torch.stack([x_norm, y_norm], dim=-1)[None].to(image.dtype) + out = F.grid_sample( + image[None], grid, mode=interp, align_corners=True, padding_mode="zeros" + )[0] + return out + + +def _torch_scatter_add_relative(image, dx, dy): + """Scatter-add a CHW image using relative floor-rounded (dx, dy) offsets. + + Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True, + interp='floor')``. Out-of-bounds targets are dropped. + """ + assert image.ndim == 3 + in_c, in_h, in_w = image.shape + assert dx.shape == dy.shape == (in_h, in_w) + + x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long) + y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None] + + valid = ((y >= 0) & (y < in_h) & (x >= 0) & (x < in_w)).reshape(-1) + indices = (y * in_w + x).reshape(-1)[valid] + + flat_image = rearrange(image, "c h w -> (h w) c")[valid] + out = torch.zeros((in_h * in_w, in_c), dtype=image.dtype, device=image.device) + out.index_add_(0, indices, flat_image) + return rearrange(out, "(h w) c -> c h w", h=in_h, w=in_w) + + +# --------------------------------------------------------------------------- +# Noise warping primitives (ported from noise_warp.py) +# --------------------------------------------------------------------------- + +def unique_pixels(image): + """Find unique pixel values in a CHW tensor. + + Returns ``(unique_colors [U, C], counts [U], index_matrix [H, W])`` where + ``index_matrix[i, j]`` is the index of the unique color at that pixel. + """ + _, h, w = image.shape + flat = rearrange(image, "c h w -> (h w) c") + unique_colors, inverse_indices, counts = torch.unique( + flat, dim=0, return_inverse=True, return_counts=True, sorted=False, + ) + index_matrix = rearrange(inverse_indices, "(h w) -> h w", h=h, w=w) + return unique_colors, counts, index_matrix + + +def sum_indexed_values(image, index_matrix): + """For each unique index, sum the CHW image values at its pixels.""" + _, h, w = image.shape + u = int(index_matrix.max().item()) + 1 + flat = rearrange(image, "c h w -> (h w) c") + out = torch.zeros((u, flat.shape[1]), dtype=flat.dtype, device=flat.device) + out.index_add_(0, index_matrix.view(-1), flat) + return out + + +def indexed_to_image(index_matrix, unique_colors): + """Build a CHW image from an index matrix and a (U, C) color table.""" + h, w = index_matrix.shape + flat = unique_colors[index_matrix.view(-1)] + return rearrange(flat, "(h w) c -> c h w", h=h, w=w) + + +def regaussianize(noise): + """Variance-preserving re-sampling of a CHW noise tensor. + + Wherever the noise contains groups of identical pixel values (e.g. after + a nearest-neighbor warp that duplicated source pixels), adds zero-mean + foreign noise within each group and scales by ``1/sqrt(count)`` so the + output is unit-variance gaussian again. + """ + _, hs, ws = noise.shape + _, counts, index_matrix = unique_pixels(noise[:1]) + + foreign_noise = torch.randn_like(noise) + summed = sum_indexed_values(foreign_noise, index_matrix) + meaned = indexed_to_image(index_matrix, summed / rearrange(counts, "u -> u 1")) + zeroed_foreign = foreign_noise - meaned + + counts_image = indexed_to_image(index_matrix, rearrange(counts, "u -> u 1")) + + output = noise / counts_image ** 0.5 + zeroed_foreign + return output, counts_image + + +def xy_meshgrid_like_image(image): + """Return a (2, H, W) tensor of (x, y) pixel coordinates matching ``image``.""" + _, h, w = image.shape + y, x = torch.meshgrid( + torch.arange(h, device=image.device, dtype=image.dtype), + torch.arange(w, device=image.device, dtype=image.dtype), + indexing="ij", + ) + return torch.stack([x, y]) + + +def noise_to_state(noise): + """Pack a (C, H, W) noise tensor into a state tensor (3+C, H, W) = [dx, dy, ω, noise].""" + zeros = torch.zeros_like(noise[:1]) + ones = torch.ones_like(noise[:1]) + return torch.cat([zeros, zeros, ones, noise]) + + +def state_to_noise(state): + """Unpack the noise channels from a state tensor.""" + return state[3:] + + +def warp_state(state, flow): + """Warp a noise-warper state tensor along the given optical flow. + + ``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels). + ``flow`` has shape ``(2, h, w)`` (= dx, dy). + """ + assert flow.device == state.device + assert flow.ndim == 3 and flow.shape[0] == 2 + assert state.ndim == 3 + xyoc, h, w = state.shape + assert flow.shape == (2, h, w) + device = state.device + + x_ch, y_ch = 0, 1 + xy = 2 # state[:xy] = [dx, dy] + xyw = 3 # state[:xyw] = [dx, dy, ω] + w_ch = 2 # state[w_ch] = ω + c = xyoc - xyw + oc = xyoc - xy + assert c > 0, "state has no noise channels" + assert (state[w_ch] > 0).all(), "all weights must be > 0" + + grid = xy_meshgrid_like_image(state) + + init = torch.empty_like(state) + init[:xy] = 0 + init[w_ch] = 1 + init[-c:] = 0 + + # --- Expansion branch: nearest-neighbor remap with negated flow --- + pre_expand = torch.empty_like(state) + pre_expand[:xy] = _torch_remap_relative(state[:xy], -flow[0], -flow[1], "nearest") + pre_expand[-oc:] = _torch_remap_relative(state[-oc:], -flow[0], -flow[1], "nearest") + pre_expand[w_ch][pre_expand[w_ch] == 0] = 1 + + # --- Shrink branch: scatter-add state into new positions --- + pre_shrink = state.clone() + pre_shrink[:xy] += flow + + pos = (grid + pre_shrink[:xy]).round() + in_bounds = (pos[x_ch] >= 0) & (pos[x_ch] < w) & (pos[y_ch] >= 0) & (pos[y_ch] < h) + pre_shrink = torch.where(~in_bounds[None], init, pre_shrink) + + scat_xy = pre_shrink[:xy].round() + pre_shrink[:xy] -= scat_xy + pre_shrink[:xy] = 0 # xy_mode='none' in upstream + + def scat(tensor): + return _torch_scatter_add_relative(tensor, scat_xy[0], scat_xy[1]) + + # rp.torch_scatter_add_image on a bool tensor errors on modern torch; + # scatter-sum a float ones tensor and threshold to get the mask instead. + shrink_mask = scat(torch.ones(1, h, w, dtype=state.dtype, device=device)) > 0 + + # Drop expansion samples at positions that will be filled by shrink. + pre_expand = torch.where(shrink_mask, init, pre_expand) + + # Regaussianize both branches together so duplicated-source groups are + # counted globally, then split back apart. + concat = torch.cat([pre_shrink, pre_expand], dim=2) # along width + concat[-c:], counts_image = regaussianize(concat[-c:]) + concat[w_ch] = concat[w_ch] / counts_image[0] + concat[w_ch] = concat[w_ch].nan_to_num() + pre_shrink, expand = torch.chunk(concat, chunks=2, dim=2) + + shrink = torch.empty_like(pre_shrink) + shrink[w_ch] = scat(pre_shrink[w_ch][None])[0] + shrink[:xy] = scat(pre_shrink[:xy] * pre_shrink[w_ch][None]) / shrink[w_ch][None] + shrink[-c:] = scat(pre_shrink[-c:] * pre_shrink[w_ch][None]) / scat( + pre_shrink[w_ch][None] ** 2 + ).sqrt() + + output = torch.where(shrink_mask, shrink, expand) + output[w_ch] = output[w_ch] / output[w_ch].mean() + output[w_ch] += 1e-5 + output[w_ch] **= 0.9999 + return output + + +class NoiseWarper: + """Maintain a warpable noise state and emit gaussian noise per frame. + + Simplified from RyannDaGreat/CommonSource/noise_warp.py::NoiseWarper: + ``scale_factor``, ``post_noise_alpha``, ``progressive_noise_alpha``, and + ``warp_kwargs`` are all dropped since VOIDWarpedNoise always uses defaults. + """ + + def __init__(self, c, h, w, device, dtype=torch.float32): + assert c > 0 and h > 0 and w > 0 + self.c = c + self.h = h + self.w = w + self.device = device + self.dtype = dtype + + noise = torch.randn(c, h, w, dtype=dtype, device=device) + self._state = noise_to_state(noise) + + @property + def noise(self): + # With scale_factor=1 the "downsample to respect weights" step is a + # size-preserving no-op; the weight-variance correction math still + # runs to stay faithful to upstream. + n = state_to_noise(self._state) + weights = self._state[2:3] + return n * weights / (weights ** 2).sqrt() + + def __call__(self, dx, dy): + assert dx.shape == dy.shape + flow = torch.stack([dx, dy]).to(self.device, self.dtype) + _, oflowh, ofloww = flow.shape + + flow = _torch_resize_chw(flow, (self.h, self.w), "bilinear", copy=True) + flowh, floww = flow.shape[-2:] + + # Upstream scales flow[0] by flowh/oflowh and flow[1] by floww/ofloww + # (channel-order appears swapped but harmless when H and W are scaled + # by the same factor, which is always the case for our callers). + flow[0] *= flowh / oflowh + flow[1] *= floww / ofloww + + self._state = warp_state(self._state, flow) + return self + + +# --------------------------------------------------------------------------- +# RAFT optical flow wrapper (ported from raft.py) +# --------------------------------------------------------------------------- + +class RaftOpticalFlow: + """Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow.""" + + def __init__(self, device=None): + from torchvision.models.optical_flow import raft_large + + if device is None: + device = comfy.model_management.get_torch_device() + device = torch.device(device) if not isinstance(device, torch.device) else device + + model = raft_large(weights="DEFAULT", progress=False).to(device) + model.eval() + self.device = device + self.model = model + + def _preprocess(self, image_chw): + image = image_chw.to(self.device, torch.float32) + _, h, w = image.shape + new_h = (h // 8) * 8 + new_w = (w // 8) * 8 + image = _torch_resize_chw(image, (new_h, new_w), "bilinear", copy=False) + image = image * 2 - 1 + return image[None] + + def __call__(self, from_image, to_image): + """``from_image``, ``to_image``: CHW float tensors in [0, 1].""" + assert from_image.shape == to_image.shape + _, h, w = from_image.shape + with torch.no_grad(): + img1 = self._preprocess(from_image) + img2 = self._preprocess(to_image) + list_of_flows = self.model(img1, img2) + flow = list_of_flows[-1][0] # (2, new_h, new_w) + if flow.shape[-2:] != (h, w): + flow = _torch_resize_chw(flow, (h, w), "bilinear", copy=False) + return flow + + +_raft_cache: dict = {} + + +def _get_raft_model(device): + key = str(device) + if key not in _raft_cache: + _raft_cache[key] = RaftOpticalFlow(device=device) + return _raft_cache[key] + + +# --------------------------------------------------------------------------- +# Narrow entry point used by VOIDWarpedNoise +# --------------------------------------------------------------------------- + +def get_noise_from_video( + video_frames: torch.Tensor, + *, + noise_channels: int = 16, + resize_frames: float = 0.5, + resize_flow: int = 8, + downscale_factor: int = 32, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """Produce optical-flow-warped gaussian noise from a video. + + Args: + video_frames: ``(T, H, W, 3)`` uint8 torch tensor. + noise_channels: Channels in the output noise. + resize_frames: Pre-RAFT frame scale factor. + resize_flow: Post-flow up-scale factor applied to the optical flow; + the internal noise state is allocated at + ``(resize_flow * resize_frames * H, resize_flow * resize_frames * W)``. + downscale_factor: Area-pool factor applied to the noise before return; + should evenly divide the internal noise resolution. + device: Target device. Defaults to ``comfy.model_management.get_torch_device()``. + + Returns: + ``(T, H', W', noise_channels)`` float32 noise tensor on ``device``. + """ + assert isinstance(resize_flow, int) and resize_flow >= 1, resize_flow + assert video_frames.ndim == 4 and video_frames.shape[-1] == 3, video_frames.shape + assert video_frames.dtype == torch.uint8, video_frames.dtype + + if device is None: + device = comfy.model_management.get_torch_device() + device = torch.device(device) if not isinstance(device, torch.device) else device + + if device.type == "cpu": + logging.warning( + "VOIDWarpedNoise: running get_noise_from_video on CPU; this will be " + "slow (minutes for ~45 frames). Use CUDA for interactive use." + ) + + T = video_frames.shape[0] + frames = video_frames.to(device).permute(0, 3, 1, 2).to(torch.float32) / 255.0 + if resize_frames != 1.0: + new_h = max(1, int(frames.shape[2] * resize_frames)) + new_w = max(1, int(frames.shape[3] * resize_frames)) + frames = F.interpolate(frames, size=(new_h, new_w), mode="area") + + _, _, H, W = frames.shape + internal_h = resize_flow * H + internal_w = resize_flow * W + if internal_h % downscale_factor or internal_w % downscale_factor: + logging.warning( + "VOIDWarpedNoise: internal noise size %dx%d is not divisible by " + "downscale_factor %d; output noise may have artifacts.", + internal_h, internal_w, downscale_factor, + ) + + raft = _get_raft_model(device) + + with torch.no_grad(): + warper = NoiseWarper( + c=noise_channels, h=internal_h, w=internal_w, device=device, + ) + down_h = warper.h // downscale_factor + down_w = warper.w // downscale_factor + output = torch.empty( + (T, down_h, down_w, noise_channels), dtype=torch.float32, device=device, + ) + + def downscale(noise_chw): + # Area-pool to 1/downscale_factor then multiply by downscale_factor + # to adjust std (sqrt of pool area == downscale_factor for a + # square pool). + down = _torch_resize_chw(noise_chw, 1.0 / downscale_factor, "area", copy=False) + return down * downscale_factor + + output[0] = downscale(warper.noise).permute(1, 2, 0) + + prev = frames[0] + for i in range(1, T): + curr = frames[i] + flow = raft(prev, curr).to(device) + warper(flow[0], flow[1]) + output[i] = downscale(warper.noise).permute(1, 2, 0) + prev = curr + + return output