From f57ebb4cf42af46e1936daea949e0db7eacadcd8 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Mon, 1 Jun 2026 15:59:23 +0200 Subject: [PATCH] Simplify/remove duplication in detect_unet_config for Depth Anything 3 --- comfy/model_detection.py | 29 ++++++++--------------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 2db394d07..c9c123bb8 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -888,35 +888,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): # Detect head type and config. has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys + dit_config["head_dim_in"] = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1] + dit_config["head_features"] = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0] + dit_config["head_out_channels"] = [ + state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0] + for i in range(4) + ] if has_aux: - dit_config["head_type"] = "dualdpt" # DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width). - head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1] - out_channels = [ - state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0] - for i in range(4) - ] - features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0] - dit_config["head_dim_in"] = head_dim_in + dit_config["head_type"] = "dualdpt" dit_config["head_output_dim"] = 2 - dit_config["head_features"] = features - dit_config["head_out_channels"] = out_channels dit_config["head_use_sky_head"] = False else: dit_config["head_type"] = "dpt" - head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1] - out_channels = [ - state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0] - for i in range(4) - ] - features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0] - output_dim = state_dict[ + dit_config["head_output_dim"] = state_dict[ '{}head.scratch.output_conv2.2.weight'.format(key_prefix) ].shape[0] - dit_config["head_dim_in"] = head_dim_in - dit_config["head_output_dim"] = output_dim - dit_config["head_features"] = features - dit_config["head_out_channels"] = out_channels dit_config["head_use_sky_head"] = ( '{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys )