mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-21 06:27:24 +08:00
Update DA3 to use dino2.get_intermediate_layers_da3
This commit is contained in:
parent
e26ba849e6
commit
aa4eef71d0
@ -363,6 +363,21 @@ class Dinov2Model(torch.nn.Module):
|
||||
return x, i, pooled_output, None
|
||||
|
||||
def get_intermediate_layers(self, pixel_values, indices, apply_norm=True):
|
||||
"""Single-view multi-layer feature extraction (MoGe / vanilla DINOv2).
|
||||
|
||||
For the multi-view Depth Anything 3 path (RoPE, alt-attention,
|
||||
camera-token injection, ref-view selection, cat_token), use
|
||||
:meth:`get_intermediate_layers_da3` instead.
|
||||
|
||||
Args:
|
||||
pixel_values: ``(B, 3, H, W)`` single-view input.
|
||||
indices: layer indices to extract; supports negative indexing.
|
||||
apply_norm: if True, apply the final layernorm to each output.
|
||||
|
||||
Returns:
|
||||
list of ``(patch_tokens, cls_token)`` tuples with shapes
|
||||
``(B, N_patch, C)`` and ``(B, C)`` (one entry per ``indices``).
|
||||
"""
|
||||
x = self.embeddings(pixel_values)
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
n_layers = len(self.encoder.layer)
|
||||
@ -415,7 +430,13 @@ class Dinov2Model(torch.nn.Module):
|
||||
def get_intermediate_layers_da3(self, pixel_values, out_layers, cam_token=None,
|
||||
ref_view_strategy="saddle_balanced",
|
||||
export_feat_layers=None):
|
||||
"""Multi-layer DINOv2 feature extraction used by Depth Anything 3.
|
||||
"""Multi-view multi-layer feature extraction used by Depth Anything 3.
|
||||
|
||||
Adds RoPE positions, alternating local/global attention across views,
|
||||
camera-token injection, reference-view selection/reordering,
|
||||
``cat_token`` output and optional auxiliary feature exports on top of
|
||||
the vanilla DINOv2 path. For the single-view MoGe / CLIP-vision use
|
||||
case, see :meth:`get_intermediate_layers`.
|
||||
|
||||
Args:
|
||||
pixel_values: ``(B, S, 3, H, W)`` views or ``(B, 3, H, W)``.
|
||||
|
||||
@ -243,7 +243,7 @@ class DepthAnything3Net(nn.Module):
|
||||
if isinstance(self.head, DualDPT):
|
||||
self.head.enable_aux = bool(use_ray_pose)
|
||||
|
||||
feats, aux_feats = self.backbone.get_intermediate_layers(
|
||||
feats, aux_feats = self.backbone.get_intermediate_layers_da3(
|
||||
image, self.out_layers, cam_token=cam_token,
|
||||
ref_view_strategy=ref_view_strategy,
|
||||
export_feat_layers=export_feat_layers,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user