mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-19 21:47:24 +08:00
implement dynamic clip saving (#13959)
Fix clip saving by doing the same patching process and diffusion models.
This commit is contained in:
parent
d4c6c9eff8
commit
16f862f02a
@ -1493,27 +1493,30 @@ class ModelPatcher:
|
||||
self.unpatch_hooks()
|
||||
self.clear_cached_hook_weights()
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
original_state_dict = self.model.diffusion_model.state_dict()
|
||||
unet_state_dict = {}
|
||||
def model_state_dict_for_saving(self, model=None, prefix=""):
|
||||
if model is None:
|
||||
model = self.model
|
||||
|
||||
original_state_dict = model.state_dict()
|
||||
output_state_dict = {}
|
||||
keys = list(original_state_dict)
|
||||
while len(keys) > 0:
|
||||
k = keys.pop(0)
|
||||
v = original_state_dict[k]
|
||||
op_keys = k.rsplit('.', 1)
|
||||
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
|
||||
unet_state_dict[k] = v
|
||||
output_state_dict[k] = v
|
||||
continue
|
||||
try:
|
||||
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
|
||||
op = comfy.utils.get_attr(model, op_keys[0])
|
||||
except:
|
||||
unet_state_dict[k] = v
|
||||
output_state_dict[k] = v
|
||||
continue
|
||||
if not op or not hasattr(op, "comfy_cast_weights") or \
|
||||
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
|
||||
unet_state_dict[k] = v
|
||||
output_state_dict[k] = v
|
||||
continue
|
||||
key = "diffusion_model." + k
|
||||
key = prefix + k
|
||||
weight = comfy.utils.get_attr(self.model, key)
|
||||
if isinstance(weight, QuantizedTensor) and k in original_state_dict:
|
||||
qt_state_dict = weight.state_dict(k)
|
||||
@ -1521,10 +1524,14 @@ class ModelPatcher:
|
||||
for group_key in (x for x in qt_state_dict if x in original_state_dict):
|
||||
if group_key in keys:
|
||||
keys.remove(group_key)
|
||||
unet_state_dict.pop(group_key, "")
|
||||
unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key])
|
||||
output_state_dict.pop(group_key, "")
|
||||
output_state_dict[group_key] = LazyCastingParamPiece(caster, prefix + group_key, original_state_dict[group_key])
|
||||
continue
|
||||
unet_state_dict[k] = LazyCastingParam(self, key, weight)
|
||||
output_state_dict[k] = LazyCastingParam(self, key, weight)
|
||||
return output_state_dict
|
||||
|
||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||
unet_state_dict = self.model_state_dict_for_saving(self.model.diffusion_model, "diffusion_model.")
|
||||
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||
|
||||
def __del__(self):
|
||||
|
||||
@ -423,6 +423,13 @@ class CLIP:
|
||||
sd_clip[k] = sd_tokenizer[k]
|
||||
return sd_clip
|
||||
|
||||
def state_dict_for_saving(self):
|
||||
sd_clip = self.patcher.model_state_dict_for_saving()
|
||||
sd_tokenizer = self.tokenizer.state_dict()
|
||||
for k in sd_tokenizer:
|
||||
sd_clip[k] = sd_tokenizer[k]
|
||||
return sd_clip
|
||||
|
||||
def load_model(self, tokens={}):
|
||||
memory_used = 0
|
||||
if hasattr(self.cond_stage_model, "memory_estimation_function"):
|
||||
@ -1908,7 +1915,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
||||
load_models = [model]
|
||||
if clip is not None:
|
||||
load_models.append(clip.load_model())
|
||||
clip_sd = clip.get_sd()
|
||||
clip_sd = clip.state_dict_for_saving()
|
||||
vae_sd = None
|
||||
if vae is not None:
|
||||
vae_sd = vae.get_sd()
|
||||
|
||||
@ -276,8 +276,8 @@ class CLIPSave:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
|
||||
clip_sd = clip.get_sd()
|
||||
clip.load_model()
|
||||
clip_sd = clip.state_dict_for_saving()
|
||||
|
||||
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
|
||||
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user