mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-24 16:07:30 +08:00
implement dynamic clip saving (#13959)
Fix clip saving by doing the same patching process and diffusion models.
This commit is contained in:
parent
d4c6c9eff8
commit
16f862f02a
@ -1493,27 +1493,30 @@ class ModelPatcher:
|
|||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
self.clear_cached_hook_weights()
|
self.clear_cached_hook_weights()
|
||||||
|
|
||||||
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
def model_state_dict_for_saving(self, model=None, prefix=""):
|
||||||
original_state_dict = self.model.diffusion_model.state_dict()
|
if model is None:
|
||||||
unet_state_dict = {}
|
model = self.model
|
||||||
|
|
||||||
|
original_state_dict = model.state_dict()
|
||||||
|
output_state_dict = {}
|
||||||
keys = list(original_state_dict)
|
keys = list(original_state_dict)
|
||||||
while len(keys) > 0:
|
while len(keys) > 0:
|
||||||
k = keys.pop(0)
|
k = keys.pop(0)
|
||||||
v = original_state_dict[k]
|
v = original_state_dict[k]
|
||||||
op_keys = k.rsplit('.', 1)
|
op_keys = k.rsplit('.', 1)
|
||||||
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
|
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
|
||||||
unet_state_dict[k] = v
|
output_state_dict[k] = v
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
|
op = comfy.utils.get_attr(model, op_keys[0])
|
||||||
except:
|
except:
|
||||||
unet_state_dict[k] = v
|
output_state_dict[k] = v
|
||||||
continue
|
continue
|
||||||
if not op or not hasattr(op, "comfy_cast_weights") or \
|
if not op or not hasattr(op, "comfy_cast_weights") or \
|
||||||
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
|
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
|
||||||
unet_state_dict[k] = v
|
output_state_dict[k] = v
|
||||||
continue
|
continue
|
||||||
key = "diffusion_model." + k
|
key = prefix + k
|
||||||
weight = comfy.utils.get_attr(self.model, key)
|
weight = comfy.utils.get_attr(self.model, key)
|
||||||
if isinstance(weight, QuantizedTensor) and k in original_state_dict:
|
if isinstance(weight, QuantizedTensor) and k in original_state_dict:
|
||||||
qt_state_dict = weight.state_dict(k)
|
qt_state_dict = weight.state_dict(k)
|
||||||
@ -1521,10 +1524,14 @@ class ModelPatcher:
|
|||||||
for group_key in (x for x in qt_state_dict if x in original_state_dict):
|
for group_key in (x for x in qt_state_dict if x in original_state_dict):
|
||||||
if group_key in keys:
|
if group_key in keys:
|
||||||
keys.remove(group_key)
|
keys.remove(group_key)
|
||||||
unet_state_dict.pop(group_key, "")
|
output_state_dict.pop(group_key, "")
|
||||||
unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key])
|
output_state_dict[group_key] = LazyCastingParamPiece(caster, prefix + group_key, original_state_dict[group_key])
|
||||||
continue
|
continue
|
||||||
unet_state_dict[k] = LazyCastingParam(self, key, weight)
|
output_state_dict[k] = LazyCastingParam(self, key, weight)
|
||||||
|
return output_state_dict
|
||||||
|
|
||||||
|
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
|
||||||
|
unet_state_dict = self.model_state_dict_for_saving(self.model.diffusion_model, "diffusion_model.")
|
||||||
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
|||||||
@ -423,6 +423,13 @@ class CLIP:
|
|||||||
sd_clip[k] = sd_tokenizer[k]
|
sd_clip[k] = sd_tokenizer[k]
|
||||||
return sd_clip
|
return sd_clip
|
||||||
|
|
||||||
|
def state_dict_for_saving(self):
|
||||||
|
sd_clip = self.patcher.model_state_dict_for_saving()
|
||||||
|
sd_tokenizer = self.tokenizer.state_dict()
|
||||||
|
for k in sd_tokenizer:
|
||||||
|
sd_clip[k] = sd_tokenizer[k]
|
||||||
|
return sd_clip
|
||||||
|
|
||||||
def load_model(self, tokens={}):
|
def load_model(self, tokens={}):
|
||||||
memory_used = 0
|
memory_used = 0
|
||||||
if hasattr(self.cond_stage_model, "memory_estimation_function"):
|
if hasattr(self.cond_stage_model, "memory_estimation_function"):
|
||||||
@ -1908,7 +1915,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
|||||||
load_models = [model]
|
load_models = [model]
|
||||||
if clip is not None:
|
if clip is not None:
|
||||||
load_models.append(clip.load_model())
|
load_models.append(clip.load_model())
|
||||||
clip_sd = clip.get_sd()
|
clip_sd = clip.state_dict_for_saving()
|
||||||
vae_sd = None
|
vae_sd = None
|
||||||
if vae is not None:
|
if vae is not None:
|
||||||
vae_sd = vae.get_sd()
|
vae_sd = vae.get_sd()
|
||||||
|
|||||||
@ -276,8 +276,8 @@ class CLIPSave:
|
|||||||
for x in extra_pnginfo:
|
for x in extra_pnginfo:
|
||||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||||
|
|
||||||
comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
|
clip.load_model()
|
||||||
clip_sd = clip.get_sd()
|
clip_sd = clip.state_dict_for_saving()
|
||||||
|
|
||||||
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
|
for prefix in ["clip_l.", "clip_g.", "clip_h.", "t5xxl.", "pile_t5xl.", "mt5xl.", "umt5xxl.", "t5base.", "gemma2_2b.", "llama.", "hydit_clip.", ""]:
|
||||||
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
|
k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user