mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-06 23:32:30 +08:00
316 lines
11 KiB
Python
316 lines
11 KiB
Python
"""
|
|
Unit tests for comfy.model_management helpers and small APIs that are safe to run on CPU.
|
|
"""
|
|
import sys
|
|
import os
|
|
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
|
|
from comfy.cli_args import args
|
|
|
|
if not torch.cuda.is_available():
|
|
args.cpu = True
|
|
|
|
import comfy.model_management as mm
|
|
|
|
|
|
class TestDtypeSize(unittest.TestCase):
|
|
def test_float16_bfloat16_float32(self):
|
|
self.assertEqual(mm.dtype_size(torch.float16), 2)
|
|
self.assertEqual(mm.dtype_size(torch.bfloat16), 2)
|
|
self.assertEqual(mm.dtype_size(torch.float32), 4)
|
|
|
|
|
|
class TestDeviceHelpers(unittest.TestCase):
|
|
def test_is_device_type_and_cpu_mps_cuda(self):
|
|
self.assertTrue(mm.is_device_cpu(torch.device("cpu")))
|
|
self.assertFalse(mm.is_device_cpu(torch.device("meta")))
|
|
self.assertTrue(mm.is_device_mps(torch.device("mps")))
|
|
self.assertTrue(mm.is_device_cuda(torch.device("cuda:0")))
|
|
self.assertTrue(mm.is_device_xpu(torch.device("xpu:0")))
|
|
|
|
def test_is_device_type_non_device(self):
|
|
self.assertFalse(mm.is_device_type(object(), "cpu"))
|
|
|
|
def test_get_autocast_device(self):
|
|
self.assertEqual(mm.get_autocast_device(torch.device("cuda:0")), "cuda")
|
|
self.assertEqual(mm.get_autocast_device(torch.device("cpu")), "cpu")
|
|
self.assertEqual(mm.get_autocast_device(0), "cuda")
|
|
|
|
|
|
class TestSupportsDtype(unittest.TestCase):
|
|
def test_float32_always_true(self):
|
|
self.assertTrue(mm.supports_dtype(torch.device("cpu"), torch.float32))
|
|
self.assertTrue(mm.supports_dtype(torch.device("cuda:0"), torch.float32))
|
|
|
|
def test_cpu_non_fp32(self):
|
|
self.assertFalse(mm.supports_dtype(torch.device("cpu"), torch.float16))
|
|
self.assertFalse(mm.supports_dtype(torch.device("cpu"), torch.bfloat16))
|
|
|
|
|
|
class TestPickWeightDtype(unittest.TestCase):
|
|
def test_none_uses_fallback(self):
|
|
d = mm.pick_weight_dtype(None, torch.float16, device=torch.device("cpu"))
|
|
self.assertEqual(d, torch.float16)
|
|
|
|
def test_downgrades_when_larger_than_fallback(self):
|
|
d = mm.pick_weight_dtype(torch.float32, torch.float16, device=torch.device("cpu"))
|
|
self.assertEqual(d, torch.float16)
|
|
|
|
def test_respects_supports_cast(self):
|
|
with patch.object(mm, "supports_cast", return_value=False):
|
|
d = mm.pick_weight_dtype(torch.float16, torch.float32, device=torch.device("cpu"))
|
|
self.assertEqual(d, torch.float32)
|
|
|
|
|
|
class TestCastTo(unittest.TestCase):
|
|
def test_same_device_no_copy_returns_same_when_dtype_matches(self):
|
|
w = torch.ones(2, 3, dtype=torch.float32)
|
|
out = mm.cast_to(w, dtype=torch.float32, device=w.device, copy=False)
|
|
self.assertIs(out, w)
|
|
|
|
def test_dtype_conversion_same_device(self):
|
|
w = torch.ones(2, 3, dtype=torch.float32)
|
|
out = mm.cast_to(w, dtype=torch.float16, device=w.device, copy=False)
|
|
self.assertEqual(out.dtype, torch.float16)
|
|
self.assertTrue(torch.allclose(out.float(), w))
|
|
|
|
|
|
class TestModuleSize(unittest.TestCase):
|
|
def test_sums_parameter_bytes(self):
|
|
m = torch.nn.Linear(4, 8, bias=True)
|
|
expected = m.weight.nbytes + m.bias.nbytes
|
|
self.assertEqual(mm.module_size(m), expected)
|
|
|
|
|
|
class TestArchiveModelDtypes(unittest.TestCase):
|
|
def test_records_param_and_buffer_dtypes(self):
|
|
m = torch.nn.Module()
|
|
m.register_parameter("w", torch.nn.Parameter(torch.zeros(1, dtype=torch.float16)))
|
|
m.register_buffer("b", torch.zeros(1, dtype=torch.bfloat16))
|
|
mm.archive_model_dtypes(m)
|
|
self.assertEqual(m.w_comfy_model_dtype, torch.float16)
|
|
self.assertEqual(m.b_comfy_model_dtype, torch.bfloat16)
|
|
|
|
|
|
class TestPlatformHelpers(unittest.TestCase):
|
|
@patch("comfy.model_management.platform.mac_ver", return_value=("14.2.1", ("", "", ""), ""))
|
|
def test_mac_version_parses(self, _mac_ver):
|
|
self.assertEqual(mm.mac_version(), (14, 2, 1))
|
|
|
|
@patch("comfy.model_management.platform.mac_ver", side_effect=ValueError("bad"))
|
|
def test_mac_version_returns_none_on_error(self, _mac_ver):
|
|
self.assertIsNone(mm.mac_version())
|
|
|
|
@patch("comfy.model_management.platform.uname")
|
|
def test_is_wsl_microsoft_suffix(self, mock_uname):
|
|
mock_uname.return_value.release = "5.10.0-Microsoft"
|
|
self.assertTrue(mm.is_wsl())
|
|
|
|
@patch("comfy.model_management.platform.uname")
|
|
def test_is_wsl_wsl2_suffix(self, mock_uname):
|
|
mock_uname.return_value.release = "5.15.0-microsoft-standard-WSL2"
|
|
self.assertTrue(mm.is_wsl())
|
|
|
|
@patch("comfy.model_management.platform.uname")
|
|
def test_is_wsl_false(self, mock_uname):
|
|
mock_uname.return_value.release = "6.8.0-31-generic"
|
|
self.assertFalse(mm.is_wsl())
|
|
|
|
|
|
class TestOomHelpers(unittest.TestCase):
|
|
def test_is_oom_torch_oom_subclass(self):
|
|
if not hasattr(torch.cuda, "OutOfMemoryError"):
|
|
self.skipTest("torch.cuda.OutOfMemoryError not available")
|
|
err = torch.cuda.OutOfMemoryError("OOM")
|
|
self.assertTrue(mm.is_oom(err))
|
|
|
|
# def test_is_oom_accelerator_error_code_2(self):
|
|
# class FakeAccel(Exception):
|
|
# error_code = 2
|
|
#
|
|
# self.assertTrue(mm.is_oom(FakeAccel()))
|
|
|
|
# def test_is_oom_accelerator_error_message(self):
|
|
# class FakeAccel(Exception):
|
|
# error_code = 0
|
|
#
|
|
# self.assertTrue(mm.is_oom(FakeAccel("CUDA out of memory")))
|
|
|
|
def test_is_oom_false_for_generic(self):
|
|
self.assertFalse(mm.is_oom(RuntimeError("other")))
|
|
|
|
def test_raise_non_oom_raises_non_oom(self):
|
|
with self.assertRaises(RuntimeError):
|
|
mm.raise_non_oom(RuntimeError("x"))
|
|
|
|
def test_raise_non_oom_swallows_oom(self):
|
|
if not hasattr(torch.cuda, "OutOfMemoryError"):
|
|
self.skipTest("torch.cuda.OutOfMemoryError not available")
|
|
mm.raise_non_oom(torch.cuda.OutOfMemoryError("OOM"))
|
|
|
|
|
|
class TestInterruptProcessing(unittest.TestCase):
|
|
def setUp(self):
|
|
mm.interrupt_current_processing(False)
|
|
|
|
def tearDown(self):
|
|
mm.interrupt_current_processing(False)
|
|
|
|
def test_interrupt_toggle_and_query(self):
|
|
self.assertFalse(mm.processing_interrupted())
|
|
mm.interrupt_current_processing(True)
|
|
self.assertTrue(mm.processing_interrupted())
|
|
|
|
def test_throw_exception_if_processing_interrupted(self):
|
|
mm.interrupt_current_processing(True)
|
|
with self.assertRaises(mm.InterruptProcessingException):
|
|
mm.throw_exception_if_processing_interrupted()
|
|
self.assertFalse(mm.processing_interrupted())
|
|
|
|
|
|
class TestCpuMpsMode(unittest.TestCase):
|
|
def test_cpu_mode_follows_cpu_state(self):
|
|
with patch.object(mm, "cpu_state", mm.CPUState.CPU):
|
|
self.assertTrue(mm.cpu_mode())
|
|
with patch.object(mm, "cpu_state", mm.CPUState.GPU):
|
|
self.assertFalse(mm.cpu_mode())
|
|
|
|
def test_mps_mode(self):
|
|
with patch.object(mm, "cpu_state", mm.CPUState.MPS):
|
|
self.assertTrue(mm.mps_mode())
|
|
with patch.object(mm, "cpu_state", mm.CPUState.CPU):
|
|
self.assertFalse(mm.mps_mode())
|
|
|
|
|
|
class TestMemoryBudgetHelpers(unittest.TestCase):
|
|
def test_extra_reserved_memory_returns_int(self):
|
|
v = mm.extra_reserved_memory()
|
|
self.assertIsInstance(v, int)
|
|
self.assertGreater(v, 0)
|
|
|
|
def test_minimum_inference_memory(self):
|
|
v = mm.minimum_inference_memory()
|
|
self.assertGreater(v, mm.extra_reserved_memory())
|
|
|
|
|
|
class TestOffloadHelpers(unittest.TestCase):
|
|
def test_offloaded_memory_sums_matching_device(self):
|
|
dev = torch.device("cpu")
|
|
m1 = MagicMock()
|
|
m1.device = dev
|
|
m1.model_offloaded_memory.return_value = 100
|
|
m2 = MagicMock()
|
|
m2.device = torch.device("cuda:0")
|
|
m2.model_offloaded_memory.return_value = 999
|
|
self.assertEqual(mm.offloaded_memory([m1, m2], dev), 100)
|
|
|
|
def test_use_more_memory_stops_when_budget_exhausted(self):
|
|
dev = torch.device("cpu")
|
|
m = MagicMock()
|
|
m.device = dev
|
|
|
|
def use_vram(n, **_kwargs):
|
|
m.model_use_more_vram.assert_called_once()
|
|
return n
|
|
|
|
m.model_use_more_vram.side_effect = use_vram
|
|
mm.use_more_memory(50, [m], dev)
|
|
m.model_use_more_vram.assert_called_once_with(50)
|
|
|
|
|
|
class TestTextEncoderDtype(unittest.TestCase):
|
|
def tearDown(self):
|
|
for name in (
|
|
"fp8_e4m3fn_text_enc",
|
|
"fp8_e5m2_text_enc",
|
|
"fp16_text_enc",
|
|
"bf16_text_enc",
|
|
"fp32_text_enc",
|
|
):
|
|
setattr(args, name, False)
|
|
|
|
def test_cli_overrides(self):
|
|
args.fp32_text_enc = True
|
|
self.assertEqual(mm.text_encoder_dtype(), torch.float32)
|
|
args.fp32_text_enc = False
|
|
args.bf16_text_enc = True
|
|
self.assertEqual(mm.text_encoder_dtype(), torch.bfloat16)
|
|
|
|
def test_default_fp16_for_cpu_device(self):
|
|
for name in (
|
|
"fp8_e4m3fn_text_enc",
|
|
"fp8_e5m2_text_enc",
|
|
"fp16_text_enc",
|
|
"bf16_text_enc",
|
|
"fp32_text_enc",
|
|
):
|
|
setattr(args, name, False)
|
|
self.assertEqual(mm.text_encoder_dtype(torch.device("cpu")), torch.float16)
|
|
|
|
|
|
class TestIntermediateDtype(unittest.TestCase):
|
|
def tearDown(self):
|
|
args.fp16_intermediates = False
|
|
|
|
def test_fp16_intermediates_flag(self):
|
|
args.fp16_intermediates = True
|
|
self.assertEqual(mm.intermediate_dtype(), torch.float16)
|
|
args.fp16_intermediates = False
|
|
self.assertEqual(mm.intermediate_dtype(), torch.float32)
|
|
|
|
|
|
class TestVaeDtype(unittest.TestCase):
|
|
def tearDown(self):
|
|
for name in ("fp16_vae", "bf16_vae", "fp32_vae"):
|
|
setattr(args, name, False)
|
|
|
|
def test_cli_overrides(self):
|
|
args.fp16_vae = True
|
|
self.assertEqual(mm.vae_dtype(), torch.float16)
|
|
args.fp16_vae = False
|
|
args.fp32_vae = True
|
|
self.assertEqual(mm.vae_dtype(), torch.float32)
|
|
|
|
|
|
class TestLoraComputeDtype(unittest.TestCase):
|
|
def tearDown(self):
|
|
mm.LORA_COMPUTE_DTYPES.clear()
|
|
|
|
def test_caches_per_device(self):
|
|
dev = torch.device("cpu")
|
|
with patch.object(mm, "should_use_fp16", return_value=True):
|
|
d1 = mm.lora_compute_dtype(dev)
|
|
self.assertEqual(d1, torch.float16)
|
|
with patch.object(mm, "should_use_fp16", return_value=False):
|
|
d2 = mm.lora_compute_dtype(dev)
|
|
self.assertEqual(d2, torch.float16)
|
|
|
|
|
|
#class TestGetSupportedFloat8Types(unittest.TestCase):
|
|
# def test_returns_list_of_dtypes(self):
|
|
# types = mm.get_supported_float8_types()
|
|
# self.assertIsInstance(types, list)
|
|
# for t in types:
|
|
# self.assertTrue(issubclass(t, torch.dtype))
|
|
|
|
|
|
class TestLoadedModelEquality(unittest.TestCase):
|
|
def test_eq_same_model_reference(self):
|
|
class Dummy:
|
|
load_device = torch.device("cpu")
|
|
parent = None
|
|
|
|
d = Dummy()
|
|
a = mm.LoadedModel(d)
|
|
b = mm.LoadedModel(d)
|
|
self.assertEqual(a, b)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|