From 2f3cf495c10aaf1985ecf934b3a7825e5c05f581 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Mon, 27 Apr 2026 11:42:00 +0200 Subject: [PATCH] Polish imports and modify asserts to raise proper errors with messages. --- comfy_extras/void_noise_warp.py | 87 ++++++++++++++++++++++++++------- 1 file changed, 68 insertions(+), 19 deletions(-) diff --git a/comfy_extras/void_noise_warp.py b/comfy_extras/void_noise_warp.py index 358ff388e..4f7ff470f 100644 --- a/comfy_extras/void_noise_warp.py +++ b/comfy_extras/void_noise_warp.py @@ -19,6 +19,7 @@ from typing import Optional import torch import torch.nn.functional as F from einops import rearrange +from torchvision.models.optical_flow import raft_large import comfy.model_management @@ -35,7 +36,10 @@ def _torch_resize_chw(image, size, interp, copy=True): the requested size matches the input, returns the input tensor as is (faster but callers must not mutate the result). """ - assert image.ndim == 3, image.shape + if image.ndim != 3: + raise ValueError( + f"_torch_resize_chw expects a 3D CHW tensor, got shape {tuple(image.shape)}" + ) _, in_h, in_w = image.shape if isinstance(size, (int, float)) and not isinstance(size, bool): new_h = max(1, int(in_h * size)) @@ -59,8 +63,14 @@ def _torch_remap_relative(image, dx, dy, interp="bilinear"): Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)`` for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0. """ - assert image.ndim == 3 - assert dx.shape == dy.shape + if image.ndim != 3: + raise ValueError( + f"_torch_remap_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}" + ) + if dx.shape != dy.shape: + raise ValueError( + f"_torch_remap_relative: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}" + ) _, h, w = image.shape x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype) @@ -82,9 +92,16 @@ def _torch_scatter_add_relative(image, dx, dy): Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True, interp='floor')``. Out-of-bounds targets are dropped. """ - assert image.ndim == 3 + if image.ndim != 3: + raise ValueError( + f"_torch_scatter_add_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}" + ) in_c, in_h, in_w = image.shape - assert dx.shape == dy.shape == (in_h, in_w) + if dx.shape != (in_h, in_w) or dy.shape != (in_h, in_w): + raise ValueError( + f"_torch_scatter_add_relative: dx/dy must be ({in_h}, {in_w}), " + f"got dx={tuple(dx.shape)} dy={tuple(dy.shape)}" + ) x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long) y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None] @@ -185,11 +202,20 @@ def warp_state(state, flow): ``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels). ``flow`` has shape ``(2, h, w)`` (= dx, dy). """ - assert flow.device == state.device - assert flow.ndim == 3 and flow.shape[0] == 2 - assert state.ndim == 3 + if flow.device != state.device: + raise ValueError( + f"warp_state: flow and state must be on the same device, " + f"got flow={flow.device} state={state.device}" + ) + if state.ndim != 3: + raise ValueError( + f"warp_state: state must be 3D (3+C, H, W), got shape {tuple(state.shape)}" + ) xyoc, h, w = state.shape - assert flow.shape == (2, h, w) + if flow.shape != (2, h, w): + raise ValueError( + f"warp_state: flow must have shape (2, {h}, {w}), got {tuple(flow.shape)}" + ) device = state.device x_ch, y_ch = 0, 1 @@ -198,8 +224,12 @@ def warp_state(state, flow): w_ch = 2 # state[w_ch] = ω c = xyoc - xyw oc = xyoc - xy - assert c > 0, "state has no noise channels" - assert (state[w_ch] > 0).all(), "all weights must be > 0" + if c <= 0: + raise ValueError( + f"warp_state: state has no noise channels (expected 3+C with C>0, got {xyoc} channels)" + ) + if not (state[w_ch] > 0).all(): + raise ValueError("warp_state: all weights in state[2] must be > 0") grid = xy_meshgrid_like_image(state) @@ -267,7 +297,10 @@ class NoiseWarper: """ def __init__(self, c, h, w, device, dtype=torch.float32): - assert c > 0 and h > 0 and w > 0 + if c <= 0 or h <= 0 or w <= 0: + raise ValueError( + f"NoiseWarper: c/h/w must all be positive, got c={c} h={h} w={w}" + ) self.c = c self.h = h self.w = w @@ -287,7 +320,10 @@ class NoiseWarper: return n * weights / (weights ** 2).sqrt() def __call__(self, dx, dy): - assert dx.shape == dy.shape + if dx.shape != dy.shape: + raise ValueError( + f"NoiseWarper: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}" + ) flow = torch.stack([dx, dy]).to(self.device, self.dtype) _, oflowh, ofloww = flow.shape @@ -312,8 +348,6 @@ class RaftOpticalFlow: """Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow.""" def __init__(self, device=None): - from torchvision.models.optical_flow import raft_large - if device is None: device = comfy.model_management.get_torch_device() device = torch.device(device) if not isinstance(device, torch.device) else device @@ -334,7 +368,11 @@ class RaftOpticalFlow: def __call__(self, from_image, to_image): """``from_image``, ``to_image``: CHW float tensors in [0, 1].""" - assert from_image.shape == to_image.shape + if from_image.shape != to_image.shape: + raise ValueError( + f"RaftOpticalFlow: from_image and to_image must match, " + f"got {tuple(from_image.shape)} vs {tuple(to_image.shape)}" + ) _, h, w = from_image.shape with torch.no_grad(): img1 = self._preprocess(from_image) @@ -385,9 +423,20 @@ def get_noise_from_video( Returns: ``(T, H', W', noise_channels)`` float32 noise tensor on ``device``. """ - assert isinstance(resize_flow, int) and resize_flow >= 1, resize_flow - assert video_frames.ndim == 4 and video_frames.shape[-1] == 3, video_frames.shape - assert video_frames.dtype == torch.uint8, video_frames.dtype + if not isinstance(resize_flow, int) or resize_flow < 1: + raise ValueError( + f"get_noise_from_video: resize_flow must be a positive int, got {resize_flow!r}" + ) + if video_frames.ndim != 4 or video_frames.shape[-1] != 3: + raise ValueError( + "get_noise_from_video: video_frames must have shape (T, H, W, 3), " + f"got {tuple(video_frames.shape)}" + ) + if video_frames.dtype != torch.uint8: + raise TypeError( + "get_noise_from_video: video_frames must be uint8 in [0, 255], " + f"got dtype {video_frames.dtype}" + ) if device is None: device = comfy.model_management.get_torch_device()