mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 16:29:25 +08:00
96 lines
3.5 KiB
Python
96 lines
3.5 KiB
Python
"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``."""
|
|
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from comfy.cli_args import args as cli_args
|
|
|
|
if not torch.cuda.is_available():
|
|
cli_args.cpu = True
|
|
|
|
import comfy.sample # noqa: E402
|
|
import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402
|
|
from comfy_extras.nodes_seedvr import SeedVR2ProgressiveSampler # noqa: E402
|
|
|
|
_LAT_C = 16
|
|
_COND_C = 17
|
|
|
|
|
|
def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8):
|
|
"""Build minimal SeedVR2-shaped sampling inputs."""
|
|
samples_5d = torch.arange(
|
|
B * _LAT_C * T * H * W, dtype=torch.float32
|
|
).reshape(B, _LAT_C, T, H, W)
|
|
samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous()
|
|
|
|
cond_5d = torch.arange(
|
|
B * _COND_C * T * H * W, dtype=torch.float32
|
|
).reshape(B, _COND_C, T, H, W) + 10000.0
|
|
cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous()
|
|
|
|
text_pos = torch.zeros(1, 4, 32)
|
|
text_neg = torch.zeros(1, 4, 32)
|
|
positive = [[text_pos, {"condition": cond.clone()}]]
|
|
negative = [[text_neg, {"condition": cond.clone()}]]
|
|
latent_image = {"samples": samples}
|
|
return latent_image, positive, negative, samples_5d, cond_5d
|
|
|
|
|
|
def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None):
|
|
return latent_image
|
|
|
|
|
|
def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None):
|
|
"""Return a tensor whose values encode ``(seed, position)``."""
|
|
base = torch.arange(
|
|
latent_image.numel(), dtype=torch.float32
|
|
).reshape(latent_image.shape)
|
|
return base + float(seed) * 1e6
|
|
|
|
|
|
def test_progressive_sampler_schema_exposes_manual_default_auto_chunking():
|
|
schema = SeedVR2ProgressiveSampler.define_schema()
|
|
inputs = {item.id: item for item in schema.inputs}
|
|
|
|
assert inputs["chunking_mode"].options == ["manual", "auto"]
|
|
assert inputs["chunking_mode"].default == "manual"
|
|
|
|
|
|
def test_vram_seed_frames_per_chunk_predicts_4n1_clamped_to_t_pixel():
|
|
"""VRAM chunk-size law: seed = nearest 4n+1 to 4*(free_GB - 3), clamped to [1, t_pixel]."""
|
|
gib = 1024 ** 3
|
|
seed = nodes_seedvr_mod._seedvr2_vram_seed_frames_per_chunk
|
|
assert seed(20 * gib, 65) == 65 # 4*(20-3)=68 -> 4n+1 69 -> clamp to t_pixel 65
|
|
assert seed(6 * gib, 97) == 13 # 4*(6-3)=12 -> nearest 4n+1 13
|
|
assert seed(2 * gib, 97) == 1 # below margin -> floor at 1
|
|
|
|
|
|
@pytest.mark.parametrize("bad_chunk", [0, -1, 2])
|
|
def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk):
|
|
"""``frames_per_chunk`` violating 4n+1 (or <1) must raise ``ValueError`` before any model invocation."""
|
|
latent, pos, neg, _, _ = _make_inputs(T=5)
|
|
|
|
sampler_called = {"n": 0}
|
|
|
|
def _should_not_be_called(*args, **kwargs):
|
|
sampler_called["n"] += 1
|
|
return torch.zeros(1)
|
|
|
|
with patch.object(comfy.sample, "sample",
|
|
side_effect=_should_not_be_called), \
|
|
patch.object(comfy.sample, "fix_empty_latent_channels",
|
|
side_effect=_identity_fix_empty), \
|
|
patch.object(comfy.sample, "prepare_noise",
|
|
side_effect=_fingerprinted_prepare_noise):
|
|
with pytest.raises(ValueError) as excinfo:
|
|
SeedVR2ProgressiveSampler.execute(
|
|
model=None, seed=0, steps=2, cfg=1.0,
|
|
sampler_name="euler", scheduler="simple",
|
|
positive=pos, negative=neg, latent=latent,
|
|
denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0,
|
|
)
|
|
assert str(bad_chunk) in str(excinfo.value)
|
|
assert sampler_called["n"] == 0
|