mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-24 07:57:29 +08:00
Add UPSCALE_MODEL lane to MultiGPU CFG Split
Introduce tiled_scale_multidim_multigpu in comfy/utils.py: a tile scheduler that dispatches per-device tile functions through the existing MultiGPUThreadPool and merges per-device CPU output buffers in deterministic key order. The worker only catches BaseException at the thread boundary to funnel errors to the main thread; bare torch.cuda.set_device and torch.cuda.synchronize calls inside the worker fail loud if the device is not CUDA, which is part of the primitive's contract. Add UPSCALE_MODEL input on the MultiGPU CFG Split node and an upscale-model descriptor deepclone helper in comfy/multigpu.py. Clones stay CPU-resident until execute time and are returned to CPU afterward. ImageUpscaleWithModel dispatches through tiled_scale_multidim_multigpu when a multigpu descriptor is attached; the single-device path runs unchanged when no clones are present.
This commit is contained in:
parent
b649502c9c
commit
74b0a826ea
@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import copy
|
||||
import queue
|
||||
import threading
|
||||
import torch
|
||||
@ -175,6 +176,35 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options:
|
||||
return model
|
||||
|
||||
|
||||
def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int):
|
||||
"""Return a shallow copy of ``upscale_model`` with a ``multigpu_clones`` dict of CPU-resident
|
||||
descriptor deepclones, one per extra CUDA device up to ``max_gpus``.
|
||||
"""
|
||||
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
||||
limit_extra_devices = full_extra_devices[:max_gpus - 1]
|
||||
if len(limit_extra_devices) == 0:
|
||||
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.")
|
||||
return upscale_model
|
||||
|
||||
cloned = copy.copy(upscale_model)
|
||||
existing = getattr(upscale_model, 'multigpu_clones', None)
|
||||
clones: dict[torch.device, object] = dict(existing) if existing else {}
|
||||
|
||||
for device in limit_extra_devices:
|
||||
if device in clones:
|
||||
continue
|
||||
clone_desc = copy.deepcopy(upscale_model)
|
||||
clone_desc.model.eval()
|
||||
for p in clone_desc.model.parameters():
|
||||
p.requires_grad_(False)
|
||||
clone_desc.to("cpu")
|
||||
clones[device] = clone_desc
|
||||
logging.info(f"Created CPU upscale_model descriptor deepclone for {device}")
|
||||
|
||||
cloned.multigpu_clones = clones
|
||||
return cloned
|
||||
|
||||
|
||||
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
||||
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
||||
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
||||
|
||||
151
comfy/utils.py
151
comfy/utils.py
@ -28,13 +28,13 @@ import numpy as np
|
||||
from PIL import Image
|
||||
import logging
|
||||
import itertools
|
||||
import threading
|
||||
from torch.nn.functional import interpolate
|
||||
from tqdm.auto import trange
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
@ -1186,6 +1186,155 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
||||
|
||||
|
||||
def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
|
||||
"""Multigpu variant of tiled_scale_multidim. ``functions`` is a dict[torch.device, callable].
|
||||
|
||||
Round-robin dispatches tile positions across devices via threading. Each thread maintains
|
||||
its own per-device CPU output and divisor buffer, applying the same feathered overlap mask
|
||||
formula as the single-device path. Buffers are summed at the end, producing output that is
|
||||
bit-equivalent to ``tiled_scale_multidim`` within fp32 add-order noise.
|
||||
|
||||
Falls back to ``tiled_scale_multidim`` with the only function when ``len(functions) < 2``.
|
||||
Falls back to single-device on the "whole input fits in one tile" branch (no parallelism
|
||||
available at that granularity).
|
||||
"""
|
||||
devices = list(functions.keys())
|
||||
if len(devices) < 2:
|
||||
only_fn = next(iter(functions.values())) if functions else None
|
||||
return tiled_scale_multidim(samples, only_fn, tile=tile, overlap=overlap,
|
||||
upscale_amount=upscale_amount, out_channels=out_channels,
|
||||
output_device=output_device, downscale=downscale,
|
||||
index_formulas=index_formulas, pbar=pbar)
|
||||
|
||||
dims = len(tile)
|
||||
|
||||
if not (isinstance(upscale_amount, (tuple, list))):
|
||||
upscale_amount = [upscale_amount] * dims
|
||||
if not (isinstance(overlap, (tuple, list))):
|
||||
overlap = [overlap] * dims
|
||||
if index_formulas is None:
|
||||
index_formulas = upscale_amount
|
||||
if not (isinstance(index_formulas, (tuple, list))):
|
||||
index_formulas = [index_formulas] * dims
|
||||
|
||||
def get_upscale(dim, val):
|
||||
up = upscale_amount[dim]
|
||||
return up(val) if callable(up) else up * val
|
||||
|
||||
def get_downscale(dim, val):
|
||||
up = upscale_amount[dim]
|
||||
return up(val) if callable(up) else val / up
|
||||
|
||||
def get_upscale_pos(dim, val):
|
||||
up = index_formulas[dim]
|
||||
return up(val) if callable(up) else up * val
|
||||
|
||||
def get_downscale_pos(dim, val):
|
||||
up = index_formulas[dim]
|
||||
return up(val) if callable(up) else val / up
|
||||
|
||||
if downscale:
|
||||
get_scale = get_downscale
|
||||
get_pos = get_downscale_pos
|
||||
else:
|
||||
get_scale = get_upscale
|
||||
get_pos = get_upscale_pos
|
||||
|
||||
def mult_list_upscale(a):
|
||||
return [round(get_scale(i, a[i])) for i in range(len(a))]
|
||||
|
||||
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
|
||||
merge_device = torch.device("cpu")
|
||||
|
||||
pbar_lock = threading.Lock() if pbar is not None else None
|
||||
primary_device = devices[0]
|
||||
|
||||
samples_staged = samples if samples.device.type == "cpu" else samples.to("cpu", non_blocking=False)
|
||||
|
||||
for b in range(samples_staged.shape[0]):
|
||||
s = samples_staged[b:b+1]
|
||||
|
||||
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
|
||||
with torch.inference_mode():
|
||||
output[b:b+1] = functions[primary_device](s.to(primary_device, non_blocking=True)).to(output_device)
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||
all_positions = list(itertools.product(*positions))
|
||||
|
||||
split = {devices[i]: all_positions[i::len(devices)] for i in range(len(devices))}
|
||||
|
||||
out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:])
|
||||
div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:])
|
||||
bufs = {d: torch.zeros(out_shape, device=merge_device) for d in devices}
|
||||
divs = {d: torch.zeros(div_shape, device=merge_device) for d in devices}
|
||||
|
||||
worker_errors: list[BaseException] = []
|
||||
worker_lock = threading.Lock()
|
||||
|
||||
def worker(device, my_positions):
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
fn = functions[device]
|
||||
local_buf = bufs[device]
|
||||
local_div = divs[device]
|
||||
with torch.inference_mode():
|
||||
for it in my_positions:
|
||||
s_in = s
|
||||
upscaled = []
|
||||
for d in range(dims):
|
||||
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
|
||||
l = min(tile[d], s.shape[d + 2] - pos)
|
||||
s_in = s_in.narrow(d + 2, pos, l)
|
||||
upscaled.append(round(get_pos(d, pos)))
|
||||
|
||||
s_in_dev = s_in.to(device, non_blocking=True)
|
||||
ps = fn(s_in_dev).to(merge_device)
|
||||
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=merge_device)
|
||||
|
||||
for d in range(2, dims + 2):
|
||||
feather = round(get_scale(d - 2, overlap[d - 2]))
|
||||
if feather >= mask.shape[d]:
|
||||
continue
|
||||
for t in range(feather):
|
||||
a = (t + 1) / feather
|
||||
mask.narrow(d, t, 1).mul_(a)
|
||||
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
||||
|
||||
o = local_buf
|
||||
o_d = local_div
|
||||
for d in range(dims):
|
||||
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||
|
||||
o.add_(ps * mask)
|
||||
o_d.add_(mask)
|
||||
|
||||
if pbar is not None:
|
||||
with pbar_lock:
|
||||
pbar.update(1)
|
||||
torch.cuda.synchronize(device)
|
||||
except BaseException as e:
|
||||
with worker_lock:
|
||||
worker_errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(d, split[d])) for d in devices]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
if worker_errors:
|
||||
raise worker_errors[0]
|
||||
|
||||
combined_buf = sum(bufs.values())
|
||||
combined_div = sum(divs.values()).clamp_(min=1e-12)
|
||||
output[b:b+1] = combined_buf / combined_div
|
||||
|
||||
return output
|
||||
|
||||
def model_trange(*args, **kwargs):
|
||||
if not comfy.memory_management.aimdo_enabled:
|
||||
return trange(*args, **kwargs)
|
||||
|
||||
@ -13,33 +13,38 @@ import comfy.multigpu
|
||||
|
||||
class MultiGPUCFGSplitNode(io.ComfyNode):
|
||||
"""
|
||||
Prepares model to have sampling accelerated via splitting work units.
|
||||
Attaches per-device deepclones to any connected MODEL and/or UPSCALE_MODEL so downstream
|
||||
nodes that recognize the attached state dispatch their work across multiple GPUs.
|
||||
|
||||
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
|
||||
|
||||
Other than those exceptions, this node can be placed in any order.
|
||||
Place after nodes that modify the model object itself (compile, attention-switch, etc.).
|
||||
Otherwise position is not order-sensitive.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MultiGPU_WorkUnits",
|
||||
display_name="MultiGPU CFG Split",
|
||||
display_name="MultiGPU Work Units",
|
||||
category="advanced/multigpu",
|
||||
description=cleandoc(cls.__doc__),
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Model.Input("model", optional=True),
|
||||
io.UpscaleModel.Input("upscale_model", optional=True),
|
||||
io.Int.Input("max_gpus", default=2, min=1, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
io.UpscaleModel.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
|
||||
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
|
||||
return io.NodeOutput(model)
|
||||
def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None) -> io.NodeOutput:
|
||||
if model is not None:
|
||||
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
|
||||
if upscale_model is not None:
|
||||
upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus)
|
||||
return io.NodeOutput(model, upscale_model)
|
||||
|
||||
|
||||
class MultiGPUOptionsNode(io.ComfyNode):
|
||||
|
||||
@ -81,13 +81,33 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
|
||||
output_device = comfy.model_management.intermediate_device()
|
||||
|
||||
multigpu_clones = getattr(upscale_model, 'multigpu_clones', None)
|
||||
if multigpu_clones:
|
||||
for dev, desc in multigpu_clones.items():
|
||||
model_management.free_memory(memory_required, dev)
|
||||
desc.to(dev)
|
||||
|
||||
oom = True
|
||||
try:
|
||||
while oom:
|
||||
try:
|
||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
|
||||
if multigpu_clones:
|
||||
functions = {device: lambda a: upscale_model(a.float())}
|
||||
for dev, desc in multigpu_clones.items():
|
||||
functions[dev] = lambda a, d=desc: d(a.float())
|
||||
s = comfy.utils.tiled_scale_multidim_multigpu(
|
||||
in_img,
|
||||
functions,
|
||||
tile=(tile, tile),
|
||||
overlap=overlap,
|
||||
upscale_amount=upscale_model.scale,
|
||||
pbar=pbar,
|
||||
output_device=output_device,
|
||||
)
|
||||
else:
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
|
||||
oom = False
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
@ -96,6 +116,9 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
raise e
|
||||
finally:
|
||||
upscale_model.to("cpu")
|
||||
if multigpu_clones:
|
||||
for desc in multigpu_clones.values():
|
||||
desc.to("cpu")
|
||||
|
||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
|
||||
return io.NodeOutput(s)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user