mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
enable test encoders to load Fp8 mixed precision
This commit is contained in:
parent
b95c05d095
commit
277035bb2f
16
comfy/sd.py
16
comfy/sd.py
@ -23,6 +23,7 @@ import comfy.pixel_space_convert
|
|||||||
import yaml
|
import yaml
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
@ -917,7 +918,20 @@ class CLIPType(Enum):
|
|||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
for p in ckpt_paths:
|
for p in ckpt_paths:
|
||||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True))
|
||||||
|
if type(clip_data[-1]) == tuple:
|
||||||
|
model, metadata = clip_data.pop()
|
||||||
|
if metadata is not None and "_quantization_metadata" in metadata:
|
||||||
|
try:
|
||||||
|
quant_metadata = metadata.pop("_quantization_metadata")
|
||||||
|
quant_metadata = json.loads(quant_metadata)
|
||||||
|
if "layers" in quant_metadata:
|
||||||
|
layer_quant_config = quant_metadata["layers"]
|
||||||
|
model_options["layer_quant_config"] = layer_quant_config
|
||||||
|
logging.info(f"Detected quantized text encoder: {len(layer_quant_config)} layers with quantization")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to parse quantization metadata: {e}")
|
||||||
|
clip_data.append(model)
|
||||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -109,13 +109,22 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
operations = model_options.get("custom_operations", None)
|
operations = model_options.get("custom_operations", None)
|
||||||
scaled_fp8 = None
|
scaled_fp8 = None
|
||||||
|
layer_quant_config = model_options.get("layer_quant_config", None)
|
||||||
|
|
||||||
if operations is None:
|
if operations is None:
|
||||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
# Use MixedPrecisionOps if layer_quant_config is present (for FP8 text encoders)
|
||||||
if scaled_fp8 is not None:
|
if layer_quant_config is not None:
|
||||||
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
operations = comfy.ops.MixedPrecisionOps
|
||||||
|
comfy.ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||||
|
comfy.ops.MixedPrecisionOps._compute_dtype = dtype
|
||||||
|
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
|
||||||
else:
|
else:
|
||||||
operations = comfy.ops.manual_cast
|
# Fallback to scaled_fp8_ops for backward compatibility
|
||||||
|
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||||
|
if scaled_fp8 is not None:
|
||||||
|
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||||
|
else:
|
||||||
|
operations = comfy.ops.manual_cast
|
||||||
|
|
||||||
self.operations = operations
|
self.operations = operations
|
||||||
self.transformer = model_class(config, dtype, device, self.operations)
|
self.transformer = model_class(config, dtype, device, self.operations)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user