diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index ec6ac2013..50f988bfd 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -505,7 +505,7 @@ def fp8_linear(func, args, kwargs): return output except Exception as e: - raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") + logger.warning(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") # Case 2: DQ Fallback if isinstance(weight, QuantizedTensor): @@ -542,7 +542,10 @@ def fp8_addmm(func, args, kwargs): bias = args[0] if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None)) + try: + return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None)) + except Exception as e: + logger.warning(f"FP8 addmm failed, falling back to dequantization: {e}") a = list(args) if isinstance(args[0], QuantizedTensor): @@ -560,7 +563,10 @@ def fp8_mm(func, args, kwargs): weight = args[1] if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None)) + try: + return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None)) + except Exception as e: + logger.warning(f"FP8 mm failed, falling back to dequantization: {e}") a = list(args) if isinstance(args[0], QuantizedTensor):