Implement persistent thread pool for multi-GPU CFG splitting
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled

Replace per-step thread create/destroy in _calc_cond_batch_multigpu with a
persistent MultiGPUThreadPool. Each worker thread calls torch.cuda.set_device()
once at startup, preserving compiled kernel caches across diffusion steps.

- Add MultiGPUThreadPool class in comfy/multigpu.py
- Create pool in CFGGuider.outer_sample(), shut down in finally block
- Main thread handles its own device batch directly for zero overhead
- Falls back to sequential execution if no pool is available

Co-authored-by: Amp <amp@ampcode.com>
Amp-Thread-ID: https://ampcode.com/threads/T-019d3f5c-28c5-72c9-abed-34681f1b54ba
This commit is contained in:
Jedrzej Kosinski 2026-04-08 04:27:26 -07:00
parent da3864436c
commit 9e4749a32f
3 changed files with 108 additions and 13 deletions

View File

@ -1,4 +1,6 @@
from __future__ import annotations
import queue
import threading
import torch
import logging
@ -11,6 +13,67 @@ import comfy.patcher_extension
import comfy.model_management
class MultiGPUThreadPool:
"""Persistent thread pool for multi-GPU work distribution.
Maintains one worker thread per extra GPU device. Each thread calls
torch.cuda.set_device() once at startup so that compiled kernel caches
(inductor/triton) stay warm across diffusion steps.
"""
def __init__(self, devices: list[torch.device]):
self._workers: list[threading.Thread] = []
self._work_queues: dict[torch.device, queue.Queue] = {}
self._result_queues: dict[torch.device, queue.Queue] = {}
for device in devices:
wq = queue.Queue()
rq = queue.Queue()
self._work_queues[device] = wq
self._result_queues[device] = rq
t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
t.start()
self._workers.append(t)
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
try:
torch.cuda.set_device(device)
except Exception as e:
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
while True:
item = work_q.get()
if item is None:
return
result_q.put((None, e))
return
while True:
item = work_q.get()
if item is None:
break
fn, args, kwargs = item
try:
result = fn(*args, **kwargs)
result_q.put((result, None))
except Exception as e:
result_q.put((None, e))
def submit(self, device: torch.device, fn, *args, **kwargs):
self._work_queues[device].put((fn, args, kwargs))
def get_result(self, device: torch.device):
return self._result_queues[device].get()
@property
def devices(self) -> list[torch.device]:
return list(self._work_queues.keys())
def shutdown(self):
for wq in self._work_queues.values():
wq.put(None) # sentinel
for t in self._workers:
t.join(timeout=5.0)
class GPUOptions:
def __init__(self, device_index: int, relative_speed: float):
self.device_index = device_index

View File

@ -11,6 +11,7 @@ import comfy.hooks
import comfy.patcher_extension
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase

View File

@ -18,10 +18,10 @@ import comfy.model_patcher
import comfy.patcher_extension
import comfy.hooks
import comfy.context_windows
import comfy.multigpu
import comfy.utils
import scipy.stats
import numpy
import threading
def add_area_dims(area, num_dims):
@ -509,15 +509,38 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
raise
results: list[thread_result] = []
threads: list[threading.Thread] = []
for device, batch_tuple in device_batched_hooked_to_run.items():
new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results))
threads.append(new_thread)
new_thread.start()
def _handle_batch_pooled(device, batch_tuple):
worker_results = []
_handle_batch(device, batch_tuple, worker_results)
return worker_results
for thread in threads:
thread.join()
results: list[thread_result] = []
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
main_device = output_device
main_batch_tuple = None
# Submit extra GPU work to pool first, then run main device on this thread
pool_devices = []
for device, batch_tuple in device_batched_hooked_to_run.items():
if device == main_device and thread_pool is not None:
main_batch_tuple = batch_tuple
elif thread_pool is not None:
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
pool_devices.append(device)
else:
# Fallback: no pool, run everything on main thread
_handle_batch(device, batch_tuple, results)
# Run main device batch on this thread (parallel with pool workers)
if main_batch_tuple is not None:
_handle_batch(main_device, main_batch_tuple, results)
# Collect results from pool workers
for device in pool_devices:
worker_results, error = thread_pool.get_result(device)
if error is not None:
raise error
results.extend(worker_results)
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
if error is not None:
@ -1187,17 +1210,25 @@ class CFGGuider:
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
# Create persistent thread pool for extra GPU devices
if multigpu_patchers:
extra_devices = [p.load_device for p in multigpu_patchers]
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(extra_devices)
try:
noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
self.model_patcher.pre_run()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
if thread_pool is not None:
thread_pool.shutdown()
self.model_patcher.cleanup()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.cleanup()