Simplify multigpu dispatch: run all devices on pool threads (#13340)
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled

Benchmarked hybrid (main thread + pool) vs all-pool on 2x RTX 4090
with SD1.5 and NetaYume models. No meaningful performance difference
(within noise). All-pool is simpler: eliminates the main_device
special case, main_batch_tuple deferred execution, and the 3-way
branch in the dispatch loop.
This commit is contained in:
Jedrzej Kosinski 2026-04-08 22:15:57 -10:00 committed by GitHub
parent 4b93c4360f
commit 48deb15c0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -516,25 +516,17 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
results: list[thread_result] = []
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
main_device = output_device
main_batch_tuple = None
# Submit extra GPU work to pool first, then run main device on this thread
# Submit all GPU work to pool threads
pool_devices = []
for device, batch_tuple in device_batched_hooked_to_run.items():
if device == main_device and thread_pool is not None:
main_batch_tuple = batch_tuple
elif thread_pool is not None:
if thread_pool is not None:
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
pool_devices.append(device)
else:
# Fallback: no pool, run everything on main thread
_handle_batch(device, batch_tuple, results)
# Run main device batch on this thread (parallel with pool workers)
if main_batch_tuple is not None:
_handle_batch(main_device, main_batch_tuple, results)
# Collect results from pool workers
for device in pool_devices:
worker_results, error = thread_pool.get_result(device)
@ -1210,10 +1202,11 @@ class CFGGuider:
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
# Create persistent thread pool for extra GPU devices
# Create persistent thread pool for all GPU devices (main + extras)
if multigpu_patchers:
extra_devices = [p.load_device for p in multigpu_patchers]
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(extra_devices)
all_devices = [device] + extra_devices
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
try:
noise = noise.to(device=device, dtype=torch.float32)