Fix MultiGPU scheduler capacity accounting (#14000)

Fixes _calc_cond_batch_multigpu so that:

1. conds_per_device uses real division before math.ceil. The previous
   expression math.ceil(total_conds // len(devices)) applied integer
   floor division first, making ceil a no-op. For 3 conds across 2
   devices this produced conds_per_device=1 instead of 2.

2. The scheduling loop skips devices that have already reached
   capacity instead of appending empty batch groups. Without this
   guard, the loop could repeatedly emit zero-length groups for a
   full device, leaving sampling stuck at 0/N until timeout.

Reproduces with an Omnigen2 image workflow that produces three
condition entries scheduled across two CUDA devices. With the fix
the scheduler assigns conds_per_device=2 and splits the batches as
2 + 1 across the two devices, allowing sampling to complete.

Original fix authored and validated by @pollockjj in
pollockjj/ComfyUI#64.

Co-authored-by: John Pollock <pollockjj@gmail.com>
This commit is contained in:
Jedrzej Kosinski 2026-05-19 20:11:53 -07:00 committed by GitHub
parent a61e2bbb85
commit 9e3ede1406
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -394,7 +394,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
total_conds = 0
for to_run in hooked_to_run.values():
total_conds += len(to_run)
conds_per_device = max(1, math.ceil(total_conds//len(devices)))
conds_per_device = max(1, math.ceil(total_conds / len(devices)))
index_device = 0
current_device = devices[index_device]
# run every hooked_to_run separately
@ -406,13 +406,17 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
batched_to_run_length = 0
for btr in batched_to_run:
batched_to_run_length += len(btr[1])
remaining_capacity = conds_per_device - batched_to_run_length
if remaining_capacity <= 0:
index_device += 1
continue
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
# make sure not over conds_per_device limit when creating temp batch
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < (conds_per_device - batched_to_run_length):
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity:
to_batch_temp += [x]
to_batch_temp.reverse()