Fix missing keys

This commit is contained in:
lspindler 2025-10-27 10:04:57 +01:00
parent 218ef4c53c
commit f287d02419

View File

@ -534,6 +534,8 @@ class MixedPrecisionOps(disable_weight_init):
if weight is None: if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}") raise ValueError(f"Missing weight for layer {layer_name}")
manually_loaded_keys = [weight_key]
if layer_name not in MixedPrecisionOps._layer_quant_config: if layer_name not in MixedPrecisionOps._layer_quant_config:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else: else:
@ -544,23 +546,33 @@ class MixedPrecisionOps(disable_weight_init):
mixin = QUANT_FORMAT_MIXINS[quant_format] mixin = QUANT_FORMAT_MIXINS[quant_format]
self.layout_type = mixin["layout_type"] self.layout_type = mixin["layout_type"]
scale_key = f"{prefix}weight_scale"
layout_params = { layout_params = {
'scale': state_dict.pop(f"{prefix}weight_scale", None), 'scale': state_dict.pop(scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype 'orig_dtype': MixedPrecisionOps._compute_dtype
} }
if layout_params['scale'] is not None:
manually_loaded_keys.append(scale_key)
self.weight = torch.nn.Parameter( self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params), QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
requires_grad=False requires_grad=False
) )
for param_name, param_value in mixin["parameters"].items(): for param_name, param_value in mixin["parameters"].items():
_v = state_dict.pop(f"{prefix}{param_name}", None) param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None: if _v is None:
continue continue
setattr(self, param_name, _v.to(device=device)) setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def _forward(self, input, weight, bias): def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)