From c58c13b2bad6df0de93cc0cf107e96522a3cb5b3 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 31 Oct 2025 21:25:17 -0700 Subject: [PATCH] Fix torch compile regression on fp8 ops. (#10580) --- comfy/ops.py | 24 +++++------------ comfy/quant_ops.py | 27 +++++++++++++++---- .../comfy_quant/test_mixed_precision.py | 8 +++--- tests-unit/comfy_quant/test_quant_registry.py | 20 +++++++------- 4 files changed, 43 insertions(+), 36 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 18f6b804b..279f6b1a7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -401,15 +401,9 @@ def fp8_linear(self, input): if dtype not in [torch.float8_e4m3fn]: return None - tensor_2d = False - if len(input.shape) == 2: - tensor_2d = True - input = input.unsqueeze(1) - - input_shape = input.shape input_dtype = input.dtype - if len(input.shape) == 3: + if input.ndim == 3 or input.ndim == 2: w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) scale_weight = self.scale_weight @@ -422,24 +416,20 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) input = torch.clamp(input, min=-448, max=448, out=input) - input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} - quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight) + quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) else: scale_input = scale_input.to(input.device) - quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) + quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype) # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} - quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) + quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) uncast_bias_weight(self, w, bias, offload_stream) - - if tensor_2d: - return o.reshape(input_shape[0], -1) - return o.reshape((-1, input_shape[1], self.weight.shape[0])) + return o return None @@ -540,12 +530,12 @@ if CUBLAS_IS_AVAILABLE: # ============================================================================== # Mixed Precision Operations # ============================================================================== -from .quant_ops import QuantizedTensor, TensorCoreFP8Layout +from .quant_ops import QuantizedTensor QUANT_FORMAT_MIXINS = { "float8_e4m3fn": { "dtype": torch.float8_e4m3fn, - "layout_type": TensorCoreFP8Layout, + "layout_type": "TensorCoreFP8Layout", "parameters": { "weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), "input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index c822fe53c..873f173ed 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -123,7 +123,7 @@ class QuantizedTensor(torch.Tensor): layout_type: Layout class (subclass of QuantizedLayout) layout_params: Dict with layout-specific parameters """ - return torch.Tensor._make_subclass(cls, qdata, require_grad=False) + return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) def __init__(self, qdata, layout_type, layout_params): self._qdata = qdata.contiguous() @@ -183,11 +183,11 @@ class QuantizedTensor(torch.Tensor): @classmethod def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': - qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) + qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs) return cls(qdata, layout_type, layout_params) def dequantize(self) -> torch.Tensor: - return self._layout_type.dequantize(self._qdata, **self._layout_params) + return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): @@ -379,7 +379,12 @@ class TensorCoreFP8Layout(QuantizedLayout): return qtensor._qdata, qtensor._layout_params['scale'] -@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) +LAYOUTS = { + "TensorCoreFP8Layout": TensorCoreFP8Layout, +} + + +@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout") def fp8_linear(func, args, kwargs): input_tensor = args[0] weight = args[1] @@ -422,7 +427,7 @@ def fp8_linear(func, args, kwargs): 'scale': output_scale, 'orig_dtype': input_tensor._layout_params['orig_dtype'] } - return QuantizedTensor(output, TensorCoreFP8Layout, output_params) + return QuantizedTensor(output, "TensorCoreFP8Layout", output_params) else: return output @@ -436,3 +441,15 @@ def fp8_linear(func, args, kwargs): input_tensor = input_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight, bias) + + +@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout") +@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout") +def fp8_func(func, args, kwargs): + input_tensor = args[0] + if isinstance(input_tensor, QuantizedTensor): + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + ar = list(args) + ar[0] = plain_input + return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) + return func(*args, **kwargs) diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 267bc177b..f8d1fd04e 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -14,7 +14,7 @@ if not has_gpu(): args.cpu = True from comfy import ops -from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout +from comfy.quant_ops import QuantizedTensor class SimpleModel(torch.nn.Module): @@ -104,14 +104,14 @@ class TestMixedPrecisionOps(unittest.TestCase): # Verify weights are wrapped in QuantizedTensor self.assertIsInstance(model.layer1.weight, QuantizedTensor) - self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) + self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout") # Layer 2 should NOT be quantized self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) # Layer 3 should be quantized self.assertIsInstance(model.layer3.weight, QuantizedTensor) - self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) + self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout") # Verify scales were loaded self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) @@ -155,7 +155,7 @@ class TestMixedPrecisionOps(unittest.TestCase): # Verify layer1.weight is a QuantizedTensor with scale preserved self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) - self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) + self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout") # Verify non-quantized layers are standard tensors self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py index 477811029..9cb54ede8 100644 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -25,14 +25,14 @@ class TestQuantizedTensor(unittest.TestCase): scale = torch.tensor(2.0) layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} - qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) self.assertIsInstance(qt, QuantizedTensor) self.assertEqual(qt.shape, (256, 128)) self.assertEqual(qt.dtype, torch.float8_e4m3fn) self.assertEqual(qt._layout_params['scale'], scale) self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) - self.assertEqual(qt._layout_type, TensorCoreFP8Layout) + self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") def test_dequantize(self): """Test explicit dequantization""" @@ -41,7 +41,7 @@ class TestQuantizedTensor(unittest.TestCase): scale = torch.tensor(3.0) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) dequantized = qt.dequantize() self.assertEqual(dequantized.dtype, torch.float32) @@ -54,7 +54,7 @@ class TestQuantizedTensor(unittest.TestCase): qt = QuantizedTensor.from_float( float_tensor, - TensorCoreFP8Layout, + "TensorCoreFP8Layout", scale=scale, dtype=torch.float8_e4m3fn ) @@ -77,28 +77,28 @@ class TestGenericUtilities(unittest.TestCase): fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) # Detach should return a new QuantizedTensor qt_detached = qt.detach() self.assertIsInstance(qt_detached, QuantizedTensor) self.assertEqual(qt_detached.shape, qt.shape) - self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) + self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout") def test_clone(self): """Test clone operation on quantized tensor""" fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) # Clone should return a new QuantizedTensor qt_cloned = qt.clone() self.assertIsInstance(qt_cloned, QuantizedTensor) self.assertEqual(qt_cloned.shape, qt.shape) - self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) + self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout") # Verify it's a deep copy self.assertIsNot(qt_cloned._qdata, qt._qdata) @@ -109,7 +109,7 @@ class TestGenericUtilities(unittest.TestCase): fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) scale = torch.tensor(1.5) layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) # Moving to same device should work (CPU to CPU) qt_cpu = qt.to('cpu') @@ -169,7 +169,7 @@ class TestFallbackMechanism(unittest.TestCase): scale = torch.tensor(1.0) a_q = QuantizedTensor.from_float( a_fp32, - TensorCoreFP8Layout, + "TensorCoreFP8Layout", scale=scale, dtype=torch.float8_e4m3fn )