Drop the in-code conversion of original DA3 models and add support for auto-detection of repackaged models.
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

This commit is contained in:
Talmaj Marinc 2026-06-09 13:39:29 +02:00
parent 8cbdd8f72e
commit 9494456d33
3 changed files with 11 additions and 87 deletions

View File

@ -11,8 +11,10 @@
# auxiliary "ray" head of ``DualDPT`` is enabled the predicted ray map can
# alternatively be used to estimate pose via RANSAC (``use_ray_pose=True``).
# The 3D-Gaussian head and the nested-architecture wrapper are intentionally
# left out of scope here; their state-dict keys are filtered in
# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``.
# left out of scope here; their state-dict keys (``gs_head.*``,
# ``gs_adapter.*``) are dropped when repackaging the checkpoint with
# ``scripts/convert_da3.py``, which also remaps the backbone into the native
# ``Dinov2Model`` layout that this module loads directly.
#
# The backbone is shared with the CLIP-vision DINOv2 path
# (``comfy.image_encoders.dino2.Dinov2Model``); the DA3-specific extensions

View File

@ -849,14 +849,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
return dit_config
# Depth Anything 3
if '{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix) in state_dict_keys:
# Depth Anything 3 (repackaged to ComfyUI's native Dinov2Model layout via scripts/convert_da3.py)
if '{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix) in state_dict_keys:
dit_config = {}
dit_config["image_model"] = "DepthAnything3"
patch_w = state_dict['{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix)]
patch_w = state_dict['{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix)]
embed_dim = patch_w.shape[0]
depth = count_blocks(state_dict_keys, '{}backbone.pretrained.blocks.'.format(key_prefix) + '{}.')
depth = count_blocks(state_dict_keys, '{}backbone.encoder.layer.'.format(key_prefix) + '{}.')
# Backbone preset is determined by embed_dim (matches vits/vitb/vitl/vitg).
backbone_name = {384: "vits", 768: "vitb", 1024: "vitl", 1536: "vitg"}.get(embed_dim)
@ -865,11 +865,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["backbone_name"] = backbone_name
# Detect DA3 extensions on top of vanilla DINOv2.
has_camera_token = '{}backbone.pretrained.camera_token'.format(key_prefix) in state_dict_keys
# qk-norm shows up as `attn.q_norm.weight` on enabled blocks.
has_camera_token = '{}backbone.embeddings.camera_token'.format(key_prefix) in state_dict_keys
# qk-norm shows up as `attention.q_norm.weight` on enabled blocks.
qknorm_indices = [
i for i in range(depth)
if '{}backbone.pretrained.blocks.{}.attn.q_norm.weight'.format(key_prefix, i) in state_dict_keys
if '{}backbone.encoder.layer.{}.attention.q_norm.weight'.format(key_prefix, i) in state_dict_keys
]
qknorm_start = qknorm_indices[0] if qknorm_indices else -1

View File

@ -2022,84 +2022,6 @@ class DepthAnything3(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return None
def process_unet_state_dict(self, state_dict):
# Drop Gaussian-head weights; remap fused backbone QKV to Dinov2Model layout.
drop_prefixes = ("gs_head.", "gs_adapter.")
for k in list(state_dict.keys()):
if k.startswith(drop_prefixes):
state_dict.pop(k)
return _da3_remap_backbone_keys(state_dict, prefix="backbone.")
def _da3_remap_backbone_keys(state_dict, prefix="backbone."):
"""Map ``backbone.pretrained.*`` (upstream DA3) keys to ``Dinov2Model`` under ``prefix``."""
pre = prefix + "pretrained."
src_keys = [k for k in state_dict.keys() if k.startswith(pre)]
if not src_keys:
return state_dict
static_renames = {
pre + "patch_embed.proj.weight": prefix + "embeddings.patch_embeddings.projection.weight",
pre + "patch_embed.proj.bias": prefix + "embeddings.patch_embeddings.projection.bias",
pre + "pos_embed": prefix + "embeddings.position_embeddings",
pre + "cls_token": prefix + "embeddings.cls_token",
pre + "camera_token": prefix + "embeddings.camera_token",
pre + "norm.weight": prefix + "layernorm.weight",
pre + "norm.bias": prefix + "layernorm.bias",
}
for src, dst in static_renames.items():
if src in state_dict:
state_dict[dst] = state_dict.pop(src)
block_pre = pre + "blocks."
block_keys = [k for k in state_dict.keys() if k.startswith(block_pre)]
for k in block_keys:
rest = k[len(block_pre):] # e.g. "5.attn.qkv.weight"
idx_str, _, sub = rest.partition(".")
target_block = "{}encoder.layer.{}.".format(prefix, idx_str)
# Fused QKV -> split query/key/value linears.
if sub == "attn.qkv.weight":
qkv = state_dict.pop(k)
c = qkv.shape[0] // 3
state_dict[target_block + "attention.attention.query.weight"] = qkv[:c].clone()
state_dict[target_block + "attention.attention.key.weight"] = qkv[c:2 * c].clone()
state_dict[target_block + "attention.attention.value.weight"] = qkv[2 * c:].clone()
continue
if sub == "attn.qkv.bias":
qkv = state_dict.pop(k)
c = qkv.shape[0] // 3
state_dict[target_block + "attention.attention.query.bias"] = qkv[:c].clone()
state_dict[target_block + "attention.attention.key.bias"] = qkv[c:2 * c].clone()
state_dict[target_block + "attention.attention.value.bias"] = qkv[2 * c:].clone()
continue
# Sub-key remap (suffix preserved).
if sub.startswith("attn.proj."):
tail = sub[len("attn.proj."):]
new = "attention.output.dense." + tail
elif sub.startswith("attn.q_norm."):
new = "attention.q_norm." + sub[len("attn.q_norm."):]
elif sub.startswith("attn.k_norm."):
new = "attention.k_norm." + sub[len("attn.k_norm."):]
elif sub == "ls1.gamma":
new = "layer_scale1.lambda1"
elif sub == "ls2.gamma":
new = "layer_scale2.lambda1"
elif sub.startswith("mlp.w12."):
new = "mlp.weights_in." + sub[len("mlp.w12."):]
elif sub.startswith("mlp.w3."):
new = "mlp.weights_out." + sub[len("mlp.w3."):]
elif sub.startswith(("norm1.", "norm2.", "mlp.fc1.", "mlp.fc2.")):
new = sub
else:
# Unrecognised key -- leave as-is so load_state_dict can complain.
continue
state_dict[target_block + new] = state_dict.pop(k)
return state_dict
class ErnieImage(supported_models_base.BASE):
unet_config = {