From 277035bb2f2b9b10d50f9564d1729dd6d783a629 Mon Sep 17 00:00:00 2001 From: Sandro Cavallari Date: Tue, 25 Nov 2025 01:23:57 +0100 Subject: [PATCH] enable test encoders to load Fp8 mixed precision --- comfy/sd.py | 16 +++++++++++++++- comfy/sd1_clip.py | 17 +++++++++++++---- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index dc0905ada..d7a8e07c7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -23,6 +23,7 @@ import comfy.pixel_space_convert import yaml import math import os +import json import comfy.utils @@ -917,7 +918,20 @@ class CLIPType(Enum): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = [] for p in ckpt_paths: - clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) + clip_data.append(comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)) + if type(clip_data[-1]) == tuple: + model, metadata = clip_data.pop() + if metadata is not None and "_quantization_metadata" in metadata: + try: + quant_metadata = metadata.pop("_quantization_metadata") + quant_metadata = json.loads(quant_metadata) + if "layers" in quant_metadata: + layer_quant_config = quant_metadata["layers"] + model_options["layer_quant_config"] = layer_quant_config + logging.info(f"Detected quantized text encoder: {len(layer_quant_config)} layers with quantization") + except Exception as e: + logging.warning(f"Failed to parse quantization metadata: {e}") + clip_data.append(model) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 3066de2d7..a04ba74b5 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -109,13 +109,22 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): operations = model_options.get("custom_operations", None) scaled_fp8 = None + layer_quant_config = model_options.get("layer_quant_config", None) if operations is None: - scaled_fp8 = model_options.get("scaled_fp8", None) - if scaled_fp8 is not None: - operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) + # Use MixedPrecisionOps if layer_quant_config is present (for FP8 text encoders) + if layer_quant_config is not None: + operations = comfy.ops.MixedPrecisionOps + comfy.ops.MixedPrecisionOps._layer_quant_config = layer_quant_config + comfy.ops.MixedPrecisionOps._compute_dtype = dtype + logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers") else: - operations = comfy.ops.manual_cast + # Fallback to scaled_fp8_ops for backward compatibility + scaled_fp8 = model_options.get("scaled_fp8", None) + if scaled_fp8 is not None: + operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) + else: + operations = comfy.ops.manual_cast self.operations = operations self.transformer = model_class(config, dtype, device, self.operations)