Don't convert quants when custom ops are used.
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled

This commit is contained in:
comfyanonymous 2025-12-05 02:16:34 -05:00 committed by GitHub
parent 129f3e7db1
commit 3193d3aa53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -964,7 +964,8 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_data = [] clip_data = []
for p in ckpt_paths: for p in ckpt_paths:
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) if model_options.get("custom_operations", None) is None:
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
clip_data.append(sd) clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
@ -1286,7 +1287,9 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device() load_device = model_management.get_torch_device()
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata) custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None: if model_config is None:
@ -1300,7 +1303,9 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if model_config.quant_config is not None: if model_config.quant_config is not None:
weight_dtype = None weight_dtype = None
model_config.custom_operations = model_options.get("custom_operations", None) if custom_operations is not None:
model_config.custom_operations = custom_operations
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None)) unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
if unet_dtype is None: if unet_dtype is None:
@ -1327,25 +1332,26 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
vae = VAE(sd=vae_sd, metadata=metadata) vae = VAE(sd=vae_sd, metadata=metadata)
if output_clip: if output_clip:
scaled_fp8_list = [] if te_model_options.get("custom_operations", None) is None:
for k in list(sd.keys()): # Convert scaled fp8 to mixed ops scaled_fp8_list = []
if k.endswith(".scaled_fp8"): for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
scaled_fp8_list.append(k[:-len("scaled_fp8")]) if k.endswith(".scaled_fp8"):
scaled_fp8_list.append(k[:-len("scaled_fp8")])
if len(scaled_fp8_list) > 0:
out_sd = {}
for k in sd:
skip = False
for pref in scaled_fp8_list:
skip = skip or k.startswith(pref)
if not skip:
out_sd[k] = sd[k]
if len(scaled_fp8_list) > 0:
out_sd = {}
for k in sd:
skip = False
for pref in scaled_fp8_list: for pref in scaled_fp8_list:
skip = skip or k.startswith(pref) quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
if not skip: for k in quant_sd:
out_sd[k] = sd[k] out_sd[k] = quant_sd[k]
sd = out_sd
for pref in scaled_fp8_list:
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
for k in quant_sd:
out_sd[k] = quant_sd[k]
sd = out_sd
clip_target = model_config.clip_target(state_dict=sd) clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None: if clip_target is not None:
@ -1409,7 +1415,9 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
if len(temp_sd) > 0: if len(temp_sd) > 0:
sd = temp_sd sd = temp_sd
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata) custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
parameters = comfy.utils.calculate_parameters(sd) parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd) weight_dtype = comfy.utils.weight_dtype(sd)
@ -1453,7 +1461,10 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
else: else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if custom_operations is not None:
model_config.custom_operations = custom_operations
if model_options.get("fp8_optimizations", False): if model_options.get("fp8_optimizations", False):
model_config.optimizations["fp8"] = True model_config.optimizations["fp8"] = True