mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-09 07:47:33 +08:00
Simplify/remove duplication in detect_unet_config for Depth Anything 3
This commit is contained in:
parent
5861890010
commit
f57ebb4cf4
@ -888,35 +888,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
# Detect head type and config.
|
# Detect head type and config.
|
||||||
has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys
|
has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys
|
||||||
if has_aux:
|
dit_config["head_dim_in"] = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
|
||||||
dit_config["head_type"] = "dualdpt"
|
dit_config["head_features"] = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
|
||||||
# DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width).
|
dit_config["head_out_channels"] = [
|
||||||
head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
|
|
||||||
out_channels = [
|
|
||||||
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
|
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
|
||||||
for i in range(4)
|
for i in range(4)
|
||||||
]
|
]
|
||||||
features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
|
if has_aux:
|
||||||
dit_config["head_dim_in"] = head_dim_in
|
# DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width).
|
||||||
|
dit_config["head_type"] = "dualdpt"
|
||||||
dit_config["head_output_dim"] = 2
|
dit_config["head_output_dim"] = 2
|
||||||
dit_config["head_features"] = features
|
|
||||||
dit_config["head_out_channels"] = out_channels
|
|
||||||
dit_config["head_use_sky_head"] = False
|
dit_config["head_use_sky_head"] = False
|
||||||
else:
|
else:
|
||||||
dit_config["head_type"] = "dpt"
|
dit_config["head_type"] = "dpt"
|
||||||
head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
|
dit_config["head_output_dim"] = state_dict[
|
||||||
out_channels = [
|
|
||||||
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
|
|
||||||
for i in range(4)
|
|
||||||
]
|
|
||||||
features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
|
|
||||||
output_dim = state_dict[
|
|
||||||
'{}head.scratch.output_conv2.2.weight'.format(key_prefix)
|
'{}head.scratch.output_conv2.2.weight'.format(key_prefix)
|
||||||
].shape[0]
|
].shape[0]
|
||||||
dit_config["head_dim_in"] = head_dim_in
|
|
||||||
dit_config["head_output_dim"] = output_dim
|
|
||||||
dit_config["head_features"] = features
|
|
||||||
dit_config["head_out_channels"] = out_channels
|
|
||||||
dit_config["head_use_sky_head"] = (
|
dit_config["head_use_sky_head"] = (
|
||||||
'{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys
|
'{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user