diff --git a/comfy/model_base.py b/comfy/model_base.py index 7b4651f8e..f850cc402 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -326,7 +326,7 @@ class BaseModel(torch.nn.Module): if self.model_config.scaled_fp8 is not None: unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) - + # Save mixed precision metadata if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: metadata = { diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 335ccbd17..c4fc27742 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -16,7 +16,7 @@ def detect_layer_quantization(metadata): logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") return quant_metadata["layers"] else: - raise ValueError(f"Invalid quantization metadata format") + raise ValueError("Invalid quantization metadata format") return None diff --git a/comfy/ops.py b/comfy/ops.py index 8af1e949d..e2d76d7a9 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -325,7 +325,7 @@ class manual_cast(disable_weight_init): def fp8_linear(self, input): """ - Legacy FP8 linear function for backward compatibility. + Legacy FP8 linear function for backward compatibility. Uses QuantizedTensor subclass for dispatch. """ dtype = self.weight.dtype @@ -339,7 +339,7 @@ def fp8_linear(self, input): input_shape = input.shape input_dtype = input.dtype - + if len(input.shape) == 3: w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) @@ -354,14 +354,14 @@ def fp8_linear(self, input): scale_input = torch.ones((), device=input.device, dtype=torch.float32) else: scale_input = scale_input.to(input.device) - + # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, fp8_dtype=dtype) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) - + if tensor_2d: return o.reshape(input_shape[0], -1) return o.reshape((-1, input_shape[1], self.weight.shape[0])) @@ -503,8 +503,8 @@ class MixedPrecisionOps(disable_weight_init): def reset_parameters(self): return None - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): device = self.factory_kwargs["device"] @@ -520,10 +520,10 @@ class MixedPrecisionOps(disable_weight_init): quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) if quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") - + mixin = QUANT_FORMAT_MIXINS[quant_format] self.layout_type = mixin["layout_type"] - + layout_params = { 'scale': state_dict.pop(f"{prefix}weight_scale", None), 'orig_dtype': MixedPrecisionOps._compute_dtype @@ -558,7 +558,7 @@ class MixedPrecisionOps(disable_weight_init): not isinstance(input, QuantizedTensor)): input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype) return self._forward(input, self.weight, self.bias) - + def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: @@ -566,7 +566,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_ MixedPrecisionOps._compute_dtype = compute_dtype logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") return MixedPrecisionOps - + fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 96d2fa03f..aa1a231bd 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -31,7 +31,7 @@ def register_generic_util(torch_op): Decorator to register a generic utility that works for all layouts. Args: torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) - + Example: @register_generic_util(torch.ops.aten.detach.default) def generic_detach(func, args, kwargs): @@ -78,10 +78,10 @@ def _copy_layout_params(params): class QuantizedLayout: """ Base class for quantization layouts. - + A layout encapsulates the format-specific logic for quantization/dequantization and provides a uniform interface for extracting raw tensors needed for computation. - + New quantization formats should subclass this and implement the required methods. """ @classmethod @@ -90,8 +90,8 @@ class QuantizedLayout: @staticmethod def dequantize(qdata, **layout_params) -> torch.Tensor: - raise NotImplementedError(f"TensorLayout must implement dequantize()") - + raise NotImplementedError("TensorLayout must implement dequantize()") + @classmethod def get_plain_tensors(cls, qtensor) -> torch.Tensor: raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") @@ -100,45 +100,45 @@ class QuantizedLayout: class QuantizedTensor(torch.Tensor): """ Universal quantized tensor that works with any layout. - + This tensor subclass uses a pluggable layout system to support multiple quantization formats (FP8, INT4, INT8, etc.) without code duplication. - + The layout_type determines format-specific behavior, while common operations (detach, clone, to) are handled generically. - + Attributes: _qdata: The quantized tensor data _layout_type: Layout class (e.g., TensorCoreFP8Layout) _layout_params: Dict with layout-specific params (scale, zero_point, etc.) """ - + @staticmethod def __new__(cls, qdata, layout_type, layout_params): """ Create a quantized tensor. - + Args: qdata: The quantized data tensor layout_type: Layout class (subclass of QuantizedLayout) layout_params: Dict with layout-specific parameters """ return torch.Tensor._make_subclass(cls, qdata, require_grad=False) - + def __init__(self, qdata, layout_type, layout_params): self._qdata = qdata.contiguous() self._layout_type = layout_type self._layout_params = layout_params - + def __repr__(self): layout_name = self._layout_type.__name__ param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" - + @property def layout_type(self): return self._layout_type - + def __tensor_flatten__(self): """ Tensor flattening protocol for proper device movement. @@ -147,7 +147,7 @@ class QuantizedTensor(torch.Tensor): ctx = { "layout_type": self._layout_type, } - + tensor_params = {} non_tensor_params = {} for k, v in self._layout_params.items(): @@ -155,17 +155,17 @@ class QuantizedTensor(torch.Tensor): tensor_params[k] = v else: non_tensor_params[k] = v - + ctx["tensor_param_keys"] = list(tensor_params.keys()) ctx["non_tensor_params"] = non_tensor_params - + for k, v in tensor_params.items(): attr_name = f"_layout_param_{k}" object.__setattr__(self, attr_name, v) inner_tensors.append(attr_name) - + return inner_tensors, ctx - + @staticmethod def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): """ @@ -174,41 +174,41 @@ class QuantizedTensor(torch.Tensor): """ layout_type = ctx["layout_type"] layout_params = dict(ctx["non_tensor_params"]) - + for key in ctx["tensor_param_keys"]: attr_name = f"_layout_param_{key}" layout_params[key] = inner_tensors[attr_name] - + return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params) - + @classmethod def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) return cls(qdata, layout_type, layout_params) - + def dequantize(self) -> torch.Tensor: return self._layout_type.dequantize(self._qdata, **self._layout_params) - + @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): kwargs = kwargs or {} - + # Step 1: Check generic utilities first (detach, clone, to, etc.) if func in _GENERIC_UTILS: return _GENERIC_UTILS[func](func, args, kwargs) - + # Step 2: Check layout-specific handlers (linear, matmul, etc.) layout_type = _get_layout_from_args(args) if layout_type and func in _LAYOUT_REGISTRY: handler = _LAYOUT_REGISTRY[func].get(layout_type) if handler: return handler(func, args, kwargs) - + # Step 3: Fallback to dequantization if isinstance(args[0] if args else None, QuantizedTensor): logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") return cls._dequant_and_fallback(func, args, kwargs) - + @classmethod def _dequant_and_fallback(cls, func, args, kwargs): def dequant_arg(arg): @@ -217,7 +217,7 @@ class QuantizedTensor(torch.Tensor): elif isinstance(arg, (list, tuple)): return type(arg)(dequant_arg(a) for a in arg) return arg - + new_args = dequant_arg(args) new_kwargs = dequant_arg(kwargs) return func(*new_args, **new_kwargs) @@ -239,13 +239,13 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= f"QuantizedTensor: dtype conversion requested to {target_dtype}, " f"but not supported for quantized tensors. Ignoring dtype." ) - + if target_layout is not None and target_layout != torch.strided: logging.warning( f"QuantizedTensor: layout change requested to {target_layout}, " f"but not supported. Ignoring layout." ) - + # Handle device transfer current_device = qt._qdata.device if target_device is not None: @@ -254,7 +254,7 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= target_device = torch.device(target_device) if isinstance(current_device, str): current_device = torch.device(current_device) - + if target_device != current_device: logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") new_q_data = qt._qdata.to(device=target_device) @@ -262,7 +262,7 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") return new_qt - + logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") return qt @@ -318,7 +318,7 @@ def generic_to_dtype_layout(func, args, kwargs): def generic_copy_(func, args, kwargs): qt_dest = args[0] src = args[1] - + if isinstance(qt_dest, QuantizedTensor): if isinstance(src, QuantizedTensor): # Copy from another quantized tensor @@ -383,15 +383,15 @@ def fp8_linear(func, args, kwargs): input_tensor = args[0] weight = args[1] bias = args[2] if len(args) > 2 else None - + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - + out_dtype = kwargs.get("out_dtype") if out_dtype is None: out_dtype = input_tensor._layout_params['orig_dtype'] - + weight_t = plain_weight.t() tensor_2d = False @@ -424,7 +424,7 @@ def fp8_linear(func, args, kwargs): return QuantizedTensor(output, TensorCoreFP8Layout, output_params) else: return output - + except Exception as e: raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index e34552760..1102f9bd4 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -16,7 +16,7 @@ class SimpleModel(torch.nn.Module): self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16) self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16) - + def forward(self, x): x = self.layer1(x) x = torch.nn.functional.relu(x) @@ -32,10 +32,10 @@ class TestMixedPrecisionOps(unittest.TestCase): """Test that model with no quantization works normally""" # Configure no quantization ops.MixedPrecisionOps._layer_quant_config = {} - + # Create model model = SimpleModel(operations=ops.MixedPrecisionOps) - + # Initialize weights manually model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16)) @@ -43,19 +43,19 @@ class TestMixedPrecisionOps(unittest.TestCase): model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16)) model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16)) model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16)) - + # Initialize weight_function and bias_function for layer in [model.layer1, model.layer2, model.layer3]: layer.weight_function = [] layer.bias_function = [] - + # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) - + self.assertEqual(output.shape, (5, 40)) self.assertEqual(output.dtype, torch.bfloat16) - + def test_mixed_precision_load(self): """Test loading a mixed precision model from state dict""" # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard @@ -70,52 +70,52 @@ class TestMixedPrecisionOps(unittest.TestCase): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create state dict with mixed precision fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn) - + state_dict = { # Layer 1: FP8 E4M3FN "layer1.weight": fp8_weight1, "layer1.bias": torch.randn(20, dtype=torch.bfloat16), "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), - + # Layer 2: Standard BF16 "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), "layer2.bias": torch.randn(30, dtype=torch.bfloat16), - + # Layer 3: FP8 E4M3FN "layer3.weight": fp8_weight3, "layer3.bias": torch.randn(40, dtype=torch.bfloat16), "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), } - + # Create model and load state dict (strict=False because custom loading pops keys) model = SimpleModel(operations=ops.MixedPrecisionOps) model.load_state_dict(state_dict, strict=False) - + # Verify weights are wrapped in QuantizedTensor self.assertIsInstance(model.layer1.weight, QuantizedTensor) self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) - + # Layer 2 should NOT be quantized self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) - + # Layer 3 should be quantized self.assertIsInstance(model.layer3.weight, QuantizedTensor) self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) - + # Verify scales were loaded self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) - + # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) - + self.assertEqual(output.shape, (5, 40)) - + def test_state_dict_quantized_preserved(self): """Test that quantized weights are preserved in state_dict()""" # Configure mixed precision @@ -126,7 +126,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) state_dict1 = { @@ -138,22 +138,22 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - + model = SimpleModel(operations=ops.MixedPrecisionOps) model.load_state_dict(state_dict1, strict=False) - + # Save state dict state_dict2 = model.state_dict() - + # Verify layer1.weight is a QuantizedTensor with scale preserved self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) - + # Verify non-quantized layers are standard tensors self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor) - + def test_weight_function_compatibility(self): """Test that weight_function (LoRA) works with quantized layers""" # Configure FP8 quantization @@ -164,7 +164,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create and load model fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) state_dict = { @@ -176,24 +176,24 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - + model = SimpleModel(operations=ops.MixedPrecisionOps) model.load_state_dict(state_dict, strict=False) - + # Add a weight function (simulating LoRA) # This should trigger dequantization during forward pass def apply_lora(weight): lora_delta = torch.randn_like(weight) * 0.01 return weight + lora_delta - + model.layer1.weight_function.append(apply_lora) - + # Forward pass should work with LoRA (triggers weight_function path) input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) output = model(input_tensor) - + self.assertEqual(output.shape, (5, 40)) - + def test_error_handling_unknown_format(self): """Test that unknown formats raise error""" # Configure with unknown format @@ -204,7 +204,7 @@ class TestMixedPrecisionOps(unittest.TestCase): } } ops.MixedPrecisionOps._layer_quant_config = layer_quant_config - + # Create state dict state_dict = { "layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16), @@ -214,7 +214,7 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS model = SimpleModel(operations=ops.MixedPrecisionOps) with self.assertRaises(KeyError): diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 263581417..26e91a7ee 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -11,51 +11,51 @@ from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout class TestQuantizedTensor(unittest.TestCase): """Test the QuantizedTensor subclass with FP8 layout""" - + def test_creation(self): """Test creating a QuantizedTensor with TensorCoreFP8Layout""" fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(2.0) layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} - + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + self.assertIsInstance(qt, QuantizedTensor) self.assertEqual(qt.shape, (256, 128)) self.assertEqual(qt.dtype, torch.float8_e4m3fn) self.assertEqual(qt._layout_params['scale'], scale) self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) self.assertEqual(qt._layout_type, TensorCoreFP8Layout) - + def test_dequantize(self): """Test explicit dequantization""" fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(3.0) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - + qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) dequantized = qt.dequantize() - + self.assertEqual(dequantized.dtype, torch.float32) self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) - + def test_from_float(self): """Test creating QuantizedTensor from float tensor""" float_tensor = torch.randn(64, 32, dtype=torch.float32) scale = torch.tensor(1.5) - + qt = QuantizedTensor.from_float( - float_tensor, - TensorCoreFP8Layout, + float_tensor, + TensorCoreFP8Layout, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + self.assertIsInstance(qt, QuantizedTensor) self.assertEqual(qt.dtype, torch.float8_e4m3fn) self.assertEqual(qt.shape, (64, 32)) - + # Verify dequantization gives approximately original values dequantized = qt.dequantize() mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() @@ -64,48 +64,48 @@ class TestQuantizedTensor(unittest.TestCase): class TestGenericUtilities(unittest.TestCase): """Test generic utility operations""" - + def test_detach(self): """Test detach operation on quantized tensor""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + # Detach should return a new QuantizedTensor qt_detached = qt.detach() - + self.assertIsInstance(qt_detached, QuantizedTensor) self.assertEqual(qt_detached.shape, qt.shape) self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) - + def test_clone(self): """Test clone operation on quantized tensor""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + # Clone should return a new QuantizedTensor qt_cloned = qt.clone() - + self.assertIsInstance(qt_cloned, QuantizedTensor) self.assertEqual(qt_cloned.shape, qt.shape) self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) - + # Verify it's a deep copy self.assertIsNot(qt_cloned._qdata, qt._qdata) - + def test_to_device(self): """Test device transfer""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) - + # Moving to same device should work (CPU to CPU) qt_cpu = qt.to('cpu') - + self.assertIsInstance(qt_cpu, QuantizedTensor) self.assertEqual(qt_cpu.device.type, 'cpu') self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') @@ -113,64 +113,63 @@ class TestGenericUtilities(unittest.TestCase): class TestTensorCoreFP8Layout(unittest.TestCase): """Test the TensorCoreFP8Layout implementation""" - + def test_quantize(self): """Test quantization method""" float_tensor = torch.randn(32, 64, dtype=torch.float32) scale = torch.tensor(1.5) - + qdata, layout_params = TensorCoreFP8Layout.quantize( float_tensor, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + self.assertEqual(qdata.dtype, torch.float8_e4m3fn) self.assertEqual(qdata.shape, float_tensor.shape) self.assertIn('scale', layout_params) self.assertIn('orig_dtype', layout_params) self.assertEqual(layout_params['orig_dtype'], torch.float32) - + def test_dequantize(self): """Test dequantization method""" float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 scale = torch.tensor(1.0) - + qdata, layout_params = TensorCoreFP8Layout.quantize( float_tensor, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) - + # Should approximately match original self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) class TestFallbackMechanism(unittest.TestCase): """Test fallback for unsupported operations""" - + def test_unsupported_op_dequantizes(self): """Test that unsupported operations fall back to dequantization""" # Set seed for reproducibility torch.manual_seed(42) - + # Create quantized tensor a_fp32 = torch.randn(10, 20, dtype=torch.float32) scale = torch.tensor(1.0) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} a_q = QuantizedTensor.from_float( a_fp32, TensorCoreFP8Layout, scale=scale, fp8_dtype=torch.float8_e4m3fn ) - + # Call an operation that doesn't have a registered handler # For example, torch.abs result = torch.abs(a_q) - + # Should work via fallback (dequantize → abs → return) self.assertNotIsInstance(result, QuantizedTensor) expected = torch.abs(a_fp32)