ComfyUI/tests-unit/comfy_quant/test_quant_registry.py

541 lines
22 KiB
Python

import os
import sys
import unittest
from pathlib import Path
import torch
from safetensors.torch import load_file
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
def has_gpu():
return torch.cuda.is_available()
from comfy.cli_args import args
if not has_gpu():
args.cpu = True
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout, AWQQuantLayout, SVDQuantLayout
from comfy.ops import mixed_precision_ops
from comfy.svdquant_converter import convert_svdquant_state_dict, convert_awq_state_dict
class TestQuantizedTensor(unittest.TestCase):
"""Test the QuantizedTensor subclass with FP8 layout"""
def test_creation(self):
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(2.0)
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._layout_params['scale'], scale)
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
def test_dequantize(self):
"""Test explicit dequantization"""
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(3.0)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
dequantized = qt.dequantize()
self.assertEqual(dequantized.dtype, torch.float32)
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
def test_from_float(self):
"""Test creating QuantizedTensor from float tensor"""
float_tensor = torch.randn(64, 32, dtype=torch.float32)
scale = torch.tensor(1.5)
qt = QuantizedTensor.from_float(
float_tensor,
"TensorCoreFP8Layout",
scale=scale,
dtype=torch.float8_e4m3fn
)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt.shape, (64, 32))
# Verify dequantization gives approximately original values
dequantized = qt.dequantize()
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
self.assertLess(mean_rel_error, 0.1)
class TestGenericUtilities(unittest.TestCase):
"""Test generic utility operations"""
def test_detach(self):
"""Test detach operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Detach should return a new QuantizedTensor
qt_detached = qt.detach()
self.assertIsInstance(qt_detached, QuantizedTensor)
self.assertEqual(qt_detached.shape, qt.shape)
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
def test_clone(self):
"""Test clone operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Clone should return a new QuantizedTensor
qt_cloned = qt.clone()
self.assertIsInstance(qt_cloned, QuantizedTensor)
self.assertEqual(qt_cloned.shape, qt.shape)
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
# Verify it's a deep copy
self.assertIsNot(qt_cloned._qdata, qt._qdata)
@unittest.skipUnless(has_gpu(), "GPU not available")
def test_to_device(self):
"""Test device transfer"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Moving to same device should work (CPU to CPU)
qt_cpu = qt.to('cpu')
self.assertIsInstance(qt_cpu, QuantizedTensor)
self.assertEqual(qt_cpu.device.type, 'cpu')
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
class TestTensorCoreFP8Layout(unittest.TestCase):
"""Test the TensorCoreFP8Layout implementation"""
def test_quantize(self):
"""Test quantization method"""
float_tensor = torch.randn(32, 64, dtype=torch.float32)
scale = torch.tensor(1.5)
qdata, layout_params = TensorCoreFP8Layout.quantize(
float_tensor,
scale=scale,
dtype=torch.float8_e4m3fn
)
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
self.assertEqual(qdata.shape, float_tensor.shape)
self.assertIn('scale', layout_params)
self.assertIn('orig_dtype', layout_params)
self.assertEqual(layout_params['orig_dtype'], torch.float32)
def test_dequantize(self):
"""Test dequantization method"""
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
scale = torch.tensor(1.0)
qdata, layout_params = TensorCoreFP8Layout.quantize(
float_tensor,
scale=scale,
dtype=torch.float8_e4m3fn
)
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
# Should approximately match original
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
class TestAWQQuantLayout(unittest.TestCase):
"""Test the AWQQuantLayout implementation"""
def test_awq_layout_creation(self):
"""Test creating an AWQ quantized tensor"""
# AWQ uses pre-quantized weights loaded from checkpoints
# Create dummy AWQ quantized weights
out_features, in_features = 256, 128
group_size = 64
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
layout_params = {
'wscales': wscales,
'wzeros': wzeros,
'group_size': group_size,
'orig_dtype': torch.bfloat16,
'is_weight': True,
}
qt = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, qweight.shape)
self.assertEqual(qt.dtype, torch.int32)
self.assertEqual(qt._layout_type, "AWQQuantLayout")
self.assertEqual(qt._layout_params['group_size'], group_size)
def test_awq_quantize_not_supported(self):
"""Test that online quantization raises NotImplementedError for AWQ"""
# AWQ doesn't support online quantization - weights must be pre-quantized
float_tensor = torch.randn(32, 64, dtype=torch.float32)
with self.assertRaises(NotImplementedError):
AWQQuantLayout.quantize(float_tensor, is_weight=True)
def test_awq_get_plain_tensors(self):
"""Test extracting plain tensors from AWQ quantized tensor"""
out_features, in_features = 256, 128
group_size = 64
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
layout_params = {
'wscales': wscales,
'wzeros': wzeros,
'group_size': group_size,
'orig_dtype': torch.bfloat16,
'is_weight': True,
}
qt = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
plain_tensors = AWQQuantLayout.get_plain_tensors(qt)
# Verify we can extract all necessary components
self.assertIsInstance(plain_tensors, dict)
self.assertIn('qweight', plain_tensors)
self.assertIn('wscales', plain_tensors)
self.assertIn('wzeros', plain_tensors)
self.assertIn('group_size', plain_tensors)
self.assertTrue(torch.equal(plain_tensors['qweight'], qweight))
self.assertTrue(torch.equal(plain_tensors['wscales'], wscales))
self.assertTrue(torch.equal(plain_tensors['wzeros'], wzeros))
class TestSVDQuantLayout(unittest.TestCase):
"""Test the SVDQuantLayout implementation"""
def test_svdquant_layout_creation(self):
"""Test creating an SVDQuant quantized tensor"""
# SVDQuant uses pre-quantized weights loaded from checkpoints
out_features, in_features = 256, 128
rank = 32
group_size = 64
precision = "int4"
# Create dummy SVDQuant quantized weights (int8 range is -128 to 127)
qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
smooth_factor = torch.randn(in_features, dtype=torch.bfloat16)
smooth_factor_orig = torch.randn(in_features, dtype=torch.bfloat16)
proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16)
proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16)
layout_params = {
'wscales': wscales,
'smooth_factor': smooth_factor,
'smooth_factor_orig': smooth_factor_orig,
'proj_down': proj_down,
'proj_up': proj_up,
'group_size': group_size,
'precision': precision,
'orig_dtype': torch.bfloat16,
'is_weight': True,
'act_unsigned': False,
'wtscale': None,
'wcscales': None,
}
qt = QuantizedTensor(qweight, "SVDQuantLayout", layout_params)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, qweight.shape)
self.assertEqual(qt.dtype, torch.int8)
self.assertEqual(qt._layout_type, "SVDQuantLayout")
self.assertEqual(qt._layout_params['group_size'], group_size)
self.assertEqual(qt._layout_params['precision'], precision)
def test_svdquant_quantize_not_supported(self):
"""Test that online quantization raises NotImplementedError for SVDQuant"""
# SVDQuant doesn't support online quantization - weights must be pre-quantized
float_tensor = torch.randn(32, 64, dtype=torch.float32)
with self.assertRaises(NotImplementedError):
SVDQuantLayout.quantize(float_tensor, is_weight=True)
def test_svdquant_dequantize_not_supported(self):
"""Test that weight dequantization raises NotImplementedError for SVDQuant"""
# Full weight dequantization is not supported (complex operation)
out_features, in_features = 256, 128
rank = 32
group_size = 64
qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
smooth_factor = torch.randn(in_features, dtype=torch.bfloat16)
proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16)
proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16)
with self.assertRaises(NotImplementedError):
SVDQuantLayout.dequantize(
qweight,
is_weight=True,
wscales=wscales,
smooth_factor=smooth_factor,
proj_down=proj_down,
proj_up=proj_up,
group_size=group_size,
precision="int4",
orig_dtype=torch.bfloat16
)
def test_svdquant_get_plain_tensors(self):
"""Test extracting plain tensors from SVDQuant quantized tensor"""
out_features, in_features = 256, 128
rank = 32
group_size = 64
qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
smooth_factor = torch.randn(in_features, dtype=torch.bfloat16)
smooth_factor_orig = torch.randn(in_features, dtype=torch.bfloat16)
proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16)
proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16)
layout_params = {
'wscales': wscales,
'smooth_factor': smooth_factor,
'smooth_factor_orig': smooth_factor_orig,
'proj_down': proj_down,
'proj_up': proj_up,
'group_size': group_size,
'precision': "int4",
'orig_dtype': torch.bfloat16,
'is_weight': True,
'act_unsigned': False,
'wtscale': None,
'wcscales': None,
}
qt = QuantizedTensor(qweight, "SVDQuantLayout", layout_params)
plain_tensors = SVDQuantLayout.get_plain_tensors(qt)
# Verify we can extract all necessary components
self.assertIsInstance(plain_tensors, dict)
self.assertIn('qweight', plain_tensors)
self.assertIn('wscales', plain_tensors)
self.assertIn('smooth_factor', plain_tensors)
self.assertIn('proj_down', plain_tensors)
self.assertIn('proj_up', plain_tensors)
self.assertIn('group_size', plain_tensors)
self.assertIn('precision', plain_tensors)
self.assertTrue(torch.equal(plain_tensors['qweight'], qweight))
self.assertTrue(torch.equal(plain_tensors['wscales'], wscales))
self.assertTrue(torch.equal(plain_tensors['smooth_factor'], smooth_factor))
self.assertTrue(torch.equal(plain_tensors['proj_down'], proj_down))
self.assertTrue(torch.equal(plain_tensors['proj_up'], proj_up))
class TestFallbackMechanism(unittest.TestCase):
"""Test fallback for unsupported operations"""
def test_unsupported_op_dequantizes(self):
"""Test that unsupported operations fall back to dequantization"""
# Set seed for reproducibility
torch.manual_seed(42)
# Create quantized tensor
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
scale = torch.tensor(1.0)
a_q = QuantizedTensor.from_float(
a_fp32,
"TensorCoreFP8Layout",
scale=scale,
dtype=torch.float8_e4m3fn
)
# Call an operation that doesn't have a registered handler
# For example, torch.abs
result = torch.abs(a_q)
# Should work via fallback (dequantize → abs → return)
self.assertNotIsInstance(result, QuantizedTensor)
expected = torch.abs(a_fp32)
# FP8 introduces quantization error, so use loose tolerance
mean_error = (result - expected).abs().mean()
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
class TestAWQConversion(unittest.TestCase):
"""Test AWQ checkpoint conversion"""
def test_awq_single_layer_conversion(self):
"""Test converting a single AWQ layer"""
in_features, out_features = 128, 256
group_size = 64
# Create AWQ checkpoint format
state_dict = {
"layer.qweight": torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32),
"layer.wscales": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
"layer.wzeros": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
"layer.bias": torch.randn(out_features, dtype=torch.bfloat16),
}
converted = convert_awq_state_dict(state_dict)
# Check that qweight was renamed to weight
self.assertIn("layer.weight", converted.tensors)
self.assertNotIn("layer.qweight", converted.tensors)
# Check other parameters preserved
self.assertIn("layer.wscales", converted.tensors)
self.assertIn("layer.wzeros", converted.tensors)
self.assertIn("layer.bias", converted.tensors)
# Check quantization metadata
self.assertIn("layer", converted.quant_layers)
self.assertEqual(converted.quant_layers["layer"], "awq_int4")
def test_awq_tensor_shapes(self):
"""Test that converted AWQ tensors have correct shapes"""
in_features, out_features = 3072, 18432
group_size = 64
state_dict = {
"layer.qweight": torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32),
"layer.wscales": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
"layer.wzeros": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
}
converted = convert_awq_state_dict(state_dict)
# Check qweight shape (packed 4-bit)
qweight = converted.tensors["layer.weight"]
self.assertEqual(qweight.shape, (out_features // 4, in_features // 2))
self.assertEqual(qweight.dtype, torch.int32)
# Check wscales shape
wscales = converted.tensors["layer.wscales"]
self.assertEqual(wscales.shape, (in_features // group_size, out_features))
self.assertEqual(wscales.dtype, torch.bfloat16)
# Check wzeros shape
wzeros = converted.tensors["layer.wzeros"]
self.assertEqual(wzeros.shape, (in_features // group_size, out_features))
self.assertEqual(wzeros.dtype, torch.bfloat16)
class TestAWQLinearOperation(unittest.TestCase):
"""Test AWQ linear operations with actual nunchaku kernels"""
@unittest.skipUnless(has_gpu(), "GPU required for AWQ operations")
def test_awq_linear_basic(self):
"""Test basic AWQ linear operation by calling kernel directly"""
try:
from nunchaku.ops.gemv import awq_gemv_w4a16_cuda
except ImportError:
self.skipTest("nunchaku package not available")
device = torch.device("cuda")
in_features, out_features = 128, 256
group_size = 64
batch_size = 4
# Create AWQ quantized weight tensors
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32, device=device)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
bias = torch.randn(out_features, dtype=torch.bfloat16, device=device)
# Create layout params
layout_params = {
'wscales': wscales,
'wzeros': wzeros,
'group_size': group_size,
'orig_dtype': torch.bfloat16,
'is_weight': True,
}
weight = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
# Check that weight is a QuantizedTensor
self.assertIsInstance(weight, QuantizedTensor)
self.assertEqual(weight._layout_type, "AWQQuantLayout")
# Create input
x = torch.randn(batch_size, in_features, dtype=torch.bfloat16, device=device)
# Call AWQ linear handler directly
from comfy.quant_ops import awq_linear
output = awq_linear(torch.ops.aten.linear.default, (x, weight, bias), {})
# Check output shape and dtype
self.assertEqual(output.shape, (batch_size, out_features))
self.assertEqual(output.dtype, torch.bfloat16)
@unittest.skipUnless(has_gpu(), "GPU required for AWQ operations")
def test_awq_linear_2d_input(self):
"""Test AWQ linear with 2D input (batch, features) by calling kernel directly"""
try:
from nunchaku.ops.gemv import awq_gemv_w4a16_cuda
except ImportError:
self.skipTest("nunchaku package not available")
device = torch.device("cuda")
in_features, out_features = 128, 256
group_size = 64
batch_size = 4
# Create AWQ quantized weight tensors
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32, device=device)
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
# Create layout params
layout_params = {
'wscales': wscales,
'wzeros': wzeros,
'group_size': group_size,
'orig_dtype': torch.bfloat16,
'is_weight': True,
}
weight = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
# Check that weight is a QuantizedTensor
self.assertIsInstance(weight, QuantizedTensor)
self.assertEqual(weight._layout_type, "AWQQuantLayout")
# Create 2D input
x = torch.randn(batch_size, in_features, dtype=torch.bfloat16, device=device)
# Call AWQ linear handler directly
from comfy.quant_ops import awq_linear
output = awq_linear(torch.ops.aten.linear.default, (x, weight, None), {})
# Check output shape
self.assertEqual(output.shape, (batch_size, out_features))
self.assertEqual(output.dtype, torch.bfloat16)
if __name__ == "__main__":
unittest.main()