From 7ea731ea98445f12c07807102a1f2a4350952786 Mon Sep 17 00:00:00 2001 From: lspindler Date: Wed, 22 Oct 2025 11:25:39 +0200 Subject: [PATCH] Fix FP8 MM --- comfy/ops.py | 14 +--- comfy/quant_ops.py | 205 +++++++++++---------------------------------- 2 files changed, 48 insertions(+), 171 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index 2e6782dbd..060b35137 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -370,19 +370,7 @@ def fp8_linear(self, input): # Wrap weight in QuantizedTensorFP8 - this enables unified dispatch quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype) - - # Handle input quantization and wrapping - if self.scale_input is None: - # Clamp input to FP8 range and quantize - input = torch.clamp(input, min=-448, max=448, out=input) - input_fp8 = input.reshape(-1, input_shape[2]).to(dtype).contiguous() - else: - # Apply inverse scale and quantize - input_fp8 = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() - - # Wrap input in QuantizedTensorFP8 - quantized_input = QuantizedTensorFP8(input_fp8, scale_input, orig_dtype=input_dtype) - + quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype) # Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py! # This is the key unification: all FP8 computation goes through one path o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 681eb9134..8e3bacbaf 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -79,18 +79,47 @@ class QuantizedTensorFP8(torch.Tensor): self._scale = scale self._orig_dtype = orig_dtype # Store a reference to prevent infinite recursion in dequantize - self._raw_data = tensor + self._raw_data = tensor.contiguous() def __repr__(self): return (f"QuantizedTensorFP8(shape={self.shape}, " f"scale={self._scale:.4f}, dtype={self._orig_dtype})") + @classmethod + def quantize(cls, tensor, scale, fp8_dtype=torch.float8_e4m3fn): + orig_dtype = tensor.dtype + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) + + tensor_fp8 = None + if _CK_AVAILABLE: + try: + tensor_fp8 = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype) + except Exception as e: + logging.debug(f"comfy_kitchen quantization failed, using PyTorch: {e}") + + if tensor_fp8 is None: + lp_amax = torch.finfo(fp8_dtype).max + tensor_scaled = tensor.float() / scale + torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled) + tensor_fp8 = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format) + + return cls(tensor_fp8, scale, orig_dtype=orig_dtype) + + @classmethod + def quantize_dynamic(cls, tensor, strategy="amax", fp8_dtype=torch.float8_e4m3fn): + if strategy == "amax": + scale = torch.amax(tensor) / torch.finfo(fp8_dtype).max + scale = scale.to(tensor.device, dtype=torch.float32) + else: + raise ValueError(f"Unknown quantization strategy: {strategy}. " + f"Supported: 'amax'") + + return cls.quantize(tensor, scale, fp8_dtype=fp8_dtype) + @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - """ - Intercept ALL torch operations. - Routes to registered handlers or falls back to dequantization. - """ kwargs = kwargs or {} # Special case: skip dispatch for internal tensor operations @@ -134,16 +163,11 @@ class QuantizedTensorFP8(torch.Tensor): return func(*new_args, **new_kwargs) def dequantize(self) -> torch.Tensor: - """Explicit dequantization""" - # Use the raw data and convert directly - # Call aten ops directly to minimize dispatch interference plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype) - # Multiply by scale return plain_tensor * self._scale def detach(self): """Detach returns a new QuantizedTensorFP8 (required for Parameter)""" - # Detach the raw data and create a new QuantizedTensorFP8 detached_data = self._raw_data.detach() return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype) @@ -165,48 +189,35 @@ def handle_linear_fp8(func, args, kwargs): input_tensor = args[0] weight = args[1] bias = args[2] if len(args) > 2 else None - + out_dtype = kwargs.get("out_dtype", input_tensor._orig_dtype) + # Case 1: Both input and weight are FP8 if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8): - # Use _scaled_mm for FP8×FP8 matmul # Get plain tensors to avoid dispatch recursion plain_input = input_tensor._raw_data plain_weight = weight._raw_data - weight_t = plain_weight.t().contiguous() + weight_t = plain_weight.t() # Keep as column-major for cuBLASLt try: - if bias is not None: - output = torch._scaled_mm( - plain_input, - weight_t, - out_dtype=input_tensor._orig_dtype, - bias=bias, - scale_a=input_tensor._scale, - scale_b=weight._scale - ) - else: - output = torch._scaled_mm( - plain_input, - weight_t, - out_dtype=input_tensor._orig_dtype, - scale_a=input_tensor._scale, - scale_b=weight._scale - ) - + output = torch._scaled_mm( + plain_input, + weight_t, + bias=bias, + scale_a=input_tensor._scale, + scale_b=weight._scale, + out_dtype=out_dtype, + ) if isinstance(output, tuple): output = output[0] - # Check if output is FP8 (some architectures support this) if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - # Keep quantized! output_scale = input_tensor._scale * weight._scale - return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) + return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) # TODO is this correct? Can't cuBLAS return it else: return output except Exception as e: logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") - # Fall through to dequantization path - + # Case 2: Only weight is quantized if isinstance(weight, QuantizedTensorFP8): weight_dq = weight.dequantize() @@ -222,125 +233,3 @@ def handle_linear_fp8(func, args, kwargs): else: return torch.nn.functional.linear(input_tensor, weight, bias) - -@register_quant_op(torch.ops.aten.silu.default) -def handle_silu_fp8(func, args, kwargs): - """ - SiLU can be computed approximately on FP8. - Keeps activations quantized for next layer. - """ - input_q = args[0] - - if not isinstance(input_q, QuantizedTensorFP8): - # Not quantized, use standard path - return torch.nn.functional.silu(input_q) - - # Compute SiLU while keeping quantized - # SiLU(x) = x * sigmoid(x) - - # Get plain tensor to avoid dispatch recursion - plain_tensor = input_q._raw_data - - # Upcast to FP16 for sigmoid stability - x_fp16 = plain_tensor.to(torch.float16) - sigmoid_fp16 = torch.sigmoid(x_fp16 * input_q._scale) - result_fp16 = x_fp16 * sigmoid_fp16 - - # Convert back to FP8 - result_fp8 = result_fp16.to(plain_tensor.dtype) - - # Return quantized (scale approximately preserved) - return QuantizedTensorFP8(result_fp8, input_q._scale, input_q._orig_dtype) - - -@register_quant_op(torch.ops.aten.layer_norm.default) -def handle_layernorm_fp8(func, args, kwargs): - """ - LayerNorm requires high precision. - Dequantizes input and returns standard tensor. - """ - input_q = args[0] - normalized_shape = args[1] - weight = args[2] if len(args) > 2 else None - bias = args[3] if len(args) > 3 else None - eps = args[4] if len(args) > 4 else 1e-5 - - # Dequantize if needed - if isinstance(input_q, QuantizedTensorFP8): - x = input_q.dequantize() - else: - x = input_q - - # Standard LayerNorm - result = torch.nn.functional.layer_norm(x, normalized_shape, weight, bias, eps) - - # Return dequantized (next layer will quantize if needed) - return result - - -@register_quant_op(torch.ops.aten.group_norm.default) -def handle_groupnorm_fp8(func, args, kwargs): - """ - GroupNorm requires high precision. - Dequantizes input and returns standard tensor. - """ - input_q = args[0] - num_groups = args[1] - weight = args[2] if len(args) > 2 else None - bias = args[3] if len(args) > 3 else None - eps = args[4] if len(args) > 4 else 1e-5 - - # Dequantize if needed - if isinstance(input_q, QuantizedTensorFP8): - x = input_q.dequantize() - else: - x = input_q - - # Standard GroupNorm - result = torch.nn.functional.group_norm(x, num_groups, weight, bias, eps) - - # Return dequantized - return result - - -@register_quant_op(torch.ops.aten.add.Tensor) -def handle_add_fp8(func, args, kwargs): - """ - Handle addition with mixed quantized/non-quantized tensors. - """ - a = args[0] - b = args[1] - - # If both are quantized, dequantize both - if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8): - return a.dequantize() + b.dequantize() - # If only one is quantized, dequantize it - elif isinstance(a, QuantizedTensorFP8): - return a.dequantize() + b - elif isinstance(b, QuantizedTensorFP8): - return a + b.dequantize() - # Neither is quantized - else: - return a + b - - -@register_quant_op(torch.ops.aten.mul.Tensor) -def handle_mul_fp8(func, args, kwargs): - """ - Handle multiplication with mixed quantized/non-quantized tensors. - """ - a = args[0] - b = args[1] - - # If both are quantized, dequantize both - if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8): - return a.dequantize() * b.dequantize() - # If only one is quantized, dequantize it - elif isinstance(a, QuantizedTensorFP8): - return a.dequantize() * b - elif isinstance(b, QuantizedTensorFP8): - return a * b.dequantize() - # Neither is quantized - else: - return a * b -