diff --git a/comfy/lora.py b/comfy/lora.py index 2c8d0f0bf..3760f79f0 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -326,6 +326,34 @@ def model_lora_keys_unet(model, key_map={}): key_map["transformer.{}".format(key_lora)] = k key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format + if isinstance(model, comfy.model_base.Krea2): + krea2_rename = { + "first": "img_in", "tmlp.0": "time_embed.linear_1", "tmlp.2": "time_embed.linear_2", + "tproj.1": "time_mod_proj", "txtmlp.1": "txt_in.linear_1", "txtmlp.3": "txt_in.linear_2", + "last.linear": "final_layer.linear", + } + krea2_sub = { + "txtfusion.": "text_fusion.", ".mlp.": ".ff.", + ".attn.wq": ".attn.to_q", ".attn.wk": ".attn.to_k", ".attn.wv": ".attn.to_v", + ".attn.gate": ".attn.to_gate", ".attn.wo": ".attn.to_out.0", + } + for k in sdk: + if not (k.startswith("diffusion_model.") and k.endswith(".weight")): + continue + name = k[len("diffusion_model."):-len(".weight")] + if name in krea2_rename: + name = krea2_rename[name] + else: + if name.startswith("blocks."): + name = "transformer_" + name # only the top-level blocks; not txtfusion's *_blocks + for a, b in krea2_sub.items(): + name = name.replace(a, b) + names = [name, name[:-2]] if name.endswith(".attn.to_out.0") else [name] # some tools drop the ".0" + for n in names: + key_map[n] = k # bare diffusers name + key_map["transformer.{}".format(n)] = k # diffusers "transformer." prefix + key_map["lycoris_{}".format(n.replace(".", "_"))] = k # SimpleTuner lycoris format + if isinstance(model, comfy.model_base.Lumina2): diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") for k in diffusers_keys: