mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Move lora key remap
This commit is contained in:
parent
c2a95725c0
commit
facb31ed10
@ -327,32 +327,15 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
||||
|
||||
if isinstance(model, comfy.model_base.Krea2):
|
||||
krea2_rename = {
|
||||
"first": "img_in", "tmlp.0": "time_embed.linear_1", "tmlp.2": "time_embed.linear_2",
|
||||
"tproj.1": "time_mod_proj", "txtmlp.1": "txt_in.linear_1", "txtmlp.3": "txt_in.linear_2",
|
||||
"last.linear": "final_layer.linear",
|
||||
}
|
||||
krea2_sub = {
|
||||
"txtfusion.": "text_fusion.", ".mlp.": ".ff.",
|
||||
".attn.wq": ".attn.to_q", ".attn.wk": ".attn.to_k", ".attn.wv": ".attn.to_v",
|
||||
".attn.gate": ".attn.to_gate", ".attn.wo": ".attn.to_out.0",
|
||||
}
|
||||
for k in sdk:
|
||||
if not (k.startswith("diffusion_model.") and k.endswith(".weight")):
|
||||
continue
|
||||
name = k[len("diffusion_model."):-len(".weight")]
|
||||
if name in krea2_rename:
|
||||
name = krea2_rename[name]
|
||||
else:
|
||||
if name.startswith("blocks."):
|
||||
name = "transformer_" + name # only the top-level blocks; not txtfusion's *_blocks
|
||||
for a, b in krea2_sub.items():
|
||||
name = name.replace(a, b)
|
||||
names = [name, name[:-2]] if name.endswith(".attn.to_out.0") else [name] # some tools drop the ".0"
|
||||
for n in names:
|
||||
key_map[n] = k # bare diffusers name
|
||||
key_map["transformer.{}".format(n)] = k # diffusers "transformer." prefix
|
||||
key_map["lycoris_{}".format(n.replace(".", "_"))] = k # SimpleTuner lycoris format
|
||||
diffusers_keys = comfy.utils.krea2_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = k[:-len(".weight")]
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["transformer.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.Lumina2):
|
||||
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
|
||||
@ -818,6 +818,44 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""):
|
||||
|
||||
return key_map
|
||||
|
||||
def krea2_to_diffusers(mmdit_config, output_prefix=""):
|
||||
n_layers = mmdit_config.get("layers", 0)
|
||||
n_txt_layerwise = 2 # TextFusionTransformer hardcodes 2 layerwise + 2 refiner blocks
|
||||
n_txt_refiner = 2
|
||||
key_map = {}
|
||||
|
||||
def add_block(prefix_to, prefix_from):
|
||||
block_map = {
|
||||
"attn.to_q": "attn.wq", "attn.to_k": "attn.wk", "attn.to_v": "attn.wv",
|
||||
"attn.to_gate": "attn.gate", "attn.to_out.0": "attn.wo",
|
||||
"attn.to_out": "attn.wo", # some tools drop the ".0" on to_out
|
||||
"ff.gate": "mlp.gate", "ff.up": "mlp.up", "ff.down": "mlp.down",
|
||||
}
|
||||
for d, c in block_map.items():
|
||||
key_map["{}.{}.weight".format(prefix_to, d)] = "{}{}.{}.weight".format(output_prefix, prefix_from, c)
|
||||
|
||||
for i in range(n_layers):
|
||||
add_block("transformer_blocks.{}".format(i), "blocks.{}".format(i))
|
||||
for i in range(n_txt_layerwise):
|
||||
add_block("text_fusion.layerwise_blocks.{}".format(i), "txtfusion.layerwise_blocks.{}".format(i))
|
||||
for i in range(n_txt_refiner):
|
||||
add_block("text_fusion.refiner_blocks.{}".format(i), "txtfusion.refiner_blocks.{}".format(i))
|
||||
|
||||
MAP_BASIC = [
|
||||
("img_in", "first"),
|
||||
("time_embed.linear_1", "tmlp.0"),
|
||||
("time_embed.linear_2", "tmlp.2"),
|
||||
("time_mod_proj", "tproj.1"),
|
||||
("txt_in.linear_1", "txtmlp.1"),
|
||||
("txt_in.linear_2", "txtmlp.3"),
|
||||
("text_fusion.projector", "txtfusion.projector"),
|
||||
("final_layer.linear", "last.linear"),
|
||||
]
|
||||
for d, c in MAP_BASIC:
|
||||
key_map["{}.weight".format(d)] = "{}{}.weight".format(output_prefix, c)
|
||||
|
||||
return key_map
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||
if tensor.shape[dim] > batch_size:
|
||||
return tensor.narrow(dim, 0, batch_size)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user