Add support for pose estimation CORE-135.

This commit is contained in:
Talmaj Marinc 2026-05-12 12:14:57 +02:00
parent 4ad749ab17
commit b911c99fda
10 changed files with 1254 additions and 54 deletions

View File

@ -6,6 +6,10 @@ import torch.nn.functional as F
from comfy.text_encoders.bert import BertAttention
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.ldm.depth_anything_3.reference_view_selector import (
select_reference_view, reorder_by_reference, restore_original_order,
THRESH_FOR_REF_SELECTION,
)
class Dino2AttentionOutput(torch.nn.Module):
@ -262,19 +266,24 @@ class Dino2Embeddings(torch.nn.Module):
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
ph = h // self.patch_size # patch grid height
pw = w // self.patch_size # patch grid width
M = int(math.sqrt(N))
assert N == M * M
# Historical 0.1 offset preserves bicubic resample compatibility with
# the original DINOv2 release; see the upstream PR for context.
sx = float(w0 + 0.1) / M
sy = float(h0 + 0.1) / M
# ``scale_factor`` is interpreted as (height_scale, width_scale) by
# ``F.interpolate`` so we must put the height scale FIRST. Earlier
# revisions of this function had it swapped which only worked for
# square inputs (e.g. CLIP-vision square crops); non-square inputs
# like DA3-Small / DA3-Base multi-view paths exposed the bug.
sh = float(ph + 0.1) / M
sw = float(pw + 0.1) / M
patch_pos_embed = F.interpolate(
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
scale_factor=(sx, sy), mode="bicubic", antialias=False,
scale_factor=(sh, sw), mode="bicubic", antialias=False,
)
assert (w0, h0) == patch_pos_embed.shape[-2:]
assert (ph, pw) == patch_pos_embed.shape[-2:]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
@ -392,7 +401,9 @@ class Dinov2Model(torch.nn.Module):
x[:, :, 0] = inj
return x
def get_intermediate_layers(self, pixel_values, out_layers, cam_token=None):
def get_intermediate_layers(self, pixel_values, out_layers, cam_token=None,
ref_view_strategy="saddle_balanced",
export_feat_layers=None):
"""Multi-layer DINOv2 feature extraction used by Depth Anything 3.
Args:
@ -401,13 +412,22 @@ class Dinov2Model(torch.nn.Module):
cam_token: optional ``(B, S, dim)`` camera token to inject at
``alt_start``. If ``None`` and the model has its own
``camera_token`` parameter, that is used.
ref_view_strategy: when ``S >= 3`` and ``cam_token is None``,
pick a reference view via this strategy and move it to
position 0 right before the first alt-attention block.
The original view order is restored on the way out.
export_feat_layers: optional iterable of layer indices whose
local attention outputs to also return as auxiliary
features (``(B, S, N_patch, C)`` after final norm). Used
by the multi-view path to expose intermediate features
to the nested-architecture wrapper.
Returns:
List of ``(patch_tokens, cls_or_cam_token)`` tuples, one per
requested ``out_layers`` entry. ``patch_tokens`` has shape
``(B, S, N_patch, C)`` (or ``(B, S, N_patch, 2*C)`` when the
model was configured with ``cat_token=True``); the second item
has shape ``(B, S, C)``.
``(layer_outputs, aux_outputs)`` where ``layer_outputs`` is a
list of ``(patch_tokens, cls_or_cam_token)`` tuples (one per
``out_layers`` entry) and ``aux_outputs`` is a list of
``(B, S, N_patch, C)`` features for ``export_feat_layers``
(empty list when not requested).
"""
if pixel_values.ndim == 4:
pixel_values = pixel_values.unsqueeze(1)
@ -426,8 +446,12 @@ class Dinov2Model(torch.nn.Module):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
out_set = set(out_layers)
export_set = set(export_feat_layers) if export_feat_layers else set()
outputs: list[torch.Tensor] = []
aux_outputs: list[torch.Tensor] = []
local_x = x
b_idx = None
for i, blk in enumerate(self.encoder.layer):
apply_rope = self.rope is not None and i >= self.rope_start
@ -435,6 +459,15 @@ class Dinov2Model(torch.nn.Module):
l_pos = pos_local if apply_rope else None
g_pos = pos_global if apply_rope else None
# Reference-view selection threshold: matches the upstream constant
# ``THRESH_FOR_REF_SELECTION = 3``. Skipped when a user-supplied
# cam_token is provided (camera info already pins the geometry).
if (self.alt_start != -1 and i == self.alt_start - 1
and S >= THRESH_FOR_REF_SELECTION and cam_token is None):
b_idx = select_reference_view(x, strategy=ref_view_strategy)
x = reorder_by_reference(x, b_idx)
local_x = reorder_by_reference(local_x, b_idx)
if self.alt_start != -1 and i == self.alt_start:
x = self._inject_camera_token(x, B, S, cam_token)
@ -457,8 +490,18 @@ class Dinov2Model(torch.nn.Module):
out_x = torch.cat([local_x, x], dim=-1)
else:
out_x = x
# Restore original view order on the way out so heads see views
# in the user's expected order.
if b_idx is not None and self.alt_start != -1:
out_x = restore_original_order(out_x, b_idx)
outputs.append(out_x)
if i in export_set:
aux = x
if b_idx is not None and self.alt_start != -1:
aux = restore_original_order(aux, b_idx)
aux_outputs.append(aux)
# Apply final norm. When ``cat_token`` is set, only the right half
# ("global" features) is normalised; the left half is left as-is to
# match the upstream DA3 head signature.
@ -477,4 +520,8 @@ class Dinov2Model(torch.nn.Module):
# Drop cls/cam token from the patch sequence.
normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed]
return list(zip(normed, cls_tokens))
# Final layernorm + drop cls token from auxiliary features too.
aux_normed = [self.layernorm(o)[..., 1 + self.num_register_tokens:, :]
for o in aux_outputs]
return list(zip(normed, cls_tokens)), aux_normed

View File

@ -0,0 +1,214 @@
"""Camera-token encoder and decoder for Depth Anything 3.
* :class:`CameraEnc` takes per-view extrinsics + intrinsics and produces a
per-view camera token that gets injected at the alt-attention boundary
in the DINOv2 backbone (block ``alt_start``).
* :class:`CameraDec` takes the final-layer camera token output by the
backbone and predicts a 9-D pose encoding (translation, quaternion,
field-of-view).
The module/parameter names match the upstream ``cam_enc.py``/``cam_dec.py``
so HF safetensors load directly with no key remapping (the upstream uses
fused QKV linears, which we replicate here).
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from .transform import affine_inverse, extri_intri_to_pose_encoding
# -----------------------------------------------------------------------------
# Building blocks (mirror ``depth_anything_3.model.utils.{attention,block}``)
# -----------------------------------------------------------------------------
class _Mlp(nn.Module):
"""Standard 2-layer MLP with GELU. Matches upstream ``utils.attention.Mlp``."""
def __init__(self, in_features, hidden_features=None, out_features=None,
*, device=None, dtype=None, operations=None):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = operations.Linear(in_features, hidden_features, bias=True,
device=device, dtype=dtype)
self.fc2 = operations.Linear(hidden_features, out_features, bias=True,
device=device, dtype=dtype)
def forward(self, x):
return self.fc2(F.gelu(self.fc1(x)))
class _LayerScale(nn.Module):
"""Per-channel learnable scaling. Matches upstream ``LayerScale``."""
def __init__(self, dim, *, device=None, dtype=None):
super().__init__()
self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x):
return x * self.gamma.to(dtype=x.dtype, device=x.device)
class _Attention(nn.Module):
"""Self-attention with fused QKV projection.
Mirrors upstream ``utils.attention.Attention``; layout matches the
HF safetensors (``attn.qkv.{weight,bias}`` and ``attn.proj.{weight,bias}``).
"""
def __init__(self, dim, num_heads,
*, device=None, dtype=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=True,
device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, bias=True,
device=device, dtype=dtype)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # 3, B, h, N, d
q, k, v = qkv.unbind(0)
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(B, N, C)
return self.proj(out)
class _Block(nn.Module):
"""Pre-norm transformer block with LayerScale.
Used by :class:`CameraEnc`. Layout follows upstream ``utils.block.Block``.
"""
def __init__(self, dim, num_heads, mlp_ratio=4, init_values=0.01,
*, device=None, dtype=None, operations=None):
super().__init__()
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.attn = _Attention(dim, num_heads,
device=device, dtype=dtype, operations=operations)
self.ls1 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.mlp = _Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio),
device=device, dtype=dtype, operations=operations)
self.ls2 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
def forward(self, x):
x = x + self.ls1(self.attn(self.norm1(x)))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
# -----------------------------------------------------------------------------
# Camera encoder
# -----------------------------------------------------------------------------
class CameraEnc(nn.Module):
"""Encode per-view (extrinsics, intrinsics) into a camera token.
Maps a 9-D pose-encoding vector through a small MLP up to the backbone's
``embed_dim``, then runs ``trunk_depth`` transformer blocks. The output
has shape ``(B, S, embed_dim)`` and is injected at block ``alt_start``
of the DINOv2 backbone in place of the cls token.
Parameters mirror the upstream ``cam_enc.py`` so HF weights load directly.
"""
def __init__(
self,
dim_out: int = 1024,
dim_in: int = 9,
trunk_depth: int = 4,
target_dim: int = 9,
num_heads: int = 16,
mlp_ratio: int = 4,
init_values: float = 0.01,
*,
device=None, dtype=None, operations=None,
**_kwargs,
):
super().__init__()
self.target_dim = target_dim
self.trunk_depth = trunk_depth
self.trunk = nn.Sequential(*[
_Block(dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio,
init_values=init_values,
device=device, dtype=dtype, operations=operations)
for _ in range(trunk_depth)
])
self.token_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
self.trunk_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
self.pose_branch = _Mlp(
in_features=dim_in,
hidden_features=dim_out // 2,
out_features=dim_out,
device=device, dtype=dtype, operations=operations,
)
def forward(self, extrinsics: torch.Tensor, intrinsics: torch.Tensor,
image_size_hw) -> torch.Tensor:
"""Encode camera parameters into ``(B, S, dim_out)`` tokens."""
c2ws = affine_inverse(extrinsics)
pose_encoding = extri_intri_to_pose_encoding(c2ws, intrinsics, image_size_hw)
tokens = self.pose_branch(pose_encoding.to(self.pose_branch.fc1.weight.dtype))
tokens = self.token_norm(tokens)
tokens = self.trunk(tokens)
tokens = self.trunk_norm(tokens)
return tokens
# -----------------------------------------------------------------------------
# Camera decoder
# -----------------------------------------------------------------------------
class CameraDec(nn.Module):
"""Decode the final cam token into a 9-D pose encoding.
Output layout: ``[T(3), quat_xyzw(4), fov_h, fov_w]``. The translation is
always predicted by the network; the quaternion and FoV can either be
predicted or supplied via ``camera_encoding`` (used at training time
when GT cameras are available -- not exercised at inference here).
Parameters mirror the upstream ``cam_dec.py`` so HF weights load directly.
"""
def __init__(self, dim_in: int = 1536,
*, device=None, dtype=None, operations=None, **_kwargs):
super().__init__()
d = dim_in
self.backbone = nn.Sequential(
operations.Linear(d, d, device=device, dtype=dtype),
nn.ReLU(),
operations.Linear(d, d, device=device, dtype=dtype),
nn.ReLU(),
)
self.fc_t = operations.Linear(d, 3, device=device, dtype=dtype)
self.fc_qvec = operations.Linear(d, 4, device=device, dtype=dtype)
self.fc_fov = nn.Sequential(
operations.Linear(d, 2, device=device, dtype=dtype),
nn.ReLU(),
)
def forward(self, feat: torch.Tensor,
camera_encoding: "torch.Tensor | None" = None) -> torch.Tensor:
"""Decode ``(B, N, dim_in)`` cam tokens into ``(B, N, 9)`` pose enc."""
B, N = feat.shape[:2]
feat = feat.reshape(B * N, -1)
feat = self.backbone(feat)
out_t = self.fc_t(feat.float()).reshape(B, N, 3)
if camera_encoding is None:
out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4)
out_fov = self.fc_fov(feat.float()).reshape(B, N, 2)
else:
out_qvec = camera_encoding[..., 3:7]
out_fov = camera_encoding[..., -2:]
return torch.cat([out_t, out_qvec, out_fov], dim=-1)

View File

@ -353,7 +353,10 @@ class DualDPT(nn.Module):
"""Two-head DPT used by DA3-Small / DA3-Base.
The auxiliary "ray" head is constructed so that HF state-dict keys load
cleanly, but its outputs are unused on the monocular path.
cleanly. It is only executed when :attr:`enable_aux` is set on the
instance (typically by ``DepthAnything3Net`` when running multi-view
with ``use_ray_pose=True``); otherwise the monocular path skips it for
speed and the auxiliary submodules sit idle.
"""
def __init__(
@ -382,6 +385,9 @@ class DualDPT(nn.Module):
self.aux_out1_conv_num = aux_out1_conv_num
self.head_main, self.head_aux = head_names
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
# Toggle the auxiliary ray branch at runtime. Default off (mono path).
# ``DepthAnything3Net`` flips this on when running multi-view + ray-pose.
self.enable_aux: bool = False
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
out_channels = list(out_channels)
@ -489,9 +495,18 @@ class DualDPT(nn.Module):
# Main pyramid (output_conv1 is applied inside the upstream `_fuse`,
# before interpolation -- replicate that order here).
m = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
if self.enable_aux:
a4 = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:])
aux_pyr = [a4]
m = self.scratch.refinenet3(m, l3_rn, size=l2_rn.shape[2:])
if self.enable_aux:
aux_pyr.append(self.scratch.refinenet3_aux(aux_pyr[-1], l3_rn, size=l2_rn.shape[2:]))
m = self.scratch.refinenet2(m, l2_rn, size=l1_rn.shape[2:])
if self.enable_aux:
aux_pyr.append(self.scratch.refinenet2_aux(aux_pyr[-1], l2_rn, size=l1_rn.shape[2:]))
m = self.scratch.refinenet1(m, l1_rn)
if self.enable_aux:
aux_pyr.append(self.scratch.refinenet1_aux(aux_pyr[-1], l1_rn))
m = self.scratch.output_conv1(m)
h_out = int(ph * self.patch_size / self.down_ratio)
@ -510,8 +525,25 @@ class DualDPT(nn.Module):
f"{self.head_main}_conf": depth_conf.view(B, S, *depth_conf.shape[1:]),
}
# NOTE: we intentionally do not run the auxiliary "ray" branch — it is
# only needed for pose/ray-conditioned outputs which are out of scope
# for this port. The aux submodules are still built so HF weights load.
if self.enable_aux:
# Auxiliary "ray" head (multi-level inside) -- only the last level
# is returned. Mirrors upstream ``DualDPT._fuse`` + ``_forward_impl``:
# each aux pyramid level goes through ``output_conv1_aux[i]``
# (5-layer conv stack that ends at ``features // 2`` channels),
# then the last level optionally gets a pos-embed and finally
# ``output_conv2_aux[-1]``.
aux_processed = [
self.scratch.output_conv1_aux[i](a) for i, a in enumerate(aux_pyr)
]
last_aux = aux_processed[-1]
if self.pos_embed:
last_aux = _add_pos_embed(last_aux, W, H)
last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux)
fmap_last = last_aux_logits.permute(0, 2, 3, 1)
# Channels: [ray(6), ray_conf(1)]; ray uses 'linear' activation.
aux_pred = fmap_last[..., :-1]
aux_conf = _apply_activation(fmap_last[..., -1], self.conf_activation)
outs[self.head_aux] = aux_pred.view(B, S, *aux_pred.shape[1:])
outs[f"{self.head_aux}_conf"] = aux_conf.view(B, S, *aux_conf.shape[1:])
return outs

View File

@ -1,23 +1,24 @@
# DepthAnything3Net: top-level wrapper that combines backbone + head.
#
# This wrapper covers the monocular forward path only (single image -> depth).
# Camera encoder/decoder, ray-pose head, 3D Gaussians and the Nested
# architecture are intentionally omitted. The HF state dict for those
# components is filtered out before loading -- see
# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``.
# Supports both the monocular and the multi-view + camera path:
#
# The class signature mirrors the upstream YAML config so a single dit_config
# detected from the state dict in ``comfy/model_detection.py`` is sufficient
# to construct the right variant.
# * Monocular: ``S = 1``, no camera encoder/decoder. Mirrors the original
# port that only handled ``DA3-MONO/METRIC-LARGE`` and the auxiliary-disabled
# ``DA3-SMALL/BASE`` configs.
# * Multi-view + camera: ``S > 1``. ``cam_enc`` (optional) maps user-supplied
# extrinsics + intrinsics into a per-view camera token; ``cam_dec`` decodes
# the final layer's camera token into a 9-D pose encoding. When the
# auxiliary "ray" head of ``DualDPT`` is enabled the predicted ray map can
# alternatively be used to estimate pose via RANSAC (``use_ray_pose=True``).
# The 3D-Gaussian head and the nested-architecture wrapper are intentionally
# left out of scope here; their state-dict keys are filtered in
# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``.
#
# Backbone: ``comfy.image_encoders.dino2.Dinov2Model`` is shared with the
# CLIP-vision DINOv2 path. DA3-specific extensions (RoPE, QK-norm,
# alternating local/global attention, camera token, multi-layer feature
# extraction, pos-embed interpolation) are opt-in via the config dict and are
# all disabled for the Mono/Metric variants. The upstream DA3 weight layout
# (``backbone.pretrained.*`` with fused QKV) is converted to the
# ``Dinov2Model`` layout in
# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``.
# The backbone is shared with the CLIP-vision DINOv2 path
# (``comfy.image_encoders.dino2.Dinov2Model``); the DA3-specific extensions
# (RoPE, QK-norm, alternating local/global attention, camera token, multi-
# layer feature extraction, reference-view reordering) are opt-in via the
# config dict and are all disabled for the Mono/Metric variants.
from __future__ import annotations
@ -28,7 +29,10 @@ import torch.nn as nn
from comfy.image_encoders.dino2 import Dinov2Model
from .camera import CameraDec, CameraEnc
from .dpt import DPT, DualDPT
from .ray_pose import get_extrinsic_from_camray
from .transform import affine_inverse, pose_encoding_to_extri_intri
_HEAD_REGISTRY = {
@ -74,11 +78,11 @@ def _build_backbone_config(
class DepthAnything3Net(nn.Module):
"""ComfyUI-side DepthAnything3 network (monocular path only).
"""ComfyUI-side DepthAnything3 network.
Parameters mirror the variant YAML configs from the upstream repo.
Values are auto-detected by ``comfy/model_detection.py`` from the state
dict. The kwargs ``device``, ``dtype`` and ``operations`` are injected by
Parameters mirror the variant YAML configs from the upstream repo and
are auto-detected from the state dict by ``comfy/model_detection.py``.
The kwargs ``device``, ``dtype`` and ``operations`` are injected by
``BaseModel``.
"""
@ -101,6 +105,11 @@ class DepthAnything3Net(nn.Module):
head_out_channels: Sequence[int] = (256, 512, 1024, 1024),
head_use_sky_head: bool = True, # ignored by DualDPT
head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT
# --- Camera (multi-view) ---
has_cam_enc: bool = False,
has_cam_dec: bool = False,
cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim)
cam_dec_dim_in: Optional[int] = None, # CameraDec dim_in (defaults to 2*embed_dim with cat_token)
# ComfyUI plumbing
device=None, dtype=None, operations=None,
**_ignored,
@ -139,25 +148,82 @@ class DepthAnything3Net(nn.Module):
pos_embed=(True if head_pos_embed is None else head_pos_embed),
)
self.head = head_cls(**head_kwargs)
# Camera encoder / decoder are only constructed when their weights are
# present in the checkpoint; the multi-view / pose forward path becomes
# available accordingly. ``cam_enc.dim_out`` matches the backbone's
# ``embed_dim`` so the cam token slots into block ``alt_start``.
embed_dim = backbone_cfg["hidden_size"]
if has_cam_enc:
self.cam_enc = CameraEnc(
dim_out=cam_dim_out if cam_dim_out is not None else embed_dim,
num_heads=max(1, embed_dim // 64),
device=device, dtype=dtype, operations=operations,
)
else:
self.cam_enc = None
if has_cam_dec:
# Default cam_dec dim_in is 2*embed_dim when cat_token is on
# (the cls/cam token in the output is the cat'd version).
default_dim = embed_dim * (2 if cat_token else 1)
self.cam_dec = CameraDec(
dim_in=cam_dec_dim_in if cam_dec_dim_in is not None else default_dim,
device=device, dtype=dtype, operations=operations,
)
else:
self.cam_dec = None
self.dtype = dtype
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def forward(self, image: torch.Tensor, **_unused) -> Dict[str, torch.Tensor]:
"""Run monocular forward.
def forward(
self,
image: torch.Tensor,
extrinsics: Optional[torch.Tensor] = None,
intrinsics: Optional[torch.Tensor] = None,
*,
use_ray_pose: bool = False,
ref_view_strategy: str = "saddle_balanced",
export_feat_layers: Optional[Sequence[int]] = None,
**_unused,
) -> Dict[str, torch.Tensor]:
"""Run depth (and optionally pose) prediction.
Args:
image: ``(B, 3, H, W)`` ImageNet-normalised image tensor, or
``(B, S, 3, H, W)`` if a fake "views" axis is supplied.
H and W must be multiples of 14.
``(B, S, 3, H, W)`` for multi-view inputs. ``H`` and ``W``
must be multiples of 14.
extrinsics: optional ``(B, S, 4, 4)`` world-to-camera extrinsics.
When provided together with ``intrinsics``, ``CameraEnc``
converts them into per-view camera tokens that the backbone
injects at block ``alt_start``.
intrinsics: optional ``(B, S, 3, 3)`` pixel-space intrinsics.
use_ray_pose: if True, predict pose from the auxiliary "ray" head
(RANSAC over per-pixel rays). Only available on DualDPT
variants. If False (default) and ``cam_dec`` is present,
the final-layer cam token is decoded into pose instead.
ref_view_strategy: reference-view selection strategy used when
``S >= 3`` and no extrinsics are supplied. See
:mod:`comfy.ldm.depth_anything_3.reference_view_selector`.
export_feat_layers: optional list of backbone layer indices whose
local features to also return as auxiliary outputs (used by
downstream nested-architecture wrappers; empty by default).
Returns:
Dict with:
- ``depth``: ``(B, H, W)`` raw depth values.
- ``depth_conf``: ``(B, H, W)`` confidence (DualDPT variants only).
- ``sky``: ``(B, H, W)`` sky probability/logit
(DPT variants only).
Dict with a subset of:
- ``depth`` ``(B*S, H, W)`` raw depth values.
- ``depth_conf`` ``(B*S, H, W)`` confidence (DualDPT only).
- ``sky`` ``(B*S, H, W)`` sky probability (DPT + sky head).
- ``ray`` ``(B, S, h, w, 6)`` per-pixel cam ray (DualDPT,
multi-view, ``use_ray_pose=True`` only).
- ``ray_conf`` ``(B, S, h, w)`` ray confidence.
- ``extrinsics`` ``(B, S, 4, 4)`` world-to-cam, when pose
prediction is active.
- ``intrinsics`` ``(B, S, 3, 3)`` pixel-space intrinsics.
- ``aux_features`` list of ``(B, S, h_p, w_p, C)`` features
when ``export_feat_layers`` is non-empty.
"""
if image.ndim == 4:
image = image.unsqueeze(1) # (B, 1, 3, H, W)
@ -168,14 +234,76 @@ class DepthAnything3Net(nn.Module):
assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \
f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}"
feats = self.backbone.get_intermediate_layers(image, self.out_layers)
# Camera-token preparation (multi-view path).
cam_token = None
if extrinsics is not None and intrinsics is not None and self.cam_enc is not None:
cam_token = self.cam_enc(extrinsics, intrinsics, (H, W))
# Toggle aux ray output on/off depending on what the caller asked for.
if isinstance(self.head, DualDPT):
self.head.enable_aux = bool(use_ray_pose)
feats, aux_feats = self.backbone.get_intermediate_layers(
image, self.out_layers, cam_token=cam_token,
ref_view_strategy=ref_view_strategy,
export_feat_layers=export_feat_layers,
)
head_out = self.head(feats, H=H, W=W, patch_start_idx=0)
# Flatten the views axis (S=1 in mono inference path).
# Pose prediction.
out: Dict[str, torch.Tensor] = {}
if use_ray_pose and "ray" in head_out and "ray_conf" in head_out:
ray = head_out["ray"]
ray_conf = head_out["ray_conf"]
extr_c2w, focal, pp = get_extrinsic_from_camray(
ray, ray_conf, ray.shape[-3], ray.shape[-2],
)
# Match the upstream output: w2c, drop the homogeneous row.
extr_w2c = affine_inverse(extr_c2w)[:, :, :3, :]
# Build pixel-space intrinsics from the normalised focal/pp output.
intr = torch.eye(3, device=ray.device, dtype=ray.dtype)
intr = intr[None, None].expand(extr_c2w.shape[0], extr_c2w.shape[1], 3, 3).clone()
intr[:, :, 0, 0] = focal[:, :, 0] / 2 * W
intr[:, :, 1, 1] = focal[:, :, 1] / 2 * H
intr[:, :, 0, 2] = pp[:, :, 0] * W * 0.5
intr[:, :, 1, 2] = pp[:, :, 1] * H * 0.5
out["extrinsics"] = extr_w2c
out["intrinsics"] = intr
elif self.cam_dec is not None and S > 1:
# Decode the cam-token of the final out_layer into a pose encoding.
cam_feat = feats[-1][1] # (B, S, dim_in_to_cam_dec)
pose_enc = self.cam_dec(cam_feat)
c2w_3x4, intr = pose_encoding_to_extri_intri(pose_enc, (H, W))
# Match the upstream output convention: w2c (world->camera), 3x4.
c2w_4x4 = torch.cat([
c2w_3x4,
torch.tensor([0, 0, 0, 1], device=c2w_3x4.device, dtype=c2w_3x4.dtype)
.view(1, 1, 1, 4).expand(B, S, 1, 4),
], dim=-2)
out["extrinsics"] = affine_inverse(c2w_4x4)[:, :, :3, :]
out["intrinsics"] = intr
# Flatten the views axis for per-pixel outputs (depth/conf/sky) so the
# per-image consumer keeps its (B*S, H, W) interface.
for k, v in head_out.items():
if v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S:
if k in ("ray", "ray_conf"):
# Keep multi-view shape for downstream pose work.
out[k] = v
elif v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S:
out[k] = v.reshape(B * S, *v.shape[2:])
else:
out[k] = v
if export_feat_layers:
out["aux_features"] = self._reshape_aux_features(aux_feats, H, W)
return out
def _reshape_aux_features(self, aux_feats, H: int, W: int):
"""Reshape ``(B, S, N, C)`` aux features into ``(B, S, h_p, w_p, C)``."""
ph, pw = H // self.PATCH_SIZE, W // self.PATCH_SIZE
out = []
for f in aux_feats:
B, S, N, C = f.shape
assert N == ph * pw, f"aux feature seq mismatch: {N} != {ph}*{pw}"
out.append(f.reshape(B, S, ph, pw, C))
return out

View File

@ -0,0 +1,312 @@
"""Ray-to-pose conversion for the multi-view path of Depth Anything 3.
Converts the auxiliary "ray" output of :class:`DualDPT` (per-pixel camera
ray vectors, predicted on the per-view local feature map) into per-view
extrinsics + intrinsics. Implementation is a 1:1 port of
``depth_anything_3.utils.ray_utils`` upstream, using a weighted-RANSAC
homography fit followed by a QL decomposition.
No learned parameters; pure tensor math. Output:
* ``R`` -- ``(B, S, 3, 3)`` rotation matrix
* ``T`` -- ``(B, S, 3)`` camera-space translation
* ``focal_lengths`` -- ``(B, S, 2)`` in normalised image space (image=2x2)
* ``principal_points`` -- ``(B, S, 2)`` ditto
:func:`get_extrinsic_from_camray` wraps these into a 4x4 extrinsic matrix
that the public node converts back into pixel-space intrinsics.
"""
from __future__ import annotations
from typing import Optional, Tuple
import torch
# -----------------------------------------------------------------------------
# Linear-algebra helpers
# -----------------------------------------------------------------------------
def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Decompose ``A = Q @ L`` with ``Q`` orthogonal and ``L`` lower-triangular.
Implemented in terms of QR by reversing the columns/rows; the standard
trick from the upstream reference. Inputs ``A`` are ``(3, 3)``.
"""
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]],
device=A.device, dtype=A.dtype)
A_tilde = A @ P
Q_tilde, R_tilde = torch.linalg.qr(A_tilde)
Q = Q_tilde @ P
L = P @ R_tilde @ P
d = torch.diag(L)
sign = torch.sign(d)
Q = Q * sign[None, :] # scale columns of Q
L = L * sign[:, None] # scale rows of L
return Q, L
def _homogenize_points(points: torch.Tensor) -> torch.Tensor:
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
# -----------------------------------------------------------------------------
# Weighted-LSQ + RANSAC homography (batched)
# -----------------------------------------------------------------------------
def _find_homography_weighted_lsq(
src_pts: torch.Tensor,
dst_pts: torch.Tensor,
confident_weight: torch.Tensor,
) -> torch.Tensor:
"""Solve a single ``H`` with weighted least-squares (DLT)."""
N = src_pts.shape[0]
if N < 4:
raise ValueError("At least 4 points are required to compute a homography.")
w = confident_weight.sqrt().unsqueeze(1) # (N, 1)
x = src_pts[:, 0:1]
y = src_pts[:, 1:2]
u = dst_pts[:, 0:1]
v = dst_pts[:, 1:2]
zeros = torch.zeros_like(x)
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=1)
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=1)
A = torch.cat([A1, A2], dim=0) # (2N, 9)
_, _, Vh = torch.linalg.svd(A)
H = Vh[-1].reshape(3, 3)
return H / H[-1, -1]
def _find_homography_weighted_lsq_batched(
src_pts_batch: torch.Tensor,
dst_pts_batch: torch.Tensor,
confident_weight_batch: torch.Tensor,
) -> torch.Tensor:
"""Batched DLT solver. Inputs ``(B, K, 2)`` / ``(B, K)``; output ``(B, 3, 3)``."""
B, K, _ = src_pts_batch.shape
w = confident_weight_batch.sqrt().unsqueeze(2)
x = src_pts_batch[:, :, 0:1]
y = src_pts_batch[:, :, 1:2]
u = dst_pts_batch[:, :, 0:1]
v = dst_pts_batch[:, :, 1:2]
zeros = torch.zeros_like(x)
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=2)
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=2)
A = torch.cat([A1, A2], dim=1) # (B, 2K, 9)
_, _, Vh = torch.linalg.svd(A)
H = Vh[:, -1].reshape(B, 3, 3)
return H / H[:, 2:3, 2:3]
def _ransac_find_homography_weighted_batched(
src_pts: torch.Tensor, # (B, N, 2)
dst_pts: torch.Tensor, # (B, N, 2)
confident_weight: torch.Tensor, # (B, N)
n_sample: int,
n_iter: int = 100,
reproj_threshold: float = 3.0,
num_sample_for_ransac: int = 8,
max_inlier_num: int = 10000,
rand_sample_iters_idx: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Batched weighted-RANSAC homography estimator.
Returns ``(B, 3, 3)`` homography matrices.
"""
B, N, _ = src_pts.shape
assert N >= 4
device = src_pts.device
sorted_idx = torch.argsort(confident_weight, descending=True, dim=1)
candidate_idx = sorted_idx[:, :n_sample] # (B, n_sample)
if rand_sample_iters_idx is None:
rand_sample_iters_idx = torch.stack(
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac]
for _ in range(n_iter)],
dim=0,
)
rand_idx = candidate_idx[:, rand_sample_iters_idx] # (B, n_iter, k)
b_idx = (
torch.arange(B, device=device)
.view(B, 1, 1)
.expand(B, n_iter, num_sample_for_ransac)
)
src_b = src_pts[b_idx, rand_idx]
dst_b = dst_pts[b_idx, rand_idx]
w_b = confident_weight[b_idx, rand_idx]
cB, cN = src_b.shape[:2]
H_batch = _find_homography_weighted_lsq_batched(
src_b.flatten(0, 1), dst_b.flatten(0, 1), w_b.flatten(0, 1),
).unflatten(0, (cB, cN)) # (B, n_iter, 3, 3)
src_homo = torch.cat([src_pts, torch.ones(B, N, 1, device=device, dtype=src_pts.dtype)], dim=2)
proj = torch.bmm(
src_homo.unsqueeze(1).expand(B, n_iter, N, 3).reshape(-1, N, 3),
H_batch.reshape(-1, 3, 3).transpose(1, 2),
) # (B*n_iter, N, 3)
proj_xy = (proj[:, :, :2] / proj[:, :, 2:3]).reshape(B, n_iter, N, 2)
err = ((proj_xy - dst_pts.unsqueeze(1)) ** 2).sum(-1).sqrt() # (B, n_iter, N)
inlier_mask = err < reproj_threshold
score = (inlier_mask * confident_weight.unsqueeze(1)).sum(dim=2)
best_idx = torch.argmax(score, dim=1)
best_inlier_mask = inlier_mask[torch.arange(B, device=device), best_idx]
# Refit with the inlier set (per-batch, since the inlier counts vary).
H_inlier_list = []
for b in range(B):
mask = best_inlier_mask[b]
in_src = src_pts[b][mask]
in_dst = dst_pts[b][mask]
in_w = confident_weight[b][mask]
if in_src.shape[0] < 4:
# Fall back to identity when RANSAC fails to find enough inliers.
H_inlier_list.append(torch.eye(3, device=device, dtype=src_pts.dtype))
continue
sorted_w = torch.argsort(in_w, descending=True)
if len(sorted_w) > max_inlier_num:
keep = max(int(len(sorted_w) * 0.95), max_inlier_num)
sorted_w = sorted_w[:keep][torch.randperm(keep, device=device)[:max_inlier_num]]
H_inlier_list.append(
_find_homography_weighted_lsq(in_src[sorted_w], in_dst[sorted_w], in_w[sorted_w])
)
return torch.stack(H_inlier_list, dim=0)
# -----------------------------------------------------------------------------
# Camera-ray utilities
# -----------------------------------------------------------------------------
def _unproject_identity(num_y: int, num_x: int, B: int, S: int,
device, dtype) -> torch.Tensor:
"""Camera-space unit rays for an identity intrinsic on a 2x2 image plane.
Replicates ``unproject_depth(..., ixt_normalized=True)`` upstream: pixel
coords ``(x, y)`` in ``[dx, 2-dx] x [dy, 2-dy]`` get mapped to
camera-space rays ``(x-1, y-1, 1)`` via the identity intrinsic
``[[1,0,1],[0,1,1],[0,0,1]]``. Returns ``(B, S, num_y, num_x, 3)``.
"""
dx = 1.0 / num_x
dy = 1.0 / num_y
# Centered camera-space coords directly (skip the K^-1 step since it's
# just a translation by -1 on x and y when K is identity-with-center=1).
y = torch.linspace(-(1 - dy), (1 - dy), num_y, device=device, dtype=dtype)
x = torch.linspace(-(1 - dx), (1 - dx), num_x, device=device, dtype=dtype)
yy, xx = torch.meshgrid(y, x, indexing="ij")
grid = torch.stack((xx, yy), dim=-1) # (h, w, 2)
grid = grid.unsqueeze(0).unsqueeze(0).expand(B, S, num_y, num_x, 2)
return torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1)
def _camray_to_caminfo(
camray: torch.Tensor, # (B, S, h, w, 6)
confidence: Optional[torch.Tensor] = None, # (B, S, h, w)
reproj_threshold: float = 0.2,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert per-pixel camera rays to per-view (R, T, focal, principal)."""
if confidence is None:
confidence = torch.ones_like(camray[..., 0])
B, S, h, w, _ = camray.shape
device = camray.device
dtype = camray.dtype
rays_target = camray[..., :3] # (B, S, h, w, 3)
rays_origin = _unproject_identity(h, w, B, S, device, dtype)
# Flatten (B*S, h*w, *) for the RANSAC routine.
rays_target = rays_target.flatten(0, 1).flatten(1, 2)
rays_origin = rays_origin.flatten(0, 1).flatten(1, 2)
weights = confidence.flatten(0, 1).flatten(1, 2).clone()
# Project to 2D in homogeneous form (the upstream calls this "perspective division").
z_thresh = 1e-4
mask = (rays_target[:, :, 2].abs() > z_thresh) & (rays_origin[:, :, 2].abs() > z_thresh)
weights = torch.where(mask, weights, torch.zeros_like(weights))
src = rays_origin.clone()
dst = rays_target.clone()
src[..., 0] = torch.where(mask, src[..., 0] / src[..., 2], src[..., 0])
src[..., 1] = torch.where(mask, src[..., 1] / src[..., 2], src[..., 1])
dst[..., 0] = torch.where(mask, dst[..., 0] / dst[..., 2], dst[..., 0])
dst[..., 1] = torch.where(mask, dst[..., 1] / dst[..., 2], dst[..., 1])
src = src[..., :2]
dst = dst[..., :2]
N = src.shape[1]
n_iter = 100
sample_ratio = 0.3
num_sample_for_ransac = 8
n_sample = max(num_sample_for_ransac, int(N * sample_ratio))
rand_idx = torch.stack(
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)],
dim=0,
)
# Chunk along the view axis to keep peak memory predictable.
chunk = 2
A_list = []
for i in range(0, src.shape[0], chunk):
A = _ransac_find_homography_weighted_batched(
src[i:i + chunk], dst[i:i + chunk], weights[i:i + chunk],
n_sample=n_sample, n_iter=n_iter,
num_sample_for_ransac=num_sample_for_ransac,
reproj_threshold=reproj_threshold,
rand_sample_iters_idx=rand_idx,
max_inlier_num=8000,
)
# Flip sign on dets that come out < 0 (so that the QL produces a
# right-handed rotation).
flip = torch.linalg.det(A) < 0
A = torch.where(flip[:, None, None], -A, A)
A_list.append(A)
A = torch.cat(A_list, dim=0) # (B*S, 3, 3)
R_list, f_list, pp_list = [], [], []
for i in range(A.shape[0]):
R, L = _ql_decomposition(A[i])
L = L / L[2][2]
f_list.append(torch.stack((L[0][0], L[1][1])))
pp_list.append(torch.stack((L[2][0], L[2][1])))
R_list.append(R)
R = torch.stack(R_list).reshape(B, S, 3, 3)
focal = torch.stack(f_list).reshape(B, S, 2)
pp = torch.stack(pp_list).reshape(B, S, 2)
# Translation: confidence-weighted average of camray direction(s).
cf = confidence.flatten(0, 1).flatten(1, 2)
T = (camray.flatten(0, 1).flatten(1, 2)[..., 3:] * cf.unsqueeze(-1)).sum(dim=1)
T = T / cf.sum(dim=-1, keepdim=True)
T = T.reshape(B, S, 3)
# Match upstream output convention: focal -> 1/focal, pp + 1.
return R, T, 1.0 / focal, pp + 1.0
def get_extrinsic_from_camray(
camray: torch.Tensor, # (B, S, h, w, 6)
conf: torch.Tensor, # (B, S, h, w, 1) or (B, S, h, w)
patch_size_y: int,
patch_size_x: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Wrap a 4x4 extrinsic + per-view focal + principal-point output.
Returns:
* extrinsic ``(B, S, 4, 4)`` camera-to-world (the inverse is
what gets stored in ``output.extrinsics``
by the caller).
* focals ``(B, S, 2)`` in normalised image space.
* pp ``(B, S, 2)`` in normalised image space.
"""
if conf.ndim == 5 and conf.shape[-1] == 1:
conf = conf.squeeze(-1)
R, T, focal, pp = _camray_to_caminfo(camray, confidence=conf)
extr = torch.cat([R, T.unsqueeze(-1)], dim=-1) # (B, S, 3, 4)
homo_row = torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device)
homo_row = homo_row.view(1, 1, 1, 4).expand(R.shape[0], R.shape[1], 1, 4)
extr = torch.cat([extr, homo_row], dim=-2) # (B, S, 4, 4)
return extr, focal, pp

View File

@ -0,0 +1,116 @@
"""Reference-view selection for the multi-view path of Depth Anything 3.
Pure tensor math, no learned parameters. Exposed as three free functions:
* :func:`select_reference_view` -- pick a reference view per batch.
* :func:`reorder_by_reference` -- move the reference view to position 0.
* :func:`restore_original_order` -- inverse of :func:`reorder_by_reference`.
Mirrors ``depth_anything_3.model.reference_view_selector`` upstream.
The default strategy (``"saddle_balanced"``) selects the view whose CLS
token features are closest to the median across multiple metrics.
"""
from __future__ import annotations
from typing import Literal
import torch
RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"]
# Per the upstream constants module: ``THRESH_FOR_REF_SELECTION = 3``.
# Reference selection only runs when there are at least this many views.
THRESH_FOR_REF_SELECTION: int = 3
def select_reference_view(
x: torch.Tensor,
strategy: RefViewStrategy = "saddle_balanced",
) -> torch.Tensor:
"""Pick a reference view index per batch element.
Args:
x: ``(B, S, N, C)`` token tensor. Index 0 along ``N`` is the
cls/cam token used by the feature-based strategies.
strategy: One of ``"first" | "middle" | "saddle_balanced" |
"saddle_sim_range"``.
Returns:
``(B,)`` long tensor with the chosen reference view index for
each batch element.
"""
B, S, _, _ = x.shape
if S <= 1:
return torch.zeros(B, dtype=torch.long, device=x.device)
if strategy == "first":
return torch.zeros(B, dtype=torch.long, device=x.device)
if strategy == "middle":
return torch.full((B,), S // 2, dtype=torch.long, device=x.device)
# Feature-based strategies: normalised cls/cam token per view.
img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True) # (B,S,C)
if strategy == "saddle_balanced":
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # (B,S,S)
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
sim_score = sim_no_diag.sum(dim=-1) / (S - 1) # (B,S)
feat_norm = x[:, :, 0].norm(dim=-1) # (B,S)
feat_var = img_class_feat.var(dim=-1) # (B,S)
def _normalize(metric):
mn = metric.min(dim=1, keepdim=True).values
mx = metric.max(dim=1, keepdim=True).values
return (metric - mn) / (mx - mn + 1e-8)
sim_n, norm_n, var_n = _normalize(sim_score), _normalize(feat_norm), _normalize(feat_var)
balance = (sim_n - 0.5).abs() + (norm_n - 0.5).abs() + (var_n - 0.5).abs()
return balance.argmin(dim=1)
if strategy == "saddle_sim_range":
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2))
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
sim_max = sim_no_diag.max(dim=-1).values
sim_min = sim_no_diag.min(dim=-1).values
return (sim_max - sim_min).argmax(dim=1)
raise ValueError(
f"Unknown reference view selection strategy: {strategy!r}. "
f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'"
)
def reorder_by_reference(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
"""Reorder ``x`` so the reference view is at position 0 in axis ``S``."""
B, S = x.shape[0], x.shape[1]
if S <= 1:
return x
positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
b_idx_exp = b_idx.unsqueeze(1)
reorder = torch.where(
(positions > 0) & (positions <= b_idx_exp),
positions - 1,
positions,
)
reorder[:, 0] = b_idx
batch = torch.arange(B, device=x.device).unsqueeze(1)
return x[batch, reorder]
def restore_original_order(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
"""Inverse of :func:`reorder_by_reference`."""
B, S = x.shape[0], x.shape[1]
if S <= 1:
return x
target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
b_idx_exp = b_idx.unsqueeze(1)
restore = torch.where(target_positions < b_idx_exp,
target_positions + 1,
target_positions)
restore = torch.scatter(
restore, dim=1, index=b_idx_exp, src=torch.zeros_like(b_idx_exp),
)
batch = torch.arange(B, device=x.device).unsqueeze(1)
return x[batch, restore]

View File

@ -0,0 +1,180 @@
"""Geometry / camera transform helpers for Depth Anything 3.
Pure tensor math, no learned parameters. Mirrors the upstream upstream
``depth_anything_3.model.utils.transform`` and the parts of
``depth_anything_3.utils.geometry`` used at inference time on the
multi-view + camera path. Kept self-contained so the DA3 module is fully
ported and does not depend on the upstream repo at runtime.
"""
from __future__ import annotations
from typing import Tuple
import torch
import torch.nn.functional as F
# -----------------------------------------------------------------------------
# Affine 4x4 helpers
# -----------------------------------------------------------------------------
def as_homogeneous(ext: torch.Tensor) -> torch.Tensor:
"""Promote ``(...,3,4)`` extrinsics to ``(...,4,4)`` homogeneous form.
A no-op when the input is already ``(...,4,4)``.
"""
if ext.shape[-2:] == (4, 4):
return ext
if ext.shape[-2:] == (3, 4):
ones = torch.zeros_like(ext[..., :1, :4])
ones[..., 0, 3] = 1.0
return torch.cat([ext, ones], dim=-2)
raise ValueError(f"Invalid affine shape: {ext.shape}")
def affine_inverse(A: torch.Tensor) -> torch.Tensor:
"""Inverse of an affine matrix ``[R|T; 0 0 0 1]``."""
R = A[..., :3, :3]
T = A[..., :3, 3:]
P = A[..., 3:, :]
return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2)
# -----------------------------------------------------------------------------
# Quaternion <-> rotation matrix (xyzw / scalar-last)
# -----------------------------------------------------------------------------
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""``sqrt(max(0, x))`` with a zero subgradient where ``x == 0``."""
ret = torch.zeros_like(x)
positive_mask = x > 0
if torch.is_grad_enabled():
ret[positive_mask] = torch.sqrt(x[positive_mask])
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""Force the real part of a unit quaternion (xyzw) to be non-negative."""
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
"""Convert quaternions (xyzw) to ``(...,3,3)`` rotation matrices."""
i, j, k, r = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
"""Convert ``(...,3,3)`` rotation matrices to quaternions (xyzw)."""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
quat_by_rijk = torch.stack(
[
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
batch_dim + (4,)
)
# Reorder rijk -> xyzw (i.e. ijkr).
out = out[..., [1, 2, 3, 0]]
return standardize_quaternion(out)
# -----------------------------------------------------------------------------
# Pose-encoding <-> extrinsics + intrinsics
# -----------------------------------------------------------------------------
def extri_intri_to_pose_encoding(
extrinsics: torch.Tensor,
intrinsics: torch.Tensor,
image_size_hw: Tuple[int, int],
) -> torch.Tensor:
"""Pack ``(extr, intr, image_size)`` into the 9-D pose-encoding vector.
``extrinsics`` are camera-to-world (c2w) ``(B,S,4,4)`` matrices,
``intrinsics`` are pixel-space ``(B,S,3,3)`` matrices, ``image_size_hw``
is a ``(H, W)`` pair. The encoding is ``[T(3), quat_xyzw(4), fov_h, fov_w]``.
"""
R = extrinsics[..., :3, :3]
T = extrinsics[..., :3, 3]
quat = mat_to_quat(R)
H, W = image_size_hw
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
return torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
def pose_encoding_to_extri_intri(
pose_encoding: torch.Tensor,
image_size_hw: Tuple[int, int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Inverse of :func:`extri_intri_to_pose_encoding`.
Returns a ``(B,S,3,4)`` c2w extrinsic matrix and a ``(B,S,3,3)``
pixel-space intrinsic matrix.
"""
T = pose_encoding[..., :3]
quat = pose_encoding[..., 3:7]
fov_h = pose_encoding[..., 7]
fov_w = pose_encoding[..., 8]
R = quat_to_mat(quat)
extrinsics = torch.cat([R, T[..., None]], dim=-1)
H, W = image_size_hw
fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6)
fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6)
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3),
device=pose_encoding.device, dtype=pose_encoding.dtype)
intrinsics[..., 0, 0] = fx
intrinsics[..., 1, 1] = fy
intrinsics[..., 0, 2] = W / 2
intrinsics[..., 1, 2] = H / 2
intrinsics[..., 2, 2] = 1.0
return extrinsics, intrinsics

View File

@ -848,6 +848,24 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
# vits/vitb: 12 blocks
dit_config["out_layers"] = [5, 7, 9, 11]
# Camera encoder/decoder presence (multi-view + pose path).
has_cam_enc = '{}cam_enc.token_norm.weight'.format(key_prefix) in state_dict_keys
has_cam_dec = '{}cam_dec.fc_t.weight'.format(key_prefix) in state_dict_keys
dit_config["has_cam_enc"] = has_cam_enc
dit_config["has_cam_dec"] = has_cam_dec
if has_cam_enc:
cam_enc_w = state_dict.get(
'{}cam_enc.pose_branch.fc2.weight'.format(key_prefix)
)
if cam_enc_w is not None:
dit_config["cam_dim_out"] = cam_enc_w.shape[0]
if has_cam_dec:
cam_dec_w = state_dict.get(
'{}cam_dec.fc_t.weight'.format(key_prefix)
)
if cam_dec_w is not None:
dit_config["cam_dec_dim_in"] = cam_dec_w.shape[1]
return dit_config
if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image

View File

@ -1864,10 +1864,12 @@ class DepthAnything3(supported_models_base.BASE):
return None
def process_unet_state_dict(self, state_dict):
# Drop weights for components we do not build (camera encoder/decoder,
# 3D Gaussian heads). Keeping unrelated keys around triggers spurious
# "unet unexpected" warnings on load.
drop_prefixes = ("cam_enc.", "cam_dec.", "gs_head.", "gs_adapter.")
# Drop weights for components we do not build (3D Gaussian heads).
# ``cam_enc.*`` / ``cam_dec.*`` are kept and consumed by the multi-view
# forward path -- their layouts in our ``camera.py`` mirror the
# upstream ``cam_enc.py`` / ``cam_dec.py`` so HF safetensors load
# directly without any key remap.
drop_prefixes = ("gs_head.", "gs_adapter.")
for k in list(state_dict.keys()):
if k.startswith(drop_prefixes):
state_dict.pop(k)

View File

@ -1,6 +1,6 @@
"""ComfyUI nodes for Depth Anything 3.
Adds three nodes:
Adds these nodes:
* ``LoadDepthAnything3`` -- load a DA3 ``.safetensors`` file from the
``models/depth_estimation/`` folder. Falls back to ``models/diffusion_models/``
@ -9,6 +9,9 @@ Adds three nodes:
depth map as a ComfyUI ``IMAGE`` (visualisation / ControlNet input).
* ``DepthAnything3DepthRaw`` -- run depth estimation and return the raw depth,
confidence and sky channels as ``MASK`` outputs.
* ``DepthAnything3MultiView`` -- multi-view path: depth + per-view extrinsics
+ intrinsics. Pose is decoded either from the camera-decoder MLP (default)
or from the auxiliary ray output via RANSAC (DA3-Small/Base only).
"""
from __future__ import annotations
@ -194,6 +197,153 @@ class DepthAnything3Depth(io.ComfyNode):
# -----------------------------------------------------------------------------
class DepthAnything3MultiView(io.ComfyNode):
"""Multi-view depth + pose estimation for DA3-Small / DA3-Base / DA3-Large.
Treats each batch element of the input ``IMAGE`` as a separate view of
the same scene. The selected reference view is auto-chosen by the
backbone via ``ref_view_strategy`` (when at least 3 views are
supplied), unless camera extrinsics are provided -- in which case the
geometry is pinned by the user and no reordering is done.
Output structure:
* ``depth_image`` -- per-view normalised depth as a stacked ``IMAGE``
batch (one frame per view, original input order).
* ``confidence`` / ``sky`` -- per-view masks (zero when the variant
does not produce them).
* ``camera`` -- ``LATENT`` dict with keys::
samples: (1, S, 1, h_p, w_p) -- raw depth packed as latent
type: "da3_multiview"
extrinsics: (1, S, 4, 4) world-to-camera matrices
intrinsics: (1, S, 3, 3) pixel-space intrinsics
depth_raw: (S, H, W) un-normalised depth
confidence: (S, H, W)
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DepthAnything3MultiView",
display_name="Depth Anything 3 (Multi-View)",
category="image/depth",
inputs=[
io.Model.Input("model"),
io.Image.Input("image",
tooltip="Image batch where each frame is a view of the same scene."),
io.Int.Input("process_res", default=504, min=140, max=2520, step=14,
tooltip="Longest-side target resolution (multiple of 14)."),
io.Combo.Input("resize_method",
options=["upper_bound_resize", "lower_bound_resize"],
default="upper_bound_resize"),
io.Combo.Input("ref_view_strategy",
options=["saddle_balanced", "saddle_sim_range", "first", "middle"],
default="saddle_balanced",
tooltip="Reference view selection (only applied when "
"S>=3 and no extrinsics are provided)."),
io.Combo.Input("pose_method",
options=["cam_dec", "ray_pose"],
default="cam_dec",
tooltip="cam_dec: small MLP on the final cam token (works for "
"all variants with cam_dec). ray_pose: RANSAC over the "
"DualDPT auxiliary ray output (DA3-Small/Base only)."),
io.Combo.Input("normalization",
options=["v2_style", "min_max", "raw"],
default="v2_style"),
],
outputs=[
io.Image.Output("depth_image"),
io.Mask.Output("confidence"),
io.Mask.Output("sky_mask"),
io.Latent.Output("camera",
tooltip="Per-view extrinsics + intrinsics + raw depth."),
],
)
@classmethod
def execute(cls, model, image, process_res, resize_method, ref_view_strategy,
pose_method, normalization) -> io.NodeOutput:
assert image.ndim == 4 and image.shape[-1] == 3, \
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
S, H, W, _ = image.shape
mm.load_model_gpu(model)
diffusion = model.model.diffusion_model
device = mm.get_torch_device()
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
# Stack all views as a single batch element with views axis = S.
x = image.to(device)
x = da3_preprocess.preprocess_image(x, process_res=process_res, method=resize_method)
x = x.to(dtype=dtype).unsqueeze(0) # (1, S, 3, H', W')
use_ray_pose = (pose_method == "ray_pose")
with torch.no_grad():
out = diffusion(x, use_ray_pose=use_ray_pose,
ref_view_strategy=ref_view_strategy)
# ``out["depth"]`` is (S, h_p, w_p); resize back to (S, H, W).
depth_lr = out["depth"].float()
depth = torch.nn.functional.interpolate(
depth_lr.unsqueeze(1), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
if "depth_conf" in out:
conf = torch.nn.functional.interpolate(
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
else:
conf = torch.zeros_like(depth)
if "sky" in out:
sky = torch.nn.functional.interpolate(
out["sky"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
else:
sky = torch.zeros_like(depth)
# Pose. Defaults to identity when neither cam_dec nor ray_pose is wired up.
if "extrinsics" in out and "intrinsics" in out:
extrinsics = out["extrinsics"].float().cpu()
intrinsics = out["intrinsics"].float().cpu()
else:
extrinsics = torch.eye(4)[None, None].expand(1, S, 4, 4).clone()
intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone()
# Normalised depth viz per view (same path as the mono node).
if normalization == "v2_style":
norm = torch.stack([
da3_preprocess.normalize_depth_v2_style(depth[i],
sky[i] if "sky" in out else None)
for i in range(S)
], dim=0)
elif normalization == "min_max":
norm = da3_preprocess.normalize_depth_min_max(depth)
else:
norm = depth
depth_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous()
camera_latent = {
# The Latent contract requires a ``samples`` field; pack the raw
# depth there so a downstream node still has a tensor to chain on.
"samples": depth.unsqueeze(0).unsqueeze(2).contiguous(), # (1, S, 1, H, W)
"type": "da3_multiview",
"extrinsics": extrinsics.contiguous(),
"intrinsics": intrinsics.contiguous(),
"depth_raw": depth.contiguous(),
"confidence": conf.contiguous(),
}
return io.NodeOutput(
depth_image,
conf.contiguous(),
sky.contiguous(),
camera_latent,
)
class DepthAnything3DepthRaw(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -240,6 +390,7 @@ class DepthAnything3Extension(ComfyExtension):
LoadDepthAnything3,
DepthAnything3Depth,
DepthAnything3DepthRaw,
DepthAnything3MultiView,
]