Polish imports and modify asserts to raise proper errors with messages.
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled

This commit is contained in:
Talmaj Marinc 2026-04-27 11:42:00 +02:00
parent 60eed34bca
commit 2f3cf495c1

View File

@ -19,6 +19,7 @@ from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from torchvision.models.optical_flow import raft_large
import comfy.model_management import comfy.model_management
@ -35,7 +36,10 @@ def _torch_resize_chw(image, size, interp, copy=True):
the requested size matches the input, returns the input tensor as is the requested size matches the input, returns the input tensor as is
(faster but callers must not mutate the result). (faster but callers must not mutate the result).
""" """
assert image.ndim == 3, image.shape if image.ndim != 3:
raise ValueError(
f"_torch_resize_chw expects a 3D CHW tensor, got shape {tuple(image.shape)}"
)
_, in_h, in_w = image.shape _, in_h, in_w = image.shape
if isinstance(size, (int, float)) and not isinstance(size, bool): if isinstance(size, (int, float)) and not isinstance(size, bool):
new_h = max(1, int(in_h * size)) new_h = max(1, int(in_h * size))
@ -59,8 +63,14 @@ def _torch_remap_relative(image, dx, dy, interp="bilinear"):
Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)`` Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)``
for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0. for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0.
""" """
assert image.ndim == 3 if image.ndim != 3:
assert dx.shape == dy.shape raise ValueError(
f"_torch_remap_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}"
)
if dx.shape != dy.shape:
raise ValueError(
f"_torch_remap_relative: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}"
)
_, h, w = image.shape _, h, w = image.shape
x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype) x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype)
@ -82,9 +92,16 @@ def _torch_scatter_add_relative(image, dx, dy):
Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True, Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True,
interp='floor')``. Out-of-bounds targets are dropped. interp='floor')``. Out-of-bounds targets are dropped.
""" """
assert image.ndim == 3 if image.ndim != 3:
raise ValueError(
f"_torch_scatter_add_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}"
)
in_c, in_h, in_w = image.shape in_c, in_h, in_w = image.shape
assert dx.shape == dy.shape == (in_h, in_w) if dx.shape != (in_h, in_w) or dy.shape != (in_h, in_w):
raise ValueError(
f"_torch_scatter_add_relative: dx/dy must be ({in_h}, {in_w}), "
f"got dx={tuple(dx.shape)} dy={tuple(dy.shape)}"
)
x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long) x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long)
y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None] y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None]
@ -185,11 +202,20 @@ def warp_state(state, flow):
``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels). ``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels).
``flow`` has shape ``(2, h, w)`` (= dx, dy). ``flow`` has shape ``(2, h, w)`` (= dx, dy).
""" """
assert flow.device == state.device if flow.device != state.device:
assert flow.ndim == 3 and flow.shape[0] == 2 raise ValueError(
assert state.ndim == 3 f"warp_state: flow and state must be on the same device, "
f"got flow={flow.device} state={state.device}"
)
if state.ndim != 3:
raise ValueError(
f"warp_state: state must be 3D (3+C, H, W), got shape {tuple(state.shape)}"
)
xyoc, h, w = state.shape xyoc, h, w = state.shape
assert flow.shape == (2, h, w) if flow.shape != (2, h, w):
raise ValueError(
f"warp_state: flow must have shape (2, {h}, {w}), got {tuple(flow.shape)}"
)
device = state.device device = state.device
x_ch, y_ch = 0, 1 x_ch, y_ch = 0, 1
@ -198,8 +224,12 @@ def warp_state(state, flow):
w_ch = 2 # state[w_ch] = ω w_ch = 2 # state[w_ch] = ω
c = xyoc - xyw c = xyoc - xyw
oc = xyoc - xy oc = xyoc - xy
assert c > 0, "state has no noise channels" if c <= 0:
assert (state[w_ch] > 0).all(), "all weights must be > 0" raise ValueError(
f"warp_state: state has no noise channels (expected 3+C with C>0, got {xyoc} channels)"
)
if not (state[w_ch] > 0).all():
raise ValueError("warp_state: all weights in state[2] must be > 0")
grid = xy_meshgrid_like_image(state) grid = xy_meshgrid_like_image(state)
@ -267,7 +297,10 @@ class NoiseWarper:
""" """
def __init__(self, c, h, w, device, dtype=torch.float32): def __init__(self, c, h, w, device, dtype=torch.float32):
assert c > 0 and h > 0 and w > 0 if c <= 0 or h <= 0 or w <= 0:
raise ValueError(
f"NoiseWarper: c/h/w must all be positive, got c={c} h={h} w={w}"
)
self.c = c self.c = c
self.h = h self.h = h
self.w = w self.w = w
@ -287,7 +320,10 @@ class NoiseWarper:
return n * weights / (weights ** 2).sqrt() return n * weights / (weights ** 2).sqrt()
def __call__(self, dx, dy): def __call__(self, dx, dy):
assert dx.shape == dy.shape if dx.shape != dy.shape:
raise ValueError(
f"NoiseWarper: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}"
)
flow = torch.stack([dx, dy]).to(self.device, self.dtype) flow = torch.stack([dx, dy]).to(self.device, self.dtype)
_, oflowh, ofloww = flow.shape _, oflowh, ofloww = flow.shape
@ -312,8 +348,6 @@ class RaftOpticalFlow:
"""Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow.""" """Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow."""
def __init__(self, device=None): def __init__(self, device=None):
from torchvision.models.optical_flow import raft_large
if device is None: if device is None:
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
device = torch.device(device) if not isinstance(device, torch.device) else device device = torch.device(device) if not isinstance(device, torch.device) else device
@ -334,7 +368,11 @@ class RaftOpticalFlow:
def __call__(self, from_image, to_image): def __call__(self, from_image, to_image):
"""``from_image``, ``to_image``: CHW float tensors in [0, 1].""" """``from_image``, ``to_image``: CHW float tensors in [0, 1]."""
assert from_image.shape == to_image.shape if from_image.shape != to_image.shape:
raise ValueError(
f"RaftOpticalFlow: from_image and to_image must match, "
f"got {tuple(from_image.shape)} vs {tuple(to_image.shape)}"
)
_, h, w = from_image.shape _, h, w = from_image.shape
with torch.no_grad(): with torch.no_grad():
img1 = self._preprocess(from_image) img1 = self._preprocess(from_image)
@ -385,9 +423,20 @@ def get_noise_from_video(
Returns: Returns:
``(T, H', W', noise_channels)`` float32 noise tensor on ``device``. ``(T, H', W', noise_channels)`` float32 noise tensor on ``device``.
""" """
assert isinstance(resize_flow, int) and resize_flow >= 1, resize_flow if not isinstance(resize_flow, int) or resize_flow < 1:
assert video_frames.ndim == 4 and video_frames.shape[-1] == 3, video_frames.shape raise ValueError(
assert video_frames.dtype == torch.uint8, video_frames.dtype f"get_noise_from_video: resize_flow must be a positive int, got {resize_flow!r}"
)
if video_frames.ndim != 4 or video_frames.shape[-1] != 3:
raise ValueError(
"get_noise_from_video: video_frames must have shape (T, H, W, 3), "
f"got {tuple(video_frames.shape)}"
)
if video_frames.dtype != torch.uint8:
raise TypeError(
"get_noise_from_video: video_frames must be uint8 in [0, 255], "
f"got dtype {video_frames.dtype}"
)
if device is None: if device is None:
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()