mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 16:29:25 +08:00
Add SeedVR2 VAE coverage
This commit is contained in:
parent
0fdbc5d260
commit
bed0cd2b8c
86
tests-unit/comfy_test/seedvr_vae_forward_test.py
Normal file
86
tests-unit/comfy_test/seedvr_vae_forward_test.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must
|
||||
honor the actual tensor/tuple return contract of ``encode()`` and
|
||||
``decode_()`` and must NOT dereference diffusers-style ``.latent_dist``
|
||||
or ``.sample`` attributes on those returns.
|
||||
|
||||
The pre-fix body raised ``AttributeError: 'Tensor' object has no
|
||||
attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and
|
||||
``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'``
|
||||
for ``mode == "decode"`` (the class only defines ``decode_`` with a
|
||||
trailing underscore). The post-fix body unwraps the optional one-element
|
||||
tuple shape that ``return_dict=False`` produces and returns the tensor
|
||||
directly.
|
||||
|
||||
Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses
|
||||
the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and
|
||||
overrides ``encode``/``decode_`` with known tensors so the contract can
|
||||
be probed without loading any real VAE weights.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402
|
||||
|
||||
|
||||
_LATENT_SHAPE = (1, 16, 2, 2, 2)
|
||||
_DECODED_SHAPE = (1, 3, 5, 16, 16)
|
||||
_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16)
|
||||
_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2)
|
||||
|
||||
|
||||
class _StubVAE(VideoAutoencoderKL):
|
||||
def __init__(self):
|
||||
nn.Module.__init__(self)
|
||||
self._encode_out = torch.zeros(*_LATENT_SHAPE)
|
||||
self._decode_out = torch.zeros(*_DECODED_SHAPE)
|
||||
|
||||
def encode(self, x, return_dict=True):
|
||||
return self._encode_out
|
||||
|
||||
def decode_(self, z, return_dict=True):
|
||||
return self._decode_out
|
||||
|
||||
|
||||
def test_forward_encode_returns_tensor():
|
||||
vae = _StubVAE()
|
||||
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
|
||||
result = vae.forward(x, mode="encode")
|
||||
assert type(result) is torch.Tensor
|
||||
assert result.shape == torch.Size(_LATENT_SHAPE)
|
||||
|
||||
|
||||
def test_forward_decode_returns_tensor():
|
||||
vae = _StubVAE()
|
||||
z = torch.zeros(*_INPUT_DECODE_SHAPE)
|
||||
result = vae.forward(z, mode="decode")
|
||||
assert type(result) is torch.Tensor
|
||||
assert result.shape == torch.Size(_DECODED_SHAPE)
|
||||
|
||||
|
||||
class _TupleReturningStubVAE(VideoAutoencoderKL):
|
||||
"""Stub whose ``encode``/``decode_`` return the ``(tensor,)`` tuple of ``return_dict=False``, exercising the unwrap branch of ``VideoAutoencoderKL.forward``."""
|
||||
|
||||
def __init__(self):
|
||||
nn.Module.__init__(self)
|
||||
self._encode_tensor = torch.zeros(*_LATENT_SHAPE)
|
||||
self._decode_tensor = torch.zeros(*_DECODED_SHAPE)
|
||||
|
||||
def encode(self, x, return_dict=True):
|
||||
return (self._encode_tensor,)
|
||||
|
||||
def decode_(self, z, return_dict=True):
|
||||
return (self._decode_tensor,)
|
||||
|
||||
|
||||
def test_forward_all_unwraps_one_tuple_at_each_step():
|
||||
vae = _TupleReturningStubVAE()
|
||||
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
|
||||
result = vae.forward(x, mode="all")
|
||||
assert type(result) is torch.Tensor
|
||||
assert result.shape == torch.Size(_DECODED_SHAPE)
|
||||
91
tests-unit/comfy_test/test_seedvr2_vae_decode.py
Normal file
91
tests-unit/comfy_test/test_seedvr2_vae_decode.py
Normal file
@ -0,0 +1,91 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||
from comfy_extras import nodes_seedvr # noqa: E402
|
||||
|
||||
|
||||
def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper:
|
||||
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
vae_mod.VideoAutoencoderKLWrapper
|
||||
)
|
||||
nn.Module.__init__(wrapper)
|
||||
return wrapper
|
||||
|
||||
|
||||
def _fingerprint_decode_(self, z, return_dict=True):
|
||||
b = int(z.shape[0])
|
||||
t = int(z.shape[2])
|
||||
h = int(z.shape[3])
|
||||
w = int(z.shape[4])
|
||||
out = torch.empty(b, 3, t, h * 8, w * 8)
|
||||
for batch_idx in range(b):
|
||||
out[batch_idx].fill_(float(batch_idx + 1))
|
||||
return out
|
||||
|
||||
|
||||
def _decode_with_patches(wrapper, z):
|
||||
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_):
|
||||
return wrapper.decode(z)
|
||||
|
||||
|
||||
def test_decode_b2_t3_multi_frame_batch_unchanged():
|
||||
wrapper = _make_wrapper()
|
||||
|
||||
out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2))
|
||||
|
||||
assert tuple(out.shape) == (2, 3, 3, 16, 16)
|
||||
|
||||
|
||||
class _Wrapper(vae_mod.VideoAutoencoderKLWrapper):
|
||||
def __init__(self):
|
||||
nn.Module.__init__(self)
|
||||
self.calls = []
|
||||
|
||||
def parameters(self):
|
||||
return iter([torch.nn.Parameter(torch.zeros(()))])
|
||||
|
||||
def _decode_stub(self, latent):
|
||||
self.calls.append(tuple(latent.shape))
|
||||
return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8)
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state():
|
||||
wrapper = _Wrapper()
|
||||
|
||||
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
|
||||
out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5))
|
||||
|
||||
assert tuple(out.shape) == (1, 3, 2, 32, 40)
|
||||
assert wrapper.calls == [(1, 16, 2, 4, 5)]
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents():
|
||||
wrapper = _Wrapper()
|
||||
|
||||
with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"):
|
||||
wrapper.decode(torch.zeros(1, 16, 4))
|
||||
|
||||
|
||||
def _t_padded(t_in: int) -> int:
|
||||
if t_in == 1:
|
||||
return 1
|
||||
if t_in <= 4:
|
||||
return 5
|
||||
if (t_in - 1) % 4 == 0:
|
||||
return t_in
|
||||
return t_in + (4 - ((t_in - 1) % 4))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("t_in", [1, 5, 9])
|
||||
def test_t_padded_matches_cut_videos(t_in):
|
||||
dummy = torch.zeros(1, t_in, 1, 1, 1)
|
||||
assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in)
|
||||
348
tests-unit/comfy_test/test_seedvr2_vae_tiled.py
Normal file
348
tests-unit/comfy_test/test_seedvr2_vae_tiled.py
Normal file
@ -0,0 +1,348 @@
|
||||
from contextlib import ExitStack
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
||||
import comfy.sd as sd_mod # noqa: E402
|
||||
from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_seedvr_vae_tiled_decode_latent_min_size_override.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
|
||||
from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae
|
||||
|
||||
class StubVAEModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.slicing_latent_min_size = 2
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.device = torch.device("cpu")
|
||||
self.use_slicing = True
|
||||
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||
self.decode_min_sizes = []
|
||||
self.memory_states = []
|
||||
|
||||
def decode_(self, t_chunk):
|
||||
self.decode_min_sizes.append(self.slicing_latent_min_size)
|
||||
return VideoAutoencoderKL.slicing_decode(self, t_chunk)
|
||||
|
||||
def _decode(self, z, memory_state=MemoryState.DISABLED):
|
||||
self.memory_states.append(memory_state)
|
||||
b, c, d, h, w = z.shape
|
||||
return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype)
|
||||
|
||||
vae = StubVAEModel()
|
||||
z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32)
|
||||
|
||||
tiled_vae(
|
||||
z,
|
||||
vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=0,
|
||||
temporal_overlap=0,
|
||||
encode=False,
|
||||
)
|
||||
|
||||
assert vae.decode_min_sizes == [5]
|
||||
assert vae.memory_states == [MemoryState.DISABLED]
|
||||
assert vae.slicing_latent_min_size == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_seedvr_vae_tiled_encode_runt_slice_override.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_zero_temporal_size_preserves_min_size_when_encode_raises():
|
||||
from comfy.ldm.seedvr.vae import tiled_vae
|
||||
|
||||
class RaisingVAEModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.slicing_sample_min_size = 4
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.device = torch.device("cpu")
|
||||
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||
|
||||
def encode(self, t_chunk):
|
||||
raise RuntimeError("simulated encode failure")
|
||||
|
||||
vae = RaisingVAEModel()
|
||||
x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32)
|
||||
|
||||
raised = False
|
||||
try:
|
||||
tiled_vae(
|
||||
x,
|
||||
vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=0,
|
||||
temporal_overlap=0,
|
||||
encode=True,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
if "simulated encode failure" not in str(exc):
|
||||
raise
|
||||
raised = True
|
||||
|
||||
assert raised
|
||||
assert vae.slicing_sample_min_size == 4
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_seedvr_vae_tiled_temporal_slicing.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _SlicingDecodeVAE(nn.Module):
|
||||
def __init__(self, slicing_latent_min_size):
|
||||
super().__init__()
|
||||
self.slicing_latent_min_size = slicing_latent_min_size
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.device = torch.device("cpu")
|
||||
self.use_slicing = True
|
||||
self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||
self.decode_min_sizes = []
|
||||
self.memory_states = []
|
||||
|
||||
def decode_(self, z):
|
||||
self.decode_min_sizes.append(self.slicing_latent_min_size)
|
||||
return vae_mod.VideoAutoencoderKL.slicing_decode(self, z)
|
||||
|
||||
def _decode(self, z, memory_state=MemoryState.DISABLED):
|
||||
self.memory_states.append(memory_state)
|
||||
x = z[:, :1].repeat(
|
||||
1,
|
||||
3,
|
||||
1,
|
||||
self.spatial_downsample_factor,
|
||||
self.spatial_downsample_factor,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
|
||||
vae = _SlicingDecodeVAE(slicing_latent_min_size=2)
|
||||
z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8)
|
||||
|
||||
tiled_vae(
|
||||
z,
|
||||
vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=12,
|
||||
temporal_overlap=4,
|
||||
encode=False,
|
||||
)
|
||||
|
||||
assert vae.decode_min_sizes == [2]
|
||||
assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE]
|
||||
assert vae.slicing_latent_min_size == 2
|
||||
|
||||
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
vae_mod.VideoAutoencoderKLWrapper
|
||||
)
|
||||
nn.Module.__init__(wrapper)
|
||||
seedvr2_tiling = {
|
||||
"enable_tiling": True,
|
||||
"tile_size": (64, 64),
|
||||
"tile_overlap": (0, 0),
|
||||
"temporal_size": 8,
|
||||
"temporal_overlap": 7,
|
||||
}
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_tiled_vae(latent, model, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return torch.zeros(1, 3, 1, 16, 16)
|
||||
|
||||
with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae):
|
||||
wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling)
|
||||
|
||||
assert captured["temporal_overlap"] == 7
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _force_oom(*a, **k):
|
||||
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
|
||||
|
||||
|
||||
def _make_vae(first_stage_model, latent_channels, latent_dim):
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = first_stage_model
|
||||
vae.patcher = MagicMock()
|
||||
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
|
||||
vae.device = vae.output_device = torch.device("cpu")
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.disable_offload = True
|
||||
vae.extra_1d_channel = None
|
||||
vae.upscale_ratio = vae.downscale_ratio = 8
|
||||
vae.upscale_index_formula = vae.downscale_index_formula = None
|
||||
vae.output_channels = 3
|
||||
vae.latent_channels = latent_channels
|
||||
vae.latent_dim = latent_dim
|
||||
vae.vae_output_dtype = lambda: torch.float32
|
||||
vae.spacial_compression_decode = lambda: 8
|
||||
vae.process_input = lambda x: x
|
||||
vae.process_output = lambda x: x
|
||||
vae.throw_exception_if_invalid = lambda: None
|
||||
vae.memory_used_decode = lambda *a, **k: 1
|
||||
return vae
|
||||
|
||||
|
||||
def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode):
|
||||
mm = sd_mod.model_management
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(patch.object(mm, "raise_non_oom", lambda e: None))
|
||||
stack.enter_context(patch.object(mm, "load_models_gpu", lambda *a, **k: None))
|
||||
stack.enter_context(patch.object(mm, "soft_empty_cache", lambda: None))
|
||||
stack.enter_context(patch.object(sd_mod.VAE, "_decode_tiled_owned", seedvr2_call))
|
||||
stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_", generic_call))
|
||||
if patch_wrapper_decode:
|
||||
stack.enter_context(patch.object(
|
||||
seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode",
|
||||
side_effect=_force_oom))
|
||||
vae.decode(samples)
|
||||
|
||||
|
||||
def test_4d_seedvr2_latent_routes_to_owned_decode_tiled():
|
||||
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
seedvr_vae_mod.VideoAutoencoderKLWrapper)
|
||||
vae = _make_vae(wrapper, latent_channels=16, latent_dim=3)
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
|
||||
_dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True)
|
||||
assert seedvr2_call.call_count == 1
|
||||
assert generic_call.call_count == 0
|
||||
|
||||
|
||||
def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled():
|
||||
first_stage = MagicMock()
|
||||
first_stage.comfy_handles_tiling = False
|
||||
first_stage.decode = MagicMock(side_effect=_force_oom)
|
||||
vae = _make_vae(first_stage, latent_channels=4, latent_dim=2)
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
|
||||
_dispatch(vae, torch.zeros(1, 4, 8, 8), seedvr2_call, generic_call, False)
|
||||
assert generic_call.call_count == 1
|
||||
assert seedvr2_call.call_count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _populate_common_vae_attrs_fallback(vae):
|
||||
vae.patcher = MagicMock()
|
||||
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
|
||||
vae.device = torch.device("cpu")
|
||||
vae.output_device = torch.device("cpu")
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.disable_offload = True
|
||||
vae.extra_1d_channel = None
|
||||
vae.upscale_ratio = 8
|
||||
vae.upscale_index_formula = None
|
||||
vae.output_channels = 3
|
||||
vae.latent_channels = 16
|
||||
vae.latent_dim = 3
|
||||
vae.downscale_ratio = 8
|
||||
vae.downscale_index_formula = None
|
||||
vae.not_video = False
|
||||
vae.crop_input = False
|
||||
vae.pad_channel_value = None
|
||||
|
||||
vae.vae_output_dtype = lambda: torch.float32
|
||||
vae.spacial_compression_encode = lambda: 8
|
||||
vae.process_input = lambda x: x
|
||||
vae.process_output = lambda x: x
|
||||
vae.throw_exception_if_invalid = lambda: None
|
||||
vae.memory_used_encode = lambda *a, **k: 1
|
||||
|
||||
|
||||
def _make_seedvr2_vae_fallback():
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
seedvr_vae_mod.VideoAutoencoderKLWrapper
|
||||
)
|
||||
vae.first_stage_model = wrapper
|
||||
_populate_common_vae_attrs_fallback(vae)
|
||||
return vae
|
||||
|
||||
|
||||
def _make_non_seedvr2_vae_fallback():
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = MagicMock()
|
||||
vae.first_stage_model.comfy_handles_tiling = False
|
||||
_populate_common_vae_attrs_fallback(vae)
|
||||
return vae
|
||||
|
||||
|
||||
def _force_regular_encode_oom(*args, **kwargs):
|
||||
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
|
||||
|
||||
|
||||
def test_seedvr2_3d_routes_to_owned_encode_tiled_on_oom():
|
||||
vae = _make_seedvr2_vae_fallback()
|
||||
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
|
||||
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
|
||||
with patch.object(sd_mod.model_management, "raise_non_oom",
|
||||
lambda e: None), \
|
||||
patch.object(sd_mod.model_management, "load_models_gpu",
|
||||
lambda *a, **k: None), \
|
||||
patch.object(sd_mod.model_management, "soft_empty_cache",
|
||||
lambda: None), \
|
||||
patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode",
|
||||
side_effect=_force_regular_encode_oom), \
|
||||
patch.object(sd_mod.VAE, "_encode_tiled_owned", seedvr2_call), \
|
||||
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
|
||||
vae.encode(pixel_samples)
|
||||
|
||||
assert seedvr2_call.call_count == 1, (
|
||||
f"Expected _encode_tiled_owned to be called once for a SeedVR2 3D "
|
||||
f"input under OOM fallback; got {seedvr2_call.call_count} calls."
|
||||
)
|
||||
assert generic_call.call_count == 0, (
|
||||
f"encode_tiled_3d must NOT be called for a SeedVR2 input; got "
|
||||
f"{generic_call.call_count} calls."
|
||||
)
|
||||
|
||||
|
||||
def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete():
|
||||
vae = _make_non_seedvr2_vae_fallback()
|
||||
vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8)
|
||||
vae.upscale_ratio = (lambda a: a * 4, 8, 8)
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
|
||||
|
||||
with patch.object(sd_mod.model_management, "load_models_gpu",
|
||||
lambda *a, **k: None), \
|
||||
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
|
||||
vae.encode_tiled(pixel_samples)
|
||||
|
||||
assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64)
|
||||
Loading…
Reference in New Issue
Block a user