ComfyUI/tests-unit/comfy_test/test_model_management.py
2026-04-07 13:08:53 -04:00

316 lines
11 KiB
Python

"""
Unit tests for comfy.model_management helpers and small APIs that are safe to run on CPU.
"""
import sys
import os
import unittest
from unittest.mock import MagicMock, patch
import torch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from comfy.cli_args import args
if not torch.cuda.is_available():
args.cpu = True
import comfy.model_management as mm
class TestDtypeSize(unittest.TestCase):
def test_float16_bfloat16_float32(self):
self.assertEqual(mm.dtype_size(torch.float16), 2)
self.assertEqual(mm.dtype_size(torch.bfloat16), 2)
self.assertEqual(mm.dtype_size(torch.float32), 4)
class TestDeviceHelpers(unittest.TestCase):
def test_is_device_type_and_cpu_mps_cuda(self):
self.assertTrue(mm.is_device_cpu(torch.device("cpu")))
self.assertFalse(mm.is_device_cpu(torch.device("meta")))
self.assertTrue(mm.is_device_mps(torch.device("mps")))
self.assertTrue(mm.is_device_cuda(torch.device("cuda:0")))
self.assertTrue(mm.is_device_xpu(torch.device("xpu:0")))
def test_is_device_type_non_device(self):
self.assertFalse(mm.is_device_type(object(), "cpu"))
def test_get_autocast_device(self):
self.assertEqual(mm.get_autocast_device(torch.device("cuda:0")), "cuda")
self.assertEqual(mm.get_autocast_device(torch.device("cpu")), "cpu")
self.assertEqual(mm.get_autocast_device(0), "cuda")
class TestSupportsDtype(unittest.TestCase):
def test_float32_always_true(self):
self.assertTrue(mm.supports_dtype(torch.device("cpu"), torch.float32))
self.assertTrue(mm.supports_dtype(torch.device("cuda:0"), torch.float32))
def test_cpu_non_fp32(self):
self.assertFalse(mm.supports_dtype(torch.device("cpu"), torch.float16))
self.assertFalse(mm.supports_dtype(torch.device("cpu"), torch.bfloat16))
class TestPickWeightDtype(unittest.TestCase):
def test_none_uses_fallback(self):
d = mm.pick_weight_dtype(None, torch.float16, device=torch.device("cpu"))
self.assertEqual(d, torch.float16)
def test_downgrades_when_larger_than_fallback(self):
d = mm.pick_weight_dtype(torch.float32, torch.float16, device=torch.device("cpu"))
self.assertEqual(d, torch.float16)
def test_respects_supports_cast(self):
with patch.object(mm, "supports_cast", return_value=False):
d = mm.pick_weight_dtype(torch.float16, torch.float32, device=torch.device("cpu"))
self.assertEqual(d, torch.float32)
class TestCastTo(unittest.TestCase):
def test_same_device_no_copy_returns_same_when_dtype_matches(self):
w = torch.ones(2, 3, dtype=torch.float32)
out = mm.cast_to(w, dtype=torch.float32, device=w.device, copy=False)
self.assertIs(out, w)
def test_dtype_conversion_same_device(self):
w = torch.ones(2, 3, dtype=torch.float32)
out = mm.cast_to(w, dtype=torch.float16, device=w.device, copy=False)
self.assertEqual(out.dtype, torch.float16)
self.assertTrue(torch.allclose(out.float(), w))
class TestModuleSize(unittest.TestCase):
def test_sums_parameter_bytes(self):
m = torch.nn.Linear(4, 8, bias=True)
expected = m.weight.nbytes + m.bias.nbytes
self.assertEqual(mm.module_size(m), expected)
class TestArchiveModelDtypes(unittest.TestCase):
def test_records_param_and_buffer_dtypes(self):
m = torch.nn.Module()
m.register_parameter("w", torch.nn.Parameter(torch.zeros(1, dtype=torch.float16)))
m.register_buffer("b", torch.zeros(1, dtype=torch.bfloat16))
mm.archive_model_dtypes(m)
self.assertEqual(m.w_comfy_model_dtype, torch.float16)
self.assertEqual(m.b_comfy_model_dtype, torch.bfloat16)
class TestPlatformHelpers(unittest.TestCase):
@patch("comfy.model_management.platform.mac_ver", return_value=("14.2.1", ("", "", ""), ""))
def test_mac_version_parses(self, _mac_ver):
self.assertEqual(mm.mac_version(), (14, 2, 1))
@patch("comfy.model_management.platform.mac_ver", side_effect=ValueError("bad"))
def test_mac_version_returns_none_on_error(self, _mac_ver):
self.assertIsNone(mm.mac_version())
@patch("comfy.model_management.platform.uname")
def test_is_wsl_microsoft_suffix(self, mock_uname):
mock_uname.return_value.release = "5.10.0-Microsoft"
self.assertTrue(mm.is_wsl())
@patch("comfy.model_management.platform.uname")
def test_is_wsl_wsl2_suffix(self, mock_uname):
mock_uname.return_value.release = "5.15.0-microsoft-standard-WSL2"
self.assertTrue(mm.is_wsl())
@patch("comfy.model_management.platform.uname")
def test_is_wsl_false(self, mock_uname):
mock_uname.return_value.release = "6.8.0-31-generic"
self.assertFalse(mm.is_wsl())
class TestOomHelpers(unittest.TestCase):
def test_is_oom_torch_oom_subclass(self):
if not hasattr(torch.cuda, "OutOfMemoryError"):
self.skipTest("torch.cuda.OutOfMemoryError not available")
err = torch.cuda.OutOfMemoryError("OOM")
self.assertTrue(mm.is_oom(err))
# def test_is_oom_accelerator_error_code_2(self):
# class FakeAccel(Exception):
# error_code = 2
#
# self.assertTrue(mm.is_oom(FakeAccel()))
# def test_is_oom_accelerator_error_message(self):
# class FakeAccel(Exception):
# error_code = 0
#
# self.assertTrue(mm.is_oom(FakeAccel("CUDA out of memory")))
def test_is_oom_false_for_generic(self):
self.assertFalse(mm.is_oom(RuntimeError("other")))
def test_raise_non_oom_raises_non_oom(self):
with self.assertRaises(RuntimeError):
mm.raise_non_oom(RuntimeError("x"))
def test_raise_non_oom_swallows_oom(self):
if not hasattr(torch.cuda, "OutOfMemoryError"):
self.skipTest("torch.cuda.OutOfMemoryError not available")
mm.raise_non_oom(torch.cuda.OutOfMemoryError("OOM"))
class TestInterruptProcessing(unittest.TestCase):
def setUp(self):
mm.interrupt_current_processing(False)
def tearDown(self):
mm.interrupt_current_processing(False)
def test_interrupt_toggle_and_query(self):
self.assertFalse(mm.processing_interrupted())
mm.interrupt_current_processing(True)
self.assertTrue(mm.processing_interrupted())
def test_throw_exception_if_processing_interrupted(self):
mm.interrupt_current_processing(True)
with self.assertRaises(mm.InterruptProcessingException):
mm.throw_exception_if_processing_interrupted()
self.assertFalse(mm.processing_interrupted())
class TestCpuMpsMode(unittest.TestCase):
def test_cpu_mode_follows_cpu_state(self):
with patch.object(mm, "cpu_state", mm.CPUState.CPU):
self.assertTrue(mm.cpu_mode())
with patch.object(mm, "cpu_state", mm.CPUState.GPU):
self.assertFalse(mm.cpu_mode())
def test_mps_mode(self):
with patch.object(mm, "cpu_state", mm.CPUState.MPS):
self.assertTrue(mm.mps_mode())
with patch.object(mm, "cpu_state", mm.CPUState.CPU):
self.assertFalse(mm.mps_mode())
class TestMemoryBudgetHelpers(unittest.TestCase):
def test_extra_reserved_memory_returns_int(self):
v = mm.extra_reserved_memory()
self.assertIsInstance(v, int)
self.assertGreater(v, 0)
def test_minimum_inference_memory(self):
v = mm.minimum_inference_memory()
self.assertGreater(v, mm.extra_reserved_memory())
class TestOffloadHelpers(unittest.TestCase):
def test_offloaded_memory_sums_matching_device(self):
dev = torch.device("cpu")
m1 = MagicMock()
m1.device = dev
m1.model_offloaded_memory.return_value = 100
m2 = MagicMock()
m2.device = torch.device("cuda:0")
m2.model_offloaded_memory.return_value = 999
self.assertEqual(mm.offloaded_memory([m1, m2], dev), 100)
def test_use_more_memory_stops_when_budget_exhausted(self):
dev = torch.device("cpu")
m = MagicMock()
m.device = dev
def use_vram(n, **_kwargs):
m.model_use_more_vram.assert_called_once()
return n
m.model_use_more_vram.side_effect = use_vram
mm.use_more_memory(50, [m], dev)
m.model_use_more_vram.assert_called_once_with(50)
class TestTextEncoderDtype(unittest.TestCase):
def tearDown(self):
for name in (
"fp8_e4m3fn_text_enc",
"fp8_e5m2_text_enc",
"fp16_text_enc",
"bf16_text_enc",
"fp32_text_enc",
):
setattr(args, name, False)
def test_cli_overrides(self):
args.fp32_text_enc = True
self.assertEqual(mm.text_encoder_dtype(), torch.float32)
args.fp32_text_enc = False
args.bf16_text_enc = True
self.assertEqual(mm.text_encoder_dtype(), torch.bfloat16)
def test_default_fp16_for_cpu_device(self):
for name in (
"fp8_e4m3fn_text_enc",
"fp8_e5m2_text_enc",
"fp16_text_enc",
"bf16_text_enc",
"fp32_text_enc",
):
setattr(args, name, False)
self.assertEqual(mm.text_encoder_dtype(torch.device("cpu")), torch.float16)
class TestIntermediateDtype(unittest.TestCase):
def tearDown(self):
args.fp16_intermediates = False
def test_fp16_intermediates_flag(self):
args.fp16_intermediates = True
self.assertEqual(mm.intermediate_dtype(), torch.float16)
args.fp16_intermediates = False
self.assertEqual(mm.intermediate_dtype(), torch.float32)
class TestVaeDtype(unittest.TestCase):
def tearDown(self):
for name in ("fp16_vae", "bf16_vae", "fp32_vae"):
setattr(args, name, False)
def test_cli_overrides(self):
args.fp16_vae = True
self.assertEqual(mm.vae_dtype(), torch.float16)
args.fp16_vae = False
args.fp32_vae = True
self.assertEqual(mm.vae_dtype(), torch.float32)
class TestLoraComputeDtype(unittest.TestCase):
def tearDown(self):
mm.LORA_COMPUTE_DTYPES.clear()
def test_caches_per_device(self):
dev = torch.device("cpu")
with patch.object(mm, "should_use_fp16", return_value=True):
d1 = mm.lora_compute_dtype(dev)
self.assertEqual(d1, torch.float16)
with patch.object(mm, "should_use_fp16", return_value=False):
d2 = mm.lora_compute_dtype(dev)
self.assertEqual(d2, torch.float16)
#class TestGetSupportedFloat8Types(unittest.TestCase):
# def test_returns_list_of_dtypes(self):
# types = mm.get_supported_float8_types()
# self.assertIsInstance(types, list)
# for t in types:
# self.assertTrue(issubclass(t, torch.dtype))
class TestLoadedModelEquality(unittest.TestCase):
def test_eq_same_model_reference(self):
class Dummy:
load_device = torch.device("cpu")
parent = None
d = Dummy()
a = mm.LoadedModel(d)
b = mm.LoadedModel(d)
self.assertEqual(a, b)
if __name__ == "__main__":
unittest.main()