diff --git a/comfy/ldm/depth_anything_3/model.py b/comfy/ldm/depth_anything_3/model.py index f939417c3..2ca472a7d 100644 --- a/comfy/ldm/depth_anything_3/model.py +++ b/comfy/ldm/depth_anything_3/model.py @@ -11,8 +11,10 @@ # auxiliary "ray" head of ``DualDPT`` is enabled the predicted ray map can # alternatively be used to estimate pose via RANSAC (``use_ray_pose=True``). # The 3D-Gaussian head and the nested-architecture wrapper are intentionally -# left out of scope here; their state-dict keys are filtered in -# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``. +# left out of scope here; their state-dict keys (``gs_head.*``, +# ``gs_adapter.*``) are dropped when repackaging the checkpoint with +# ``scripts/convert_da3.py``, which also remaps the backbone into the native +# ``Dinov2Model`` layout that this module loads directly. # # The backbone is shared with the CLIP-vision DINOv2 path # (``comfy.image_encoders.dino2.Dinov2Model``); the DA3-specific extensions diff --git a/comfy/model_detection.py b/comfy/model_detection.py index c9c123bb8..0e48edb13 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -849,14 +849,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0] return dit_config - # Depth Anything 3 - if '{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix) in state_dict_keys: + # Depth Anything 3 (repackaged to ComfyUI's native Dinov2Model layout via scripts/convert_da3.py) + if '{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix) in state_dict_keys: dit_config = {} dit_config["image_model"] = "DepthAnything3" - patch_w = state_dict['{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix)] + patch_w = state_dict['{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix)] embed_dim = patch_w.shape[0] - depth = count_blocks(state_dict_keys, '{}backbone.pretrained.blocks.'.format(key_prefix) + '{}.') + depth = count_blocks(state_dict_keys, '{}backbone.encoder.layer.'.format(key_prefix) + '{}.') # Backbone preset is determined by embed_dim (matches vits/vitb/vitl/vitg). backbone_name = {384: "vits", 768: "vitb", 1024: "vitl", 1536: "vitg"}.get(embed_dim) @@ -865,11 +865,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["backbone_name"] = backbone_name # Detect DA3 extensions on top of vanilla DINOv2. - has_camera_token = '{}backbone.pretrained.camera_token'.format(key_prefix) in state_dict_keys - # qk-norm shows up as `attn.q_norm.weight` on enabled blocks. + has_camera_token = '{}backbone.embeddings.camera_token'.format(key_prefix) in state_dict_keys + # qk-norm shows up as `attention.q_norm.weight` on enabled blocks. qknorm_indices = [ i for i in range(depth) - if '{}backbone.pretrained.blocks.{}.attn.q_norm.weight'.format(key_prefix, i) in state_dict_keys + if '{}backbone.encoder.layer.{}.attention.q_norm.weight'.format(key_prefix, i) in state_dict_keys ] qknorm_start = qknorm_indices[0] if qknorm_indices else -1 diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 55f96ba9e..4059cd1a1 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -2022,84 +2022,6 @@ class DepthAnything3(supported_models_base.BASE): def clip_target(self, state_dict={}): return None - def process_unet_state_dict(self, state_dict): - # Drop Gaussian-head weights; remap fused backbone QKV to Dinov2Model layout. - drop_prefixes = ("gs_head.", "gs_adapter.") - for k in list(state_dict.keys()): - if k.startswith(drop_prefixes): - state_dict.pop(k) - return _da3_remap_backbone_keys(state_dict, prefix="backbone.") - - -def _da3_remap_backbone_keys(state_dict, prefix="backbone."): - """Map ``backbone.pretrained.*`` (upstream DA3) keys to ``Dinov2Model`` under ``prefix``.""" - pre = prefix + "pretrained." - src_keys = [k for k in state_dict.keys() if k.startswith(pre)] - if not src_keys: - return state_dict - - static_renames = { - pre + "patch_embed.proj.weight": prefix + "embeddings.patch_embeddings.projection.weight", - pre + "patch_embed.proj.bias": prefix + "embeddings.patch_embeddings.projection.bias", - pre + "pos_embed": prefix + "embeddings.position_embeddings", - pre + "cls_token": prefix + "embeddings.cls_token", - pre + "camera_token": prefix + "embeddings.camera_token", - pre + "norm.weight": prefix + "layernorm.weight", - pre + "norm.bias": prefix + "layernorm.bias", - } - for src, dst in static_renames.items(): - if src in state_dict: - state_dict[dst] = state_dict.pop(src) - - block_pre = pre + "blocks." - block_keys = [k for k in state_dict.keys() if k.startswith(block_pre)] - for k in block_keys: - rest = k[len(block_pre):] # e.g. "5.attn.qkv.weight" - idx_str, _, sub = rest.partition(".") - target_block = "{}encoder.layer.{}.".format(prefix, idx_str) - - # Fused QKV -> split query/key/value linears. - if sub == "attn.qkv.weight": - qkv = state_dict.pop(k) - c = qkv.shape[0] // 3 - state_dict[target_block + "attention.attention.query.weight"] = qkv[:c].clone() - state_dict[target_block + "attention.attention.key.weight"] = qkv[c:2 * c].clone() - state_dict[target_block + "attention.attention.value.weight"] = qkv[2 * c:].clone() - continue - if sub == "attn.qkv.bias": - qkv = state_dict.pop(k) - c = qkv.shape[0] // 3 - state_dict[target_block + "attention.attention.query.bias"] = qkv[:c].clone() - state_dict[target_block + "attention.attention.key.bias"] = qkv[c:2 * c].clone() - state_dict[target_block + "attention.attention.value.bias"] = qkv[2 * c:].clone() - continue - - # Sub-key remap (suffix preserved). - if sub.startswith("attn.proj."): - tail = sub[len("attn.proj."):] - new = "attention.output.dense." + tail - elif sub.startswith("attn.q_norm."): - new = "attention.q_norm." + sub[len("attn.q_norm."):] - elif sub.startswith("attn.k_norm."): - new = "attention.k_norm." + sub[len("attn.k_norm."):] - elif sub == "ls1.gamma": - new = "layer_scale1.lambda1" - elif sub == "ls2.gamma": - new = "layer_scale2.lambda1" - elif sub.startswith("mlp.w12."): - new = "mlp.weights_in." + sub[len("mlp.w12."):] - elif sub.startswith("mlp.w3."): - new = "mlp.weights_out." + sub[len("mlp.w3."):] - elif sub.startswith(("norm1.", "norm2.", "mlp.fc1.", "mlp.fc2.")): - new = sub - else: - # Unrecognised key -- leave as-is so load_state_dict can complain. - continue - - state_dict[target_block + new] = state_dict.pop(k) - - return state_dict - class ErnieImage(supported_models_base.BASE): unet_config = {