mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-04 05:31:03 +08:00
Simplify model detection
This commit is contained in:
parent
7e39dea988
commit
f4b2173cf2
@ -113,42 +113,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
||||
return unet_config
|
||||
|
||||
def _detect_proj(sub_prefix: str, name: str):
|
||||
key = '{}{}.blocks.0.cross_attn.proj_linear.weight'.format(key_prefix, sub_prefix)
|
||||
if key in state_dict_keys:
|
||||
unet_config["image_attn_mode_{}".format(name)] = "proj"
|
||||
unet_config["proj_in_channels_{}".format(name)] = int(state_dict[key].shape[1])
|
||||
|
||||
if '{}img2shape.blocks.0.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or \
|
||||
'{}img2shape.blocks.0.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||
unet_config = {}
|
||||
unet_config["image_model"] = "trellis2"
|
||||
|
||||
unet_config["init_txt_model"] = (
|
||||
'{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or
|
||||
'{}shape2txt.blocks.29.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys
|
||||
)
|
||||
|
||||
unet_config["resolution"] = 64
|
||||
if metadata is not None:
|
||||
if "is_512" in metadata:
|
||||
unet_config["resolution"] = 32
|
||||
|
||||
unet_config["num_heads"] = 12
|
||||
|
||||
_detect_proj("img2shape", "shape")
|
||||
_detect_proj("shape2txt", "texture")
|
||||
_detect_proj("structure_model", "structure")
|
||||
return unet_config
|
||||
|
||||
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys or \
|
||||
'{}shape2txt.blocks.29.cross_attn.cross_attn_block.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: # trellis2 texture
|
||||
unet_config = {}
|
||||
unet_config["image_model"] = "trellis2"
|
||||
unet_config["resolution"] = 64
|
||||
unet_config["num_heads"] = 12
|
||||
unet_config["txt_only"] = True
|
||||
_detect_proj("shape2txt", "texture")
|
||||
shape_key = '{}img2shape.t_embedder.mlp.0.weight'.format(key_prefix)
|
||||
tex_key = '{}shape2txt.t_embedder.mlp.0.weight'.format(key_prefix)
|
||||
if shape_key in state_dict_keys or tex_key in state_dict_keys: # trellis2 / pixal3d
|
||||
has_shape = shape_key in state_dict_keys
|
||||
has_tex = tex_key in state_dict_keys
|
||||
unet_config = {
|
||||
"image_model": "trellis2",
|
||||
"resolution": 32 if (metadata is not None and "is_512" in metadata) else 64,
|
||||
"init_txt_model": has_tex,
|
||||
"txt_only": has_tex and not has_shape,
|
||||
}
|
||||
# Per-submodel projection head (Pixal3D adds `proj_linear`; Trellis2 doesn't).
|
||||
for sub, name in (("img2shape", "shape"), ("shape2txt", "texture"), ("structure_model", "structure")):
|
||||
key = '{}{}.blocks.0.cross_attn.proj_linear.weight'.format(key_prefix, sub)
|
||||
if key in state_dict_keys:
|
||||
unet_config["image_attn_mode_{}".format(name)] = "proj"
|
||||
unet_config["proj_in_channels_{}".format(name)] = int(state_dict[key].shape[1])
|
||||
return unet_config
|
||||
|
||||
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
||||
|
||||
@ -1436,7 +1436,8 @@ class Trellis2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "trellis2"
|
||||
}
|
||||
unet_extra_config = {}
|
||||
|
||||
unet_extra_config = {"num_heads": 12}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user