ComfyUI/comfy/ldm/depth_anything_3/model.py
2026-05-13 10:59:31 +02:00

310 lines
13 KiB
Python

# DepthAnything3Net: top-level wrapper that combines backbone + head.
#
# Supports both the monocular and the multi-view + camera path:
#
# * Monocular: ``S = 1``, no camera encoder/decoder. Mirrors the original
# port that only handled ``DA3-MONO/METRIC-LARGE`` and the auxiliary-disabled
# ``DA3-SMALL/BASE`` configs.
# * Multi-view + camera: ``S > 1``. ``cam_enc`` (optional) maps user-supplied
# extrinsics + intrinsics into a per-view camera token; ``cam_dec`` decodes
# the final layer's camera token into a 9-D pose encoding. When the
# auxiliary "ray" head of ``DualDPT`` is enabled the predicted ray map can
# alternatively be used to estimate pose via RANSAC (``use_ray_pose=True``).
# The 3D-Gaussian head and the nested-architecture wrapper are intentionally
# left out of scope here; their state-dict keys are filtered in
# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``.
#
# The backbone is shared with the CLIP-vision DINOv2 path
# (``comfy.image_encoders.dino2.Dinov2Model``); the DA3-specific extensions
# (RoPE, QK-norm, alternating local/global attention, camera token, multi-
# layer feature extraction, reference-view reordering) are opt-in via the
# config dict and are all disabled for the Mono/Metric variants.
from __future__ import annotations
from typing import Dict, Optional, Sequence
import torch
import torch.nn as nn
from comfy.image_encoders.dino2 import Dinov2Model
from .camera import CameraDec, CameraEnc
from .dpt import DPT, DualDPT
from .ray_pose import get_extrinsic_from_camray
from .transform import affine_inverse, pose_encoding_to_extri_intri
_HEAD_REGISTRY = {
"dpt": DPT,
"dualdpt": DualDPT,
}
# Backbone presets (mirror the upstream DINOv2 ViT variants).
_BACKBONE_PRESETS = {
"vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False),
"vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False),
"vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False),
"vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True),
}
def _build_backbone_config(
backbone_name: str,
*,
alt_start: int,
qknorm_start: int,
rope_start: int,
cat_token: bool,
) -> dict:
if backbone_name not in _BACKBONE_PRESETS:
raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}")
cfg = dict(_BACKBONE_PRESETS[backbone_name])
cfg.update(dict(
layer_norm_eps=1e-6,
patch_size=14,
image_size=518,
# DA3 weights have no mask_token; skip registering it to avoid spurious
# missing-key warnings on load.
use_mask_token=False,
alt_start=alt_start,
qknorm_start=qknorm_start,
rope_start=rope_start,
cat_token=cat_token,
rope_freq=100.0,
))
return cfg
class DepthAnything3Net(nn.Module):
"""ComfyUI-side DepthAnything3 network.
Parameters mirror the variant YAML configs from the upstream repo and
are auto-detected from the state dict by ``comfy/model_detection.py``.
The kwargs ``device``, ``dtype`` and ``operations`` are injected by
``BaseModel``.
"""
PATCH_SIZE = 14
def __init__(
self,
# --- Backbone ---
backbone_name: str = "vitl",
out_layers: Sequence[int] = (4, 11, 17, 23),
alt_start: int = -1,
qknorm_start: int = -1,
rope_start: int = -1,
cat_token: bool = False,
# --- Head ---
head_type: str = "dpt", # "dpt" or "dualdpt"
head_dim_in: int = 1024,
head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf
head_features: int = 256,
head_out_channels: Sequence[int] = (256, 512, 1024, 1024),
head_use_sky_head: bool = True, # ignored by DualDPT
head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT
# --- Camera (multi-view) ---
has_cam_enc: bool = False,
has_cam_dec: bool = False,
cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim)
cam_dec_dim_in: Optional[int] = None, # CameraDec dim_in (defaults to 2*embed_dim with cat_token)
# ComfyUI plumbing
device=None, dtype=None, operations=None,
**_ignored,
):
super().__init__()
head_cls = _HEAD_REGISTRY[head_type.lower()]
self.head_type = head_type.lower()
self.has_sky = (self.head_type == "dpt") and head_use_sky_head
self.has_conf = head_output_dim > 1
self.out_layers = list(out_layers)
backbone_cfg = _build_backbone_config(
backbone_name,
alt_start=alt_start,
qknorm_start=qknorm_start,
rope_start=rope_start,
cat_token=cat_token,
)
self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations)
head_kwargs = dict(
dim_in=head_dim_in,
patch_size=self.PATCH_SIZE,
output_dim=head_output_dim,
features=head_features,
out_channels=tuple(head_out_channels),
device=device, dtype=dtype, operations=operations,
)
if self.head_type == "dpt":
head_kwargs.update(
use_sky_head=head_use_sky_head,
pos_embed=(False if head_pos_embed is None else head_pos_embed),
)
else: # dualdpt
head_kwargs.update(
pos_embed=(True if head_pos_embed is None else head_pos_embed),
)
self.head = head_cls(**head_kwargs)
# Camera encoder / decoder are only constructed when their weights are
# present in the checkpoint; the multi-view / pose forward path becomes
# available accordingly. ``cam_enc.dim_out`` matches the backbone's
# ``embed_dim`` so the cam token slots into block ``alt_start``.
embed_dim = backbone_cfg["hidden_size"]
if has_cam_enc:
self.cam_enc = CameraEnc(
dim_out=cam_dim_out if cam_dim_out is not None else embed_dim,
num_heads=max(1, embed_dim // 64),
device=device, dtype=dtype, operations=operations,
)
else:
self.cam_enc = None
if has_cam_dec:
# Default cam_dec dim_in is 2*embed_dim when cat_token is on
# (the cls/cam token in the output is the cat'd version).
default_dim = embed_dim * (2 if cat_token else 1)
self.cam_dec = CameraDec(
dim_in=cam_dec_dim_in if cam_dec_dim_in is not None else default_dim,
device=device, dtype=dtype, operations=operations,
)
else:
self.cam_dec = None
self.dtype = dtype
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def forward(
self,
image: torch.Tensor,
extrinsics: Optional[torch.Tensor] = None,
intrinsics: Optional[torch.Tensor] = None,
*,
use_ray_pose: bool = False,
ref_view_strategy: str = "saddle_balanced",
export_feat_layers: Optional[Sequence[int]] = None,
**_unused,
) -> Dict[str, torch.Tensor]:
"""Run depth (and optionally pose) prediction.
Args:
image: ``(B, 3, H, W)`` ImageNet-normalised image tensor, or
``(B, S, 3, H, W)`` for multi-view inputs. ``H`` and ``W``
must be multiples of 14.
extrinsics: optional ``(B, S, 4, 4)`` world-to-camera extrinsics.
When provided together with ``intrinsics``, ``CameraEnc``
converts them into per-view camera tokens that the backbone
injects at block ``alt_start``.
intrinsics: optional ``(B, S, 3, 3)`` pixel-space intrinsics.
use_ray_pose: if True, predict pose from the auxiliary "ray" head
(RANSAC over per-pixel rays). Only available on DualDPT
variants. If False (default) and ``cam_dec`` is present,
the final-layer cam token is decoded into pose instead.
ref_view_strategy: reference-view selection strategy used when
``S >= 3`` and no extrinsics are supplied. See
:mod:`comfy.ldm.depth_anything_3.reference_view_selector`.
export_feat_layers: optional list of backbone layer indices whose
local features to also return as auxiliary outputs (used by
downstream nested-architecture wrappers; empty by default).
Returns:
Dict with a subset of:
- ``depth`` ``(B*S, H, W)`` raw depth values.
- ``depth_conf`` ``(B*S, H, W)`` confidence (DualDPT only).
- ``sky`` ``(B*S, H, W)`` sky probability (DPT + sky head).
- ``ray`` ``(B, S, h, w, 6)`` per-pixel cam ray (DualDPT,
multi-view, ``use_ray_pose=True`` only).
- ``ray_conf`` ``(B, S, h, w)`` ray confidence.
- ``extrinsics`` ``(B, S, 4, 4)`` world-to-cam, when pose
prediction is active.
- ``intrinsics`` ``(B, S, 3, 3)`` pixel-space intrinsics.
- ``aux_features`` list of ``(B, S, h_p, w_p, C)`` features
when ``export_feat_layers`` is non-empty.
"""
if image.ndim == 4:
image = image.unsqueeze(1) # (B, 1, 3, H, W)
assert image.ndim == 5 and image.shape[2] == 3, \
f"image must be (B,3,H,W) or (B,S,3,H,W); got {tuple(image.shape)}"
B, S, _, H, W = image.shape
assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \
f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}"
# Camera-token preparation (multi-view path).
cam_token = None
if extrinsics is not None and intrinsics is not None and self.cam_enc is not None:
cam_token = self.cam_enc(extrinsics, intrinsics, (H, W))
# Toggle aux ray output on/off depending on what the caller asked for.
if isinstance(self.head, DualDPT):
self.head.enable_aux = bool(use_ray_pose)
feats, aux_feats = self.backbone.get_intermediate_layers(
image, self.out_layers, cam_token=cam_token,
ref_view_strategy=ref_view_strategy,
export_feat_layers=export_feat_layers,
)
head_out = self.head(feats, H=H, W=W, patch_start_idx=0)
# Pose prediction.
out: Dict[str, torch.Tensor] = {}
if use_ray_pose and "ray" in head_out and "ray_conf" in head_out:
ray = head_out["ray"]
ray_conf = head_out["ray_conf"]
extr_c2w, focal, pp = get_extrinsic_from_camray(
ray, ray_conf, ray.shape[-3], ray.shape[-2],
)
# Match the upstream output: w2c, drop the homogeneous row.
extr_w2c = affine_inverse(extr_c2w)[:, :, :3, :]
# Build pixel-space intrinsics from the normalised focal/pp output.
intr = torch.eye(3, device=ray.device, dtype=ray.dtype)
intr = intr[None, None].expand(extr_c2w.shape[0], extr_c2w.shape[1], 3, 3).clone()
intr[:, :, 0, 0] = focal[:, :, 0] / 2 * W
intr[:, :, 1, 1] = focal[:, :, 1] / 2 * H
intr[:, :, 0, 2] = pp[:, :, 0] * W * 0.5
intr[:, :, 1, 2] = pp[:, :, 1] * H * 0.5
out["extrinsics"] = extr_w2c
out["intrinsics"] = intr
elif self.cam_dec is not None and S > 1:
# Decode the cam-token of the final out_layer into a pose encoding.
cam_feat = feats[-1][1] # (B, S, dim_in_to_cam_dec)
pose_enc = self.cam_dec(cam_feat)
c2w_3x4, intr = pose_encoding_to_extri_intri(pose_enc, (H, W))
# Match the upstream output convention: w2c (world->camera), 3x4.
c2w_4x4 = torch.cat([
c2w_3x4,
torch.tensor([0, 0, 0, 1], device=c2w_3x4.device, dtype=c2w_3x4.dtype)
.view(1, 1, 1, 4).expand(B, S, 1, 4),
], dim=-2)
out["extrinsics"] = affine_inverse(c2w_4x4)[:, :, :3, :]
out["intrinsics"] = intr
# Flatten the views axis for per-pixel outputs (depth/conf/sky) so the
# per-image consumer keeps its (B*S, H, W) interface.
for k, v in head_out.items():
if k in ("ray", "ray_conf"):
# Keep multi-view shape for downstream pose work.
out[k] = v
elif v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S:
out[k] = v.reshape(B * S, *v.shape[2:])
else:
out[k] = v
if export_feat_layers:
out["aux_features"] = self._reshape_aux_features(aux_feats, H, W)
return out
def _reshape_aux_features(self, aux_feats, H: int, W: int):
"""Reshape ``(B, S, N, C)`` aux features into ``(B, S, h_p, w_p, C)``."""
ph, pw = H // self.PATCH_SIZE, W // self.PATCH_SIZE
out = []
for f in aux_feats:
B, S, N, C = f.shape
assert N == ph * pw, f"aux feature seq mismatch: {N} != {ph}*{pw}"
out.append(f.reshape(B, S, ph, pw, C))
return out