mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-21 14:37:30 +08:00
Refactor MultiGPU scheduler for readability and termination safety (#14001)
Behaviour-equivalent cleanup of _calc_cond_batch_multigpu device scheduling. No change to batching decisions or memory checks for any valid input. Changes: * Replace re-summed batched_to_run_length with a per-device load dict (device_load), so capacity checks are O(1) and use a single source of truth. * Extract device selection into next_available_device(), which scans at most len(devices) positions and raises if no device has remaining capacity. This makes the 'skip a full device' rule live in one place instead of two and guarantees the outer while loop cannot spin forever on a scheduling bug. * Drop the unused current_device assignment before the outer loop and the index_device % len(devices) modulo dance (now handled inside next_available_device). * Minor cleanups: list comprehensions for total_conds, conds_to_batch, and the devices list.
This commit is contained in:
parent
9e3ede1406
commit
819c7c0702
@ -388,33 +388,40 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
|
||||
|
||||
model.current_patcher.prepare_state(timestep, model_options)
|
||||
|
||||
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
|
||||
devices = list(model_options['multigpu_clones'].keys())
|
||||
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
|
||||
# Track conds currently scheduled per device; single source of truth for capacity checks.
|
||||
device_load: dict[torch.device, int] = {d: 0 for d in devices}
|
||||
|
||||
total_conds = 0
|
||||
for to_run in hooked_to_run.values():
|
||||
total_conds += len(to_run)
|
||||
total_conds = sum(len(to_run) for to_run in hooked_to_run.values())
|
||||
conds_per_device = max(1, math.ceil(total_conds / len(devices)))
|
||||
index_device = 0
|
||||
current_device = devices[index_device]
|
||||
|
||||
def next_available_device(start: int) -> tuple[int, torch.device]:
|
||||
"""Return (index, device) for the next device with remaining capacity, starting at `start`.
|
||||
|
||||
Scans at most len(devices) positions, so this always terminates. Raises if no device
|
||||
has remaining capacity, which would indicate a bug in conds_per_device accounting.
|
||||
"""
|
||||
for offset in range(len(devices)):
|
||||
i = (start + offset) % len(devices)
|
||||
if device_load[devices[i]] < conds_per_device:
|
||||
return i, devices[i]
|
||||
raise RuntimeError(
|
||||
f"MultiGPU scheduler: all {len(devices)} devices at capacity "
|
||||
f"({conds_per_device}) but conds remain to schedule"
|
||||
)
|
||||
|
||||
# run every hooked_to_run separately
|
||||
index_device = 0
|
||||
for hooks, to_run in hooked_to_run.items():
|
||||
while len(to_run) > 0:
|
||||
current_device = devices[index_device % len(devices)]
|
||||
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
|
||||
# keep track of conds currently scheduled onto this device
|
||||
batched_to_run_length = 0
|
||||
for btr in batched_to_run:
|
||||
batched_to_run_length += len(btr[1])
|
||||
remaining_capacity = conds_per_device - batched_to_run_length
|
||||
if remaining_capacity <= 0:
|
||||
index_device += 1
|
||||
continue
|
||||
index_device, current_device = next_available_device(index_device)
|
||||
remaining_capacity = conds_per_device - device_load[current_device]
|
||||
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
# collect candidate indices that can be concatenated with `first`, up to remaining capacity
|
||||
to_batch_temp = []
|
||||
# make sure not over conds_per_device limit when creating temp batch
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity:
|
||||
to_batch_temp += [x]
|
||||
@ -429,13 +436,12 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
|
||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
conds_to_batch = []
|
||||
for x in to_batch:
|
||||
conds_to_batch.append(to_run.pop(x))
|
||||
batched_to_run_length += len(conds_to_batch)
|
||||
|
||||
batched_to_run.append((hooks, conds_to_batch))
|
||||
if batched_to_run_length >= conds_per_device:
|
||||
conds_to_batch = [to_run.pop(x) for x in to_batch]
|
||||
device_load[current_device] += len(conds_to_batch)
|
||||
device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch))
|
||||
|
||||
if device_load[current_device] >= conds_per_device:
|
||||
index_device += 1
|
||||
|
||||
class thread_result(NamedTuple):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user