diff --git a/comfy/ops.py b/comfy/ops.py index b46e7553d..911228b51 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -566,7 +566,7 @@ class MixedPrecisionOps(disable_weight_init): def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) - self._forward(input, weight, bias) + return self._forward(input, weight, bias) def forward(self, input, *args, **kwargs): run_every_op() diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 8d7f6480a..96d2fa03f 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -143,7 +143,7 @@ class QuantizedTensor(torch.Tensor): """ Tensor flattening protocol for proper device movement. """ - inner_tensors = ["_q_data"] + inner_tensors = ["_qdata"] ctx = { "layout_type": self._layout_type, } @@ -206,7 +206,7 @@ class QuantizedTensor(torch.Tensor): # Step 3: Fallback to dequantization if isinstance(args[0] if args else None, QuantizedTensor): - logging.warning(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") + logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") return cls._dequant_and_fallback(func, args, kwargs) @classmethod diff --git a/tests-unit/comfy_test/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py similarity index 60% rename from tests-unit/comfy_test/test_mixed_precision.py rename to tests-unit/comfy_quant/test_mixed_precision.py index cbfa2866d..e34552760 100644 --- a/tests-unit/comfy_test/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -1,8 +1,3 @@ -""" -End-to-end tests for mixed precision quantization. -Tests Phase 3: Mixed Precision Operations -""" - import unittest import torch import sys @@ -12,10 +7,10 @@ import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) from comfy import ops +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout class SimpleModel(torch.nn.Module): - """Simple model for testing mixed precision""" def __init__(self, operations=ops.disable_weight_init): super().__init__() self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) @@ -32,8 +27,7 @@ class SimpleModel(torch.nn.Module): class TestMixedPrecisionOps(unittest.TestCase): - """Test MixedPrecisionOps end-to-end""" - + def test_all_layers_standard(self): """Test that model with no quantization works normally""" # Configure no quantization @@ -67,48 +61,54 @@ class TestMixedPrecisionOps(unittest.TestCase): # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard layer_quant_config = { "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": False} # Disable for CPU testing + "format": "float8_e4m3fn", + "params": {} }, "layer3": { - "format": "fp8_e5m2_scaled", - "params": {"use_fp8_matmul": False} + "format": "float8_e4m3fn", + "params": {} } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create state dict with mixed precision fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) - fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e5m2) + fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn) state_dict = { # Layer 1: FP8 E4M3FN "layer1.weight": fp8_weight1, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), - "layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), # Layer 2: Standard BF16 "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), - # Layer 3: FP8 E5M2 + # Layer 3: FP8 E4M3FN "layer3.weight": fp8_weight3, "layer3.bias": torch.randn(40, dtype=torch.bfloat16), - "layer3.scale_weight": torch.tensor(1.5, dtype=torch.float32), + "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), } - # Create model and load state dict + # Create model and load state dict (strict=False because custom loading pops keys) model = SimpleModel(operations=ops.MixedPrecisionOps) - model.load_state_dict(state_dict) + model.load_state_dict(state_dict, strict=False) - # Verify handlers are set up correctly - self.assertIsNotNone(model.layer1.quant_handler) - self.assertIsNone(model.layer2.quant_handler) # No quantization - self.assertIsNotNone(model.layer3.quant_handler) + # Verify weights are wrapped in QuantizedTensor + self.assertIsInstance(model.layer1.weight, QuantizedTensor) + self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) + + # Layer 2 should NOT be quantized + self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) + + # Layer 3 should be quantized + self.assertIsInstance(model.layer3.weight, QuantizedTensor) + self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) # Verify scales were loaded - self.assertEqual(model.layer1.scale_weight.item(), 2.0) - self.assertEqual(model.layer3.scale_weight.item(), 1.5) + self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) + self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) @@ -116,13 +116,13 @@ class TestMixedPrecisionOps(unittest.TestCase): self.assertEqual(output.shape, (5, 40)) - def test_state_dict_round_trip(self): - """Test saving and loading state dict preserves quantization""" + def test_state_dict_quantized_preserved(self): + """Test that quantized weights are preserved in state_dict()""" # Configure mixed precision layer_quant_config = { "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": False} + "format": "float8_e4m3fn", + "params": {} } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config @@ -132,45 +132,35 @@ class TestMixedPrecisionOps(unittest.TestCase): state_dict1 = { "layer1.weight": fp8_weight, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), - "layer1.scale_weight": torch.tensor(3.0, dtype=torch.float32), + "layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32), "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model1 = SimpleModel(operations=ops.MixedPrecisionOps) - model1.load_state_dict(state_dict1) + model = SimpleModel(operations=ops.MixedPrecisionOps) + model.load_state_dict(state_dict1, strict=False) # Save state dict - state_dict2 = model1.state_dict() + state_dict2 = model.state_dict() - # Verify scale_weight is saved - self.assertIn("layer1.scale_weight", state_dict2) - self.assertEqual(state_dict2["layer1.scale_weight"].item(), 3.0) + # Verify layer1.weight is a QuantizedTensor with scale preserved + self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) + self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) + self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) - # Load into new model - model2 = SimpleModel(operations=ops.MixedPrecisionOps) - model2.load_state_dict(state_dict2) - - # Verify handler is set up - self.assertIsNotNone(model2.layer1.quant_handler) - self.assertEqual(model2.layer1.scale_weight.item(), 3.0) - - # Verify forward passes match - input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) - output1 = model1(input_tensor) - output2 = model2(input_tensor) - - torch.testing.assert_close(output1, output2, rtol=1e-3, atol=1e-3) + # Verify non-quantized layers are standard tensors + self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) + self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor) def test_weight_function_compatibility(self): """Test that weight_function (LoRA) works with quantized layers""" # Configure FP8 quantization layer_quant_config = { "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": False} + "format": "float8_e4m3fn", + "params": {} } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config @@ -180,7 +170,7 @@ class TestMixedPrecisionOps(unittest.TestCase): state_dict = { "layer1.weight": fp8_weight, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), - "layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), @@ -188,25 +178,24 @@ class TestMixedPrecisionOps(unittest.TestCase): } model = SimpleModel(operations=ops.MixedPrecisionOps) - model.load_state_dict(state_dict) + model.load_state_dict(state_dict, strict=False) # Add a weight function (simulating LoRA) - # LoRA delta must match weight shape (20, 10) + # This should trigger dequantization during forward pass def apply_lora(weight): - # Generate LoRA delta matching weight shape lora_delta = torch.randn_like(weight) * 0.01 return weight + lora_delta model.layer1.weight_function.append(apply_lora) - # Forward pass should work with LoRA + # Forward pass should work with LoRA (triggers weight_function path) input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) self.assertEqual(output.shape, (5, 40)) def test_error_handling_unknown_format(self): - """Test that unknown formats fall back gracefully""" + """Test that unknown formats raise error""" # Configure with unknown format layer_quant_config = { "layer1": { @@ -226,48 +215,10 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - # Load should not crash, just log warning + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS model = SimpleModel(operations=ops.MixedPrecisionOps) - model.load_state_dict(state_dict) - - # Handler should be None (fallback to standard) - self.assertIsNone(model.layer1.quant_handler) - - # Forward pass should still work - input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) - output = model(input_tensor) - self.assertEqual(output.shape, (5, 40)) - - -class TestPickOperationsWithMixedPrecision(unittest.TestCase): - """Test pick_operations with mixed precision config""" - - def test_pick_operations_with_layer_quant_config(self): - """Test that pick_operations returns MixedPrecisionOps when config present""" - from comfy import supported_models_base - - # Create model config with layer_quant_config - model_config = supported_models_base.BASE({}) - model_config.layer_quant_config = { - "layer1": {"format": "fp8_e4m3fn_scaled", "params": {}} - } - - result = ops.pick_operations(None, None, model_config=model_config) - - self.assertEqual(result, ops.MixedPrecisionOps) - self.assertEqual(ops.MixedPrecisionOps._layer_quant_config, model_config.layer_quant_config) - - def test_pick_operations_without_layer_quant_config(self): - """Test that pick_operations falls back to standard when no config""" - from comfy import supported_models_base - - model_config = supported_models_base.BASE({}) - model_config.layer_quant_config = None - - result = ops.pick_operations(None, None, model_config=model_config) - - self.assertEqual(result, ops.disable_weight_init) - + with self.assertRaises(KeyError): + model.load_state_dict(state_dict, strict=False) if __name__ == "__main__": unittest.main() diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py new file mode 100644 index 000000000..263581417 --- /dev/null +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -0,0 +1,183 @@ +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout + + +class TestQuantizedTensor(unittest.TestCase): + """Test the QuantizedTensor subclass with FP8 layout""" + + def test_creation(self): + """Test creating a QuantizedTensor with TensorCoreFP8Layout""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} + + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._layout_params['scale'], scale) + self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) + self.assertEqual(qt._layout_type, TensorCoreFP8Layout) + + def test_dequantize(self): + """Test explicit dequantization""" + + fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(3.0) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + dequantized = qt.dequantize() + + self.assertEqual(dequantized.dtype, torch.float32) + self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + + def test_from_float(self): + """Test creating QuantizedTensor from float tensor""" + float_tensor = torch.randn(64, 32, dtype=torch.float32) + scale = torch.tensor(1.5) + + qt = QuantizedTensor.from_float( + float_tensor, + TensorCoreFP8Layout, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt.shape, (64, 32)) + + # Verify dequantization gives approximately original values + dequantized = qt.dequantize() + mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.1) + + +class TestGenericUtilities(unittest.TestCase): + """Test generic utility operations""" + + def test_detach(self): + """Test detach operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Detach should return a new QuantizedTensor + qt_detached = qt.detach() + + self.assertIsInstance(qt_detached, QuantizedTensor) + self.assertEqual(qt_detached.shape, qt.shape) + self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) + + def test_clone(self): + """Test clone operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Clone should return a new QuantizedTensor + qt_cloned = qt.clone() + + self.assertIsInstance(qt_cloned, QuantizedTensor) + self.assertEqual(qt_cloned.shape, qt.shape) + self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) + + # Verify it's a deep copy + self.assertIsNot(qt_cloned._qdata, qt._qdata) + + def test_to_device(self): + """Test device transfer""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + + # Moving to same device should work (CPU to CPU) + qt_cpu = qt.to('cpu') + + self.assertIsInstance(qt_cpu, QuantizedTensor) + self.assertEqual(qt_cpu.device.type, 'cpu') + self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') + + +class TestTensorCoreFP8Layout(unittest.TestCase): + """Test the TensorCoreFP8Layout implementation""" + + def test_quantize(self): + """Test quantization method""" + float_tensor = torch.randn(32, 64, dtype=torch.float32) + scale = torch.tensor(1.5) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + self.assertEqual(qdata.dtype, torch.float8_e4m3fn) + self.assertEqual(qdata.shape, float_tensor.shape) + self.assertIn('scale', layout_params) + self.assertIn('orig_dtype', layout_params) + self.assertEqual(layout_params['orig_dtype'], torch.float32) + + def test_dequantize(self): + """Test dequantization method""" + float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 + scale = torch.tensor(1.0) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) + + # Should approximately match original + self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) + + +class TestFallbackMechanism(unittest.TestCase): + """Test fallback for unsupported operations""" + + def test_unsupported_op_dequantizes(self): + """Test that unsupported operations fall back to dequantization""" + # Set seed for reproducibility + torch.manual_seed(42) + + # Create quantized tensor + a_fp32 = torch.randn(10, 20, dtype=torch.float32) + scale = torch.tensor(1.0) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + a_q = QuantizedTensor.from_float( + a_fp32, + TensorCoreFP8Layout, + scale=scale, + fp8_dtype=torch.float8_e4m3fn + ) + + # Call an operation that doesn't have a registered handler + # For example, torch.abs + result = torch.abs(a_q) + + # Should work via fallback (dequantize → abs → return) + self.assertNotIsInstance(result, QuantizedTensor) + expected = torch.abs(a_fp32) + # FP8 introduces quantization error, so use loose tolerance + mean_error = (result - expected).abs().mean() + self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests-unit/comfy_test/test_quant_detection.py b/tests-unit/comfy_test/test_quant_detection.py deleted file mode 100644 index bb952a81b..000000000 --- a/tests-unit/comfy_test/test_quant_detection.py +++ /dev/null @@ -1,262 +0,0 @@ -""" -Integration tests for quantization detection. -Tests Phase 2: Detection & Integration -""" - -import unittest -import torch -import sys -import os - -# Add comfy to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - -from comfy import model_detection - - -class TestNormalizeLayerName(unittest.TestCase): - """Test the normalize_layer_name helper function""" - - def test_strip_prefix_and_suffix(self): - """Test stripping prefix and suffix""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "model.diffusion_model.layer1.weight", - known_prefixes - ) - self.assertEqual(result, "layer1") - - def test_strip_multiple_prefixes(self): - """Test with multiple known prefixes""" - known_prefixes = ["model.diffusion_model.", "model.model.", "net."] - - result1 = model_detection.normalize_layer_name( - "model.diffusion_model.block.attn.weight", - known_prefixes - ) - self.assertEqual(result1, "block.attn") - - result2 = model_detection.normalize_layer_name( - "model.model.encoder.layer.weight", - known_prefixes - ) - self.assertEqual(result2, "encoder.layer") - - result3 = model_detection.normalize_layer_name( - "net.transformer.blocks.0.weight", - known_prefixes - ) - self.assertEqual(result3, "transformer.blocks.0") - - def test_strip_scale_weight_suffix(self): - """Test stripping scale_weight suffix""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "model.diffusion_model.layer1.scale_weight", - known_prefixes - ) - self.assertEqual(result, "layer1") - - def test_strip_bias_suffix(self): - """Test stripping bias suffix""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "model.diffusion_model.layer1.bias", - known_prefixes - ) - self.assertEqual(result, "layer1") - - def test_no_prefix_match(self): - """Test with no prefix match""" - known_prefixes = ["model.diffusion_model."] - result = model_detection.normalize_layer_name( - "other.model.layer1.weight", - known_prefixes - ) - # Should strip suffix but not prefix - self.assertEqual(result, "other.model.layer1") - - -class TestDetectLayerQuantization(unittest.TestCase): - """Test the detect_layer_quantization function""" - - def test_no_quantization(self): - """Test with no quantization markers""" - state_dict = { - "model.diffusion_model.layer1.weight": torch.randn(10, 20), - "model.diffusion_model.layer2.weight": torch.randn(20, 30), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - self.assertIsNone(result) - - def test_legacy_scaled_fp8(self): - """Test that legacy scaled_fp8 marker returns None""" - # Create FP8 tensor by converting from float32 - fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - "model.diffusion_model.scaled_fp8": torch.tensor([], dtype=torch.float8_e4m3fn), - "model.diffusion_model.layer1.weight": fp8_weight, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - # Should return None to trigger legacy path - self.assertIsNone(result) - - def test_metadata_format(self): - """Test with new metadata format""" - metadata = { - "format_version": "1.0", - "layers": { - "layer1": { - "format": "fp8_e4m3fn_scaled", - "params": {"use_fp8_matmul": True} - }, - "layer2": { - "format": "fp8_e5m2_scaled", - "params": {"use_fp8_matmul": True} - } - } - } - state_dict = { - "model.diffusion_model._quantization_metadata": metadata, - "model.diffusion_model.layer1.weight": torch.randn(10, 20), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - self.assertIn("layer2", result) - self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled") - self.assertEqual(result["layer2"]["format"], "fp8_e5m2_scaled") - # Metadata should be popped from state_dict - self.assertNotIn("model.diffusion_model._quantization_metadata", state_dict) - - def test_mixed_precision_detection(self): - """Test detection of mixed precision via scale patterns""" - # Create FP8 tensors by converting from float32 - fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - fp8_weight3 = torch.randn(30, 40, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - # Layer 1: FP8 (has scale_weight) - "model.diffusion_model.layer1.weight": fp8_weight1, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - # Layer 2: Standard (no scale_weight) - "model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), - # Layer 3: FP8 (has scale_weight) - "model.diffusion_model.layer3.weight": fp8_weight3, - "model.diffusion_model.layer3.scale_weight": torch.tensor(1.0), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - self.assertIn("layer3", result) - self.assertNotIn("layer2", result) # Layer 2 not quantized - self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled") - self.assertEqual(result["layer3"]["format"], "fp8_e4m3fn_scaled") - - def test_all_layers_quantized(self): - """Test that uniform quantization (all layers) returns None""" - # Create FP8 tensors by converting from float32 - fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - fp8_weight2 = torch.randn(20, 30, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - # All layers have scale_weight - "model.diffusion_model.layer1.weight": fp8_weight1, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - "model.diffusion_model.layer2.weight": fp8_weight2, - "model.diffusion_model.layer2.scale_weight": torch.tensor(1.0), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - # If all layers are quantized, it's not mixed precision - # Should return None to use legacy scaled_fp8_ops path - self.assertIsNone(result) - - def test_fp8_e5m2_detection(self): - """Test detection of FP8 E5M2 format""" - # Create FP8 E5M2 tensor by converting from float32 - fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e5m2) - state_dict = { - "model.diffusion_model.layer1.weight": fp8_weight, - "model.diffusion_model.layer1.scale_weight": torch.tensor(1.0), - "model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - self.assertEqual(result["layer1"]["format"], "fp8_e5m2_scaled") - - def test_invalid_metadata(self): - """Test with invalid metadata format""" - state_dict = { - "model.diffusion_model._quantization_metadata": "invalid_string", - "model.diffusion_model.layer1.weight": torch.randn(10, 20), - } - result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.") - # Should return None on invalid metadata - self.assertIsNone(result) - - def test_different_prefix(self): - """Test with different model prefix (audio model)""" - # Create FP8 tensor by converting from float32 - fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - state_dict = { - "model.model.layer1.weight": fp8_weight, - "model.model.layer1.scale_weight": torch.tensor(1.0), - "model.model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16), - } - result = model_detection.detect_layer_quantization(state_dict, "model.model.") - - self.assertIsNotNone(result) - self.assertIn("layer1", result) - - -class TestPickOperationsIntegration(unittest.TestCase): - """Test pick_operations with model_config parameter""" - - def test_backward_compatibility(self): - """Test that pick_operations works without model_config (legacy)""" - from comfy import ops - - # Should work without model_config parameter - result = ops.pick_operations(None, None) - self.assertIsNotNone(result) - self.assertEqual(result, ops.disable_weight_init) - - def test_with_model_config_no_quant(self): - """Test with model_config but no quantization""" - from comfy import ops, supported_models_base - - model_config = supported_models_base.BASE({}) - model_config.layer_quant_config = None - - result = ops.pick_operations(None, None, model_config=model_config) - self.assertIsNotNone(result) - # Should use standard path - self.assertEqual(result, ops.disable_weight_init) - - def test_legacy_scaled_fp8(self): - """Test that legacy scaled_fp8 still works""" - from comfy import ops, supported_models_base - - model_config = supported_models_base.BASE({}) - model_config.scaled_fp8 = torch.float8_e4m3fn - - result = ops.pick_operations( - None, None, - scaled_fp8=torch.float8_e4m3fn, - model_config=model_config - ) - self.assertIsNotNone(result) - # Should return scaled_fp8_ops (the returned class is the inner class) - # Check that it's not the standard disable_weight_init - self.assertNotEqual(result, ops.disable_weight_init) - # Verify it has Linear class - self.assertTrue(hasattr(result, 'Linear')) - - -if __name__ == "__main__": - unittest.main() - diff --git a/tests-unit/comfy_test/test_quant_registry.py b/tests-unit/comfy_test/test_quant_registry.py deleted file mode 100644 index 5c624b1db..000000000 --- a/tests-unit/comfy_test/test_quant_registry.py +++ /dev/null @@ -1,399 +0,0 @@ -""" -Unit tests for tensor subclass quantization system. -Tests the new QuantizedTensorFP8 subclass and operation handlers. -""" - -import unittest -import torch -import sys -import os - -# Add comfy to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - -from comfy import ops -from comfy import quant_ops - - -class TestQuantizedTensorFP8(unittest.TestCase): - """Test the QuantizedTensorFP8 tensor subclass""" - - def test_creation(self): - """Test creating a QuantizedTensorFP8""" - fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(2.0) - - qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16) - - self.assertIsInstance(qt, quant_ops.QuantizedTensorFP8) - self.assertEqual(qt.shape, (256, 128)) - self.assertEqual(qt.dtype, torch.float8_e4m3fn) - self.assertEqual(qt._scale, scale) - self.assertEqual(qt._orig_dtype, torch.bfloat16) - - def test_dequantize(self): - """Test explicit dequantization""" - # Create a simple FP8 tensor - fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(3.0) - - qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.float32) - dequantized = qt.dequantize() - - # Dequantized should be approximately ones * 3.0 - self.assertEqual(dequantized.dtype, torch.float32) - self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) - - def test_repr(self): - """Test string representation""" - fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(2.5) - - qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16) - repr_str = repr(qt) - - self.assertIn("QuantizedTensorFP8", repr_str) - self.assertIn("shape", repr_str) - self.assertIn("scale", repr_str) - - -class TestOperationRegistry(unittest.TestCase): - """Test the operation registry system""" - - def test_registry_basics(self): - """Test that operations are registered""" - registered_ops = quant_ops.list_registered_ops() - - # Check that key operations are registered - self.assertIn(torch.ops.aten.linear.default, registered_ops) - self.assertIn(torch.ops.aten.silu.default, registered_ops) - self.assertIn(torch.ops.aten.layer_norm.default, registered_ops) - self.assertIn(torch.ops.aten.add.Tensor, registered_ops) - self.assertIn(torch.ops.aten.mul.Tensor, registered_ops) - - def test_get_handler(self): - """Test getting a registered handler""" - handler = quant_ops.get_quant_handler(torch.ops.aten.linear.default) - self.assertIsNotNone(handler) - self.assertTrue(callable(handler)) - - def test_custom_registration(self): - """Test registering a custom operation""" - - # Define a custom handler - @quant_ops.register_quant_op(torch.ops.aten.relu.default) - def custom_relu_handler(func, args, kwargs): - return func(*args, **kwargs) - - # Verify registration - handler = quant_ops.get_quant_handler(torch.ops.aten.relu.default) - self.assertIsNotNone(handler) - self.assertEqual(handler, custom_relu_handler) - - -class TestLinearHandler(unittest.TestCase): - """Test the linear operation handler""" - - def test_linear_with_quantized_weight(self): - """Test F.linear with quantized weight""" - # Set seed for reproducibility - torch.manual_seed(42) - - # Create quantized weight - weight_fp32 = torch.randn(256, 128, dtype=torch.float32) - scale = torch.tensor(2.0) - weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) - weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32) - - # Create input - input_tensor = torch.randn(16, 128, dtype=torch.float32) - - # Call linear (should trigger dispatch) - output = torch.nn.functional.linear(input_tensor, weight_q, bias=None) - - # Verify output shape - self.assertEqual(output.shape, (16, 256)) - - # Verify it's approximately correct (allowing for FP8 quantization error) - # Note: FP8 has limited precision, so use very loose tolerance - expected = torch.nn.functional.linear(input_tensor, weight_fp32, bias=None) - # Just check that it's in the right ballpark (within 50% error on average) - mean_rel_error = ((output - expected).abs() / (expected.abs() + 1e-6)).mean() - self.assertLess(mean_rel_error, 0.5, f"Mean relative error {mean_rel_error:.3f} is too large") - - def test_linear_with_bias(self): - """Test F.linear with quantized weight and bias""" - weight_fp32 = torch.randn(64, 32, dtype=torch.float32) - scale = torch.tensor(1.5) - weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) - weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32) - - input_tensor = torch.randn(8, 32, dtype=torch.float32) - bias = torch.randn(64, dtype=torch.float32) - - output = torch.nn.functional.linear(input_tensor, weight_q, bias) - - self.assertEqual(output.shape, (8, 64)) - - -class TestActivationHandlers(unittest.TestCase): - """Test activation function handlers""" - - def test_silu_with_quantized_input(self): - """Test SiLU with quantized input""" - # Create quantized input - input_fp32 = torch.randn(16, 128, dtype=torch.float32) - scale = torch.tensor(1.0) - input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn) - input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32) - - # Apply SiLU - output = torch.nn.functional.silu(input_q) - - # Should return a QuantizedTensorFP8 - self.assertIsInstance(output, quant_ops.QuantizedTensorFP8) - - # Verify approximate correctness - expected = torch.nn.functional.silu(input_fp32) - output_dq = output.dequantize() - self.assertTrue(torch.allclose(output_dq, expected, rtol=0.2, atol=0.2)) - - def test_layernorm_dequantizes(self): - """Test that LayerNorm dequantizes input""" - # Create quantized input - input_fp32 = torch.randn(16, 128, dtype=torch.float32) - scale = torch.tensor(1.0) - input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn) - input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32) - - # Apply LayerNorm - weight = torch.ones(128) - bias = torch.zeros(128) - output = torch.nn.functional.layer_norm(input_q, (128,), weight, bias) - - # Should NOT be quantized (LayerNorm breaks quantization) - self.assertNotIsInstance(output, quant_ops.QuantizedTensorFP8) - self.assertEqual(output.dtype, torch.float32) - - -class TestElementwiseHandlers(unittest.TestCase): - """Test element-wise operation handlers""" - - def test_add_mixed_tensors(self): - """Test addition with mixed quantized/non-quantized tensors""" - # Create quantized tensor - a_fp32 = torch.ones(10, 20, dtype=torch.float32) - scale = torch.tensor(1.0) - a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn) - a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32) - - # Non-quantized tensor - b = torch.ones(10, 20, dtype=torch.float32) * 2.0 - - # Add them - result = a_q + b - - # Should be dequantized - self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) - self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 3.0, rtol=0.1)) - - def test_mul_quantized_tensors(self): - """Test multiplication of two quantized tensors""" - a_fp32 = torch.ones(10, 20) * 2.0 - scale_a = torch.tensor(1.0) - a_fp8 = (a_fp32 / scale_a).to(torch.float8_e4m3fn) - a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale_a, orig_dtype=torch.float32) - - b_fp32 = torch.ones(10, 20) * 3.0 - scale_b = torch.tensor(1.0) - b_fp8 = (b_fp32 / scale_b).to(torch.float8_e4m3fn) - b_q = quant_ops.QuantizedTensorFP8(b_fp8, scale_b, orig_dtype=torch.float32) - - result = a_q * b_q - - # Should be dequantized - self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) - self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 6.0, rtol=0.1)) - - -class TestFallbackMechanism(unittest.TestCase): - """Test fallback for unsupported operations""" - - def test_unsupported_op_dequantizes(self): - """Test that unsupported operations fall back to dequantization""" - # Set seed for reproducibility - torch.manual_seed(42) - - # Create quantized tensor - a_fp32 = torch.randn(10, 20, dtype=torch.float32) - scale = torch.tensor(1.0) - a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn) - a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32) - - # Call an operation that doesn't have a registered handler - # For example, torch.abs - result = torch.abs(a_q) - - # Should work via fallback (dequantize → abs → return) - self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8) - expected = torch.abs(a_fp32) - # FP8 introduces quantization error, so use loose tolerance - mean_error = (result - expected).abs().mean() - self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") - - -class TestMixedPrecisionOps(unittest.TestCase): - """Test MixedPrecisionOps integration""" - - def test_linear_layer_creation(self): - """Test that MixedPrecisionOps.Linear can be created""" - layer = ops.MixedPrecisionOps.Linear(128, 256, bias=True, device="cpu", dtype=torch.float32) - - self.assertIsInstance(layer, ops.MixedPrecisionOps.Linear) - self.assertFalse(layer._quantization_initialized) - self.assertIsNone(layer.quant_format) - - def test_layer_quant_config_detection(self): - """Test that layer quantization config is detected during load""" - # Set up layer config - ops.MixedPrecisionOps._layer_quant_config = { - "test_layer": { - "format": "fp8_e4m3fn", - "params": {} - } - } - - # Create a state dict with quantized weight - weight_fp32 = torch.randn(256, 128, dtype=torch.float32) - scale = torch.tensor(2.0) - weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn) - - state_dict = { - "model.diffusion_model.test_layer.weight": weight_fp8, - "model.diffusion_model.test_layer.scale_weight": scale, - } - - # Create layer and load - layer = ops.MixedPrecisionOps.Linear(128, 256, bias=False, device="cpu", dtype=torch.float8_e4m3fn) - layer.weight = torch.nn.Parameter(torch.zeros(256, 128, dtype=torch.float8_e4m3fn)) - - # Manually call _load_from_state_dict - layer._load_from_state_dict( - state_dict, - prefix="model.diffusion_model.test_layer.", - local_metadata={}, - strict=True, - missing_keys=[], - unexpected_keys=[], - error_msgs=[] - ) - - # Verify quantization was initialized - self.assertTrue(layer._quantization_initialized) - self.assertEqual(layer.quant_format, "fp8_e4m3fn") - self.assertIsNotNone(layer.quant_scale) - - # Verify weight is wrapped - self.assertIsInstance(layer.weight.data, quant_ops.QuantizedTensorFP8) - - # Clean up - ops.MixedPrecisionOps._layer_quant_config = {} - - -class TestBackwardCompatibility(unittest.TestCase): - """Test backward compatibility with legacy systems""" - - def test_legacy_ops_classes_exist(self): - """Test that legacy ops classes still exist""" - self.assertTrue(hasattr(ops, 'disable_weight_init')) - self.assertTrue(hasattr(ops, 'manual_cast')) - self.assertTrue(hasattr(ops, 'fp8_ops')) - self.assertTrue(hasattr(ops, 'scaled_fp8_ops')) - - def test_pick_operations_legacy_path(self): - """Test pick_operations returns correct class for legacy cases""" - # Test standard case - result = ops.pick_operations(torch.float32, torch.float32) - self.assertEqual(result, ops.disable_weight_init) - - # Test manual cast case - result = ops.pick_operations(torch.float32, torch.float16) - self.assertEqual(result, ops.manual_cast) - - -class TestFP8LinearUnification(unittest.TestCase): - """Test that fp8_linear now uses the unified tensor subclass infrastructure""" - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA required for FP8") - def test_fp8_linear_uses_tensor_subclass(self): - """Verify fp8_linear wraps tensors in QuantizedTensorFP8""" - torch.manual_seed(42) - - # Create a mock Linear layer with FP8 weight - linear = ops.fp8_ops.Linear(4, 3, bias=True) - linear.weight = torch.nn.Parameter( - torch.randn(3, 4, dtype=torch.bfloat16).to(torch.float8_e4m3fn), - requires_grad=False - ) - linear.bias = torch.nn.Parameter( - torch.randn(3, dtype=torch.bfloat16), - requires_grad=False - ) - linear.scale_weight = torch.tensor(1.0) - linear.scale_input = None # No input scaling - - # Create input - input_tensor = torch.randn(2, 4, dtype=torch.bfloat16) - - # Call fp8_linear - should work without errors - try: - result = ops.fp8_linear(linear, input_tensor) - self.assertIsNotNone(result) - self.assertEqual(result.shape, (2, 3)) - except Exception as e: - # On CPU or unsupported hardware, _scaled_mm might not be available - # but the function should still complete without syntax errors - pass - - def test_fp8_linear_maintains_signature(self): - """Verify fp8_linear maintains its original function signature""" - import inspect - sig = inspect.signature(ops.fp8_linear) - params = list(sig.parameters.keys()) - - # Should have 'self' and 'input' parameters - self.assertIn('self', params) - self.assertIn('input', params) - self.assertEqual(len(params), 2) - - def test_fp8_linear_returns_none_for_non_fp8(self): - """Verify fp8_linear returns None for non-FP8 weights""" - # Create a Linear layer with BF16 weight (not FP8) - linear = ops.disable_weight_init.Linear(4, 3, bias=False) - linear.weight = torch.nn.Parameter( - torch.randn(3, 4, dtype=torch.bfloat16), - requires_grad=False - ) - - input_tensor = torch.randn(2, 4, dtype=torch.bfloat16) - - # Should return None for non-FP8 weights - result = ops.fp8_linear(linear, input_tensor) - self.assertIsNone(result) - - def test_fp8_ops_linear_uses_fp8_linear(self): - """Verify fp8_ops.Linear still uses fp8_linear in forward pass""" - linear = ops.fp8_ops.Linear(4, 3, bias=False) - - # Verify the class has the forward_comfy_cast_weights method - self.assertTrue(hasattr(linear, 'forward_comfy_cast_weights')) - - # The forward_comfy_cast_weights should attempt to call fp8_linear - # (we can't easily test this without mocking, but we verify structure) - import inspect - source = inspect.getsource(linear.forward_comfy_cast_weights) - self.assertIn('fp8_linear', source) - - -if __name__ == "__main__": - unittest.main()