mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-16 22:58:19 +08:00
fix: per-device fp8/nvfp4 compute detection for multi-GPU setups
supports_fp8_compute() and supports_nvfp4_compute() used the global is_nvidia() check which ignores the device argument, then defaulted to cuda:0 when device was None. In heterogeneous multi-GPU setups (e.g. RTX 5070 + RTX 3090 Ti) this causes the wrong GPU's compute capability to be checked, incorrectly disabling fp8 on capable devices. Replace the global is_nvidia() gate with per-device checks: - Default device=None to get_torch_device() explicitly - Early-return False for CPU/MPS devices - Use is_device_cuda(device) + torch.version.cuda instead of the global is_nvidia() Fixes #4589, relates to #4577, #12405 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
8086468d2a
commit
bb31f8b707
@ -1627,7 +1627,21 @@ def supports_fp8_compute(device=None):
|
||||
if SUPPORT_FP8_OPS:
|
||||
return True
|
||||
|
||||
if not is_nvidia():
|
||||
if device is None:
|
||||
device = get_torch_device()
|
||||
|
||||
if is_device_cpu(device) or is_device_mps(device):
|
||||
return False
|
||||
|
||||
# Per-device check instead of the global is_nvidia(). On ROCm builds,
|
||||
# is_device_cuda() returns True (AMD GPUs appear as cuda:N via HIP) but
|
||||
# torch.version.cuda is None, so this correctly returns False for AMD.
|
||||
# If PyTorch ever supports mixed-vendor GPUs in one process, these
|
||||
# per-device checks remain correct unlike the global is_nvidia().
|
||||
if not is_device_cuda(device):
|
||||
return False
|
||||
|
||||
if not torch.version.cuda:
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
@ -1648,7 +1662,10 @@ def supports_fp8_compute(device=None):
|
||||
return True
|
||||
|
||||
def supports_nvfp4_compute(device=None):
|
||||
if not is_nvidia():
|
||||
if device is None:
|
||||
device = get_torch_device()
|
||||
|
||||
if not is_device_cuda(device) or not torch.version.cuda:
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
|
||||
109
tests-unit/comfy_test/model_management_test.py
Normal file
109
tests-unit/comfy_test/model_management_test.py
Normal file
@ -0,0 +1,109 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import torch
|
||||
|
||||
import comfy.model_management as mm
|
||||
|
||||
|
||||
class FakeDeviceProps:
|
||||
"""Minimal stand-in for torch.cuda.get_device_properties return value."""
|
||||
def __init__(self, major, minor, name="FakeGPU"):
|
||||
self.major = major
|
||||
self.minor = minor
|
||||
self.name = name
|
||||
|
||||
|
||||
class TestSupportsFp8Compute:
|
||||
"""Tests for per-device fp8 compute capability detection."""
|
||||
|
||||
def test_cpu_device_returns_false(self):
|
||||
assert mm.supports_fp8_compute(torch.device("cpu")) is False
|
||||
|
||||
@pytest.mark.skipif(not hasattr(torch.backends, "mps"), reason="MPS backend not available")
|
||||
def test_mps_device_returns_false(self):
|
||||
assert mm.supports_fp8_compute(torch.device("mps")) is False
|
||||
|
||||
@patch("comfy.model_management.SUPPORT_FP8_OPS", True)
|
||||
def test_cli_override_returns_true(self):
|
||||
assert mm.supports_fp8_compute(torch.device("cpu")) is True
|
||||
|
||||
@patch("comfy.model_management.get_torch_device", return_value=torch.device("cpu"))
|
||||
def test_none_device_defaults_to_get_torch_device(self, mock_get):
|
||||
result = mm.supports_fp8_compute(None)
|
||||
mock_get.assert_called_once()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_each_cuda_device_checked_independently(self):
|
||||
"""On a multi-GPU system, each device should be queried for its own capabilities."""
|
||||
count = torch.cuda.device_count()
|
||||
if count < 2:
|
||||
pytest.skip("Need 2+ CUDA devices for multi-GPU test")
|
||||
results = {}
|
||||
for i in range(count):
|
||||
dev = torch.device(f"cuda:{i}")
|
||||
results[i] = mm.supports_fp8_compute(dev)
|
||||
props = torch.cuda.get_device_properties(dev)
|
||||
# Verify the result is consistent with the device's compute capability
|
||||
if props.major >= 9:
|
||||
assert results[i] is True, f"cuda:{i} ({props.name}) has SM {props.major}.{props.minor}, should support fp8"
|
||||
elif props.major < 8 or props.minor < 9:
|
||||
assert results[i] is False, f"cuda:{i} ({props.name}) has SM {props.major}.{props.minor}, should not support fp8"
|
||||
|
||||
@patch("torch.version.cuda", None)
|
||||
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
|
||||
def test_rocm_build_returns_false(self):
|
||||
"""On ROCm, devices appear as cuda:N via HIP but torch.version.cuda is None."""
|
||||
dev = MagicMock()
|
||||
dev.type = "cuda"
|
||||
assert mm.supports_fp8_compute(dev) is False
|
||||
|
||||
@patch("torch.version.cuda", "12.4")
|
||||
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
|
||||
@patch("torch.cuda.get_device_properties")
|
||||
def test_sm89_supports_fp8(self, mock_props):
|
||||
"""Ada Lovelace (SM 8.9, e.g. RTX 4080) should support fp8."""
|
||||
mock_props.return_value = FakeDeviceProps(major=8, minor=9)
|
||||
dev = torch.device("cuda:0")
|
||||
assert mm.supports_fp8_compute(dev) is True
|
||||
|
||||
@patch("torch.version.cuda", "12.4")
|
||||
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
|
||||
@patch("torch.cuda.get_device_properties")
|
||||
def test_sm86_does_not_support_fp8(self, mock_props):
|
||||
"""Ampere (SM 8.6, e.g. RTX 3090) should not support fp8."""
|
||||
mock_props.return_value = FakeDeviceProps(major=8, minor=6)
|
||||
dev = torch.device("cuda:0")
|
||||
assert mm.supports_fp8_compute(dev) is False
|
||||
|
||||
@patch("torch.version.cuda", "12.4")
|
||||
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
|
||||
@patch("torch.cuda.get_device_properties")
|
||||
def test_sm90_supports_fp8(self, mock_props):
|
||||
"""Hopper (SM 9.0) and above should support fp8."""
|
||||
mock_props.return_value = FakeDeviceProps(major=9, minor=0)
|
||||
dev = torch.device("cuda:0")
|
||||
assert mm.supports_fp8_compute(dev) is True
|
||||
|
||||
|
||||
class TestSupportsNvfp4Compute:
|
||||
"""Tests for per-device nvfp4 compute capability detection."""
|
||||
|
||||
def test_cpu_device_returns_false(self):
|
||||
assert mm.supports_nvfp4_compute(torch.device("cpu")) is False
|
||||
|
||||
@patch("torch.version.cuda", "12.4")
|
||||
@patch("torch.cuda.get_device_properties")
|
||||
def test_sm100_supports_nvfp4(self, mock_props):
|
||||
"""Blackwell (SM 10.0) should support nvfp4."""
|
||||
mock_props.return_value = FakeDeviceProps(major=10, minor=0)
|
||||
dev = torch.device("cuda:0")
|
||||
assert mm.supports_nvfp4_compute(dev) is True
|
||||
|
||||
@patch("torch.version.cuda", "12.4")
|
||||
@patch("torch.cuda.get_device_properties")
|
||||
def test_sm89_does_not_support_nvfp4(self, mock_props):
|
||||
"""Ada Lovelace (SM 8.9) should not support nvfp4."""
|
||||
mock_props.return_value = FakeDeviceProps(major=8, minor=9)
|
||||
dev = torch.device("cuda:0")
|
||||
assert mm.supports_nvfp4_compute(dev) is False
|
||||
Loading…
Reference in New Issue
Block a user