mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 16:32:34 +08:00
Fix unittests for CPU build
This commit is contained in:
parent
135d3025ea
commit
9d9f98cb72
@ -6,6 +6,13 @@ import os
|
|||||||
# Add comfy to path
|
# Add comfy to path
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
|
||||||
|
def has_gpu():
|
||||||
|
return torch.cuda.is_available()
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
if not has_gpu():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
from comfy import ops
|
from comfy import ops
|
||||||
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,13 @@ import os
|
|||||||
# Add comfy to path
|
# Add comfy to path
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||||
|
|
||||||
|
def has_gpu():
|
||||||
|
return torch.cuda.is_available()
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
if not has_gpu():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||||
|
|
||||||
|
|
||||||
@ -49,7 +56,7 @@ class TestQuantizedTensor(unittest.TestCase):
|
|||||||
float_tensor,
|
float_tensor,
|
||||||
TensorCoreFP8Layout,
|
TensorCoreFP8Layout,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
fp8_dtype=torch.float8_e4m3fn
|
dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertIsInstance(qt, QuantizedTensor)
|
self.assertIsInstance(qt, QuantizedTensor)
|
||||||
@ -96,6 +103,7 @@ class TestGenericUtilities(unittest.TestCase):
|
|||||||
# Verify it's a deep copy
|
# Verify it's a deep copy
|
||||||
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
||||||
|
|
||||||
|
@unittest.skipUnless(has_gpu(), "GPU not available")
|
||||||
def test_to_device(self):
|
def test_to_device(self):
|
||||||
"""Test device transfer"""
|
"""Test device transfer"""
|
||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
@ -163,7 +171,7 @@ class TestFallbackMechanism(unittest.TestCase):
|
|||||||
a_fp32,
|
a_fp32,
|
||||||
TensorCoreFP8Layout,
|
TensorCoreFP8Layout,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
fp8_dtype=torch.float8_e4m3fn
|
dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call an operation that doesn't have a registered handler
|
# Call an operation that doesn't have a registered handler
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user