Fix get_all_torch_devices for XPU/NPU and guard remove()

torch.device(i) defaults to CUDA, so XPU/NPU branches were producing 'cuda:N' devices that don't match get_torch_device() output ('xpu:N'/'npu:N'). This caused devices.remove(get_torch_device()) to raise ValueError when exclude_current=True on non-NVIDIA hardware. Use explicit device strings, and guard the remove() with a membership check for safety.

Amp-Thread-ID: https://ampcode.com/threads/T-019e43b8-8258-70fd-ab3a-53e4c97f85d5
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Jedrzej Kosinski 2026-05-20 16:46:38 -07:00
parent 9a681ccfc9
commit ba417750a7

View File

@ -215,17 +215,19 @@ def get_all_torch_devices(exclude_current=False):
if cpu_state == CPUState.GPU:
if is_nvidia():
for i in range(torch.cuda.device_count()):
devices.append(torch.device(i))
devices.append(torch.device("cuda", i))
elif is_intel_xpu():
for i in range(torch.xpu.device_count()):
devices.append(torch.device(i))
devices.append(torch.device("xpu", i))
elif is_ascend_npu():
for i in range(torch.npu.device_count()):
devices.append(torch.device(i))
devices.append(torch.device("npu", i))
else:
devices.append(get_torch_device())
if exclude_current:
devices.remove(get_torch_device())
current = get_torch_device()
if current in devices:
devices.remove(current)
return devices
def get_gpu_device_options():