diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 33bdedfb1..2ea14bc2c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -242,6 +242,37 @@ class LazyCastingParam(torch.nn.Parameter): return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu") +class LazyCastingQuantizedParam: + def __init__(self, model, key): + self.model = model + self.key = key + self.cpu_state_dict = None + + def state_dict_tensor(self, state_dict_key): + if self.cpu_state_dict is None: + weight = self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True) + self.cpu_state_dict = {k: v.to("cpu") for k, v in weight.state_dict(self.key).items()} + return self.cpu_state_dict[state_dict_key] + + +class LazyCastingParamPiece(torch.nn.Parameter): + def __new__(cls, caster, state_dict_key, tensor): + return super().__new__(cls, tensor) + + def __init__(self, caster, state_dict_key, tensor): + self.caster = caster + self.state_dict_key = state_dict_key + + @property + def device(self): + return CustomTorchDevice + + def to(self, *args, **kwargs): + caster = self.caster + del self.caster + return caster.state_dict_tensor(self.state_dict_key) + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -1463,20 +1494,37 @@ class ModelPatcher: self.clear_cached_hook_weights() def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): - unet_state_dict = self.model.diffusion_model.state_dict() - for k, v in unet_state_dict.items(): + original_state_dict = self.model.diffusion_model.state_dict() + unet_state_dict = {} + keys = list(original_state_dict) + while len(keys) > 0: + k = keys.pop(0) + v = original_state_dict[k] op_keys = k.rsplit('.', 1) if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]: + unet_state_dict[k] = v continue try: op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0]) except: + unet_state_dict[k] = v continue if not op or not hasattr(op, "comfy_cast_weights") or \ (hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True): + unet_state_dict[k] = v continue key = "diffusion_model." + k - unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key)) + weight = comfy.utils.get_attr(self.model, key) + if isinstance(weight, QuantizedTensor) and k in original_state_dict: + qt_state_dict = weight.state_dict(k) + caster = LazyCastingQuantizedParam(self, key) + for group_key in (x for x in qt_state_dict if x in original_state_dict): + if group_key in keys: + keys.remove(group_key) + unet_state_dict.pop(group_key, "") + unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key]) + continue + unet_state_dict[k] = LazyCastingParam(self, key, weight) return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) def __del__(self):