Update unit tests

This commit is contained in:
lspindler 2025-10-27 08:41:23 +01:00
parent 0d201540ca
commit 77d307049f
6 changed files with 235 additions and 762 deletions

View File

@ -566,7 +566,7 @@ class MixedPrecisionOps(disable_weight_init):
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
self._forward(input, weight, bias)
return self._forward(input, weight, bias)
def forward(self, input, *args, **kwargs):
run_every_op()

View File

@ -143,7 +143,7 @@ class QuantizedTensor(torch.Tensor):
"""
Tensor flattening protocol for proper device movement.
"""
inner_tensors = ["_q_data"]
inner_tensors = ["_qdata"]
ctx = {
"layout_type": self._layout_type,
}
@ -206,7 +206,7 @@ class QuantizedTensor(torch.Tensor):
# Step 3: Fallback to dequantization
if isinstance(args[0] if args else None, QuantizedTensor):
logging.warning(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return cls._dequant_and_fallback(func, args, kwargs)
@classmethod

View File

@ -1,8 +1,3 @@
"""
End-to-end tests for mixed precision quantization.
Tests Phase 3: Mixed Precision Operations
"""
import unittest
import torch
import sys
@ -12,10 +7,10 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from comfy import ops
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
class SimpleModel(torch.nn.Module):
"""Simple model for testing mixed precision"""
def __init__(self, operations=ops.disable_weight_init):
super().__init__()
self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16)
@ -32,8 +27,7 @@ class SimpleModel(torch.nn.Module):
class TestMixedPrecisionOps(unittest.TestCase):
"""Test MixedPrecisionOps end-to-end"""
def test_all_layers_standard(self):
"""Test that model with no quantization works normally"""
# Configure no quantization
@ -67,48 +61,54 @@ class TestMixedPrecisionOps(unittest.TestCase):
# Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard
layer_quant_config = {
"layer1": {
"format": "fp8_e4m3fn_scaled",
"params": {"use_fp8_matmul": False} # Disable for CPU testing
"format": "float8_e4m3fn",
"params": {}
},
"layer3": {
"format": "fp8_e5m2_scaled",
"params": {"use_fp8_matmul": False}
"format": "float8_e4m3fn",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
# Create state dict with mixed precision
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e5m2)
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict = {
# Layer 1: FP8 E4M3FN
"layer1.weight": fp8_weight1,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32),
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
# Layer 2: Standard BF16
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
# Layer 3: FP8 E5M2
# Layer 3: FP8 E4M3FN
"layer3.weight": fp8_weight3,
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
"layer3.scale_weight": torch.tensor(1.5, dtype=torch.float32),
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
}
# Create model and load state dict
# Create model and load state dict (strict=False because custom loading pops keys)
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict)
model.load_state_dict(state_dict, strict=False)
# Verify handlers are set up correctly
self.assertIsNotNone(model.layer1.quant_handler)
self.assertIsNone(model.layer2.quant_handler) # No quantization
self.assertIsNotNone(model.layer3.quant_handler)
# Verify weights are wrapped in QuantizedTensor
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
# Layer 2 should NOT be quantized
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
# Layer 3 should be quantized
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
# Verify scales were loaded
self.assertEqual(model.layer1.scale_weight.item(), 2.0)
self.assertEqual(model.layer3.scale_weight.item(), 1.5)
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
# Forward pass
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
@ -116,13 +116,13 @@ class TestMixedPrecisionOps(unittest.TestCase):
self.assertEqual(output.shape, (5, 40))
def test_state_dict_round_trip(self):
"""Test saving and loading state dict preserves quantization"""
def test_state_dict_quantized_preserved(self):
"""Test that quantized weights are preserved in state_dict()"""
# Configure mixed precision
layer_quant_config = {
"layer1": {
"format": "fp8_e4m3fn_scaled",
"params": {"use_fp8_matmul": False}
"format": "float8_e4m3fn",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
@ -132,45 +132,35 @@ class TestMixedPrecisionOps(unittest.TestCase):
state_dict1 = {
"layer1.weight": fp8_weight,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.scale_weight": torch.tensor(3.0, dtype=torch.float32),
"layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32),
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
model1 = SimpleModel(operations=ops.MixedPrecisionOps)
model1.load_state_dict(state_dict1)
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict1, strict=False)
# Save state dict
state_dict2 = model1.state_dict()
state_dict2 = model.state_dict()
# Verify scale_weight is saved
self.assertIn("layer1.scale_weight", state_dict2)
self.assertEqual(state_dict2["layer1.scale_weight"].item(), 3.0)
# Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
# Load into new model
model2 = SimpleModel(operations=ops.MixedPrecisionOps)
model2.load_state_dict(state_dict2)
# Verify handler is set up
self.assertIsNotNone(model2.layer1.quant_handler)
self.assertEqual(model2.layer1.scale_weight.item(), 3.0)
# Verify forward passes match
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output1 = model1(input_tensor)
output2 = model2(input_tensor)
torch.testing.assert_close(output1, output2, rtol=1e-3, atol=1e-3)
# Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor)
def test_weight_function_compatibility(self):
"""Test that weight_function (LoRA) works with quantized layers"""
# Configure FP8 quantization
layer_quant_config = {
"layer1": {
"format": "fp8_e4m3fn_scaled",
"params": {"use_fp8_matmul": False}
"format": "float8_e4m3fn",
"params": {}
}
}
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
@ -180,7 +170,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
state_dict = {
"layer1.weight": fp8_weight,
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
"layer1.scale_weight": torch.tensor(2.0, dtype=torch.float32),
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
@ -188,25 +178,24 @@ class TestMixedPrecisionOps(unittest.TestCase):
}
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict)
model.load_state_dict(state_dict, strict=False)
# Add a weight function (simulating LoRA)
# LoRA delta must match weight shape (20, 10)
# This should trigger dequantization during forward pass
def apply_lora(weight):
# Generate LoRA delta matching weight shape
lora_delta = torch.randn_like(weight) * 0.01
return weight + lora_delta
model.layer1.weight_function.append(apply_lora)
# Forward pass should work with LoRA
# Forward pass should work with LoRA (triggers weight_function path)
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
def test_error_handling_unknown_format(self):
"""Test that unknown formats fall back gracefully"""
"""Test that unknown formats raise error"""
# Configure with unknown format
layer_quant_config = {
"layer1": {
@ -226,48 +215,10 @@ class TestMixedPrecisionOps(unittest.TestCase):
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
}
# Load should not crash, just log warning
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
model = SimpleModel(operations=ops.MixedPrecisionOps)
model.load_state_dict(state_dict)
# Handler should be None (fallback to standard)
self.assertIsNone(model.layer1.quant_handler)
# Forward pass should still work
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
output = model(input_tensor)
self.assertEqual(output.shape, (5, 40))
class TestPickOperationsWithMixedPrecision(unittest.TestCase):
"""Test pick_operations with mixed precision config"""
def test_pick_operations_with_layer_quant_config(self):
"""Test that pick_operations returns MixedPrecisionOps when config present"""
from comfy import supported_models_base
# Create model config with layer_quant_config
model_config = supported_models_base.BASE({})
model_config.layer_quant_config = {
"layer1": {"format": "fp8_e4m3fn_scaled", "params": {}}
}
result = ops.pick_operations(None, None, model_config=model_config)
self.assertEqual(result, ops.MixedPrecisionOps)
self.assertEqual(ops.MixedPrecisionOps._layer_quant_config, model_config.layer_quant_config)
def test_pick_operations_without_layer_quant_config(self):
"""Test that pick_operations falls back to standard when no config"""
from comfy import supported_models_base
model_config = supported_models_base.BASE({})
model_config.layer_quant_config = None
result = ops.pick_operations(None, None, model_config=model_config)
self.assertEqual(result, ops.disable_weight_init)
with self.assertRaises(KeyError):
model.load_state_dict(state_dict, strict=False)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,183 @@
import unittest
import torch
import sys
import os
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
class TestQuantizedTensor(unittest.TestCase):
"""Test the QuantizedTensor subclass with FP8 layout"""
def test_creation(self):
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(2.0)
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._layout_params['scale'], scale)
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
def test_dequantize(self):
"""Test explicit dequantization"""
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(3.0)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
dequantized = qt.dequantize()
self.assertEqual(dequantized.dtype, torch.float32)
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
def test_from_float(self):
"""Test creating QuantizedTensor from float tensor"""
float_tensor = torch.randn(64, 32, dtype=torch.float32)
scale = torch.tensor(1.5)
qt = QuantizedTensor.from_float(
float_tensor,
TensorCoreFP8Layout,
scale=scale,
fp8_dtype=torch.float8_e4m3fn
)
self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt.shape, (64, 32))
# Verify dequantization gives approximately original values
dequantized = qt.dequantize()
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
self.assertLess(mean_rel_error, 0.1)
class TestGenericUtilities(unittest.TestCase):
"""Test generic utility operations"""
def test_detach(self):
"""Test detach operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
# Detach should return a new QuantizedTensor
qt_detached = qt.detach()
self.assertIsInstance(qt_detached, QuantizedTensor)
self.assertEqual(qt_detached.shape, qt.shape)
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
def test_clone(self):
"""Test clone operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
# Clone should return a new QuantizedTensor
qt_cloned = qt.clone()
self.assertIsInstance(qt_cloned, QuantizedTensor)
self.assertEqual(qt_cloned.shape, qt.shape)
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
# Verify it's a deep copy
self.assertIsNot(qt_cloned._qdata, qt._qdata)
def test_to_device(self):
"""Test device transfer"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
# Moving to same device should work (CPU to CPU)
qt_cpu = qt.to('cpu')
self.assertIsInstance(qt_cpu, QuantizedTensor)
self.assertEqual(qt_cpu.device.type, 'cpu')
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
class TestTensorCoreFP8Layout(unittest.TestCase):
"""Test the TensorCoreFP8Layout implementation"""
def test_quantize(self):
"""Test quantization method"""
float_tensor = torch.randn(32, 64, dtype=torch.float32)
scale = torch.tensor(1.5)
qdata, layout_params = TensorCoreFP8Layout.quantize(
float_tensor,
scale=scale,
fp8_dtype=torch.float8_e4m3fn
)
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
self.assertEqual(qdata.shape, float_tensor.shape)
self.assertIn('scale', layout_params)
self.assertIn('orig_dtype', layout_params)
self.assertEqual(layout_params['orig_dtype'], torch.float32)
def test_dequantize(self):
"""Test dequantization method"""
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
scale = torch.tensor(1.0)
qdata, layout_params = TensorCoreFP8Layout.quantize(
float_tensor,
scale=scale,
fp8_dtype=torch.float8_e4m3fn
)
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
# Should approximately match original
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
class TestFallbackMechanism(unittest.TestCase):
"""Test fallback for unsupported operations"""
def test_unsupported_op_dequantizes(self):
"""Test that unsupported operations fall back to dequantization"""
# Set seed for reproducibility
torch.manual_seed(42)
# Create quantized tensor
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
scale = torch.tensor(1.0)
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
a_q = QuantizedTensor.from_float(
a_fp32,
TensorCoreFP8Layout,
scale=scale,
fp8_dtype=torch.float8_e4m3fn
)
# Call an operation that doesn't have a registered handler
# For example, torch.abs
result = torch.abs(a_q)
# Should work via fallback (dequantize → abs → return)
self.assertNotIsInstance(result, QuantizedTensor)
expected = torch.abs(a_fp32)
# FP8 introduces quantization error, so use loose tolerance
mean_error = (result - expected).abs().mean()
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
if __name__ == "__main__":
unittest.main()

View File

@ -1,262 +0,0 @@
"""
Integration tests for quantization detection.
Tests Phase 2: Detection & Integration
"""
import unittest
import torch
import sys
import os
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from comfy import model_detection
class TestNormalizeLayerName(unittest.TestCase):
"""Test the normalize_layer_name helper function"""
def test_strip_prefix_and_suffix(self):
"""Test stripping prefix and suffix"""
known_prefixes = ["model.diffusion_model."]
result = model_detection.normalize_layer_name(
"model.diffusion_model.layer1.weight",
known_prefixes
)
self.assertEqual(result, "layer1")
def test_strip_multiple_prefixes(self):
"""Test with multiple known prefixes"""
known_prefixes = ["model.diffusion_model.", "model.model.", "net."]
result1 = model_detection.normalize_layer_name(
"model.diffusion_model.block.attn.weight",
known_prefixes
)
self.assertEqual(result1, "block.attn")
result2 = model_detection.normalize_layer_name(
"model.model.encoder.layer.weight",
known_prefixes
)
self.assertEqual(result2, "encoder.layer")
result3 = model_detection.normalize_layer_name(
"net.transformer.blocks.0.weight",
known_prefixes
)
self.assertEqual(result3, "transformer.blocks.0")
def test_strip_scale_weight_suffix(self):
"""Test stripping scale_weight suffix"""
known_prefixes = ["model.diffusion_model."]
result = model_detection.normalize_layer_name(
"model.diffusion_model.layer1.scale_weight",
known_prefixes
)
self.assertEqual(result, "layer1")
def test_strip_bias_suffix(self):
"""Test stripping bias suffix"""
known_prefixes = ["model.diffusion_model."]
result = model_detection.normalize_layer_name(
"model.diffusion_model.layer1.bias",
known_prefixes
)
self.assertEqual(result, "layer1")
def test_no_prefix_match(self):
"""Test with no prefix match"""
known_prefixes = ["model.diffusion_model."]
result = model_detection.normalize_layer_name(
"other.model.layer1.weight",
known_prefixes
)
# Should strip suffix but not prefix
self.assertEqual(result, "other.model.layer1")
class TestDetectLayerQuantization(unittest.TestCase):
"""Test the detect_layer_quantization function"""
def test_no_quantization(self):
"""Test with no quantization markers"""
state_dict = {
"model.diffusion_model.layer1.weight": torch.randn(10, 20),
"model.diffusion_model.layer2.weight": torch.randn(20, 30),
}
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
self.assertIsNone(result)
def test_legacy_scaled_fp8(self):
"""Test that legacy scaled_fp8 marker returns None"""
# Create FP8 tensor by converting from float32
fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict = {
"model.diffusion_model.scaled_fp8": torch.tensor([], dtype=torch.float8_e4m3fn),
"model.diffusion_model.layer1.weight": fp8_weight,
"model.diffusion_model.layer1.scale_weight": torch.tensor(1.0),
}
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
# Should return None to trigger legacy path
self.assertIsNone(result)
def test_metadata_format(self):
"""Test with new metadata format"""
metadata = {
"format_version": "1.0",
"layers": {
"layer1": {
"format": "fp8_e4m3fn_scaled",
"params": {"use_fp8_matmul": True}
},
"layer2": {
"format": "fp8_e5m2_scaled",
"params": {"use_fp8_matmul": True}
}
}
}
state_dict = {
"model.diffusion_model._quantization_metadata": metadata,
"model.diffusion_model.layer1.weight": torch.randn(10, 20),
}
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
self.assertIsNotNone(result)
self.assertIn("layer1", result)
self.assertIn("layer2", result)
self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled")
self.assertEqual(result["layer2"]["format"], "fp8_e5m2_scaled")
# Metadata should be popped from state_dict
self.assertNotIn("model.diffusion_model._quantization_metadata", state_dict)
def test_mixed_precision_detection(self):
"""Test detection of mixed precision via scale patterns"""
# Create FP8 tensors by converting from float32
fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
fp8_weight3 = torch.randn(30, 40, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict = {
# Layer 1: FP8 (has scale_weight)
"model.diffusion_model.layer1.weight": fp8_weight1,
"model.diffusion_model.layer1.scale_weight": torch.tensor(1.0),
# Layer 2: Standard (no scale_weight)
"model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16),
# Layer 3: FP8 (has scale_weight)
"model.diffusion_model.layer3.weight": fp8_weight3,
"model.diffusion_model.layer3.scale_weight": torch.tensor(1.0),
}
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
self.assertIsNotNone(result)
self.assertIn("layer1", result)
self.assertIn("layer3", result)
self.assertNotIn("layer2", result) # Layer 2 not quantized
self.assertEqual(result["layer1"]["format"], "fp8_e4m3fn_scaled")
self.assertEqual(result["layer3"]["format"], "fp8_e4m3fn_scaled")
def test_all_layers_quantized(self):
"""Test that uniform quantization (all layers) returns None"""
# Create FP8 tensors by converting from float32
fp8_weight1 = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
fp8_weight2 = torch.randn(20, 30, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict = {
# All layers have scale_weight
"model.diffusion_model.layer1.weight": fp8_weight1,
"model.diffusion_model.layer1.scale_weight": torch.tensor(1.0),
"model.diffusion_model.layer2.weight": fp8_weight2,
"model.diffusion_model.layer2.scale_weight": torch.tensor(1.0),
}
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
# If all layers are quantized, it's not mixed precision
# Should return None to use legacy scaled_fp8_ops path
self.assertIsNone(result)
def test_fp8_e5m2_detection(self):
"""Test detection of FP8 E5M2 format"""
# Create FP8 E5M2 tensor by converting from float32
fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e5m2)
state_dict = {
"model.diffusion_model.layer1.weight": fp8_weight,
"model.diffusion_model.layer1.scale_weight": torch.tensor(1.0),
"model.diffusion_model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16),
}
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
self.assertIsNotNone(result)
self.assertIn("layer1", result)
self.assertEqual(result["layer1"]["format"], "fp8_e5m2_scaled")
def test_invalid_metadata(self):
"""Test with invalid metadata format"""
state_dict = {
"model.diffusion_model._quantization_metadata": "invalid_string",
"model.diffusion_model.layer1.weight": torch.randn(10, 20),
}
result = model_detection.detect_layer_quantization(state_dict, "model.diffusion_model.")
# Should return None on invalid metadata
self.assertIsNone(result)
def test_different_prefix(self):
"""Test with different model prefix (audio model)"""
# Create FP8 tensor by converting from float32
fp8_weight = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
state_dict = {
"model.model.layer1.weight": fp8_weight,
"model.model.layer1.scale_weight": torch.tensor(1.0),
"model.model.layer2.weight": torch.randn(20, 30, dtype=torch.bfloat16),
}
result = model_detection.detect_layer_quantization(state_dict, "model.model.")
self.assertIsNotNone(result)
self.assertIn("layer1", result)
class TestPickOperationsIntegration(unittest.TestCase):
"""Test pick_operations with model_config parameter"""
def test_backward_compatibility(self):
"""Test that pick_operations works without model_config (legacy)"""
from comfy import ops
# Should work without model_config parameter
result = ops.pick_operations(None, None)
self.assertIsNotNone(result)
self.assertEqual(result, ops.disable_weight_init)
def test_with_model_config_no_quant(self):
"""Test with model_config but no quantization"""
from comfy import ops, supported_models_base
model_config = supported_models_base.BASE({})
model_config.layer_quant_config = None
result = ops.pick_operations(None, None, model_config=model_config)
self.assertIsNotNone(result)
# Should use standard path
self.assertEqual(result, ops.disable_weight_init)
def test_legacy_scaled_fp8(self):
"""Test that legacy scaled_fp8 still works"""
from comfy import ops, supported_models_base
model_config = supported_models_base.BASE({})
model_config.scaled_fp8 = torch.float8_e4m3fn
result = ops.pick_operations(
None, None,
scaled_fp8=torch.float8_e4m3fn,
model_config=model_config
)
self.assertIsNotNone(result)
# Should return scaled_fp8_ops (the returned class is the inner class)
# Check that it's not the standard disable_weight_init
self.assertNotEqual(result, ops.disable_weight_init)
# Verify it has Linear class
self.assertTrue(hasattr(result, 'Linear'))
if __name__ == "__main__":
unittest.main()

View File

@ -1,399 +0,0 @@
"""
Unit tests for tensor subclass quantization system.
Tests the new QuantizedTensorFP8 subclass and operation handlers.
"""
import unittest
import torch
import sys
import os
# Add comfy to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
from comfy import ops
from comfy import quant_ops
class TestQuantizedTensorFP8(unittest.TestCase):
"""Test the QuantizedTensorFP8 tensor subclass"""
def test_creation(self):
"""Test creating a QuantizedTensorFP8"""
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(2.0)
qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16)
self.assertIsInstance(qt, quant_ops.QuantizedTensorFP8)
self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._scale, scale)
self.assertEqual(qt._orig_dtype, torch.bfloat16)
def test_dequantize(self):
"""Test explicit dequantization"""
# Create a simple FP8 tensor
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(3.0)
qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.float32)
dequantized = qt.dequantize()
# Dequantized should be approximately ones * 3.0
self.assertEqual(dequantized.dtype, torch.float32)
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
def test_repr(self):
"""Test string representation"""
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(2.5)
qt = quant_ops.QuantizedTensorFP8(fp8_data, scale, orig_dtype=torch.bfloat16)
repr_str = repr(qt)
self.assertIn("QuantizedTensorFP8", repr_str)
self.assertIn("shape", repr_str)
self.assertIn("scale", repr_str)
class TestOperationRegistry(unittest.TestCase):
"""Test the operation registry system"""
def test_registry_basics(self):
"""Test that operations are registered"""
registered_ops = quant_ops.list_registered_ops()
# Check that key operations are registered
self.assertIn(torch.ops.aten.linear.default, registered_ops)
self.assertIn(torch.ops.aten.silu.default, registered_ops)
self.assertIn(torch.ops.aten.layer_norm.default, registered_ops)
self.assertIn(torch.ops.aten.add.Tensor, registered_ops)
self.assertIn(torch.ops.aten.mul.Tensor, registered_ops)
def test_get_handler(self):
"""Test getting a registered handler"""
handler = quant_ops.get_quant_handler(torch.ops.aten.linear.default)
self.assertIsNotNone(handler)
self.assertTrue(callable(handler))
def test_custom_registration(self):
"""Test registering a custom operation"""
# Define a custom handler
@quant_ops.register_quant_op(torch.ops.aten.relu.default)
def custom_relu_handler(func, args, kwargs):
return func(*args, **kwargs)
# Verify registration
handler = quant_ops.get_quant_handler(torch.ops.aten.relu.default)
self.assertIsNotNone(handler)
self.assertEqual(handler, custom_relu_handler)
class TestLinearHandler(unittest.TestCase):
"""Test the linear operation handler"""
def test_linear_with_quantized_weight(self):
"""Test F.linear with quantized weight"""
# Set seed for reproducibility
torch.manual_seed(42)
# Create quantized weight
weight_fp32 = torch.randn(256, 128, dtype=torch.float32)
scale = torch.tensor(2.0)
weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn)
weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32)
# Create input
input_tensor = torch.randn(16, 128, dtype=torch.float32)
# Call linear (should trigger dispatch)
output = torch.nn.functional.linear(input_tensor, weight_q, bias=None)
# Verify output shape
self.assertEqual(output.shape, (16, 256))
# Verify it's approximately correct (allowing for FP8 quantization error)
# Note: FP8 has limited precision, so use very loose tolerance
expected = torch.nn.functional.linear(input_tensor, weight_fp32, bias=None)
# Just check that it's in the right ballpark (within 50% error on average)
mean_rel_error = ((output - expected).abs() / (expected.abs() + 1e-6)).mean()
self.assertLess(mean_rel_error, 0.5, f"Mean relative error {mean_rel_error:.3f} is too large")
def test_linear_with_bias(self):
"""Test F.linear with quantized weight and bias"""
weight_fp32 = torch.randn(64, 32, dtype=torch.float32)
scale = torch.tensor(1.5)
weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn)
weight_q = quant_ops.QuantizedTensorFP8(weight_fp8, scale, orig_dtype=torch.float32)
input_tensor = torch.randn(8, 32, dtype=torch.float32)
bias = torch.randn(64, dtype=torch.float32)
output = torch.nn.functional.linear(input_tensor, weight_q, bias)
self.assertEqual(output.shape, (8, 64))
class TestActivationHandlers(unittest.TestCase):
"""Test activation function handlers"""
def test_silu_with_quantized_input(self):
"""Test SiLU with quantized input"""
# Create quantized input
input_fp32 = torch.randn(16, 128, dtype=torch.float32)
scale = torch.tensor(1.0)
input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn)
input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32)
# Apply SiLU
output = torch.nn.functional.silu(input_q)
# Should return a QuantizedTensorFP8
self.assertIsInstance(output, quant_ops.QuantizedTensorFP8)
# Verify approximate correctness
expected = torch.nn.functional.silu(input_fp32)
output_dq = output.dequantize()
self.assertTrue(torch.allclose(output_dq, expected, rtol=0.2, atol=0.2))
def test_layernorm_dequantizes(self):
"""Test that LayerNorm dequantizes input"""
# Create quantized input
input_fp32 = torch.randn(16, 128, dtype=torch.float32)
scale = torch.tensor(1.0)
input_fp8 = (input_fp32 / scale).to(torch.float8_e4m3fn)
input_q = quant_ops.QuantizedTensorFP8(input_fp8, scale, orig_dtype=torch.float32)
# Apply LayerNorm
weight = torch.ones(128)
bias = torch.zeros(128)
output = torch.nn.functional.layer_norm(input_q, (128,), weight, bias)
# Should NOT be quantized (LayerNorm breaks quantization)
self.assertNotIsInstance(output, quant_ops.QuantizedTensorFP8)
self.assertEqual(output.dtype, torch.float32)
class TestElementwiseHandlers(unittest.TestCase):
"""Test element-wise operation handlers"""
def test_add_mixed_tensors(self):
"""Test addition with mixed quantized/non-quantized tensors"""
# Create quantized tensor
a_fp32 = torch.ones(10, 20, dtype=torch.float32)
scale = torch.tensor(1.0)
a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn)
a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32)
# Non-quantized tensor
b = torch.ones(10, 20, dtype=torch.float32) * 2.0
# Add them
result = a_q + b
# Should be dequantized
self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8)
self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 3.0, rtol=0.1))
def test_mul_quantized_tensors(self):
"""Test multiplication of two quantized tensors"""
a_fp32 = torch.ones(10, 20) * 2.0
scale_a = torch.tensor(1.0)
a_fp8 = (a_fp32 / scale_a).to(torch.float8_e4m3fn)
a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale_a, orig_dtype=torch.float32)
b_fp32 = torch.ones(10, 20) * 3.0
scale_b = torch.tensor(1.0)
b_fp8 = (b_fp32 / scale_b).to(torch.float8_e4m3fn)
b_q = quant_ops.QuantizedTensorFP8(b_fp8, scale_b, orig_dtype=torch.float32)
result = a_q * b_q
# Should be dequantized
self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8)
self.assertTrue(torch.allclose(result, torch.ones(10, 20) * 6.0, rtol=0.1))
class TestFallbackMechanism(unittest.TestCase):
"""Test fallback for unsupported operations"""
def test_unsupported_op_dequantizes(self):
"""Test that unsupported operations fall back to dequantization"""
# Set seed for reproducibility
torch.manual_seed(42)
# Create quantized tensor
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
scale = torch.tensor(1.0)
a_fp8 = (a_fp32 / scale).to(torch.float8_e4m3fn)
a_q = quant_ops.QuantizedTensorFP8(a_fp8, scale, orig_dtype=torch.float32)
# Call an operation that doesn't have a registered handler
# For example, torch.abs
result = torch.abs(a_q)
# Should work via fallback (dequantize → abs → return)
self.assertNotIsInstance(result, quant_ops.QuantizedTensorFP8)
expected = torch.abs(a_fp32)
# FP8 introduces quantization error, so use loose tolerance
mean_error = (result - expected).abs().mean()
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
class TestMixedPrecisionOps(unittest.TestCase):
"""Test MixedPrecisionOps integration"""
def test_linear_layer_creation(self):
"""Test that MixedPrecisionOps.Linear can be created"""
layer = ops.MixedPrecisionOps.Linear(128, 256, bias=True, device="cpu", dtype=torch.float32)
self.assertIsInstance(layer, ops.MixedPrecisionOps.Linear)
self.assertFalse(layer._quantization_initialized)
self.assertIsNone(layer.quant_format)
def test_layer_quant_config_detection(self):
"""Test that layer quantization config is detected during load"""
# Set up layer config
ops.MixedPrecisionOps._layer_quant_config = {
"test_layer": {
"format": "fp8_e4m3fn",
"params": {}
}
}
# Create a state dict with quantized weight
weight_fp32 = torch.randn(256, 128, dtype=torch.float32)
scale = torch.tensor(2.0)
weight_fp8 = (weight_fp32 / scale).to(torch.float8_e4m3fn)
state_dict = {
"model.diffusion_model.test_layer.weight": weight_fp8,
"model.diffusion_model.test_layer.scale_weight": scale,
}
# Create layer and load
layer = ops.MixedPrecisionOps.Linear(128, 256, bias=False, device="cpu", dtype=torch.float8_e4m3fn)
layer.weight = torch.nn.Parameter(torch.zeros(256, 128, dtype=torch.float8_e4m3fn))
# Manually call _load_from_state_dict
layer._load_from_state_dict(
state_dict,
prefix="model.diffusion_model.test_layer.",
local_metadata={},
strict=True,
missing_keys=[],
unexpected_keys=[],
error_msgs=[]
)
# Verify quantization was initialized
self.assertTrue(layer._quantization_initialized)
self.assertEqual(layer.quant_format, "fp8_e4m3fn")
self.assertIsNotNone(layer.quant_scale)
# Verify weight is wrapped
self.assertIsInstance(layer.weight.data, quant_ops.QuantizedTensorFP8)
# Clean up
ops.MixedPrecisionOps._layer_quant_config = {}
class TestBackwardCompatibility(unittest.TestCase):
"""Test backward compatibility with legacy systems"""
def test_legacy_ops_classes_exist(self):
"""Test that legacy ops classes still exist"""
self.assertTrue(hasattr(ops, 'disable_weight_init'))
self.assertTrue(hasattr(ops, 'manual_cast'))
self.assertTrue(hasattr(ops, 'fp8_ops'))
self.assertTrue(hasattr(ops, 'scaled_fp8_ops'))
def test_pick_operations_legacy_path(self):
"""Test pick_operations returns correct class for legacy cases"""
# Test standard case
result = ops.pick_operations(torch.float32, torch.float32)
self.assertEqual(result, ops.disable_weight_init)
# Test manual cast case
result = ops.pick_operations(torch.float32, torch.float16)
self.assertEqual(result, ops.manual_cast)
class TestFP8LinearUnification(unittest.TestCase):
"""Test that fp8_linear now uses the unified tensor subclass infrastructure"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for FP8")
def test_fp8_linear_uses_tensor_subclass(self):
"""Verify fp8_linear wraps tensors in QuantizedTensorFP8"""
torch.manual_seed(42)
# Create a mock Linear layer with FP8 weight
linear = ops.fp8_ops.Linear(4, 3, bias=True)
linear.weight = torch.nn.Parameter(
torch.randn(3, 4, dtype=torch.bfloat16).to(torch.float8_e4m3fn),
requires_grad=False
)
linear.bias = torch.nn.Parameter(
torch.randn(3, dtype=torch.bfloat16),
requires_grad=False
)
linear.scale_weight = torch.tensor(1.0)
linear.scale_input = None # No input scaling
# Create input
input_tensor = torch.randn(2, 4, dtype=torch.bfloat16)
# Call fp8_linear - should work without errors
try:
result = ops.fp8_linear(linear, input_tensor)
self.assertIsNotNone(result)
self.assertEqual(result.shape, (2, 3))
except Exception as e:
# On CPU or unsupported hardware, _scaled_mm might not be available
# but the function should still complete without syntax errors
pass
def test_fp8_linear_maintains_signature(self):
"""Verify fp8_linear maintains its original function signature"""
import inspect
sig = inspect.signature(ops.fp8_linear)
params = list(sig.parameters.keys())
# Should have 'self' and 'input' parameters
self.assertIn('self', params)
self.assertIn('input', params)
self.assertEqual(len(params), 2)
def test_fp8_linear_returns_none_for_non_fp8(self):
"""Verify fp8_linear returns None for non-FP8 weights"""
# Create a Linear layer with BF16 weight (not FP8)
linear = ops.disable_weight_init.Linear(4, 3, bias=False)
linear.weight = torch.nn.Parameter(
torch.randn(3, 4, dtype=torch.bfloat16),
requires_grad=False
)
input_tensor = torch.randn(2, 4, dtype=torch.bfloat16)
# Should return None for non-FP8 weights
result = ops.fp8_linear(linear, input_tensor)
self.assertIsNone(result)
def test_fp8_ops_linear_uses_fp8_linear(self):
"""Verify fp8_ops.Linear still uses fp8_linear in forward pass"""
linear = ops.fp8_ops.Linear(4, 3, bias=False)
# Verify the class has the forward_comfy_cast_weights method
self.assertTrue(hasattr(linear, 'forward_comfy_cast_weights'))
# The forward_comfy_cast_weights should attempt to call fp8_linear
# (we can't easily test this without mocking, but we verify structure)
import inspect
source = inspect.getsource(linear.forward_comfy_cast_weights)
self.assertIn('fp8_linear', source)
if __name__ == "__main__":
unittest.main()