""" Optical-flow-warped noise for VOID Pass 2 refinement. Adapted from RyannDaGreat/CommonSource (MIT License, Ryan Burgert): https://github.com/RyannDaGreat/CommonSource - noise_warp.py (NoiseWarper / warp_xyωc / regaussianize / get_noise_from_video) - raft.py (RaftOpticalFlow) Only the code paths that ``comfy_extras/nodes_void.py::VOIDWarpedNoise`` actually uses (torch THWC uint8 input, no background removal, no visualization, no disk I/O, default warp/noise params) have been inlined. External ``rp`` utilities have been replaced with equivalents from torch.nn.functional / einops / torchvision. """ import logging from typing import Optional import torch import torch.nn.functional as F from einops import rearrange import comfy.model_management # --------------------------------------------------------------------------- # Low-level torch image helpers (drop-in replacements for rp.torch_* primitives) # --------------------------------------------------------------------------- def _torch_resize_chw(image, size, interp, copy=True): """Resize a CHW tensor. ``size`` is either a scalar factor or a (h, w) tuple. ``interp`` is one of ``"bilinear"``, ``"nearest"``, ``"area"``. When ``copy`` is False and the requested size matches the input, returns the input tensor as is (faster but callers must not mutate the result). """ assert image.ndim == 3, image.shape _, in_h, in_w = image.shape if isinstance(size, (int, float)) and not isinstance(size, bool): new_h = max(1, int(in_h * size)) new_w = max(1, int(in_w * size)) else: new_h, new_w = size if (new_h, new_w) == (in_h, in_w): return image.clone() if copy else image kwargs = {} if interp in ("bilinear", "bicubic"): kwargs["align_corners"] = False out = F.interpolate(image[None], size=(new_h, new_w), mode=interp, **kwargs)[0] return out def _torch_remap_relative(image, dx, dy, interp="bilinear"): """Relative remap of a CHW image via ``F.grid_sample``. Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)`` for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0. """ assert image.ndim == 3 assert dx.shape == dy.shape _, h, w = image.shape x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype) y_abs = dy + torch.arange(h, device=dy.device, dtype=dy.dtype)[:, None] x_norm = (x_abs / (w - 1)) * 2 - 1 y_norm = (y_abs / (h - 1)) * 2 - 1 grid = torch.stack([x_norm, y_norm], dim=-1)[None].to(image.dtype) out = F.grid_sample( image[None], grid, mode=interp, align_corners=True, padding_mode="zeros" )[0] return out def _torch_scatter_add_relative(image, dx, dy): """Scatter-add a CHW image using relative floor-rounded (dx, dy) offsets. Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True, interp='floor')``. Out-of-bounds targets are dropped. """ assert image.ndim == 3 in_c, in_h, in_w = image.shape assert dx.shape == dy.shape == (in_h, in_w) x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long) y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None] valid = ((y >= 0) & (y < in_h) & (x >= 0) & (x < in_w)).reshape(-1) indices = (y * in_w + x).reshape(-1)[valid] flat_image = rearrange(image, "c h w -> (h w) c")[valid] out = torch.zeros((in_h * in_w, in_c), dtype=image.dtype, device=image.device) out.index_add_(0, indices, flat_image) return rearrange(out, "(h w) c -> c h w", h=in_h, w=in_w) # --------------------------------------------------------------------------- # Noise warping primitives (ported from noise_warp.py) # --------------------------------------------------------------------------- def unique_pixels(image): """Find unique pixel values in a CHW tensor. Returns ``(unique_colors [U, C], counts [U], index_matrix [H, W])`` where ``index_matrix[i, j]`` is the index of the unique color at that pixel. """ _, h, w = image.shape flat = rearrange(image, "c h w -> (h w) c") unique_colors, inverse_indices, counts = torch.unique( flat, dim=0, return_inverse=True, return_counts=True, sorted=False, ) index_matrix = rearrange(inverse_indices, "(h w) -> h w", h=h, w=w) return unique_colors, counts, index_matrix def sum_indexed_values(image, index_matrix): """For each unique index, sum the CHW image values at its pixels.""" _, h, w = image.shape u = int(index_matrix.max().item()) + 1 flat = rearrange(image, "c h w -> (h w) c") out = torch.zeros((u, flat.shape[1]), dtype=flat.dtype, device=flat.device) out.index_add_(0, index_matrix.view(-1), flat) return out def indexed_to_image(index_matrix, unique_colors): """Build a CHW image from an index matrix and a (U, C) color table.""" h, w = index_matrix.shape flat = unique_colors[index_matrix.view(-1)] return rearrange(flat, "(h w) c -> c h w", h=h, w=w) def regaussianize(noise): """Variance-preserving re-sampling of a CHW noise tensor. Wherever the noise contains groups of identical pixel values (e.g. after a nearest-neighbor warp that duplicated source pixels), adds zero-mean foreign noise within each group and scales by ``1/sqrt(count)`` so the output is unit-variance gaussian again. """ _, hs, ws = noise.shape _, counts, index_matrix = unique_pixels(noise[:1]) foreign_noise = torch.randn_like(noise) summed = sum_indexed_values(foreign_noise, index_matrix) meaned = indexed_to_image(index_matrix, summed / rearrange(counts, "u -> u 1")) zeroed_foreign = foreign_noise - meaned counts_image = indexed_to_image(index_matrix, rearrange(counts, "u -> u 1")) output = noise / counts_image ** 0.5 + zeroed_foreign return output, counts_image def xy_meshgrid_like_image(image): """Return a (2, H, W) tensor of (x, y) pixel coordinates matching ``image``.""" _, h, w = image.shape y, x = torch.meshgrid( torch.arange(h, device=image.device, dtype=image.dtype), torch.arange(w, device=image.device, dtype=image.dtype), indexing="ij", ) return torch.stack([x, y]) def noise_to_state(noise): """Pack a (C, H, W) noise tensor into a state tensor (3+C, H, W) = [dx, dy, ω, noise].""" zeros = torch.zeros_like(noise[:1]) ones = torch.ones_like(noise[:1]) return torch.cat([zeros, zeros, ones, noise]) def state_to_noise(state): """Unpack the noise channels from a state tensor.""" return state[3:] def warp_state(state, flow): """Warp a noise-warper state tensor along the given optical flow. ``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels). ``flow`` has shape ``(2, h, w)`` (= dx, dy). """ assert flow.device == state.device assert flow.ndim == 3 and flow.shape[0] == 2 assert state.ndim == 3 xyoc, h, w = state.shape assert flow.shape == (2, h, w) device = state.device x_ch, y_ch = 0, 1 xy = 2 # state[:xy] = [dx, dy] xyw = 3 # state[:xyw] = [dx, dy, ω] w_ch = 2 # state[w_ch] = ω c = xyoc - xyw oc = xyoc - xy assert c > 0, "state has no noise channels" assert (state[w_ch] > 0).all(), "all weights must be > 0" grid = xy_meshgrid_like_image(state) init = torch.empty_like(state) init[:xy] = 0 init[w_ch] = 1 init[-c:] = 0 # --- Expansion branch: nearest-neighbor remap with negated flow --- pre_expand = torch.empty_like(state) pre_expand[:xy] = _torch_remap_relative(state[:xy], -flow[0], -flow[1], "nearest") pre_expand[-oc:] = _torch_remap_relative(state[-oc:], -flow[0], -flow[1], "nearest") pre_expand[w_ch][pre_expand[w_ch] == 0] = 1 # --- Shrink branch: scatter-add state into new positions --- pre_shrink = state.clone() pre_shrink[:xy] += flow pos = (grid + pre_shrink[:xy]).round() in_bounds = (pos[x_ch] >= 0) & (pos[x_ch] < w) & (pos[y_ch] >= 0) & (pos[y_ch] < h) pre_shrink = torch.where(~in_bounds[None], init, pre_shrink) scat_xy = pre_shrink[:xy].round() pre_shrink[:xy] -= scat_xy pre_shrink[:xy] = 0 # xy_mode='none' in upstream def scat(tensor): return _torch_scatter_add_relative(tensor, scat_xy[0], scat_xy[1]) # rp.torch_scatter_add_image on a bool tensor errors on modern torch; # scatter-sum a float ones tensor and threshold to get the mask instead. shrink_mask = scat(torch.ones(1, h, w, dtype=state.dtype, device=device)) > 0 # Drop expansion samples at positions that will be filled by shrink. pre_expand = torch.where(shrink_mask, init, pre_expand) # Regaussianize both branches together so duplicated-source groups are # counted globally, then split back apart. concat = torch.cat([pre_shrink, pre_expand], dim=2) # along width concat[-c:], counts_image = regaussianize(concat[-c:]) concat[w_ch] = concat[w_ch] / counts_image[0] concat[w_ch] = concat[w_ch].nan_to_num() pre_shrink, expand = torch.chunk(concat, chunks=2, dim=2) shrink = torch.empty_like(pre_shrink) shrink[w_ch] = scat(pre_shrink[w_ch][None])[0] shrink[:xy] = scat(pre_shrink[:xy] * pre_shrink[w_ch][None]) / shrink[w_ch][None] shrink[-c:] = scat(pre_shrink[-c:] * pre_shrink[w_ch][None]) / scat( pre_shrink[w_ch][None] ** 2 ).sqrt() output = torch.where(shrink_mask, shrink, expand) output[w_ch] = output[w_ch] / output[w_ch].mean() output[w_ch] += 1e-5 output[w_ch] **= 0.9999 return output class NoiseWarper: """Maintain a warpable noise state and emit gaussian noise per frame. Simplified from RyannDaGreat/CommonSource/noise_warp.py::NoiseWarper: ``scale_factor``, ``post_noise_alpha``, ``progressive_noise_alpha``, and ``warp_kwargs`` are all dropped since VOIDWarpedNoise always uses defaults. """ def __init__(self, c, h, w, device, dtype=torch.float32): assert c > 0 and h > 0 and w > 0 self.c = c self.h = h self.w = w self.device = device self.dtype = dtype noise = torch.randn(c, h, w, dtype=dtype, device=device) self._state = noise_to_state(noise) @property def noise(self): # With scale_factor=1 the "downsample to respect weights" step is a # size-preserving no-op; the weight-variance correction math still # runs to stay faithful to upstream. n = state_to_noise(self._state) weights = self._state[2:3] return n * weights / (weights ** 2).sqrt() def __call__(self, dx, dy): assert dx.shape == dy.shape flow = torch.stack([dx, dy]).to(self.device, self.dtype) _, oflowh, ofloww = flow.shape flow = _torch_resize_chw(flow, (self.h, self.w), "bilinear", copy=True) flowh, floww = flow.shape[-2:] # Upstream scales flow[0] by flowh/oflowh and flow[1] by floww/ofloww # (channel-order appears swapped but harmless when H and W are scaled # by the same factor, which is always the case for our callers). flow[0] *= flowh / oflowh flow[1] *= floww / ofloww self._state = warp_state(self._state, flow) return self # --------------------------------------------------------------------------- # RAFT optical flow wrapper (ported from raft.py) # --------------------------------------------------------------------------- class RaftOpticalFlow: """Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow.""" def __init__(self, device=None): from torchvision.models.optical_flow import raft_large if device is None: device = comfy.model_management.get_torch_device() device = torch.device(device) if not isinstance(device, torch.device) else device model = raft_large(weights="DEFAULT", progress=False).to(device) model.eval() self.device = device self.model = model def _preprocess(self, image_chw): image = image_chw.to(self.device, torch.float32) _, h, w = image.shape new_h = (h // 8) * 8 new_w = (w // 8) * 8 image = _torch_resize_chw(image, (new_h, new_w), "bilinear", copy=False) image = image * 2 - 1 return image[None] def __call__(self, from_image, to_image): """``from_image``, ``to_image``: CHW float tensors in [0, 1].""" assert from_image.shape == to_image.shape _, h, w = from_image.shape with torch.no_grad(): img1 = self._preprocess(from_image) img2 = self._preprocess(to_image) list_of_flows = self.model(img1, img2) flow = list_of_flows[-1][0] # (2, new_h, new_w) if flow.shape[-2:] != (h, w): flow = _torch_resize_chw(flow, (h, w), "bilinear", copy=False) return flow _raft_cache: dict = {} def _get_raft_model(device): key = str(device) if key not in _raft_cache: _raft_cache[key] = RaftOpticalFlow(device=device) return _raft_cache[key] # --------------------------------------------------------------------------- # Narrow entry point used by VOIDWarpedNoise # --------------------------------------------------------------------------- def get_noise_from_video( video_frames: torch.Tensor, *, noise_channels: int = 16, resize_frames: float = 0.5, resize_flow: int = 8, downscale_factor: int = 32, device: Optional[torch.device] = None, ) -> torch.Tensor: """Produce optical-flow-warped gaussian noise from a video. Args: video_frames: ``(T, H, W, 3)`` uint8 torch tensor. noise_channels: Channels in the output noise. resize_frames: Pre-RAFT frame scale factor. resize_flow: Post-flow up-scale factor applied to the optical flow; the internal noise state is allocated at ``(resize_flow * resize_frames * H, resize_flow * resize_frames * W)``. downscale_factor: Area-pool factor applied to the noise before return; should evenly divide the internal noise resolution. device: Target device. Defaults to ``comfy.model_management.get_torch_device()``. Returns: ``(T, H', W', noise_channels)`` float32 noise tensor on ``device``. """ assert isinstance(resize_flow, int) and resize_flow >= 1, resize_flow assert video_frames.ndim == 4 and video_frames.shape[-1] == 3, video_frames.shape assert video_frames.dtype == torch.uint8, video_frames.dtype if device is None: device = comfy.model_management.get_torch_device() device = torch.device(device) if not isinstance(device, torch.device) else device if device.type == "cpu": logging.warning( "VOIDWarpedNoise: running get_noise_from_video on CPU; this will be " "slow (minutes for ~45 frames). Use CUDA for interactive use." ) T = video_frames.shape[0] frames = video_frames.to(device).permute(0, 3, 1, 2).to(torch.float32) / 255.0 if resize_frames != 1.0: new_h = max(1, int(frames.shape[2] * resize_frames)) new_w = max(1, int(frames.shape[3] * resize_frames)) frames = F.interpolate(frames, size=(new_h, new_w), mode="area") _, _, H, W = frames.shape internal_h = resize_flow * H internal_w = resize_flow * W if internal_h % downscale_factor or internal_w % downscale_factor: logging.warning( "VOIDWarpedNoise: internal noise size %dx%d is not divisible by " "downscale_factor %d; output noise may have artifacts.", internal_h, internal_w, downscale_factor, ) raft = _get_raft_model(device) with torch.no_grad(): warper = NoiseWarper( c=noise_channels, h=internal_h, w=internal_w, device=device, ) down_h = warper.h // downscale_factor down_w = warper.w // downscale_factor output = torch.empty( (T, down_h, down_w, noise_channels), dtype=torch.float32, device=device, ) def downscale(noise_chw): # Area-pool to 1/downscale_factor then multiply by downscale_factor # to adjust std (sqrt of pool area == downscale_factor for a # square pool). down = _torch_resize_chw(noise_chw, 1.0 / downscale_factor, "area", copy=False) return down * downscale_factor output[0] = downscale(warper.noise).permute(1, 2, 0) prev = frames[0] for i in range(1, T): curr = frames[i] flow = raft(prev, curr).to(device) warper(flow[0], flow[1]) output[i] = downscale(warper.noise).permute(1, 2, 0) prev = curr return output