mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-26 08:57:26 +08:00
Add UPSCALE_MODEL lane to MultiGPU CFG Split
Introduce tiled_scale_multidim_multigpu in comfy/utils.py: a tile scheduler that dispatches per-device tile functions through the existing MultiGPUThreadPool and merges per-device CPU output buffers in deterministic key order. The worker only catches BaseException at the thread boundary to funnel errors to the main thread; bare torch.cuda.set_device and torch.cuda.synchronize calls inside the worker fail loud if the device is not CUDA, which is part of the primitive's contract. Add UPSCALE_MODEL input on the MultiGPU CFG Split node and an upscale-model descriptor deepclone helper in comfy/multigpu.py. Clones stay CPU-resident until execute time and are returned to CPU afterward. ImageUpscaleWithModel dispatches through tiled_scale_multidim_multigpu when a multigpu descriptor is attached; the single-device path runs unchanged when no clones are present.
This commit is contained in:
parent
b649502c9c
commit
74b0a826ea
@ -1,4 +1,5 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import copy
|
||||||
import queue
|
import queue
|
||||||
import threading
|
import threading
|
||||||
import torch
|
import torch
|
||||||
@ -175,6 +176,35 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int):
|
||||||
|
"""Return a shallow copy of ``upscale_model`` with a ``multigpu_clones`` dict of CPU-resident
|
||||||
|
descriptor deepclones, one per extra CUDA device up to ``max_gpus``.
|
||||||
|
"""
|
||||||
|
full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True)
|
||||||
|
limit_extra_devices = full_extra_devices[:max_gpus - 1]
|
||||||
|
if len(limit_extra_devices) == 0:
|
||||||
|
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.")
|
||||||
|
return upscale_model
|
||||||
|
|
||||||
|
cloned = copy.copy(upscale_model)
|
||||||
|
existing = getattr(upscale_model, 'multigpu_clones', None)
|
||||||
|
clones: dict[torch.device, object] = dict(existing) if existing else {}
|
||||||
|
|
||||||
|
for device in limit_extra_devices:
|
||||||
|
if device in clones:
|
||||||
|
continue
|
||||||
|
clone_desc = copy.deepcopy(upscale_model)
|
||||||
|
clone_desc.model.eval()
|
||||||
|
for p in clone_desc.model.parameters():
|
||||||
|
p.requires_grad_(False)
|
||||||
|
clone_desc.to("cpu")
|
||||||
|
clones[device] = clone_desc
|
||||||
|
logging.info(f"Created CPU upscale_model descriptor deepclone for {device}")
|
||||||
|
|
||||||
|
cloned.multigpu_clones = clones
|
||||||
|
return cloned
|
||||||
|
|
||||||
|
|
||||||
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
|
||||||
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
|
||||||
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
|
||||||
|
|||||||
151
comfy/utils.py
151
comfy/utils.py
@ -28,13 +28,13 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import logging
|
import logging
|
||||||
import itertools
|
import itertools
|
||||||
|
import threading
|
||||||
from torch.nn.functional import interpolate
|
from torch.nn.functional import interpolate
|
||||||
from tqdm.auto import trange
|
from tqdm.auto import trange
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import threading
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
@ -1186,6 +1186,155 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
||||||
|
|
||||||
|
|
||||||
|
def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
|
||||||
|
"""Multigpu variant of tiled_scale_multidim. ``functions`` is a dict[torch.device, callable].
|
||||||
|
|
||||||
|
Round-robin dispatches tile positions across devices via threading. Each thread maintains
|
||||||
|
its own per-device CPU output and divisor buffer, applying the same feathered overlap mask
|
||||||
|
formula as the single-device path. Buffers are summed at the end, producing output that is
|
||||||
|
bit-equivalent to ``tiled_scale_multidim`` within fp32 add-order noise.
|
||||||
|
|
||||||
|
Falls back to ``tiled_scale_multidim`` with the only function when ``len(functions) < 2``.
|
||||||
|
Falls back to single-device on the "whole input fits in one tile" branch (no parallelism
|
||||||
|
available at that granularity).
|
||||||
|
"""
|
||||||
|
devices = list(functions.keys())
|
||||||
|
if len(devices) < 2:
|
||||||
|
only_fn = next(iter(functions.values())) if functions else None
|
||||||
|
return tiled_scale_multidim(samples, only_fn, tile=tile, overlap=overlap,
|
||||||
|
upscale_amount=upscale_amount, out_channels=out_channels,
|
||||||
|
output_device=output_device, downscale=downscale,
|
||||||
|
index_formulas=index_formulas, pbar=pbar)
|
||||||
|
|
||||||
|
dims = len(tile)
|
||||||
|
|
||||||
|
if not (isinstance(upscale_amount, (tuple, list))):
|
||||||
|
upscale_amount = [upscale_amount] * dims
|
||||||
|
if not (isinstance(overlap, (tuple, list))):
|
||||||
|
overlap = [overlap] * dims
|
||||||
|
if index_formulas is None:
|
||||||
|
index_formulas = upscale_amount
|
||||||
|
if not (isinstance(index_formulas, (tuple, list))):
|
||||||
|
index_formulas = [index_formulas] * dims
|
||||||
|
|
||||||
|
def get_upscale(dim, val):
|
||||||
|
up = upscale_amount[dim]
|
||||||
|
return up(val) if callable(up) else up * val
|
||||||
|
|
||||||
|
def get_downscale(dim, val):
|
||||||
|
up = upscale_amount[dim]
|
||||||
|
return up(val) if callable(up) else val / up
|
||||||
|
|
||||||
|
def get_upscale_pos(dim, val):
|
||||||
|
up = index_formulas[dim]
|
||||||
|
return up(val) if callable(up) else up * val
|
||||||
|
|
||||||
|
def get_downscale_pos(dim, val):
|
||||||
|
up = index_formulas[dim]
|
||||||
|
return up(val) if callable(up) else val / up
|
||||||
|
|
||||||
|
if downscale:
|
||||||
|
get_scale = get_downscale
|
||||||
|
get_pos = get_downscale_pos
|
||||||
|
else:
|
||||||
|
get_scale = get_upscale
|
||||||
|
get_pos = get_upscale_pos
|
||||||
|
|
||||||
|
def mult_list_upscale(a):
|
||||||
|
return [round(get_scale(i, a[i])) for i in range(len(a))]
|
||||||
|
|
||||||
|
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
|
||||||
|
merge_device = torch.device("cpu")
|
||||||
|
|
||||||
|
pbar_lock = threading.Lock() if pbar is not None else None
|
||||||
|
primary_device = devices[0]
|
||||||
|
|
||||||
|
samples_staged = samples if samples.device.type == "cpu" else samples.to("cpu", non_blocking=False)
|
||||||
|
|
||||||
|
for b in range(samples_staged.shape[0]):
|
||||||
|
s = samples_staged[b:b+1]
|
||||||
|
|
||||||
|
if all(s.shape[d+2] <= tile[d] for d in range(dims)):
|
||||||
|
with torch.inference_mode():
|
||||||
|
output[b:b+1] = functions[primary_device](s.to(primary_device, non_blocking=True)).to(output_device)
|
||||||
|
if pbar is not None:
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||||
|
all_positions = list(itertools.product(*positions))
|
||||||
|
|
||||||
|
split = {devices[i]: all_positions[i::len(devices)] for i in range(len(devices))}
|
||||||
|
|
||||||
|
out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:])
|
||||||
|
div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:])
|
||||||
|
bufs = {d: torch.zeros(out_shape, device=merge_device) for d in devices}
|
||||||
|
divs = {d: torch.zeros(div_shape, device=merge_device) for d in devices}
|
||||||
|
|
||||||
|
worker_errors: list[BaseException] = []
|
||||||
|
worker_lock = threading.Lock()
|
||||||
|
|
||||||
|
def worker(device, my_positions):
|
||||||
|
try:
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
fn = functions[device]
|
||||||
|
local_buf = bufs[device]
|
||||||
|
local_div = divs[device]
|
||||||
|
with torch.inference_mode():
|
||||||
|
for it in my_positions:
|
||||||
|
s_in = s
|
||||||
|
upscaled = []
|
||||||
|
for d in range(dims):
|
||||||
|
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
|
||||||
|
l = min(tile[d], s.shape[d + 2] - pos)
|
||||||
|
s_in = s_in.narrow(d + 2, pos, l)
|
||||||
|
upscaled.append(round(get_pos(d, pos)))
|
||||||
|
|
||||||
|
s_in_dev = s_in.to(device, non_blocking=True)
|
||||||
|
ps = fn(s_in_dev).to(merge_device)
|
||||||
|
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=merge_device)
|
||||||
|
|
||||||
|
for d in range(2, dims + 2):
|
||||||
|
feather = round(get_scale(d - 2, overlap[d - 2]))
|
||||||
|
if feather >= mask.shape[d]:
|
||||||
|
continue
|
||||||
|
for t in range(feather):
|
||||||
|
a = (t + 1) / feather
|
||||||
|
mask.narrow(d, t, 1).mul_(a)
|
||||||
|
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
|
||||||
|
|
||||||
|
o = local_buf
|
||||||
|
o_d = local_div
|
||||||
|
for d in range(dims):
|
||||||
|
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
|
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||||
|
|
||||||
|
o.add_(ps * mask)
|
||||||
|
o_d.add_(mask)
|
||||||
|
|
||||||
|
if pbar is not None:
|
||||||
|
with pbar_lock:
|
||||||
|
pbar.update(1)
|
||||||
|
torch.cuda.synchronize(device)
|
||||||
|
except BaseException as e:
|
||||||
|
with worker_lock:
|
||||||
|
worker_errors.append(e)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=worker, args=(d, split[d])) for d in devices]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
if worker_errors:
|
||||||
|
raise worker_errors[0]
|
||||||
|
|
||||||
|
combined_buf = sum(bufs.values())
|
||||||
|
combined_div = sum(divs.values()).clamp_(min=1e-12)
|
||||||
|
output[b:b+1] = combined_buf / combined_div
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
def model_trange(*args, **kwargs):
|
def model_trange(*args, **kwargs):
|
||||||
if not comfy.memory_management.aimdo_enabled:
|
if not comfy.memory_management.aimdo_enabled:
|
||||||
return trange(*args, **kwargs)
|
return trange(*args, **kwargs)
|
||||||
|
|||||||
@ -13,33 +13,38 @@ import comfy.multigpu
|
|||||||
|
|
||||||
class MultiGPUCFGSplitNode(io.ComfyNode):
|
class MultiGPUCFGSplitNode(io.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Prepares model to have sampling accelerated via splitting work units.
|
Attaches per-device deepclones to any connected MODEL and/or UPSCALE_MODEL so downstream
|
||||||
|
nodes that recognize the attached state dispatch their work across multiple GPUs.
|
||||||
|
|
||||||
Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes.
|
Place after nodes that modify the model object itself (compile, attention-switch, etc.).
|
||||||
|
Otherwise position is not order-sensitive.
|
||||||
Other than those exceptions, this node can be placed in any order.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="MultiGPU_WorkUnits",
|
node_id="MultiGPU_WorkUnits",
|
||||||
display_name="MultiGPU CFG Split",
|
display_name="MultiGPU Work Units",
|
||||||
category="advanced/multigpu",
|
category="advanced/multigpu",
|
||||||
description=cleandoc(cls.__doc__),
|
description=cleandoc(cls.__doc__),
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model", optional=True),
|
||||||
|
io.UpscaleModel.Input("upscale_model", optional=True),
|
||||||
io.Int.Input("max_gpus", default=2, min=1, step=1),
|
io.Int.Input("max_gpus", default=2, min=1, step=1),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(),
|
io.Model.Output(),
|
||||||
|
io.UpscaleModel.Output(),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput:
|
def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None) -> io.NodeOutput:
|
||||||
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
|
if model is not None:
|
||||||
return io.NodeOutput(model)
|
model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True)
|
||||||
|
if upscale_model is not None:
|
||||||
|
upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus)
|
||||||
|
return io.NodeOutput(model, upscale_model)
|
||||||
|
|
||||||
|
|
||||||
class MultiGPUOptionsNode(io.ComfyNode):
|
class MultiGPUOptionsNode(io.ComfyNode):
|
||||||
|
|||||||
@ -81,13 +81,33 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
|||||||
|
|
||||||
output_device = comfy.model_management.intermediate_device()
|
output_device = comfy.model_management.intermediate_device()
|
||||||
|
|
||||||
|
multigpu_clones = getattr(upscale_model, 'multigpu_clones', None)
|
||||||
|
if multigpu_clones:
|
||||||
|
for dev, desc in multigpu_clones.items():
|
||||||
|
model_management.free_memory(memory_required, dev)
|
||||||
|
desc.to(dev)
|
||||||
|
|
||||||
oom = True
|
oom = True
|
||||||
try:
|
try:
|
||||||
while oom:
|
while oom:
|
||||||
try:
|
try:
|
||||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
|
if multigpu_clones:
|
||||||
|
functions = {device: lambda a: upscale_model(a.float())}
|
||||||
|
for dev, desc in multigpu_clones.items():
|
||||||
|
functions[dev] = lambda a, d=desc: d(a.float())
|
||||||
|
s = comfy.utils.tiled_scale_multidim_multigpu(
|
||||||
|
in_img,
|
||||||
|
functions,
|
||||||
|
tile=(tile, tile),
|
||||||
|
overlap=overlap,
|
||||||
|
upscale_amount=upscale_model.scale,
|
||||||
|
pbar=pbar,
|
||||||
|
output_device=output_device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
|
||||||
oom = False
|
oom = False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_management.raise_non_oom(e)
|
model_management.raise_non_oom(e)
|
||||||
@ -96,6 +116,9 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
|||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
upscale_model.to("cpu")
|
upscale_model.to("cpu")
|
||||||
|
if multigpu_clones:
|
||||||
|
for desc in multigpu_clones.values():
|
||||||
|
desc.to("cpu")
|
||||||
|
|
||||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
|
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
|
||||||
return io.NodeOutput(s)
|
return io.NodeOutput(s)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user