mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 03:23:00 +08:00
Working ModelSave node.
This commit is contained in:
parent
b440a2bde9
commit
e3861b193b
@ -341,6 +341,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
meta_out = {"_quantization_metadata": json.dumps(meta)}
|
meta_out = {"_quantization_metadata": json.dumps(meta)}
|
||||||
meta_out = self.model_config.process_unet_state_dict_for_saving(meta_out)
|
meta_out = self.model_config.process_unet_state_dict_for_saving(meta_out)
|
||||||
|
metadata.update(meta_out)
|
||||||
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
|
||||||
|
|||||||
@ -562,7 +562,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
_v = state_dict.pop(param_key, None)
|
_v = state_dict.pop(param_key, None)
|
||||||
if _v is None:
|
if _v is None:
|
||||||
continue
|
continue
|
||||||
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||||
manually_loaded_keys.append(param_key)
|
manually_loaded_keys.append(param_key)
|
||||||
|
|
||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
@ -571,6 +571,12 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
|||||||
if key in missing_keys:
|
if key in missing_keys:
|
||||||
missing_keys.remove(key)
|
missing_keys.remove(key)
|
||||||
|
|
||||||
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||||
|
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
||||||
|
if isinstance(self.weight, QuantizedTensor):
|
||||||
|
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
|
||||||
|
return sd
|
||||||
|
|
||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
|||||||
@ -238,6 +238,12 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
def is_contiguous(self, *arg, **kwargs):
|
def is_contiguous(self, *arg, **kwargs):
|
||||||
return self._qdata.is_contiguous(*arg, **kwargs)
|
return self._qdata.is_contiguous(*arg, **kwargs)
|
||||||
|
|
||||||
|
def storage(self):
|
||||||
|
return self._qdata.storage()
|
||||||
|
|
||||||
|
def untyped_storage(self):
|
||||||
|
return self._qdata.untyped_storage()
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Generic Utilities (Layout-Agnostic Operations)
|
# Generic Utilities (Layout-Agnostic Operations)
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user