from unittest.mock import patch import torch from comfy.cli_args import args as cli_args if not torch.cuda.is_available(): cli_args.cpu = True from comfy_extras import nodes_seedvr # noqa: E402 def _schema_ids(items): return [item.id for item in items] def test_seedvr2_post_processing_schema(): schema = nodes_seedvr.SeedVR2PostProcessing.define_schema() assert _schema_ids(schema.inputs) == ["decoded", "original_image", "upscaled_shorter_edge", "color_correction_method"] assert schema.inputs[2].default is None assert schema.inputs[2].min == 2 assert schema.inputs[2].force_input is True assert schema.inputs[3].options == ["lab", "wavelet", "adain", "none"] assert schema.inputs[3].default == "lab" assert schema.outputs[0].get_io_type() == "IMAGE" def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch): decoded = torch.full((1, 3, 4, 4), 0.25) reference = torch.full((1, 3, 4, 4), 0.75) def _lab(content, style): raise torch.cuda.OutOfMemoryError("CUDA out of memory") monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu")) monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000) monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None) with patch.object(nodes_seedvr, "lab_color_transfer", _lab): try: nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked( decoded, reference, torch.device("cpu"), "lab", ) except RuntimeError as exc: assert "color_correction_method=lab" in str(exc) assert " method=lab" not in str(exc) else: raise AssertionError("expected RuntimeError for one-frame LAB OOM") def test_seedvr2_post_processing_unknown_color_correction_method_raises(): decoded = torch.zeros(1, 2, 4, 4, 3) original = torch.zeros(1, 2, 4, 4, 3) try: nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "bogus") except ValueError as exc: assert "color_correction_method" in str(exc) else: raise AssertionError("expected ValueError for unknown color_correction_method")