mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-15 03:27:24 +08:00
310 lines
13 KiB
Python
310 lines
13 KiB
Python
# DepthAnything3Net: top-level wrapper that combines backbone + head.
|
|
#
|
|
# Supports both the monocular and the multi-view + camera path:
|
|
#
|
|
# * Monocular: ``S = 1``, no camera encoder/decoder. Mirrors the original
|
|
# port that only handled ``DA3-MONO/METRIC-LARGE`` and the auxiliary-disabled
|
|
# ``DA3-SMALL/BASE`` configs.
|
|
# * Multi-view + camera: ``S > 1``. ``cam_enc`` (optional) maps user-supplied
|
|
# extrinsics + intrinsics into a per-view camera token; ``cam_dec`` decodes
|
|
# the final layer's camera token into a 9-D pose encoding. When the
|
|
# auxiliary "ray" head of ``DualDPT`` is enabled the predicted ray map can
|
|
# alternatively be used to estimate pose via RANSAC (``use_ray_pose=True``).
|
|
# The 3D-Gaussian head and the nested-architecture wrapper are intentionally
|
|
# left out of scope here; their state-dict keys are filtered in
|
|
# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``.
|
|
#
|
|
# The backbone is shared with the CLIP-vision DINOv2 path
|
|
# (``comfy.image_encoders.dino2.Dinov2Model``); the DA3-specific extensions
|
|
# (RoPE, QK-norm, alternating local/global attention, camera token, multi-
|
|
# layer feature extraction, reference-view reordering) are opt-in via the
|
|
# config dict and are all disabled for the Mono/Metric variants.
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Dict, Optional, Sequence
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from comfy.image_encoders.dino2 import Dinov2Model
|
|
|
|
from .camera import CameraDec, CameraEnc
|
|
from .dpt import DPT, DualDPT
|
|
from .ray_pose import get_extrinsic_from_camray
|
|
from .transform import affine_inverse, pose_encoding_to_extri_intri
|
|
|
|
|
|
_HEAD_REGISTRY = {
|
|
"dpt": DPT,
|
|
"dualdpt": DualDPT,
|
|
}
|
|
|
|
|
|
# Backbone presets (mirror the upstream DINOv2 ViT variants).
|
|
_BACKBONE_PRESETS = {
|
|
"vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False),
|
|
"vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False),
|
|
"vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False),
|
|
"vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True),
|
|
}
|
|
|
|
|
|
def _build_backbone_config(
|
|
backbone_name: str,
|
|
*,
|
|
alt_start: int,
|
|
qknorm_start: int,
|
|
rope_start: int,
|
|
cat_token: bool,
|
|
) -> dict:
|
|
if backbone_name not in _BACKBONE_PRESETS:
|
|
raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}")
|
|
cfg = dict(_BACKBONE_PRESETS[backbone_name])
|
|
cfg.update(dict(
|
|
layer_norm_eps=1e-6,
|
|
patch_size=14,
|
|
image_size=518,
|
|
# DA3 weights have no mask_token; skip registering it to avoid spurious
|
|
# missing-key warnings on load.
|
|
use_mask_token=False,
|
|
alt_start=alt_start,
|
|
qknorm_start=qknorm_start,
|
|
rope_start=rope_start,
|
|
cat_token=cat_token,
|
|
rope_freq=100.0,
|
|
))
|
|
return cfg
|
|
|
|
|
|
class DepthAnything3Net(nn.Module):
|
|
"""ComfyUI-side DepthAnything3 network.
|
|
|
|
Parameters mirror the variant YAML configs from the upstream repo and
|
|
are auto-detected from the state dict by ``comfy/model_detection.py``.
|
|
The kwargs ``device``, ``dtype`` and ``operations`` are injected by
|
|
``BaseModel``.
|
|
"""
|
|
|
|
PATCH_SIZE = 14
|
|
|
|
def __init__(
|
|
self,
|
|
# --- Backbone ---
|
|
backbone_name: str = "vitl",
|
|
out_layers: Sequence[int] = (4, 11, 17, 23),
|
|
alt_start: int = -1,
|
|
qknorm_start: int = -1,
|
|
rope_start: int = -1,
|
|
cat_token: bool = False,
|
|
# --- Head ---
|
|
head_type: str = "dpt", # "dpt" or "dualdpt"
|
|
head_dim_in: int = 1024,
|
|
head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf
|
|
head_features: int = 256,
|
|
head_out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
|
head_use_sky_head: bool = True, # ignored by DualDPT
|
|
head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT
|
|
# --- Camera (multi-view) ---
|
|
has_cam_enc: bool = False,
|
|
has_cam_dec: bool = False,
|
|
cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim)
|
|
cam_dec_dim_in: Optional[int] = None, # CameraDec dim_in (defaults to 2*embed_dim with cat_token)
|
|
# ComfyUI plumbing
|
|
device=None, dtype=None, operations=None,
|
|
**_ignored,
|
|
):
|
|
super().__init__()
|
|
head_cls = _HEAD_REGISTRY[head_type.lower()]
|
|
self.head_type = head_type.lower()
|
|
self.has_sky = (self.head_type == "dpt") and head_use_sky_head
|
|
self.has_conf = head_output_dim > 1
|
|
self.out_layers = list(out_layers)
|
|
|
|
backbone_cfg = _build_backbone_config(
|
|
backbone_name,
|
|
alt_start=alt_start,
|
|
qknorm_start=qknorm_start,
|
|
rope_start=rope_start,
|
|
cat_token=cat_token,
|
|
)
|
|
self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations)
|
|
|
|
head_kwargs = dict(
|
|
dim_in=head_dim_in,
|
|
patch_size=self.PATCH_SIZE,
|
|
output_dim=head_output_dim,
|
|
features=head_features,
|
|
out_channels=tuple(head_out_channels),
|
|
device=device, dtype=dtype, operations=operations,
|
|
)
|
|
if self.head_type == "dpt":
|
|
head_kwargs.update(
|
|
use_sky_head=head_use_sky_head,
|
|
pos_embed=(False if head_pos_embed is None else head_pos_embed),
|
|
)
|
|
else: # dualdpt
|
|
head_kwargs.update(
|
|
pos_embed=(True if head_pos_embed is None else head_pos_embed),
|
|
)
|
|
self.head = head_cls(**head_kwargs)
|
|
|
|
# Camera encoder / decoder are only constructed when their weights are
|
|
# present in the checkpoint; the multi-view / pose forward path becomes
|
|
# available accordingly. ``cam_enc.dim_out`` matches the backbone's
|
|
# ``embed_dim`` so the cam token slots into block ``alt_start``.
|
|
embed_dim = backbone_cfg["hidden_size"]
|
|
if has_cam_enc:
|
|
self.cam_enc = CameraEnc(
|
|
dim_out=cam_dim_out if cam_dim_out is not None else embed_dim,
|
|
num_heads=max(1, embed_dim // 64),
|
|
device=device, dtype=dtype, operations=operations,
|
|
)
|
|
else:
|
|
self.cam_enc = None
|
|
if has_cam_dec:
|
|
# Default cam_dec dim_in is 2*embed_dim when cat_token is on
|
|
# (the cls/cam token in the output is the cat'd version).
|
|
default_dim = embed_dim * (2 if cat_token else 1)
|
|
self.cam_dec = CameraDec(
|
|
dim_in=cam_dec_dim_in if cam_dec_dim_in is not None else default_dim,
|
|
device=device, dtype=dtype, operations=operations,
|
|
)
|
|
else:
|
|
self.cam_dec = None
|
|
|
|
self.dtype = dtype
|
|
|
|
# ------------------------------------------------------------------
|
|
# Forward
|
|
# ------------------------------------------------------------------
|
|
def forward(
|
|
self,
|
|
image: torch.Tensor,
|
|
extrinsics: Optional[torch.Tensor] = None,
|
|
intrinsics: Optional[torch.Tensor] = None,
|
|
*,
|
|
use_ray_pose: bool = False,
|
|
ref_view_strategy: str = "saddle_balanced",
|
|
export_feat_layers: Optional[Sequence[int]] = None,
|
|
**_unused,
|
|
) -> Dict[str, torch.Tensor]:
|
|
"""Run depth (and optionally pose) prediction.
|
|
|
|
Args:
|
|
image: ``(B, 3, H, W)`` ImageNet-normalised image tensor, or
|
|
``(B, S, 3, H, W)`` for multi-view inputs. ``H`` and ``W``
|
|
must be multiples of 14.
|
|
extrinsics: optional ``(B, S, 4, 4)`` world-to-camera extrinsics.
|
|
When provided together with ``intrinsics``, ``CameraEnc``
|
|
converts them into per-view camera tokens that the backbone
|
|
injects at block ``alt_start``.
|
|
intrinsics: optional ``(B, S, 3, 3)`` pixel-space intrinsics.
|
|
use_ray_pose: if True, predict pose from the auxiliary "ray" head
|
|
(RANSAC over per-pixel rays). Only available on DualDPT
|
|
variants. If False (default) and ``cam_dec`` is present,
|
|
the final-layer cam token is decoded into pose instead.
|
|
ref_view_strategy: reference-view selection strategy used when
|
|
``S >= 3`` and no extrinsics are supplied. See
|
|
:mod:`comfy.ldm.depth_anything_3.reference_view_selector`.
|
|
export_feat_layers: optional list of backbone layer indices whose
|
|
local features to also return as auxiliary outputs (used by
|
|
downstream nested-architecture wrappers; empty by default).
|
|
|
|
Returns:
|
|
Dict with a subset of:
|
|
- ``depth`` ``(B*S, H, W)`` raw depth values.
|
|
- ``depth_conf`` ``(B*S, H, W)`` confidence (DualDPT only).
|
|
- ``sky`` ``(B*S, H, W)`` sky probability (DPT + sky head).
|
|
- ``ray`` ``(B, S, h, w, 6)`` per-pixel cam ray (DualDPT,
|
|
multi-view, ``use_ray_pose=True`` only).
|
|
- ``ray_conf`` ``(B, S, h, w)`` ray confidence.
|
|
- ``extrinsics`` ``(B, S, 4, 4)`` world-to-cam, when pose
|
|
prediction is active.
|
|
- ``intrinsics`` ``(B, S, 3, 3)`` pixel-space intrinsics.
|
|
- ``aux_features`` list of ``(B, S, h_p, w_p, C)`` features
|
|
when ``export_feat_layers`` is non-empty.
|
|
"""
|
|
if image.ndim == 4:
|
|
image = image.unsqueeze(1) # (B, 1, 3, H, W)
|
|
assert image.ndim == 5 and image.shape[2] == 3, \
|
|
f"image must be (B,3,H,W) or (B,S,3,H,W); got {tuple(image.shape)}"
|
|
|
|
B, S, _, H, W = image.shape
|
|
assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \
|
|
f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}"
|
|
|
|
# Camera-token preparation (multi-view path).
|
|
cam_token = None
|
|
if extrinsics is not None and intrinsics is not None and self.cam_enc is not None:
|
|
cam_token = self.cam_enc(extrinsics, intrinsics, (H, W))
|
|
|
|
# Toggle aux ray output on/off depending on what the caller asked for.
|
|
if isinstance(self.head, DualDPT):
|
|
self.head.enable_aux = bool(use_ray_pose)
|
|
|
|
feats, aux_feats = self.backbone.get_intermediate_layers(
|
|
image, self.out_layers, cam_token=cam_token,
|
|
ref_view_strategy=ref_view_strategy,
|
|
export_feat_layers=export_feat_layers,
|
|
)
|
|
head_out = self.head(feats, H=H, W=W, patch_start_idx=0)
|
|
|
|
# Pose prediction.
|
|
out: Dict[str, torch.Tensor] = {}
|
|
if use_ray_pose and "ray" in head_out and "ray_conf" in head_out:
|
|
ray = head_out["ray"]
|
|
ray_conf = head_out["ray_conf"]
|
|
extr_c2w, focal, pp = get_extrinsic_from_camray(
|
|
ray, ray_conf, ray.shape[-3], ray.shape[-2],
|
|
)
|
|
# Match the upstream output: w2c, drop the homogeneous row.
|
|
extr_w2c = affine_inverse(extr_c2w)[:, :, :3, :]
|
|
# Build pixel-space intrinsics from the normalised focal/pp output.
|
|
intr = torch.eye(3, device=ray.device, dtype=ray.dtype)
|
|
intr = intr[None, None].expand(extr_c2w.shape[0], extr_c2w.shape[1], 3, 3).clone()
|
|
intr[:, :, 0, 0] = focal[:, :, 0] / 2 * W
|
|
intr[:, :, 1, 1] = focal[:, :, 1] / 2 * H
|
|
intr[:, :, 0, 2] = pp[:, :, 0] * W * 0.5
|
|
intr[:, :, 1, 2] = pp[:, :, 1] * H * 0.5
|
|
out["extrinsics"] = extr_w2c
|
|
out["intrinsics"] = intr
|
|
elif self.cam_dec is not None and S > 1:
|
|
# Decode the cam-token of the final out_layer into a pose encoding.
|
|
cam_feat = feats[-1][1] # (B, S, dim_in_to_cam_dec)
|
|
pose_enc = self.cam_dec(cam_feat)
|
|
c2w_3x4, intr = pose_encoding_to_extri_intri(pose_enc, (H, W))
|
|
# Match the upstream output convention: w2c (world->camera), 3x4.
|
|
c2w_4x4 = torch.cat([
|
|
c2w_3x4,
|
|
torch.tensor([0, 0, 0, 1], device=c2w_3x4.device, dtype=c2w_3x4.dtype)
|
|
.view(1, 1, 1, 4).expand(B, S, 1, 4),
|
|
], dim=-2)
|
|
out["extrinsics"] = affine_inverse(c2w_4x4)[:, :, :3, :]
|
|
out["intrinsics"] = intr
|
|
|
|
# Flatten the views axis for per-pixel outputs (depth/conf/sky) so the
|
|
# per-image consumer keeps its (B*S, H, W) interface.
|
|
for k, v in head_out.items():
|
|
if k in ("ray", "ray_conf"):
|
|
# Keep multi-view shape for downstream pose work.
|
|
out[k] = v
|
|
elif v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S:
|
|
out[k] = v.reshape(B * S, *v.shape[2:])
|
|
else:
|
|
out[k] = v
|
|
|
|
if export_feat_layers:
|
|
out["aux_features"] = self._reshape_aux_features(aux_feats, H, W)
|
|
return out
|
|
|
|
def _reshape_aux_features(self, aux_feats, H: int, W: int):
|
|
"""Reshape ``(B, S, N, C)`` aux features into ``(B, S, h_p, w_p, C)``."""
|
|
ph, pw = H // self.PATCH_SIZE, W // self.PATCH_SIZE
|
|
out = []
|
|
for f in aux_feats:
|
|
B, S, N, C = f.shape
|
|
assert N == ph * pw, f"aux feature seq mismatch: {N} != {ph}*{pw}"
|
|
out.append(f.reshape(B, S, ph, pw, C))
|
|
return out
|