mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-10 00:07:33 +08:00
Drop the in-code conversion of original DA3 models and add support for auto-detection of repackaged models.
This commit is contained in:
parent
8cbdd8f72e
commit
9494456d33
@ -11,8 +11,10 @@
|
||||
# auxiliary "ray" head of ``DualDPT`` is enabled the predicted ray map can
|
||||
# alternatively be used to estimate pose via RANSAC (``use_ray_pose=True``).
|
||||
# The 3D-Gaussian head and the nested-architecture wrapper are intentionally
|
||||
# left out of scope here; their state-dict keys are filtered in
|
||||
# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``.
|
||||
# left out of scope here; their state-dict keys (``gs_head.*``,
|
||||
# ``gs_adapter.*``) are dropped when repackaging the checkpoint with
|
||||
# ``scripts/convert_da3.py``, which also remaps the backbone into the native
|
||||
# ``Dinov2Model`` layout that this module loads directly.
|
||||
#
|
||||
# The backbone is shared with the CLIP-vision DINOv2 path
|
||||
# (``comfy.image_encoders.dino2.Dinov2Model``); the DA3-specific extensions
|
||||
|
||||
@ -849,14 +849,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
|
||||
return dit_config
|
||||
|
||||
# Depth Anything 3
|
||||
if '{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix) in state_dict_keys:
|
||||
# Depth Anything 3 (repackaged to ComfyUI's native Dinov2Model layout via scripts/convert_da3.py)
|
||||
if '{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "DepthAnything3"
|
||||
|
||||
patch_w = state_dict['{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix)]
|
||||
patch_w = state_dict['{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix)]
|
||||
embed_dim = patch_w.shape[0]
|
||||
depth = count_blocks(state_dict_keys, '{}backbone.pretrained.blocks.'.format(key_prefix) + '{}.')
|
||||
depth = count_blocks(state_dict_keys, '{}backbone.encoder.layer.'.format(key_prefix) + '{}.')
|
||||
|
||||
# Backbone preset is determined by embed_dim (matches vits/vitb/vitl/vitg).
|
||||
backbone_name = {384: "vits", 768: "vitb", 1024: "vitl", 1536: "vitg"}.get(embed_dim)
|
||||
@ -865,11 +865,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["backbone_name"] = backbone_name
|
||||
|
||||
# Detect DA3 extensions on top of vanilla DINOv2.
|
||||
has_camera_token = '{}backbone.pretrained.camera_token'.format(key_prefix) in state_dict_keys
|
||||
# qk-norm shows up as `attn.q_norm.weight` on enabled blocks.
|
||||
has_camera_token = '{}backbone.embeddings.camera_token'.format(key_prefix) in state_dict_keys
|
||||
# qk-norm shows up as `attention.q_norm.weight` on enabled blocks.
|
||||
qknorm_indices = [
|
||||
i for i in range(depth)
|
||||
if '{}backbone.pretrained.blocks.{}.attn.q_norm.weight'.format(key_prefix, i) in state_dict_keys
|
||||
if '{}backbone.encoder.layer.{}.attention.q_norm.weight'.format(key_prefix, i) in state_dict_keys
|
||||
]
|
||||
qknorm_start = qknorm_indices[0] if qknorm_indices else -1
|
||||
|
||||
|
||||
@ -2022,84 +2022,6 @@ class DepthAnything3(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
# Drop Gaussian-head weights; remap fused backbone QKV to Dinov2Model layout.
|
||||
drop_prefixes = ("gs_head.", "gs_adapter.")
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(drop_prefixes):
|
||||
state_dict.pop(k)
|
||||
return _da3_remap_backbone_keys(state_dict, prefix="backbone.")
|
||||
|
||||
|
||||
def _da3_remap_backbone_keys(state_dict, prefix="backbone."):
|
||||
"""Map ``backbone.pretrained.*`` (upstream DA3) keys to ``Dinov2Model`` under ``prefix``."""
|
||||
pre = prefix + "pretrained."
|
||||
src_keys = [k for k in state_dict.keys() if k.startswith(pre)]
|
||||
if not src_keys:
|
||||
return state_dict
|
||||
|
||||
static_renames = {
|
||||
pre + "patch_embed.proj.weight": prefix + "embeddings.patch_embeddings.projection.weight",
|
||||
pre + "patch_embed.proj.bias": prefix + "embeddings.patch_embeddings.projection.bias",
|
||||
pre + "pos_embed": prefix + "embeddings.position_embeddings",
|
||||
pre + "cls_token": prefix + "embeddings.cls_token",
|
||||
pre + "camera_token": prefix + "embeddings.camera_token",
|
||||
pre + "norm.weight": prefix + "layernorm.weight",
|
||||
pre + "norm.bias": prefix + "layernorm.bias",
|
||||
}
|
||||
for src, dst in static_renames.items():
|
||||
if src in state_dict:
|
||||
state_dict[dst] = state_dict.pop(src)
|
||||
|
||||
block_pre = pre + "blocks."
|
||||
block_keys = [k for k in state_dict.keys() if k.startswith(block_pre)]
|
||||
for k in block_keys:
|
||||
rest = k[len(block_pre):] # e.g. "5.attn.qkv.weight"
|
||||
idx_str, _, sub = rest.partition(".")
|
||||
target_block = "{}encoder.layer.{}.".format(prefix, idx_str)
|
||||
|
||||
# Fused QKV -> split query/key/value linears.
|
||||
if sub == "attn.qkv.weight":
|
||||
qkv = state_dict.pop(k)
|
||||
c = qkv.shape[0] // 3
|
||||
state_dict[target_block + "attention.attention.query.weight"] = qkv[:c].clone()
|
||||
state_dict[target_block + "attention.attention.key.weight"] = qkv[c:2 * c].clone()
|
||||
state_dict[target_block + "attention.attention.value.weight"] = qkv[2 * c:].clone()
|
||||
continue
|
||||
if sub == "attn.qkv.bias":
|
||||
qkv = state_dict.pop(k)
|
||||
c = qkv.shape[0] // 3
|
||||
state_dict[target_block + "attention.attention.query.bias"] = qkv[:c].clone()
|
||||
state_dict[target_block + "attention.attention.key.bias"] = qkv[c:2 * c].clone()
|
||||
state_dict[target_block + "attention.attention.value.bias"] = qkv[2 * c:].clone()
|
||||
continue
|
||||
|
||||
# Sub-key remap (suffix preserved).
|
||||
if sub.startswith("attn.proj."):
|
||||
tail = sub[len("attn.proj."):]
|
||||
new = "attention.output.dense." + tail
|
||||
elif sub.startswith("attn.q_norm."):
|
||||
new = "attention.q_norm." + sub[len("attn.q_norm."):]
|
||||
elif sub.startswith("attn.k_norm."):
|
||||
new = "attention.k_norm." + sub[len("attn.k_norm."):]
|
||||
elif sub == "ls1.gamma":
|
||||
new = "layer_scale1.lambda1"
|
||||
elif sub == "ls2.gamma":
|
||||
new = "layer_scale2.lambda1"
|
||||
elif sub.startswith("mlp.w12."):
|
||||
new = "mlp.weights_in." + sub[len("mlp.w12."):]
|
||||
elif sub.startswith("mlp.w3."):
|
||||
new = "mlp.weights_out." + sub[len("mlp.w3."):]
|
||||
elif sub.startswith(("norm1.", "norm2.", "mlp.fc1.", "mlp.fc2.")):
|
||||
new = sub
|
||||
else:
|
||||
# Unrecognised key -- leave as-is so load_state_dict can complain.
|
||||
continue
|
||||
|
||||
state_dict[target_block + new] = state_dict.pop(k)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
class ErnieImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user