mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
Fix FP8 MM
This commit is contained in:
parent
388294677e
commit
19ce6b056d
14
comfy/ops.py
14
comfy/ops.py
@ -390,19 +390,7 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
# Wrap weight in QuantizedTensorFP8 - this enables unified dispatch
|
# Wrap weight in QuantizedTensorFP8 - this enables unified dispatch
|
||||||
quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype)
|
quantized_weight = QuantizedTensorFP8(w, scale_weight, orig_dtype=input_dtype)
|
||||||
|
quantized_input = QuantizedTensorFP8.quantize(input.reshape(-1, input_shape[2]), scale_input, fp8_dtype=dtype)
|
||||||
# Handle input quantization and wrapping
|
|
||||||
if self.scale_input is None:
|
|
||||||
# Clamp input to FP8 range and quantize
|
|
||||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
|
||||||
input_fp8 = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
|
||||||
else:
|
|
||||||
# Apply inverse scale and quantize
|
|
||||||
input_fp8 = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
|
||||||
|
|
||||||
# Wrap input in QuantizedTensorFP8
|
|
||||||
quantized_input = QuantizedTensorFP8(input_fp8, scale_input, orig_dtype=input_dtype)
|
|
||||||
|
|
||||||
# Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py!
|
# Call F.linear - __torch_dispatch__ routes to handle_linear_fp8 in quant_ops.py!
|
||||||
# This is the key unification: all FP8 computation goes through one path
|
# This is the key unification: all FP8 computation goes through one path
|
||||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|||||||
@ -79,18 +79,47 @@ class QuantizedTensorFP8(torch.Tensor):
|
|||||||
self._scale = scale
|
self._scale = scale
|
||||||
self._orig_dtype = orig_dtype
|
self._orig_dtype = orig_dtype
|
||||||
# Store a reference to prevent infinite recursion in dequantize
|
# Store a reference to prevent infinite recursion in dequantize
|
||||||
self._raw_data = tensor
|
self._raw_data = tensor.contiguous()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return (f"QuantizedTensorFP8(shape={self.shape}, "
|
return (f"QuantizedTensorFP8(shape={self.shape}, "
|
||||||
f"scale={self._scale:.4f}, dtype={self._orig_dtype})")
|
f"scale={self._scale:.4f}, dtype={self._orig_dtype})")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, scale, fp8_dtype=torch.float8_e4m3fn):
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
|
||||||
|
if not isinstance(scale, torch.Tensor):
|
||||||
|
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
tensor_fp8 = None
|
||||||
|
if _CK_AVAILABLE:
|
||||||
|
try:
|
||||||
|
tensor_fp8 = ck.quantize_per_tensor_fp8(tensor, scale, fp8_dtype)
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug(f"comfy_kitchen quantization failed, using PyTorch: {e}")
|
||||||
|
|
||||||
|
if tensor_fp8 is None:
|
||||||
|
lp_amax = torch.finfo(fp8_dtype).max
|
||||||
|
tensor_scaled = tensor.float() / scale
|
||||||
|
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||||
|
tensor_fp8 = tensor_scaled.to(fp8_dtype, memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
return cls(tensor_fp8, scale, orig_dtype=orig_dtype)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def quantize_dynamic(cls, tensor, strategy="amax", fp8_dtype=torch.float8_e4m3fn):
|
||||||
|
if strategy == "amax":
|
||||||
|
scale = torch.amax(tensor) / torch.finfo(fp8_dtype).max
|
||||||
|
scale = scale.to(tensor.device, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown quantization strategy: {strategy}. "
|
||||||
|
f"Supported: 'amax'")
|
||||||
|
|
||||||
|
return cls.quantize(tensor, scale, fp8_dtype=fp8_dtype)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
"""
|
|
||||||
Intercept ALL torch operations.
|
|
||||||
Routes to registered handlers or falls back to dequantization.
|
|
||||||
"""
|
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
# Special case: skip dispatch for internal tensor operations
|
# Special case: skip dispatch for internal tensor operations
|
||||||
@ -134,16 +163,11 @@ class QuantizedTensorFP8(torch.Tensor):
|
|||||||
return func(*new_args, **new_kwargs)
|
return func(*new_args, **new_kwargs)
|
||||||
|
|
||||||
def dequantize(self) -> torch.Tensor:
|
def dequantize(self) -> torch.Tensor:
|
||||||
"""Explicit dequantization"""
|
|
||||||
# Use the raw data and convert directly
|
|
||||||
# Call aten ops directly to minimize dispatch interference
|
|
||||||
plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype)
|
plain_tensor = torch.ops.aten._to_copy.default(self._raw_data, dtype=self._orig_dtype)
|
||||||
# Multiply by scale
|
|
||||||
return plain_tensor * self._scale
|
return plain_tensor * self._scale
|
||||||
|
|
||||||
def detach(self):
|
def detach(self):
|
||||||
"""Detach returns a new QuantizedTensorFP8 (required for Parameter)"""
|
"""Detach returns a new QuantizedTensorFP8 (required for Parameter)"""
|
||||||
# Detach the raw data and create a new QuantizedTensorFP8
|
|
||||||
detached_data = self._raw_data.detach()
|
detached_data = self._raw_data.detach()
|
||||||
return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype)
|
return QuantizedTensorFP8(detached_data, self._scale, self._orig_dtype)
|
||||||
|
|
||||||
@ -165,47 +189,34 @@ def handle_linear_fp8(func, args, kwargs):
|
|||||||
input_tensor = args[0]
|
input_tensor = args[0]
|
||||||
weight = args[1]
|
weight = args[1]
|
||||||
bias = args[2] if len(args) > 2 else None
|
bias = args[2] if len(args) > 2 else None
|
||||||
|
out_dtype = kwargs.get("out_dtype", input_tensor._orig_dtype)
|
||||||
|
|
||||||
# Case 1: Both input and weight are FP8
|
# Case 1: Both input and weight are FP8
|
||||||
if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8):
|
if isinstance(input_tensor, QuantizedTensorFP8) and isinstance(weight, QuantizedTensorFP8):
|
||||||
# Use _scaled_mm for FP8×FP8 matmul
|
|
||||||
# Get plain tensors to avoid dispatch recursion
|
# Get plain tensors to avoid dispatch recursion
|
||||||
plain_input = input_tensor._raw_data
|
plain_input = input_tensor._raw_data
|
||||||
plain_weight = weight._raw_data
|
plain_weight = weight._raw_data
|
||||||
weight_t = plain_weight.t().contiguous()
|
weight_t = plain_weight.t() # Keep as column-major for cuBLASLt
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if bias is not None:
|
|
||||||
output = torch._scaled_mm(
|
output = torch._scaled_mm(
|
||||||
plain_input,
|
plain_input,
|
||||||
weight_t,
|
weight_t,
|
||||||
out_dtype=input_tensor._orig_dtype,
|
|
||||||
bias=bias,
|
bias=bias,
|
||||||
scale_a=input_tensor._scale,
|
scale_a=input_tensor._scale,
|
||||||
scale_b=weight._scale
|
scale_b=weight._scale,
|
||||||
|
out_dtype=out_dtype,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
output = torch._scaled_mm(
|
|
||||||
plain_input,
|
|
||||||
weight_t,
|
|
||||||
out_dtype=input_tensor._orig_dtype,
|
|
||||||
scale_a=input_tensor._scale,
|
|
||||||
scale_b=weight._scale
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(output, tuple):
|
if isinstance(output, tuple):
|
||||||
output = output[0]
|
output = output[0]
|
||||||
|
|
||||||
# Check if output is FP8 (some architectures support this)
|
|
||||||
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
# Keep quantized!
|
|
||||||
output_scale = input_tensor._scale * weight._scale
|
output_scale = input_tensor._scale * weight._scale
|
||||||
return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype)
|
return QuantizedTensorFP8(output, output_scale, input_tensor._orig_dtype) # TODO is this correct? Can't cuBLAS return it
|
||||||
else:
|
else:
|
||||||
return output
|
return output
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
logging.debug(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||||
# Fall through to dequantization path
|
|
||||||
|
|
||||||
# Case 2: Only weight is quantized
|
# Case 2: Only weight is quantized
|
||||||
if isinstance(weight, QuantizedTensorFP8):
|
if isinstance(weight, QuantizedTensorFP8):
|
||||||
@ -222,125 +233,3 @@ def handle_linear_fp8(func, args, kwargs):
|
|||||||
else:
|
else:
|
||||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
@register_quant_op(torch.ops.aten.silu.default)
|
|
||||||
def handle_silu_fp8(func, args, kwargs):
|
|
||||||
"""
|
|
||||||
SiLU can be computed approximately on FP8.
|
|
||||||
Keeps activations quantized for next layer.
|
|
||||||
"""
|
|
||||||
input_q = args[0]
|
|
||||||
|
|
||||||
if not isinstance(input_q, QuantizedTensorFP8):
|
|
||||||
# Not quantized, use standard path
|
|
||||||
return torch.nn.functional.silu(input_q)
|
|
||||||
|
|
||||||
# Compute SiLU while keeping quantized
|
|
||||||
# SiLU(x) = x * sigmoid(x)
|
|
||||||
|
|
||||||
# Get plain tensor to avoid dispatch recursion
|
|
||||||
plain_tensor = input_q._raw_data
|
|
||||||
|
|
||||||
# Upcast to FP16 for sigmoid stability
|
|
||||||
x_fp16 = plain_tensor.to(torch.float16)
|
|
||||||
sigmoid_fp16 = torch.sigmoid(x_fp16 * input_q._scale)
|
|
||||||
result_fp16 = x_fp16 * sigmoid_fp16
|
|
||||||
|
|
||||||
# Convert back to FP8
|
|
||||||
result_fp8 = result_fp16.to(plain_tensor.dtype)
|
|
||||||
|
|
||||||
# Return quantized (scale approximately preserved)
|
|
||||||
return QuantizedTensorFP8(result_fp8, input_q._scale, input_q._orig_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
@register_quant_op(torch.ops.aten.layer_norm.default)
|
|
||||||
def handle_layernorm_fp8(func, args, kwargs):
|
|
||||||
"""
|
|
||||||
LayerNorm requires high precision.
|
|
||||||
Dequantizes input and returns standard tensor.
|
|
||||||
"""
|
|
||||||
input_q = args[0]
|
|
||||||
normalized_shape = args[1]
|
|
||||||
weight = args[2] if len(args) > 2 else None
|
|
||||||
bias = args[3] if len(args) > 3 else None
|
|
||||||
eps = args[4] if len(args) > 4 else 1e-5
|
|
||||||
|
|
||||||
# Dequantize if needed
|
|
||||||
if isinstance(input_q, QuantizedTensorFP8):
|
|
||||||
x = input_q.dequantize()
|
|
||||||
else:
|
|
||||||
x = input_q
|
|
||||||
|
|
||||||
# Standard LayerNorm
|
|
||||||
result = torch.nn.functional.layer_norm(x, normalized_shape, weight, bias, eps)
|
|
||||||
|
|
||||||
# Return dequantized (next layer will quantize if needed)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@register_quant_op(torch.ops.aten.group_norm.default)
|
|
||||||
def handle_groupnorm_fp8(func, args, kwargs):
|
|
||||||
"""
|
|
||||||
GroupNorm requires high precision.
|
|
||||||
Dequantizes input and returns standard tensor.
|
|
||||||
"""
|
|
||||||
input_q = args[0]
|
|
||||||
num_groups = args[1]
|
|
||||||
weight = args[2] if len(args) > 2 else None
|
|
||||||
bias = args[3] if len(args) > 3 else None
|
|
||||||
eps = args[4] if len(args) > 4 else 1e-5
|
|
||||||
|
|
||||||
# Dequantize if needed
|
|
||||||
if isinstance(input_q, QuantizedTensorFP8):
|
|
||||||
x = input_q.dequantize()
|
|
||||||
else:
|
|
||||||
x = input_q
|
|
||||||
|
|
||||||
# Standard GroupNorm
|
|
||||||
result = torch.nn.functional.group_norm(x, num_groups, weight, bias, eps)
|
|
||||||
|
|
||||||
# Return dequantized
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@register_quant_op(torch.ops.aten.add.Tensor)
|
|
||||||
def handle_add_fp8(func, args, kwargs):
|
|
||||||
"""
|
|
||||||
Handle addition with mixed quantized/non-quantized tensors.
|
|
||||||
"""
|
|
||||||
a = args[0]
|
|
||||||
b = args[1]
|
|
||||||
|
|
||||||
# If both are quantized, dequantize both
|
|
||||||
if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8):
|
|
||||||
return a.dequantize() + b.dequantize()
|
|
||||||
# If only one is quantized, dequantize it
|
|
||||||
elif isinstance(a, QuantizedTensorFP8):
|
|
||||||
return a.dequantize() + b
|
|
||||||
elif isinstance(b, QuantizedTensorFP8):
|
|
||||||
return a + b.dequantize()
|
|
||||||
# Neither is quantized
|
|
||||||
else:
|
|
||||||
return a + b
|
|
||||||
|
|
||||||
|
|
||||||
@register_quant_op(torch.ops.aten.mul.Tensor)
|
|
||||||
def handle_mul_fp8(func, args, kwargs):
|
|
||||||
"""
|
|
||||||
Handle multiplication with mixed quantized/non-quantized tensors.
|
|
||||||
"""
|
|
||||||
a = args[0]
|
|
||||||
b = args[1]
|
|
||||||
|
|
||||||
# If both are quantized, dequantize both
|
|
||||||
if isinstance(a, QuantizedTensorFP8) and isinstance(b, QuantizedTensorFP8):
|
|
||||||
return a.dequantize() * b.dequantize()
|
|
||||||
# If only one is quantized, dequantize it
|
|
||||||
elif isinstance(a, QuantizedTensorFP8):
|
|
||||||
return a.dequantize() * b
|
|
||||||
elif isinstance(b, QuantizedTensorFP8):
|
|
||||||
return a * b.dequantize()
|
|
||||||
# Neither is quantized
|
|
||||||
else:
|
|
||||||
return a * b
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user