fix: move convert_old_quants before state_dict_prefix_replace in load_diffusion_model_state_dict

When loading NVFP4 quantized models (e.g. LTX-Video 2.3 NVFP4),
convert_old_quants() was called with an empty prefix ("") BEFORE the
diffusion_model_prefix was determined, and then state_dict_prefix_replace()
stripped the "model.diffusion_model." prefix from weight keys. However,
_quantization_metadata keys retained the full prefix, causing quant
markers to not match weight keys. This resulted in layout_type=None
and a shape mismatch error during loading.

Fix: determine diffusion_model_prefix first, pass it to
convert_old_quants() so metadata keys match, then call
state_dict_prefix_replace() afterwards. The redundant second call to
convert_old_quants inside the len(temp_sd)>0 block is also removed
since conversion now happens correctly before prefix stripping.

This mirrors the correct ordering already used in
load_checkpoint_guess_config().
This commit is contained in:
codeman101 2026-04-08 00:33:59 -07:00
parent b615af1c65
commit 52eec2910f

View File

@ -1743,17 +1743,16 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
""" """
dtype = model_options.get("dtype", None) dtype = model_options.get("dtype", None)
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
#Allow loading unets from checkpoint files #Allow loading unets from checkpoint files
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True) temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
if len(temp_sd) > 0: if len(temp_sd) > 0:
sd = temp_sd sd = temp_sd
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
parameters = comfy.utils.calculate_parameters(sd) parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd) weight_dtype = comfy.utils.weight_dtype(sd)