Add tiled VAE lane to MultiGPU Work Units
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

This commit is contained in:
John Pollock 2026-05-22 12:32:30 -05:00
parent 74b0a826ea
commit 4d3d68e473
6 changed files with 366 additions and 21 deletions

View File

@ -182,18 +182,23 @@ def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int):
"""
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
limit_extra_devices = full_extra_devices[:max_gpus - 1]
if len(limit_extra_devices) == 0:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.")
return upscale_model
cloned = copy.copy(upscale_model)
existing = getattr(upscale_model, 'multigpu_clones', None)
clones: dict[torch.device, object] = dict(existing) if existing else {}
limit_extra_device_set = set(limit_extra_devices)
clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {}
if len(limit_extra_devices) == 0:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.")
if hasattr(cloned, 'multigpu_clones'):
del cloned.multigpu_clones
return cloned
for device in limit_extra_devices:
if device in clones:
continue
clone_desc = copy.deepcopy(upscale_model)
clone_source = copy.copy(upscale_model)
if hasattr(clone_source, 'multigpu_clones'):
del clone_source.multigpu_clones
clone_desc = copy.deepcopy(clone_source)
clone_desc.model.eval()
for p in clone_desc.model.parameters():
p.requires_grad_(False)
@ -205,6 +210,53 @@ def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int):
return cloned
def create_vae_multigpu_deepclones(vae, max_gpus: int):
"""Return a shallow copy of ``vae`` with a ``multigpu_clones`` dict of CPU-resident VAE
deepclones, one per extra CUDA device up to ``max_gpus``.
"""
vae.throw_exception_if_invalid()
vae_device = torch.device(vae.device)
cloned = copy.copy(vae)
if hasattr(cloned, 'multigpu_clones'):
del cloned.multigpu_clones
if vae_device.type == "cpu":
logging.info("CPU VAE selected, skipping initializing MultiGPU VAE clones.")
return cloned
full_extra_devices = comfy.model_management.get_all_torch_devices()
def is_vae_device(device):
return device.type == vae_device.type and device.index == vae_device.index
limit_extra_devices = [d for d in full_extra_devices if not is_vae_device(d)][:max_gpus - 1]
if len(limit_extra_devices) == 0:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU VAE clones.")
return cloned
existing = getattr(vae, 'multigpu_clones', None)
limit_extra_device_set = set(limit_extra_devices)
clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {}
for device in limit_extra_devices:
if device in clones:
continue
cloned_patcher = vae.patcher.deepclone_multigpu(new_load_device=device)
clone_vae = copy.copy(vae)
if hasattr(clone_vae, 'multigpu_clones'):
del clone_vae.multigpu_clones
clone_vae.first_stage_model = cloned_patcher.model
clone_vae.patcher = cloned_patcher
clone_vae.first_stage_model.eval()
for p in clone_vae.first_stage_model.parameters():
p.requires_grad_(False)
clone_vae.first_stage_model.to("cpu")
clones[device] = clone_vae
logging.info(f"Created CPU VAE deepclone for {device}")
cloned.multigpu_clones = clones
return cloned
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'

View File

@ -972,6 +972,26 @@ class VAE:
pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: decode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
output = self.process_output(
(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar))
/ 3.0)
return output
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
@ -981,16 +1001,49 @@ class VAE:
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3:
memory_shape = samples.shape
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
else:
og_shape = samples.shape
memory_shape = og_shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: decode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_decode(memory_shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = clone_decode_fn_factory(c, dev)
return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: decode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@ -1000,6 +1053,25 @@ class VAE:
pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: encode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_encode(pixel_samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
samples = comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples /= 3.0
return samples
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@ -1009,6 +1081,7 @@ class VAE:
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio
else:
@ -1018,8 +1091,24 @@ class VAE:
overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).reshape(1, out_channels, -1).to(dtype=c.vae_output_dtype()))
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: encode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = clone_encode_fn_factory(c, dev)
out = comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
else:
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1:
return out
else:
@ -1027,6 +1116,21 @@ class VAE:
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: encode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
return comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in, vae_options={}):
@ -1727,8 +1831,14 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
if output_vae and out[2] is not None and hasattr(out[2], "patcher"):
out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options, disable_dynamic))
return out
def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
_, _, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=False, embedding_directory=embedding_directory, output_model=False, model_options=model_options, te_model_options=te_model_options, disable_dynamic=disable_dynamic)
return vae.patcher
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
embedding_directory=embedding_directory,
@ -1954,6 +2064,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
return model
def load_vae_patcher(vae_path, metadata=None, device=None):
"""Reload a VAE from disk and return its patcher.
Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so that
:meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
fresh VAE patcher with no inherited source-device storage tracking. The
optional device matches the source loader's VAE initialization path; the
cloned patcher's load_device still controls the device targeted by the
multigpu clone. Without this, bare ``copy.deepcopy`` of the VAE wrapper
carries dynamic-VRAM allocator state forward to the clone, which causes
per-device worker threads in tiled encode/decode dispatch to access weights
through the source-device buffer."""
if metadata is None:
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
else:
sd = comfy.utils.load_torch_file(vae_path)
vae = VAE(sd=sd, metadata=metadata, device=device)
vae.throw_exception_if_invalid()
return vae.patcher
def load_unet(unet_path, dtype=None):
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype})

View File

@ -1263,9 +1263,7 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8,
continue
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
all_positions = list(itertools.product(*positions))
split = {devices[i]: all_positions[i::len(devices)] for i in range(len(devices))}
split = {devices[i]: itertools.islice(itertools.product(*positions), i, None, len(devices)) for i in range(len(devices))}
out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:])
div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:])
@ -1277,7 +1275,8 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8,
def worker(device, my_positions):
try:
torch.cuda.set_device(device)
if device.type == "cuda":
torch.cuda.set_device(device)
fn = functions[device]
local_buf = bufs[device]
local_div = divs[device]
@ -1306,17 +1305,24 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8,
o = local_buf
o_d = local_div
ps_view = ps
mask_view = mask
for d in range(dims):
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
l = min(ps_view.shape[d + 2], o.shape[d + 2] - upscaled[d])
o = o.narrow(d + 2, upscaled[d], l)
o_d = o_d.narrow(d + 2, upscaled[d], l)
if l < ps_view.shape[d + 2]:
ps_view = ps_view.narrow(d + 2, 0, l)
mask_view = mask_view.narrow(d + 2, 0, l)
o.add_(ps * mask)
o_d.add_(mask)
o.add_(ps_view * mask_view)
o_d.add_(mask_view)
if pbar is not None:
with pbar_lock:
pbar.update(1)
torch.cuda.synchronize(device)
if device.type == "cuda":
torch.cuda.synchronize(device)
except BaseException as e:
with worker_lock:
worker_errors.append(e)
@ -1330,7 +1336,7 @@ def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8,
raise worker_errors[0]
combined_buf = sum(bufs.values())
combined_div = sum(divs.values()).clamp_(min=1e-12)
combined_div = sum(divs.values())
output[b:b+1] = combined_buf / combined_div
return output

View File

@ -13,8 +13,8 @@ import comfy.multigpu
class MultiGPUCFGSplitNode(io.ComfyNode):
"""
Attaches per-device deepclones to any connected MODEL and/or UPSCALE_MODEL so downstream
nodes that recognize the attached state dispatch their work across multiple GPUs.
Attaches per-device deepclones to any connected MODEL, UPSCALE_MODEL, and/or VAE so
downstream nodes that recognize the attached state dispatch their work across multiple GPUs.
Place after nodes that modify the model object itself (compile, attention-switch, etc.).
Otherwise position is not order-sensitive.
@ -30,21 +30,25 @@ class MultiGPUCFGSplitNode(io.ComfyNode):
inputs=[
io.Model.Input("model", optional=True),
io.UpscaleModel.Input("upscale_model", optional=True),
io.Vae.Input("vae", optional=True),
io.Int.Input("max_gpus", default=2, min=1, step=1),
],
outputs=[
io.Model.Output(),
io.UpscaleModel.Output(),
io.Vae.Output(),
],
)
@classmethod
def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None) -> io.NodeOutput:
def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None, vae=None) -> io.NodeOutput:
if model is not None:
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
if upscale_model is not None:
upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus)
return io.NodeOutput(model, upscale_model)
if vae is not None:
vae = comfy.multigpu.create_vae_multigpu_deepclones(vae, max_gpus)
return io.NodeOutput(model, upscale_model, vae)
class MultiGPUOptionsNode(io.ComfyNode):

View File

@ -869,6 +869,7 @@ class VAELoader:
#TODO: scale factor?
def load_vae(self, vae_name, device="default"):
metadata = None
vae_path = None
if vae_name == "pixel_space":
sd = {}
sd["pixel_space_vae"] = torch.tensor(1.0)
@ -888,6 +889,11 @@ class VAELoader:
resolved = comfy.model_management.resolve_gpu_device_option(device)
vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved)
vae.throw_exception_if_invalid()
# Register a reload factory on the patcher so MultiGPU work-units can use
# ModelPatcher.deepclone_multigpu to produce per-device clones from the
# same loader context (mirrors UNETLoader / CLIPLoader / checkpoint loader).
if vae_path is not None:
vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, resolved))
return (vae,)
class ControlNetLoader:

View File

@ -0,0 +1,147 @@
import importlib
import sys
import types
import torch
import comfy.utils
def install_fake_comfy_aimdo(monkeypatch):
package = types.ModuleType("comfy_aimdo")
package.__path__ = []
monkeypatch.setitem(sys.modules, "comfy_aimdo", package)
for name in ("vram_buffer", "host_buffer", "torch", "model_vbar", "model_mmap", "control"):
module = types.ModuleType(f"comfy_aimdo.{name}")
monkeypatch.setitem(sys.modules, f"comfy_aimdo.{name}", module)
setattr(package, name, module)
def test_tiled_scale_multidim_multigpu_clips_edge_tiles(monkeypatch):
monkeypatch.setattr(torch.cuda, "set_device", lambda device: None)
monkeypatch.setattr(torch.cuda, "synchronize", lambda device: None)
scale = 1.1
def upscale(a):
return torch.ones((a.shape[0], 1, round(a.shape[-1] * scale)), dtype=a.dtype, device=a.device)
samples = torch.ones((1, 1, 11))
devices = [torch.device("cpu:0"), torch.device("cpu:1")]
actual = comfy.utils.tiled_scale_multidim_multigpu(
samples,
{device: upscale for device in devices},
tile=(5,),
overlap=2,
upscale_amount=scale,
out_channels=1,
output_device="cpu",
)
expected = comfy.utils.tiled_scale_multidim(
samples,
upscale,
tile=(5,),
overlap=2,
upscale_amount=scale,
out_channels=1,
output_device="cpu",
)
assert actual.shape == expected.shape == (1, 1, 12)
torch.testing.assert_close(actual, expected)
def test_upscale_model_deepclone_does_not_copy_existing_clone_graph(monkeypatch):
class FakeModel:
def __init__(self):
self.param = torch.nn.Parameter(torch.ones(1))
def eval(self):
return self
def parameters(self):
return [self.param]
class FakeDescriptor:
def __init__(self):
self.model = FakeModel()
self.device = None
def to(self, device):
self.device = device
return self
first_device = torch.device("cpu:0")
second_device = torch.device("cpu:1")
stale_device = torch.device("cpu:2")
existing_clone = FakeDescriptor()
stale_clone = FakeDescriptor()
source = FakeDescriptor()
source.multigpu_clones = {first_device: existing_clone, stale_device: stale_clone}
fake_model_management = types.ModuleType("comfy.model_management")
fake_model_management.get_all_torch_devices = lambda exclude_current=True: [first_device, second_device]
monkeypatch.setitem(sys.modules, "comfy.model_management", fake_model_management)
import comfy
monkeypatch.setattr(comfy, "model_management", fake_model_management, raising=False)
import comfy.multigpu
importlib.reload(comfy.multigpu)
cloned = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=3)
assert cloned is not source
assert cloned.multigpu_clones[first_device] is existing_clone
assert stale_device not in cloned.multigpu_clones
assert second_device in cloned.multigpu_clones
assert not hasattr(cloned.multigpu_clones[second_device], "multigpu_clones")
assert cloned.multigpu_clones[second_device].device == "cpu"
assert not cloned.multigpu_clones[second_device].model.param.requires_grad
single_gpu_clone = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=1)
assert single_gpu_clone is not source
assert not hasattr(single_gpu_clone, "multigpu_clones")
def test_checkpoint_loader_registers_vae_cached_patcher(monkeypatch):
install_fake_comfy_aimdo(monkeypatch)
import comfy.sd
importlib.reload(comfy.sd)
class FakeVAE:
def __init__(self):
self.patcher = types.SimpleNamespace(cached_patcher_init=None)
model_patcher = types.SimpleNamespace(cached_patcher_init=None)
vae = FakeVAE()
metadata = {"format": "checkpoint"}
monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata))
monkeypatch.setattr(
comfy.sd,
"load_state_dict_guess_config",
lambda *args, **kwargs: (model_patcher, None, vae, None),
)
comfy.sd.load_checkpoint_guess_config("checkpoint.safetensors", output_vae=True)
assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config
assert vae.patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_vae_patcher
assert vae.patcher.cached_patcher_init[1][0] == "checkpoint.safetensors"
def test_checkpoint_loader_skips_cached_patcher_for_placeholder_vae(monkeypatch):
install_fake_comfy_aimdo(monkeypatch)
import comfy.sd
importlib.reload(comfy.sd)
model_patcher = types.SimpleNamespace(cached_patcher_init=None)
placeholder_vae = types.SimpleNamespace()
metadata = {"format": "checkpoint"}
monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata))
monkeypatch.setattr(
comfy.sd,
"load_state_dict_guess_config",
lambda *args, **kwargs: (model_patcher, None, placeholder_vae, None),
)
assert comfy.sd.load_checkpoint_guess_config("diffusion_only.safetensors", output_vae=True)[2] is placeholder_vae
assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config