mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 16:32:34 +08:00
Fix FP8 MM
This commit is contained in:
parent
388294677e
commit
19ce6b056d
14
comfy/ops.py
14
comfy/ops.py
@ -390,19 +390,7 @@ def fp8_linear(self, input):
|
||||
|
||||
# Wrap weight in QuantizedTensorFP8 - this enables unified dispatch
|
||||
quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype)
|
||||
|
||||
# Handle input quantization and wrapping
|
||||
if self.scale_input is None:
|
||||
# Clamp input to FP8 range and quantize
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input_fp8 = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
else:
|
||||
# Apply inverse scale and quantize
|
||||
input_fp8 = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
|
||||
# Wrap input in QuantizedTensorFP8
|
||||
quantized_input = QuantizedTensorFP8(input_fp8, scale_input, orig_dtype=input_dtype)
|
||||
|
||||
quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype)
|
||||
# Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py!
|
||||
# This is the key unification: all FP8 computation goes through one path
|
||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||
|
||||
@ -79,18 +79,47 @@ class QuantizedTensorFP8(torch.Tensor):
|
||||
self._scale = scale
|
||||
self._orig_dtype = orig_dtype
|
||||
# Store a reference to prevent infinite recursion in dequantize
|
||||
self._raw_data = tensor
|
||||
self._raw_data = tensor.contiguous()
|
||||
|
||||
def __repr__(self):
|
||||
return (f"QuantizedTensorFP8(shape={self.shape}, "
|
||||
f"scale={self._scale:.4f}, dtype={self._orig_dtype})")
|
||||
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale, fp8_dtype=torch.float8_e4m3fn):
|
||||
orig_dtype = tensor.dtype
|
||||
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
|
||||
|
||||
tensor_fp8 = None
|
||||
if _CK_AVAILABLE:
|
||||
try:
|
||||
tensor_fp8 = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype)
|
||||
except Exception as e:
|
||||
logging.debug(f"comfy_kitchen quantization failed, using PyTorch: {e}")
|
||||
|
||||
if tensor_fp8 is None:
|
||||
lp_amax = torch.finfo(fp8_dtype).max
|
||||
tensor_scaled = tensor.float() / scale
|
||||
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||
tensor_fp8 = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
|
||||
|
||||
return cls(tensor_fp8, scale, orig_dtype=orig_dtype)
|
||||
|
||||
@classmethod
|
||||
def quantize_dynamic(cls, tensor, strategy="amax", fp8_dtype=torch.float8_e4m3fn):
|
||||
if strategy == "amax":
|
||||
scale = torch.amax(tensor) / torch.finfo(fp8_dtype).max
|
||||
scale = scale.to(tensor.device, dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization strategy: {strategy}. "
|
||||
f"Supported: 'amax'")
|
||||
|
||||
return cls.quantize(tensor, scale, fp8_dtype=fp8_dtype)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
"""
|
||||
Intercept ALL torch operations.
|
||||
Routes to registered handlers or falls back to dequantization.
|
||||
"""
|
||||
kwargs = kwargs or {}
|
||||
|
||||
# Special case: skip dispatch for internal tensor operations
|
||||
@ -134,16 +163,11 @@ class QuantizedTensorFP8(torch.Tensor):
|
||||
return func(*new_args, **new_kwargs)
|
||||
|
||||
def dequantize(self) -> torch.Tensor:
|
||||
"""Explicit dequantization"""
|
||||
# Use the raw data and convert directly
|
||||
# Call aten ops directly to minimize dispatch interference
|
||||
plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype)
|
||||
# Multiply by scale
|
||||
return plain_tensor * self._scale
|
||||
|
||||
def detach(self):
|
||||
"""Detach returns a new QuantizedTensorFP8 (required for Parameter)"""
|
||||
# Detach the raw data and create a new QuantizedTensorFP8
|
||||
detached_data = self._raw_data.detach()
|
||||
return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype)
|
||||
|
||||
@ -165,48 +189,35 @@ def handle_linear_fp8(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
bias = args[2] if len(args) > 2 else None
|
||||
|
||||
out_dtype = kwargs.get("out_dtype", input_tensor._orig_dtype)
|
||||
|
||||
# Case 1: Both input and weight are FP8
|
||||
if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8):
|
||||
# Use _scaled_mm for FP8×FP8 matmul
|
||||
# Get plain tensors to avoid dispatch recursion
|
||||
plain_input = input_tensor._raw_data
|
||||
plain_weight = weight._raw_data
|
||||
weight_t = plain_weight.t().contiguous()
|
||||
weight_t = plain_weight.t() # Keep as column-major for cuBLASLt
|
||||
|
||||
try:
|
||||
if bias is not None:
|
||||
output = torch._scaled_mm(
|
||||
plain_input,
|
||||
weight_t,
|
||||
out_dtype=input_tensor._orig_dtype,
|
||||
bias=bias,
|
||||
scale_a=input_tensor._scale,
|
||||
scale_b=weight._scale
|
||||
)
|
||||
else:
|
||||
output = torch._scaled_mm(
|
||||
plain_input,
|
||||
weight_t,
|
||||
out_dtype=input_tensor._orig_dtype,
|
||||
scale_a=input_tensor._scale,
|
||||
scale_b=weight._scale
|
||||
)
|
||||
|
||||
output = torch._scaled_mm(
|
||||
plain_input,
|
||||
weight_t,
|
||||
bias=bias,
|
||||
scale_a=input_tensor._scale,
|
||||
scale_b=weight._scale,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
|
||||
# Check if output is FP8 (some architectures support this)
|
||||
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
# Keep quantized!
|
||||
output_scale = input_tensor._scale * weight._scale
|
||||
return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype)
|
||||
return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) # TODO is this correct? Can't cuBLAS return it
|
||||
else:
|
||||
return output
|
||||
except Exception as e:
|
||||
logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||
# Fall through to dequantization path
|
||||
|
||||
|
||||
# Case 2: Only weight is quantized
|
||||
if isinstance(weight, QuantizedTensorFP8):
|
||||
weight_dq = weight.dequantize()
|
||||
@ -222,125 +233,3 @@ def handle_linear_fp8(func, args, kwargs):
|
||||
else:
|
||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||
|
||||
|
||||
@register_quant_op(torch.ops.aten.silu.default)
|
||||
def handle_silu_fp8(func, args, kwargs):
|
||||
"""
|
||||
SiLU can be computed approximately on FP8.
|
||||
Keeps activations quantized for next layer.
|
||||
"""
|
||||
input_q = args[0]
|
||||
|
||||
if not isinstance(input_q, QuantizedTensorFP8):
|
||||
# Not quantized, use standard path
|
||||
return torch.nn.functional.silu(input_q)
|
||||
|
||||
# Compute SiLU while keeping quantized
|
||||
# SiLU(x) = x * sigmoid(x)
|
||||
|
||||
# Get plain tensor to avoid dispatch recursion
|
||||
plain_tensor = input_q._raw_data
|
||||
|
||||
# Upcast to FP16 for sigmoid stability
|
||||
x_fp16 = plain_tensor.to(torch.float16)
|
||||
sigmoid_fp16 = torch.sigmoid(x_fp16 * input_q._scale)
|
||||
result_fp16 = x_fp16 * sigmoid_fp16
|
||||
|
||||
# Convert back to FP8
|
||||
result_fp8 = result_fp16.to(plain_tensor.dtype)
|
||||
|
||||
# Return quantized (scale approximately preserved)
|
||||
return QuantizedTensorFP8(result_fp8, input_q._scale, input_q._orig_dtype)
|
||||
|
||||
|
||||
@register_quant_op(torch.ops.aten.layer_norm.default)
|
||||
def handle_layernorm_fp8(func, args, kwargs):
|
||||
"""
|
||||
LayerNorm requires high precision.
|
||||
Dequantizes input and returns standard tensor.
|
||||
"""
|
||||
input_q = args[0]
|
||||
normalized_shape = args[1]
|
||||
weight = args[2] if len(args) > 2 else None
|
||||
bias = args[3] if len(args) > 3 else None
|
||||
eps = args[4] if len(args) > 4 else 1e-5
|
||||
|
||||
# Dequantize if needed
|
||||
if isinstance(input_q, QuantizedTensorFP8):
|
||||
x = input_q.dequantize()
|
||||
else:
|
||||
x = input_q
|
||||
|
||||
# Standard LayerNorm
|
||||
result = torch.nn.functional.layer_norm(x, normalized_shape, weight, bias, eps)
|
||||
|
||||
# Return dequantized (next layer will quantize if needed)
|
||||
return result
|
||||
|
||||
|
||||
@register_quant_op(torch.ops.aten.group_norm.default)
|
||||
def handle_groupnorm_fp8(func, args, kwargs):
|
||||
"""
|
||||
GroupNorm requires high precision.
|
||||
Dequantizes input and returns standard tensor.
|
||||
"""
|
||||
input_q = args[0]
|
||||
num_groups = args[1]
|
||||
weight = args[2] if len(args) > 2 else None
|
||||
bias = args[3] if len(args) > 3 else None
|
||||
eps = args[4] if len(args) > 4 else 1e-5
|
||||
|
||||
# Dequantize if needed
|
||||
if isinstance(input_q, QuantizedTensorFP8):
|
||||
x = input_q.dequantize()
|
||||
else:
|
||||
x = input_q
|
||||
|
||||
# Standard GroupNorm
|
||||
result = torch.nn.functional.group_norm(x, num_groups, weight, bias, eps)
|
||||
|
||||
# Return dequantized
|
||||
return result
|
||||
|
||||
|
||||
@register_quant_op(torch.ops.aten.add.Tensor)
|
||||
def handle_add_fp8(func, args, kwargs):
|
||||
"""
|
||||
Handle addition with mixed quantized/non-quantized tensors.
|
||||
"""
|
||||
a = args[0]
|
||||
b = args[1]
|
||||
|
||||
# If both are quantized, dequantize both
|
||||
if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8):
|
||||
return a.dequantize() + b.dequantize()
|
||||
# If only one is quantized, dequantize it
|
||||
elif isinstance(a, QuantizedTensorFP8):
|
||||
return a.dequantize() + b
|
||||
elif isinstance(b, QuantizedTensorFP8):
|
||||
return a + b.dequantize()
|
||||
# Neither is quantized
|
||||
else:
|
||||
return a + b
|
||||
|
||||
|
||||
@register_quant_op(torch.ops.aten.mul.Tensor)
|
||||
def handle_mul_fp8(func, args, kwargs):
|
||||
"""
|
||||
Handle multiplication with mixed quantized/non-quantized tensors.
|
||||
"""
|
||||
a = args[0]
|
||||
b = args[1]
|
||||
|
||||
# If both are quantized, dequantize both
|
||||
if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8):
|
||||
return a.dequantize() * b.dequantize()
|
||||
# If only one is quantized, dequantize it
|
||||
elif isinstance(a, QuantizedTensorFP8):
|
||||
return a.dequantize() * b
|
||||
elif isinstance(b, QuantizedTensorFP8):
|
||||
return a * b.dequantize()
|
||||
# Neither is quantized
|
||||
else:
|
||||
return a * b
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user