Implement twinflow_z_image_key_mapping function

Add twinflow_z_image_key_mapping function for key mapping.
This commit is contained in:
azazeal04 2026-04-04 16:06:41 +02:00 committed by GitHub
parent b76510b549
commit e978c61a79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -818,6 +818,16 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""):
return key_map
def twinflow_z_image_key_mapping(state_dict, key):
"""
TwinFlow-Z-Image key mapping.
Maps t_embedder_2 keys to t_embedder for weight loading.
"""
if key.startswith("t_embedder_2."):
new_key = key.replace("t_embedder_2.", "t_embedder.", 1)
state_dict[new_key] = state_dict.pop(key)
return state_dict
def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size)