diff --git a/comfy/multigpu.py b/comfy/multigpu.py index 90995a5ab..096270c12 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -1,4 +1,6 @@ from __future__ import annotations +import queue +import threading import torch import logging @@ -11,6 +13,67 @@ import comfy.patcher_extension import comfy.model_management +class MultiGPUThreadPool: + """Persistent thread pool for multi-GPU work distribution. + + Maintains one worker thread per extra GPU device. Each thread calls + torch.cuda.set_device() once at startup so that compiled kernel caches + (inductor/triton) stay warm across diffusion steps. + """ + + def __init__(self, devices: list[torch.device]): + self._workers: list[threading.Thread] = [] + self._work_queues: dict[torch.device, queue.Queue] = {} + self._result_queues: dict[torch.device, queue.Queue] = {} + + for device in devices: + wq = queue.Queue() + rq = queue.Queue() + self._work_queues[device] = wq + self._result_queues[device] = rq + t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True) + t.start() + self._workers.append(t) + + def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue): + try: + torch.cuda.set_device(device) + except Exception as e: + logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}") + while True: + item = work_q.get() + if item is None: + return + result_q.put((None, e)) + return + while True: + item = work_q.get() + if item is None: + break + fn, args, kwargs = item + try: + result = fn(*args, **kwargs) + result_q.put((result, None)) + except Exception as e: + result_q.put((None, e)) + + def submit(self, device: torch.device, fn, *args, **kwargs): + self._work_queues[device].put((fn, args, kwargs)) + + def get_result(self, device: torch.device): + return self._result_queues[device].get() + + @property + def devices(self) -> list[torch.device]: + return list(self._work_queues.keys()) + + def shutdown(self): + for wq in self._work_queues.values(): + wq.put(None) # sentinel + for t in self._workers: + t.join(timeout=5.0) + + class GPUOptions: def __init__(self, device_index: int, relative_speed: float): self.device_index = device_index diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 844fadacd..6f5447d95 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -11,6 +11,7 @@ import comfy.hooks import comfy.patcher_extension from typing import TYPE_CHECKING if TYPE_CHECKING: + from comfy.model_base import BaseModel from comfy.model_patcher import ModelPatcher from comfy.controlnet import ControlBase diff --git a/comfy/samplers.py b/comfy/samplers.py index 1ff50f51d..68f093749 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -18,10 +18,10 @@ import comfy.model_patcher import comfy.patcher_extension import comfy.hooks import comfy.context_windows +import comfy.multigpu import comfy.utils import scipy.stats import numpy -import threading def add_area_dims(area, num_dims): @@ -509,15 +509,38 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t raise - results: list[thread_result] = [] - threads: list[threading.Thread] = [] - for device, batch_tuple in device_batched_hooked_to_run.items(): - new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results)) - threads.append(new_thread) - new_thread.start() + def _handle_batch_pooled(device, batch_tuple): + worker_results = [] + _handle_batch(device, batch_tuple, worker_results) + return worker_results - for thread in threads: - thread.join() + results: list[thread_result] = [] + thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool") + main_device = output_device + main_batch_tuple = None + + # Submit extra GPU work to pool first, then run main device on this thread + pool_devices = [] + for device, batch_tuple in device_batched_hooked_to_run.items(): + if device == main_device and thread_pool is not None: + main_batch_tuple = batch_tuple + elif thread_pool is not None: + thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple) + pool_devices.append(device) + else: + # Fallback: no pool, run everything on main thread + _handle_batch(device, batch_tuple, results) + + # Run main device batch on this thread (parallel with pool workers) + if main_batch_tuple is not None: + _handle_batch(main_device, main_batch_tuple, results) + + # Collect results from pool workers + for device in pool_devices: + worker_results, error = thread_pool.get_result(device) + if error is not None: + raise error + results.extend(worker_results) for output, mult, area, batch_chunks, cond_or_uncond, error in results: if error is not None: @@ -1187,17 +1210,25 @@ class CFGGuider: multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options) - noise = noise.to(device=device, dtype=torch.float32) - latent_image = latent_image.to(device=device, dtype=torch.float32) - sigmas = sigmas.to(device) - cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) + # Create persistent thread pool for extra GPU devices + if multigpu_patchers: + extra_devices = [p.load_device for p in multigpu_patchers] + self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(extra_devices) try: + noise = noise.to(device=device, dtype=torch.float32) + latent_image = latent_image.to(device=device, dtype=torch.float32) + sigmas = sigmas.to(device) + cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) + self.model_patcher.pre_run() for multigpu_patcher in multigpu_patchers: multigpu_patcher.pre_run() output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) finally: + thread_pool = self.model_options.pop("multigpu_thread_pool", None) + if thread_pool is not None: + thread_pool.shutdown() self.model_patcher.cleanup() for multigpu_patcher in multigpu_patchers: multigpu_patcher.cleanup()