mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-28 01:47:32 +08:00
106 lines
3.7 KiB
Python
106 lines
3.7 KiB
Python
"""Regression tests for the diffusers-format guard inside ``comfy.sd.VAE.__init__``.
|
|
|
|
The guard previously indexed ``metadata["keep_diffusers_format"]`` directly,
|
|
raising ``KeyError`` when ``metadata`` was non-``None`` but lacked that key. The
|
|
fixed guard uses ``metadata.get("keep_diffusers_format") != "true"``: a missing
|
|
key flows through to ``convert_vae_state_dict``; the explicit ``"true"`` value
|
|
bypasses it.
|
|
|
|
Five cells exercise every reachable shape of the guard input — missing key,
|
|
explicit ``"true"``, ``None``, explicit non-``"true"``, empty dict — and halt
|
|
the constructor at the first post-guard call (``model_management.is_amd``).
|
|
``_make_standin`` borrows ``__init__`` onto a bare class, mirroring
|
|
``seedvr_model_test.py::_make_standin`` (#109). ``_exercise_guard`` single-
|
|
sources the patched-constructor harness so the cells stay synchronised.
|
|
"""
|
|
|
|
from comfy.cli_args import args
|
|
import torch
|
|
|
|
if not torch.cuda.is_available():
|
|
args.cpu = True
|
|
|
|
import contextlib # noqa: E402
|
|
import unittest.mock # noqa: E402
|
|
|
|
import comfy.sd # noqa: E402
|
|
|
|
|
|
_DIFFUSERS_TRIGGER_KEY = "decoder.up_blocks.0.resnets.0.norm1.weight"
|
|
|
|
|
|
class _PostGuardReached(Exception):
|
|
"""Sentinel raised by the patched ``is_amd`` to halt ``__init__`` at the first post-guard statement."""
|
|
|
|
|
|
def _make_standin():
|
|
class _StandIn:
|
|
__init__ = comfy.sd.VAE.__init__
|
|
|
|
return _StandIn
|
|
|
|
|
|
def _exercise_guard(metadata):
|
|
"""Drive ``VAE.__init__`` with the diffusers trigger key and the supplied
|
|
``metadata``; halt at ``is_amd``. Returns ``(mock_convert, mock_is_amd)``
|
|
for branch (call_count) + reach (called) assertions per cell.
|
|
"""
|
|
StandIn = _make_standin()
|
|
sd = {_DIFFUSERS_TRIGGER_KEY: torch.zeros(1)}
|
|
|
|
with unittest.mock.patch.object(
|
|
comfy.sd.diffusers_convert,
|
|
"convert_vae_state_dict",
|
|
autospec=True,
|
|
side_effect=lambda state_dict: state_dict,
|
|
) as mock_convert, unittest.mock.patch.object(
|
|
comfy.sd.model_management,
|
|
"is_amd",
|
|
autospec=True,
|
|
side_effect=_PostGuardReached("post-guard reached"),
|
|
) as mock_is_amd:
|
|
with contextlib.suppress(_PostGuardReached):
|
|
StandIn(sd=sd, metadata=metadata)
|
|
|
|
return mock_convert, mock_is_amd
|
|
|
|
|
|
def test_diffusers_guard_invokes_convert_when_metadata_missing_key():
|
|
"""AC1: metadata is non-None but lacks ``keep_diffusers_format`` → convert is invoked."""
|
|
mock_convert, mock_is_amd = _exercise_guard({"unrelated_key": "value"})
|
|
|
|
assert mock_convert.call_count == 1
|
|
assert mock_is_amd.called
|
|
|
|
|
|
def test_diffusers_guard_skips_convert_when_metadata_pins_keep_true():
|
|
"""AC2: metadata pins ``keep_diffusers_format == "true"`` → convert is skipped."""
|
|
mock_convert, mock_is_amd = _exercise_guard({"keep_diffusers_format": "true"})
|
|
|
|
assert mock_convert.call_count == 0
|
|
assert mock_is_amd.called
|
|
|
|
|
|
def test_diffusers_guard_invokes_convert_when_metadata_is_none():
|
|
"""AC3: metadata is ``None`` → first disjunct fires, convert is invoked."""
|
|
mock_convert, mock_is_amd = _exercise_guard(None)
|
|
|
|
assert mock_convert.call_count == 1
|
|
assert mock_is_amd.called
|
|
|
|
|
|
def test_diffusers_guard_invokes_convert_when_metadata_pins_keep_false():
|
|
"""AC4: metadata pins a non-``"true"`` value → second disjunct fires, convert is invoked."""
|
|
mock_convert, mock_is_amd = _exercise_guard({"keep_diffusers_format": "false"})
|
|
|
|
assert mock_convert.call_count == 1
|
|
assert mock_is_amd.called
|
|
|
|
|
|
def test_diffusers_guard_invokes_convert_when_metadata_is_empty_dict():
|
|
"""AC5: metadata is ``{}`` (the ``convert_old_quants`` None→{} normalization shape) → convert is invoked."""
|
|
mock_convert, mock_is_amd = _exercise_guard({})
|
|
|
|
assert mock_convert.call_count == 1
|
|
assert mock_is_amd.called
|