mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 02:57:24 +08:00
Support anima TE lora kohya format. (#13847)
This commit is contained in:
parent
20e439419c
commit
0a7d2ffd68
@ -97,12 +97,14 @@ def load_lora(lora, to_load, log_missing=True):
|
|||||||
|
|
||||||
def model_lora_keys_clip(model, key_map={}):
|
def model_lora_keys_clip(model, key_map={}):
|
||||||
sdk = model.state_dict().keys()
|
sdk = model.state_dict().keys()
|
||||||
|
prefix_set = set()
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||||
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
|
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
|
||||||
if tp > 0 and not k.startswith("clip_"):
|
if tp > 0 and not k.startswith("clip_"):
|
||||||
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
|
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
|
||||||
|
prefix_set.add(k.split('.')[0])
|
||||||
|
|
||||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||||
clip_l_present = False
|
clip_l_present = False
|
||||||
@ -163,6 +165,13 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
if len(prefix_set) == 1:
|
||||||
|
full_prefix = "{}.transformer.model.".format(next(iter(prefix_set))) # kohya anima and maybe other single TE models that use a single llama arch based te
|
||||||
|
for k in sdk:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
if k.startswith(full_prefix):
|
||||||
|
l_key = k[len(full_prefix):-len(".weight")]
|
||||||
|
key_map["lora_te_{}".format(l_key.replace(".", "_"))] = k
|
||||||
|
|
||||||
k = "clip_g.transformer.text_projection.weight"
|
k = "clip_g.transformer.text_projection.weight"
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user