Mark non-NVIDIA multigpu gaps with TODOs in _handle_batch

Two CodeRabbit findings from #7063 (#13 and #14) are deferred because
worksplit-multigpu's initial release scope is NVIDIA-only QA. Leave a
TODO at the unconditional torch.cuda.set_device call and at the
post-aggregation point so the required guards/synchronize are easy to
find when multigpu support is extended to XPU/NPU/MPS/CPU/DirectML.

Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Jedrzej Kosinski 2026-05-21 12:47:43 -07:00
parent d0b9dbb5a6
commit 2ed396c769

View File

@ -465,6 +465,9 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
try: try:
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
# XPU/NPU/MPS/CPU/DirectML backends.
torch.cuda.set_device(device) torch.cuda.set_device(device)
model_current: BaseModel = model_options["multigpu_clones"][device].model model_current: BaseModel = model_options["multigpu_clones"][device].model
# run every hooked_to_run separately # run every hooked_to_run separately
@ -524,6 +527,12 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
else: else:
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks) output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
# TODO: non-NVIDIA support -- the `.to(output_device)` copies
# above are async on CUDA, so the main thread's aggregation
# could race with in-flight transfers. CUDA-only QA has not
# surfaced this in practice, but before extending multigpu
# beyond NVIDIA add a `torch.cuda.synchronize(output_device)`
# here (guarded by `output_device.type == "cuda"`).
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond)) results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
except Exception as e: except Exception as e:
results.append(thread_result(None, None, None, None, None, error=e)) results.append(thread_result(None, None, None, None, None, error=e))