fix(multigpu): replace hardcoded torch.cuda.set_device with device-agnostic set_torch_device (#14191)
Some checks failed
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled

This commit is contained in:
savvadesogle 2026-05-31 04:18:42 +03:00 committed by GitHub
parent 81aa5a38b2
commit cd45f42a83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 6 deletions

View File

@ -1716,6 +1716,13 @@ def is_device_xpu(device):
def is_device_cuda(device): def is_device_cuda(device):
return is_device_type(device, 'cuda') return is_device_type(device, 'cuda')
def set_torch_device(device):
"""Set the current device for the given torch device. Supports CUDA and XPU."""
if is_device_cuda(device):
torch.cuda.set_device(device)
elif is_device_xpu(device):
torch.xpu.set_device(device)
def is_directml_enabled(): def is_directml_enabled():
global directml_enabled global directml_enabled
if directml_enabled: if directml_enabled:

View File

@ -17,7 +17,7 @@ class MultiGPUThreadPool:
"""Persistent thread pool for multi-GPU work distribution. """Persistent thread pool for multi-GPU work distribution.
Maintains one worker thread per extra GPU device. Each thread calls Maintains one worker thread per extra GPU device. Each thread calls
torch.cuda.set_device() once at startup so that compiled kernel caches set_torch_device() once at startup so that compiled kernel caches
(inductor/triton) stay warm across diffusion steps. (inductor/triton) stay warm across diffusion steps.
""" """
@ -37,7 +37,7 @@ class MultiGPUThreadPool:
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue): def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
try: try:
torch.cuda.set_device(device) comfy.model_management.set_torch_device(device)
except Exception as e: except Exception as e:
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}") logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
while True: while True:

View File

@ -464,10 +464,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
try: try:
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once comfy.model_management.set_torch_device(device)
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
# XPU/NPU/MPS/CPU/DirectML backends.
torch.cuda.set_device(device)
model_current: BaseModel = model_options["multigpu_clones"][device].model model_current: BaseModel = model_options["multigpu_clones"][device].model
# run every hooked_to_run separately # run every hooked_to_run separately
with torch.no_grad(): with torch.no_grad():