fix fp8 mixed-precision loading issue

This commit is contained in:
Sandro Cavallari 2025-11-25 01:22:27 +01:00
parent b2ef58e2b1
commit db730ee283

View File

@ -599,6 +599,7 @@ class MixedPrecisionOps(disable_weight_init):
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:
layout_params['scale'] = layout_params['scale'].to(device=device)
manually_loaded_keys.append(weight_scale_key)
self.weight = torch.nn.Parameter(
@ -611,7 +612,7 @@ class MixedPrecisionOps(disable_weight_init):
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
self.register_buffer(param_name, _v.to(device=device))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)