From 2ed396c769bfce4c668a840672cc44c701051dbd Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 21 May 2026 12:47:43 -0700 Subject: [PATCH] Mark non-NVIDIA multigpu gaps with TODOs in _handle_batch Two CodeRabbit findings from #7063 (#13 and #14) are deferred because worksplit-multigpu's initial release scope is NVIDIA-only QA. Leave a TODO at the unconditional torch.cuda.set_device call and at the post-aggregation point so the required guards/synchronize are easy to find when multigpu support is extended to XPU/NPU/MPS/CPU/DirectML. Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp --- comfy/samplers.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/comfy/samplers.py b/comfy/samplers.py index 6fd0387d5..42b05f3ba 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -465,6 +465,9 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): try: + # TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once + # we extend multigpu QA beyond CUDA. Unconditional call crashes on + # XPU/NPU/MPS/CPU/DirectML backends. torch.cuda.set_device(device) model_current: BaseModel = model_options["multigpu_clones"][device].model # run every hooked_to_run separately @@ -524,6 +527,12 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks) else: output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks) + # TODO: non-NVIDIA support -- the `.to(output_device)` copies + # above are async on CUDA, so the main thread's aggregation + # could race with in-flight transfers. CUDA-only QA has not + # surfaced this in practice, but before extending multigpu + # beyond NVIDIA add a `torch.cuda.synchronize(output_device)` + # here (guarded by `output_device.type == "cuda"`). results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond)) except Exception as e: results.append(thread_result(None, None, None, None, None, error=e))