ComfyUI/tests-unit/comfy_test/test_mixed_precision.py
2025-10-27 08:57:54 +01:00

275 lines
11 KiB
Python

"""
End-to-end tests for mixed precision quantization.
Tests Phase 3: Mixed Precision Operations
"""
import unittest
import torch
import sys
import os
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from comfy import ops
class SimpleModel(torch.nn.Module):
"""Simple model for testing mixed precision"""
def __init__(self, operations=ops.disable_weight_init):
super().__init__()
self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16)
self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16)
self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16)
def forward(self, x):
x = self.layer1(x)
x = torch.nn.functional.relu(x)
x = self.layer2(x)
x = torch.nn.functional.relu(x)
x = self.layer3(x)
return x
class TestMixedPrecisionOps(unittest.TestCase):
"""Test MixedPrecisionOps end-to-end"""
def test_all_layers_standard(self):
"""Test that model with no quantization works normally"""
# Configure no quantization
ops.MixedPrecisionOps._layer_quant_config = {}
# Create model
model = SimpleModel(operations=ops.MixedPrecisionOps)
# Initialize weights manually
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16))
model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16))
model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16))
model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16))
model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16))
# Initialize weight_function and bias_function
for layer in [model.layer1, model.layer2, model.layer3]:
layer.weight_function = []
layer.bias_function = []
# Forward pass
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
self.assertEqual(output.dtype, torch.bfloat16)
def test_mixed_precision_load(self):
"""Test loading a mixed precision model from state dict"""
# Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard
layer_quant_config = {
"layer1": {
"format": "fp8_e4m3fn_scaled",
"params": {"use_fp8_matmul": False} # Disable for CPU testing
},
"layer3": {
"format": "fp8_e5m2_scaled",
"params": {"use_fp8_matmul": False}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict with mixed precision
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e5m2)
state_dict = {
# Layer 1: FP8 E4M3FN
"layer1.weight": fp8_weight1,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32),
# Layer 2: Standard BF16
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
# Layer 3: FP8 E5M2
"layer3.weight": fp8_weight3,
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
"layer3.scale_weight": torch.tensor(1.5, dtype=torch.float32),
}
# Create model and load state dict
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict)
# Verify handlers are set up correctly
self.assertIsNotNone(model.layer1.quant_handler)
self.assertIsNone(model.layer2.quant_handler) # No quantization
self.assertIsNotNone(model.layer3.quant_handler)
# Verify scales were loaded
self.assertEqual(model.layer1.scale_weight.item(), 2.0)
self.assertEqual(model.layer3.scale_weight.item(), 1.5)
# Forward pass
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
def test_state_dict_round_trip(self):
"""Test saving and loading state dict preserves quantization"""
# Configure mixed precision
layer_quant_config = {
"layer1": {
"format": "fp8_e4m3fn_scaled",
"params": {"use_fp8_matmul": False}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict1 = {
"layer1.weight": fp8_weight,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.scale_weight": torch.tensor(3.0, dtype=torch.float32),
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
model1 = SimpleModel(operations=ops.MixedPrecisionOps)
model1.load_state_dict(state_dict1)
# Save state dict
state_dict2 = model1.state_dict()
# Verify scale_weight is saved
self.assertIn("layer1.scale_weight", state_dict2)
self.assertEqual(state_dict2["layer1.scale_weight"].item(), 3.0)
# Load into new model
model2 = SimpleModel(operations=ops.MixedPrecisionOps)
model2.load_state_dict(state_dict2)
# Verify handler is set up
self.assertIsNotNone(model2.layer1.quant_handler)
self.assertEqual(model2.layer1.scale_weight.item(), 3.0)
# Verify forward passes match
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output1 = model1(input_tensor)
output2 = model2(input_tensor)
torch.testing.assert_close(output1, output2, rtol=1e-3, atol=1e-3)
def test_weight_function_compatibility(self):
"""Test that weight_function (LoRA) works with quantized layers"""
# Configure FP8 quantization
layer_quant_config = {
"layer1": {
"format": "fp8_e4m3fn_scaled",
"params": {"use_fp8_matmul": False}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create and load model
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict = {
"layer1.weight": fp8_weight,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32),
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict)
# Add a weight function (simulating LoRA)
# LoRA delta must match weight shape (20, 10)
def apply_lora(weight):
# Generate LoRA delta matching weight shape
lora_delta = torch.randn_like(weight) * 0.01
return weight + lora_delta
model.layer1.weight_function.append(apply_lora)
# Forward pass should work with LoRA
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
def test_error_handling_unknown_format(self):
"""Test that unknown formats fall back gracefully"""
# Configure with unknown format
layer_quant_config = {
"layer1": {
"format": "unknown_format_xyz",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict
state_dict = {
"layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16),
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
# Load should not crash, just log warning
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict)
# Handler should be None (fallback to standard)
self.assertIsNone(model.layer1.quant_handler)
# Forward pass should still work
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
class TestPickOperationsWithMixedPrecision(unittest.TestCase):
"""Test pick_operations with mixed precision config"""
def test_pick_operations_with_layer_quant_config(self):
"""Test that pick_operations returns MixedPrecisionOps when config present"""
from comfy import supported_models_base
# Create model config with layer_quant_config
model_config = supported_models_base.BASE({})
model_config.layer_quant_config = {
"layer1": {"format": "fp8_e4m3fn_scaled", "params": {}}
}
result = ops.pick_operations(None, None, model_config=model_config)
self.assertEqual(result, ops.MixedPrecisionOps)
self.assertEqual(ops.MixedPrecisionOps._layer_quant_config, model_config.layer_quant_config)
def test_pick_operations_without_layer_quant_config(self):
"""Test that pick_operations falls back to standard when no config"""
from comfy import supported_models_base
model_config = supported_models_base.BASE({})
model_config.layer_quant_config = None
result = ops.pick_operations(None, None, model_config=model_config)
self.assertEqual(result, ops.disable_weight_init)
if __name__ == "__main__":
unittest.main()