mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-09 07:47:33 +08:00
Simplify/remove duplication in detect_unet_config for Depth Anything 3
This commit is contained in:
parent
5861890010
commit
f57ebb4cf4
@ -888,35 +888,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
# Detect head type and config.
|
||||
has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["head_dim_in"] = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["head_features"] = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["head_out_channels"] = [
|
||||
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
|
||||
for i in range(4)
|
||||
]
|
||||
if has_aux:
|
||||
dit_config["head_type"] = "dualdpt"
|
||||
# DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width).
|
||||
head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
|
||||
out_channels = [
|
||||
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
|
||||
for i in range(4)
|
||||
]
|
||||
features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["head_dim_in"] = head_dim_in
|
||||
dit_config["head_type"] = "dualdpt"
|
||||
dit_config["head_output_dim"] = 2
|
||||
dit_config["head_features"] = features
|
||||
dit_config["head_out_channels"] = out_channels
|
||||
dit_config["head_use_sky_head"] = False
|
||||
else:
|
||||
dit_config["head_type"] = "dpt"
|
||||
head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
|
||||
out_channels = [
|
||||
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
|
||||
for i in range(4)
|
||||
]
|
||||
features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
|
||||
output_dim = state_dict[
|
||||
dit_config["head_output_dim"] = state_dict[
|
||||
'{}head.scratch.output_conv2.2.weight'.format(key_prefix)
|
||||
].shape[0]
|
||||
dit_config["head_dim_in"] = head_dim_in
|
||||
dit_config["head_output_dim"] = output_dim
|
||||
dit_config["head_features"] = features
|
||||
dit_config["head_out_channels"] = out_channels
|
||||
dit_config["head_use_sky_head"] = (
|
||||
'{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user