Defer @pollockjj's tiled-VAE and UPSCALE_MODEL MultiGPU lanes (#14066)

* Revert "Add tiled VAE lane to MultiGPU Work Units"

This reverts commit 4d3d68e473.

The tiled VAE lane will land as part of a follow-up PR alongside the
UPSCALE_MODEL lane, separated from the threaded-loader fix PR (#14052)
to keep the upstream merge focused.

* Revert "Add UPSCALE_MODEL lane to MultiGPU CFG Split"

This reverts commit 74b0a826ea.

The UPSCALE_MODEL lane will land as part of a follow-up PR alongside the
tiled VAE lane, separated from the threaded-loader fix PR (#14052) to
keep the upstream merge focused.

---------

Co-authored-by: John Pollock <pollockjj@gmail.com>
This commit is contained in:
Jedrzej Kosinski 2026-05-22 16:44:29 -07:00 committed by GitHub
parent cb83c41db7
commit 5dc4e38b89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 12 additions and 564 deletions

View File

@ -1,5 +1,4 @@
from __future__ import annotations from __future__ import annotations
import copy
import queue import queue
import threading import threading
import torch import torch
@ -176,87 +175,6 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options:
return model return model
def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int):
"""Return a shallow copy of ``upscale_model`` with a ``multigpu_clones`` dict of CPU-resident
descriptor deepclones, one per extra CUDA device up to ``max_gpus``.
"""
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
limit_extra_devices = full_extra_devices[:max_gpus - 1]
cloned = copy.copy(upscale_model)
existing = getattr(upscale_model, 'multigpu_clones', None)
limit_extra_device_set = set(limit_extra_devices)
clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {}
if len(limit_extra_devices) == 0:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.")
if hasattr(cloned, 'multigpu_clones'):
del cloned.multigpu_clones
return cloned
for device in limit_extra_devices:
if device in clones:
continue
clone_source = copy.copy(upscale_model)
if hasattr(clone_source, 'multigpu_clones'):
del clone_source.multigpu_clones
clone_desc = copy.deepcopy(clone_source)
clone_desc.model.eval()
for p in clone_desc.model.parameters():
p.requires_grad_(False)
clone_desc.to("cpu")
clones[device] = clone_desc
logging.info(f"Created CPU upscale_model descriptor deepclone for {device}")
cloned.multigpu_clones = clones
return cloned
def create_vae_multigpu_deepclones(vae, max_gpus: int):
"""Return a shallow copy of ``vae`` with a ``multigpu_clones`` dict of CPU-resident VAE
deepclones, one per extra CUDA device up to ``max_gpus``.
"""
vae.throw_exception_if_invalid()
vae_device = torch.device(vae.device)
cloned = copy.copy(vae)
if hasattr(cloned, 'multigpu_clones'):
del cloned.multigpu_clones
if vae_device.type == "cpu":
logging.info("CPU VAE selected, skipping initializing MultiGPU VAE clones.")
return cloned
full_extra_devices = comfy.model_management.get_all_torch_devices()
def is_vae_device(device):
return device.type == vae_device.type and device.index == vae_device.index
limit_extra_devices = [d for d in full_extra_devices if not is_vae_device(d)][:max_gpus - 1]
if len(limit_extra_devices) == 0:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU VAE clones.")
return cloned
existing = getattr(vae, 'multigpu_clones', None)
limit_extra_device_set = set(limit_extra_devices)
clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {}
for device in limit_extra_devices:
if device in clones:
continue
cloned_patcher = vae.patcher.deepclone_multigpu(new_load_device=device)
clone_vae = copy.copy(vae)
if hasattr(clone_vae, 'multigpu_clones'):
del clone_vae.multigpu_clones
clone_vae.first_stage_model = cloned_patcher.model
clone_vae.patcher = cloned_patcher
clone_vae.first_stage_model.eval()
for p in clone_vae.first_stage_model.parameters():
p.requires_grad_(False)
clone_vae.first_stage_model.to("cpu")
clones[device] = clone_vae
logging.info(f"Created CPU VAE deepclone for {device}")
cloned.multigpu_clones = clones
return cloned
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'

View File

@ -972,26 +972,6 @@ class VAE:
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: decode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
output = self.process_output(
(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) +
comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar))
/ 3.0)
return output
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
output = self.process_output( output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
@ -1001,49 +981,16 @@ class VAE:
def decode_tiled_1d(self, samples, tile_x=256, overlap=32): def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3: if samples.ndim == 3:
memory_shape = samples.shape
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
else: else:
og_shape = samples.shape og_shape = samples.shape
memory_shape = og_shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1)) samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: decode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_decode(memory_shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = clone_decode_fn_factory(c, dev)
return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: decode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@ -1053,25 +1000,6 @@ class VAE:
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: encode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_encode(pixel_samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
samples = comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples /= 3.0
return samples
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@ -1081,7 +1009,6 @@ class VAE:
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1: if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype()))
out_channels = self.latent_channels out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio upscale_amount = 1 / self.downscale_ratio
else: else:
@ -1091,24 +1018,8 @@ class VAE:
overlap = overlap // extra_channel_size overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype()) encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).reshape(1, out_channels, -1).to(dtype=c.vae_output_dtype()))
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: encode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = clone_encode_fn_factory(c, dev)
out = comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
else:
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1: if self.latent_dim == 1:
return out return out
else: else:
@ -1116,21 +1027,6 @@ class VAE:
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
multigpu_clones = getattr(self, 'multigpu_clones', None)
if multigpu_clones:
functions = {self.device: encode_fn}
try:
for dev, c in multigpu_clones.items():
model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev)
c.first_stage_model.to(dev)
for dev, c in multigpu_clones.items():
functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype())
return comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
finally:
for c in multigpu_clones.values():
c.first_stage_model.to("cpu")
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in, vae_options={}): def decode(self, samples_in, vae_options={}):
@ -1831,14 +1727,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if out[0] is not None: if out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
if output_vae and out[2] is not None and hasattr(out[2], "patcher"):
out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options, disable_dynamic))
return out return out
def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
_, _, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=False, embedding_directory=embedding_directory, output_model=False, model_options=model_options, te_model_options=te_model_options, disable_dynamic=disable_dynamic)
return vae.patcher
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False, model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
embedding_directory=embedding_directory, embedding_directory=embedding_directory,
@ -2064,26 +1954,6 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
return model return model
def load_vae_patcher(vae_path, metadata=None, device=None):
"""Reload a VAE from disk and return its patcher.
Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so that
:meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
fresh VAE patcher with no inherited source-device storage tracking. The
optional device matches the source loader's VAE initialization path; the
cloned patcher's load_device still controls the device targeted by the
multigpu clone. Without this, bare ``copy.deepcopy`` of the VAE wrapper
carries dynamic-VRAM allocator state forward to the clone, which causes
per-device worker threads in tiled encode/decode dispatch to access weights
through the source-device buffer."""
if metadata is None:
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
else:
sd = comfy.utils.load_torch_file(vae_path)
vae = VAE(sd=sd, metadata=metadata, device=device)
vae.throw_exception_if_invalid()
return vae.patcher
def load_unet(unet_path, dtype=None): def load_unet(unet_path, dtype=None):
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype}) return load_diffusion_model(unet_path, model_options={"dtype": dtype})

View File

@ -28,13 +28,13 @@ import numpy as np
from PIL import Image from PIL import Image
import logging import logging
import itertools import itertools
import threading
from torch.nn.functional import interpolate from torch.nn.functional import interpolate
from tqdm.auto import trange from tqdm.auto import trange
from einops import rearrange from einops import rearrange
from comfy.cli_args import args from comfy.cli_args import args
import json import json
import time import time
import threading
import warnings import warnings
MMAP_TORCH_FILES = args.mmap_torch_files MMAP_TORCH_FILES = args.mmap_torch_files
@ -1187,161 +1187,6 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
"""Multigpu variant of tiled_scale_multidim. ``functions`` is a dict[torch.device, callable].
Round-robin dispatches tile positions across devices via threading. Each thread maintains
its own per-device CPU output and divisor buffer, applying the same feathered overlap mask
formula as the single-device path. Buffers are summed at the end, producing output that is
bit-equivalent to ``tiled_scale_multidim`` within fp32 add-order noise.
Falls back to ``tiled_scale_multidim`` with the only function when ``len(functions) < 2``.
Falls back to single-device on the "whole input fits in one tile" branch (no parallelism
available at that granularity).
"""
devices = list(functions.keys())
if len(devices) < 2:
only_fn = next(iter(functions.values())) if functions else None
return tiled_scale_multidim(samples, only_fn, tile=tile, overlap=overlap,
upscale_amount=upscale_amount, out_channels=out_channels,
output_device=output_device, downscale=downscale,
index_formulas=index_formulas, pbar=pbar)
dims = len(tile)
if not (isinstance(upscale_amount, (tuple, list))):
upscale_amount = [upscale_amount] * dims
if not (isinstance(overlap, (tuple, list))):
overlap = [overlap] * dims
if index_formulas is None:
index_formulas = upscale_amount
if not (isinstance(index_formulas, (tuple, list))):
index_formulas = [index_formulas] * dims
def get_upscale(dim, val):
up = upscale_amount[dim]
return up(val) if callable(up) else up * val
def get_downscale(dim, val):
up = upscale_amount[dim]
return up(val) if callable(up) else val / up
def get_upscale_pos(dim, val):
up = index_formulas[dim]
return up(val) if callable(up) else up * val
def get_downscale_pos(dim, val):
up = index_formulas[dim]
return up(val) if callable(up) else val / up
if downscale:
get_scale = get_downscale
get_pos = get_downscale_pos
else:
get_scale = get_upscale
get_pos = get_upscale_pos
def mult_list_upscale(a):
return [round(get_scale(i, a[i])) for i in range(len(a))]
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
merge_device = torch.device("cpu")
pbar_lock = threading.Lock() if pbar is not None else None
primary_device = devices[0]
samples_staged = samples if samples.device.type == "cpu" else samples.to("cpu", non_blocking=False)
for b in range(samples_staged.shape[0]):
s = samples_staged[b:b+1]
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
with torch.inference_mode():
output[b:b+1] = functions[primary_device](s.to(primary_device, non_blocking=True)).to(output_device)
if pbar is not None:
pbar.update(1)
continue
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
split = {devices[i]: itertools.islice(itertools.product(*positions), i, None, len(devices)) for i in range(len(devices))}
out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:])
div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:])
bufs = {d: torch.zeros(out_shape, device=merge_device) for d in devices}
divs = {d: torch.zeros(div_shape, device=merge_device) for d in devices}
worker_errors: list[BaseException] = []
worker_lock = threading.Lock()
def worker(device, my_positions):
try:
if device.type == "cuda":
torch.cuda.set_device(device)
fn = functions[device]
local_buf = bufs[device]
local_div = divs[device]
with torch.inference_mode():
for it in my_positions:
s_in = s
upscaled = []
for d in range(dims):
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(get_pos(d, pos)))
s_in_dev = s_in.to(device, non_blocking=True)
ps = fn(s_in_dev).to(merge_device)
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=merge_device)
for d in range(2, dims + 2):
feather = round(get_scale(d - 2, overlap[d - 2]))
if feather >= mask.shape[d]:
continue
for t in range(feather):
a = (t + 1) / feather
mask.narrow(d, t, 1).mul_(a)
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
o = local_buf
o_d = local_div
ps_view = ps
mask_view = mask
for d in range(dims):
l = min(ps_view.shape[d + 2], o.shape[d + 2] - upscaled[d])
o = o.narrow(d + 2, upscaled[d], l)
o_d = o_d.narrow(d + 2, upscaled[d], l)
if l < ps_view.shape[d + 2]:
ps_view = ps_view.narrow(d + 2, 0, l)
mask_view = mask_view.narrow(d + 2, 0, l)
o.add_(ps_view * mask_view)
o_d.add_(mask_view)
if pbar is not None:
with pbar_lock:
pbar.update(1)
if device.type == "cuda":
torch.cuda.synchronize(device)
except BaseException as e:
with worker_lock:
worker_errors.append(e)
threads = [threading.Thread(target=worker, args=(d, split[d])) for d in devices]
for t in threads:
t.start()
for t in threads:
t.join()
if worker_errors:
raise worker_errors[0]
combined_buf = sum(bufs.values())
combined_div = sum(divs.values())
output[b:b+1] = combined_buf / combined_div
return output
def model_trange(*args, **kwargs): def model_trange(*args, **kwargs):
if not comfy.memory_management.aimdo_enabled: if not comfy.memory_management.aimdo_enabled:
return trange(*args, **kwargs) return trange(*args, **kwargs)

View File

@ -13,42 +13,33 @@ import comfy.multigpu
class MultiGPUCFGSplitNode(io.ComfyNode): class MultiGPUCFGSplitNode(io.ComfyNode):
""" """
Attaches per-device deepclones to any connected MODEL, UPSCALE_MODEL, and/or VAE so Prepares model to have sampling accelerated via splitting work units.
downstream nodes that recognize the attached state dispatch their work across multiple GPUs.
Place after nodes that modify the model object itself (compile, attention-switch, etc.). Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
Otherwise position is not order-sensitive.
Other than those exceptions, this node can be placed in any order.
""" """
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="MultiGPU_WorkUnits", node_id="MultiGPU_WorkUnits",
display_name="MultiGPU Work Units", display_name="MultiGPU CFG Split",
category="advanced/multigpu", category="advanced/multigpu",
description=cleandoc(cls.__doc__), description=cleandoc(cls.__doc__),
inputs=[ inputs=[
io.Model.Input("model", optional=True), io.Model.Input("model"),
io.UpscaleModel.Input("upscale_model", optional=True),
io.Vae.Input("vae", optional=True),
io.Int.Input("max_gpus", default=2, min=1, step=1), io.Int.Input("max_gpus", default=2, min=1, step=1),
], ],
outputs=[ outputs=[
io.Model.Output(), io.Model.Output(),
io.UpscaleModel.Output(),
io.Vae.Output(),
], ],
) )
@classmethod @classmethod
def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None, vae=None) -> io.NodeOutput: def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
if model is not None: model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) return io.NodeOutput(model)
if upscale_model is not None:
upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus)
if vae is not None:
vae = comfy.multigpu.create_vae_multigpu_deepclones(vae, max_gpus)
return io.NodeOutput(model, upscale_model, vae)
class MultiGPUOptionsNode(io.ComfyNode): class MultiGPUOptionsNode(io.ComfyNode):

View File

@ -81,33 +81,13 @@ class ImageUpscaleWithModel(io.ComfyNode):
output_device = comfy.model_management.intermediate_device() output_device = comfy.model_management.intermediate_device()
multigpu_clones = getattr(upscale_model, 'multigpu_clones', None)
if multigpu_clones:
for dev, desc in multigpu_clones.items():
model_management.free_memory(memory_required, dev)
desc.to(dev)
oom = True oom = True
try: try:
while oom: while oom:
try: try:
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
if multigpu_clones: s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
functions = {device: lambda a: upscale_model(a.float())}
for dev, desc in multigpu_clones.items():
functions[dev] = lambda a, d=desc: d(a.float())
s = comfy.utils.tiled_scale_multidim_multigpu(
in_img,
functions,
tile=(tile, tile),
overlap=overlap,
upscale_amount=upscale_model.scale,
pbar=pbar,
output_device=output_device,
)
else:
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
oom = False oom = False
except Exception as e: except Exception as e:
model_management.raise_non_oom(e) model_management.raise_non_oom(e)
@ -116,9 +96,6 @@ class ImageUpscaleWithModel(io.ComfyNode):
raise e raise e
finally: finally:
upscale_model.to("cpu") upscale_model.to("cpu")
if multigpu_clones:
for desc in multigpu_clones.values():
desc.to("cpu")
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype()) s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
return io.NodeOutput(s) return io.NodeOutput(s)

View File

@ -869,7 +869,6 @@ class VAELoader:
#TODO: scale factor? #TODO: scale factor?
def load_vae(self, vae_name, device="default"): def load_vae(self, vae_name, device="default"):
metadata = None metadata = None
vae_path = None
if vae_name == "pixel_space": if vae_name == "pixel_space":
sd = {} sd = {}
sd["pixel_space_vae"] = torch.tensor(1.0) sd["pixel_space_vae"] = torch.tensor(1.0)
@ -889,11 +888,6 @@ class VAELoader:
resolved = comfy.model_management.resolve_gpu_device_option(device) resolved = comfy.model_management.resolve_gpu_device_option(device)
vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved) vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved)
vae.throw_exception_if_invalid() vae.throw_exception_if_invalid()
# Register a reload factory on the patcher so MultiGPU work-units can use
# ModelPatcher.deepclone_multigpu to produce per-device clones from the
# same loader context (mirrors UNETLoader / CLIPLoader / checkpoint loader).
if vae_path is not None:
vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, resolved))
return (vae,) return (vae,)
class ControlNetLoader: class ControlNetLoader:

View File

@ -1,147 +0,0 @@
import importlib
import sys
import types
import torch
import comfy.utils
def install_fake_comfy_aimdo(monkeypatch):
package = types.ModuleType("comfy_aimdo")
package.__path__ = []
monkeypatch.setitem(sys.modules, "comfy_aimdo", package)
for name in ("vram_buffer", "host_buffer", "torch", "model_vbar", "model_mmap", "control"):
module = types.ModuleType(f"comfy_aimdo.{name}")
monkeypatch.setitem(sys.modules, f"comfy_aimdo.{name}", module)
setattr(package, name, module)
def test_tiled_scale_multidim_multigpu_clips_edge_tiles(monkeypatch):
monkeypatch.setattr(torch.cuda, "set_device", lambda device: None)
monkeypatch.setattr(torch.cuda, "synchronize", lambda device: None)
scale = 1.1
def upscale(a):
return torch.ones((a.shape[0], 1, round(a.shape[-1] * scale)), dtype=a.dtype, device=a.device)
samples = torch.ones((1, 1, 11))
devices = [torch.device("cpu:0"), torch.device("cpu:1")]
actual = comfy.utils.tiled_scale_multidim_multigpu(
samples,
{device: upscale for device in devices},
tile=(5,),
overlap=2,
upscale_amount=scale,
out_channels=1,
output_device="cpu",
)
expected = comfy.utils.tiled_scale_multidim(
samples,
upscale,
tile=(5,),
overlap=2,
upscale_amount=scale,
out_channels=1,
output_device="cpu",
)
assert actual.shape == expected.shape == (1, 1, 12)
torch.testing.assert_close(actual, expected)
def test_upscale_model_deepclone_does_not_copy_existing_clone_graph(monkeypatch):
class FakeModel:
def __init__(self):
self.param = torch.nn.Parameter(torch.ones(1))
def eval(self):
return self
def parameters(self):
return [self.param]
class FakeDescriptor:
def __init__(self):
self.model = FakeModel()
self.device = None
def to(self, device):
self.device = device
return self
first_device = torch.device("cpu:0")
second_device = torch.device("cpu:1")
stale_device = torch.device("cpu:2")
existing_clone = FakeDescriptor()
stale_clone = FakeDescriptor()
source = FakeDescriptor()
source.multigpu_clones = {first_device: existing_clone, stale_device: stale_clone}
fake_model_management = types.ModuleType("comfy.model_management")
fake_model_management.get_all_torch_devices = lambda exclude_current=True: [first_device, second_device]
monkeypatch.setitem(sys.modules, "comfy.model_management", fake_model_management)
import comfy
monkeypatch.setattr(comfy, "model_management", fake_model_management, raising=False)
import comfy.multigpu
importlib.reload(comfy.multigpu)
cloned = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=3)
assert cloned is not source
assert cloned.multigpu_clones[first_device] is existing_clone
assert stale_device not in cloned.multigpu_clones
assert second_device in cloned.multigpu_clones
assert not hasattr(cloned.multigpu_clones[second_device], "multigpu_clones")
assert cloned.multigpu_clones[second_device].device == "cpu"
assert not cloned.multigpu_clones[second_device].model.param.requires_grad
single_gpu_clone = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=1)
assert single_gpu_clone is not source
assert not hasattr(single_gpu_clone, "multigpu_clones")
def test_checkpoint_loader_registers_vae_cached_patcher(monkeypatch):
install_fake_comfy_aimdo(monkeypatch)
import comfy.sd
importlib.reload(comfy.sd)
class FakeVAE:
def __init__(self):
self.patcher = types.SimpleNamespace(cached_patcher_init=None)
model_patcher = types.SimpleNamespace(cached_patcher_init=None)
vae = FakeVAE()
metadata = {"format": "checkpoint"}
monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata))
monkeypatch.setattr(
comfy.sd,
"load_state_dict_guess_config",
lambda *args, **kwargs: (model_patcher, None, vae, None),
)
comfy.sd.load_checkpoint_guess_config("checkpoint.safetensors", output_vae=True)
assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config
assert vae.patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_vae_patcher
assert vae.patcher.cached_patcher_init[1][0] == "checkpoint.safetensors"
def test_checkpoint_loader_skips_cached_patcher_for_placeholder_vae(monkeypatch):
install_fake_comfy_aimdo(monkeypatch)
import comfy.sd
importlib.reload(comfy.sd)
model_patcher = types.SimpleNamespace(cached_patcher_init=None)
placeholder_vae = types.SimpleNamespace()
metadata = {"format": "checkpoint"}
monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata))
monkeypatch.setattr(
comfy.sd,
"load_state_dict_guess_config",
lambda *args, **kwargs: (model_patcher, None, placeholder_vae, None),
)
assert comfy.sd.load_checkpoint_guess_config("diffusion_only.safetensors", output_vae=True)[2] is placeholder_vae
assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config