mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-30 20:32:45 +08:00
Polish imports and modify asserts to raise proper errors with messages.
This commit is contained in:
parent
60eed34bca
commit
2f3cf495c1
@ -19,6 +19,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from torchvision.models.optical_flow import raft_large
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
@ -35,7 +36,10 @@ def _torch_resize_chw(image, size, interp, copy=True):
|
|||||||
the requested size matches the input, returns the input tensor as is
|
the requested size matches the input, returns the input tensor as is
|
||||||
(faster but callers must not mutate the result).
|
(faster but callers must not mutate the result).
|
||||||
"""
|
"""
|
||||||
assert image.ndim == 3, image.shape
|
if image.ndim != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"_torch_resize_chw expects a 3D CHW tensor, got shape {tuple(image.shape)}"
|
||||||
|
)
|
||||||
_, in_h, in_w = image.shape
|
_, in_h, in_w = image.shape
|
||||||
if isinstance(size, (int, float)) and not isinstance(size, bool):
|
if isinstance(size, (int, float)) and not isinstance(size, bool):
|
||||||
new_h = max(1, int(in_h * size))
|
new_h = max(1, int(in_h * size))
|
||||||
@ -59,8 +63,14 @@ def _torch_remap_relative(image, dx, dy, interp="bilinear"):
|
|||||||
Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)``
|
Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)``
|
||||||
for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0.
|
for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0.
|
||||||
"""
|
"""
|
||||||
assert image.ndim == 3
|
if image.ndim != 3:
|
||||||
assert dx.shape == dy.shape
|
raise ValueError(
|
||||||
|
f"_torch_remap_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}"
|
||||||
|
)
|
||||||
|
if dx.shape != dy.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"_torch_remap_relative: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}"
|
||||||
|
)
|
||||||
_, h, w = image.shape
|
_, h, w = image.shape
|
||||||
|
|
||||||
x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype)
|
x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype)
|
||||||
@ -82,9 +92,16 @@ def _torch_scatter_add_relative(image, dx, dy):
|
|||||||
Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True,
|
Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True,
|
||||||
interp='floor')``. Out-of-bounds targets are dropped.
|
interp='floor')``. Out-of-bounds targets are dropped.
|
||||||
"""
|
"""
|
||||||
assert image.ndim == 3
|
if image.ndim != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"_torch_scatter_add_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}"
|
||||||
|
)
|
||||||
in_c, in_h, in_w = image.shape
|
in_c, in_h, in_w = image.shape
|
||||||
assert dx.shape == dy.shape == (in_h, in_w)
|
if dx.shape != (in_h, in_w) or dy.shape != (in_h, in_w):
|
||||||
|
raise ValueError(
|
||||||
|
f"_torch_scatter_add_relative: dx/dy must be ({in_h}, {in_w}), "
|
||||||
|
f"got dx={tuple(dx.shape)} dy={tuple(dy.shape)}"
|
||||||
|
)
|
||||||
|
|
||||||
x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long)
|
x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long)
|
||||||
y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None]
|
y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None]
|
||||||
@ -185,11 +202,20 @@ def warp_state(state, flow):
|
|||||||
``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels).
|
``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels).
|
||||||
``flow`` has shape ``(2, h, w)`` (= dx, dy).
|
``flow`` has shape ``(2, h, w)`` (= dx, dy).
|
||||||
"""
|
"""
|
||||||
assert flow.device == state.device
|
if flow.device != state.device:
|
||||||
assert flow.ndim == 3 and flow.shape[0] == 2
|
raise ValueError(
|
||||||
assert state.ndim == 3
|
f"warp_state: flow and state must be on the same device, "
|
||||||
|
f"got flow={flow.device} state={state.device}"
|
||||||
|
)
|
||||||
|
if state.ndim != 3:
|
||||||
|
raise ValueError(
|
||||||
|
f"warp_state: state must be 3D (3+C, H, W), got shape {tuple(state.shape)}"
|
||||||
|
)
|
||||||
xyoc, h, w = state.shape
|
xyoc, h, w = state.shape
|
||||||
assert flow.shape == (2, h, w)
|
if flow.shape != (2, h, w):
|
||||||
|
raise ValueError(
|
||||||
|
f"warp_state: flow must have shape (2, {h}, {w}), got {tuple(flow.shape)}"
|
||||||
|
)
|
||||||
device = state.device
|
device = state.device
|
||||||
|
|
||||||
x_ch, y_ch = 0, 1
|
x_ch, y_ch = 0, 1
|
||||||
@ -198,8 +224,12 @@ def warp_state(state, flow):
|
|||||||
w_ch = 2 # state[w_ch] = ω
|
w_ch = 2 # state[w_ch] = ω
|
||||||
c = xyoc - xyw
|
c = xyoc - xyw
|
||||||
oc = xyoc - xy
|
oc = xyoc - xy
|
||||||
assert c > 0, "state has no noise channels"
|
if c <= 0:
|
||||||
assert (state[w_ch] > 0).all(), "all weights must be > 0"
|
raise ValueError(
|
||||||
|
f"warp_state: state has no noise channels (expected 3+C with C>0, got {xyoc} channels)"
|
||||||
|
)
|
||||||
|
if not (state[w_ch] > 0).all():
|
||||||
|
raise ValueError("warp_state: all weights in state[2] must be > 0")
|
||||||
|
|
||||||
grid = xy_meshgrid_like_image(state)
|
grid = xy_meshgrid_like_image(state)
|
||||||
|
|
||||||
@ -267,7 +297,10 @@ class NoiseWarper:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, c, h, w, device, dtype=torch.float32):
|
def __init__(self, c, h, w, device, dtype=torch.float32):
|
||||||
assert c > 0 and h > 0 and w > 0
|
if c <= 0 or h <= 0 or w <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"NoiseWarper: c/h/w must all be positive, got c={c} h={h} w={w}"
|
||||||
|
)
|
||||||
self.c = c
|
self.c = c
|
||||||
self.h = h
|
self.h = h
|
||||||
self.w = w
|
self.w = w
|
||||||
@ -287,7 +320,10 @@ class NoiseWarper:
|
|||||||
return n * weights / (weights ** 2).sqrt()
|
return n * weights / (weights ** 2).sqrt()
|
||||||
|
|
||||||
def __call__(self, dx, dy):
|
def __call__(self, dx, dy):
|
||||||
assert dx.shape == dy.shape
|
if dx.shape != dy.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"NoiseWarper: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}"
|
||||||
|
)
|
||||||
flow = torch.stack([dx, dy]).to(self.device, self.dtype)
|
flow = torch.stack([dx, dy]).to(self.device, self.dtype)
|
||||||
_, oflowh, ofloww = flow.shape
|
_, oflowh, ofloww = flow.shape
|
||||||
|
|
||||||
@ -312,8 +348,6 @@ class RaftOpticalFlow:
|
|||||||
"""Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow."""
|
"""Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow."""
|
||||||
|
|
||||||
def __init__(self, device=None):
|
def __init__(self, device=None):
|
||||||
from torchvision.models.optical_flow import raft_large
|
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
device = torch.device(device) if not isinstance(device, torch.device) else device
|
device = torch.device(device) if not isinstance(device, torch.device) else device
|
||||||
@ -334,7 +368,11 @@ class RaftOpticalFlow:
|
|||||||
|
|
||||||
def __call__(self, from_image, to_image):
|
def __call__(self, from_image, to_image):
|
||||||
"""``from_image``, ``to_image``: CHW float tensors in [0, 1]."""
|
"""``from_image``, ``to_image``: CHW float tensors in [0, 1]."""
|
||||||
assert from_image.shape == to_image.shape
|
if from_image.shape != to_image.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"RaftOpticalFlow: from_image and to_image must match, "
|
||||||
|
f"got {tuple(from_image.shape)} vs {tuple(to_image.shape)}"
|
||||||
|
)
|
||||||
_, h, w = from_image.shape
|
_, h, w = from_image.shape
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
img1 = self._preprocess(from_image)
|
img1 = self._preprocess(from_image)
|
||||||
@ -385,9 +423,20 @@ def get_noise_from_video(
|
|||||||
Returns:
|
Returns:
|
||||||
``(T, H', W', noise_channels)`` float32 noise tensor on ``device``.
|
``(T, H', W', noise_channels)`` float32 noise tensor on ``device``.
|
||||||
"""
|
"""
|
||||||
assert isinstance(resize_flow, int) and resize_flow >= 1, resize_flow
|
if not isinstance(resize_flow, int) or resize_flow < 1:
|
||||||
assert video_frames.ndim == 4 and video_frames.shape[-1] == 3, video_frames.shape
|
raise ValueError(
|
||||||
assert video_frames.dtype == torch.uint8, video_frames.dtype
|
f"get_noise_from_video: resize_flow must be a positive int, got {resize_flow!r}"
|
||||||
|
)
|
||||||
|
if video_frames.ndim != 4 or video_frames.shape[-1] != 3:
|
||||||
|
raise ValueError(
|
||||||
|
"get_noise_from_video: video_frames must have shape (T, H, W, 3), "
|
||||||
|
f"got {tuple(video_frames.shape)}"
|
||||||
|
)
|
||||||
|
if video_frames.dtype != torch.uint8:
|
||||||
|
raise TypeError(
|
||||||
|
"get_noise_from_video: video_frames must be uint8 in [0, 255], "
|
||||||
|
f"got dtype {video_frames.dtype}"
|
||||||
|
)
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user