From f4b2173cf2b1b8f1546936f6e7ba6134b942f096 Mon Sep 17 00:00:00 2001 From: kijai Date: Tue, 30 Jun 2026 20:15:46 +0300 Subject: [PATCH] Simplify model detection --- comfy/model_detection.py | 53 +++++++++++++-------------------------- comfy/supported_models.py | 3 ++- 2 files changed, 19 insertions(+), 37 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7cff615fa..cf83159e7 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -113,42 +113,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] return unet_config - def _detect_proj(sub_prefix: str, name: str): - key = '{}{}.blocks.0.cross_attn.proj_linear.weight'.format(key_prefix, sub_prefix) - if key in state_dict_keys: - unet_config["image_attn_mode_{}".format(name)] = "proj" - unet_config["proj_in_channels_{}".format(name)] = int(state_dict[key].shape[1]) - - if '{}img2shape.blocks.0.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or \ - '{}img2shape.blocks.0.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: - unet_config = {} - unet_config["image_model"] = "trellis2" - - unet_config["init_txt_model"] = ( - '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or - '{}shape2txt.blocks.29.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys - ) - - unet_config["resolution"] = 64 - if metadata is not None: - if "is_512" in metadata: - unet_config["resolution"] = 32 - - unet_config["num_heads"] = 12 - - _detect_proj("img2shape", "shape") - _detect_proj("shape2txt", "texture") - _detect_proj("structure_model", "structure") - return unet_config - - if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or \ - '{}shape2txt.blocks.29.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: # trellis2 texture - unet_config = {} - unet_config["image_model"] = "trellis2" - unet_config["resolution"] = 64 - unet_config["num_heads"] = 12 - unet_config["txt_only"] = True - _detect_proj("shape2txt", "texture") + shape_key = '{}img2shape.t_embedder.mlp.0.weight'.format(key_prefix) + tex_key = '{}shape2txt.t_embedder.mlp.0.weight'.format(key_prefix) + if shape_key in state_dict_keys or tex_key in state_dict_keys: # trellis2 / pixal3d + has_shape = shape_key in state_dict_keys + has_tex = tex_key in state_dict_keys + unet_config = { + "image_model": "trellis2", + "resolution": 32 if (metadata is not None and "is_512" in metadata) else 64, + "init_txt_model": has_tex, + "txt_only": has_tex and not has_shape, + } + # Per-submodel projection head (Pixal3D adds `proj_linear`; Trellis2 doesn't). + for sub, name in (("img2shape", "shape"), ("shape2txt", "texture"), ("structure_model", "structure")): + key = '{}{}.blocks.0.cross_attn.proj_linear.weight'.format(key_prefix, sub) + if key in state_dict_keys: + unet_config["image_attn_mode_{}".format(name)] = "proj" + unet_config["proj_in_channels_{}".format(name)] = int(state_dict[key].shape[1]) return unet_config if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 815c0c3f6..69ee9ce5b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1436,7 +1436,8 @@ class Trellis2(supported_models_base.BASE): unet_config = { "image_model": "trellis2" } - unet_extra_config = {} + + unet_extra_config = {"num_heads": 12} sampling_settings = { "shift": 3.0,