mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 21:42:37 +08:00
fix fp8 mixed-precision loading issue
This commit is contained in:
parent
b2ef58e2b1
commit
db730ee283
@ -599,6 +599,7 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
'block_size': qconfig.get("group_size", None),
|
'block_size': qconfig.get("group_size", None),
|
||||||
}
|
}
|
||||||
if layout_params['scale'] is not None:
|
if layout_params['scale'] is not None:
|
||||||
|
layout_params['scale'] = layout_params['scale'].to(device=device)
|
||||||
manually_loaded_keys.append(weight_scale_key)
|
manually_loaded_keys.append(weight_scale_key)
|
||||||
|
|
||||||
self.weight = torch.nn.Parameter(
|
self.weight = torch.nn.Parameter(
|
||||||
@ -611,7 +612,7 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
_v = state_dict.pop(param_key, None)
|
_v = state_dict.pop(param_key, None)
|
||||||
if _v is None:
|
if _v is None:
|
||||||
continue
|
continue
|
||||||
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
self.register_buffer(param_name, _v.to(device=device))
|
||||||
manually_loaded_keys.append(param_key)
|
manually_loaded_keys.append(param_key)
|
||||||
|
|
||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user