Simplify model detection

This commit is contained in:
kijai 2026-06-30 20:15:46 +03:00
parent 7e39dea988
commit f4b2173cf2
2 changed files with 19 additions and 37 deletions

View File

@ -113,42 +113,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
return unet_config
def _detect_proj(sub_prefix: str, name: str):
key = '{}{}.blocks.0.cross_attn.proj_linear.weight'.format(key_prefix, sub_prefix)
if key in state_dict_keys:
unet_config["image_attn_mode_{}".format(name)] = "proj"
unet_config["proj_in_channels_{}".format(name)] = int(state_dict[key].shape[1])
if '{}img2shape.blocks.0.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or \
'{}img2shape.blocks.0.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
unet_config = {}
unet_config["image_model"] = "trellis2"
unet_config["init_txt_model"] = (
'{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or
'{}shape2txt.blocks.29.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys
)
unet_config["resolution"] = 64
if metadata is not None:
if "is_512" in metadata:
unet_config["resolution"] = 32
unet_config["num_heads"] = 12
_detect_proj("img2shape", "shape")
_detect_proj("shape2txt", "texture")
_detect_proj("structure_model", "structure")
return unet_config
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or \
'{}shape2txt.blocks.29.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: # trellis2 texture
unet_config = {}
unet_config["image_model"] = "trellis2"
unet_config["resolution"] = 64
unet_config["num_heads"] = 12
unet_config["txt_only"] = True
_detect_proj("shape2txt", "texture")
shape_key = '{}img2shape.t_embedder.mlp.0.weight'.format(key_prefix)
tex_key = '{}shape2txt.t_embedder.mlp.0.weight'.format(key_prefix)
if shape_key in state_dict_keys or tex_key in state_dict_keys: # trellis2 / pixal3d
has_shape = shape_key in state_dict_keys
has_tex = tex_key in state_dict_keys
unet_config = {
"image_model": "trellis2",
"resolution": 32 if (metadata is not None and "is_512" in metadata) else 64,
"init_txt_model": has_tex,
"txt_only": has_tex and not has_shape,
}
# Per-submodel projection head (Pixal3D adds `proj_linear`; Trellis2 doesn't).
for sub, name in (("img2shape", "shape"), ("shape2txt", "texture"), ("structure_model", "structure")):
key = '{}{}.blocks.0.cross_attn.proj_linear.weight'.format(key_prefix, sub)
if key in state_dict_keys:
unet_config["image_attn_mode_{}".format(name)] = "proj"
unet_config["proj_in_channels_{}".format(name)] = int(state_dict[key].shape[1])
return unet_config
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit

View File

@ -1436,7 +1436,8 @@ class Trellis2(supported_models_base.BASE):
unet_config = {
"image_model": "trellis2"
}
unet_extra_config = {}
unet_extra_config = {"num_heads": 12}
sampling_settings = {
"shift": 3.0,