Add UPSCALE_MODEL lane to MultiGPU CFG Split

Introduce tiled_scale_multidim_multigpu in comfy/utils.py: a tile scheduler
that dispatches per-device tile functions through the existing
MultiGPUThreadPool and merges per-device CPU output buffers in deterministic
key order. The worker only catches BaseException at the thread boundary to
funnel errors to the main thread; bare torch.cuda.set_device and
torch.cuda.synchronize calls inside the worker fail loud if the device is
not CUDA, which is part of the primitive's contract.

Add UPSCALE_MODEL input on the MultiGPU CFG Split node and an upscale-model
descriptor deepclone helper in comfy/multigpu.py. Clones stay CPU-resident
until execute time and are returned to CPU afterward.

ImageUpscaleWithModel dispatches through tiled_scale_multidim_multigpu when
a multigpu descriptor is attached; the single-device path runs unchanged
when no clones are present.
This commit is contained in:
John Pollock 2026-05-20 15:37:09 -05:00
parent b649502c9c
commit 74b0a826ea
4 changed files with 218 additions and 11 deletions

View File

@ -1,4 +1,5 @@
from __future__ import annotations
import copy
import queue
import threading
import torch
@ -175,6 +176,35 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options:
return model
def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int):
"""Return a shallow copy of ``upscale_model`` with a ``multigpu_clones`` dict of CPU-resident
descriptor deepclones, one per extra CUDA device up to ``max_gpus``.
"""
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
limit_extra_devices = full_extra_devices[:max_gpus - 1]
if len(limit_extra_devices) == 0:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.")
return upscale_model
cloned = copy.copy(upscale_model)
existing = getattr(upscale_model, 'multigpu_clones', None)
clones: dict[torch.device, object] = dict(existing) if existing else {}
for device in limit_extra_devices:
if device in clones:
continue
clone_desc = copy.deepcopy(upscale_model)
clone_desc.model.eval()
for p in clone_desc.model.parameters():
p.requires_grad_(False)
clone_desc.to("cpu")
clones[device] = clone_desc
logging.info(f"Created CPU upscale_model descriptor deepclone for {device}")
cloned.multigpu_clones = clones
return cloned
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'

View File

@ -28,13 +28,13 @@ import numpy as np
from PIL import Image
import logging
import itertools
import threading
from torch.nn.functional import interpolate
from tqdm.auto import trange
from einops import rearrange
from comfy.cli_args import args
import json
import time
import threading
import warnings
MMAP_TORCH_FILES = args.mmap_torch_files
@ -1186,6 +1186,155 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
"""Multigpu variant of tiled_scale_multidim. ``functions`` is a dict[torch.device, callable].
Round-robin dispatches tile positions across devices via threading. Each thread maintains
its own per-device CPU output and divisor buffer, applying the same feathered overlap mask
formula as the single-device path. Buffers are summed at the end, producing output that is
bit-equivalent to ``tiled_scale_multidim`` within fp32 add-order noise.
Falls back to ``tiled_scale_multidim`` with the only function when ``len(functions) < 2``.
Falls back to single-device on the "whole input fits in one tile" branch (no parallelism
available at that granularity).
"""
devices = list(functions.keys())
if len(devices) < 2:
only_fn = next(iter(functions.values())) if functions else None
return tiled_scale_multidim(samples, only_fn, tile=tile, overlap=overlap,
upscale_amount=upscale_amount, out_channels=out_channels,
output_device=output_device, downscale=downscale,
index_formulas=index_formulas, pbar=pbar)
dims = len(tile)
if not (isinstance(upscale_amount, (tuple, list))):
upscale_amount = [upscale_amount] * dims
if not (isinstance(overlap, (tuple, list))):
overlap = [overlap] * dims
if index_formulas is None:
index_formulas = upscale_amount
if not (isinstance(index_formulas, (tuple, list))):
index_formulas = [index_formulas] * dims
def get_upscale(dim, val):
up = upscale_amount[dim]
return up(val) if callable(up) else up * val
def get_downscale(dim, val):
up = upscale_amount[dim]
return up(val) if callable(up) else val / up
def get_upscale_pos(dim, val):
up = index_formulas[dim]
return up(val) if callable(up) else up * val
def get_downscale_pos(dim, val):
up = index_formulas[dim]
return up(val) if callable(up) else val / up
if downscale:
get_scale = get_downscale
get_pos = get_downscale_pos
else:
get_scale = get_upscale
get_pos = get_upscale_pos
def mult_list_upscale(a):
return [round(get_scale(i, a[i])) for i in range(len(a))]
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
merge_device = torch.device("cpu")
pbar_lock = threading.Lock() if pbar is not None else None
primary_device = devices[0]
samples_staged = samples if samples.device.type == "cpu" else samples.to("cpu", non_blocking=False)
for b in range(samples_staged.shape[0]):
s = samples_staged[b:b+1]
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
with torch.inference_mode():
output[b:b+1] = functions[primary_device](s.to(primary_device, non_blocking=True)).to(output_device)
if pbar is not None:
pbar.update(1)
continue
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
all_positions = list(itertools.product(*positions))
split = {devices[i]: all_positions[i::len(devices)] for i in range(len(devices))}
out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:])
div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:])
bufs = {d: torch.zeros(out_shape, device=merge_device) for d in devices}
divs = {d: torch.zeros(div_shape, device=merge_device) for d in devices}
worker_errors: list[BaseException] = []
worker_lock = threading.Lock()
def worker(device, my_positions):
try:
torch.cuda.set_device(device)
fn = functions[device]
local_buf = bufs[device]
local_div = divs[device]
with torch.inference_mode():
for it in my_positions:
s_in = s
upscaled = []
for d in range(dims):
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(get_pos(d, pos)))
s_in_dev = s_in.to(device, non_blocking=True)
ps = fn(s_in_dev).to(merge_device)
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=merge_device)
for d in range(2, dims + 2):
feather = round(get_scale(d - 2, overlap[d - 2]))
if feather >= mask.shape[d]:
continue
for t in range(feather):
a = (t + 1) / feather
mask.narrow(d, t, 1).mul_(a)
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
o = local_buf
o_d = local_div
for d in range(dims):
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o.add_(ps * mask)
o_d.add_(mask)
if pbar is not None:
with pbar_lock:
pbar.update(1)
torch.cuda.synchronize(device)
except BaseException as e:
with worker_lock:
worker_errors.append(e)
threads = [threading.Thread(target=worker, args=(d, split[d])) for d in devices]
for t in threads:
t.start()
for t in threads:
t.join()
if worker_errors:
raise worker_errors[0]
combined_buf = sum(bufs.values())
combined_div = sum(divs.values()).clamp_(min=1e-12)
output[b:b+1] = combined_buf / combined_div
return output
def model_trange(*args, **kwargs):
if not comfy.memory_management.aimdo_enabled:
return trange(*args, **kwargs)

View File

@ -13,33 +13,38 @@ import comfy.multigpu
class MultiGPUCFGSplitNode(io.ComfyNode):
"""
Prepares model to have sampling accelerated via splitting work units.
Attaches per-device deepclones to any connected MODEL and/or UPSCALE_MODEL so downstream
nodes that recognize the attached state dispatch their work across multiple GPUs.
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
Other than those exceptions, this node can be placed in any order.
Place after nodes that modify the model object itself (compile, attention-switch, etc.).
Otherwise position is not order-sensitive.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MultiGPU_WorkUnits",
display_name="MultiGPU CFG Split",
display_name="MultiGPU Work Units",
category="advanced/multigpu",
description=cleandoc(cls.__doc__),
inputs=[
io.Model.Input("model"),
io.Model.Input("model", optional=True),
io.UpscaleModel.Input("upscale_model", optional=True),
io.Int.Input("max_gpus", default=2, min=1, step=1),
],
outputs=[
io.Model.Output(),
io.UpscaleModel.Output(),
],
)
@classmethod
def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
return io.NodeOutput(model)
def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None) -> io.NodeOutput:
if model is not None:
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
if upscale_model is not None:
upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus)
return io.NodeOutput(model, upscale_model)
class MultiGPUOptionsNode(io.ComfyNode):

View File

@ -81,13 +81,33 @@ class ImageUpscaleWithModel(io.ComfyNode):
output_device = comfy.model_management.intermediate_device()
multigpu_clones = getattr(upscale_model, 'multigpu_clones', None)
if multigpu_clones:
for dev, desc in multigpu_clones.items():
model_management.free_memory(memory_required, dev)
desc.to(dev)
oom = True
try:
while oom:
try:
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
if multigpu_clones:
functions = {device: lambda a: upscale_model(a.float())}
for dev, desc in multigpu_clones.items():
functions[dev] = lambda a, d=desc: d(a.float())
s = comfy.utils.tiled_scale_multidim_multigpu(
in_img,
functions,
tile=(tile, tile),
overlap=overlap,
upscale_amount=upscale_model.scale,
pbar=pbar,
output_device=output_device,
)
else:
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
oom = False
except Exception as e:
model_management.raise_non_oom(e)
@ -96,6 +116,9 @@ class ImageUpscaleWithModel(io.ComfyNode):
raise e
finally:
upscale_model.to("cpu")
if multigpu_clones:
for desc in multigpu_clones.values():
desc.to("cpu")
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
return io.NodeOutput(s)