ComfyUI/tests-unit/comfy_test/test_quant_registry.py
2025-10-27 08:57:54 +01:00

400 lines
16 KiB
Python

"""
Unit tests for tensor subclass quantization system.
Tests the new QuantizedTensorFP8 subclass and operation handlers.
"""
import unittest
import torch
import sys
import os
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from comfy import ops
from comfy import quant_ops
class TestQuantizedTensorFP8(unittest.TestCase):
"""Test the QuantizedTensorFP8 tensor subclass"""
def test_creation(self):
"""Test creating a QuantizedTensorFP8"""
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(2.0)
qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16)
self.assertIsInstance(qt, quant_ops.QuantizedTensorFP8)
self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._scale, scale)
self.assertEqual(qt._orig_dtype, torch.bfloat16)
def test_dequantize(self):
"""Test explicit dequantization"""
# Create a simple FP8 tensor
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(3.0)
qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.float32)
dequantized = qt.dequantize()
# Dequantized should be approximately ones * 3.0
self.assertEqual(dequantized.dtype, torch.float32)
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
def test_repr(self):
"""Test string representation"""
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(2.5)
qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16)
repr_str = repr(qt)
self.assertIn("QuantizedTensorFP8", repr_str)
self.assertIn("shape", repr_str)
self.assertIn("scale", repr_str)
class TestOperationRegistry(unittest.TestCase):
"""Test the operation registry system"""
def test_registry_basics(self):
"""Test that operations are registered"""
registered_ops = quant_ops.list_registered_ops()
# Check that key operations are registered
self.assertIn(torch.ops.aten.linear.default, registered_ops)
self.assertIn(torch.ops.aten.silu.default, registered_ops)
self.assertIn(torch.ops.aten.layer_norm.default, registered_ops)
self.assertIn(torch.ops.aten.add.Tensor, registered_ops)
self.assertIn(torch.ops.aten.mul.Tensor, registered_ops)
def test_get_handler(self):
"""Test getting a registered handler"""
handler = quant_ops.get_quant_handler(torch.ops.aten.linear.default)
self.assertIsNotNone(handler)
self.assertTrue(callable(handler))
def test_custom_registration(self):
"""Test registering a custom operation"""
# Define a custom handler
@quant_ops.register_quant_op(torch.ops.aten.relu.default)
def custom_relu_handler(func, args, kwargs):
return func(*args, **kwargs)
# Verify registration
handler = quant_ops.get_quant_handler(torch.ops.aten.relu.default)
self.assertIsNotNone(handler)
self.assertEqual(handler, custom_relu_handler)
class TestLinearHandler(unittest.TestCase):
"""Test the linear operation handler"""
def test_linear_with_quantized_weight(self):
"""Test F.linear with quantized weight"""
# Set seed for reproducibility
torch.manual_seed(42)
# Create quantized weight
weight_fp32 = torch.randn(256, 128, dtype=torch.float32)
scale = torch.tensor(2.0)
weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn)
weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32)
# Create input
input_tensor = torch.randn(16, 128, dtype=torch.float32)
# Call linear (should trigger dispatch)
output = torch.nn.functional.linear(input_tensor, weight_q, bias=None)
# Verify output shape
self.assertEqual(output.shape, (16, 256))
# Verify it's approximately correct (allowing for FP8 quantization error)
# Note: FP8 has limited precision, so use very loose tolerance
expected = torch.nn.functional.linear(input_tensor, weight_fp32, bias=None)
# Just check that it's in the right ballpark (within 50% error on average)
mean_rel_error = ((output - expected).abs() / (expected.abs() + 1e-6)).mean()
self.assertLess(mean_rel_error, 0.5, f"Mean relative error {mean_rel_error:.3f} is too large")
def test_linear_with_bias(self):
"""Test F.linear with quantized weight and bias"""
weight_fp32 = torch.randn(64, 32, dtype=torch.float32)
scale = torch.tensor(1.5)
weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn)
weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32)
input_tensor = torch.randn(8, 32, dtype=torch.float32)
bias = torch.randn(64, dtype=torch.float32)
output = torch.nn.functional.linear(input_tensor, weight_q, bias)
self.assertEqual(output.shape, (8, 64))
class TestActivationHandlers(unittest.TestCase):
"""Test activation function handlers"""
def test_silu_with_quantized_input(self):
"""Test SiLU with quantized input"""
# Create quantized input
input_fp32 = torch.randn(16, 128, dtype=torch.float32)
scale = torch.tensor(1.0)
input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn)
input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32)
# Apply SiLU
output = torch.nn.functional.silu(input_q)
# Should return a QuantizedTensorFP8
self.assertIsInstance(output, quant_ops.QuantizedTensorFP8)
# Verify approximate correctness
expected = torch.nn.functional.silu(input_fp32)
output_dq = output.dequantize()
self.assertTrue(torch.allclose(output_dq, expected, rtol=0.2, atol=0.2))
def test_layernorm_dequantizes(self):
"""Test that LayerNorm dequantizes input"""
# Create quantized input
input_fp32 = torch.randn(16, 128, dtype=torch.float32)
scale = torch.tensor(1.0)
input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn)
input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32)
# Apply LayerNorm
weight = torch.ones(128)
bias = torch.zeros(128)
output = torch.nn.functional.layer_norm(input_q, (128,), weight, bias)
# Should NOT be quantized (LayerNorm breaks quantization)
self.assertNotIsInstance(output, quant_ops.QuantizedTensorFP8)
self.assertEqual(output.dtype, torch.float32)
class TestElementwiseHandlers(unittest.TestCase):
"""Test element-wise operation handlers"""
def test_add_mixed_tensors(self):
"""Test addition with mixed quantized/non-quantized tensors"""
# Create quantized tensor
a_fp32 = torch.ones(10, 20, dtype=torch.float32)
scale = torch.tensor(1.0)
a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn)
a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32)
# Non-quantized tensor
b = torch.ones(10, 20, dtype=torch.float32) * 2.0
# Add them
result = a_q + b
# Should be dequantized
self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8)
self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 3.0, rtol=0.1))
def test_mul_quantized_tensors(self):
"""Test multiplication of two quantized tensors"""
a_fp32 = torch.ones(10, 20) * 2.0
scale_a = torch.tensor(1.0)
a_fp8 = (a_fp32 / scale_a).to(torch.float8_e4m3fn)
a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale_a, orig_dtype=torch.float32)
b_fp32 = torch.ones(10, 20) * 3.0
scale_b = torch.tensor(1.0)
b_fp8 = (b_fp32 / scale_b).to(torch.float8_e4m3fn)
b_q = quant_ops.QuantizedTensorFP8(b_fp8, scale_b, orig_dtype=torch.float32)
result = a_q * b_q
# Should be dequantized
self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8)
self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 6.0, rtol=0.1))
class TestFallbackMechanism(unittest.TestCase):
"""Test fallback for unsupported operations"""
def test_unsupported_op_dequantizes(self):
"""Test that unsupported operations fall back to dequantization"""
# Set seed for reproducibility
torch.manual_seed(42)
# Create quantized tensor
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
scale = torch.tensor(1.0)
a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn)
a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32)
# Call an operation that doesn't have a registered handler
# For example, torch.abs
result = torch.abs(a_q)
# Should work via fallback (dequantize → abs → return)
self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8)
expected = torch.abs(a_fp32)
# FP8 introduces quantization error, so use loose tolerance
mean_error = (result - expected).abs().mean()
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
class TestMixedPrecisionOps(unittest.TestCase):
"""Test MixedPrecisionOps integration"""
def test_linear_layer_creation(self):
"""Test that MixedPrecisionOps.Linear can be created"""
layer = ops.MixedPrecisionOps.Linear(128, 256, bias=True, device="cpu", dtype=torch.float32)
self.assertIsInstance(layer, ops.MixedPrecisionOps.Linear)
self.assertFalse(layer._quantization_initialized)
self.assertIsNone(layer.quant_format)
def test_layer_quant_config_detection(self):
"""Test that layer quantization config is detected during load"""
# Set up layer config
ops.MixedPrecisionOps._layer_quant_config = {
"test_layer": {
"format": "fp8_e4m3fn",
"params": {}
}
}
# Create a state dict with quantized weight
weight_fp32 = torch.randn(256, 128, dtype=torch.float32)
scale = torch.tensor(2.0)
weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn)
state_dict = {
"model.diffusion_model.test_layer.weight": weight_fp8,
"model.diffusion_model.test_layer.scale_weight": scale,
}
# Create layer and load
layer = ops.MixedPrecisionOps.Linear(128, 256, bias=False, device="cpu", dtype=torch.float8_e4m3fn)
layer.weight = torch.nn.Parameter(torch.zeros(256, 128, dtype=torch.float8_e4m3fn))
# Manually call _load_from_state_dict
layer._load_from_state_dict(
state_dict,
prefix="model.diffusion_model.test_layer.",
local_metadata={},
strict=True,
missing_keys=[],
unexpected_keys=[],
error_msgs=[]
)
# Verify quantization was initialized
self.assertTrue(layer._quantization_initialized)
self.assertEqual(layer.quant_format, "fp8_e4m3fn")
self.assertIsNotNone(layer.quant_scale)
# Verify weight is wrapped
self.assertIsInstance(layer.weight.data, quant_ops.QuantizedTensorFP8)
# Clean up
ops.MixedPrecisionOps._layer_quant_config = {}
class TestBackwardCompatibility(unittest.TestCase):
"""Test backward compatibility with legacy systems"""
def test_legacy_ops_classes_exist(self):
"""Test that legacy ops classes still exist"""
self.assertTrue(hasattr(ops, 'disable_weight_init'))
self.assertTrue(hasattr(ops, 'manual_cast'))
self.assertTrue(hasattr(ops, 'fp8_ops'))
self.assertTrue(hasattr(ops, 'scaled_fp8_ops'))
def test_pick_operations_legacy_path(self):
"""Test pick_operations returns correct class for legacy cases"""
# Test standard case
result = ops.pick_operations(torch.float32, torch.float32)
self.assertEqual(result, ops.disable_weight_init)
# Test manual cast case
result = ops.pick_operations(torch.float32, torch.float16)
self.assertEqual(result, ops.manual_cast)
class TestFP8LinearUnification(unittest.TestCase):
"""Test that fp8_linear now uses the unified tensor subclass infrastructure"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for FP8")
def test_fp8_linear_uses_tensor_subclass(self):
"""Verify fp8_linear wraps tensors in QuantizedTensorFP8"""
torch.manual_seed(42)
# Create a mock Linear layer with FP8 weight
linear = ops.fp8_ops.Linear(4, 3, bias=True)
linear.weight = torch.nn.Parameter(
torch.randn(3, 4, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
requires_grad=False
)
linear.bias = torch.nn.Parameter(
torch.randn(3, dtype=torch.bfloat16),
requires_grad=False
)
linear.scale_weight = torch.tensor(1.0)
linear.scale_input = None # No input scaling
# Create input
input_tensor = torch.randn(2, 4, dtype=torch.bfloat16)
# Call fp8_linear - should work without errors
try:
result = ops.fp8_linear(linear, input_tensor)
self.assertIsNotNone(result)
self.assertEqual(result.shape, (2, 3))
except Exception as e:
# On CPU or unsupported hardware, _scaled_mm might not be available
# but the function should still complete without syntax errors
pass
def test_fp8_linear_maintains_signature(self):
"""Verify fp8_linear maintains its original function signature"""
import inspect
sig = inspect.signature(ops.fp8_linear)
params = list(sig.parameters.keys())
# Should have 'self' and 'input' parameters
self.assertIn('self', params)
self.assertIn('input', params)
self.assertEqual(len(params), 2)
def test_fp8_linear_returns_none_for_non_fp8(self):
"""Verify fp8_linear returns None for non-FP8 weights"""
# Create a Linear layer with BF16 weight (not FP8)
linear = ops.disable_weight_init.Linear(4, 3, bias=False)
linear.weight = torch.nn.Parameter(
torch.randn(3, 4, dtype=torch.bfloat16),
requires_grad=False
)
input_tensor = torch.randn(2, 4, dtype=torch.bfloat16)
# Should return None for non-FP8 weights
result = ops.fp8_linear(linear, input_tensor)
self.assertIsNone(result)
def test_fp8_ops_linear_uses_fp8_linear(self):
"""Verify fp8_ops.Linear still uses fp8_linear in forward pass"""
linear = ops.fp8_ops.Linear(4, 3, bias=False)
# Verify the class has the forward_comfy_cast_weights method
self.assertTrue(hasattr(linear, 'forward_comfy_cast_weights'))
# The forward_comfy_cast_weights should attempt to call fp8_linear
# (we can't easily test this without mocking, but we verify structure)
import inspect
source = inspect.getsource(linear.forward_comfy_cast_weights)
self.assertIn('fp8_linear', source)
if __name__ == "__main__":
unittest.main()