mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-15 16:02:32 +08:00
Update unit tests
This commit is contained in:
parent
efb35035f3
commit
a7216e18e5
@ -546,7 +546,7 @@ class MixedPrecisionOps(disable_weight_init):
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
self._forward(input, weight, bias)
|
||||
return self._forward(input, weight, bias)
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
run_every_op()
|
||||
|
||||
@ -143,7 +143,7 @@ class QuantizedTensor(torch.Tensor):
|
||||
"""
|
||||
Tensor flattening protocol for proper device movement.
|
||||
"""
|
||||
inner_tensors = ["_q_data"]
|
||||
inner_tensors = ["_qdata"]
|
||||
ctx = {
|
||||
"layout_type": self._layout_type,
|
||||
}
|
||||
@ -206,7 +206,7 @@ class QuantizedTensor(torch.Tensor):
|
||||
|
||||
# Step 3: Fallback to dequantization
|
||||
if isinstance(args[0] if args else None, QuantizedTensor):
|
||||
logging.warning(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
||||
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
||||
return cls._dequant_and_fallback(func, args, kwargs)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1,8 +1,3 @@
|
||||
"""
|
||||
End-to-end tests for mixed precision quantization.
|
||||
Tests Phase 3: Mixed Precision Operations
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
import sys
|
||||
@ -12,10 +7,10 @@ import os
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from comfy import ops
|
||||
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
"""Simple model for testing mixed precision"""
|
||||
def __init__(self, operations=ops.disable_weight_init):
|
||||
super().__init__()
|
||||
self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16)
|
||||
@ -32,8 +27,7 @@ class SimpleModel(torch.nn.Module):
|
||||
|
||||
|
||||
class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"""Test MixedPrecisionOps end-to-end"""
|
||||
|
||||
|
||||
def test_all_layers_standard(self):
|
||||
"""Test that model with no quantization works normally"""
|
||||
# Configure no quantization
|
||||
@ -67,48 +61,54 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
# Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard
|
||||
layer_quant_config = {
|
||||
"layer1": {
|
||||
"format": "fp8_e4m3fn_scaled",
|
||||
"params": {"use_fp8_matmul": False} # Disable for CPU testing
|
||||
"format": "float8_e4m3fn",
|
||||
"params": {}
|
||||
},
|
||||
"layer3": {
|
||||
"format": "fp8_e5m2_scaled",
|
||||
"params": {"use_fp8_matmul": False}
|
||||
"format": "float8_e4m3fn",
|
||||
"params": {}
|
||||
}
|
||||
}
|
||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||
|
||||
# Create state dict with mixed precision
|
||||
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e5m2)
|
||||
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
|
||||
state_dict = {
|
||||
# Layer 1: FP8 E4M3FN
|
||||
"layer1.weight": fp8_weight1,
|
||||
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||
"layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32),
|
||||
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
|
||||
|
||||
# Layer 2: Standard BF16
|
||||
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||
|
||||
# Layer 3: FP8 E5M2
|
||||
# Layer 3: FP8 E4M3FN
|
||||
"layer3.weight": fp8_weight3,
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
"layer3.scale_weight": torch.tensor(1.5, dtype=torch.float32),
|
||||
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
|
||||
}
|
||||
|
||||
# Create model and load state dict
|
||||
# Create model and load state dict (strict=False because custom loading pops keys)
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model.load_state_dict(state_dict)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Verify handlers are set up correctly
|
||||
self.assertIsNotNone(model.layer1.quant_handler)
|
||||
self.assertIsNone(model.layer2.quant_handler) # No quantization
|
||||
self.assertIsNotNone(model.layer3.quant_handler)
|
||||
# Verify weights are wrapped in QuantizedTensor
|
||||
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
||||
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
|
||||
|
||||
# Layer 2 should NOT be quantized
|
||||
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
||||
|
||||
# Layer 3 should be quantized
|
||||
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
||||
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
|
||||
|
||||
# Verify scales were loaded
|
||||
self.assertEqual(model.layer1.scale_weight.item(), 2.0)
|
||||
self.assertEqual(model.layer3.scale_weight.item(), 1.5)
|
||||
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
||||
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
|
||||
|
||||
# Forward pass
|
||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||
@ -116,13 +116,13 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
|
||||
self.assertEqual(output.shape, (5, 40))
|
||||
|
||||
def test_state_dict_round_trip(self):
|
||||
"""Test saving and loading state dict preserves quantization"""
|
||||
def test_state_dict_quantized_preserved(self):
|
||||
"""Test that quantized weights are preserved in state_dict()"""
|
||||
# Configure mixed precision
|
||||
layer_quant_config = {
|
||||
"layer1": {
|
||||
"format": "fp8_e4m3fn_scaled",
|
||||
"params": {"use_fp8_matmul": False}
|
||||
"format": "float8_e4m3fn",
|
||||
"params": {}
|
||||
}
|
||||
}
|
||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||
@ -132,45 +132,35 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
state_dict1 = {
|
||||
"layer1.weight": fp8_weight,
|
||||
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||
"layer1.scale_weight": torch.tensor(3.0, dtype=torch.float32),
|
||||
"layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32),
|
||||
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
model1 = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model1.load_state_dict(state_dict1)
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model.load_state_dict(state_dict1, strict=False)
|
||||
|
||||
# Save state dict
|
||||
state_dict2 = model1.state_dict()
|
||||
state_dict2 = model.state_dict()
|
||||
|
||||
# Verify scale_weight is saved
|
||||
self.assertIn("layer1.scale_weight", state_dict2)
|
||||
self.assertEqual(state_dict2["layer1.scale_weight"].item(), 3.0)
|
||||
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
||||
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
|
||||
|
||||
# Load into new model
|
||||
model2 = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model2.load_state_dict(state_dict2)
|
||||
|
||||
# Verify handler is set up
|
||||
self.assertIsNotNone(model2.layer1.quant_handler)
|
||||
self.assertEqual(model2.layer1.scale_weight.item(), 3.0)
|
||||
|
||||
# Verify forward passes match
|
||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||
output1 = model1(input_tensor)
|
||||
output2 = model2(input_tensor)
|
||||
|
||||
torch.testing.assert_close(output1, output2, rtol=1e-3, atol=1e-3)
|
||||
# Verify non-quantized layers are standard tensors
|
||||
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||
self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor)
|
||||
|
||||
def test_weight_function_compatibility(self):
|
||||
"""Test that weight_function (LoRA) works with quantized layers"""
|
||||
# Configure FP8 quantization
|
||||
layer_quant_config = {
|
||||
"layer1": {
|
||||
"format": "fp8_e4m3fn_scaled",
|
||||
"params": {"use_fp8_matmul": False}
|
||||
"format": "float8_e4m3fn",
|
||||
"params": {}
|
||||
}
|
||||
}
|
||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||
@ -180,7 +170,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
state_dict = {
|
||||
"layer1.weight": fp8_weight,
|
||||
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||
"layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32),
|
||||
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
|
||||
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||
@ -188,25 +178,24 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
}
|
||||
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model.load_state_dict(state_dict)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Add a weight function (simulating LoRA)
|
||||
# LoRA delta must match weight shape (20, 10)
|
||||
# This should trigger dequantization during forward pass
|
||||
def apply_lora(weight):
|
||||
# Generate LoRA delta matching weight shape
|
||||
lora_delta = torch.randn_like(weight) * 0.01
|
||||
return weight + lora_delta
|
||||
|
||||
model.layer1.weight_function.append(apply_lora)
|
||||
|
||||
# Forward pass should work with LoRA
|
||||
# Forward pass should work with LoRA (triggers weight_function path)
|
||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||
output = model(input_tensor)
|
||||
|
||||
self.assertEqual(output.shape, (5, 40))
|
||||
|
||||
def test_error_handling_unknown_format(self):
|
||||
"""Test that unknown formats fall back gracefully"""
|
||||
"""Test that unknown formats raise error"""
|
||||
# Configure with unknown format
|
||||
layer_quant_config = {
|
||||
"layer1": {
|
||||
@ -226,48 +215,10 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
# Load should not crash, just log warning
|
||||
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# Handler should be None (fallback to standard)
|
||||
self.assertIsNone(model.layer1.quant_handler)
|
||||
|
||||
# Forward pass should still work
|
||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||
output = model(input_tensor)
|
||||
self.assertEqual(output.shape, (5, 40))
|
||||
|
||||
|
||||
class TestPickOperationsWithMixedPrecision(unittest.TestCase):
|
||||
"""Test pick_operations with mixed precision config"""
|
||||
|
||||
def test_pick_operations_with_layer_quant_config(self):
|
||||
"""Test that pick_operations returns MixedPrecisionOps when config present"""
|
||||
from comfy import supported_models_base
|
||||
|
||||
# Create model config with layer_quant_config
|
||||
model_config = supported_models_base.BASE({})
|
||||
model_config.layer_quant_config = {
|
||||
"layer1": {"format": "fp8_e4m3fn_scaled", "params": {}}
|
||||
}
|
||||
|
||||
result = ops.pick_operations(None, None, model_config=model_config)
|
||||
|
||||
self.assertEqual(result, ops.MixedPrecisionOps)
|
||||
self.assertEqual(ops.MixedPrecisionOps._layer_quant_config, model_config.layer_quant_config)
|
||||
|
||||
def test_pick_operations_without_layer_quant_config(self):
|
||||
"""Test that pick_operations falls back to standard when no config"""
|
||||
from comfy import supported_models_base
|
||||
|
||||
model_config = supported_models_base.BASE({})
|
||||
model_config.layer_quant_config = None
|
||||
|
||||
result = ops.pick_operations(None, None, model_config=model_config)
|
||||
|
||||
self.assertEqual(result, ops.disable_weight_init)
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
183
tests-unit/comfy_quant/test_quant_registry.py
Normal file
183
tests-unit/comfy_quant/test_quant_registry.py
Normal file
@ -0,0 +1,183 @@
|
||||
import unittest
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add comfy to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||
|
||||
|
||||
class TestQuantizedTensor(unittest.TestCase):
|
||||
"""Test the QuantizedTensor subclass with FP8 layout"""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
|
||||
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(2.0)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
||||
|
||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||
|
||||
self.assertIsInstance(qt, QuantizedTensor)
|
||||
self.assertEqual(qt.shape, (256, 128))
|
||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qt._layout_params['scale'], scale)
|
||||
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||
self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
|
||||
|
||||
def test_dequantize(self):
|
||||
"""Test explicit dequantization"""
|
||||
|
||||
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(3.0)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
|
||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||
dequantized = qt.dequantize()
|
||||
|
||||
self.assertEqual(dequantized.dtype, torch.float32)
|
||||
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
||||
|
||||
def test_from_float(self):
|
||||
"""Test creating QuantizedTensor from float tensor"""
|
||||
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
||||
scale = torch.tensor(1.5)
|
||||
|
||||
qt = QuantizedTensor.from_float(
|
||||
float_tensor,
|
||||
TensorCoreFP8Layout,
|
||||
scale=scale,
|
||||
fp8_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
self.assertIsInstance(qt, QuantizedTensor)
|
||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qt.shape, (64, 32))
|
||||
|
||||
# Verify dequantization gives approximately original values
|
||||
dequantized = qt.dequantize()
|
||||
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
|
||||
self.assertLess(mean_rel_error, 0.1)
|
||||
|
||||
|
||||
class TestGenericUtilities(unittest.TestCase):
|
||||
"""Test generic utility operations"""
|
||||
|
||||
def test_detach(self):
|
||||
"""Test detach operation on quantized tensor"""
|
||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(1.5)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||
|
||||
# Detach should return a new QuantizedTensor
|
||||
qt_detached = qt.detach()
|
||||
|
||||
self.assertIsInstance(qt_detached, QuantizedTensor)
|
||||
self.assertEqual(qt_detached.shape, qt.shape)
|
||||
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
|
||||
|
||||
def test_clone(self):
|
||||
"""Test clone operation on quantized tensor"""
|
||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(1.5)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||
|
||||
# Clone should return a new QuantizedTensor
|
||||
qt_cloned = qt.clone()
|
||||
|
||||
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
||||
self.assertEqual(qt_cloned.shape, qt.shape)
|
||||
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
|
||||
|
||||
# Verify it's a deep copy
|
||||
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
||||
|
||||
def test_to_device(self):
|
||||
"""Test device transfer"""
|
||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(1.5)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
||||
|
||||
# Moving to same device should work (CPU to CPU)
|
||||
qt_cpu = qt.to('cpu')
|
||||
|
||||
self.assertIsInstance(qt_cpu, QuantizedTensor)
|
||||
self.assertEqual(qt_cpu.device.type, 'cpu')
|
||||
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
|
||||
|
||||
|
||||
class TestTensorCoreFP8Layout(unittest.TestCase):
|
||||
"""Test the TensorCoreFP8Layout implementation"""
|
||||
|
||||
def test_quantize(self):
|
||||
"""Test quantization method"""
|
||||
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
||||
scale = torch.tensor(1.5)
|
||||
|
||||
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
||||
float_tensor,
|
||||
scale=scale,
|
||||
fp8_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qdata.shape, float_tensor.shape)
|
||||
self.assertIn('scale', layout_params)
|
||||
self.assertIn('orig_dtype', layout_params)
|
||||
self.assertEqual(layout_params['orig_dtype'], torch.float32)
|
||||
|
||||
def test_dequantize(self):
|
||||
"""Test dequantization method"""
|
||||
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
|
||||
scale = torch.tensor(1.0)
|
||||
|
||||
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
||||
float_tensor,
|
||||
scale=scale,
|
||||
fp8_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
|
||||
|
||||
# Should approximately match original
|
||||
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
|
||||
|
||||
|
||||
class TestFallbackMechanism(unittest.TestCase):
|
||||
"""Test fallback for unsupported operations"""
|
||||
|
||||
def test_unsupported_op_dequantizes(self):
|
||||
"""Test that unsupported operations fall back to dequantization"""
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create quantized tensor
|
||||
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
||||
scale = torch.tensor(1.0)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
a_q = QuantizedTensor.from_float(
|
||||
a_fp32,
|
||||
TensorCoreFP8Layout,
|
||||
scale=scale,
|
||||
fp8_dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
# Call an operation that doesn't have a registered handler
|
||||
# For example, torch.abs
|
||||
result = torch.abs(a_q)
|
||||
|
||||
# Should work via fallback (dequantize → abs → return)
|
||||
self.assertNotIsInstance(result, QuantizedTensor)
|
||||
expected = torch.abs(a_fp32)
|
||||
# FP8 introduces quantization error, so use loose tolerance
|
||||
mean_error = (result - expected).abs().mean()
|
||||
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -1,262 +0,0 @@
|
||||
"""
|
||||
Integration tests for quantization detection.
|
||||
Tests Phase 2: Detection & Integration
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add comfy to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from comfy import model_detection
|
||||
|
||||
|
||||
class TestNormalizeLayerName(unittest.TestCase):
|
||||
"""Test the normalize_layer_name helper function"""
|
||||
|
||||
def test_strip_prefix_and_suffix(self):
|
||||
"""Test stripping prefix and suffix"""
|
||||
known_prefixes = ["model.diffusion_model."]
|
||||
result = model_detection.normalize_layer_name(
|
||||
"model.diffusion_model.layer1.weight",
|
||||
known_prefixes
|
||||
)
|
||||
self.assertEqual(result, "layer1")
|
||||
|
||||
def test_strip_multiple_prefixes(self):
|
||||
"""Test with multiple known prefixes"""
|
||||
known_prefixes = ["model.diffusion_model.", "model.model.", "net."]
|
||||
|
||||
result1 = model_detection.normalize_layer_name(
|
||||
"model.diffusion_model.block.attn.weight",
|
||||
known_prefixes
|
||||
)
|
||||
self.assertEqual(result1, "block.attn")
|
||||
|
||||
result2 = model_detection.normalize_layer_name(
|
||||
"model.model.encoder.layer.weight",
|
||||
known_prefixes
|
||||
)
|
||||
self.assertEqual(result2, "encoder.layer")
|
||||
|
||||
result3 = model_detection.normalize_layer_name(
|
||||
"net.transformer.blocks.0.weight",
|
||||
known_prefixes
|
||||
)
|
||||
self.assertEqual(result3, "transformer.blocks.0")
|
||||
|
||||
def test_strip_scale_weight_suffix(self):
|
||||
"""Test stripping scale_weight suffix"""
|
||||
known_prefixes = ["model.diffusion_model."]
|
||||
result = model_detection.normalize_layer_name(
|
||||
"model.diffusion_model.layer1.scale_weight",
|
||||
known_prefixes
|
||||
)
|
||||
self.assertEqual(result, "layer1")
|
||||
|
||||
def test_strip_bias_suffix(self):
|
||||
"""Test stripping bias suffix"""
|
||||
known_prefixes = ["model.diffusion_model."]
|
||||
result = model_detection.normalize_layer_name(
|
||||
"model.diffusion_model.layer1.bias",
|
||||
known_prefixes
|
||||
)
|
||||
self.assertEqual(result, "layer1")
|
||||
|
||||
def test_no_prefix_match(self):
|
||||
"""Test with no prefix match"""
|
||||
known_prefixes = ["model.diffusion_model."]
|
||||
result = model_detection.normalize_layer_name(
|
||||
"other.model.layer1.weight",
|
||||
known_prefixes
|
||||
)
|
||||
# Should strip suffix but not prefix
|
||||
self.assertEqual(result, "other.model.layer1")
|
||||
|
||||
|
||||
class TestDetectLayerQuantization(unittest.TestCase):
|
||||
"""Test the detect_layer_quantization function"""
|
||||
|
||||
def test_no_quantization(self):
|
||||
"""Test with no quantization markers"""
|
||||
state_dict = {
|
||||
"model.diffusion_model.layer1.weight": torch.randn(10, 20),
|
||||
"model.diffusion_model.layer2.weight": torch.randn(20, 30),
|
||||
}
|
||||
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_legacy_scaled_fp8(self):
|
||||
"""Test that legacy scaled_fp8 marker returns None"""
|
||||
# Create FP8 tensor by converting from float32
|
||||
fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
state_dict = {
|
||||
"model.diffusion_model.scaled_fp8": torch.tensor([], dtype=torch.float8_e4m3fn),
|
||||
"model.diffusion_model.layer1.weight": fp8_weight,
|
||||
"model.diffusion_model.layer1.scale_weight": torch.tensor(1.0),
|
||||
}
|
||||
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
|
||||
# Should return None to trigger legacy path
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_metadata_format(self):
|
||||
"""Test with new metadata format"""
|
||||
metadata = {
|
||||
"format_version": "1.0",
|
||||
"layers": {
|
||||
"layer1": {
|
||||
"format": "fp8_e4m3fn_scaled",
|
||||
"params": {"use_fp8_matmul": True}
|
||||
},
|
||||
"layer2": {
|
||||
"format": "fp8_e5m2_scaled",
|
||||
"params": {"use_fp8_matmul": True}
|
||||
}
|
||||
}
|
||||
}
|
||||
state_dict = {
|
||||
"model.diffusion_model._quantization_metadata": metadata,
|
||||
"model.diffusion_model.layer1.weight": torch.randn(10, 20),
|
||||
}
|
||||
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertIn("layer1", result)
|
||||
self.assertIn("layer2", result)
|
||||
self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled")
|
||||
self.assertEqual(result["layer2"]["format"], "fp8_e5m2_scaled")
|
||||
# Metadata should be popped from state_dict
|
||||
self.assertNotIn("model.diffusion_model._quantization_metadata", state_dict)
|
||||
|
||||
def test_mixed_precision_detection(self):
|
||||
"""Test detection of mixed precision via scale patterns"""
|
||||
# Create FP8 tensors by converting from float32
|
||||
fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
fp8_weight3 = torch.randn(30, 40, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
state_dict = {
|
||||
# Layer 1: FP8 (has scale_weight)
|
||||
"model.diffusion_model.layer1.weight": fp8_weight1,
|
||||
"model.diffusion_model.layer1.scale_weight": torch.tensor(1.0),
|
||||
# Layer 2: Standard (no scale_weight)
|
||||
"model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16),
|
||||
# Layer 3: FP8 (has scale_weight)
|
||||
"model.diffusion_model.layer3.weight": fp8_weight3,
|
||||
"model.diffusion_model.layer3.scale_weight": torch.tensor(1.0),
|
||||
}
|
||||
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertIn("layer1", result)
|
||||
self.assertIn("layer3", result)
|
||||
self.assertNotIn("layer2", result) # Layer 2 not quantized
|
||||
self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled")
|
||||
self.assertEqual(result["layer3"]["format"], "fp8_e4m3fn_scaled")
|
||||
|
||||
def test_all_layers_quantized(self):
|
||||
"""Test that uniform quantization (all layers) returns None"""
|
||||
# Create FP8 tensors by converting from float32
|
||||
fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
fp8_weight2 = torch.randn(20, 30, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
state_dict = {
|
||||
# All layers have scale_weight
|
||||
"model.diffusion_model.layer1.weight": fp8_weight1,
|
||||
"model.diffusion_model.layer1.scale_weight": torch.tensor(1.0),
|
||||
"model.diffusion_model.layer2.weight": fp8_weight2,
|
||||
"model.diffusion_model.layer2.scale_weight": torch.tensor(1.0),
|
||||
}
|
||||
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
|
||||
|
||||
# If all layers are quantized, it's not mixed precision
|
||||
# Should return None to use legacy scaled_fp8_ops path
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_fp8_e5m2_detection(self):
|
||||
"""Test detection of FP8 E5M2 format"""
|
||||
# Create FP8 E5M2 tensor by converting from float32
|
||||
fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e5m2)
|
||||
state_dict = {
|
||||
"model.diffusion_model.layer1.weight": fp8_weight,
|
||||
"model.diffusion_model.layer1.scale_weight": torch.tensor(1.0),
|
||||
"model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16),
|
||||
}
|
||||
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertIn("layer1", result)
|
||||
self.assertEqual(result["layer1"]["format"], "fp8_e5m2_scaled")
|
||||
|
||||
def test_invalid_metadata(self):
|
||||
"""Test with invalid metadata format"""
|
||||
state_dict = {
|
||||
"model.diffusion_model._quantization_metadata": "invalid_string",
|
||||
"model.diffusion_model.layer1.weight": torch.randn(10, 20),
|
||||
}
|
||||
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
|
||||
# Should return None on invalid metadata
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_different_prefix(self):
|
||||
"""Test with different model prefix (audio model)"""
|
||||
# Create FP8 tensor by converting from float32
|
||||
fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
state_dict = {
|
||||
"model.model.layer1.weight": fp8_weight,
|
||||
"model.model.layer1.scale_weight": torch.tensor(1.0),
|
||||
"model.model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16),
|
||||
}
|
||||
result = model_detection.detect_layer_quantization(state_dict, "model.model.")
|
||||
|
||||
self.assertIsNotNone(result)
|
||||
self.assertIn("layer1", result)
|
||||
|
||||
|
||||
class TestPickOperationsIntegration(unittest.TestCase):
|
||||
"""Test pick_operations with model_config parameter"""
|
||||
|
||||
def test_backward_compatibility(self):
|
||||
"""Test that pick_operations works without model_config (legacy)"""
|
||||
from comfy import ops
|
||||
|
||||
# Should work without model_config parameter
|
||||
result = ops.pick_operations(None, None)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result, ops.disable_weight_init)
|
||||
|
||||
def test_with_model_config_no_quant(self):
|
||||
"""Test with model_config but no quantization"""
|
||||
from comfy import ops, supported_models_base
|
||||
|
||||
model_config = supported_models_base.BASE({})
|
||||
model_config.layer_quant_config = None
|
||||
|
||||
result = ops.pick_operations(None, None, model_config=model_config)
|
||||
self.assertIsNotNone(result)
|
||||
# Should use standard path
|
||||
self.assertEqual(result, ops.disable_weight_init)
|
||||
|
||||
def test_legacy_scaled_fp8(self):
|
||||
"""Test that legacy scaled_fp8 still works"""
|
||||
from comfy import ops, supported_models_base
|
||||
|
||||
model_config = supported_models_base.BASE({})
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
|
||||
result = ops.pick_operations(
|
||||
None, None,
|
||||
scaled_fp8=torch.float8_e4m3fn,
|
||||
model_config=model_config
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
# Should return scaled_fp8_ops (the returned class is the inner class)
|
||||
# Check that it's not the standard disable_weight_init
|
||||
self.assertNotEqual(result, ops.disable_weight_init)
|
||||
# Verify it has Linear class
|
||||
self.assertTrue(hasattr(result, 'Linear'))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -1,399 +0,0 @@
|
||||
"""
|
||||
Unit tests for tensor subclass quantization system.
|
||||
Tests the new QuantizedTensorFP8 subclass and operation handlers.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add comfy to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from comfy import ops
|
||||
from comfy import quant_ops
|
||||
|
||||
|
||||
class TestQuantizedTensorFP8(unittest.TestCase):
|
||||
"""Test the QuantizedTensorFP8 tensor subclass"""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating a QuantizedTensorFP8"""
|
||||
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(2.0)
|
||||
|
||||
qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16)
|
||||
|
||||
self.assertIsInstance(qt, quant_ops.QuantizedTensorFP8)
|
||||
self.assertEqual(qt.shape, (256, 128))
|
||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qt._scale, scale)
|
||||
self.assertEqual(qt._orig_dtype, torch.bfloat16)
|
||||
|
||||
def test_dequantize(self):
|
||||
"""Test explicit dequantization"""
|
||||
# Create a simple FP8 tensor
|
||||
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(3.0)
|
||||
|
||||
qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.float32)
|
||||
dequantized = qt.dequantize()
|
||||
|
||||
# Dequantized should be approximately ones * 3.0
|
||||
self.assertEqual(dequantized.dtype, torch.float32)
|
||||
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
||||
|
||||
def test_repr(self):
|
||||
"""Test string representation"""
|
||||
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(2.5)
|
||||
|
||||
qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16)
|
||||
repr_str = repr(qt)
|
||||
|
||||
self.assertIn("QuantizedTensorFP8", repr_str)
|
||||
self.assertIn("shape", repr_str)
|
||||
self.assertIn("scale", repr_str)
|
||||
|
||||
|
||||
class TestOperationRegistry(unittest.TestCase):
|
||||
"""Test the operation registry system"""
|
||||
|
||||
def test_registry_basics(self):
|
||||
"""Test that operations are registered"""
|
||||
registered_ops = quant_ops.list_registered_ops()
|
||||
|
||||
# Check that key operations are registered
|
||||
self.assertIn(torch.ops.aten.linear.default, registered_ops)
|
||||
self.assertIn(torch.ops.aten.silu.default, registered_ops)
|
||||
self.assertIn(torch.ops.aten.layer_norm.default, registered_ops)
|
||||
self.assertIn(torch.ops.aten.add.Tensor, registered_ops)
|
||||
self.assertIn(torch.ops.aten.mul.Tensor, registered_ops)
|
||||
|
||||
def test_get_handler(self):
|
||||
"""Test getting a registered handler"""
|
||||
handler = quant_ops.get_quant_handler(torch.ops.aten.linear.default)
|
||||
self.assertIsNotNone(handler)
|
||||
self.assertTrue(callable(handler))
|
||||
|
||||
def test_custom_registration(self):
|
||||
"""Test registering a custom operation"""
|
||||
|
||||
# Define a custom handler
|
||||
@quant_ops.register_quant_op(torch.ops.aten.relu.default)
|
||||
def custom_relu_handler(func, args, kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Verify registration
|
||||
handler = quant_ops.get_quant_handler(torch.ops.aten.relu.default)
|
||||
self.assertIsNotNone(handler)
|
||||
self.assertEqual(handler, custom_relu_handler)
|
||||
|
||||
|
||||
class TestLinearHandler(unittest.TestCase):
|
||||
"""Test the linear operation handler"""
|
||||
|
||||
def test_linear_with_quantized_weight(self):
|
||||
"""Test F.linear with quantized weight"""
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create quantized weight
|
||||
weight_fp32 = torch.randn(256, 128, dtype=torch.float32)
|
||||
scale = torch.tensor(2.0)
|
||||
weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn)
|
||||
weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32)
|
||||
|
||||
# Create input
|
||||
input_tensor = torch.randn(16, 128, dtype=torch.float32)
|
||||
|
||||
# Call linear (should trigger dispatch)
|
||||
output = torch.nn.functional.linear(input_tensor, weight_q, bias=None)
|
||||
|
||||
# Verify output shape
|
||||
self.assertEqual(output.shape, (16, 256))
|
||||
|
||||
# Verify it's approximately correct (allowing for FP8 quantization error)
|
||||
# Note: FP8 has limited precision, so use very loose tolerance
|
||||
expected = torch.nn.functional.linear(input_tensor, weight_fp32, bias=None)
|
||||
# Just check that it's in the right ballpark (within 50% error on average)
|
||||
mean_rel_error = ((output - expected).abs() / (expected.abs() + 1e-6)).mean()
|
||||
self.assertLess(mean_rel_error, 0.5, f"Mean relative error {mean_rel_error:.3f} is too large")
|
||||
|
||||
def test_linear_with_bias(self):
|
||||
"""Test F.linear with quantized weight and bias"""
|
||||
weight_fp32 = torch.randn(64, 32, dtype=torch.float32)
|
||||
scale = torch.tensor(1.5)
|
||||
weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn)
|
||||
weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32)
|
||||
|
||||
input_tensor = torch.randn(8, 32, dtype=torch.float32)
|
||||
bias = torch.randn(64, dtype=torch.float32)
|
||||
|
||||
output = torch.nn.functional.linear(input_tensor, weight_q, bias)
|
||||
|
||||
self.assertEqual(output.shape, (8, 64))
|
||||
|
||||
|
||||
class TestActivationHandlers(unittest.TestCase):
|
||||
"""Test activation function handlers"""
|
||||
|
||||
def test_silu_with_quantized_input(self):
|
||||
"""Test SiLU with quantized input"""
|
||||
# Create quantized input
|
||||
input_fp32 = torch.randn(16, 128, dtype=torch.float32)
|
||||
scale = torch.tensor(1.0)
|
||||
input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn)
|
||||
input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32)
|
||||
|
||||
# Apply SiLU
|
||||
output = torch.nn.functional.silu(input_q)
|
||||
|
||||
# Should return a QuantizedTensorFP8
|
||||
self.assertIsInstance(output, quant_ops.QuantizedTensorFP8)
|
||||
|
||||
# Verify approximate correctness
|
||||
expected = torch.nn.functional.silu(input_fp32)
|
||||
output_dq = output.dequantize()
|
||||
self.assertTrue(torch.allclose(output_dq, expected, rtol=0.2, atol=0.2))
|
||||
|
||||
def test_layernorm_dequantizes(self):
|
||||
"""Test that LayerNorm dequantizes input"""
|
||||
# Create quantized input
|
||||
input_fp32 = torch.randn(16, 128, dtype=torch.float32)
|
||||
scale = torch.tensor(1.0)
|
||||
input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn)
|
||||
input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32)
|
||||
|
||||
# Apply LayerNorm
|
||||
weight = torch.ones(128)
|
||||
bias = torch.zeros(128)
|
||||
output = torch.nn.functional.layer_norm(input_q, (128,), weight, bias)
|
||||
|
||||
# Should NOT be quantized (LayerNorm breaks quantization)
|
||||
self.assertNotIsInstance(output, quant_ops.QuantizedTensorFP8)
|
||||
self.assertEqual(output.dtype, torch.float32)
|
||||
|
||||
|
||||
class TestElementwiseHandlers(unittest.TestCase):
|
||||
"""Test element-wise operation handlers"""
|
||||
|
||||
def test_add_mixed_tensors(self):
|
||||
"""Test addition with mixed quantized/non-quantized tensors"""
|
||||
# Create quantized tensor
|
||||
a_fp32 = torch.ones(10, 20, dtype=torch.float32)
|
||||
scale = torch.tensor(1.0)
|
||||
a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn)
|
||||
a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32)
|
||||
|
||||
# Non-quantized tensor
|
||||
b = torch.ones(10, 20, dtype=torch.float32) * 2.0
|
||||
|
||||
# Add them
|
||||
result = a_q + b
|
||||
|
||||
# Should be dequantized
|
||||
self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8)
|
||||
self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 3.0, rtol=0.1))
|
||||
|
||||
def test_mul_quantized_tensors(self):
|
||||
"""Test multiplication of two quantized tensors"""
|
||||
a_fp32 = torch.ones(10, 20) * 2.0
|
||||
scale_a = torch.tensor(1.0)
|
||||
a_fp8 = (a_fp32 / scale_a).to(torch.float8_e4m3fn)
|
||||
a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale_a, orig_dtype=torch.float32)
|
||||
|
||||
b_fp32 = torch.ones(10, 20) * 3.0
|
||||
scale_b = torch.tensor(1.0)
|
||||
b_fp8 = (b_fp32 / scale_b).to(torch.float8_e4m3fn)
|
||||
b_q = quant_ops.QuantizedTensorFP8(b_fp8, scale_b, orig_dtype=torch.float32)
|
||||
|
||||
result = a_q * b_q
|
||||
|
||||
# Should be dequantized
|
||||
self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8)
|
||||
self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 6.0, rtol=0.1))
|
||||
|
||||
|
||||
class TestFallbackMechanism(unittest.TestCase):
|
||||
"""Test fallback for unsupported operations"""
|
||||
|
||||
def test_unsupported_op_dequantizes(self):
|
||||
"""Test that unsupported operations fall back to dequantization"""
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create quantized tensor
|
||||
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
||||
scale = torch.tensor(1.0)
|
||||
a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn)
|
||||
a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32)
|
||||
|
||||
# Call an operation that doesn't have a registered handler
|
||||
# For example, torch.abs
|
||||
result = torch.abs(a_q)
|
||||
|
||||
# Should work via fallback (dequantize → abs → return)
|
||||
self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8)
|
||||
expected = torch.abs(a_fp32)
|
||||
# FP8 introduces quantization error, so use loose tolerance
|
||||
mean_error = (result - expected).abs().mean()
|
||||
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
|
||||
|
||||
|
||||
class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"""Test MixedPrecisionOps integration"""
|
||||
|
||||
def test_linear_layer_creation(self):
|
||||
"""Test that MixedPrecisionOps.Linear can be created"""
|
||||
layer = ops.MixedPrecisionOps.Linear(128, 256, bias=True, device="cpu", dtype=torch.float32)
|
||||
|
||||
self.assertIsInstance(layer, ops.MixedPrecisionOps.Linear)
|
||||
self.assertFalse(layer._quantization_initialized)
|
||||
self.assertIsNone(layer.quant_format)
|
||||
|
||||
def test_layer_quant_config_detection(self):
|
||||
"""Test that layer quantization config is detected during load"""
|
||||
# Set up layer config
|
||||
ops.MixedPrecisionOps._layer_quant_config = {
|
||||
"test_layer": {
|
||||
"format": "fp8_e4m3fn",
|
||||
"params": {}
|
||||
}
|
||||
}
|
||||
|
||||
# Create a state dict with quantized weight
|
||||
weight_fp32 = torch.randn(256, 128, dtype=torch.float32)
|
||||
scale = torch.tensor(2.0)
|
||||
weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn)
|
||||
|
||||
state_dict = {
|
||||
"model.diffusion_model.test_layer.weight": weight_fp8,
|
||||
"model.diffusion_model.test_layer.scale_weight": scale,
|
||||
}
|
||||
|
||||
# Create layer and load
|
||||
layer = ops.MixedPrecisionOps.Linear(128, 256, bias=False, device="cpu", dtype=torch.float8_e4m3fn)
|
||||
layer.weight = torch.nn.Parameter(torch.zeros(256, 128, dtype=torch.float8_e4m3fn))
|
||||
|
||||
# Manually call _load_from_state_dict
|
||||
layer._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix="model.diffusion_model.test_layer.",
|
||||
local_metadata={},
|
||||
strict=True,
|
||||
missing_keys=[],
|
||||
unexpected_keys=[],
|
||||
error_msgs=[]
|
||||
)
|
||||
|
||||
# Verify quantization was initialized
|
||||
self.assertTrue(layer._quantization_initialized)
|
||||
self.assertEqual(layer.quant_format, "fp8_e4m3fn")
|
||||
self.assertIsNotNone(layer.quant_scale)
|
||||
|
||||
# Verify weight is wrapped
|
||||
self.assertIsInstance(layer.weight.data, quant_ops.QuantizedTensorFP8)
|
||||
|
||||
# Clean up
|
||||
ops.MixedPrecisionOps._layer_quant_config = {}
|
||||
|
||||
|
||||
class TestBackwardCompatibility(unittest.TestCase):
|
||||
"""Test backward compatibility with legacy systems"""
|
||||
|
||||
def test_legacy_ops_classes_exist(self):
|
||||
"""Test that legacy ops classes still exist"""
|
||||
self.assertTrue(hasattr(ops, 'disable_weight_init'))
|
||||
self.assertTrue(hasattr(ops, 'manual_cast'))
|
||||
self.assertTrue(hasattr(ops, 'fp8_ops'))
|
||||
self.assertTrue(hasattr(ops, 'scaled_fp8_ops'))
|
||||
|
||||
def test_pick_operations_legacy_path(self):
|
||||
"""Test pick_operations returns correct class for legacy cases"""
|
||||
# Test standard case
|
||||
result = ops.pick_operations(torch.float32, torch.float32)
|
||||
self.assertEqual(result, ops.disable_weight_init)
|
||||
|
||||
# Test manual cast case
|
||||
result = ops.pick_operations(torch.float32, torch.float16)
|
||||
self.assertEqual(result, ops.manual_cast)
|
||||
|
||||
|
||||
class TestFP8LinearUnification(unittest.TestCase):
|
||||
"""Test that fp8_linear now uses the unified tensor subclass infrastructure"""
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for FP8")
|
||||
def test_fp8_linear_uses_tensor_subclass(self):
|
||||
"""Verify fp8_linear wraps tensors in QuantizedTensorFP8"""
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create a mock Linear layer with FP8 weight
|
||||
linear = ops.fp8_ops.Linear(4, 3, bias=True)
|
||||
linear.weight = torch.nn.Parameter(
|
||||
torch.randn(3, 4, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
|
||||
requires_grad=False
|
||||
)
|
||||
linear.bias = torch.nn.Parameter(
|
||||
torch.randn(3, dtype=torch.bfloat16),
|
||||
requires_grad=False
|
||||
)
|
||||
linear.scale_weight = torch.tensor(1.0)
|
||||
linear.scale_input = None # No input scaling
|
||||
|
||||
# Create input
|
||||
input_tensor = torch.randn(2, 4, dtype=torch.bfloat16)
|
||||
|
||||
# Call fp8_linear - should work without errors
|
||||
try:
|
||||
result = ops.fp8_linear(linear, input_tensor)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.shape, (2, 3))
|
||||
except Exception as e:
|
||||
# On CPU or unsupported hardware, _scaled_mm might not be available
|
||||
# but the function should still complete without syntax errors
|
||||
pass
|
||||
|
||||
def test_fp8_linear_maintains_signature(self):
|
||||
"""Verify fp8_linear maintains its original function signature"""
|
||||
import inspect
|
||||
sig = inspect.signature(ops.fp8_linear)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
# Should have 'self' and 'input' parameters
|
||||
self.assertIn('self', params)
|
||||
self.assertIn('input', params)
|
||||
self.assertEqual(len(params), 2)
|
||||
|
||||
def test_fp8_linear_returns_none_for_non_fp8(self):
|
||||
"""Verify fp8_linear returns None for non-FP8 weights"""
|
||||
# Create a Linear layer with BF16 weight (not FP8)
|
||||
linear = ops.disable_weight_init.Linear(4, 3, bias=False)
|
||||
linear.weight = torch.nn.Parameter(
|
||||
torch.randn(3, 4, dtype=torch.bfloat16),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
input_tensor = torch.randn(2, 4, dtype=torch.bfloat16)
|
||||
|
||||
# Should return None for non-FP8 weights
|
||||
result = ops.fp8_linear(linear, input_tensor)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_fp8_ops_linear_uses_fp8_linear(self):
|
||||
"""Verify fp8_ops.Linear still uses fp8_linear in forward pass"""
|
||||
linear = ops.fp8_ops.Linear(4, 3, bias=False)
|
||||
|
||||
# Verify the class has the forward_comfy_cast_weights method
|
||||
self.assertTrue(hasattr(linear, 'forward_comfy_cast_weights'))
|
||||
|
||||
# The forward_comfy_cast_weights should attempt to call fp8_linear
|
||||
# (we can't easily test this without mocking, but we verify structure)
|
||||
import inspect
|
||||
source = inspect.getsource(linear.forward_comfy_cast_weights)
|
||||
self.assertIn('fp8_linear', source)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user