From 0d6791be86b02ac494156ced5ab37cadea37251f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 25 Nov 2025 00:57:49 -0500 Subject: [PATCH] Fix tests. --- tests-unit/comfy_quant/test_mixed_precision.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index f8d1fd04e..63361309f 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -37,11 +37,8 @@ class TestMixedPrecisionOps(unittest.TestCase): def test_all_layers_standard(self): """Test that model with no quantization works normally""" - # Configure no quantization - ops.MixedPrecisionOps._layer_quant_config = {} - # Create model - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops({})) # Initialize weights manually model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) @@ -76,7 +73,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create state dict with mixed precision fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -99,7 +95,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } # Create model and load state dict (strict=False because custom loading pops keys) - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) model.load_state_dict(state_dict, strict=False) # Verify weights are wrapped in QuantizedTensor @@ -132,7 +128,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -146,7 +141,7 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) model.load_state_dict(state_dict1, strict=False) # Save state dict @@ -170,7 +165,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) @@ -184,7 +178,7 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) model.load_state_dict(state_dict, strict=False) # Add a weight function (simulating LoRA) @@ -210,7 +204,6 @@ class TestMixedPrecisionOps(unittest.TestCase): "params": {} } } - ops.MixedPrecisionOps._layer_quant_config = layer_quant_config # Create state dict state_dict = { @@ -223,7 +216,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS - model = SimpleModel(operations=ops.MixedPrecisionOps) + model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) with self.assertRaises(KeyError): model.load_state_dict(state_dict, strict=False)