From 74b0a826eaa7962e5093d83a27e13c20d4acfadf Mon Sep 17 00:00:00 2001 From: John Pollock Date: Wed, 20 May 2026 15:37:09 -0500 Subject: [PATCH] Add UPSCALE_MODEL lane to MultiGPU CFG Split Introduce tiled_scale_multidim_multigpu in comfy/utils.py: a tile scheduler that dispatches per-device tile functions through the existing MultiGPUThreadPool and merges per-device CPU output buffers in deterministic key order. The worker only catches BaseException at the thread boundary to funnel errors to the main thread; bare torch.cuda.set_device and torch.cuda.synchronize calls inside the worker fail loud if the device is not CUDA, which is part of the primitive's contract. Add UPSCALE_MODEL input on the MultiGPU CFG Split node and an upscale-model descriptor deepclone helper in comfy/multigpu.py. Clones stay CPU-resident until execute time and are returned to CPU afterward. ImageUpscaleWithModel dispatches through tiled_scale_multidim_multigpu when a multigpu descriptor is attached; the single-device path runs unchanged when no clones are present. --- comfy/multigpu.py | 30 ++++++ comfy/utils.py | 151 +++++++++++++++++++++++++++- comfy_extras/nodes_multigpu.py | 23 +++-- comfy_extras/nodes_upscale_model.py | 25 ++++- 4 files changed, 218 insertions(+), 11 deletions(-) diff --git a/comfy/multigpu.py b/comfy/multigpu.py index eff7d0649..7f90b7db7 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -1,4 +1,5 @@ from __future__ import annotations +import copy import queue import threading import torch @@ -175,6 +176,35 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: return model +def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int): + """Return a shallow copy of ``upscale_model`` with a ``multigpu_clones`` dict of CPU-resident + descriptor deepclones, one per extra CUDA device up to ``max_gpus``. + """ + full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) + limit_extra_devices = full_extra_devices[:max_gpus - 1] + if len(limit_extra_devices) == 0: + logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.") + return upscale_model + + cloned = copy.copy(upscale_model) + existing = getattr(upscale_model, 'multigpu_clones', None) + clones: dict[torch.device, object] = dict(existing) if existing else {} + + for device in limit_extra_devices: + if device in clones: + continue + clone_desc = copy.deepcopy(upscale_model) + clone_desc.model.eval() + for p in clone_desc.model.parameters(): + p.requires_grad_(False) + clone_desc.to("cpu") + clones[device] = clone_desc + logging.info(f"Created CPU upscale_model descriptor deepclone for {device}") + + cloned.multigpu_clones = clones + return cloned + + LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' diff --git a/comfy/utils.py b/comfy/utils.py index 31052714a..c53e0cb91 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -28,13 +28,13 @@ import numpy as np from PIL import Image import logging import itertools +import threading from torch.nn.functional import interpolate from tqdm.auto import trange from einops import rearrange from comfy.cli_args import args import json import time -import threading import warnings MMAP_TORCH_FILES = args.mmap_torch_files @@ -1186,6 +1186,155 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) + +def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None): + """Multigpu variant of tiled_scale_multidim. ``functions`` is a dict[torch.device, callable]. + + Round-robin dispatches tile positions across devices via threading. Each thread maintains + its own per-device CPU output and divisor buffer, applying the same feathered overlap mask + formula as the single-device path. Buffers are summed at the end, producing output that is + bit-equivalent to ``tiled_scale_multidim`` within fp32 add-order noise. + + Falls back to ``tiled_scale_multidim`` with the only function when ``len(functions) < 2``. + Falls back to single-device on the "whole input fits in one tile" branch (no parallelism + available at that granularity). + """ + devices = list(functions.keys()) + if len(devices) < 2: + only_fn = next(iter(functions.values())) if functions else None + return tiled_scale_multidim(samples, only_fn, tile=tile, overlap=overlap, + upscale_amount=upscale_amount, out_channels=out_channels, + output_device=output_device, downscale=downscale, + index_formulas=index_formulas, pbar=pbar) + + dims = len(tile) + + if not (isinstance(upscale_amount, (tuple, list))): + upscale_amount = [upscale_amount] * dims + if not (isinstance(overlap, (tuple, list))): + overlap = [overlap] * dims + if index_formulas is None: + index_formulas = upscale_amount + if not (isinstance(index_formulas, (tuple, list))): + index_formulas = [index_formulas] * dims + + def get_upscale(dim, val): + up = upscale_amount[dim] + return up(val) if callable(up) else up * val + + def get_downscale(dim, val): + up = upscale_amount[dim] + return up(val) if callable(up) else val / up + + def get_upscale_pos(dim, val): + up = index_formulas[dim] + return up(val) if callable(up) else up * val + + def get_downscale_pos(dim, val): + up = index_formulas[dim] + return up(val) if callable(up) else val / up + + if downscale: + get_scale = get_downscale + get_pos = get_downscale_pos + else: + get_scale = get_upscale + get_pos = get_upscale_pos + + def mult_list_upscale(a): + return [round(get_scale(i, a[i])) for i in range(len(a))] + + output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device) + merge_device = torch.device("cpu") + + pbar_lock = threading.Lock() if pbar is not None else None + primary_device = devices[0] + + samples_staged = samples if samples.device.type == "cpu" else samples.to("cpu", non_blocking=False) + + for b in range(samples_staged.shape[0]): + s = samples_staged[b:b+1] + + if all(s.shape[d+2] <= tile[d] for d in range(dims)): + with torch.inference_mode(): + output[b:b+1] = functions[primary_device](s.to(primary_device, non_blocking=True)).to(output_device) + if pbar is not None: + pbar.update(1) + continue + + positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] + all_positions = list(itertools.product(*positions)) + + split = {devices[i]: all_positions[i::len(devices)] for i in range(len(devices))} + + out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]) + div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:]) + bufs = {d: torch.zeros(out_shape, device=merge_device) for d in devices} + divs = {d: torch.zeros(div_shape, device=merge_device) for d in devices} + + worker_errors: list[BaseException] = [] + worker_lock = threading.Lock() + + def worker(device, my_positions): + try: + torch.cuda.set_device(device) + fn = functions[device] + local_buf = bufs[device] + local_div = divs[device] + with torch.inference_mode(): + for it in my_positions: + s_in = s + upscaled = [] + for d in range(dims): + pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) + l = min(tile[d], s.shape[d + 2] - pos) + s_in = s_in.narrow(d + 2, pos, l) + upscaled.append(round(get_pos(d, pos))) + + s_in_dev = s_in.to(device, non_blocking=True) + ps = fn(s_in_dev).to(merge_device) + mask = torch.ones([1, 1] + list(ps.shape[2:]), device=merge_device) + + for d in range(2, dims + 2): + feather = round(get_scale(d - 2, overlap[d - 2])) + if feather >= mask.shape[d]: + continue + for t in range(feather): + a = (t + 1) / feather + mask.narrow(d, t, 1).mul_(a) + mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a) + + o = local_buf + o_d = local_div + for d in range(dims): + o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) + o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) + + o.add_(ps * mask) + o_d.add_(mask) + + if pbar is not None: + with pbar_lock: + pbar.update(1) + torch.cuda.synchronize(device) + except BaseException as e: + with worker_lock: + worker_errors.append(e) + + threads = [threading.Thread(target=worker, args=(d, split[d])) for d in devices] + for t in threads: + t.start() + for t in threads: + t.join() + if worker_errors: + raise worker_errors[0] + + combined_buf = sum(bufs.values()) + combined_div = sum(divs.values()).clamp_(min=1e-12) + output[b:b+1] = combined_buf / combined_div + + return output + def model_trange(*args, **kwargs): if not comfy.memory_management.aimdo_enabled: return trange(*args, **kwargs) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index fedafef71..021dfca3f 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -13,33 +13,38 @@ import comfy.multigpu class MultiGPUCFGSplitNode(io.ComfyNode): """ - Prepares model to have sampling accelerated via splitting work units. + Attaches per-device deepclones to any connected MODEL and/or UPSCALE_MODEL so downstream + nodes that recognize the attached state dispatch their work across multiple GPUs. - Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes. - - Other than those exceptions, this node can be placed in any order. + Place after nodes that modify the model object itself (compile, attention-switch, etc.). + Otherwise position is not order-sensitive. """ @classmethod def define_schema(cls): return io.Schema( node_id="MultiGPU_WorkUnits", - display_name="MultiGPU CFG Split", + display_name="MultiGPU Work Units", category="advanced/multigpu", description=cleandoc(cls.__doc__), inputs=[ - io.Model.Input("model"), + io.Model.Input("model", optional=True), + io.UpscaleModel.Input("upscale_model", optional=True), io.Int.Input("max_gpus", default=2, min=1, step=1), ], outputs=[ io.Model.Output(), + io.UpscaleModel.Output(), ], ) @classmethod - def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: - model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) - return io.NodeOutput(model) + def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None) -> io.NodeOutput: + if model is not None: + model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) + if upscale_model is not None: + upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus) + return io.NodeOutput(model, upscale_model) class MultiGPUOptionsNode(io.ComfyNode): diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index d3ee3f1c1..3a4e3926c 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -81,13 +81,33 @@ class ImageUpscaleWithModel(io.ComfyNode): output_device = comfy.model_management.intermediate_device() + multigpu_clones = getattr(upscale_model, 'multigpu_clones', None) + if multigpu_clones: + for dev, desc in multigpu_clones.items(): + model_management.free_memory(memory_required, dev) + desc.to(dev) + oom = True try: while oom: try: steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) pbar = comfy.utils.ProgressBar(steps) - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device) + if multigpu_clones: + functions = {device: lambda a: upscale_model(a.float())} + for dev, desc in multigpu_clones.items(): + functions[dev] = lambda a, d=desc: d(a.float()) + s = comfy.utils.tiled_scale_multidim_multigpu( + in_img, + functions, + tile=(tile, tile), + overlap=overlap, + upscale_amount=upscale_model.scale, + pbar=pbar, + output_device=output_device, + ) + else: + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device) oom = False except Exception as e: model_management.raise_non_oom(e) @@ -96,6 +116,9 @@ class ImageUpscaleWithModel(io.ComfyNode): raise e finally: upscale_model.to("cpu") + if multigpu_clones: + for desc in multigpu_clones.values(): + desc.to("cpu") s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype()) return io.NodeOutput(s)