ComfyUI/tests-unit/comfy_test/model_management_test.py
Tsondo bb31f8b707 fix: per-device fp8/nvfp4 compute detection for multi-GPU setups
supports_fp8_compute() and supports_nvfp4_compute() used the global
is_nvidia() check which ignores the device argument, then defaulted
to cuda:0 when device was None. In heterogeneous multi-GPU setups
(e.g. RTX 5070 + RTX 3090 Ti) this causes the wrong GPU's compute
capability to be checked, incorrectly disabling fp8 on capable
devices.

Replace the global is_nvidia() gate with per-device checks:
- Default device=None to get_torch_device() explicitly
- Early-return False for CPU/MPS devices
- Use is_device_cuda(device) + torch.version.cuda instead of
  the global is_nvidia()

Fixes #4589, relates to #4577, #12405

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-14 22:56:42 +01:00

110 lines
4.7 KiB
Python

import pytest
from unittest.mock import patch, MagicMock
import torch
import comfy.model_management as mm
class FakeDeviceProps:
"""Minimal stand-in for torch.cuda.get_device_properties return value."""
def __init__(self, major, minor, name="FakeGPU"):
self.major = major
self.minor = minor
self.name = name
class TestSupportsFp8Compute:
"""Tests for per-device fp8 compute capability detection."""
def test_cpu_device_returns_false(self):
assert mm.supports_fp8_compute(torch.device("cpu")) is False
@pytest.mark.skipif(not hasattr(torch.backends, "mps"), reason="MPS backend not available")
def test_mps_device_returns_false(self):
assert mm.supports_fp8_compute(torch.device("mps")) is False
@patch("comfy.model_management.SUPPORT_FP8_OPS", True)
def test_cli_override_returns_true(self):
assert mm.supports_fp8_compute(torch.device("cpu")) is True
@patch("comfy.model_management.get_torch_device", return_value=torch.device("cpu"))
def test_none_device_defaults_to_get_torch_device(self, mock_get):
result = mm.supports_fp8_compute(None)
mock_get.assert_called_once()
assert result is False
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_each_cuda_device_checked_independently(self):
"""On a multi-GPU system, each device should be queried for its own capabilities."""
count = torch.cuda.device_count()
if count < 2:
pytest.skip("Need 2+ CUDA devices for multi-GPU test")
results = {}
for i in range(count):
dev = torch.device(f"cuda:{i}")
results[i] = mm.supports_fp8_compute(dev)
props = torch.cuda.get_device_properties(dev)
# Verify the result is consistent with the device's compute capability
if props.major >= 9:
assert results[i] is True, f"cuda:{i} ({props.name}) has SM {props.major}.{props.minor}, should support fp8"
elif props.major < 8 or props.minor < 9:
assert results[i] is False, f"cuda:{i} ({props.name}) has SM {props.major}.{props.minor}, should not support fp8"
@patch("torch.version.cuda", None)
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
def test_rocm_build_returns_false(self):
"""On ROCm, devices appear as cuda:N via HIP but torch.version.cuda is None."""
dev = MagicMock()
dev.type = "cuda"
assert mm.supports_fp8_compute(dev) is False
@patch("torch.version.cuda", "12.4")
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
@patch("torch.cuda.get_device_properties")
def test_sm89_supports_fp8(self, mock_props):
"""Ada Lovelace (SM 8.9, e.g. RTX 4080) should support fp8."""
mock_props.return_value = FakeDeviceProps(major=8, minor=9)
dev = torch.device("cuda:0")
assert mm.supports_fp8_compute(dev) is True
@patch("torch.version.cuda", "12.4")
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
@patch("torch.cuda.get_device_properties")
def test_sm86_does_not_support_fp8(self, mock_props):
"""Ampere (SM 8.6, e.g. RTX 3090) should not support fp8."""
mock_props.return_value = FakeDeviceProps(major=8, minor=6)
dev = torch.device("cuda:0")
assert mm.supports_fp8_compute(dev) is False
@patch("torch.version.cuda", "12.4")
@patch("comfy.model_management.SUPPORT_FP8_OPS", False)
@patch("torch.cuda.get_device_properties")
def test_sm90_supports_fp8(self, mock_props):
"""Hopper (SM 9.0) and above should support fp8."""
mock_props.return_value = FakeDeviceProps(major=9, minor=0)
dev = torch.device("cuda:0")
assert mm.supports_fp8_compute(dev) is True
class TestSupportsNvfp4Compute:
"""Tests for per-device nvfp4 compute capability detection."""
def test_cpu_device_returns_false(self):
assert mm.supports_nvfp4_compute(torch.device("cpu")) is False
@patch("torch.version.cuda", "12.4")
@patch("torch.cuda.get_device_properties")
def test_sm100_supports_nvfp4(self, mock_props):
"""Blackwell (SM 10.0) should support nvfp4."""
mock_props.return_value = FakeDeviceProps(major=10, minor=0)
dev = torch.device("cuda:0")
assert mm.supports_nvfp4_compute(dev) is True
@patch("torch.version.cuda", "12.4")
@patch("torch.cuda.get_device_properties")
def test_sm89_does_not_support_nvfp4(self, mock_props):
"""Ada Lovelace (SM 8.9) should not support nvfp4."""
mock_props.return_value = FakeDeviceProps(major=8, minor=9)
dev = torch.device("cuda:0")
assert mm.supports_nvfp4_compute(dev) is False