mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-14 16:20:50 +08:00
211 lines
8.0 KiB
Python
211 lines
8.0 KiB
Python
import os
|
|
|
|
import pytest
|
|
import importlib
|
|
import importlib.util
|
|
|
|
torch = pytest.importorskip("torch")
|
|
|
|
|
|
def _write_safetensors(tmp_path, tensors):
|
|
import safetensors.torch
|
|
path = os.path.join(tmp_path, "test.safetensors")
|
|
safetensors.torch.save_file(tensors, path)
|
|
return path
|
|
|
|
|
|
def test_stream_state_dict_meta_is_lazy(tmp_path, monkeypatch):
|
|
if torch is None:
|
|
pytest.skip("torch not installed")
|
|
import comfy.utils
|
|
path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)})
|
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
|
calls = []
|
|
|
|
original = sd._file.read_tensor
|
|
|
|
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
|
|
calls.append(meta)
|
|
return original(meta, device, dtype, allow_gds, pin_if_cpu)
|
|
|
|
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
|
|
meta = sd.meta("a")
|
|
assert meta.shape == (2, 3)
|
|
assert meta.dtype == torch.float32
|
|
assert meta.numel == 6
|
|
assert calls == []
|
|
|
|
|
|
def test_stream_state_dict_getitem_loads_single_tensor(tmp_path, monkeypatch):
|
|
if torch is None:
|
|
pytest.skip("torch not installed")
|
|
import comfy.utils
|
|
path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32), "b": torch.ones((4,), dtype=torch.float16)})
|
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
|
calls = []
|
|
|
|
original = sd._file.read_tensor
|
|
|
|
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
|
|
calls.append(meta)
|
|
return original(meta, device, dtype, allow_gds, pin_if_cpu)
|
|
|
|
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
|
|
_ = sd["a"]
|
|
assert len(calls) == 1
|
|
assert calls[0].shape == (2, 3)
|
|
|
|
|
|
def test_stream_views_do_not_materialize(tmp_path, monkeypatch):
|
|
if torch is None:
|
|
pytest.skip("torch not installed")
|
|
import comfy.utils
|
|
path = _write_safetensors(tmp_path, {"prefix.a": torch.zeros((2, 3)), "other": torch.ones((4,))})
|
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
|
calls = []
|
|
|
|
original = sd._file.read_tensor
|
|
|
|
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
|
|
calls.append(meta)
|
|
return original(meta, device, dtype, allow_gds, pin_if_cpu)
|
|
|
|
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
|
|
view = comfy.utils.state_dict_prefix_replace(sd, {"prefix.": ""}, filter_keys=True)
|
|
_ = list(view.keys())
|
|
assert calls == []
|
|
|
|
|
|
def test_stream_load_rss_small(tmp_path):
|
|
if torch is None:
|
|
pytest.skip("torch not installed")
|
|
import comfy.utils
|
|
psutil = pytest.importorskip("psutil")
|
|
process = psutil.Process()
|
|
size_elems = 4_000_000 # ~16MB float32
|
|
tensor = torch.zeros((size_elems,), dtype=torch.float32)
|
|
path = _write_safetensors(tmp_path, {"big": tensor})
|
|
rss_before = process.memory_info().rss
|
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
|
rss_after = process.memory_info().rss
|
|
expected_size = tensor.numel() * tensor.element_size()
|
|
assert (rss_after - rss_before) < expected_size
|
|
_ = sd.meta("big")
|
|
|
|
|
|
def test_gds_path_errors_without_support(tmp_path, monkeypatch):
|
|
if torch is None:
|
|
pytest.skip("torch not installed")
|
|
import comfy.utils
|
|
path = _write_safetensors(tmp_path, {"a": torch.zeros((2, 3), dtype=torch.float32)})
|
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
|
device = torch.device("cuda")
|
|
|
|
if importlib.util.find_spec("fastsafetensors") is None:
|
|
fst = None
|
|
else:
|
|
fst = importlib.import_module("fastsafetensors")
|
|
|
|
gds_available = False
|
|
if fst is not None and torch.cuda.is_available():
|
|
gds_supported = fst.cpp.is_gds_supported(torch.cuda.current_device())
|
|
gds_available = bool(fst.cpp.is_cufile_found()) and gds_supported == 1
|
|
|
|
if not gds_available:
|
|
with pytest.raises(RuntimeError, match="GPUDirect requested"):
|
|
sd.get_tensor("a", device=device, allow_gds=True)
|
|
else:
|
|
def fail_nogds(*args, **kwargs):
|
|
raise AssertionError("nogds path used during GDS request")
|
|
|
|
monkeypatch.setattr(sd._file, "_read_tensor_nogds", fail_nogds)
|
|
t = sd.get_tensor("a", device=device, allow_gds=True)
|
|
assert t.device.type == "cuda"
|
|
|
|
|
|
def test_stream_load_without_disk_cache_keeps_cpu_weights(tmp_path):
|
|
if torch is None:
|
|
pytest.skip("torch not installed")
|
|
import comfy.utils
|
|
import comfy.disk_weights
|
|
|
|
prev_cache = comfy.disk_weights.CACHE.max_bytes
|
|
prev_gds = comfy.disk_weights.ALLOW_GDS
|
|
prev_pin = comfy.disk_weights.PIN_IF_CPU
|
|
prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
|
|
comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
|
|
|
|
try:
|
|
path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)})
|
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
|
model = torch.nn.Linear(4, 4, bias=True)
|
|
comfy.utils.load_state_dict(model, sd, strict=False)
|
|
assert model.weight.device.type == "cpu"
|
|
assert model.weight.device.type != "meta"
|
|
finally:
|
|
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
|
|
|
|
|
|
def test_lazy_disk_weights_loads_on_demand(tmp_path, monkeypatch):
|
|
if importlib.util.find_spec("fastsafetensors") is None:
|
|
pytest.skip("fastsafetensors not installed")
|
|
import comfy.utils
|
|
import comfy.disk_weights
|
|
|
|
prev_cache = comfy.disk_weights.CACHE.max_bytes
|
|
prev_gds = comfy.disk_weights.ALLOW_GDS
|
|
prev_pin = comfy.disk_weights.PIN_IF_CPU
|
|
prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
|
|
comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
|
|
|
|
try:
|
|
path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.float32), "bias": torch.zeros((4,), dtype=torch.float32)})
|
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
|
model = torch.nn.Linear(4, 4, bias=True)
|
|
calls = []
|
|
|
|
original = sd._file.read_tensor
|
|
|
|
def wrapped(meta, device, dtype, allow_gds, pin_if_cpu):
|
|
calls.append(meta)
|
|
return original(meta, device, dtype, allow_gds, pin_if_cpu)
|
|
|
|
monkeypatch.setattr(sd._file, "read_tensor", wrapped)
|
|
comfy.utils.load_state_dict(model, sd, strict=True)
|
|
assert model.weight.device.type == "meta"
|
|
assert calls == []
|
|
|
|
comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu"))
|
|
assert model.weight.device.type == "cpu"
|
|
assert len(calls) == 2
|
|
finally:
|
|
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
|
|
|
|
|
|
def test_lazy_disk_weights_respects_dtype_override(tmp_path):
|
|
if importlib.util.find_spec("fastsafetensors") is None:
|
|
pytest.skip("fastsafetensors not installed")
|
|
import comfy.utils
|
|
import comfy.disk_weights
|
|
|
|
prev_cache = comfy.disk_weights.CACHE.max_bytes
|
|
prev_gds = comfy.disk_weights.ALLOW_GDS
|
|
prev_pin = comfy.disk_weights.PIN_IF_CPU
|
|
prev_enabled = comfy.disk_weights.DISK_WEIGHTS_ENABLED
|
|
comfy.disk_weights.configure(0, allow_gds=False, pin_if_cpu=False, enabled=True)
|
|
|
|
try:
|
|
path = _write_safetensors(tmp_path, {"weight": torch.zeros((4, 4), dtype=torch.bfloat16), "bias": torch.zeros((4,), dtype=torch.bfloat16)})
|
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
|
model = torch.nn.Linear(4, 4, bias=True)
|
|
comfy.utils.load_state_dict(model, sd, strict=True)
|
|
assert model.weight.device.type == "meta"
|
|
|
|
comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu"))
|
|
assert model.weight.dtype == torch.bfloat16
|
|
|
|
comfy.disk_weights.ensure_module_materialized(model, torch.device("cpu"), dtype_override=torch.float16)
|
|
assert model.weight.dtype == torch.float16
|
|
finally:
|
|
comfy.disk_weights.configure(prev_cache, allow_gds=prev_gds, pin_if_cpu=prev_pin, enabled=prev_enabled)
|