Fix quantization fallback

This commit is contained in:
doctorpangloss 2025-12-12 13:07:15 -08:00
parent 5a1dda8bd0
commit 1fd1c2c7fc

View File

@ -505,7 +505,7 @@ def fp8_linear(func, args, kwargs):
return output return output
except Exception as e: except Exception as e:
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") logger.warning(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
# Case 2: DQ Fallback # Case 2: DQ Fallback
if isinstance(weight, QuantizedTensor): if isinstance(weight, QuantizedTensor):
@ -542,7 +542,10 @@ def fp8_addmm(func, args, kwargs):
bias = args[0] bias = args[0]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None)) try:
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
except Exception as e:
logger.warning(f"FP8 addmm failed, falling back to dequantization: {e}")
a = list(args) a = list(args)
if isinstance(args[0], QuantizedTensor): if isinstance(args[0], QuantizedTensor):
@ -560,7 +563,10 @@ def fp8_mm(func, args, kwargs):
weight = args[1] weight = args[1]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None)) try:
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
except Exception as e:
logger.warning(f"FP8 mm failed, falling back to dequantization: {e}")
a = list(args) a = list(args)
if isinstance(args[0], QuantizedTensor): if isinstance(args[0], QuantizedTensor):