Update DA3 to use dino2.get_intermediate_layers_da3

This commit is contained in:
Talmaj Marinc 2026-05-19 12:11:17 +02:00
parent e26ba849e6
commit aa4eef71d0
2 changed files with 23 additions and 2 deletions

View File

@ -363,6 +363,21 @@ class Dinov2Model(torch.nn.Module):
return x, i, pooled_output, None
def get_intermediate_layers(self, pixel_values, indices, apply_norm=True):
"""Single-view multi-layer feature extraction (MoGe / vanilla DINOv2).
For the multi-view Depth Anything 3 path (RoPE, alt-attention,
camera-token injection, ref-view selection, cat_token), use
:meth:`get_intermediate_layers_da3` instead.
Args:
pixel_values: ``(B, 3, H, W)`` single-view input.
indices: layer indices to extract; supports negative indexing.
apply_norm: if True, apply the final layernorm to each output.
Returns:
list of ``(patch_tokens, cls_token)`` tuples with shapes
``(B, N_patch, C)`` and ``(B, C)`` (one entry per ``indices``).
"""
x = self.embeddings(pixel_values)
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
n_layers = len(self.encoder.layer)
@ -415,7 +430,13 @@ class Dinov2Model(torch.nn.Module):
def get_intermediate_layers_da3(self, pixel_values, out_layers, cam_token=None,
ref_view_strategy="saddle_balanced",
export_feat_layers=None):
"""Multi-layer DINOv2 feature extraction used by Depth Anything 3.
"""Multi-view multi-layer feature extraction used by Depth Anything 3.
Adds RoPE positions, alternating local/global attention across views,
camera-token injection, reference-view selection/reordering,
``cat_token`` output and optional auxiliary feature exports on top of
the vanilla DINOv2 path. For the single-view MoGe / CLIP-vision use
case, see :meth:`get_intermediate_layers`.
Args:
pixel_values: ``(B, S, 3, H, W)`` views or ``(B, 3, H, W)``.

View File

@ -243,7 +243,7 @@ class DepthAnything3Net(nn.Module):
if isinstance(self.head, DualDPT):
self.head.enable_aux = bool(use_ray_pose)
feats, aux_feats = self.backbone.get_intermediate_layers(
feats, aux_feats = self.backbone.get_intermediate_layers_da3(
image, self.out_layers, cam_token=cam_token,
ref_view_strategy=ref_view_strategy,
export_feat_layers=export_feat_layers,