mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-04 13:27:41 +08:00
Merge remote-tracking branch 'upstream/master' into gaussians
This commit is contained in:
commit
3fd3269a8e
@ -1716,6 +1716,13 @@ def is_device_xpu(device):
|
||||
def is_device_cuda(device):
|
||||
return is_device_type(device, 'cuda')
|
||||
|
||||
def set_torch_device(device):
|
||||
"""Set the current device for the given torch device. Supports CUDA and XPU."""
|
||||
if is_device_cuda(device):
|
||||
torch.cuda.set_device(device)
|
||||
elif is_device_xpu(device):
|
||||
torch.xpu.set_device(device)
|
||||
|
||||
def is_directml_enabled():
|
||||
global directml_enabled
|
||||
if directml_enabled:
|
||||
|
||||
@ -17,7 +17,7 @@ class MultiGPUThreadPool:
|
||||
"""Persistent thread pool for multi-GPU work distribution.
|
||||
|
||||
Maintains one worker thread per extra GPU device. Each thread calls
|
||||
torch.cuda.set_device() once at startup so that compiled kernel caches
|
||||
set_torch_device() once at startup so that compiled kernel caches
|
||||
(inductor/triton) stay warm across diffusion steps.
|
||||
"""
|
||||
|
||||
@ -37,7 +37,7 @@ class MultiGPUThreadPool:
|
||||
|
||||
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
comfy.model_management.set_torch_device(device)
|
||||
except Exception as e:
|
||||
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
|
||||
while True:
|
||||
|
||||
@ -464,10 +464,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
|
||||
|
||||
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||
try:
|
||||
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
|
||||
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
|
||||
# XPU/NPU/MPS/CPU/DirectML backends.
|
||||
torch.cuda.set_device(device)
|
||||
comfy.model_management.set_torch_device(device)
|
||||
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||
# run every hooked_to_run separately
|
||||
with torch.no_grad():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user