diff --git a/comfy/ops.py b/comfy/ops.py index af185ec24..785aa1c9f 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -646,11 +646,12 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): + fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular + if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") - return mixed_precision_ops(model_config.layer_quant_config, compute_dtype) + return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute) - fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)