mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 07:22:34 +08:00
Add detection for TwinFlow-Z-Image checkpoints
This commit is contained in:
parent
21ed4a2242
commit
84c983a779
@ -44,6 +44,12 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
|||||||
def detect_unet_config(state_dict, key_prefix, metadata=None):
|
def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||||
state_dict_keys = list(state_dict.keys())
|
state_dict_keys = list(state_dict.keys())
|
||||||
|
|
||||||
|
# TwinFlow-Z-Image: detect dual timestep embedder checkpoints first.
|
||||||
|
if any(k.startswith('{}t_embedder_2.'.format(key_prefix)) for k in state_dict_keys):
|
||||||
|
return {
|
||||||
|
"image_model": "twinflow_z_image",
|
||||||
|
"architecture": "TwinFlow_Z_Image",
|
||||||
|
}
|
||||||
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
|
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
|
||||||
unet_config = {}
|
unet_config = {}
|
||||||
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
|
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user