mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-24 07:57:29 +08:00
Mark non-NVIDIA multigpu gaps with TODOs in _handle_batch
Two CodeRabbit findings from #7063 (#13 and #14) are deferred because worksplit-multigpu's initial release scope is NVIDIA-only QA. Leave a TODO at the unconditional torch.cuda.set_device call and at the post-aggregation point so the required guards/synchronize are easy to find when multigpu support is extended to XPU/NPU/MPS/CPU/DirectML. Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
d0b9dbb5a6
commit
2ed396c769
@ -465,6 +465,9 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
|
||||
|
||||
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||
try:
|
||||
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
|
||||
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
|
||||
# XPU/NPU/MPS/CPU/DirectML backends.
|
||||
torch.cuda.set_device(device)
|
||||
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||
# run every hooked_to_run separately
|
||||
@ -524,6 +527,12 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
|
||||
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
|
||||
else:
|
||||
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
|
||||
# TODO: non-NVIDIA support -- the `.to(output_device)` copies
|
||||
# above are async on CUDA, so the main thread's aggregation
|
||||
# could race with in-flight transfers. CUDA-only QA has not
|
||||
# surfaced this in practice, but before extending multigpu
|
||||
# beyond NVIDIA add a `torch.cuda.synchronize(output_device)`
|
||||
# here (guarded by `output_device.type == "cuda"`).
|
||||
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
|
||||
except Exception as e:
|
||||
results.append(thread_result(None, None, None, None, None, error=e))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user