ComfyUI/tests-unit/comfy_test/multigpu_test.py
John Pollock 4d3d68e473
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Add tiled VAE lane to MultiGPU Work Units
2026-05-22 13:42:21 -05:00

148 lines
5.2 KiB
Python

import importlib
import sys
import types
import torch
import comfy.utils
def install_fake_comfy_aimdo(monkeypatch):
package = types.ModuleType("comfy_aimdo")
package.__path__ = []
monkeypatch.setitem(sys.modules, "comfy_aimdo", package)
for name in ("vram_buffer", "host_buffer", "torch", "model_vbar", "model_mmap", "control"):
module = types.ModuleType(f"comfy_aimdo.{name}")
monkeypatch.setitem(sys.modules, f"comfy_aimdo.{name}", module)
setattr(package, name, module)
def test_tiled_scale_multidim_multigpu_clips_edge_tiles(monkeypatch):
monkeypatch.setattr(torch.cuda, "set_device", lambda device: None)
monkeypatch.setattr(torch.cuda, "synchronize", lambda device: None)
scale = 1.1
def upscale(a):
return torch.ones((a.shape[0], 1, round(a.shape[-1] * scale)), dtype=a.dtype, device=a.device)
samples = torch.ones((1, 1, 11))
devices = [torch.device("cpu:0"), torch.device("cpu:1")]
actual = comfy.utils.tiled_scale_multidim_multigpu(
samples,
{device: upscale for device in devices},
tile=(5,),
overlap=2,
upscale_amount=scale,
out_channels=1,
output_device="cpu",
)
expected = comfy.utils.tiled_scale_multidim(
samples,
upscale,
tile=(5,),
overlap=2,
upscale_amount=scale,
out_channels=1,
output_device="cpu",
)
assert actual.shape == expected.shape == (1, 1, 12)
torch.testing.assert_close(actual, expected)
def test_upscale_model_deepclone_does_not_copy_existing_clone_graph(monkeypatch):
class FakeModel:
def __init__(self):
self.param = torch.nn.Parameter(torch.ones(1))
def eval(self):
return self
def parameters(self):
return [self.param]
class FakeDescriptor:
def __init__(self):
self.model = FakeModel()
self.device = None
def to(self, device):
self.device = device
return self
first_device = torch.device("cpu:0")
second_device = torch.device("cpu:1")
stale_device = torch.device("cpu:2")
existing_clone = FakeDescriptor()
stale_clone = FakeDescriptor()
source = FakeDescriptor()
source.multigpu_clones = {first_device: existing_clone, stale_device: stale_clone}
fake_model_management = types.ModuleType("comfy.model_management")
fake_model_management.get_all_torch_devices = lambda exclude_current=True: [first_device, second_device]
monkeypatch.setitem(sys.modules, "comfy.model_management", fake_model_management)
import comfy
monkeypatch.setattr(comfy, "model_management", fake_model_management, raising=False)
import comfy.multigpu
importlib.reload(comfy.multigpu)
cloned = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=3)
assert cloned is not source
assert cloned.multigpu_clones[first_device] is existing_clone
assert stale_device not in cloned.multigpu_clones
assert second_device in cloned.multigpu_clones
assert not hasattr(cloned.multigpu_clones[second_device], "multigpu_clones")
assert cloned.multigpu_clones[second_device].device == "cpu"
assert not cloned.multigpu_clones[second_device].model.param.requires_grad
single_gpu_clone = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=1)
assert single_gpu_clone is not source
assert not hasattr(single_gpu_clone, "multigpu_clones")
def test_checkpoint_loader_registers_vae_cached_patcher(monkeypatch):
install_fake_comfy_aimdo(monkeypatch)
import comfy.sd
importlib.reload(comfy.sd)
class FakeVAE:
def __init__(self):
self.patcher = types.SimpleNamespace(cached_patcher_init=None)
model_patcher = types.SimpleNamespace(cached_patcher_init=None)
vae = FakeVAE()
metadata = {"format": "checkpoint"}
monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata))
monkeypatch.setattr(
comfy.sd,
"load_state_dict_guess_config",
lambda *args, **kwargs: (model_patcher, None, vae, None),
)
comfy.sd.load_checkpoint_guess_config("checkpoint.safetensors", output_vae=True)
assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config
assert vae.patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_vae_patcher
assert vae.patcher.cached_patcher_init[1][0] == "checkpoint.safetensors"
def test_checkpoint_loader_skips_cached_patcher_for_placeholder_vae(monkeypatch):
install_fake_comfy_aimdo(monkeypatch)
import comfy.sd
importlib.reload(comfy.sd)
model_patcher = types.SimpleNamespace(cached_patcher_init=None)
placeholder_vae = types.SimpleNamespace()
metadata = {"format": "checkpoint"}
monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata))
monkeypatch.setattr(
comfy.sd,
"load_state_dict_guess_config",
lambda *args, **kwargs: (model_patcher, None, placeholder_vae, None),
)
assert comfy.sd.load_checkpoint_guess_config("diffusion_only.safetensors", output_vae=True)[2] is placeholder_vae
assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config