diff --git a/comfy/lora.py b/comfy/lora.py index 360cd128f..3a9077869 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -316,10 +316,11 @@ def model_lora_keys_unet(model, key_map={}): if isinstance(model, comfy.model_base.Lumina2): diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") for k in diffusers_keys: - to = diffusers_keys[k] - key_lora = k[:-len(".weight")] - key_map["diffusion_model.{}".format(key_lora)] = to - key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = k[:-len(".weight")] + key_map["diffusion_model.{}".format(key_lora)] = to + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to return key_map diff --git a/comfy/utils.py b/comfy/utils.py index 21bd6e8cf..37485e497 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -678,17 +678,14 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): def z_image_to_diffusers(mmdit_config, output_prefix=""): n_layers = mmdit_config.get("n_layers", 0) hidden_size = mmdit_config.get("dim", 0) - + n_context_refiner = mmdit_config.get("n_refiner_layers", 2) + n_noise_refiner = mmdit_config.get("n_refiner_layers", 2) key_map = {} - for index in range(n_layers): - prefix_from = "layers.{}".format(index) - prefix_to = "{}layers.{}".format(output_prefix, index) - + def add_block_keys(prefix_from, prefix_to, has_adaln=True): for end in ("weight", "bias"): k = "{}.attention.".format(prefix_from) qkv = "{}.attention.qkv.{}".format(prefix_to, end) - key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) @@ -698,28 +695,52 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""): "attention.norm_k.weight": "attention.k_norm.weight", "attention.to_out.0.weight": "attention.out.weight", "attention.to_out.0.bias": "attention.out.bias", + "attention_norm1.weight": "attention_norm1.weight", + "attention_norm2.weight": "attention_norm2.weight", + "feed_forward.w1.weight": "feed_forward.w1.weight", + "feed_forward.w2.weight": "feed_forward.w2.weight", + "feed_forward.w3.weight": "feed_forward.w3.weight", + "ffn_norm1.weight": "ffn_norm1.weight", + "ffn_norm2.weight": "ffn_norm2.weight", } + if has_adaln: + block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight" + block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias" + for k, v in block_map.items(): + key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v) - for k in block_map: - key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) + for i in range(n_layers): + add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i)) - MAP_BASIC = { - # Final layer + for i in range(n_context_refiner): + add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i)) + + for i in range(n_noise_refiner): + add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i)) + + MAP_BASIC = [ ("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"), ("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"), ("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"), ("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"), - # X embedder ("x_embedder.weight", "all_x_embedder.2-1.weight"), ("x_embedder.bias", "all_x_embedder.2-1.bias"), - } + ("x_pad_token", "x_pad_token"), + ("cap_embedder.0.weight", "cap_embedder.0.weight"), + ("cap_embedder.1.weight", "cap_embedder.1.weight"), + ("cap_embedder.1.bias", "cap_embedder.1.bias"), + ("cap_pad_token", "cap_pad_token"), + ("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"), + ("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"), + ("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"), + ("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"), + ] - for k in MAP_BASIC: - key_map[k[1]] = "{}{}".format(output_prefix, k[0]) + for c, diffusers in MAP_BASIC: + key_map[diffusers] = "{}{}".format(output_prefix, c) return key_map - def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: return tensor.narrow(dim, 0, batch_size)