diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index b1f8ac010..cd737726f 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,7 +1,5 @@ import torch import logging -import dataclasses -from typing import Dict try: import comfy_kitchen as ck @@ -130,7 +128,6 @@ __all__ = [ "TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreNVFP4Layout", - "LAYOUTS", "QUANT_ALGOS", "register_layout_op", ] diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 3a54941e6..7b2eac940 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -103,18 +103,18 @@ class TestMixedPrecisionOps(unittest.TestCase): # Verify weights are wrapped in QuantizedTensor self.assertIsInstance(model.layer1.weight, QuantizedTensor) - self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout") + self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout") # Layer 2 should NOT be quantized self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) # Layer 3 should be quantized self.assertIsInstance(model.layer3.weight, QuantizedTensor) - self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout") + self.assertEqual(model.layer3.weight._layout_cls, "TensorCoreFP8E4M3Layout") # Verify scales were loaded - self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) - self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) + self.assertEqual(model.layer1.weight._params.scale.item(), 2.0) + self.assertEqual(model.layer3.weight._params.scale.item(), 1.5) # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) @@ -154,8 +154,8 @@ class TestMixedPrecisionOps(unittest.TestCase): # Verify layer1.weight is a QuantizedTensor with scale preserved self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) - self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) - self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout") + self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0) + self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout") # Verify non-quantized layers are standard tensors self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py deleted file mode 100644 index 9cb54ede8..000000000 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ /dev/null @@ -1,190 +0,0 @@ -import unittest -import torch -import sys -import os - -# Add comfy to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - -def has_gpu(): - return torch.cuda.is_available() - -from comfy.cli_args import args -if not has_gpu(): - args.cpu = True - -from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout - - -class TestQuantizedTensor(unittest.TestCase): - """Test the QuantizedTensor subclass with FP8 layout""" - - def test_creation(self): - """Test creating a QuantizedTensor with TensorCoreFP8Layout""" - fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(2.0) - layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} - - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - self.assertIsInstance(qt, QuantizedTensor) - self.assertEqual(qt.shape, (256, 128)) - self.assertEqual(qt.dtype, torch.float8_e4m3fn) - self.assertEqual(qt._layout_params['scale'], scale) - self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) - self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") - - def test_dequantize(self): - """Test explicit dequantization""" - - fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(3.0) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - dequantized = qt.dequantize() - - self.assertEqual(dequantized.dtype, torch.float32) - self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) - - def test_from_float(self): - """Test creating QuantizedTensor from float tensor""" - float_tensor = torch.randn(64, 32, dtype=torch.float32) - scale = torch.tensor(1.5) - - qt = QuantizedTensor.from_float( - float_tensor, - "TensorCoreFP8Layout", - scale=scale, - dtype=torch.float8_e4m3fn - ) - - self.assertIsInstance(qt, QuantizedTensor) - self.assertEqual(qt.dtype, torch.float8_e4m3fn) - self.assertEqual(qt.shape, (64, 32)) - - # Verify dequantization gives approximately original values - dequantized = qt.dequantize() - mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() - self.assertLess(mean_rel_error, 0.1) - - -class TestGenericUtilities(unittest.TestCase): - """Test generic utility operations""" - - def test_detach(self): - """Test detach operation on quantized tensor""" - fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(1.5) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - # Detach should return a new QuantizedTensor - qt_detached = qt.detach() - - self.assertIsInstance(qt_detached, QuantizedTensor) - self.assertEqual(qt_detached.shape, qt.shape) - self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout") - - def test_clone(self): - """Test clone operation on quantized tensor""" - fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(1.5) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - # Clone should return a new QuantizedTensor - qt_cloned = qt.clone() - - self.assertIsInstance(qt_cloned, QuantizedTensor) - self.assertEqual(qt_cloned.shape, qt.shape) - self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout") - - # Verify it's a deep copy - self.assertIsNot(qt_cloned._qdata, qt._qdata) - - @unittest.skipUnless(has_gpu(), "GPU not available") - def test_to_device(self): - """Test device transfer""" - fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(1.5) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - # Moving to same device should work (CPU to CPU) - qt_cpu = qt.to('cpu') - - self.assertIsInstance(qt_cpu, QuantizedTensor) - self.assertEqual(qt_cpu.device.type, 'cpu') - self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') - - -class TestTensorCoreFP8Layout(unittest.TestCase): - """Test the TensorCoreFP8Layout implementation""" - - def test_quantize(self): - """Test quantization method""" - float_tensor = torch.randn(32, 64, dtype=torch.float32) - scale = torch.tensor(1.5) - - qdata, layout_params = TensorCoreFP8Layout.quantize( - float_tensor, - scale=scale, - dtype=torch.float8_e4m3fn - ) - - self.assertEqual(qdata.dtype, torch.float8_e4m3fn) - self.assertEqual(qdata.shape, float_tensor.shape) - self.assertIn('scale', layout_params) - self.assertIn('orig_dtype', layout_params) - self.assertEqual(layout_params['orig_dtype'], torch.float32) - - def test_dequantize(self): - """Test dequantization method""" - float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 - scale = torch.tensor(1.0) - - qdata, layout_params = TensorCoreFP8Layout.quantize( - float_tensor, - scale=scale, - dtype=torch.float8_e4m3fn - ) - - dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) - - # Should approximately match original - self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) - - -class TestFallbackMechanism(unittest.TestCase): - """Test fallback for unsupported operations""" - - def test_unsupported_op_dequantizes(self): - """Test that unsupported operations fall back to dequantization""" - # Set seed for reproducibility - torch.manual_seed(42) - - # Create quantized tensor - a_fp32 = torch.randn(10, 20, dtype=torch.float32) - scale = torch.tensor(1.0) - a_q = QuantizedTensor.from_float( - a_fp32, - "TensorCoreFP8Layout", - scale=scale, - dtype=torch.float8_e4m3fn - ) - - # Call an operation that doesn't have a registered handler - # For example, torch.abs - result = torch.abs(a_q) - - # Should work via fallback (dequantize → abs → return) - self.assertNotIsInstance(result, QuantizedTensor) - expected = torch.abs(a_fp32) - # FP8 introduces quantization error, so use loose tolerance - mean_error = (result - expected).abs().mean() - self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") - - -if __name__ == "__main__": - unittest.main()