mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-12 03:22:37 +08:00
Implement persistent thread pool for multi-GPU CFG splitting (#13329)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Replace per-step thread create/destroy in _calc_cond_batch_multigpu with a persistent MultiGPUThreadPool. Each worker thread calls torch.cuda.set_device() once at startup, preserving compiled kernel caches across diffusion steps. - Add MultiGPUThreadPool class in comfy/multigpu.py - Create pool in CFGGuider.outer_sample(), shut down in finally block - Main thread handles its own device batch directly for zero overhead - Falls back to sequential execution if no pool is available
This commit is contained in:
parent
da3864436c
commit
4b93c4360f
@ -1,4 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -11,6 +13,67 @@ import comfy.patcher_extension
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
|
class MultiGPUThreadPool:
|
||||||
|
"""Persistent thread pool for multi-GPU work distribution.
|
||||||
|
|
||||||
|
Maintains one worker thread per extra GPU device. Each thread calls
|
||||||
|
torch.cuda.set_device() once at startup so that compiled kernel caches
|
||||||
|
(inductor/triton) stay warm across diffusion steps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, devices: list[torch.device]):
|
||||||
|
self._workers: list[threading.Thread] = []
|
||||||
|
self._work_queues: dict[torch.device, queue.Queue] = {}
|
||||||
|
self._result_queues: dict[torch.device, queue.Queue] = {}
|
||||||
|
|
||||||
|
for device in devices:
|
||||||
|
wq = queue.Queue()
|
||||||
|
rq = queue.Queue()
|
||||||
|
self._work_queues[device] = wq
|
||||||
|
self._result_queues[device] = rq
|
||||||
|
t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
|
||||||
|
t.start()
|
||||||
|
self._workers.append(t)
|
||||||
|
|
||||||
|
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
|
||||||
|
try:
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
|
||||||
|
while True:
|
||||||
|
item = work_q.get()
|
||||||
|
if item is None:
|
||||||
|
return
|
||||||
|
result_q.put((None, e))
|
||||||
|
return
|
||||||
|
while True:
|
||||||
|
item = work_q.get()
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
fn, args, kwargs = item
|
||||||
|
try:
|
||||||
|
result = fn(*args, **kwargs)
|
||||||
|
result_q.put((result, None))
|
||||||
|
except Exception as e:
|
||||||
|
result_q.put((None, e))
|
||||||
|
|
||||||
|
def submit(self, device: torch.device, fn, *args, **kwargs):
|
||||||
|
self._work_queues[device].put((fn, args, kwargs))
|
||||||
|
|
||||||
|
def get_result(self, device: torch.device):
|
||||||
|
return self._result_queues[device].get()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def devices(self) -> list[torch.device]:
|
||||||
|
return list(self._work_queues.keys())
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
for wq in self._work_queues.values():
|
||||||
|
wq.put(None) # sentinel
|
||||||
|
for t in self._workers:
|
||||||
|
t.join(timeout=5.0)
|
||||||
|
|
||||||
|
|
||||||
class GPUOptions:
|
class GPUOptions:
|
||||||
def __init__(self, device_index: int, relative_speed: float):
|
def __init__(self, device_index: int, relative_speed: float):
|
||||||
self.device_index = device_index
|
self.device_index = device_index
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import comfy.hooks
|
|||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from comfy.model_base import BaseModel
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
from comfy.controlnet import ControlBase
|
from comfy.controlnet import ControlBase
|
||||||
|
|
||||||
|
|||||||
@ -18,10 +18,10 @@ import comfy.model_patcher
|
|||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.hooks
|
import comfy.hooks
|
||||||
import comfy.context_windows
|
import comfy.context_windows
|
||||||
|
import comfy.multigpu
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import scipy.stats
|
import scipy.stats
|
||||||
import numpy
|
import numpy
|
||||||
import threading
|
|
||||||
|
|
||||||
|
|
||||||
def add_area_dims(area, num_dims):
|
def add_area_dims(area, num_dims):
|
||||||
@ -509,15 +509,38 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
results: list[thread_result] = []
|
def _handle_batch_pooled(device, batch_tuple):
|
||||||
threads: list[threading.Thread] = []
|
worker_results = []
|
||||||
for device, batch_tuple in device_batched_hooked_to_run.items():
|
_handle_batch(device, batch_tuple, worker_results)
|
||||||
new_thread = threading.Thread(target=_handle_batch, args=(device, batch_tuple, results))
|
return worker_results
|
||||||
threads.append(new_thread)
|
|
||||||
new_thread.start()
|
|
||||||
|
|
||||||
for thread in threads:
|
results: list[thread_result] = []
|
||||||
thread.join()
|
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
|
||||||
|
main_device = output_device
|
||||||
|
main_batch_tuple = None
|
||||||
|
|
||||||
|
# Submit extra GPU work to pool first, then run main device on this thread
|
||||||
|
pool_devices = []
|
||||||
|
for device, batch_tuple in device_batched_hooked_to_run.items():
|
||||||
|
if device == main_device and thread_pool is not None:
|
||||||
|
main_batch_tuple = batch_tuple
|
||||||
|
elif thread_pool is not None:
|
||||||
|
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
|
||||||
|
pool_devices.append(device)
|
||||||
|
else:
|
||||||
|
# Fallback: no pool, run everything on main thread
|
||||||
|
_handle_batch(device, batch_tuple, results)
|
||||||
|
|
||||||
|
# Run main device batch on this thread (parallel with pool workers)
|
||||||
|
if main_batch_tuple is not None:
|
||||||
|
_handle_batch(main_device, main_batch_tuple, results)
|
||||||
|
|
||||||
|
# Collect results from pool workers
|
||||||
|
for device in pool_devices:
|
||||||
|
worker_results, error = thread_pool.get_result(device)
|
||||||
|
if error is not None:
|
||||||
|
raise error
|
||||||
|
results.extend(worker_results)
|
||||||
|
|
||||||
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
|
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
|
||||||
if error is not None:
|
if error is not None:
|
||||||
@ -1187,17 +1210,25 @@ class CFGGuider:
|
|||||||
|
|
||||||
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
|
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
|
||||||
|
|
||||||
noise = noise.to(device=device, dtype=torch.float32)
|
# Create persistent thread pool for extra GPU devices
|
||||||
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
if multigpu_patchers:
|
||||||
sigmas = sigmas.to(device)
|
extra_devices = [p.load_device for p in multigpu_patchers]
|
||||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(extra_devices)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
noise = noise.to(device=device, dtype=torch.float32)
|
||||||
|
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||||
|
sigmas = sigmas.to(device)
|
||||||
|
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||||
|
|
||||||
self.model_patcher.pre_run()
|
self.model_patcher.pre_run()
|
||||||
for multigpu_patcher in multigpu_patchers:
|
for multigpu_patcher in multigpu_patchers:
|
||||||
multigpu_patcher.pre_run()
|
multigpu_patcher.pre_run()
|
||||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||||
finally:
|
finally:
|
||||||
|
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
|
||||||
|
if thread_pool is not None:
|
||||||
|
thread_pool.shutdown()
|
||||||
self.model_patcher.cleanup()
|
self.model_patcher.cleanup()
|
||||||
for multigpu_patcher in multigpu_patchers:
|
for multigpu_patcher in multigpu_patchers:
|
||||||
multigpu_patcher.cleanup()
|
multigpu_patcher.cleanup()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user