diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 7f90b7db7..eff7d0649 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -1,5 +1,4 @@ from __future__ import annotations -import copy import queue import threading import torch @@ -176,35 +175,6 @@ def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: return model -def create_upscale_model_multigpu_deepclones(upscale_model, max_gpus: int): - """Return a shallow copy of ``upscale_model`` with a ``multigpu_clones`` dict of CPU-resident - descriptor deepclones, one per extra CUDA device up to ``max_gpus``. - """ - full_extra_devices = comfy.model_management.get_all_torch_devices(exclude_current=True) - limit_extra_devices = full_extra_devices[:max_gpus - 1] - if len(limit_extra_devices) == 0: - logging.info("No extra torch devices need initialization, skipping initializing MultiGPU upscale clones.") - return upscale_model - - cloned = copy.copy(upscale_model) - existing = getattr(upscale_model, 'multigpu_clones', None) - clones: dict[torch.device, object] = dict(existing) if existing else {} - - for device in limit_extra_devices: - if device in clones: - continue - clone_desc = copy.deepcopy(upscale_model) - clone_desc.model.eval() - for p in clone_desc.model.parameters(): - p.requires_grad_(False) - clone_desc.to("cpu") - clones[device] = clone_desc - logging.info(f"Created CPU upscale_model descriptor deepclone for {device}") - - cloned.multigpu_clones = clones - return cloned - - LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time']) def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None): 'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.' diff --git a/comfy/utils.py b/comfy/utils.py index 60b69324b..49ae12b06 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -28,13 +28,13 @@ import numpy as np from PIL import Image import logging import itertools -import threading from torch.nn.functional import interpolate from tqdm.auto import trange from einops import rearrange from comfy.cli_args import args import json import time +import threading import warnings MMAP_TORCH_FILES = args.mmap_torch_files @@ -1187,155 +1187,6 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) - -def tiled_scale_multidim_multigpu(samples, functions, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None): - """Multigpu variant of tiled_scale_multidim. ``functions`` is a dict[torch.device, callable]. - - Round-robin dispatches tile positions across devices via threading. Each thread maintains - its own per-device CPU output and divisor buffer, applying the same feathered overlap mask - formula as the single-device path. Buffers are summed at the end, producing output that is - bit-equivalent to ``tiled_scale_multidim`` within fp32 add-order noise. - - Falls back to ``tiled_scale_multidim`` with the only function when ``len(functions) < 2``. - Falls back to single-device on the "whole input fits in one tile" branch (no parallelism - available at that granularity). - """ - devices = list(functions.keys()) - if len(devices) < 2: - only_fn = next(iter(functions.values())) if functions else None - return tiled_scale_multidim(samples, only_fn, tile=tile, overlap=overlap, - upscale_amount=upscale_amount, out_channels=out_channels, - output_device=output_device, downscale=downscale, - index_formulas=index_formulas, pbar=pbar) - - dims = len(tile) - - if not (isinstance(upscale_amount, (tuple, list))): - upscale_amount = [upscale_amount] * dims - if not (isinstance(overlap, (tuple, list))): - overlap = [overlap] * dims - if index_formulas is None: - index_formulas = upscale_amount - if not (isinstance(index_formulas, (tuple, list))): - index_formulas = [index_formulas] * dims - - def get_upscale(dim, val): - up = upscale_amount[dim] - return up(val) if callable(up) else up * val - - def get_downscale(dim, val): - up = upscale_amount[dim] - return up(val) if callable(up) else val / up - - def get_upscale_pos(dim, val): - up = index_formulas[dim] - return up(val) if callable(up) else up * val - - def get_downscale_pos(dim, val): - up = index_formulas[dim] - return up(val) if callable(up) else val / up - - if downscale: - get_scale = get_downscale - get_pos = get_downscale_pos - else: - get_scale = get_upscale - get_pos = get_upscale_pos - - def mult_list_upscale(a): - return [round(get_scale(i, a[i])) for i in range(len(a))] - - output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device) - merge_device = torch.device("cpu") - - pbar_lock = threading.Lock() if pbar is not None else None - primary_device = devices[0] - - samples_staged = samples if samples.device.type == "cpu" else samples.to("cpu", non_blocking=False) - - for b in range(samples_staged.shape[0]): - s = samples_staged[b:b+1] - - if all(s.shape[d+2] <= tile[d] for d in range(dims)): - with torch.inference_mode(): - output[b:b+1] = functions[primary_device](s.to(primary_device, non_blocking=True)).to(output_device) - if pbar is not None: - pbar.update(1) - continue - - positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] - all_positions = list(itertools.product(*positions)) - - split = {devices[i]: all_positions[i::len(devices)] for i in range(len(devices))} - - out_shape = [s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]) - div_shape = [s.shape[0], 1] + mult_list_upscale(s.shape[2:]) - bufs = {d: torch.zeros(out_shape, device=merge_device) for d in devices} - divs = {d: torch.zeros(div_shape, device=merge_device) for d in devices} - - worker_errors: list[BaseException] = [] - worker_lock = threading.Lock() - - def worker(device, my_positions): - try: - torch.cuda.set_device(device) - fn = functions[device] - local_buf = bufs[device] - local_div = divs[device] - with torch.inference_mode(): - for it in my_positions: - s_in = s - upscaled = [] - for d in range(dims): - pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) - l = min(tile[d], s.shape[d + 2] - pos) - s_in = s_in.narrow(d + 2, pos, l) - upscaled.append(round(get_pos(d, pos))) - - s_in_dev = s_in.to(device, non_blocking=True) - ps = fn(s_in_dev).to(merge_device) - mask = torch.ones([1, 1] + list(ps.shape[2:]), device=merge_device) - - for d in range(2, dims + 2): - feather = round(get_scale(d - 2, overlap[d - 2])) - if feather >= mask.shape[d]: - continue - for t in range(feather): - a = (t + 1) / feather - mask.narrow(d, t, 1).mul_(a) - mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a) - - o = local_buf - o_d = local_div - for d in range(dims): - o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) - o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) - - o.add_(ps * mask) - o_d.add_(mask) - - if pbar is not None: - with pbar_lock: - pbar.update(1) - torch.cuda.synchronize(device) - except BaseException as e: - with worker_lock: - worker_errors.append(e) - - threads = [threading.Thread(target=worker, args=(d, split[d])) for d in devices] - for t in threads: - t.start() - for t in threads: - t.join() - if worker_errors: - raise worker_errors[0] - - combined_buf = sum(bufs.values()) - combined_div = sum(divs.values()).clamp_(min=1e-12) - output[b:b+1] = combined_buf / combined_div - - return output - def model_trange(*args, **kwargs): if not comfy.memory_management.aimdo_enabled: return trange(*args, **kwargs) diff --git a/comfy_extras/nodes_multigpu.py b/comfy_extras/nodes_multigpu.py index 021dfca3f..fedafef71 100644 --- a/comfy_extras/nodes_multigpu.py +++ b/comfy_extras/nodes_multigpu.py @@ -13,38 +13,33 @@ import comfy.multigpu class MultiGPUCFGSplitNode(io.ComfyNode): """ - Attaches per-device deepclones to any connected MODEL and/or UPSCALE_MODEL so downstream - nodes that recognize the attached state dispatch their work across multiple GPUs. + Prepares model to have sampling accelerated via splitting work units. - Place after nodes that modify the model object itself (compile, attention-switch, etc.). - Otherwise position is not order-sensitive. + Should be placed after nodes that modify the model object itself, such as compile or attention-switch nodes. + + Other than those exceptions, this node can be placed in any order. """ @classmethod def define_schema(cls): return io.Schema( node_id="MultiGPU_WorkUnits", - display_name="MultiGPU Work Units", + display_name="MultiGPU CFG Split", category="advanced/multigpu", description=cleandoc(cls.__doc__), inputs=[ - io.Model.Input("model", optional=True), - io.UpscaleModel.Input("upscale_model", optional=True), + io.Model.Input("model"), io.Int.Input("max_gpus", default=2, min=1, step=1), ], outputs=[ io.Model.Output(), - io.UpscaleModel.Output(), ], ) @classmethod - def execute(cls, max_gpus: int, model: ModelPatcher = None, upscale_model=None) -> io.NodeOutput: - if model is not None: - model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) - if upscale_model is not None: - upscale_model = comfy.multigpu.create_upscale_model_multigpu_deepclones(upscale_model, max_gpus) - return io.NodeOutput(model, upscale_model) + def execute(cls, model: ModelPatcher, max_gpus: int) -> io.NodeOutput: + model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, reuse_loaded=True) + return io.NodeOutput(model) class MultiGPUOptionsNode(io.ComfyNode): diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 3a4e3926c..d3ee3f1c1 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -81,33 +81,13 @@ class ImageUpscaleWithModel(io.ComfyNode): output_device = comfy.model_management.intermediate_device() - multigpu_clones = getattr(upscale_model, 'multigpu_clones', None) - if multigpu_clones: - for dev, desc in multigpu_clones.items(): - model_management.free_memory(memory_required, dev) - desc.to(dev) - oom = True try: while oom: try: steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) pbar = comfy.utils.ProgressBar(steps) - if multigpu_clones: - functions = {device: lambda a: upscale_model(a.float())} - for dev, desc in multigpu_clones.items(): - functions[dev] = lambda a, d=desc: d(a.float()) - s = comfy.utils.tiled_scale_multidim_multigpu( - in_img, - functions, - tile=(tile, tile), - overlap=overlap, - upscale_amount=upscale_model.scale, - pbar=pbar, - output_device=output_device, - ) - else: - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device) oom = False except Exception as e: model_management.raise_non_oom(e) @@ -116,9 +96,6 @@ class ImageUpscaleWithModel(io.ComfyNode): raise e finally: upscale_model.to("cpu") - if multigpu_clones: - for desc in multigpu_clones.values(): - desc.to("cpu") s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype()) return io.NodeOutput(s)