Fix FP8 MM

This commit is contained in:
lspindler 2025-10-22 11:25:39 +02:00
parent 388294677e
commit 19ce6b056d
2 changed files with 48 additions and 171 deletions

View File

@ -390,19 +390,7 @@ def fp8_linear(self, input):
# Wrap weight in QuantizedTensorFP8 - this enables unified dispatch # Wrap weight in QuantizedTensorFP8 - this enables unified dispatch
quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype) quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype)
quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype)
# Handle input quantization and wrapping
if self.scale_input is None:
# Clamp input to FP8 range and quantize
input = torch.clamp(input, min=-448, max=448, out=input)
input_fp8 = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
else:
# Apply inverse scale and quantize
input_fp8 = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
# Wrap input in QuantizedTensorFP8
quantized_input = QuantizedTensorFP8(input_fp8, scale_input, orig_dtype=input_dtype)
# Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py! # Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py!
# This is the key unification: all FP8 computation goes through one path # This is the key unification: all FP8 computation goes through one path
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)

View File

@ -79,18 +79,47 @@ class QuantizedTensorFP8(torch.Tensor):
self._scale = scale self._scale = scale
self._orig_dtype = orig_dtype self._orig_dtype = orig_dtype
# Store a reference to prevent infinite recursion in dequantize # Store a reference to prevent infinite recursion in dequantize
self._raw_data = tensor self._raw_data = tensor.contiguous()
def __repr__(self): def __repr__(self):
return (f"QuantizedTensorFP8(shape={self.shape}, " return (f"QuantizedTensorFP8(shape={self.shape}, "
f"scale={self._scale:.4f}, dtype={self._orig_dtype})") f"scale={self._scale:.4f}, dtype={self._orig_dtype})")
@classmethod
def quantize(cls, tensor, scale, fp8_dtype=torch.float8_e4m3fn):
orig_dtype = tensor.dtype
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
tensor_fp8 = None
if _CK_AVAILABLE:
try:
tensor_fp8 = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype)
except Exception as e:
logging.debug(f"comfy_kitchen quantization failed, using PyTorch: {e}")
if tensor_fp8 is None:
lp_amax = torch.finfo(fp8_dtype).max
tensor_scaled = tensor.float() / scale
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
tensor_fp8 = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
return cls(tensor_fp8, scale, orig_dtype=orig_dtype)
@classmethod
def quantize_dynamic(cls, tensor, strategy="amax", fp8_dtype=torch.float8_e4m3fn):
if strategy == "amax":
scale = torch.amax(tensor) / torch.finfo(fp8_dtype).max
scale = scale.to(tensor.device, dtype=torch.float32)
else:
raise ValueError(f"Unknown quantization strategy: {strategy}. "
f"Supported: 'amax'")
return cls.quantize(tensor, scale, fp8_dtype=fp8_dtype)
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
"""
Intercept ALL torch operations.
Routes to registered handlers or falls back to dequantization.
"""
kwargs = kwargs or {} kwargs = kwargs or {}
# Special case: skip dispatch for internal tensor operations # Special case: skip dispatch for internal tensor operations
@ -134,16 +163,11 @@ class QuantizedTensorFP8(torch.Tensor):
return func(*new_args, **new_kwargs) return func(*new_args, **new_kwargs)
def dequantize(self) -> torch.Tensor: def dequantize(self) -> torch.Tensor:
"""Explicit dequantization"""
# Use the raw data and convert directly
# Call aten ops directly to minimize dispatch interference
plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype) plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype)
# Multiply by scale
return plain_tensor * self._scale return plain_tensor * self._scale
def detach(self): def detach(self):
"""Detach returns a new QuantizedTensorFP8 (required for Parameter)""" """Detach returns a new QuantizedTensorFP8 (required for Parameter)"""
# Detach the raw data and create a new QuantizedTensorFP8
detached_data = self._raw_data.detach() detached_data = self._raw_data.detach()
return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype) return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype)
@ -165,47 +189,34 @@ def handle_linear_fp8(func, args, kwargs):
input_tensor = args[0] input_tensor = args[0]
weight = args[1] weight = args[1]
bias = args[2] if len(args) > 2 else None bias = args[2] if len(args) > 2 else None
out_dtype = kwargs.get("out_dtype", input_tensor._orig_dtype)
# Case 1: Both input and weight are FP8 # Case 1: Both input and weight are FP8
if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8): if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8):
# Use _scaled_mm for FP8×FP8 matmul
# Get plain tensors to avoid dispatch recursion # Get plain tensors to avoid dispatch recursion
plain_input = input_tensor._raw_data plain_input = input_tensor._raw_data
plain_weight = weight._raw_data plain_weight = weight._raw_data
weight_t = plain_weight.t().contiguous() weight_t = plain_weight.t() # Keep as column-major for cuBLASLt
try: try:
if bias is not None: output = torch._scaled_mm(
output = torch._scaled_mm( plain_input,
plain_input, weight_t,
weight_t, bias=bias,
out_dtype=input_tensor._orig_dtype, scale_a=input_tensor._scale,
bias=bias, scale_b=weight._scale,
scale_a=input_tensor._scale, out_dtype=out_dtype,
scale_b=weight._scale )
)
else:
output = torch._scaled_mm(
plain_input,
weight_t,
out_dtype=input_tensor._orig_dtype,
scale_a=input_tensor._scale,
scale_b=weight._scale
)
if isinstance(output, tuple): if isinstance(output, tuple):
output = output[0] output = output[0]
# Check if output is FP8 (some architectures support this)
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
# Keep quantized!
output_scale = input_tensor._scale * weight._scale output_scale = input_tensor._scale * weight._scale
return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) # TODO is this correct? Can't cuBLAS return it
else: else:
return output return output
except Exception as e: except Exception as e:
logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
# Fall through to dequantization path
# Case 2: Only weight is quantized # Case 2: Only weight is quantized
if isinstance(weight, QuantizedTensorFP8): if isinstance(weight, QuantizedTensorFP8):
@ -222,125 +233,3 @@ def handle_linear_fp8(func, args, kwargs):
else: else:
return torch.nn.functional.linear(input_tensor, weight, bias) return torch.nn.functional.linear(input_tensor, weight, bias)
@register_quant_op(torch.ops.aten.silu.default)
def handle_silu_fp8(func, args, kwargs):
"""
SiLU can be computed approximately on FP8.
Keeps activations quantized for next layer.
"""
input_q = args[0]
if not isinstance(input_q, QuantizedTensorFP8):
# Not quantized, use standard path
return torch.nn.functional.silu(input_q)
# Compute SiLU while keeping quantized
# SiLU(x) = x * sigmoid(x)
# Get plain tensor to avoid dispatch recursion
plain_tensor = input_q._raw_data
# Upcast to FP16 for sigmoid stability
x_fp16 = plain_tensor.to(torch.float16)
sigmoid_fp16 = torch.sigmoid(x_fp16 * input_q._scale)
result_fp16 = x_fp16 * sigmoid_fp16
# Convert back to FP8
result_fp8 = result_fp16.to(plain_tensor.dtype)
# Return quantized (scale approximately preserved)
return QuantizedTensorFP8(result_fp8, input_q._scale, input_q._orig_dtype)
@register_quant_op(torch.ops.aten.layer_norm.default)
def handle_layernorm_fp8(func, args, kwargs):
"""
LayerNorm requires high precision.
Dequantizes input and returns standard tensor.
"""
input_q = args[0]
normalized_shape = args[1]
weight = args[2] if len(args) > 2 else None
bias = args[3] if len(args) > 3 else None
eps = args[4] if len(args) > 4 else 1e-5
# Dequantize if needed
if isinstance(input_q, QuantizedTensorFP8):
x = input_q.dequantize()
else:
x = input_q
# Standard LayerNorm
result = torch.nn.functional.layer_norm(x, normalized_shape, weight, bias, eps)
# Return dequantized (next layer will quantize if needed)
return result
@register_quant_op(torch.ops.aten.group_norm.default)
def handle_groupnorm_fp8(func, args, kwargs):
"""
GroupNorm requires high precision.
Dequantizes input and returns standard tensor.
"""
input_q = args[0]
num_groups = args[1]
weight = args[2] if len(args) > 2 else None
bias = args[3] if len(args) > 3 else None
eps = args[4] if len(args) > 4 else 1e-5
# Dequantize if needed
if isinstance(input_q, QuantizedTensorFP8):
x = input_q.dequantize()
else:
x = input_q
# Standard GroupNorm
result = torch.nn.functional.group_norm(x, num_groups, weight, bias, eps)
# Return dequantized
return result
@register_quant_op(torch.ops.aten.add.Tensor)
def handle_add_fp8(func, args, kwargs):
"""
Handle addition with mixed quantized/non-quantized tensors.
"""
a = args[0]
b = args[1]
# If both are quantized, dequantize both
if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8):
return a.dequantize() + b.dequantize()
# If only one is quantized, dequantize it
elif isinstance(a, QuantizedTensorFP8):
return a.dequantize() + b
elif isinstance(b, QuantizedTensorFP8):
return a + b.dequantize()
# Neither is quantized
else:
return a + b
@register_quant_op(torch.ops.aten.mul.Tensor)
def handle_mul_fp8(func, args, kwargs):
"""
Handle multiplication with mixed quantized/non-quantized tensors.
"""
a = args[0]
b = args[1]
# If both are quantized, dequantize both
if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8):
return a.dequantize() * b.dequantize()
# If only one is quantized, dequantize it
elif isinstance(a, QuantizedTensorFP8):
return a.dequantize() * b
elif isinstance(b, QuantizedTensorFP8):
return a * b.dequantize()
# Neither is quantized
else:
return a * b