diff --git a/comfy/lora.py b/comfy/lora.py index 3760f79f0..427cf98aa 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -327,32 +327,15 @@ def model_lora_keys_unet(model, key_map={}): key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format if isinstance(model, comfy.model_base.Krea2): - krea2_rename = { - "first": "img_in", "tmlp.0": "time_embed.linear_1", "tmlp.2": "time_embed.linear_2", - "tproj.1": "time_mod_proj", "txtmlp.1": "txt_in.linear_1", "txtmlp.3": "txt_in.linear_2", - "last.linear": "final_layer.linear", - } - krea2_sub = { - "txtfusion.": "text_fusion.", ".mlp.": ".ff.", - ".attn.wq": ".attn.to_q", ".attn.wk": ".attn.to_k", ".attn.wv": ".attn.to_v", - ".attn.gate": ".attn.to_gate", ".attn.wo": ".attn.to_out.0", - } - for k in sdk: - if not (k.startswith("diffusion_model.") and k.endswith(".weight")): - continue - name = k[len("diffusion_model."):-len(".weight")] - if name in krea2_rename: - name = krea2_rename[name] - else: - if name.startswith("blocks."): - name = "transformer_" + name # only the top-level blocks; not txtfusion's *_blocks - for a, b in krea2_sub.items(): - name = name.replace(a, b) - names = [name, name[:-2]] if name.endswith(".attn.to_out.0") else [name] # some tools drop the ".0" - for n in names: - key_map[n] = k # bare diffusers name - key_map["transformer.{}".format(n)] = k # diffusers "transformer." prefix - key_map["lycoris_{}".format(n.replace(".", "_"))] = k # SimpleTuner lycoris format + diffusers_keys = comfy.utils.krea2_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + for k in diffusers_keys: + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = k[:-len(".weight")] + key_map["diffusion_model.{}".format(key_lora)] = to + key_map["transformer.{}".format(key_lora)] = to + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + key_map[key_lora] = to if isinstance(model, comfy.model_base.Lumina2): diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") diff --git a/comfy/utils.py b/comfy/utils.py index 09d783fff..61c2a22dd 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -818,6 +818,44 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""): return key_map +def krea2_to_diffusers(mmdit_config, output_prefix=""): + n_layers = mmdit_config.get("layers", 0) + n_txt_layerwise = 2 # TextFusionTransformer hardcodes 2 layerwise + 2 refiner blocks + n_txt_refiner = 2 + key_map = {} + + def add_block(prefix_to, prefix_from): + block_map = { + "attn.to_q": "attn.wq", "attn.to_k": "attn.wk", "attn.to_v": "attn.wv", + "attn.to_gate": "attn.gate", "attn.to_out.0": "attn.wo", + "attn.to_out": "attn.wo", # some tools drop the ".0" on to_out + "ff.gate": "mlp.gate", "ff.up": "mlp.up", "ff.down": "mlp.down", + } + for d, c in block_map.items(): + key_map["{}.{}.weight".format(prefix_to, d)] = "{}{}.{}.weight".format(output_prefix, prefix_from, c) + + for i in range(n_layers): + add_block("transformer_blocks.{}".format(i), "blocks.{}".format(i)) + for i in range(n_txt_layerwise): + add_block("text_fusion.layerwise_blocks.{}".format(i), "txtfusion.layerwise_blocks.{}".format(i)) + for i in range(n_txt_refiner): + add_block("text_fusion.refiner_blocks.{}".format(i), "txtfusion.refiner_blocks.{}".format(i)) + + MAP_BASIC = [ + ("img_in", "first"), + ("time_embed.linear_1", "tmlp.0"), + ("time_embed.linear_2", "tmlp.2"), + ("time_mod_proj", "tproj.1"), + ("txt_in.linear_1", "txtmlp.1"), + ("txt_in.linear_2", "txtmlp.3"), + ("text_fusion.projector", "txtfusion.projector"), + ("final_layer.linear", "last.linear"), + ] + for d, c in MAP_BASIC: + key_map["{}.weight".format(d)] = "{}{}.weight".format(output_prefix, c) + + return key_map + def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: return tensor.narrow(dim, 0, batch_size)