Fix tests.

This commit is contained in:
comfyanonymous 2025-11-25 00:57:49 -05:00
parent cb2b4f282e
commit 0d6791be86

View File

@ -37,11 +37,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
def test_all_layers_standard(self): def test_all_layers_standard(self):
"""Test that model with no quantization works normally""" """Test that model with no quantization works normally"""
# Configure no quantization
ops.MixedPrecisionOps._layer_quant_config = {}
# Create model # Create model
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.mixed_precision_ops({}))
# Initialize weights manually # Initialize weights manually
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
@ -76,7 +73,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
"params": {} "params": {}
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict with mixed precision # Create state dict with mixed precision
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
@ -99,7 +95,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
} }
# Create model and load state dict (strict=False because custom loading pops keys) # Create model and load state dict (strict=False because custom loading pops keys)
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
# Verify weights are wrapped in QuantizedTensor # Verify weights are wrapped in QuantizedTensor
@ -132,7 +128,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
"params": {} "params": {}
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model # Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
@ -146,7 +141,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16),
} }
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
model.load_state_dict(state_dict1, strict=False) model.load_state_dict(state_dict1, strict=False)
# Save state dict # Save state dict
@ -170,7 +165,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
"params": {} "params": {}
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model # Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
@ -184,7 +178,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16),
} }
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
# Add a weight function (simulating LoRA) # Add a weight function (simulating LoRA)
@ -210,7 +204,6 @@ class TestMixedPrecisionOps(unittest.TestCase):
"params": {} "params": {}
} }
} }
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict # Create state dict
state_dict = { state_dict = {
@ -223,7 +216,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
} }
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
model = SimpleModel(operations=ops.MixedPrecisionOps) model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)