mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-30 02:47:24 +08:00
Refactor MultiGPU scheduler for readability and termination safety (#14001)
Behaviour-equivalent cleanup of _calc_cond_batch_multigpu device scheduling. No change to batching decisions or memory checks for any valid input. Changes: * Replace re-summed batched_to_run_length with a per-device load dict (device_load), so capacity checks are O(1) and use a single source of truth. * Extract device selection into next_available_device(), which scans at most len(devices) positions and raises if no device has remaining capacity. This makes the 'skip a full device' rule live in one place instead of two and guarantees the outer while loop cannot spin forever on a scheduling bug. * Drop the unused current_device assignment before the outer loop and the index_device % len(devices) modulo dance (now handled inside next_available_device). * Minor cleanups: list comprehensions for total_conds, conds_to_batch, and the devices list.
This commit is contained in:
parent
9e3ede1406
commit
819c7c0702
@ -388,33 +388,40 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
|
|||||||
|
|
||||||
model.current_patcher.prepare_state(timestep, model_options)
|
model.current_patcher.prepare_state(timestep, model_options)
|
||||||
|
|
||||||
devices = [dev_m for dev_m in model_options['multigpu_clones'].keys()]
|
devices = list(model_options['multigpu_clones'].keys())
|
||||||
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
|
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
|
||||||
|
# Track conds currently scheduled per device; single source of truth for capacity checks.
|
||||||
|
device_load: dict[torch.device, int] = {d: 0 for d in devices}
|
||||||
|
|
||||||
total_conds = 0
|
total_conds = sum(len(to_run) for to_run in hooked_to_run.values())
|
||||||
for to_run in hooked_to_run.values():
|
|
||||||
total_conds += len(to_run)
|
|
||||||
conds_per_device = max(1, math.ceil(total_conds / len(devices)))
|
conds_per_device = max(1, math.ceil(total_conds / len(devices)))
|
||||||
index_device = 0
|
|
||||||
current_device = devices[index_device]
|
def next_available_device(start: int) -> tuple[int, torch.device]:
|
||||||
|
"""Return (index, device) for the next device with remaining capacity, starting at `start`.
|
||||||
|
|
||||||
|
Scans at most len(devices) positions, so this always terminates. Raises if no device
|
||||||
|
has remaining capacity, which would indicate a bug in conds_per_device accounting.
|
||||||
|
"""
|
||||||
|
for offset in range(len(devices)):
|
||||||
|
i = (start + offset) % len(devices)
|
||||||
|
if device_load[devices[i]] < conds_per_device:
|
||||||
|
return i, devices[i]
|
||||||
|
raise RuntimeError(
|
||||||
|
f"MultiGPU scheduler: all {len(devices)} devices at capacity "
|
||||||
|
f"({conds_per_device}) but conds remain to schedule"
|
||||||
|
)
|
||||||
|
|
||||||
# run every hooked_to_run separately
|
# run every hooked_to_run separately
|
||||||
|
index_device = 0
|
||||||
for hooks, to_run in hooked_to_run.items():
|
for hooks, to_run in hooked_to_run.items():
|
||||||
while len(to_run) > 0:
|
while len(to_run) > 0:
|
||||||
current_device = devices[index_device % len(devices)]
|
index_device, current_device = next_available_device(index_device)
|
||||||
batched_to_run = device_batched_hooked_to_run.setdefault(current_device, [])
|
remaining_capacity = conds_per_device - device_load[current_device]
|
||||||
# keep track of conds currently scheduled onto this device
|
|
||||||
batched_to_run_length = 0
|
|
||||||
for btr in batched_to_run:
|
|
||||||
batched_to_run_length += len(btr[1])
|
|
||||||
remaining_capacity = conds_per_device - batched_to_run_length
|
|
||||||
if remaining_capacity <= 0:
|
|
||||||
index_device += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
first = to_run[0]
|
first = to_run[0]
|
||||||
first_shape = first[0][0].shape
|
first_shape = first[0][0].shape
|
||||||
|
# collect candidate indices that can be concatenated with `first`, up to remaining capacity
|
||||||
to_batch_temp = []
|
to_batch_temp = []
|
||||||
# make sure not over conds_per_device limit when creating temp batch
|
|
||||||
for x in range(len(to_run)):
|
for x in range(len(to_run)):
|
||||||
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity:
|
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity:
|
||||||
to_batch_temp += [x]
|
to_batch_temp += [x]
|
||||||
@ -429,13 +436,12 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
|
|||||||
if model.memory_required(input_shape) * 1.5 < free_memory:
|
if model.memory_required(input_shape) * 1.5 < free_memory:
|
||||||
to_batch = batch_amount
|
to_batch = batch_amount
|
||||||
break
|
break
|
||||||
conds_to_batch = []
|
|
||||||
for x in to_batch:
|
|
||||||
conds_to_batch.append(to_run.pop(x))
|
|
||||||
batched_to_run_length += len(conds_to_batch)
|
|
||||||
|
|
||||||
batched_to_run.append((hooks, conds_to_batch))
|
conds_to_batch = [to_run.pop(x) for x in to_batch]
|
||||||
if batched_to_run_length >= conds_per_device:
|
device_load[current_device] += len(conds_to_batch)
|
||||||
|
device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch))
|
||||||
|
|
||||||
|
if device_load[current_device] >= conds_per_device:
|
||||||
index_device += 1
|
index_device += 1
|
||||||
|
|
||||||
class thread_result(NamedTuple):
|
class thread_result(NamedTuple):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user