"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``.""" from unittest.mock import patch import pytest import torch from comfy.cli_args import args as cli_args if not torch.cuda.is_available(): cli_args.cpu = True import comfy.sample # noqa: E402 import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402 from comfy_extras.nodes_seedvr import SeedVR2ProgressiveSampler # noqa: E402 _LAT_C = 16 _COND_C = 17 def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8): """Build minimal SeedVR2-shaped sampling inputs.""" samples_5d = torch.arange( B * _LAT_C * T * H * W, dtype=torch.float32 ).reshape(B, _LAT_C, T, H, W) samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous() cond_5d = torch.arange( B * _COND_C * T * H * W, dtype=torch.float32 ).reshape(B, _COND_C, T, H, W) + 10000.0 cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous() text_pos = torch.zeros(1, 4, 32) text_neg = torch.zeros(1, 4, 32) positive = [[text_pos, {"condition": cond.clone()}]] negative = [[text_neg, {"condition": cond.clone()}]] latent_image = {"samples": samples} return latent_image, positive, negative, samples_5d, cond_5d def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None): return latent_image def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None): """Return a tensor whose values encode ``(seed, position)``.""" base = torch.arange( latent_image.numel(), dtype=torch.float32 ).reshape(latent_image.shape) return base + float(seed) * 1e6 def test_progressive_sampler_schema_exposes_manual_default_auto_chunking(): schema = SeedVR2ProgressiveSampler.define_schema() inputs = {item.id: item for item in schema.inputs} assert inputs["chunking_mode"].options == ["manual", "auto"] assert inputs["chunking_mode"].default == "manual" def test_vram_seed_frames_per_chunk_predicts_4n1_clamped_to_t_pixel(): """VRAM chunk-size law: seed = nearest 4n+1 to 4*(free_GB - 3), clamped to [1, t_pixel].""" gib = 1024 ** 3 seed = nodes_seedvr_mod._seedvr2_vram_seed_frames_per_chunk assert seed(20 * gib, 65) == 65 # 4*(20-3)=68 -> 4n+1 69 -> clamp to t_pixel 65 assert seed(6 * gib, 97) == 13 # 4*(6-3)=12 -> nearest 4n+1 13 assert seed(2 * gib, 97) == 1 # below margin -> floor at 1 @pytest.mark.parametrize("bad_chunk", [0, -1, 2]) def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk): """``frames_per_chunk`` violating 4n+1 (or <1) must raise ``ValueError`` before any model invocation.""" latent, pos, neg, _, _ = _make_inputs(T=5) sampler_called = {"n": 0} def _should_not_be_called(*args, **kwargs): sampler_called["n"] += 1 return torch.zeros(1) with patch.object(comfy.sample, "sample", side_effect=_should_not_be_called), \ patch.object(comfy.sample, "fix_empty_latent_channels", side_effect=_identity_fix_empty), \ patch.object(comfy.sample, "prepare_noise", side_effect=_fingerprinted_prepare_noise): with pytest.raises(ValueError) as excinfo: SeedVR2ProgressiveSampler.execute( model=None, seed=0, steps=2, cfg=1.0, sampler_name="euler", scheduler="simple", positive=pos, negative=neg, latent=latent, denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0, ) assert str(bad_chunk) in str(excinfo.value) assert sampler_called["n"] == 0