from contextlib import ExitStack from unittest.mock import MagicMock, patch import torch import torch.nn as nn from comfy.cli_args import args as cli_args if not torch.cuda.is_available(): cli_args.cpu = True import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 import comfy.sd as sd_mod # noqa: E402 from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402 # --------------------------------------------------------------------------- # From test_seedvr_vae_tiled_decode_latent_min_size_override.py # --------------------------------------------------------------------------- def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae class StubVAEModel(torch.nn.Module): def __init__(self): super().__init__() self.slicing_latent_min_size = 2 self.spatial_downsample_factor = 8 self.temporal_downsample_factor = 4 self.device = torch.device("cpu") self.use_slicing = True self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) self.decode_min_sizes = [] self.memory_states = [] def decode_(self, t_chunk): self.decode_min_sizes.append(self.slicing_latent_min_size) return VideoAutoencoderKL.slicing_decode(self, t_chunk) def _decode(self, z, memory_state=MemoryState.DISABLED): self.memory_states.append(memory_state) b, c, d, h, w = z.shape return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype) vae = StubVAEModel() z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32) tiled_vae( z, vae, tile_size=(64, 64), tile_overlap=(0, 0), temporal_size=0, temporal_overlap=0, encode=False, ) assert vae.decode_min_sizes == [5] assert vae.memory_states == [MemoryState.DISABLED] assert vae.slicing_latent_min_size == 2 # --------------------------------------------------------------------------- # From test_seedvr_vae_tiled_encode_runt_slice_override.py # --------------------------------------------------------------------------- def test_zero_temporal_size_preserves_min_size_when_encode_raises(): from comfy.ldm.seedvr.vae import tiled_vae class RaisingVAEModel(torch.nn.Module): def __init__(self): super().__init__() self.slicing_sample_min_size = 4 self.spatial_downsample_factor = 8 self.temporal_downsample_factor = 4 self.device = torch.device("cpu") self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) def encode(self, t_chunk): raise RuntimeError("simulated encode failure") vae = RaisingVAEModel() x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32) raised = False try: tiled_vae( x, vae, tile_size=(64, 64), tile_overlap=(0, 0), temporal_size=0, temporal_overlap=0, encode=True, ) except RuntimeError as exc: if "simulated encode failure" not in str(exc): raise raised = True assert raised assert vae.slicing_sample_min_size == 4 # --------------------------------------------------------------------------- # From test_seedvr_vae_tiled_temporal_slicing.py # --------------------------------------------------------------------------- class _SlicingDecodeVAE(nn.Module): def __init__(self, slicing_latent_min_size): super().__init__() self.slicing_latent_min_size = slicing_latent_min_size self.spatial_downsample_factor = 8 self.temporal_downsample_factor = 4 self.device = torch.device("cpu") self.use_slicing = True self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32)) self.decode_min_sizes = [] self.memory_states = [] def decode_(self, z): self.decode_min_sizes.append(self.slicing_latent_min_size) return vae_mod.VideoAutoencoderKL.slicing_decode(self, z) def _decode(self, z, memory_state=MemoryState.DISABLED): self.memory_states.append(memory_state) x = z[:, :1].repeat( 1, 3, 1, self.spatial_downsample_factor, self.spatial_downsample_factor, ) return x def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size(): vae = _SlicingDecodeVAE(slicing_latent_min_size=2) z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8) tiled_vae( z, vae, tile_size=(64, 64), tile_overlap=(0, 0), temporal_size=12, temporal_overlap=4, encode=False, ) assert vae.decode_min_sizes == [2] assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE] assert vae.slicing_latent_min_size == 2 wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( vae_mod.VideoAutoencoderKLWrapper ) nn.Module.__init__(wrapper) seedvr2_tiling = { "enable_tiling": True, "tile_size": (64, 64), "tile_overlap": (0, 0), "temporal_size": 8, "temporal_overlap": 7, } captured = {} def _fake_tiled_vae(latent, model, **kwargs): captured.update(kwargs) return torch.zeros(1, 3, 1, 16, 16) with ( patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae), patch.object(vae_mod, "lab_color_transfer", side_effect=lambda content, style: content), ): wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling) assert captured["temporal_overlap"] == 7 # --------------------------------------------------------------------------- # From test_vae_decode_tiled_dispatcher_seedvr2_4d.py # --------------------------------------------------------------------------- def _force_oom(*a, **k): raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") def _make_vae(first_stage_model, latent_channels, latent_dim): vae = sd_mod.VAE.__new__(sd_mod.VAE) vae.first_stage_model = first_stage_model vae.patcher = MagicMock() vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) vae.device = vae.output_device = torch.device("cpu") vae.vae_dtype = torch.float32 vae.disable_offload = True vae.extra_1d_channel = None vae.upscale_ratio = vae.downscale_ratio = 8 vae.upscale_index_formula = vae.downscale_index_formula = None vae.output_channels = 3 vae.latent_channels = latent_channels vae.latent_dim = latent_dim vae.vae_output_dtype = lambda: torch.float32 vae.spacial_compression_decode = lambda: 8 vae.process_input = lambda x: x vae.process_output = lambda x: x vae.throw_exception_if_invalid = lambda: None vae.memory_used_decode = lambda *a, **k: 1 return vae def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode): mm = sd_mod.model_management with ExitStack() as stack: stack.enter_context(patch.object(mm, "raise_non_oom", lambda e: None)) stack.enter_context(patch.object(mm, "load_models_gpu", lambda *a, **k: None)) stack.enter_context(patch.object(mm, "soft_empty_cache", lambda: None)) stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call)) stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_", generic_call)) if patch_wrapper_decode: stack.enter_context(patch.object( seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode", side_effect=_force_oom)) vae.decode(samples) def test_4d_seedvr2_latent_routes_to_decode_tiled_seedvr2(): wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( seedvr_vae_mod.VideoAutoencoderKLWrapper) vae = _make_vae(wrapper, latent_channels=16, latent_dim=3) seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) _dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True) assert seedvr2_call.call_count == 1 assert generic_call.call_count == 0 def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled(): first_stage = MagicMock() first_stage.decode = MagicMock(side_effect=_force_oom) vae = _make_vae(first_stage, latent_channels=4, latent_dim=2) seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) _dispatch(vae, torch.zeros(1, 4, 8, 8), seedvr2_call, generic_call, False) assert generic_call.call_count == 1 assert seedvr2_call.call_count == 0 # --------------------------------------------------------------------------- # From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py # --------------------------------------------------------------------------- def _populate_common_vae_attrs_fallback(vae): vae.patcher = MagicMock() vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) vae.device = torch.device("cpu") vae.output_device = torch.device("cpu") vae.vae_dtype = torch.float32 vae.disable_offload = True vae.extra_1d_channel = None vae.upscale_ratio = 8 vae.upscale_index_formula = None vae.output_channels = 3 vae.latent_channels = 16 vae.latent_dim = 3 vae.downscale_ratio = 8 vae.downscale_index_formula = None vae.not_video = False vae.crop_input = False vae.pad_channel_value = None vae.vae_output_dtype = lambda: torch.float32 vae.spacial_compression_encode = lambda: 8 vae.process_input = lambda x: x vae.process_output = lambda x: x vae.throw_exception_if_invalid = lambda: None vae.memory_used_encode = lambda *a, **k: 1 def _make_seedvr2_vae_fallback(): vae = sd_mod.VAE.__new__(sd_mod.VAE) wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( seedvr_vae_mod.VideoAutoencoderKLWrapper ) vae.first_stage_model = wrapper _populate_common_vae_attrs_fallback(vae) return vae def _make_non_seedvr2_vae_fallback(): vae = sd_mod.VAE.__new__(sd_mod.VAE) vae.first_stage_model = MagicMock() _populate_common_vae_attrs_fallback(vae) return vae def _force_regular_encode_oom(*args, **kwargs): raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") def test_seedvr2_3d_routes_to_encode_tiled_seedvr2_on_oom(): vae = _make_seedvr2_vae_fallback() pixel_samples = torch.zeros((1, 8, 64, 64, 3)) seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) with patch.object(sd_mod.model_management, "raise_non_oom", lambda e: None), \ patch.object(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None), \ patch.object(sd_mod.model_management, "soft_empty_cache", lambda: None), \ patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode", side_effect=_force_regular_encode_oom), \ patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call, create=True), \ patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): vae.encode(pixel_samples) assert seedvr2_call.call_count == 1, ( f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D " f"input under OOM fallback; got {seedvr2_call.call_count} calls." ) assert generic_call.call_count == 0, ( f"encode_tiled_3d must NOT be called for a SeedVR2 input; got " f"{generic_call.call_count} calls." ) def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete(): vae = _make_non_seedvr2_vae_fallback() vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8) vae.upscale_ratio = (lambda a: a * 4, 8, 8) generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) pixel_samples = torch.zeros((1, 8, 64, 64, 3)) with patch.object(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None), \ patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): vae.encode_tiled(pixel_samples) assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64)