From e1e50e017cf950bc0368dc3ae8e2ed9023901370 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 30 Nov 2025 00:33:01 -0500 Subject: [PATCH] Make old scaled fp8 format use the new mixed quant ops system. Since model saving is not implemented for the mixed quant system this breaks model saving for every scaled fp8 model which needs to be fixed before this gets merged. --- comfy/model_base.py | 28 ++++---- comfy/model_detection.py | 17 +---- comfy/model_patcher.py | 20 +----- comfy/ops.py | 97 ++++++---------------------- comfy/quant_ops.py | 17 +++-- comfy/sd.py | 94 ++++++++++++++++++++++++--- comfy/sd1_clip.py | 27 +++----- comfy/supported_models_base.py | 1 - comfy/text_encoders/cosmos.py | 12 ++-- comfy/text_encoders/flux.py | 14 ++-- comfy/text_encoders/genmo.py | 6 +- comfy/text_encoders/hidream.py | 10 +-- comfy/text_encoders/hunyuan_image.py | 12 ++-- comfy/text_encoders/hunyuan_video.py | 21 +++--- comfy/text_encoders/lumina2.py | 6 +- comfy/text_encoders/omnigen2.py | 6 +- comfy/text_encoders/pixart_t5.py | 6 +- comfy/text_encoders/qwen_image.py | 6 +- comfy/text_encoders/sd3_clip.py | 18 +++--- comfy/text_encoders/wan.py | 6 +- comfy/text_encoders/z_image.py | 5 +- 21 files changed, 204 insertions(+), 225 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 9b76c285e..9fef5d4ce 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -57,6 +57,7 @@ from . import utils import comfy.latent_formats import comfy.model_sampling import math +import json from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher @@ -134,7 +135,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) - operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config) + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) @@ -319,7 +320,7 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None, metadata=None): extra_sds = [] if clip_state_dict is not None: extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) @@ -330,16 +331,16 @@ class BaseModel(torch.nn.Module): unet_state_dict = self.diffusion_model.state_dict() - if self.model_config.scaled_fp8 is not None: - unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) - # Save mixed precision metadata - if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: - metadata = { - "format_version": "1.0", - "layers": self.model_config.layer_quant_config - } - unet_state_dict["_quantization_metadata"] = metadata + if metadata is not None: + if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: + meta = { + "format_version": "1.0", + "layers": self.model_config.layer_quant_config + } + + meta_out = {"_quantization_metadata": json.dumps(meta)} + meta_out = self.model_config.process_unet_state_dict_for_saving(meta_out) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) @@ -349,7 +350,10 @@ class BaseModel(torch.nn.Module): for sd in extra_sds: unet_state_dict.update(sd) - return unet_state_dict + if metadata is not None: + return unet_state_dict, metadata + else: + return unet_state_dict def set_inpaint(self): self.concat_keys = ("mask", "masked_image") diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7afe4a798..dfcb6621c 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -7,8 +7,8 @@ import logging import torch -def detect_layer_quantization(metadata): - quant_key = "_quantization_metadata" +def detect_layer_quantization(metadata, key_prefix=""): + quant_key = "{}_quantization_metadata".format(key_prefix) if metadata is not None and quant_key in metadata: quant_metadata = metadata.pop(quant_key) quant_metadata = json.loads(quant_metadata) @@ -759,19 +759,8 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal if model_config is None and use_base_if_no_match: model_config = comfy.supported_models_base.BASE(unet_config) - scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix) - if scaled_fp8_key in state_dict: - scaled_fp8_weight = state_dict.pop(scaled_fp8_key) - model_config.scaled_fp8 = scaled_fp8_weight.dtype - if model_config.scaled_fp8 == torch.float32: - model_config.scaled_fp8 = torch.float8_e4m3fn - if scaled_fp8_weight.nelement() == 2: - model_config.optimizations["fp8"] = False - else: - model_config.optimizations["fp8"] = True - # Detect per-layer quantization (mixed precision) - layer_quant_config = detect_layer_quantization(metadata) + layer_quant_config = detect_layer_quantization(metadata, unet_key_prefix) if layer_quant_config: model_config.layer_quant_config = layer_quant_config logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3eac77275..d2762d0d0 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -126,27 +126,11 @@ class LowVramPatch: def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key self.patches = patches - self.convert_func = convert_func + self.convert_func = convert_func # TODO: remove self.set_func = set_func def __call__(self, weight): - intermediate_dtype = weight.dtype - if self.convert_func is not None: - weight = self.convert_func(weight, inplace=False) - - if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops - intermediate_dtype = torch.float32 - out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype) - if self.set_func is None: - return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key)) - else: - return self.set_func(out, seed=string_to_seed(self.key), return_weight=True) - - out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) - if self.set_func is not None: - return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype) - else: - return out + return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) #The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 diff --git a/comfy/ops.py b/comfy/ops.py index 61a2f0754..b34c80a7f 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -417,22 +417,12 @@ def fp8_linear(self, input): if input.ndim == 3 or input.ndim == 2: w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - scale_weight = self.scale_weight - scale_input = self.scale_input - if scale_weight is None: - scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - else: - scale_weight = scale_weight.to(input.device) - - if scale_input is None: - scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} - quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) - else: - scale_input = scale_input.to(input.device) - quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype) + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} + quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! @@ -466,59 +456,6 @@ class fp8_ops(manual_cast): uncast_bias_weight(self, weight, bias, offload_stream) return x -def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): - logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) - class scaled_fp8_op(manual_cast): - class Linear(manual_cast.Linear): - def __init__(self, *args, **kwargs): - if override_dtype is not None: - kwargs['dtype'] = override_dtype - super().__init__(*args, **kwargs) - - def reset_parameters(self): - if not hasattr(self, 'scale_weight'): - self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) - - if not scale_input: - self.scale_input = None - - if not hasattr(self, 'scale_input'): - self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) - return None - - def forward_comfy_cast_weights(self, input): - if fp8_matrix_mult: - out = fp8_linear(self, input) - if out is not None: - return out - - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - - if weight.numel() < input.numel(): #TODO: optimize - x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) - else: - x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) - uncast_bias_weight(self, weight, bias, offload_stream) - return x - - def convert_weight(self, weight, inplace=False, **kwargs): - if inplace: - weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) - return weight - else: - return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32) - - def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): - weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) - if return_weight: - return weight - if inplace_update: - self.weight.data.copy_(weight) - else: - self.weight = torch.nn.Parameter(weight, requires_grad=False) - - return scaled_fp8_op - CUBLAS_IS_AVAILABLE = False try: from cublas_ops import CublasLinear @@ -593,7 +530,11 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful if layer_name not in MixedPrecisionOps._layer_quant_config: self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: - quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) + layer_conf = MixedPrecisionOps._layer_quant_config[layer_name] + quant_format = layer_conf.get("format", None) + if not self._full_precision_mm: + self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False) + if quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") @@ -601,16 +542,18 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful self.layout_type = qconfig["comfy_tensor_layout"] weight_scale_key = f"{prefix}weight_scale" + scale = state_dict.pop(weight_scale_key, None) layout_params = { - 'scale': state_dict.pop(weight_scale_key, None), + 'scale': scale, 'orig_dtype': MixedPrecisionOps._compute_dtype, 'block_size': qconfig.get("group_size", None), } - if layout_params['scale'] is not None: + + if scale is not None: manually_loaded_keys.append(weight_scale_key) self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), + QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params), requires_grad=False ) @@ -643,9 +586,8 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(input, *args, **kwargs) if (getattr(self, 'layout_type', None) is not None and - getattr(self, 'input_scale', None) is not None and not isinstance(input, QuantizedTensor)): - input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) + input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype) return self._forward(input, self.weight, self.bias) def convert_weight(self, weight, inplace=False, **kwargs): @@ -656,7 +598,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: - weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) else: weight = weight.to(self.weight.dtype) if return_weight: @@ -667,16 +609,13 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful return MixedPrecisionOps -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute) - if scaled_fp8 is not None: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) - if ( fp8_compute and (fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index bb1fb860c..82148bd2b 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -397,17 +397,20 @@ class TensorCoreFP8Layout(QuantizedLayout): def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): orig_dtype = tensor.dtype - if scale is None: + if scale == "recalculate": scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max - if not isinstance(scale, torch.Tensor): - scale = torch.tensor(scale) - scale = scale.to(device=tensor.device, dtype=torch.float32) + if scale is not None: + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) - if inplace_ops: - tensor *= (1.0 / scale).to(tensor.dtype) + if inplace_ops: + tensor *= (1.0 / scale).to(tensor.dtype) + else: + tensor = tensor * (1.0 / scale).to(tensor.dtype) else: - tensor = tensor * (1.0 / scale).to(tensor.dtype) + scale = torch.ones((), device=tensor.device, dtype=torch.float32) if stochastic_rounding > 0: tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) diff --git a/comfy/sd.py b/comfy/sd.py index 9eeb0c45a..d44d88edb 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -962,6 +962,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI clip_data = [] for p in ckpt_paths: sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) + sd, metadata = preprocess_diffusion_state_dict(sd, diffusion_model_prefix="", metadata=metadata) if metadata is not None: quant_metadata = metadata.get("_quantization_metadata", None) if quant_metadata is not None: @@ -1078,7 +1079,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HIDREAM: - clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLRefinerClipModel @@ -1102,7 +1103,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.HIDREAM: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), - clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) + clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) @@ -1131,7 +1132,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.LLAMA3_8: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), - clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) + clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer elif te_model == TEModel.QWEN25_3B: clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data)) @@ -1156,7 +1157,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HIDREAM: - clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sd1_clip.SD1ClipModel @@ -1270,6 +1271,56 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) return out +def preprocess_diffusion_state_dict(state_dict, diffusion_model_prefix="", metadata={}, model_options={}, output_prefix=""): + if metadata is None: + metadata = {} + + if "_quantization_metadata" not in metadata: + scaled_fp8_key = "{}scaled_fp8".format(diffusion_model_prefix) + + if scaled_fp8_key in state_dict: + scaled_fp8_weight = state_dict[scaled_fp8_key] + scaled_fp8_dtype = scaled_fp8_weight.dtype + if scaled_fp8_dtype == torch.float32: + scaled_fp8_dtype = torch.float8_e4m3fn + + if scaled_fp8_weight.nelement() == 2: + full_precision_matrix_mult = True + else: + full_precision_matrix_mult = False + + out_sd = {} + layers = {} + for k in list(state_dict.keys()): + if not k.startswith(diffusion_model_prefix): + out_sd[k] = state_dict[k] + continue + k_out = k + w = state_dict.pop(k) + layer = None + if k_out.endswith(".scale_weight"): + layer = k_out[:-len(".scale_weight")] + k_out = "{}.weight_scale".format(layer) + + if layer is not None: + layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints + if full_precision_matrix_mult: + layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult + layers["{}{}".format(output_prefix, layer[len(diffusion_model_prefix):])] = layer_conf + + if k_out.endswith(".scale_input"): + layer = k_out[:-len(".scale_input")] + k_out = "{}.input_scale".format(layer) + if w.item() == 1.0: + continue + + out_sd[k_out] = w + + state_dict = out_sd + metadata["{}_quantization_metadata".format(diffusion_model_prefix)] = json.dumps({"format_version": "1.0", "layers": layers}) + + return state_dict, metadata + def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): clip = None clipvision = None @@ -1282,6 +1333,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() + sd, metadata = preprocess_diffusion_state_dict(sd, diffusion_model_prefix, metadata=metadata, model_options=model_options) + model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.") @@ -1292,7 +1345,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c unet_weight_dtype = list(model_config.supported_inference_dtypes) - if model_config.scaled_fp8 is not None: + if model_config.layer_quant_config is not None: weight_dtype = None model_config.custom_operations = model_options.get("custom_operations", None) @@ -1301,7 +1354,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if unet_dtype is None: unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype) - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + if model_config.layer_quant_config is not None: + manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) + else: + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if model_config.clip_vision_prefix is not None: @@ -1319,7 +1375,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c vae = VAE(sd=vae_sd, metadata=metadata) if output_clip: + to_pop = [] + for k in list(sd.keys()): # Convert scaled fp8 to mixed ops + if k.endswith(".scaled_fp8"): + pref = k[:-len("scaled_fp8")] + out_pref = ".".join(pref.split(".")[1:]) + sd, qmetadata = preprocess_diffusion_state_dict(sd, pref, metadata={}, model_options=model_options, output_prefix=out_pref) + for mk in qmetadata: + sd[mk] = qmetadata[mk] + to_pop.append(mk) + to_pop.append(k) + clip_target = model_config.clip_target(state_dict=sd) + + for k in to_pop: # pop the keys for mixed ops + sd.pop(k) + if clip_target is not None: clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: @@ -1380,7 +1451,11 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True) if len(temp_sd) > 0: sd = temp_sd + quant_key = "{}_quantization_metadata".format(diffusion_model_prefix) + if metadata is not None and quant_key in metadata: + metadata["_quantization_metadata"] = metadata.pop(quant_key) + sd, metadata = preprocess_diffusion_state_dict(sd, "", metadata=metadata, model_options=model_options) parameters = comfy.utils.calculate_parameters(sd) weight_dtype = comfy.utils.weight_dtype(sd) @@ -1411,7 +1486,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): offload_device = model_management.unet_offload_device() unet_weight_dtype = list(model_config.supported_inference_dtypes) - if model_config.scaled_fp8 is not None: + if model_config.layer_quant_config is not None: weight_dtype = None if dtype is None: @@ -1463,9 +1538,12 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if vae is not None: vae_sd = vae.get_sd() + if metadata is None: + metadata = {} + model_management.load_models_gpu(load_models, force_patch_weights=True) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None - sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) + sd, metadata = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd, metadata=metadata) for k in extra_keys: sd[k] = extra_keys[k] diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 0fc9ab3db..2d3343513 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -90,6 +90,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32 super().__init__() + assert layer in self.LAYERS if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") @@ -107,29 +108,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): config[k] = v operations = model_options.get("custom_operations", None) - scaled_fp8 = None quantization_metadata = model_options.get("quantization_metadata", None) + self.layer_quant_config = None if operations is None: - layer_quant_config = None if quantization_metadata is not None: - layer_quant_config = json.loads(quantization_metadata).get("layers", None) + self.layer_quant_config = json.loads(quantization_metadata).get("layers", None) - if layer_quant_config is not None: - operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True) - logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers") + if self.layer_quant_config is not None: + operations = comfy.ops.mixed_precision_ops(self.layer_quant_config, dtype, full_precision_mm=True) + logging.info(f"Using MixedPrecisionOps for text encoder: {len(self.layer_quant_config)} quantized layers") else: - # Fallback to scaled_fp8_ops for backward compatibility - scaled_fp8 = model_options.get("scaled_fp8", None) - if scaled_fp8 is not None: - operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) - else: - operations = comfy.ops.manual_cast + operations = comfy.ops.manual_cast self.operations = operations self.transformer = model_class(config, dtype, device, self.operations) - if scaled_fp8 is not None: - self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8)) self.num_layers = self.transformer.num_layers @@ -163,7 +156,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_clip_options(self, options): layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) - if isinstance(self.layer, list) or self.layer == "all": + if self.layer == "all": pass elif layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" @@ -265,9 +258,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if self.enable_attention_masks: attention_mask_model = attention_mask - if isinstance(self.layer, list): - intermediate_output = self.layer - elif self.layer == "all": + if self.layer == "all": intermediate_output = "all" else: intermediate_output = self.layer_idx diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index e4bd74514..7ba7ef645 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -49,7 +49,6 @@ class BASE: manual_cast_dtype = None custom_operations = None - scaled_fp8 = None layer_quant_config = None # Per-layer quantization configuration for mixed precision optimizations = {"fp8": False} diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index a1adb5242..448381fa9 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -7,10 +7,10 @@ from transformers import T5TokenizerFast class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json") - t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) - if t5xxl_scaled_fp8 is not None: + t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None) + if t5xxl_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5xxl_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options) @@ -30,12 +30,12 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def te(dtype_t5=None, t5xxl_scaled_fp8=None): +def te(dtype_t5=None, t5_quantization_metadata=None): class CosmosTEModel_(CosmosT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 99f4812bb..913ad5ac2 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -63,12 +63,12 @@ class FluxClipModel(torch.nn.Module): else: return self.t5xxl.load_sd(sd) -def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None): +def flux_clip(dtype_t5=None, t5_quantization_metadata=None): class FluxClipModel_(FluxClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) return FluxClipModel_ @@ -159,16 +159,14 @@ class Flux2TEModel(sd1_clip.SD1ClipModel): out = out.reshape(out.shape[0], out.shape[1], -1) return out, pooled, extra -def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False): +def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False): class Flux2TEModel_(Flux2TEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: - model_options["quantization_metadata"] = llama_quantization_metadata + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata if pruned: model_options = model_options.copy() model_options["num_layers"] = 30 diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 9dcf190a2..5daea8135 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -26,12 +26,12 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None): +def mochi_te(dtype_t5=None, t5_quantization_metadata=None): class MochiTEModel_(MochiT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index dbcf52784..600b34480 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -142,14 +142,14 @@ class HiDreamTEModel(torch.nn.Module): return self.llama.load_sd(sd) -def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None): +def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None): class HiDreamTEModel_(HiDreamTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return HiDreamTEModel_ diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py index ff04726e1..cd198036c 100644 --- a/comfy/text_encoders/hunyuan_image.py +++ b/comfy/text_encoders/hunyuan_image.py @@ -40,10 +40,10 @@ class HunyuanImageTokenizer(QwenImageTokenizer): class Qwen25_7BVLIModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): - llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -91,12 +91,12 @@ class HunyuanImageTEModel(QwenImageTEModel): else: return super().load_sd(sd) -def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None): +def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None): class QwenImageTEModel_(HunyuanImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["qwen_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 0110517bb..f126434fd 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -14,12 +14,9 @@ def llama_detect(state_dict, prefix=""): if t5_key in state_dict: out["dtype_llama"] = state_dict[t5_key].dtype - scaled_fp8_key = "{}scaled_fp8".format(prefix) - if scaled_fp8_key in state_dict: - out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype - - if "_quantization_metadata" in state_dict: - out["llama_quantization_metadata"] = state_dict["_quantization_metadata"] + quant_key = "{}_quantization_metadata".format(prefix) + if quant_key in state_dict: + out["llama_quantization_metadata"] = state_dict[quant_key] return out @@ -31,10 +28,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer): class LLAMAModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}): - llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata textmodel_json_config = {} vocab_size = model_options.get("vocab_size", None) @@ -161,11 +158,11 @@ class HunyuanVideoClipModel(torch.nn.Module): return self.llama.load_sd(sd) -def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None): +def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None): class HunyuanVideoClipModel_(HunyuanVideoClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return HunyuanVideoClipModel_ diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index fd986e2c1..7a6cfdab2 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -40,7 +40,7 @@ class LuminaModel(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"): +def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"): if model_type == "gemma2_2b": model = Gemma2_2BModel elif model_type == "gemma3_4b": @@ -48,9 +48,9 @@ def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"): class LuminaTEModel_(LuminaModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model) diff --git a/comfy/text_encoders/omnigen2.py b/comfy/text_encoders/omnigen2.py index 1a01b2dd4..50aa4121f 100644 --- a/comfy/text_encoders/omnigen2.py +++ b/comfy/text_encoders/omnigen2.py @@ -32,12 +32,12 @@ class Omnigen2Model(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class Omnigen2TEModel_(Omnigen2Model): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index 5f383de07..e5e5f18be 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -30,12 +30,12 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None): +def pixart_te(dtype_t5=None, t5_quantization_metadata=None): class PixArtTEModel_(PixArtT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index c0d32a6ef..5c14dec23 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -85,12 +85,12 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel): return out, pooled, extra -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class QwenImageTEModel_(QwenImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index ff5d412db..fae72a9c2 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -10,10 +10,10 @@ import logging class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") - t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) - if t5xxl_scaled_fp8 is not None: + t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None) + if t5xxl_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5xxl_quantization_metadata model_options = {**model_options, "model_name": "t5xxl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -25,9 +25,9 @@ def t5_xxl_detect(state_dict, prefix=""): if t5_key in state_dict: out["dtype_t5"] = state_dict[t5_key].dtype - scaled_fp8_key = "{}scaled_fp8".format(prefix) - if scaled_fp8_key in state_dict: - out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype + quant_key = "{}_quantization_metadata".format(prefix) + if quant_key in state_dict: + out["t5_quantization_metadata"] = state_dict[quant_key] return out @@ -156,11 +156,11 @@ class SD3ClipModel(torch.nn.Module): else: return self.t5xxl.load_sd(sd) -def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False): +def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False): class SD3ClipModel_(SD3ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options) return SD3ClipModel_ diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py index d50fa4b28..164a57edd 100644 --- a/comfy/text_encoders/wan.py +++ b/comfy/text_encoders/wan.py @@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs) -def te(dtype_t5=None, t5xxl_scaled_fp8=None): +def te(dtype_t5=None, t5_quantization_metadata=None): class WanTEModel(WanT5Model): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5_quantization_metadata if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py index bb9273b20..19adde0b7 100644 --- a/comfy/text_encoders/z_image.py +++ b/comfy/text_encoders/z_image.py @@ -34,12 +34,9 @@ class ZImageTEModel(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class ZImageTEModel_(ZImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: