Merge remote-tracking branch 'upstream/master' into gaussians

This commit is contained in:
kijai 2026-05-31 11:01:29 +03:00
commit 3fd3269a8e
3 changed files with 10 additions and 6 deletions

View File

@ -1716,6 +1716,13 @@ def is_device_xpu(device):
def is_device_cuda(device):
return is_device_type(device, 'cuda')
def set_torch_device(device):
"""Set the current device for the given torch device. Supports CUDA and XPU."""
if is_device_cuda(device):
torch.cuda.set_device(device)
elif is_device_xpu(device):
torch.xpu.set_device(device)
def is_directml_enabled():
global directml_enabled
if directml_enabled:

View File

@ -17,7 +17,7 @@ class MultiGPUThreadPool:
"""Persistent thread pool for multi-GPU work distribution.
Maintains one worker thread per extra GPU device. Each thread calls
torch.cuda.set_device() once at startup so that compiled kernel caches
set_torch_device() once at startup so that compiled kernel caches
(inductor/triton) stay warm across diffusion steps.
"""
@ -37,7 +37,7 @@ class MultiGPUThreadPool:
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
try:
torch.cuda.set_device(device)
comfy.model_management.set_torch_device(device)
except Exception as e:
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
while True:

View File

@ -464,10 +464,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
try:
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
# XPU/NPU/MPS/CPU/DirectML backends.
torch.cuda.set_device(device)
comfy.model_management.set_torch_device(device)
model_current: BaseModel = model_options["multigpu_clones"][device].model
# run every hooked_to_run separately
with torch.no_grad():