Working ModelSave node.

This commit is contained in:
comfyanonymous 2025-11-30 03:29:50 -05:00
parent b440a2bde9
commit e3861b193b
3 changed files with 14 additions and 1 deletions

View File

@ -341,6 +341,7 @@ class BaseModel(torch.nn.Module):
meta_out = {"_quantization_metadata": json.dumps(meta)} meta_out = {"_quantization_metadata": json.dumps(meta)}
meta_out = self.model_config.process_unet_state_dict_for_saving(meta_out) meta_out = self.model_config.process_unet_state_dict_for_saving(meta_out)
metadata.update(meta_out)
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)

View File

@ -562,7 +562,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
_v = state_dict.pop(param_key, None) _v = state_dict.pop(param_key, None)
if _v is None: if _v is None:
continue continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key) manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
@ -571,6 +571,12 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
if key in missing_keys: if key in missing_keys:
missing_keys.remove(key) missing_keys.remove(key)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
if isinstance(self.weight, QuantizedTensor):
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
return sd
def _forward(self, input, weight, bias): def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)

View File

@ -238,6 +238,12 @@ class QuantizedTensor(torch.Tensor):
def is_contiguous(self, *arg, **kwargs): def is_contiguous(self, *arg, **kwargs):
return self._qdata.is_contiguous(*arg, **kwargs) return self._qdata.is_contiguous(*arg, **kwargs)
def storage(self):
return self._qdata.storage()
def untyped_storage(self):
return self._qdata.untyped_storage()
# ============================================================================== # ==============================================================================
# Generic Utilities (Layout-Agnostic Operations) # Generic Utilities (Layout-Agnostic Operations)
# ============================================================================== # ==============================================================================