diff --git a/comfy/model_management.py b/comfy/model_management.py index 8fb1d7fbc..dfd58bf1b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1716,6 +1716,13 @@ def is_device_xpu(device): def is_device_cuda(device): return is_device_type(device, 'cuda') +def set_torch_device(device): + """Set the current device for the given torch device. Supports CUDA and XPU.""" + if is_device_cuda(device): + torch.cuda.set_device(device) + elif is_device_xpu(device): + torch.xpu.set_device(device) + def is_directml_enabled(): global directml_enabled if directml_enabled: diff --git a/comfy/multigpu.py b/comfy/multigpu.py index e7f5b3d6f..bb9d334d3 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -17,7 +17,7 @@ class MultiGPUThreadPool: """Persistent thread pool for multi-GPU work distribution. Maintains one worker thread per extra GPU device. Each thread calls - torch.cuda.set_device() once at startup so that compiled kernel caches + set_torch_device() once at startup so that compiled kernel caches (inductor/triton) stay warm across diffusion steps. """ @@ -37,7 +37,7 @@ class MultiGPUThreadPool: def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue): try: - torch.cuda.set_device(device) + comfy.model_management.set_torch_device(device) except Exception as e: logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}") while True: diff --git a/comfy/samplers.py b/comfy/samplers.py index e31277f7b..25c5a855f 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -464,10 +464,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): try: - # TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once - # we extend multigpu QA beyond CUDA. Unconditional call crashes on - # XPU/NPU/MPS/CPU/DirectML backends. - torch.cuda.set_device(device) + comfy.model_management.set_torch_device(device) model_current: BaseModel = model_options["multigpu_clones"][device].model # run every hooked_to_run separately with torch.no_grad():