mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
Fix missing keys
This commit is contained in:
parent
218ef4c53c
commit
f287d02419
18
comfy/ops.py
18
comfy/ops.py
@ -534,6 +534,8 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
if weight is None:
|
if weight is None:
|
||||||
raise ValueError(f"Missing weight for layer {layer_name}")
|
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||||
|
|
||||||
|
manually_loaded_keys = [weight_key]
|
||||||
|
|
||||||
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
||||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||||
else:
|
else:
|
||||||
@ -544,23 +546,33 @@ class MixedPrecisionOps(disable_weight_init):
|
|||||||
mixin = QUANT_FORMAT_MIXINS[quant_format]
|
mixin = QUANT_FORMAT_MIXINS[quant_format]
|
||||||
self.layout_type = mixin["layout_type"]
|
self.layout_type = mixin["layout_type"]
|
||||||
|
|
||||||
|
scale_key = f"{prefix}weight_scale"
|
||||||
layout_params = {
|
layout_params = {
|
||||||
'scale': state_dict.pop(f"{prefix}weight_scale", None),
|
'scale': state_dict.pop(scale_key, None),
|
||||||
'orig_dtype': MixedPrecisionOps._compute_dtype
|
'orig_dtype': MixedPrecisionOps._compute_dtype
|
||||||
}
|
}
|
||||||
|
if layout_params['scale'] is not None:
|
||||||
|
manually_loaded_keys.append(scale_key)
|
||||||
|
|
||||||
self.weight = torch.nn.Parameter(
|
self.weight = torch.nn.Parameter(
|
||||||
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
|
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
|
||||||
requires_grad=False
|
requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
for param_name, param_value in mixin["parameters"].items():
|
for param_name, param_value in mixin["parameters"].items():
|
||||||
_v = state_dict.pop(f"{prefix}{param_name}", None)
|
param_key = f"{prefix}{param_name}"
|
||||||
|
_v = state_dict.pop(param_key, None)
|
||||||
if _v is None:
|
if _v is None:
|
||||||
continue
|
continue
|
||||||
setattr(self, param_name, _v.to(device=device))
|
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||||
|
manually_loaded_keys.append(param_key)
|
||||||
|
|
||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
for key in manually_loaded_keys:
|
||||||
|
if key in missing_keys:
|
||||||
|
missing_keys.remove(key)
|
||||||
|
|
||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user