Refactor MultiGPU scheduler for readability and termination safety (#14001)

Behaviour-equivalent cleanup of _calc_cond_batch_multigpu device
scheduling. No change to batching decisions or memory checks for any
valid input.

Changes:

* Replace re-summed batched_to_run_length with a per-device load
  dict (device_load), so capacity checks are O(1) and use a single
  source of truth.
* Extract device selection into next_available_device(), which scans
  at most len(devices) positions and raises if no device has
  remaining capacity. This makes the 'skip a full device' rule live
  in one place instead of two and guarantees the outer while loop
  cannot spin forever on a scheduling bug.
* Drop the unused current_device assignment before the outer loop
  and the index_device % len(devices) modulo dance (now handled
  inside next_available_device).
* Minor cleanups: list comprehensions for total_conds, conds_to_batch,
  and the devices list.
This commit is contained in:
Jedrzej Kosinski 2026-05-19 21:23:56 -07:00 committed by GitHub
parent 9e3ede1406
commit 819c7c0702
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -388,33 +388,40 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
model.current_patcher.prepare_state(timestep, model_options)
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
devices = list(model_options['multigpu_clones'].keys())
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
# Track conds currently scheduled per device; single source of truth for capacity checks.
device_load: dict[torch.device, int] = {d: 0 for d in devices}
total_conds = 0
for to_run in hooked_to_run.values():
total_conds += len(to_run)
total_conds = sum(len(to_run) for to_run in hooked_to_run.values())
conds_per_device = max(1, math.ceil(total_conds / len(devices)))
index_device = 0
current_device = devices[index_device]
def next_available_device(start: int) -> tuple[int, torch.device]:
"""Return (index, device) for the next device with remaining capacity, starting at `start`.
Scans at most len(devices) positions, so this always terminates. Raises if no device
has remaining capacity, which would indicate a bug in conds_per_device accounting.
"""
for offset in range(len(devices)):
i = (start + offset) % len(devices)
if device_load[devices[i]] < conds_per_device:
return i, devices[i]
raise RuntimeError(
f"MultiGPU scheduler: all {len(devices)} devices at capacity "
f"({conds_per_device}) but conds remain to schedule"
)
# run every hooked_to_run separately
index_device = 0
for hooks, to_run in hooked_to_run.items():
while len(to_run) > 0:
current_device = devices[index_device % len(devices)]
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
# keep track of conds currently scheduled onto this device
batched_to_run_length = 0
for btr in batched_to_run:
batched_to_run_length += len(btr[1])
remaining_capacity = conds_per_device - batched_to_run_length
if remaining_capacity <= 0:
index_device += 1
continue
index_device, current_device = next_available_device(index_device)
remaining_capacity = conds_per_device - device_load[current_device]
first = to_run[0]
first_shape = first[0][0].shape
# collect candidate indices that can be concatenated with `first`, up to remaining capacity
to_batch_temp = []
# make sure not over conds_per_device limit when creating temp batch
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity:
to_batch_temp += [x]
@ -429,13 +436,12 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
if model.memory_required(input_shape) * 1.5 < free_memory:
to_batch = batch_amount
break
conds_to_batch = []
for x in to_batch:
conds_to_batch.append(to_run.pop(x))
batched_to_run_length += len(conds_to_batch)
batched_to_run.append((hooks, conds_to_batch))
if batched_to_run_length >= conds_per_device:
conds_to_batch = [to_run.pop(x) for x in to_batch]
device_load[current_device] += len(conds_to_batch)
device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch))
if device_load[current_device] >= conds_per_device:
index_device += 1
class thread_result(NamedTuple):