diff --git a/tests-unit/comfy_test/test_model_management.py b/tests-unit/comfy_test/test_model_management.py new file mode 100644 index 000000000..cf1c4b352 --- /dev/null +++ b/tests-unit/comfy_test/test_model_management.py @@ -0,0 +1,315 @@ +""" +Unit tests for comfy.model_management helpers and small APIs that are safe to run on CPU. +""" +import sys +import os +import unittest +from unittest.mock import MagicMock, patch + +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from comfy.cli_args import args + +if not torch.cuda.is_available(): + args.cpu = True + +import comfy.model_management as mm + + +class TestDtypeSize(unittest.TestCase): + def test_float16_bfloat16_float32(self): + self.assertEqual(mm.dtype_size(torch.float16), 2) + self.assertEqual(mm.dtype_size(torch.bfloat16), 2) + self.assertEqual(mm.dtype_size(torch.float32), 4) + + +class TestDeviceHelpers(unittest.TestCase): + def test_is_device_type_and_cpu_mps_cuda(self): + self.assertTrue(mm.is_device_cpu(torch.device("cpu"))) + self.assertFalse(mm.is_device_cpu(torch.device("meta"))) + self.assertTrue(mm.is_device_mps(torch.device("mps"))) + self.assertTrue(mm.is_device_cuda(torch.device("cuda:0"))) + self.assertTrue(mm.is_device_xpu(torch.device("xpu:0"))) + + def test_is_device_type_non_device(self): + self.assertFalse(mm.is_device_type(object(), "cpu")) + + def test_get_autocast_device(self): + self.assertEqual(mm.get_autocast_device(torch.device("cuda:0")), "cuda") + self.assertEqual(mm.get_autocast_device(torch.device("cpu")), "cpu") + self.assertEqual(mm.get_autocast_device(0), "cuda") + + +class TestSupportsDtype(unittest.TestCase): + def test_float32_always_true(self): + self.assertTrue(mm.supports_dtype(torch.device("cpu"), torch.float32)) + self.assertTrue(mm.supports_dtype(torch.device("cuda:0"), torch.float32)) + + def test_cpu_non_fp32(self): + self.assertFalse(mm.supports_dtype(torch.device("cpu"), torch.float16)) + self.assertFalse(mm.supports_dtype(torch.device("cpu"), torch.bfloat16)) + + +class TestPickWeightDtype(unittest.TestCase): + def test_none_uses_fallback(self): + d = mm.pick_weight_dtype(None, torch.float16, device=torch.device("cpu")) + self.assertEqual(d, torch.float16) + + def test_downgrades_when_larger_than_fallback(self): + d = mm.pick_weight_dtype(torch.float32, torch.float16, device=torch.device("cpu")) + self.assertEqual(d, torch.float16) + + def test_respects_supports_cast(self): + with patch.object(mm, "supports_cast", return_value=False): + d = mm.pick_weight_dtype(torch.float16, torch.float32, device=torch.device("cpu")) + self.assertEqual(d, torch.float32) + + +class TestCastTo(unittest.TestCase): + def test_same_device_no_copy_returns_same_when_dtype_matches(self): + w = torch.ones(2, 3, dtype=torch.float32) + out = mm.cast_to(w, dtype=torch.float32, device=w.device, copy=False) + self.assertIs(out, w) + + def test_dtype_conversion_same_device(self): + w = torch.ones(2, 3, dtype=torch.float32) + out = mm.cast_to(w, dtype=torch.float16, device=w.device, copy=False) + self.assertEqual(out.dtype, torch.float16) + self.assertTrue(torch.allclose(out.float(), w)) + + +class TestModuleSize(unittest.TestCase): + def test_sums_parameter_bytes(self): + m = torch.nn.Linear(4, 8, bias=True) + expected = m.weight.nbytes + m.bias.nbytes + self.assertEqual(mm.module_size(m), expected) + + +class TestArchiveModelDtypes(unittest.TestCase): + def test_records_param_and_buffer_dtypes(self): + m = torch.nn.Module() + m.register_parameter("w", torch.nn.Parameter(torch.zeros(1, dtype=torch.float16))) + m.register_buffer("b", torch.zeros(1, dtype=torch.bfloat16)) + mm.archive_model_dtypes(m) + self.assertEqual(m.w_comfy_model_dtype, torch.float16) + self.assertEqual(m.b_comfy_model_dtype, torch.bfloat16) + + +class TestPlatformHelpers(unittest.TestCase): + @patch("comfy.model_management.platform.mac_ver", return_value=("14.2.1", ("", "", ""), "")) + def test_mac_version_parses(self, _mac_ver): + self.assertEqual(mm.mac_version(), (14, 2, 1)) + + @patch("comfy.model_management.platform.mac_ver", side_effect=ValueError("bad")) + def test_mac_version_returns_none_on_error(self, _mac_ver): + self.assertIsNone(mm.mac_version()) + + @patch("comfy.model_management.platform.uname") + def test_is_wsl_microsoft_suffix(self, mock_uname): + mock_uname.return_value.release = "5.10.0-Microsoft" + self.assertTrue(mm.is_wsl()) + + @patch("comfy.model_management.platform.uname") + def test_is_wsl_wsl2_suffix(self, mock_uname): + mock_uname.return_value.release = "5.15.0-microsoft-standard-WSL2" + self.assertTrue(mm.is_wsl()) + + @patch("comfy.model_management.platform.uname") + def test_is_wsl_false(self, mock_uname): + mock_uname.return_value.release = "6.8.0-31-generic" + self.assertFalse(mm.is_wsl()) + + +class TestOomHelpers(unittest.TestCase): + def test_is_oom_torch_oom_subclass(self): + if not hasattr(torch.cuda, "OutOfMemoryError"): + self.skipTest("torch.cuda.OutOfMemoryError not available") + err = torch.cuda.OutOfMemoryError("OOM") + self.assertTrue(mm.is_oom(err)) + +# def test_is_oom_accelerator_error_code_2(self): +# class FakeAccel(Exception): +# error_code = 2 +# +# self.assertTrue(mm.is_oom(FakeAccel())) + +# def test_is_oom_accelerator_error_message(self): +# class FakeAccel(Exception): +# error_code = 0 +# +# self.assertTrue(mm.is_oom(FakeAccel("CUDA out of memory"))) + + def test_is_oom_false_for_generic(self): + self.assertFalse(mm.is_oom(RuntimeError("other"))) + + def test_raise_non_oom_raises_non_oom(self): + with self.assertRaises(RuntimeError): + mm.raise_non_oom(RuntimeError("x")) + + def test_raise_non_oom_swallows_oom(self): + if not hasattr(torch.cuda, "OutOfMemoryError"): + self.skipTest("torch.cuda.OutOfMemoryError not available") + mm.raise_non_oom(torch.cuda.OutOfMemoryError("OOM")) + + +class TestInterruptProcessing(unittest.TestCase): + def setUp(self): + mm.interrupt_current_processing(False) + + def tearDown(self): + mm.interrupt_current_processing(False) + + def test_interrupt_toggle_and_query(self): + self.assertFalse(mm.processing_interrupted()) + mm.interrupt_current_processing(True) + self.assertTrue(mm.processing_interrupted()) + + def test_throw_exception_if_processing_interrupted(self): + mm.interrupt_current_processing(True) + with self.assertRaises(mm.InterruptProcessingException): + mm.throw_exception_if_processing_interrupted() + self.assertFalse(mm.processing_interrupted()) + + +class TestCpuMpsMode(unittest.TestCase): + def test_cpu_mode_follows_cpu_state(self): + with patch.object(mm, "cpu_state", mm.CPUState.CPU): + self.assertTrue(mm.cpu_mode()) + with patch.object(mm, "cpu_state", mm.CPUState.GPU): + self.assertFalse(mm.cpu_mode()) + + def test_mps_mode(self): + with patch.object(mm, "cpu_state", mm.CPUState.MPS): + self.assertTrue(mm.mps_mode()) + with patch.object(mm, "cpu_state", mm.CPUState.CPU): + self.assertFalse(mm.mps_mode()) + + +class TestMemoryBudgetHelpers(unittest.TestCase): + def test_extra_reserved_memory_returns_int(self): + v = mm.extra_reserved_memory() + self.assertIsInstance(v, int) + self.assertGreater(v, 0) + + def test_minimum_inference_memory(self): + v = mm.minimum_inference_memory() + self.assertGreater(v, mm.extra_reserved_memory()) + + +class TestOffloadHelpers(unittest.TestCase): + def test_offloaded_memory_sums_matching_device(self): + dev = torch.device("cpu") + m1 = MagicMock() + m1.device = dev + m1.model_offloaded_memory.return_value = 100 + m2 = MagicMock() + m2.device = torch.device("cuda:0") + m2.model_offloaded_memory.return_value = 999 + self.assertEqual(mm.offloaded_memory([m1, m2], dev), 100) + + def test_use_more_memory_stops_when_budget_exhausted(self): + dev = torch.device("cpu") + m = MagicMock() + m.device = dev + + def use_vram(n, **_kwargs): + m.model_use_more_vram.assert_called_once() + return n + + m.model_use_more_vram.side_effect = use_vram + mm.use_more_memory(50, [m], dev) + m.model_use_more_vram.assert_called_once_with(50) + + +class TestTextEncoderDtype(unittest.TestCase): + def tearDown(self): + for name in ( + "fp8_e4m3fn_text_enc", + "fp8_e5m2_text_enc", + "fp16_text_enc", + "bf16_text_enc", + "fp32_text_enc", + ): + setattr(args, name, False) + + def test_cli_overrides(self): + args.fp32_text_enc = True + self.assertEqual(mm.text_encoder_dtype(), torch.float32) + args.fp32_text_enc = False + args.bf16_text_enc = True + self.assertEqual(mm.text_encoder_dtype(), torch.bfloat16) + + def test_default_fp16_for_cpu_device(self): + for name in ( + "fp8_e4m3fn_text_enc", + "fp8_e5m2_text_enc", + "fp16_text_enc", + "bf16_text_enc", + "fp32_text_enc", + ): + setattr(args, name, False) + self.assertEqual(mm.text_encoder_dtype(torch.device("cpu")), torch.float16) + + +class TestIntermediateDtype(unittest.TestCase): + def tearDown(self): + args.fp16_intermediates = False + + def test_fp16_intermediates_flag(self): + args.fp16_intermediates = True + self.assertEqual(mm.intermediate_dtype(), torch.float16) + args.fp16_intermediates = False + self.assertEqual(mm.intermediate_dtype(), torch.float32) + + +class TestVaeDtype(unittest.TestCase): + def tearDown(self): + for name in ("fp16_vae", "bf16_vae", "fp32_vae"): + setattr(args, name, False) + + def test_cli_overrides(self): + args.fp16_vae = True + self.assertEqual(mm.vae_dtype(), torch.float16) + args.fp16_vae = False + args.fp32_vae = True + self.assertEqual(mm.vae_dtype(), torch.float32) + + +class TestLoraComputeDtype(unittest.TestCase): + def tearDown(self): + mm.LORA_COMPUTE_DTYPES.clear() + + def test_caches_per_device(self): + dev = torch.device("cpu") + with patch.object(mm, "should_use_fp16", return_value=True): + d1 = mm.lora_compute_dtype(dev) + self.assertEqual(d1, torch.float16) + with patch.object(mm, "should_use_fp16", return_value=False): + d2 = mm.lora_compute_dtype(dev) + self.assertEqual(d2, torch.float16) + + +#class TestGetSupportedFloat8Types(unittest.TestCase): +# def test_returns_list_of_dtypes(self): +# types = mm.get_supported_float8_types() +# self.assertIsInstance(types, list) +# for t in types: +# self.assertTrue(issubclass(t, torch.dtype)) + + +class TestLoadedModelEquality(unittest.TestCase): + def test_eq_same_model_reference(self): + class Dummy: + load_device = torch.device("cpu") + parent = None + + d = Dummy() + a = mm.LoadedModel(d) + b = mm.LoadedModel(d) + self.assertEqual(a, b) + + +if __name__ == "__main__": + unittest.main()