Reduce RAM and compute time in model saving with Loras

Get the model saving logic away from force_patch_weights and instead do
the patching JIT during safetensors saving.

Firstly switch off force_patch_weights in the load for save which avoids
creating CPU side tensors with loras calculated.

Then at save time, wrap the tensor to catch safetensors call to .to() and
patch it live.

This avoids having to ever have a lora-calculated copy of offloaded
weights on the CPU.

Also take advantage of the presence of the GPU when doing this Lora
calculation. The former force_patch_weights would just do eveyrthing on
the CPU. Its generally faster to go the GPU and back even if its just
a Lora application.
This commit is contained in:
Rattus 2026-01-01 00:36:06 +10:00
parent 5ac1372533
commit caf6c6aada
3 changed files with 53 additions and 15 deletions

View File

@ -321,7 +321,7 @@ class BaseModel(torch.nn.Module):
def process_latent_out(self, latent): def process_latent_out(self, latent):
return self.latent_format.process_out(latent) return self.latent_format.process_out(latent)
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
extra_sds = [] extra_sds = []
if clip_state_dict is not None: if clip_state_dict is not None:
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
@ -329,10 +329,7 @@ class BaseModel(torch.nn.Module):
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict)) extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
if clip_vision_state_dict is not None: if clip_vision_state_dict is not None:
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION: if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([]) unet_state_dict["v_pred"] = torch.tensor([])
@ -775,8 +772,8 @@ class StableAudio1(BaseModel):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out return out
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()} d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
for k in d: for k in d:
s = d[k] s = d[k]

View File

@ -24,6 +24,7 @@ import inspect
import logging import logging
import math import math
import uuid import uuid
import types
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
@ -212,6 +213,27 @@ class MemoryCounter:
def decrement(self, used: int): def decrement(self, used: int):
self.value -= used self.value -= used
CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0)
class LazyCastingParam(torch.nn.Parameter):
def __new__(cls, model, key, tensor):
return super().__new__(cls, tensor)
def __init__(self, model, key, tensor):
self.model = model
self.key = key
@property
def device(self):
return CustomTorchDevice
#safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is
#then just a short lived thing in the safetensors serialization logic inside its big for loop over
#all weights getting garbage collected per-weight
def to(self, *args, **kwargs):
return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu")
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size self.size = size
@ -611,14 +633,14 @@ class ModelPatcher:
sd.pop(k) sd.pop(k)
return sd return sd
def patch_weight_to_device(self, key, device_to=None, inplace_update=False): def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False):
if key not in self.patches:
return
weight, set_func, convert_func = get_key_weight(self.model, key) weight, set_func, convert_func = get_key_weight(self.model, key)
if key not in self.patches:
return weight
inplace_update = self.weight_inplace_update or inplace_update inplace_update = self.weight_inplace_update or inplace_update
if key not in self.backup: if key not in self.backup and not return_weight:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
temp_dtype = comfy.model_management.lora_compute_dtype(device_to) temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
@ -632,12 +654,14 @@ class ModelPatcher:
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
if set_func is None: if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
if inplace_update: if return_weight:
return out_weight
elif inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight) comfy.utils.copy_to_param(self.model, key, out_weight)
else: else:
comfy.utils.set_attr_param(self.model, key, out_weight) comfy.utils.set_attr_param(self.model, key, out_weight)
else: else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) return set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key), return_weight=return_weight)
def pin_weight_to_device(self, key): def pin_weight_to_device(self, key):
weight, set_func, convert_func = get_key_weight(self.model, key) weight, set_func, convert_func = get_key_weight(self.model, key)
@ -1355,6 +1379,23 @@ class ModelPatcher:
self.unpatch_hooks() self.unpatch_hooks()
self.clear_cached_hook_weights() self.clear_cached_hook_weights()
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
unet_state_dict = self.model.diffusion_model.state_dict()
for k, v in unet_state_dict.items():
op_keys = k.rsplit('.', 1)
if (len(op_keys) < 2) or not op_keys[1] in ["weight", "bias"]:
continue
try:
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
except:
continue
if not op or not hasattr(op, "comfy_cast_weights") or \
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
continue
key = "diffusion_model." + k
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
return self.model.state_dict_for_saving(unet_state_dict)
def __del__(self): def __del__(self):
self.unpin_all_weights() self.unpin_all_weights()
self.detach(unpatch_all=False) self.detach(unpatch_all=False)

View File

@ -1575,9 +1575,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
if metadata is None: if metadata is None:
metadata = {} metadata = {}
model_management.load_models_gpu(load_models, force_patch_weights=True) model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
for k in extra_keys: for k in extra_keys:
sd[k] = extra_keys[k] sd[k] = extra_keys[k]