mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 00:12:33 +08:00
400 lines
16 KiB
Python
400 lines
16 KiB
Python
"""
|
|
Unit tests for tensor subclass quantization system.
|
|
Tests the new QuantizedTensorFP8 subclass and operation handlers.
|
|
"""
|
|
|
|
import unittest
|
|
import torch
|
|
import sys
|
|
import os
|
|
|
|
# Add comfy to path
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
|
|
from comfy import ops
|
|
from comfy import quant_ops
|
|
|
|
|
|
class TestQuantizedTensorFP8(unittest.TestCase):
|
|
"""Test the QuantizedTensorFP8 tensor subclass"""
|
|
|
|
def test_creation(self):
|
|
"""Test creating a QuantizedTensorFP8"""
|
|
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
scale = torch.tensor(2.0)
|
|
|
|
qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16)
|
|
|
|
self.assertIsInstance(qt, quant_ops.QuantizedTensorFP8)
|
|
self.assertEqual(qt.shape, (256, 128))
|
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
|
self.assertEqual(qt._scale, scale)
|
|
self.assertEqual(qt._orig_dtype, torch.bfloat16)
|
|
|
|
def test_dequantize(self):
|
|
"""Test explicit dequantization"""
|
|
# Create a simple FP8 tensor
|
|
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
scale = torch.tensor(3.0)
|
|
|
|
qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.float32)
|
|
dequantized = qt.dequantize()
|
|
|
|
# Dequantized should be approximately ones * 3.0
|
|
self.assertEqual(dequantized.dtype, torch.float32)
|
|
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
|
|
|
def test_repr(self):
|
|
"""Test string representation"""
|
|
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
scale = torch.tensor(2.5)
|
|
|
|
qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16)
|
|
repr_str = repr(qt)
|
|
|
|
self.assertIn("QuantizedTensorFP8", repr_str)
|
|
self.assertIn("shape", repr_str)
|
|
self.assertIn("scale", repr_str)
|
|
|
|
|
|
class TestOperationRegistry(unittest.TestCase):
|
|
"""Test the operation registry system"""
|
|
|
|
def test_registry_basics(self):
|
|
"""Test that operations are registered"""
|
|
registered_ops = quant_ops.list_registered_ops()
|
|
|
|
# Check that key operations are registered
|
|
self.assertIn(torch.ops.aten.linear.default, registered_ops)
|
|
self.assertIn(torch.ops.aten.silu.default, registered_ops)
|
|
self.assertIn(torch.ops.aten.layer_norm.default, registered_ops)
|
|
self.assertIn(torch.ops.aten.add.Tensor, registered_ops)
|
|
self.assertIn(torch.ops.aten.mul.Tensor, registered_ops)
|
|
|
|
def test_get_handler(self):
|
|
"""Test getting a registered handler"""
|
|
handler = quant_ops.get_quant_handler(torch.ops.aten.linear.default)
|
|
self.assertIsNotNone(handler)
|
|
self.assertTrue(callable(handler))
|
|
|
|
def test_custom_registration(self):
|
|
"""Test registering a custom operation"""
|
|
|
|
# Define a custom handler
|
|
@quant_ops.register_quant_op(torch.ops.aten.relu.default)
|
|
def custom_relu_handler(func, args, kwargs):
|
|
return func(*args, **kwargs)
|
|
|
|
# Verify registration
|
|
handler = quant_ops.get_quant_handler(torch.ops.aten.relu.default)
|
|
self.assertIsNotNone(handler)
|
|
self.assertEqual(handler, custom_relu_handler)
|
|
|
|
|
|
class TestLinearHandler(unittest.TestCase):
|
|
"""Test the linear operation handler"""
|
|
|
|
def test_linear_with_quantized_weight(self):
|
|
"""Test F.linear with quantized weight"""
|
|
# Set seed for reproducibility
|
|
torch.manual_seed(42)
|
|
|
|
# Create quantized weight
|
|
weight_fp32 = torch.randn(256, 128, dtype=torch.float32)
|
|
scale = torch.tensor(2.0)
|
|
weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn)
|
|
weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32)
|
|
|
|
# Create input
|
|
input_tensor = torch.randn(16, 128, dtype=torch.float32)
|
|
|
|
# Call linear (should trigger dispatch)
|
|
output = torch.nn.functional.linear(input_tensor, weight_q, bias=None)
|
|
|
|
# Verify output shape
|
|
self.assertEqual(output.shape, (16, 256))
|
|
|
|
# Verify it's approximately correct (allowing for FP8 quantization error)
|
|
# Note: FP8 has limited precision, so use very loose tolerance
|
|
expected = torch.nn.functional.linear(input_tensor, weight_fp32, bias=None)
|
|
# Just check that it's in the right ballpark (within 50% error on average)
|
|
mean_rel_error = ((output - expected).abs() / (expected.abs() + 1e-6)).mean()
|
|
self.assertLess(mean_rel_error, 0.5, f"Mean relative error {mean_rel_error:.3f} is too large")
|
|
|
|
def test_linear_with_bias(self):
|
|
"""Test F.linear with quantized weight and bias"""
|
|
weight_fp32 = torch.randn(64, 32, dtype=torch.float32)
|
|
scale = torch.tensor(1.5)
|
|
weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn)
|
|
weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32)
|
|
|
|
input_tensor = torch.randn(8, 32, dtype=torch.float32)
|
|
bias = torch.randn(64, dtype=torch.float32)
|
|
|
|
output = torch.nn.functional.linear(input_tensor, weight_q, bias)
|
|
|
|
self.assertEqual(output.shape, (8, 64))
|
|
|
|
|
|
class TestActivationHandlers(unittest.TestCase):
|
|
"""Test activation function handlers"""
|
|
|
|
def test_silu_with_quantized_input(self):
|
|
"""Test SiLU with quantized input"""
|
|
# Create quantized input
|
|
input_fp32 = torch.randn(16, 128, dtype=torch.float32)
|
|
scale = torch.tensor(1.0)
|
|
input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn)
|
|
input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32)
|
|
|
|
# Apply SiLU
|
|
output = torch.nn.functional.silu(input_q)
|
|
|
|
# Should return a QuantizedTensorFP8
|
|
self.assertIsInstance(output, quant_ops.QuantizedTensorFP8)
|
|
|
|
# Verify approximate correctness
|
|
expected = torch.nn.functional.silu(input_fp32)
|
|
output_dq = output.dequantize()
|
|
self.assertTrue(torch.allclose(output_dq, expected, rtol=0.2, atol=0.2))
|
|
|
|
def test_layernorm_dequantizes(self):
|
|
"""Test that LayerNorm dequantizes input"""
|
|
# Create quantized input
|
|
input_fp32 = torch.randn(16, 128, dtype=torch.float32)
|
|
scale = torch.tensor(1.0)
|
|
input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn)
|
|
input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32)
|
|
|
|
# Apply LayerNorm
|
|
weight = torch.ones(128)
|
|
bias = torch.zeros(128)
|
|
output = torch.nn.functional.layer_norm(input_q, (128,), weight, bias)
|
|
|
|
# Should NOT be quantized (LayerNorm breaks quantization)
|
|
self.assertNotIsInstance(output, quant_ops.QuantizedTensorFP8)
|
|
self.assertEqual(output.dtype, torch.float32)
|
|
|
|
|
|
class TestElementwiseHandlers(unittest.TestCase):
|
|
"""Test element-wise operation handlers"""
|
|
|
|
def test_add_mixed_tensors(self):
|
|
"""Test addition with mixed quantized/non-quantized tensors"""
|
|
# Create quantized tensor
|
|
a_fp32 = torch.ones(10, 20, dtype=torch.float32)
|
|
scale = torch.tensor(1.0)
|
|
a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn)
|
|
a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32)
|
|
|
|
# Non-quantized tensor
|
|
b = torch.ones(10, 20, dtype=torch.float32) * 2.0
|
|
|
|
# Add them
|
|
result = a_q + b
|
|
|
|
# Should be dequantized
|
|
self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8)
|
|
self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 3.0, rtol=0.1))
|
|
|
|
def test_mul_quantized_tensors(self):
|
|
"""Test multiplication of two quantized tensors"""
|
|
a_fp32 = torch.ones(10, 20) * 2.0
|
|
scale_a = torch.tensor(1.0)
|
|
a_fp8 = (a_fp32 / scale_a).to(torch.float8_e4m3fn)
|
|
a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale_a, orig_dtype=torch.float32)
|
|
|
|
b_fp32 = torch.ones(10, 20) * 3.0
|
|
scale_b = torch.tensor(1.0)
|
|
b_fp8 = (b_fp32 / scale_b).to(torch.float8_e4m3fn)
|
|
b_q = quant_ops.QuantizedTensorFP8(b_fp8, scale_b, orig_dtype=torch.float32)
|
|
|
|
result = a_q * b_q
|
|
|
|
# Should be dequantized
|
|
self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8)
|
|
self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 6.0, rtol=0.1))
|
|
|
|
|
|
class TestFallbackMechanism(unittest.TestCase):
|
|
"""Test fallback for unsupported operations"""
|
|
|
|
def test_unsupported_op_dequantizes(self):
|
|
"""Test that unsupported operations fall back to dequantization"""
|
|
# Set seed for reproducibility
|
|
torch.manual_seed(42)
|
|
|
|
# Create quantized tensor
|
|
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
|
scale = torch.tensor(1.0)
|
|
a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn)
|
|
a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32)
|
|
|
|
# Call an operation that doesn't have a registered handler
|
|
# For example, torch.abs
|
|
result = torch.abs(a_q)
|
|
|
|
# Should work via fallback (dequantize → abs → return)
|
|
self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8)
|
|
expected = torch.abs(a_fp32)
|
|
# FP8 introduces quantization error, so use loose tolerance
|
|
mean_error = (result - expected).abs().mean()
|
|
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
|
|
|
|
|
|
class TestMixedPrecisionOps(unittest.TestCase):
|
|
"""Test MixedPrecisionOps integration"""
|
|
|
|
def test_linear_layer_creation(self):
|
|
"""Test that MixedPrecisionOps.Linear can be created"""
|
|
layer = ops.MixedPrecisionOps.Linear(128, 256, bias=True, device="cpu", dtype=torch.float32)
|
|
|
|
self.assertIsInstance(layer, ops.MixedPrecisionOps.Linear)
|
|
self.assertFalse(layer._quantization_initialized)
|
|
self.assertIsNone(layer.quant_format)
|
|
|
|
def test_layer_quant_config_detection(self):
|
|
"""Test that layer quantization config is detected during load"""
|
|
# Set up layer config
|
|
ops.MixedPrecisionOps._layer_quant_config = {
|
|
"test_layer": {
|
|
"format": "fp8_e4m3fn",
|
|
"params": {}
|
|
}
|
|
}
|
|
|
|
# Create a state dict with quantized weight
|
|
weight_fp32 = torch.randn(256, 128, dtype=torch.float32)
|
|
scale = torch.tensor(2.0)
|
|
weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn)
|
|
|
|
state_dict = {
|
|
"model.diffusion_model.test_layer.weight": weight_fp8,
|
|
"model.diffusion_model.test_layer.scale_weight": scale,
|
|
}
|
|
|
|
# Create layer and load
|
|
layer = ops.MixedPrecisionOps.Linear(128, 256, bias=False, device="cpu", dtype=torch.float8_e4m3fn)
|
|
layer.weight = torch.nn.Parameter(torch.zeros(256, 128, dtype=torch.float8_e4m3fn))
|
|
|
|
# Manually call _load_from_state_dict
|
|
layer._load_from_state_dict(
|
|
state_dict,
|
|
prefix="model.diffusion_model.test_layer.",
|
|
local_metadata={},
|
|
strict=True,
|
|
missing_keys=[],
|
|
unexpected_keys=[],
|
|
error_msgs=[]
|
|
)
|
|
|
|
# Verify quantization was initialized
|
|
self.assertTrue(layer._quantization_initialized)
|
|
self.assertEqual(layer.quant_format, "fp8_e4m3fn")
|
|
self.assertIsNotNone(layer.quant_scale)
|
|
|
|
# Verify weight is wrapped
|
|
self.assertIsInstance(layer.weight.data, quant_ops.QuantizedTensorFP8)
|
|
|
|
# Clean up
|
|
ops.MixedPrecisionOps._layer_quant_config = {}
|
|
|
|
|
|
class TestBackwardCompatibility(unittest.TestCase):
|
|
"""Test backward compatibility with legacy systems"""
|
|
|
|
def test_legacy_ops_classes_exist(self):
|
|
"""Test that legacy ops classes still exist"""
|
|
self.assertTrue(hasattr(ops, 'disable_weight_init'))
|
|
self.assertTrue(hasattr(ops, 'manual_cast'))
|
|
self.assertTrue(hasattr(ops, 'fp8_ops'))
|
|
self.assertTrue(hasattr(ops, 'scaled_fp8_ops'))
|
|
|
|
def test_pick_operations_legacy_path(self):
|
|
"""Test pick_operations returns correct class for legacy cases"""
|
|
# Test standard case
|
|
result = ops.pick_operations(torch.float32, torch.float32)
|
|
self.assertEqual(result, ops.disable_weight_init)
|
|
|
|
# Test manual cast case
|
|
result = ops.pick_operations(torch.float32, torch.float16)
|
|
self.assertEqual(result, ops.manual_cast)
|
|
|
|
|
|
class TestFP8LinearUnification(unittest.TestCase):
|
|
"""Test that fp8_linear now uses the unified tensor subclass infrastructure"""
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for FP8")
|
|
def test_fp8_linear_uses_tensor_subclass(self):
|
|
"""Verify fp8_linear wraps tensors in QuantizedTensorFP8"""
|
|
torch.manual_seed(42)
|
|
|
|
# Create a mock Linear layer with FP8 weight
|
|
linear = ops.fp8_ops.Linear(4, 3, bias=True)
|
|
linear.weight = torch.nn.Parameter(
|
|
torch.randn(3, 4, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
|
|
requires_grad=False
|
|
)
|
|
linear.bias = torch.nn.Parameter(
|
|
torch.randn(3, dtype=torch.bfloat16),
|
|
requires_grad=False
|
|
)
|
|
linear.scale_weight = torch.tensor(1.0)
|
|
linear.scale_input = None # No input scaling
|
|
|
|
# Create input
|
|
input_tensor = torch.randn(2, 4, dtype=torch.bfloat16)
|
|
|
|
# Call fp8_linear - should work without errors
|
|
try:
|
|
result = ops.fp8_linear(linear, input_tensor)
|
|
self.assertIsNotNone(result)
|
|
self.assertEqual(result.shape, (2, 3))
|
|
except Exception as e:
|
|
# On CPU or unsupported hardware, _scaled_mm might not be available
|
|
# but the function should still complete without syntax errors
|
|
pass
|
|
|
|
def test_fp8_linear_maintains_signature(self):
|
|
"""Verify fp8_linear maintains its original function signature"""
|
|
import inspect
|
|
sig = inspect.signature(ops.fp8_linear)
|
|
params = list(sig.parameters.keys())
|
|
|
|
# Should have 'self' and 'input' parameters
|
|
self.assertIn('self', params)
|
|
self.assertIn('input', params)
|
|
self.assertEqual(len(params), 2)
|
|
|
|
def test_fp8_linear_returns_none_for_non_fp8(self):
|
|
"""Verify fp8_linear returns None for non-FP8 weights"""
|
|
# Create a Linear layer with BF16 weight (not FP8)
|
|
linear = ops.disable_weight_init.Linear(4, 3, bias=False)
|
|
linear.weight = torch.nn.Parameter(
|
|
torch.randn(3, 4, dtype=torch.bfloat16),
|
|
requires_grad=False
|
|
)
|
|
|
|
input_tensor = torch.randn(2, 4, dtype=torch.bfloat16)
|
|
|
|
# Should return None for non-FP8 weights
|
|
result = ops.fp8_linear(linear, input_tensor)
|
|
self.assertIsNone(result)
|
|
|
|
def test_fp8_ops_linear_uses_fp8_linear(self):
|
|
"""Verify fp8_ops.Linear still uses fp8_linear in forward pass"""
|
|
linear = ops.fp8_ops.Linear(4, 3, bias=False)
|
|
|
|
# Verify the class has the forward_comfy_cast_weights method
|
|
self.assertTrue(hasattr(linear, 'forward_comfy_cast_weights'))
|
|
|
|
# The forward_comfy_cast_weights should attempt to call fp8_linear
|
|
# (we can't easily test this without mocking, but we verify structure)
|
|
import inspect
|
|
source = inspect.getsource(linear.forward_comfy_cast_weights)
|
|
self.assertIn('fp8_linear', source)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|