Polish imports and modify asserts to raise proper errors with messages.
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled

This commit is contained in:
Talmaj Marinc 2026-04-27 11:42:00 +02:00
parent 60eed34bca
commit 2f3cf495c1

View File

@ -19,6 +19,7 @@ from typing import Optional
import torch
import torch.nn.functional as F
from einops import rearrange
from torchvision.models.optical_flow import raft_large
import comfy.model_management
@ -35,7 +36,10 @@ def _torch_resize_chw(image, size, interp, copy=True):
the requested size matches the input, returns the input tensor as is
(faster but callers must not mutate the result).
"""
assert image.ndim == 3, image.shape
if image.ndim != 3:
raise ValueError(
f"_torch_resize_chw expects a 3D CHW tensor, got shape {tuple(image.shape)}"
)
_, in_h, in_w = image.shape
if isinstance(size, (int, float)) and not isinstance(size, bool):
new_h = max(1, int(in_h * size))
@ -59,8 +63,14 @@ def _torch_remap_relative(image, dx, dy, interp="bilinear"):
Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)``
for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0.
"""
assert image.ndim == 3
assert dx.shape == dy.shape
if image.ndim != 3:
raise ValueError(
f"_torch_remap_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}"
)
if dx.shape != dy.shape:
raise ValueError(
f"_torch_remap_relative: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}"
)
_, h, w = image.shape
x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype)
@ -82,9 +92,16 @@ def _torch_scatter_add_relative(image, dx, dy):
Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True,
interp='floor')``. Out-of-bounds targets are dropped.
"""
assert image.ndim == 3
if image.ndim != 3:
raise ValueError(
f"_torch_scatter_add_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}"
)
in_c, in_h, in_w = image.shape
assert dx.shape == dy.shape == (in_h, in_w)
if dx.shape != (in_h, in_w) or dy.shape != (in_h, in_w):
raise ValueError(
f"_torch_scatter_add_relative: dx/dy must be ({in_h}, {in_w}), "
f"got dx={tuple(dx.shape)} dy={tuple(dy.shape)}"
)
x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long)
y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None]
@ -185,11 +202,20 @@ def warp_state(state, flow):
``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels).
``flow`` has shape ``(2, h, w)`` (= dx, dy).
"""
assert flow.device == state.device
assert flow.ndim == 3 and flow.shape[0] == 2
assert state.ndim == 3
if flow.device != state.device:
raise ValueError(
f"warp_state: flow and state must be on the same device, "
f"got flow={flow.device} state={state.device}"
)
if state.ndim != 3:
raise ValueError(
f"warp_state: state must be 3D (3+C, H, W), got shape {tuple(state.shape)}"
)
xyoc, h, w = state.shape
assert flow.shape == (2, h, w)
if flow.shape != (2, h, w):
raise ValueError(
f"warp_state: flow must have shape (2, {h}, {w}), got {tuple(flow.shape)}"
)
device = state.device
x_ch, y_ch = 0, 1
@ -198,8 +224,12 @@ def warp_state(state, flow):
w_ch = 2 # state[w_ch] = ω
c = xyoc - xyw
oc = xyoc - xy
assert c > 0, "state has no noise channels"
assert (state[w_ch] > 0).all(), "all weights must be > 0"
if c <= 0:
raise ValueError(
f"warp_state: state has no noise channels (expected 3+C with C>0, got {xyoc} channels)"
)
if not (state[w_ch] > 0).all():
raise ValueError("warp_state: all weights in state[2] must be > 0")
grid = xy_meshgrid_like_image(state)
@ -267,7 +297,10 @@ class NoiseWarper:
"""
def __init__(self, c, h, w, device, dtype=torch.float32):
assert c > 0 and h > 0 and w > 0
if c <= 0 or h <= 0 or w <= 0:
raise ValueError(
f"NoiseWarper: c/h/w must all be positive, got c={c} h={h} w={w}"
)
self.c = c
self.h = h
self.w = w
@ -287,7 +320,10 @@ class NoiseWarper:
return n * weights / (weights ** 2).sqrt()
def __call__(self, dx, dy):
assert dx.shape == dy.shape
if dx.shape != dy.shape:
raise ValueError(
f"NoiseWarper: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}"
)
flow = torch.stack([dx, dy]).to(self.device, self.dtype)
_, oflowh, ofloww = flow.shape
@ -312,8 +348,6 @@ class RaftOpticalFlow:
"""Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow."""
def __init__(self, device=None):
from torchvision.models.optical_flow import raft_large
if device is None:
device = comfy.model_management.get_torch_device()
device = torch.device(device) if not isinstance(device, torch.device) else device
@ -334,7 +368,11 @@ class RaftOpticalFlow:
def __call__(self, from_image, to_image):
"""``from_image``, ``to_image``: CHW float tensors in [0, 1]."""
assert from_image.shape == to_image.shape
if from_image.shape != to_image.shape:
raise ValueError(
f"RaftOpticalFlow: from_image and to_image must match, "
f"got {tuple(from_image.shape)} vs {tuple(to_image.shape)}"
)
_, h, w = from_image.shape
with torch.no_grad():
img1 = self._preprocess(from_image)
@ -385,9 +423,20 @@ def get_noise_from_video(
Returns:
``(T, H', W', noise_channels)`` float32 noise tensor on ``device``.
"""
assert isinstance(resize_flow, int) and resize_flow >= 1, resize_flow
assert video_frames.ndim == 4 and video_frames.shape[-1] == 3, video_frames.shape
assert video_frames.dtype == torch.uint8, video_frames.dtype
if not isinstance(resize_flow, int) or resize_flow < 1:
raise ValueError(
f"get_noise_from_video: resize_flow must be a positive int, got {resize_flow!r}"
)
if video_frames.ndim != 4 or video_frames.shape[-1] != 3:
raise ValueError(
"get_noise_from_video: video_frames must have shape (T, H, W, 3), "
f"got {tuple(video_frames.shape)}"
)
if video_frames.dtype != torch.uint8:
raise TypeError(
"get_noise_from_video: video_frames must be uint8 in [0, 255], "
f"got dtype {video_frames.dtype}"
)
if device is None:
device = comfy.model_management.get_torch_device()