diff --git a/comfy/model_base.py b/comfy/model_base.py index 49efd700b..b9e2abdb9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -321,7 +321,7 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): extra_sds = [] if clip_state_dict is not None: extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) @@ -329,10 +329,7 @@ class BaseModel(torch.nn.Module): extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict)) if clip_vision_state_dict is not None: extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) - - unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) - if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) @@ -775,8 +772,8 @@ class StableAudio1(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): - sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()} for k in d: s = d[k] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f6b80a40f..30ca39b2a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -24,6 +24,7 @@ import inspect import logging import math import uuid +import types from typing import Callable, Optional import torch @@ -212,6 +213,27 @@ class MemoryCounter: def decrement(self, used: int): self.value -= used +CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0) + +class LazyCastingParam(torch.nn.Parameter): + def __new__(cls, model, key, tensor): + return super().__new__(cls, tensor) + + def __init__(self, model, key, tensor): + self.model = model + self.key = key + + @property + def device(self): + return CustomTorchDevice + + #safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is + #then just a short lived thing in the safetensors serialization logic inside its big for loop over + #all weights getting garbage collected per-weight + def to(self, *args, **kwargs): + return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu") + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -611,14 +633,14 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False): - if key not in self.patches: - return - + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False): weight, set_func, convert_func = get_key_weight(self.model, key) + if key not in self.patches: + return weight + inplace_update = self.weight_inplace_update or inplace_update - if key not in self.backup: + if key not in self.backup and not return_weight: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) temp_dtype = comfy.model_management.lora_compute_dtype(device_to) @@ -632,12 +654,14 @@ class ModelPatcher: out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) - if inplace_update: + if return_weight: + return out_weight + elif inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: comfy.utils.set_attr_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + return set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key), return_weight=return_weight) def pin_weight_to_device(self, key): weight, set_func, convert_func = get_key_weight(self.model, key) @@ -1355,6 +1379,23 @@ class ModelPatcher: self.unpatch_hooks() self.clear_cached_hook_weights() + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + unet_state_dict = self.model.diffusion_model.state_dict() + for k, v in unet_state_dict.items(): + op_keys = k.rsplit('.', 1) + if (len(op_keys) < 2) or not op_keys[1] in ["weight", "bias"]: + continue + try: + op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0]) + except: + continue + if not op or not hasattr(op, "comfy_cast_weights") or \ + (hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True): + continue + key = "diffusion_model." + k + unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key)) + return self.model.state_dict_for_saving(unet_state_dict) + def __del__(self): self.unpin_all_weights() self.detach(unpatch_all=False) diff --git a/comfy/sd.py b/comfy/sd.py index b689c0dfc..b71157343 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1575,9 +1575,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if metadata is None: metadata = {} - model_management.load_models_gpu(load_models, force_patch_weights=True) + model_management.load_models_gpu(load_models) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None - sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) + sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) for k in extra_keys: sd[k] = extra_keys[k]