From caf6c6aada3125afa3cfd73144b98ac025e6b2b8 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 1 Jan 2026 00:36:06 +1000 Subject: [PATCH] Reduce RAM and compute time in model saving with Loras Get the model saving logic away from force_patch_weights and instead do the patching JIT during safetensors saving. Firstly switch off force_patch_weights in the load for save which avoids creating CPU side tensors with loras calculated. Then at save time, wrap the tensor to catch safetensors call to .to() and patch it live. This avoids having to ever have a lora-calculated copy of offloaded weights on the CPU. Also take advantage of the presence of the GPU when doing this Lora calculation. The former force_patch_weights would just do eveyrthing on the CPU. Its generally faster to go the GPU and back even if its just a Lora application. --- comfy/model_base.py | 9 +++---- comfy/model_patcher.py | 55 ++++++++++++++++++++++++++++++++++++------ comfy/sd.py | 4 +-- 3 files changed, 53 insertions(+), 15 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 49efd700b..b9e2abdb9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -321,7 +321,7 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): extra_sds = [] if clip_state_dict is not None: extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) @@ -329,10 +329,7 @@ class BaseModel(torch.nn.Module): extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict)) if clip_vision_state_dict is not None: extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) - - unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) - if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) @@ -775,8 +772,8 @@ class StableAudio1(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): - sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()} for k in d: s = d[k] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f6b80a40f..30ca39b2a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -24,6 +24,7 @@ import inspect import logging import math import uuid +import types from typing import Callable, Optional import torch @@ -212,6 +213,27 @@ class MemoryCounter: def decrement(self, used: int): self.value -= used +CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0) + +class LazyCastingParam(torch.nn.Parameter): + def __new__(cls, model, key, tensor): + return super().__new__(cls, tensor) + + def __init__(self, model, key, tensor): + self.model = model + self.key = key + + @property + def device(self): + return CustomTorchDevice + + #safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is + #then just a short lived thing in the safetensors serialization logic inside its big for loop over + #all weights getting garbage collected per-weight + def to(self, *args, **kwargs): + return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu") + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -611,14 +633,14 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False): - if key not in self.patches: - return - + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False): weight, set_func, convert_func = get_key_weight(self.model, key) + if key not in self.patches: + return weight + inplace_update = self.weight_inplace_update or inplace_update - if key not in self.backup: + if key not in self.backup and not return_weight: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) temp_dtype = comfy.model_management.lora_compute_dtype(device_to) @@ -632,12 +654,14 @@ class ModelPatcher: out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) - if inplace_update: + if return_weight: + return out_weight + elif inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: comfy.utils.set_attr_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + return set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key), return_weight=return_weight) def pin_weight_to_device(self, key): weight, set_func, convert_func = get_key_weight(self.model, key) @@ -1355,6 +1379,23 @@ class ModelPatcher: self.unpatch_hooks() self.clear_cached_hook_weights() + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + unet_state_dict = self.model.diffusion_model.state_dict() + for k, v in unet_state_dict.items(): + op_keys = k.rsplit('.', 1) + if (len(op_keys) < 2) or not op_keys[1] in ["weight", "bias"]: + continue + try: + op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0]) + except: + continue + if not op or not hasattr(op, "comfy_cast_weights") or \ + (hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True): + continue + key = "diffusion_model." + k + unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key)) + return self.model.state_dict_for_saving(unet_state_dict) + def __del__(self): self.unpin_all_weights() self.detach(unpatch_all=False) diff --git a/comfy/sd.py b/comfy/sd.py index b689c0dfc..b71157343 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1575,9 +1575,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if metadata is None: metadata = {} - model_management.load_models_gpu(load_models, force_patch_weights=True) + model_management.load_models_gpu(load_models) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None - sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) + sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) for k in extra_keys: sd[k] = extra_keys[k]