From 5dc4e38b89503ba77d58ae450d3f3fff30f57fa8 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 22 May 2026 16:44:29 -0700 Subject: [PATCH] Defer @pollockjj's tiled-VAE and UPSCALE_MODEL MultiGPU lanes (#14066) * Revert "Add tiled VAE lane to MultiGPU Work Units" This reverts commit 4d3d68e4731cf366289f9f4ca11242f4a78956df. The tiled VAE lane will land as part of a follow-up PR alongside the UPSCALE_MODEL lane, separated from the threaded-loader fix PR (#14052) to keep the upstream merge focused. * Revert "Add UPSCALE_MODEL lane to MultiGPU CFG Split" This reverts commit 74b0a826eaa7962e5093d83a27e13c20d4acfadf. The UPSCALE_MODEL lane will land as part of a follow-up PR alongside the tiled VAE lane, separated from the threaded-loader fix PR (#14052) to keep the upstream merge focused. --------- Co-authored-by: John Pollock --- comfy/multigpu.py | 82 ------------- comfy/sd.py | 132 +-------------------- comfy/utils.py | 157 +------------------------ comfy_extras/nodes_multigpu.py | 27 ++--- comfy_extras/nodes_upscale_model.py | 25 +--- nodes.py | 6 - tests-unit/comfy_test/multigpu_test.py | 147 ----------------------- 7 files changed, 12 insertions(+), 564 deletions(-) delete mode 100644 tests-unit/comfy_test/multigpu_test.py diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 2573185de..eff7d0649 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -1,5 +1,4 @@ from __future__ import annotations -import copy import queue import threading import torch @@ -176,87 +175,6 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: return model -def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int): - """Return a shallow copy of ``upscale_model`` with a ``multigpu_clones`` dict of CPU-resident - descriptor deepclones, one per extra CUDA device up to ``max_gpus``. - """ - full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) - limit_extra_devices = full_extra_devices[:max_gpus - 1] - cloned = copy.copy(upscale_model) - existing = getattr(upscale_model, 'multigpu_clones', None) - limit_extra_device_set = set(limit_extra_devices) - clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {} - if len(limit_extra_devices) == 0: - logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.") - if hasattr(cloned, 'multigpu_clones'): - del cloned.multigpu_clones - return cloned - - for device in limit_extra_devices: - if device in clones: - continue - clone_source = copy.copy(upscale_model) - if hasattr(clone_source, 'multigpu_clones'): - del clone_source.multigpu_clones - clone_desc = copy.deepcopy(clone_source) - clone_desc.model.eval() - for p in clone_desc.model.parameters(): - p.requires_grad_(False) - clone_desc.to("cpu") - clones[device] = clone_desc - logging.info(f"Created CPU upscale_model descriptor deepclone for {device}") - - cloned.multigpu_clones = clones - return cloned - - -def create_vae_multigpu_deepclones(vae, max_gpus: int): - """Return a shallow copy of ``vae`` with a ``multigpu_clones`` dict of CPU-resident VAE - deepclones, one per extra CUDA device up to ``max_gpus``. - """ - vae.throw_exception_if_invalid() - vae_device = torch.device(vae.device) - cloned = copy.copy(vae) - if hasattr(cloned, 'multigpu_clones'): - del cloned.multigpu_clones - if vae_device.type == "cpu": - logging.info("CPU VAE selected, skipping initializing MultiGPU VAE clones.") - return cloned - - full_extra_devices = comfy.model_management.get_all_torch_devices() - - def is_vae_device(device): - return device.type == vae_device.type and device.index == vae_device.index - - limit_extra_devices = [d for d in full_extra_devices if not is_vae_device(d)][:max_gpus - 1] - if len(limit_extra_devices) == 0: - logging.info("No extra torch devices need initialization, skipping initializing MultiGPU VAE clones.") - return cloned - - existing = getattr(vae, 'multigpu_clones', None) - limit_extra_device_set = set(limit_extra_devices) - clones: dict[torch.device, object] = {d: c for d, c in dict(existing).items() if d in limit_extra_device_set} if existing else {} - - for device in limit_extra_devices: - if device in clones: - continue - cloned_patcher = vae.patcher.deepclone_multigpu(new_load_device=device) - clone_vae = copy.copy(vae) - if hasattr(clone_vae, 'multigpu_clones'): - del clone_vae.multigpu_clones - clone_vae.first_stage_model = cloned_patcher.model - clone_vae.patcher = cloned_patcher - clone_vae.first_stage_model.eval() - for p in clone_vae.first_stage_model.parameters(): - p.requires_grad_(False) - clone_vae.first_stage_model.to("cpu") - clones[device] = clone_vae - logging.info(f"Created CPU VAE deepclone for {device}") - - cloned.multigpu_clones = clones - return cloned - - LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' diff --git a/comfy/sd.py b/comfy/sd.py index 6401fdb14..1670a0486 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -972,26 +972,6 @@ class VAE: pbar = comfy.utils.ProgressBar(steps) decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) - - multigpu_clones = getattr(self, 'multigpu_clones', None) - if multigpu_clones: - functions = {self.device: decode_fn} - try: - for dev, c in multigpu_clones.items(): - model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev) - c.first_stage_model.to(dev) - for dev, c in multigpu_clones.items(): - functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) - output = self.process_output( - (comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) + - comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar) + - comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=self.upscale_ratio, output_device=self.output_device, pbar=pbar)) - / 3.0) - return output - finally: - for c in multigpu_clones.values(): - c.first_stage_model.to("cpu") - output = self.process_output( (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + @@ -1001,49 +981,16 @@ class VAE: def decode_tiled_1d(self, samples, tile_x=256, overlap=32): if samples.ndim == 3: - memory_shape = samples.shape decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) - clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype())) else: og_shape = samples.shape - memory_shape = og_shape samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1)) decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) - clone_decode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype())) - - multigpu_clones = getattr(self, 'multigpu_clones', None) - if multigpu_clones: - functions = {self.device: decode_fn} - try: - for dev, c in multigpu_clones.items(): - model_management.free_memory(c.model_size() + c.memory_used_decode(memory_shape, c.vae_dtype), dev) - c.first_stage_model.to(dev) - for dev, c in multigpu_clones.items(): - functions[dev] = clone_decode_fn_factory(c, dev) - return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) - finally: - for c in multigpu_clones.values(): - c.first_stage_model.to("cpu") return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) - - multigpu_clones = getattr(self, 'multigpu_clones', None) - if multigpu_clones: - functions = {self.device: decode_fn} - try: - for dev, c in multigpu_clones.items(): - model_management.free_memory(c.model_size() + c.memory_used_decode(samples.shape, c.vae_dtype), dev) - c.first_stage_model.to(dev) - for dev, c in multigpu_clones.items(): - functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.decode(a.to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) - return self.process_output(comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) - finally: - for c in multigpu_clones.values(): - c.first_stage_model.to("cpu") - return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): @@ -1053,25 +1000,6 @@ class VAE: pbar = comfy.utils.ProgressBar(steps) encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) - - multigpu_clones = getattr(self, 'multigpu_clones', None) - if multigpu_clones: - functions = {self.device: encode_fn} - try: - for dev, c in multigpu_clones.items(): - model_management.free_memory(c.model_size() + c.memory_used_encode(pixel_samples.shape, c.vae_dtype), dev) - c.first_stage_model.to(dev) - for dev, c in multigpu_clones.items(): - functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) - samples = comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y, tile_x), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) - samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y // 2, tile_x * 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) - samples += comfy.utils.tiled_scale_multidim_multigpu(pixel_samples, functions, tile=(tile_y * 2, tile_x // 2), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) - samples /= 3.0 - return samples - finally: - for c in multigpu_clones.values(): - c.first_stage_model.to("cpu") - samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) @@ -1081,7 +1009,6 @@ class VAE: def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): if self.latent_dim == 1: encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) - clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).to(dtype=c.vae_output_dtype())) out_channels = self.latent_channels upscale_amount = 1 / self.downscale_ratio else: @@ -1091,24 +1018,8 @@ class VAE: overlap = overlap // extra_channel_size upscale_amount = 1 / self.downscale_ratio encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype()) - clone_encode_fn_factory = lambda c, dev: (lambda a: c.first_stage_model.encode((c.process_input(a)).to(c.vae_dtype).to(dev)).reshape(1, out_channels, -1).to(dtype=c.vae_output_dtype())) - - multigpu_clones = getattr(self, 'multigpu_clones', None) - if multigpu_clones: - functions = {self.device: encode_fn} - try: - for dev, c in multigpu_clones.items(): - model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev) - c.first_stage_model.to(dev) - for dev, c in multigpu_clones.items(): - functions[dev] = clone_encode_fn_factory(c, dev) - out = comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) - finally: - for c in multigpu_clones.values(): - c.first_stage_model.to("cpu") - else: - out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) + out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) if self.latent_dim == 1: return out else: @@ -1116,21 +1027,6 @@ class VAE: def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) - - multigpu_clones = getattr(self, 'multigpu_clones', None) - if multigpu_clones: - functions = {self.device: encode_fn} - try: - for dev, c in multigpu_clones.items(): - model_management.free_memory(c.model_size() + c.memory_used_encode(samples.shape, c.vae_dtype), dev) - c.first_stage_model.to(dev) - for dev, c in multigpu_clones.items(): - functions[dev] = lambda a, _c=c, _dev=dev: _c.first_stage_model.encode((_c.process_input(a)).to(_c.vae_dtype).to(_dev)).to(dtype=_c.vae_output_dtype()) - return comfy.utils.tiled_scale_multidim_multigpu(samples, functions, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) - finally: - for c in multigpu_clones.values(): - c.first_stage_model.to("cpu") - return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) def decode(self, samples_in, vae_options={}): @@ -1831,14 +1727,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) if out[0] is not None: out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0) - if output_vae and out[2] is not None and hasattr(out[2], "patcher"): - out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options, disable_dynamic)) return out -def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): - _, _, vae, _ = load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=False, embedding_directory=embedding_directory, output_model=False, model_options=model_options, te_model_options=te_model_options, disable_dynamic=disable_dynamic) - return vae.patcher - def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False, embedding_directory=embedding_directory, @@ -2064,26 +1954,6 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False): model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) return model -def load_vae_patcher(vae_path, metadata=None, device=None): - """Reload a VAE from disk and return its patcher. - - Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so that - :meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a - fresh VAE patcher with no inherited source-device storage tracking. The - optional device matches the source loader's VAE initialization path; the - cloned patcher's load_device still controls the device targeted by the - multigpu clone. Without this, bare ``copy.deepcopy`` of the VAE wrapper - carries dynamic-VRAM allocator state forward to the clone, which causes - per-device worker threads in tiled encode/decode dispatch to access weights - through the source-device buffer.""" - if metadata is None: - sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) - else: - sd = comfy.utils.load_torch_file(vae_path) - vae = VAE(sd=sd, metadata=metadata, device=device) - vae.throw_exception_if_invalid() - return vae.patcher - def load_unet(unet_path, dtype=None): logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") return load_diffusion_model(unet_path, model_options={"dtype": dtype}) diff --git a/comfy/utils.py b/comfy/utils.py index abfd4079d..49ae12b06 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -28,13 +28,13 @@ import numpy as np from PIL import Image import logging import itertools -import threading from torch.nn.functional import interpolate from tqdm.auto import trange from einops import rearrange from comfy.cli_args import args import json import time +import threading import warnings MMAP_TORCH_FILES = args.mmap_torch_files @@ -1187,161 +1187,6 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) - -def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None): - """Multigpu variant of tiled_scale_multidim. ``functions`` is a dict[torch.device, callable]. - - Round-robin dispatches tile positions across devices via threading. Each thread maintains - its own per-device CPU output and divisor buffer, applying the same feathered overlap mask - formula as the single-device path. Buffers are summed at the end, producing output that is - bit-equivalent to ``tiled_scale_multidim`` within fp32 add-order noise. - - Falls back to ``tiled_scale_multidim`` with the only function when ``len(functions) < 2``. - Falls back to single-device on the "whole input fits in one tile" branch (no parallelism - available at that granularity). - """ - devices = list(functions.keys()) - if len(devices) < 2: - only_fn = next(iter(functions.values())) if functions else None - return tiled_scale_multidim(samples, only_fn, tile=tile, overlap=overlap, - upscale_amount=upscale_amount, out_channels=out_channels, - output_device=output_device, downscale=downscale, - index_formulas=index_formulas, pbar=pbar) - - dims = len(tile) - - if not (isinstance(upscale_amount, (tuple, list))): - upscale_amount = [upscale_amount] * dims - if not (isinstance(overlap, (tuple, list))): - overlap = [overlap] * dims - if index_formulas is None: - index_formulas = upscale_amount - if not (isinstance(index_formulas, (tuple, list))): - index_formulas = [index_formulas] * dims - - def get_upscale(dim, val): - up = upscale_amount[dim] - return up(val) if callable(up) else up * val - - def get_downscale(dim, val): - up = upscale_amount[dim] - return up(val) if callable(up) else val / up - - def get_upscale_pos(dim, val): - up = index_formulas[dim] - return up(val) if callable(up) else up * val - - def get_downscale_pos(dim, val): - up = index_formulas[dim] - return up(val) if callable(up) else val / up - - if downscale: - get_scale = get_downscale - get_pos = get_downscale_pos - else: - get_scale = get_upscale - get_pos = get_upscale_pos - - def mult_list_upscale(a): - return [round(get_scale(i, a[i])) for i in range(len(a))] - - output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device) - merge_device = torch.device("cpu") - - pbar_lock = threading.Lock() if pbar is not None else None - primary_device = devices[0] - - samples_staged = samples if samples.device.type == "cpu" else samples.to("cpu", non_blocking=False) - - for b in range(samples_staged.shape[0]): - s = samples_staged[b:b+1] - - if all(s.shape[d+2] <= tile[d] for d in range(dims)): - with torch.inference_mode(): - output[b:b+1] = functions[primary_device](s.to(primary_device, non_blocking=True)).to(output_device) - if pbar is not None: - pbar.update(1) - continue - - positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] - split = {devices[i]: itertools.islice(itertools.product(*positions), i, None, len(devices)) for i in range(len(devices))} - - out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]) - div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:]) - bufs = {d: torch.zeros(out_shape, device=merge_device) for d in devices} - divs = {d: torch.zeros(div_shape, device=merge_device) for d in devices} - - worker_errors: list[BaseException] = [] - worker_lock = threading.Lock() - - def worker(device, my_positions): - try: - if device.type == "cuda": - torch.cuda.set_device(device) - fn = functions[device] - local_buf = bufs[device] - local_div = divs[device] - with torch.inference_mode(): - for it in my_positions: - s_in = s - upscaled = [] - for d in range(dims): - pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) - l = min(tile[d], s.shape[d + 2] - pos) - s_in = s_in.narrow(d + 2, pos, l) - upscaled.append(round(get_pos(d, pos))) - - s_in_dev = s_in.to(device, non_blocking=True) - ps = fn(s_in_dev).to(merge_device) - mask = torch.ones([1, 1] + list(ps.shape[2:]), device=merge_device) - - for d in range(2, dims + 2): - feather = round(get_scale(d - 2, overlap[d - 2])) - if feather >= mask.shape[d]: - continue - for t in range(feather): - a = (t + 1) / feather - mask.narrow(d, t, 1).mul_(a) - mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a) - - o = local_buf - o_d = local_div - ps_view = ps - mask_view = mask - for d in range(dims): - l = min(ps_view.shape[d + 2], o.shape[d + 2] - upscaled[d]) - o = o.narrow(d + 2, upscaled[d], l) - o_d = o_d.narrow(d + 2, upscaled[d], l) - if l < ps_view.shape[d + 2]: - ps_view = ps_view.narrow(d + 2, 0, l) - mask_view = mask_view.narrow(d + 2, 0, l) - - o.add_(ps_view * mask_view) - o_d.add_(mask_view) - - if pbar is not None: - with pbar_lock: - pbar.update(1) - if device.type == "cuda": - torch.cuda.synchronize(device) - except BaseException as e: - with worker_lock: - worker_errors.append(e) - - threads = [threading.Thread(target=worker, args=(d, split[d])) for d in devices] - for t in threads: - t.start() - for t in threads: - t.join() - if worker_errors: - raise worker_errors[0] - - combined_buf = sum(bufs.values()) - combined_div = sum(divs.values()) - output[b:b+1] = combined_buf / combined_div - - return output - def model_trange(*args, **kwargs): if not comfy.memory_management.aimdo_enabled: return trange(*args, **kwargs) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index dd0f76798..fedafef71 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -13,42 +13,33 @@ import comfy.multigpu class MultiGPUCFGSplitNode(io.ComfyNode): """ - Attaches per-device deepclones to any connected MODEL, UPSCALE_MODEL, and/or VAE so - downstream nodes that recognize the attached state dispatch their work across multiple GPUs. + Prepares model to have sampling accelerated via splitting work units. - Place after nodes that modify the model object itself (compile, attention-switch, etc.). - Otherwise position is not order-sensitive. + Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes. + + Other than those exceptions, this node can be placed in any order. """ @classmethod def define_schema(cls): return io.Schema( node_id="MultiGPU_WorkUnits", - display_name="MultiGPU Work Units", + display_name="MultiGPU CFG Split", category="advanced/multigpu", description=cleandoc(cls.__doc__), inputs=[ - io.Model.Input("model", optional=True), - io.UpscaleModel.Input("upscale_model", optional=True), - io.Vae.Input("vae", optional=True), + io.Model.Input("model"), io.Int.Input("max_gpus", default=2, min=1, step=1), ], outputs=[ io.Model.Output(), - io.UpscaleModel.Output(), - io.Vae.Output(), ], ) @classmethod - def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None, vae=None) -> io.NodeOutput: - if model is not None: - model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) - if upscale_model is not None: - upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus) - if vae is not None: - vae = comfy.multigpu.create_vae_multigpu_deepclones(vae, max_gpus) - return io.NodeOutput(model, upscale_model, vae) + def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: + model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) + return io.NodeOutput(model) class MultiGPUOptionsNode(io.ComfyNode): diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 3a4e3926c..d3ee3f1c1 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -81,33 +81,13 @@ class ImageUpscaleWithModel(io.ComfyNode): output_device = comfy.model_management.intermediate_device() - multigpu_clones = getattr(upscale_model, 'multigpu_clones', None) - if multigpu_clones: - for dev, desc in multigpu_clones.items(): - model_management.free_memory(memory_required, dev) - desc.to(dev) - oom = True try: while oom: try: steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) pbar = comfy.utils.ProgressBar(steps) - if multigpu_clones: - functions = {device: lambda a: upscale_model(a.float())} - for dev, desc in multigpu_clones.items(): - functions[dev] = lambda a, d=desc: d(a.float()) - s = comfy.utils.tiled_scale_multidim_multigpu( - in_img, - functions, - tile=(tile, tile), - overlap=overlap, - upscale_amount=upscale_model.scale, - pbar=pbar, - output_device=output_device, - ) - else: - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device) oom = False except Exception as e: model_management.raise_non_oom(e) @@ -116,9 +96,6 @@ class ImageUpscaleWithModel(io.ComfyNode): raise e finally: upscale_model.to("cpu") - if multigpu_clones: - for desc in multigpu_clones.values(): - desc.to("cpu") s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype()) return io.NodeOutput(s) diff --git a/nodes.py b/nodes.py index 9193e9ddb..2f3856330 100644 --- a/nodes.py +++ b/nodes.py @@ -869,7 +869,6 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name, device="default"): metadata = None - vae_path = None if vae_name == "pixel_space": sd = {} sd["pixel_space_vae"] = torch.tensor(1.0) @@ -889,11 +888,6 @@ class VAELoader: resolved = comfy.model_management.resolve_gpu_device_option(device) vae = comfy.sd.VAE(sd=sd, metadata=metadata, device=resolved) vae.throw_exception_if_invalid() - # Register a reload factory on the patcher so MultiGPU work-units can use - # ModelPatcher.deepclone_multigpu to produce per-device clones from the - # same loader context (mirrors UNETLoader / CLIPLoader / checkpoint loader). - if vae_path is not None: - vae.patcher.cached_patcher_init = (comfy.sd.load_vae_patcher, (vae_path, metadata, resolved)) return (vae,) class ControlNetLoader: diff --git a/tests-unit/comfy_test/multigpu_test.py b/tests-unit/comfy_test/multigpu_test.py deleted file mode 100644 index e7ba15df7..000000000 --- a/tests-unit/comfy_test/multigpu_test.py +++ /dev/null @@ -1,147 +0,0 @@ -import importlib -import sys -import types - -import torch - -import comfy.utils - - -def install_fake_comfy_aimdo(monkeypatch): - package = types.ModuleType("comfy_aimdo") - package.__path__ = [] - monkeypatch.setitem(sys.modules, "comfy_aimdo", package) - for name in ("vram_buffer", "host_buffer", "torch", "model_vbar", "model_mmap", "control"): - module = types.ModuleType(f"comfy_aimdo.{name}") - monkeypatch.setitem(sys.modules, f"comfy_aimdo.{name}", module) - setattr(package, name, module) - - -def test_tiled_scale_multidim_multigpu_clips_edge_tiles(monkeypatch): - monkeypatch.setattr(torch.cuda, "set_device", lambda device: None) - monkeypatch.setattr(torch.cuda, "synchronize", lambda device: None) - - scale = 1.1 - - def upscale(a): - return torch.ones((a.shape[0], 1, round(a.shape[-1] * scale)), dtype=a.dtype, device=a.device) - - samples = torch.ones((1, 1, 11)) - devices = [torch.device("cpu:0"), torch.device("cpu:1")] - - actual = comfy.utils.tiled_scale_multidim_multigpu( - samples, - {device: upscale for device in devices}, - tile=(5,), - overlap=2, - upscale_amount=scale, - out_channels=1, - output_device="cpu", - ) - expected = comfy.utils.tiled_scale_multidim( - samples, - upscale, - tile=(5,), - overlap=2, - upscale_amount=scale, - out_channels=1, - output_device="cpu", - ) - - assert actual.shape == expected.shape == (1, 1, 12) - torch.testing.assert_close(actual, expected) - - -def test_upscale_model_deepclone_does_not_copy_existing_clone_graph(monkeypatch): - class FakeModel: - def __init__(self): - self.param = torch.nn.Parameter(torch.ones(1)) - - def eval(self): - return self - - def parameters(self): - return [self.param] - - class FakeDescriptor: - def __init__(self): - self.model = FakeModel() - self.device = None - - def to(self, device): - self.device = device - return self - - first_device = torch.device("cpu:0") - second_device = torch.device("cpu:1") - stale_device = torch.device("cpu:2") - existing_clone = FakeDescriptor() - stale_clone = FakeDescriptor() - source = FakeDescriptor() - source.multigpu_clones = {first_device: existing_clone, stale_device: stale_clone} - fake_model_management = types.ModuleType("comfy.model_management") - fake_model_management.get_all_torch_devices = lambda exclude_current=True: [first_device, second_device] - monkeypatch.setitem(sys.modules, "comfy.model_management", fake_model_management) - import comfy - monkeypatch.setattr(comfy, "model_management", fake_model_management, raising=False) - import comfy.multigpu - importlib.reload(comfy.multigpu) - - cloned = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=3) - - assert cloned is not source - assert cloned.multigpu_clones[first_device] is existing_clone - assert stale_device not in cloned.multigpu_clones - assert second_device in cloned.multigpu_clones - assert not hasattr(cloned.multigpu_clones[second_device], "multigpu_clones") - assert cloned.multigpu_clones[second_device].device == "cpu" - assert not cloned.multigpu_clones[second_device].model.param.requires_grad - - single_gpu_clone = comfy.multigpu.create_upscale_model_multigpu_deepclones(source, max_gpus=1) - assert single_gpu_clone is not source - assert not hasattr(single_gpu_clone, "multigpu_clones") - - -def test_checkpoint_loader_registers_vae_cached_patcher(monkeypatch): - install_fake_comfy_aimdo(monkeypatch) - import comfy.sd - importlib.reload(comfy.sd) - - class FakeVAE: - def __init__(self): - self.patcher = types.SimpleNamespace(cached_patcher_init=None) - - model_patcher = types.SimpleNamespace(cached_patcher_init=None) - vae = FakeVAE() - metadata = {"format": "checkpoint"} - monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata)) - monkeypatch.setattr( - comfy.sd, - "load_state_dict_guess_config", - lambda *args, **kwargs: (model_patcher, None, vae, None), - ) - - comfy.sd.load_checkpoint_guess_config("checkpoint.safetensors", output_vae=True) - - assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config - assert vae.patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_vae_patcher - assert vae.patcher.cached_patcher_init[1][0] == "checkpoint.safetensors" - - -def test_checkpoint_loader_skips_cached_patcher_for_placeholder_vae(monkeypatch): - install_fake_comfy_aimdo(monkeypatch) - import comfy.sd - importlib.reload(comfy.sd) - - model_patcher = types.SimpleNamespace(cached_patcher_init=None) - placeholder_vae = types.SimpleNamespace() - metadata = {"format": "checkpoint"} - monkeypatch.setattr(comfy.utils, "load_torch_file", lambda path, return_metadata=False: ({}, metadata)) - monkeypatch.setattr( - comfy.sd, - "load_state_dict_guess_config", - lambda *args, **kwargs: (model_patcher, None, placeholder_vae, None), - ) - - assert comfy.sd.load_checkpoint_guess_config("diffusion_only.safetensors", output_vae=True)[2] is placeholder_vae - assert model_patcher.cached_patcher_init[0] is comfy.sd.load_checkpoint_guess_config