Fix unittests for CPU build

This commit is contained in:
lspindler 2025-10-28 08:02:26 +01:00
parent 135d3025ea
commit 9d9f98cb72
2 changed files with 17 additions and 2 deletions

View File

@ -6,6 +6,13 @@ import os
# Add comfy to path # Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
def has_gpu():
return torch.cuda.is_available()
from comfy.cli_args import args
if not has_gpu():
args.cpu = True
from comfy import ops from comfy import ops
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout

View File

@ -6,6 +6,13 @@ import os
# Add comfy to path # Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
def has_gpu():
return torch.cuda.is_available()
from comfy.cli_args import args
if not has_gpu():
args.cpu = True
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
@ -49,7 +56,7 @@ class TestQuantizedTensor(unittest.TestCase):
float_tensor, float_tensor,
TensorCoreFP8Layout, TensorCoreFP8Layout,
scale=scale, scale=scale,
fp8_dtype=torch.float8_e4m3fn dtype=torch.float8_e4m3fn
) )
self.assertIsInstance(qt, QuantizedTensor) self.assertIsInstance(qt, QuantizedTensor)
@ -96,6 +103,7 @@ class TestGenericUtilities(unittest.TestCase):
# Verify it's a deep copy # Verify it's a deep copy
self.assertIsNot(qt_cloned._qdata, qt._qdata) self.assertIsNot(qt_cloned._qdata, qt._qdata)
@unittest.skipUnless(has_gpu(), "GPU not available")
def test_to_device(self): def test_to_device(self):
"""Test device transfer""" """Test device transfer"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
@ -163,7 +171,7 @@ class TestFallbackMechanism(unittest.TestCase):
a_fp32, a_fp32,
TensorCoreFP8Layout, TensorCoreFP8Layout,
scale=scale, scale=scale,
fp8_dtype=torch.float8_e4m3fn dtype=torch.float8_e4m3fn
) )
# Call an operation that doesn't have a registered handler # Call an operation that doesn't have a registered handler