mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 05:40:49 +08:00
Fix tests and ruff.
This commit is contained in:
parent
3c7b599222
commit
134d163ab6
@ -1,7 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
import dataclasses
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import comfy_kitchen as ck
|
import comfy_kitchen as ck
|
||||||
@ -130,7 +128,6 @@ __all__ = [
|
|||||||
"TensorCoreFP8E4M3Layout",
|
"TensorCoreFP8E4M3Layout",
|
||||||
"TensorCoreFP8E5M2Layout",
|
"TensorCoreFP8E5M2Layout",
|
||||||
"TensorCoreNVFP4Layout",
|
"TensorCoreNVFP4Layout",
|
||||||
"LAYOUTS",
|
|
||||||
"QUANT_ALGOS",
|
"QUANT_ALGOS",
|
||||||
"register_layout_op",
|
"register_layout_op",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -103,18 +103,18 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
|
|
||||||
# Verify weights are wrapped in QuantizedTensor
|
# Verify weights are wrapped in QuantizedTensor
|
||||||
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
||||||
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
|
self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout")
|
||||||
|
|
||||||
# Layer 2 should NOT be quantized
|
# Layer 2 should NOT be quantized
|
||||||
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
||||||
|
|
||||||
# Layer 3 should be quantized
|
# Layer 3 should be quantized
|
||||||
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
||||||
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
|
self.assertEqual(model.layer3.weight._layout_cls, "TensorCoreFP8E4M3Layout")
|
||||||
|
|
||||||
# Verify scales were loaded
|
# Verify scales were loaded
|
||||||
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
self.assertEqual(model.layer1.weight._params.scale.item(), 2.0)
|
||||||
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
|
self.assertEqual(model.layer3.weight._params.scale.item(), 1.5)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||||
@ -154,8 +154,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
|
|
||||||
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||||
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||||
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0)
|
||||||
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
|
self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout")
|
||||||
|
|
||||||
# Verify non-quantized layers are standard tensors
|
# Verify non-quantized layers are standard tensors
|
||||||
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||||
|
|||||||
@ -1,190 +0,0 @@
|
|||||||
import unittest
|
|
||||||
import torch
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Add comfy to path
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
||||||
|
|
||||||
def has_gpu():
|
|
||||||
return torch.cuda.is_available()
|
|
||||||
|
|
||||||
from comfy.cli_args import args
|
|
||||||
if not has_gpu():
|
|
||||||
args.cpu = True
|
|
||||||
|
|
||||||
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
|
||||||
|
|
||||||
|
|
||||||
class TestQuantizedTensor(unittest.TestCase):
|
|
||||||
"""Test the QuantizedTensor subclass with FP8 layout"""
|
|
||||||
|
|
||||||
def test_creation(self):
|
|
||||||
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
|
|
||||||
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
||||||
scale = torch.tensor(2.0)
|
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
|
||||||
|
|
||||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
||||||
|
|
||||||
self.assertIsInstance(qt, QuantizedTensor)
|
|
||||||
self.assertEqual(qt.shape, (256, 128))
|
|
||||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
|
||||||
self.assertEqual(qt._layout_params['scale'], scale)
|
|
||||||
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
|
||||||
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
|
|
||||||
|
|
||||||
def test_dequantize(self):
|
|
||||||
"""Test explicit dequantization"""
|
|
||||||
|
|
||||||
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
||||||
scale = torch.tensor(3.0)
|
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
||||||
|
|
||||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
||||||
dequantized = qt.dequantize()
|
|
||||||
|
|
||||||
self.assertEqual(dequantized.dtype, torch.float32)
|
|
||||||
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
|
||||||
|
|
||||||
def test_from_float(self):
|
|
||||||
"""Test creating QuantizedTensor from float tensor"""
|
|
||||||
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
|
||||||
scale = torch.tensor(1.5)
|
|
||||||
|
|
||||||
qt = QuantizedTensor.from_float(
|
|
||||||
float_tensor,
|
|
||||||
"TensorCoreFP8Layout",
|
|
||||||
scale=scale,
|
|
||||||
dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsInstance(qt, QuantizedTensor)
|
|
||||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
|
||||||
self.assertEqual(qt.shape, (64, 32))
|
|
||||||
|
|
||||||
# Verify dequantization gives approximately original values
|
|
||||||
dequantized = qt.dequantize()
|
|
||||||
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
|
|
||||||
self.assertLess(mean_rel_error, 0.1)
|
|
||||||
|
|
||||||
|
|
||||||
class TestGenericUtilities(unittest.TestCase):
|
|
||||||
"""Test generic utility operations"""
|
|
||||||
|
|
||||||
def test_detach(self):
|
|
||||||
"""Test detach operation on quantized tensor"""
|
|
||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
||||||
scale = torch.tensor(1.5)
|
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
||||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
||||||
|
|
||||||
# Detach should return a new QuantizedTensor
|
|
||||||
qt_detached = qt.detach()
|
|
||||||
|
|
||||||
self.assertIsInstance(qt_detached, QuantizedTensor)
|
|
||||||
self.assertEqual(qt_detached.shape, qt.shape)
|
|
||||||
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
|
|
||||||
|
|
||||||
def test_clone(self):
|
|
||||||
"""Test clone operation on quantized tensor"""
|
|
||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
||||||
scale = torch.tensor(1.5)
|
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
||||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
||||||
|
|
||||||
# Clone should return a new QuantizedTensor
|
|
||||||
qt_cloned = qt.clone()
|
|
||||||
|
|
||||||
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
|
||||||
self.assertEqual(qt_cloned.shape, qt.shape)
|
|
||||||
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
|
|
||||||
|
|
||||||
# Verify it's a deep copy
|
|
||||||
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
|
||||||
|
|
||||||
@unittest.skipUnless(has_gpu(), "GPU not available")
|
|
||||||
def test_to_device(self):
|
|
||||||
"""Test device transfer"""
|
|
||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
||||||
scale = torch.tensor(1.5)
|
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
||||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
||||||
|
|
||||||
# Moving to same device should work (CPU to CPU)
|
|
||||||
qt_cpu = qt.to('cpu')
|
|
||||||
|
|
||||||
self.assertIsInstance(qt_cpu, QuantizedTensor)
|
|
||||||
self.assertEqual(qt_cpu.device.type, 'cpu')
|
|
||||||
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
|
|
||||||
|
|
||||||
|
|
||||||
class TestTensorCoreFP8Layout(unittest.TestCase):
|
|
||||||
"""Test the TensorCoreFP8Layout implementation"""
|
|
||||||
|
|
||||||
def test_quantize(self):
|
|
||||||
"""Test quantization method"""
|
|
||||||
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
|
||||||
scale = torch.tensor(1.5)
|
|
||||||
|
|
||||||
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
|
||||||
float_tensor,
|
|
||||||
scale=scale,
|
|
||||||
dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
|
|
||||||
self.assertEqual(qdata.shape, float_tensor.shape)
|
|
||||||
self.assertIn('scale', layout_params)
|
|
||||||
self.assertIn('orig_dtype', layout_params)
|
|
||||||
self.assertEqual(layout_params['orig_dtype'], torch.float32)
|
|
||||||
|
|
||||||
def test_dequantize(self):
|
|
||||||
"""Test dequantization method"""
|
|
||||||
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
|
|
||||||
scale = torch.tensor(1.0)
|
|
||||||
|
|
||||||
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
|
||||||
float_tensor,
|
|
||||||
scale=scale,
|
|
||||||
dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
|
|
||||||
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
|
|
||||||
|
|
||||||
# Should approximately match original
|
|
||||||
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
|
|
||||||
|
|
||||||
|
|
||||||
class TestFallbackMechanism(unittest.TestCase):
|
|
||||||
"""Test fallback for unsupported operations"""
|
|
||||||
|
|
||||||
def test_unsupported_op_dequantizes(self):
|
|
||||||
"""Test that unsupported operations fall back to dequantization"""
|
|
||||||
# Set seed for reproducibility
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
# Create quantized tensor
|
|
||||||
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
|
||||||
scale = torch.tensor(1.0)
|
|
||||||
a_q = QuantizedTensor.from_float(
|
|
||||||
a_fp32,
|
|
||||||
"TensorCoreFP8Layout",
|
|
||||||
scale=scale,
|
|
||||||
dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call an operation that doesn't have a registered handler
|
|
||||||
# For example, torch.abs
|
|
||||||
result = torch.abs(a_q)
|
|
||||||
|
|
||||||
# Should work via fallback (dequantize → abs → return)
|
|
||||||
self.assertNotIsInstance(result, QuantizedTensor)
|
|
||||||
expected = torch.abs(a_fp32)
|
|
||||||
# FP8 introduces quantization error, so use loose tolerance
|
|
||||||
mean_error = (result - expected).abs().mean()
|
|
||||||
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
Loading…
Reference in New Issue
Block a user