mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 00:30:55 +08:00
541 lines
22 KiB
Python
541 lines
22 KiB
Python
import os
|
|
import sys
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from safetensors.torch import load_file
|
|
|
|
# Add comfy to path
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
|
|
def has_gpu():
|
|
return torch.cuda.is_available()
|
|
|
|
from comfy.cli_args import args
|
|
if not has_gpu():
|
|
args.cpu = True
|
|
|
|
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout, AWQQuantLayout, SVDQuantLayout
|
|
from comfy.ops import mixed_precision_ops
|
|
from comfy.svdquant_converter import convert_svdquant_state_dict, convert_awq_state_dict
|
|
|
|
|
|
class TestQuantizedTensor(unittest.TestCase):
|
|
"""Test the QuantizedTensor subclass with FP8 layout"""
|
|
|
|
def test_creation(self):
|
|
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
|
|
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
scale = torch.tensor(2.0)
|
|
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
|
|
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
|
|
self.assertIsInstance(qt, QuantizedTensor)
|
|
self.assertEqual(qt.shape, (256, 128))
|
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
|
self.assertEqual(qt._layout_params['scale'], scale)
|
|
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
|
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
|
|
|
|
def test_dequantize(self):
|
|
"""Test explicit dequantization"""
|
|
|
|
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
scale = torch.tensor(3.0)
|
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
|
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
dequantized = qt.dequantize()
|
|
|
|
self.assertEqual(dequantized.dtype, torch.float32)
|
|
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
|
|
|
def test_from_float(self):
|
|
"""Test creating QuantizedTensor from float tensor"""
|
|
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
|
scale = torch.tensor(1.5)
|
|
|
|
qt = QuantizedTensor.from_float(
|
|
float_tensor,
|
|
"TensorCoreFP8Layout",
|
|
scale=scale,
|
|
dtype=torch.float8_e4m3fn
|
|
)
|
|
|
|
self.assertIsInstance(qt, QuantizedTensor)
|
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
|
self.assertEqual(qt.shape, (64, 32))
|
|
|
|
# Verify dequantization gives approximately original values
|
|
dequantized = qt.dequantize()
|
|
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
|
|
self.assertLess(mean_rel_error, 0.1)
|
|
|
|
|
|
class TestGenericUtilities(unittest.TestCase):
|
|
"""Test generic utility operations"""
|
|
|
|
def test_detach(self):
|
|
"""Test detach operation on quantized tensor"""
|
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
scale = torch.tensor(1.5)
|
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
|
|
# Detach should return a new QuantizedTensor
|
|
qt_detached = qt.detach()
|
|
|
|
self.assertIsInstance(qt_detached, QuantizedTensor)
|
|
self.assertEqual(qt_detached.shape, qt.shape)
|
|
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
|
|
|
|
def test_clone(self):
|
|
"""Test clone operation on quantized tensor"""
|
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
scale = torch.tensor(1.5)
|
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
|
|
# Clone should return a new QuantizedTensor
|
|
qt_cloned = qt.clone()
|
|
|
|
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
|
self.assertEqual(qt_cloned.shape, qt.shape)
|
|
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
|
|
|
|
# Verify it's a deep copy
|
|
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
|
|
|
@unittest.skipUnless(has_gpu(), "GPU not available")
|
|
def test_to_device(self):
|
|
"""Test device transfer"""
|
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
|
scale = torch.tensor(1.5)
|
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
|
|
|
# Moving to same device should work (CPU to CPU)
|
|
qt_cpu = qt.to('cpu')
|
|
|
|
self.assertIsInstance(qt_cpu, QuantizedTensor)
|
|
self.assertEqual(qt_cpu.device.type, 'cpu')
|
|
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
|
|
|
|
|
|
class TestTensorCoreFP8Layout(unittest.TestCase):
|
|
"""Test the TensorCoreFP8Layout implementation"""
|
|
|
|
def test_quantize(self):
|
|
"""Test quantization method"""
|
|
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
|
scale = torch.tensor(1.5)
|
|
|
|
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
|
float_tensor,
|
|
scale=scale,
|
|
dtype=torch.float8_e4m3fn
|
|
)
|
|
|
|
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
|
|
self.assertEqual(qdata.shape, float_tensor.shape)
|
|
self.assertIn('scale', layout_params)
|
|
self.assertIn('orig_dtype', layout_params)
|
|
self.assertEqual(layout_params['orig_dtype'], torch.float32)
|
|
|
|
def test_dequantize(self):
|
|
"""Test dequantization method"""
|
|
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
|
|
scale = torch.tensor(1.0)
|
|
|
|
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
|
float_tensor,
|
|
scale=scale,
|
|
dtype=torch.float8_e4m3fn
|
|
)
|
|
|
|
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
|
|
|
|
# Should approximately match original
|
|
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
|
|
|
|
|
|
class TestAWQQuantLayout(unittest.TestCase):
|
|
"""Test the AWQQuantLayout implementation"""
|
|
|
|
def test_awq_layout_creation(self):
|
|
"""Test creating an AWQ quantized tensor"""
|
|
# AWQ uses pre-quantized weights loaded from checkpoints
|
|
# Create dummy AWQ quantized weights
|
|
out_features, in_features = 256, 128
|
|
group_size = 64
|
|
|
|
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32)
|
|
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
|
|
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
|
|
|
|
layout_params = {
|
|
'wscales': wscales,
|
|
'wzeros': wzeros,
|
|
'group_size': group_size,
|
|
'orig_dtype': torch.bfloat16,
|
|
'is_weight': True,
|
|
}
|
|
|
|
qt = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
|
|
|
|
self.assertIsInstance(qt, QuantizedTensor)
|
|
self.assertEqual(qt.shape, qweight.shape)
|
|
self.assertEqual(qt.dtype, torch.int32)
|
|
self.assertEqual(qt._layout_type, "AWQQuantLayout")
|
|
self.assertEqual(qt._layout_params['group_size'], group_size)
|
|
|
|
def test_awq_quantize_not_supported(self):
|
|
"""Test that online quantization raises NotImplementedError for AWQ"""
|
|
# AWQ doesn't support online quantization - weights must be pre-quantized
|
|
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
|
|
|
with self.assertRaises(NotImplementedError):
|
|
AWQQuantLayout.quantize(float_tensor, is_weight=True)
|
|
|
|
def test_awq_get_plain_tensors(self):
|
|
"""Test extracting plain tensors from AWQ quantized tensor"""
|
|
out_features, in_features = 256, 128
|
|
group_size = 64
|
|
|
|
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32)
|
|
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
|
|
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
|
|
|
|
layout_params = {
|
|
'wscales': wscales,
|
|
'wzeros': wzeros,
|
|
'group_size': group_size,
|
|
'orig_dtype': torch.bfloat16,
|
|
'is_weight': True,
|
|
}
|
|
|
|
qt = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
|
|
plain_tensors = AWQQuantLayout.get_plain_tensors(qt)
|
|
|
|
# Verify we can extract all necessary components
|
|
self.assertIsInstance(plain_tensors, dict)
|
|
self.assertIn('qweight', plain_tensors)
|
|
self.assertIn('wscales', plain_tensors)
|
|
self.assertIn('wzeros', plain_tensors)
|
|
self.assertIn('group_size', plain_tensors)
|
|
self.assertTrue(torch.equal(plain_tensors['qweight'], qweight))
|
|
self.assertTrue(torch.equal(plain_tensors['wscales'], wscales))
|
|
self.assertTrue(torch.equal(plain_tensors['wzeros'], wzeros))
|
|
|
|
|
|
class TestSVDQuantLayout(unittest.TestCase):
|
|
"""Test the SVDQuantLayout implementation"""
|
|
|
|
def test_svdquant_layout_creation(self):
|
|
"""Test creating an SVDQuant quantized tensor"""
|
|
# SVDQuant uses pre-quantized weights loaded from checkpoints
|
|
out_features, in_features = 256, 128
|
|
rank = 32
|
|
group_size = 64
|
|
precision = "int4"
|
|
|
|
# Create dummy SVDQuant quantized weights (int8 range is -128 to 127)
|
|
qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8)
|
|
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
|
|
smooth_factor = torch.randn(in_features, dtype=torch.bfloat16)
|
|
smooth_factor_orig = torch.randn(in_features, dtype=torch.bfloat16)
|
|
proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16)
|
|
proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16)
|
|
|
|
layout_params = {
|
|
'wscales': wscales,
|
|
'smooth_factor': smooth_factor,
|
|
'smooth_factor_orig': smooth_factor_orig,
|
|
'proj_down': proj_down,
|
|
'proj_up': proj_up,
|
|
'group_size': group_size,
|
|
'precision': precision,
|
|
'orig_dtype': torch.bfloat16,
|
|
'is_weight': True,
|
|
'act_unsigned': False,
|
|
'wtscale': None,
|
|
'wcscales': None,
|
|
}
|
|
|
|
qt = QuantizedTensor(qweight, "SVDQuantLayout", layout_params)
|
|
|
|
self.assertIsInstance(qt, QuantizedTensor)
|
|
self.assertEqual(qt.shape, qweight.shape)
|
|
self.assertEqual(qt.dtype, torch.int8)
|
|
self.assertEqual(qt._layout_type, "SVDQuantLayout")
|
|
self.assertEqual(qt._layout_params['group_size'], group_size)
|
|
self.assertEqual(qt._layout_params['precision'], precision)
|
|
|
|
def test_svdquant_quantize_not_supported(self):
|
|
"""Test that online quantization raises NotImplementedError for SVDQuant"""
|
|
# SVDQuant doesn't support online quantization - weights must be pre-quantized
|
|
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
|
|
|
with self.assertRaises(NotImplementedError):
|
|
SVDQuantLayout.quantize(float_tensor, is_weight=True)
|
|
|
|
def test_svdquant_dequantize_not_supported(self):
|
|
"""Test that weight dequantization raises NotImplementedError for SVDQuant"""
|
|
# Full weight dequantization is not supported (complex operation)
|
|
out_features, in_features = 256, 128
|
|
rank = 32
|
|
group_size = 64
|
|
|
|
qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8)
|
|
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
|
|
smooth_factor = torch.randn(in_features, dtype=torch.bfloat16)
|
|
proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16)
|
|
proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16)
|
|
|
|
with self.assertRaises(NotImplementedError):
|
|
SVDQuantLayout.dequantize(
|
|
qweight,
|
|
is_weight=True,
|
|
wscales=wscales,
|
|
smooth_factor=smooth_factor,
|
|
proj_down=proj_down,
|
|
proj_up=proj_up,
|
|
group_size=group_size,
|
|
precision="int4",
|
|
orig_dtype=torch.bfloat16
|
|
)
|
|
|
|
def test_svdquant_get_plain_tensors(self):
|
|
"""Test extracting plain tensors from SVDQuant quantized tensor"""
|
|
out_features, in_features = 256, 128
|
|
rank = 32
|
|
group_size = 64
|
|
|
|
qweight = torch.randint(-128, 127, (out_features, in_features // 2), dtype=torch.int8)
|
|
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16)
|
|
smooth_factor = torch.randn(in_features, dtype=torch.bfloat16)
|
|
smooth_factor_orig = torch.randn(in_features, dtype=torch.bfloat16)
|
|
proj_down = torch.randn(in_features, rank, dtype=torch.bfloat16)
|
|
proj_up = torch.randn(out_features, rank, dtype=torch.bfloat16)
|
|
|
|
layout_params = {
|
|
'wscales': wscales,
|
|
'smooth_factor': smooth_factor,
|
|
'smooth_factor_orig': smooth_factor_orig,
|
|
'proj_down': proj_down,
|
|
'proj_up': proj_up,
|
|
'group_size': group_size,
|
|
'precision': "int4",
|
|
'orig_dtype': torch.bfloat16,
|
|
'is_weight': True,
|
|
'act_unsigned': False,
|
|
'wtscale': None,
|
|
'wcscales': None,
|
|
}
|
|
|
|
qt = QuantizedTensor(qweight, "SVDQuantLayout", layout_params)
|
|
plain_tensors = SVDQuantLayout.get_plain_tensors(qt)
|
|
|
|
# Verify we can extract all necessary components
|
|
self.assertIsInstance(plain_tensors, dict)
|
|
self.assertIn('qweight', plain_tensors)
|
|
self.assertIn('wscales', plain_tensors)
|
|
self.assertIn('smooth_factor', plain_tensors)
|
|
self.assertIn('proj_down', plain_tensors)
|
|
self.assertIn('proj_up', plain_tensors)
|
|
self.assertIn('group_size', plain_tensors)
|
|
self.assertIn('precision', plain_tensors)
|
|
self.assertTrue(torch.equal(plain_tensors['qweight'], qweight))
|
|
self.assertTrue(torch.equal(plain_tensors['wscales'], wscales))
|
|
self.assertTrue(torch.equal(plain_tensors['smooth_factor'], smooth_factor))
|
|
self.assertTrue(torch.equal(plain_tensors['proj_down'], proj_down))
|
|
self.assertTrue(torch.equal(plain_tensors['proj_up'], proj_up))
|
|
|
|
|
|
class TestFallbackMechanism(unittest.TestCase):
|
|
"""Test fallback for unsupported operations"""
|
|
|
|
def test_unsupported_op_dequantizes(self):
|
|
"""Test that unsupported operations fall back to dequantization"""
|
|
# Set seed for reproducibility
|
|
torch.manual_seed(42)
|
|
|
|
# Create quantized tensor
|
|
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
|
scale = torch.tensor(1.0)
|
|
a_q = QuantizedTensor.from_float(
|
|
a_fp32,
|
|
"TensorCoreFP8Layout",
|
|
scale=scale,
|
|
dtype=torch.float8_e4m3fn
|
|
)
|
|
|
|
# Call an operation that doesn't have a registered handler
|
|
# For example, torch.abs
|
|
result = torch.abs(a_q)
|
|
|
|
# Should work via fallback (dequantize → abs → return)
|
|
self.assertNotIsInstance(result, QuantizedTensor)
|
|
expected = torch.abs(a_fp32)
|
|
# FP8 introduces quantization error, so use loose tolerance
|
|
mean_error = (result - expected).abs().mean()
|
|
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
|
|
|
|
|
|
class TestAWQConversion(unittest.TestCase):
|
|
"""Test AWQ checkpoint conversion"""
|
|
|
|
def test_awq_single_layer_conversion(self):
|
|
"""Test converting a single AWQ layer"""
|
|
in_features, out_features = 128, 256
|
|
group_size = 64
|
|
|
|
# Create AWQ checkpoint format
|
|
state_dict = {
|
|
"layer.qweight": torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32),
|
|
"layer.wscales": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
|
|
"layer.wzeros": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
|
|
"layer.bias": torch.randn(out_features, dtype=torch.bfloat16),
|
|
}
|
|
|
|
converted = convert_awq_state_dict(state_dict)
|
|
|
|
# Check that qweight was renamed to weight
|
|
self.assertIn("layer.weight", converted.tensors)
|
|
self.assertNotIn("layer.qweight", converted.tensors)
|
|
|
|
# Check other parameters preserved
|
|
self.assertIn("layer.wscales", converted.tensors)
|
|
self.assertIn("layer.wzeros", converted.tensors)
|
|
self.assertIn("layer.bias", converted.tensors)
|
|
|
|
# Check quantization metadata
|
|
self.assertIn("layer", converted.quant_layers)
|
|
self.assertEqual(converted.quant_layers["layer"], "awq_int4")
|
|
|
|
def test_awq_tensor_shapes(self):
|
|
"""Test that converted AWQ tensors have correct shapes"""
|
|
in_features, out_features = 3072, 18432
|
|
group_size = 64
|
|
|
|
state_dict = {
|
|
"layer.qweight": torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32),
|
|
"layer.wscales": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
|
|
"layer.wzeros": torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16),
|
|
}
|
|
|
|
converted = convert_awq_state_dict(state_dict)
|
|
|
|
# Check qweight shape (packed 4-bit)
|
|
qweight = converted.tensors["layer.weight"]
|
|
self.assertEqual(qweight.shape, (out_features // 4, in_features // 2))
|
|
self.assertEqual(qweight.dtype, torch.int32)
|
|
|
|
# Check wscales shape
|
|
wscales = converted.tensors["layer.wscales"]
|
|
self.assertEqual(wscales.shape, (in_features // group_size, out_features))
|
|
self.assertEqual(wscales.dtype, torch.bfloat16)
|
|
|
|
# Check wzeros shape
|
|
wzeros = converted.tensors["layer.wzeros"]
|
|
self.assertEqual(wzeros.shape, (in_features // group_size, out_features))
|
|
self.assertEqual(wzeros.dtype, torch.bfloat16)
|
|
|
|
|
|
class TestAWQLinearOperation(unittest.TestCase):
|
|
"""Test AWQ linear operations with actual nunchaku kernels"""
|
|
|
|
@unittest.skipUnless(has_gpu(), "GPU required for AWQ operations")
|
|
def test_awq_linear_basic(self):
|
|
"""Test basic AWQ linear operation by calling kernel directly"""
|
|
try:
|
|
from nunchaku.ops.gemv import awq_gemv_w4a16_cuda
|
|
except ImportError:
|
|
self.skipTest("nunchaku package not available")
|
|
|
|
device = torch.device("cuda")
|
|
in_features, out_features = 128, 256
|
|
group_size = 64
|
|
batch_size = 4
|
|
|
|
# Create AWQ quantized weight tensors
|
|
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32, device=device)
|
|
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
|
|
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
|
|
bias = torch.randn(out_features, dtype=torch.bfloat16, device=device)
|
|
|
|
# Create layout params
|
|
layout_params = {
|
|
'wscales': wscales,
|
|
'wzeros': wzeros,
|
|
'group_size': group_size,
|
|
'orig_dtype': torch.bfloat16,
|
|
'is_weight': True,
|
|
}
|
|
|
|
weight = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
|
|
|
|
# Check that weight is a QuantizedTensor
|
|
self.assertIsInstance(weight, QuantizedTensor)
|
|
self.assertEqual(weight._layout_type, "AWQQuantLayout")
|
|
|
|
# Create input
|
|
x = torch.randn(batch_size, in_features, dtype=torch.bfloat16, device=device)
|
|
|
|
# Call AWQ linear handler directly
|
|
from comfy.quant_ops import awq_linear
|
|
output = awq_linear(torch.ops.aten.linear.default, (x, weight, bias), {})
|
|
|
|
# Check output shape and dtype
|
|
self.assertEqual(output.shape, (batch_size, out_features))
|
|
self.assertEqual(output.dtype, torch.bfloat16)
|
|
|
|
@unittest.skipUnless(has_gpu(), "GPU required for AWQ operations")
|
|
def test_awq_linear_2d_input(self):
|
|
"""Test AWQ linear with 2D input (batch, features) by calling kernel directly"""
|
|
try:
|
|
from nunchaku.ops.gemv import awq_gemv_w4a16_cuda
|
|
except ImportError:
|
|
self.skipTest("nunchaku package not available")
|
|
|
|
device = torch.device("cuda")
|
|
in_features, out_features = 128, 256
|
|
group_size = 64
|
|
batch_size = 4
|
|
|
|
# Create AWQ quantized weight tensors
|
|
qweight = torch.randint(0, 255, (out_features // 4, in_features // 2), dtype=torch.int32, device=device)
|
|
wscales = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
|
|
wzeros = torch.randn(in_features // group_size, out_features, dtype=torch.bfloat16, device=device)
|
|
|
|
# Create layout params
|
|
layout_params = {
|
|
'wscales': wscales,
|
|
'wzeros': wzeros,
|
|
'group_size': group_size,
|
|
'orig_dtype': torch.bfloat16,
|
|
'is_weight': True,
|
|
}
|
|
|
|
weight = QuantizedTensor(qweight, "AWQQuantLayout", layout_params)
|
|
|
|
# Check that weight is a QuantizedTensor
|
|
self.assertIsInstance(weight, QuantizedTensor)
|
|
self.assertEqual(weight._layout_type, "AWQQuantLayout")
|
|
|
|
# Create 2D input
|
|
x = torch.randn(batch_size, in_features, dtype=torch.bfloat16, device=device)
|
|
|
|
# Call AWQ linear handler directly
|
|
from comfy.quant_ops import awq_linear
|
|
output = awq_linear(torch.ops.aten.linear.default, (x, weight, None), {})
|
|
|
|
# Check output shape
|
|
self.assertEqual(output.shape, (batch_size, out_features))
|
|
self.assertEqual(output.dtype, torch.bfloat16)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main() |