implement dynamic clip saving (#13959)

Fix clip saving by doing the same patching process and diffusion
models.
This commit is contained in:
rattus 2026-05-19 04:46:40 +10:00 committed by GitHub
parent d4c6c9eff8
commit 16f862f02a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 14 deletions

View File

@ -1493,27 +1493,30 @@ class ModelPatcher:
self.unpatch_hooks()
self.clear_cached_hook_weights()
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
original_state_dict = self.model.diffusion_model.state_dict()
unet_state_dict = {}
def model_state_dict_for_saving(self, model=None, prefix=""):
if model is None:
model = self.model
original_state_dict = model.state_dict()
output_state_dict = {}
keys = list(original_state_dict)
while len(keys) > 0:
k = keys.pop(0)
v = original_state_dict[k]
op_keys = k.rsplit('.', 1)
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
unet_state_dict[k] = v
output_state_dict[k] = v
continue
try:
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
op = comfy.utils.get_attr(model, op_keys[0])
except:
unet_state_dict[k] = v
output_state_dict[k] = v
continue
if not op or not hasattr(op, "comfy_cast_weights") or \
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
unet_state_dict[k] = v
output_state_dict[k] = v
continue
key = "diffusion_model." + k
key = prefix + k
weight = comfy.utils.get_attr(self.model, key)
if isinstance(weight, QuantizedTensor) and k in original_state_dict:
qt_state_dict = weight.state_dict(k)
@ -1521,10 +1524,14 @@ class ModelPatcher:
for group_key in (x for x in qt_state_dict if x in original_state_dict):
if group_key in keys:
keys.remove(group_key)
unet_state_dict.pop(group_key, "")
unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key])
output_state_dict.pop(group_key, "")
output_state_dict[group_key] = LazyCastingParamPiece(caster, prefix + group_key, original_state_dict[group_key])
continue
unet_state_dict[k] = LazyCastingParam(self, key, weight)
output_state_dict[k] = LazyCastingParam(self, key, weight)
return output_state_dict
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
unet_state_dict = self.model_state_dict_for_saving(self.model.diffusion_model, "diffusion_model.")
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
def __del__(self):

View File

@ -423,6 +423,13 @@ class CLIP:
sd_clip[k] = sd_tokenizer[k]
return sd_clip
def state_dict_for_saving(self):
sd_clip = self.patcher.model_state_dict_for_saving()
sd_tokenizer = self.tokenizer.state_dict()
for k in sd_tokenizer:
sd_clip[k] = sd_tokenizer[k]
return sd_clip
def load_model(self, tokens={}):
memory_used = 0
if hasattr(self.cond_stage_model, "memory_estimation_function"):
@ -1908,7 +1915,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
load_models = [model]
if clip is not None:
load_models.append(clip.load_model())
clip_sd = clip.get_sd()
clip_sd = clip.state_dict_for_saving()
vae_sd = None
if vae is not None:
vae_sd = vae.get_sd()

View File

@ -276,8 +276,8 @@ class CLIPSave:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
clip_sd = clip.get_sd()
clip.load_model()
clip_sd = clip.state_dict_for_saving()
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))