diff --git a/comfy/sd.py b/comfy/sd.py index e264c1ee5..9eea7d63b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -964,7 +964,8 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI clip_data = [] for p in ckpt_paths: sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) - sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) + if model_options.get("custom_operations", None) is None: + sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) clip_data.append(sd) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) @@ -1286,7 +1287,9 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() - sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata) + custom_operations = model_options.get("custom_operations", None) + if custom_operations is None: + sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata) model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: @@ -1300,7 +1303,9 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if model_config.quant_config is not None: weight_dtype = None - model_config.custom_operations = model_options.get("custom_operations", None) + if custom_operations is not None: + model_config.custom_operations = custom_operations + unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None)) if unet_dtype is None: @@ -1327,25 +1332,26 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c vae = VAE(sd=vae_sd, metadata=metadata) if output_clip: - scaled_fp8_list = [] - for k in list(sd.keys()): # Convert scaled fp8 to mixed ops - if k.endswith(".scaled_fp8"): - scaled_fp8_list.append(k[:-len("scaled_fp8")]) + if te_model_options.get("custom_operations", None) is None: + scaled_fp8_list = [] + for k in list(sd.keys()): # Convert scaled fp8 to mixed ops + if k.endswith(".scaled_fp8"): + scaled_fp8_list.append(k[:-len("scaled_fp8")]) + + if len(scaled_fp8_list) > 0: + out_sd = {} + for k in sd: + skip = False + for pref in scaled_fp8_list: + skip = skip or k.startswith(pref) + if not skip: + out_sd[k] = sd[k] - if len(scaled_fp8_list) > 0: - out_sd = {} - for k in sd: - skip = False for pref in scaled_fp8_list: - skip = skip or k.startswith(pref) - if not skip: - out_sd[k] = sd[k] - - for pref in scaled_fp8_list: - quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) - for k in quant_sd: - out_sd[k] = quant_sd[k] - sd = out_sd + quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) + for k in quant_sd: + out_sd[k] = quant_sd[k] + sd = out_sd clip_target = model_config.clip_target(state_dict=sd) if clip_target is not None: @@ -1409,7 +1415,9 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): if len(temp_sd) > 0: sd = temp_sd - sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata) + custom_operations = model_options.get("custom_operations", None) + if custom_operations is None: + sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata) parameters = comfy.utils.calculate_parameters(sd) weight_dtype = comfy.utils.weight_dtype(sd) @@ -1453,7 +1461,10 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) - model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) + + if custom_operations is not None: + model_config.custom_operations = custom_operations + if model_options.get("fp8_optimizations", False): model_config.optimizations["fp8"] = True