Simplify/remove duplication in detect_unet_config for Depth Anything 3

This commit is contained in:
Talmaj Marinc 2026-06-01 15:59:23 +02:00
parent 5861890010
commit f57ebb4cf4

View File

@ -888,35 +888,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
# Detect head type and config. # Detect head type and config.
has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys
if has_aux: dit_config["head_dim_in"] = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
dit_config["head_type"] = "dualdpt" dit_config["head_features"] = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
# DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width). dit_config["head_out_channels"] = [
head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
out_channels = [
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0] state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
for i in range(4) for i in range(4)
] ]
features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0] if has_aux:
dit_config["head_dim_in"] = head_dim_in # DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width).
dit_config["head_type"] = "dualdpt"
dit_config["head_output_dim"] = 2 dit_config["head_output_dim"] = 2
dit_config["head_features"] = features
dit_config["head_out_channels"] = out_channels
dit_config["head_use_sky_head"] = False dit_config["head_use_sky_head"] = False
else: else:
dit_config["head_type"] = "dpt" dit_config["head_type"] = "dpt"
head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1] dit_config["head_output_dim"] = state_dict[
out_channels = [
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
for i in range(4)
]
features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
output_dim = state_dict[
'{}head.scratch.output_conv2.2.weight'.format(key_prefix) '{}head.scratch.output_conv2.2.weight'.format(key_prefix)
].shape[0] ].shape[0]
dit_config["head_dim_in"] = head_dim_in
dit_config["head_output_dim"] = output_dim
dit_config["head_features"] = features
dit_config["head_out_channels"] = out_channels
dit_config["head_use_sky_head"] = ( dit_config["head_use_sky_head"] = (
'{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys '{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys
) )