enable test encoders to load Fp8 mixed precision

This commit is contained in:
Sandro Cavallari 2025-11-25 01:23:57 +01:00
parent b95c05d095
commit 277035bb2f
2 changed files with 28 additions and 5 deletions

View File

@ -23,6 +23,7 @@ import comfy.pixel_space_convert
import yaml import yaml
import math import math
import os import os
import json
import comfy.utils import comfy.utils
@ -917,7 +918,20 @@ class CLIPType(Enum):
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = [] clip_data = []
for p in ckpt_paths: for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) clip_data.append(comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True))
if type(clip_data[-1]) == tuple:
model, metadata = clip_data.pop()
if metadata is not None and "_quantization_metadata" in metadata:
try:
quant_metadata = metadata.pop("_quantization_metadata")
quant_metadata = json.loads(quant_metadata)
if "layers" in quant_metadata:
layer_quant_config = quant_metadata["layers"]
model_options["layer_quant_config"] = layer_quant_config
logging.info(f"Detected quantized text encoder: {len(layer_quant_config)} layers with quantization")
except Exception as e:
logging.warning(f"Failed to parse quantization metadata: {e}")
clip_data.append(model)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)

View File

@ -109,13 +109,22 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
operations = model_options.get("custom_operations", None) operations = model_options.get("custom_operations", None)
scaled_fp8 = None scaled_fp8 = None
layer_quant_config = model_options.get("layer_quant_config", None)
if operations is None: if operations is None:
scaled_fp8 = model_options.get("scaled_fp8", None) # Use MixedPrecisionOps if layer_quant_config is present (for FP8 text encoders)
if scaled_fp8 is not None: if layer_quant_config is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) operations = comfy.ops.MixedPrecisionOps
comfy.ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
comfy.ops.MixedPrecisionOps._compute_dtype = dtype
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
else: else:
operations = comfy.ops.manual_cast # Fallback to scaled_fp8_ops for backward compatibility
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
else:
operations = comfy.ops.manual_cast
self.operations = operations self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations) self.transformer = model_class(config, dtype, device, self.operations)