diff --git a/comfy/model_base.py b/comfy/model_base.py index 9fef5d4ce..7ce5c7b6e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -341,6 +341,7 @@ class BaseModel(torch.nn.Module): meta_out = {"_quantization_metadata": json.dumps(meta)} meta_out = self.model_config.process_unet_state_dict_for_saving(meta_out) + metadata.update(meta_out) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) diff --git a/comfy/ops.py b/comfy/ops.py index b34c80a7f..93dcd5972 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -562,7 +562,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful _v = state_dict.pop(param_key, None) if _v is None: continue - setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) manually_loaded_keys.append(param_key) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @@ -571,6 +571,12 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful if key in missing_keys: missing_keys.remove(key) + def state_dict(self, *args, destination=None, prefix="", **kwargs): + sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) + if isinstance(self.weight, QuantizedTensor): + sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale'] + return sd + def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 82148bd2b..a565162db 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -238,6 +238,12 @@ class QuantizedTensor(torch.Tensor): def is_contiguous(self, *arg, **kwargs): return self._qdata.is_contiguous(*arg, **kwargs) + def storage(self): + return self._qdata.storage() + + def untyped_storage(self): + return self._qdata.untyped_storage() + # ============================================================================== # Generic Utilities (Layout-Agnostic Operations) # ==============================================================================