From 819c7c0702107511b4d08a7de5e7f03007b53799 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 19 May 2026 21:23:56 -0700 Subject: [PATCH] Refactor MultiGPU scheduler for readability and termination safety (#14001) Behaviour-equivalent cleanup of _calc_cond_batch_multigpu device scheduling. No change to batching decisions or memory checks for any valid input. Changes: * Replace re-summed batched_to_run_length with a per-device load dict (device_load), so capacity checks are O(1) and use a single source of truth. * Extract device selection into next_available_device(), which scans at most len(devices) positions and raises if no device has remaining capacity. This makes the 'skip a full device' rule live in one place instead of two and guarantees the outer while loop cannot spin forever on a scheduling bug. * Drop the unused current_device assignment before the outer loop and the index_device % len(devices) modulo dance (now handled inside next_available_device). * Minor cleanups: list comprehensions for total_conds, conds_to_batch, and the devices list. --- comfy/samplers.py | 52 ++++++++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 83fa2e609..f0d67cb7e 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -388,33 +388,40 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t model.current_patcher.prepare_state(timestep, model_options) - devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()] + devices = list(model_options['multigpu_clones'].keys()) device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {} + # Track conds currently scheduled per device; single source of truth for capacity checks. + device_load: dict[torch.device, int] = {d: 0 for d in devices} - total_conds = 0 - for to_run in hooked_to_run.values(): - total_conds += len(to_run) + total_conds = sum(len(to_run) for to_run in hooked_to_run.values()) conds_per_device = max(1, math.ceil(total_conds / len(devices))) - index_device = 0 - current_device = devices[index_device] + + def next_available_device(start: int) -> tuple[int, torch.device]: + """Return (index, device) for the next device with remaining capacity, starting at `start`. + + Scans at most len(devices) positions, so this always terminates. Raises if no device + has remaining capacity, which would indicate a bug in conds_per_device accounting. + """ + for offset in range(len(devices)): + i = (start + offset) % len(devices) + if device_load[devices[i]] < conds_per_device: + return i, devices[i] + raise RuntimeError( + f"MultiGPU scheduler: all {len(devices)} devices at capacity " + f"({conds_per_device}) but conds remain to schedule" + ) + # run every hooked_to_run separately + index_device = 0 for hooks, to_run in hooked_to_run.items(): while len(to_run) > 0: - current_device = devices[index_device % len(devices)] - batched_to_run = device_batched_hooked_to_run.setdefault(current_device, []) - # keep track of conds currently scheduled onto this device - batched_to_run_length = 0 - for btr in batched_to_run: - batched_to_run_length += len(btr[1]) - remaining_capacity = conds_per_device - batched_to_run_length - if remaining_capacity <= 0: - index_device += 1 - continue + index_device, current_device = next_available_device(index_device) + remaining_capacity = conds_per_device - device_load[current_device] first = to_run[0] first_shape = first[0][0].shape + # collect candidate indices that can be concatenated with `first`, up to remaining capacity to_batch_temp = [] - # make sure not over conds_per_device limit when creating temp batch for x in range(len(to_run)): if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity: to_batch_temp += [x] @@ -429,13 +436,12 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t if model.memory_required(input_shape) * 1.5 < free_memory: to_batch = batch_amount break - conds_to_batch = [] - for x in to_batch: - conds_to_batch.append(to_run.pop(x)) - batched_to_run_length += len(conds_to_batch) - batched_to_run.append((hooks, conds_to_batch)) - if batched_to_run_length >= conds_per_device: + conds_to_batch = [to_run.pop(x) for x in to_batch] + device_load[current_device] += len(conds_to_batch) + device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch)) + + if device_load[current_device] >= conds_per_device: index_device += 1 class thread_result(NamedTuple):