"""Camera-token encoder and decoder for Depth Anything 3. * :class:`CameraEnc` takes per-view extrinsics + intrinsics and produces a per-view camera token that gets injected at the alt-attention boundary in the DINOv2 backbone (block ``alt_start``). * :class:`CameraDec` takes the final-layer camera token output by the backbone and predicts a 9-D pose encoding (translation, quaternion, field-of-view). The module/parameter names match the upstream ``cam_enc.py``/``cam_dec.py`` so HF safetensors load directly with no key remapping (the upstream uses fused QKV linears, which we replicate here). """ from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F from .transform import affine_inverse, extri_intri_to_pose_encoding # ----------------------------------------------------------------------------- # Building blocks (mirror ``depth_anything_3.model.utils.{attention,block}``) # ----------------------------------------------------------------------------- class _Mlp(nn.Module): """Standard 2-layer MLP with GELU. Matches upstream ``utils.attention.Mlp``.""" def __init__(self, in_features, hidden_features=None, out_features=None, *, device=None, dtype=None, operations=None): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = operations.Linear(in_features, hidden_features, bias=True, device=device, dtype=dtype) self.fc2 = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype) def forward(self, x): return self.fc2(F.gelu(self.fc1(x))) class _LayerScale(nn.Module): """Per-channel learnable scaling. Matches upstream ``LayerScale``.""" def __init__(self, dim, *, device=None, dtype=None): super().__init__() self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) def forward(self, x): return x * self.gamma.to(dtype=x.dtype, device=x.device) class _Attention(nn.Module): """Self-attention with fused QKV projection. Mirrors upstream ``utils.attention.Attention``; layout matches the HF safetensors (``attn.qkv.{weight,bias}`` and ``attn.proj.{weight,bias}``). """ def __init__(self, dim, num_heads, *, device=None, dtype=None, operations=None): super().__init__() assert dim % num_heads == 0 self.num_heads = num_heads self.head_dim = dim // num_heads self.qkv = operations.Linear(dim, dim * 3, bias=True, device=device, dtype=dtype) self.proj = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) # 3, B, h, N, d q, k, v = qkv.unbind(0) out = F.scaled_dot_product_attention(q, k, v) out = out.transpose(1, 2).reshape(B, N, C) return self.proj(out) class _Block(nn.Module): """Pre-norm transformer block with LayerScale. Used by :class:`CameraEnc`. Layout follows upstream ``utils.block.Block``. """ def __init__(self, dim, num_heads, mlp_ratio=4, init_values=0.01, *, device=None, dtype=None, operations=None): super().__init__() self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype) self.attn = _Attention(dim, num_heads, device=device, dtype=dtype, operations=operations) self.ls1 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity() self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype) self.mlp = _Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), device=device, dtype=dtype, operations=operations) self.ls2 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity() def forward(self, x): x = x + self.ls1(self.attn(self.norm1(x))) x = x + self.ls2(self.mlp(self.norm2(x))) return x # ----------------------------------------------------------------------------- # Camera encoder # ----------------------------------------------------------------------------- class CameraEnc(nn.Module): """Encode per-view (extrinsics, intrinsics) into a camera token. Maps a 9-D pose-encoding vector through a small MLP up to the backbone's ``embed_dim``, then runs ``trunk_depth`` transformer blocks. The output has shape ``(B, S, embed_dim)`` and is injected at block ``alt_start`` of the DINOv2 backbone in place of the cls token. Parameters mirror the upstream ``cam_enc.py`` so HF weights load directly. """ def __init__( self, dim_out: int = 1024, dim_in: int = 9, trunk_depth: int = 4, target_dim: int = 9, num_heads: int = 16, mlp_ratio: int = 4, init_values: float = 0.01, *, device=None, dtype=None, operations=None, **_kwargs, ): super().__init__() self.target_dim = target_dim self.trunk_depth = trunk_depth self.trunk = nn.Sequential(*[ _Block(dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values, device=device, dtype=dtype, operations=operations) for _ in range(trunk_depth) ]) self.token_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype) self.trunk_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype) self.pose_branch = _Mlp( in_features=dim_in, hidden_features=dim_out // 2, out_features=dim_out, device=device, dtype=dtype, operations=operations, ) def forward(self, extrinsics: torch.Tensor, intrinsics: torch.Tensor, image_size_hw) -> torch.Tensor: """Encode camera parameters into ``(B, S, dim_out)`` tokens.""" c2ws = affine_inverse(extrinsics) pose_encoding = extri_intri_to_pose_encoding(c2ws, intrinsics, image_size_hw) tokens = self.pose_branch(pose_encoding.to(self.pose_branch.fc1.weight.dtype)) tokens = self.token_norm(tokens) tokens = self.trunk(tokens) tokens = self.trunk_norm(tokens) return tokens # ----------------------------------------------------------------------------- # Camera decoder # ----------------------------------------------------------------------------- class CameraDec(nn.Module): """Decode the final cam token into a 9-D pose encoding. Output layout: ``[T(3), quat_xyzw(4), fov_h, fov_w]``. The translation is always predicted by the network; the quaternion and FoV can either be predicted or supplied via ``camera_encoding`` (used at training time when GT cameras are available -- not exercised at inference here). Parameters mirror the upstream ``cam_dec.py`` so HF weights load directly. """ def __init__(self, dim_in: int = 1536, *, device=None, dtype=None, operations=None, **_kwargs): super().__init__() d = dim_in self.backbone = nn.Sequential( operations.Linear(d, d, device=device, dtype=dtype), nn.ReLU(), operations.Linear(d, d, device=device, dtype=dtype), nn.ReLU(), ) self.fc_t = operations.Linear(d, 3, device=device, dtype=dtype) self.fc_qvec = operations.Linear(d, 4, device=device, dtype=dtype) self.fc_fov = nn.Sequential( operations.Linear(d, 2, device=device, dtype=dtype), nn.ReLU(), ) def forward(self, feat: torch.Tensor, camera_encoding: "torch.Tensor | None" = None) -> torch.Tensor: """Decode ``(B, N, dim_in)`` cam tokens into ``(B, N, 9)`` pose enc.""" B, N = feat.shape[:2] feat = feat.reshape(B * N, -1) feat = self.backbone(feat) out_t = self.fc_t(feat.float()).reshape(B, N, 3) if camera_encoding is None: out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4) out_fov = self.fc_fov(feat.float()).reshape(B, N, 2) else: out_qvec = camera_encoding[..., 3:7] out_fov = camera_encoding[..., -2:] return torch.cat([out_t, out_qvec, out_fov], dim=-1)