mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-12 01:07:30 +08:00
237 lines
9.0 KiB
Python
237 lines
9.0 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Dict, Optional, Sequence
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from comfy.image_encoders.dino2 import Dinov2Model
|
|
|
|
from .camera import CameraDec, CameraEnc
|
|
from .dpt import DPT, DualDPT
|
|
from .ray_pose import get_extrinsic_from_camray
|
|
from .transform import affine_inverse, pose_encoding_to_extri_intri
|
|
|
|
|
|
_HEAD_REGISTRY = {
|
|
"dpt": DPT,
|
|
"dualdpt": DualDPT,
|
|
}
|
|
|
|
|
|
# Backbone presets (mirror the upstream DINOv2 ViT variants).
|
|
_BACKBONE_PRESETS = {
|
|
"vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False),
|
|
"vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False),
|
|
"vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False),
|
|
"vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True),
|
|
}
|
|
|
|
|
|
def _build_backbone_config(
|
|
backbone_name: str,
|
|
*,
|
|
alt_start: int,
|
|
qknorm_start: int,
|
|
rope_start: int,
|
|
cat_token: bool,
|
|
) -> dict:
|
|
if backbone_name not in _BACKBONE_PRESETS:
|
|
raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}")
|
|
cfg = dict(_BACKBONE_PRESETS[backbone_name])
|
|
cfg.update(dict(
|
|
layer_norm_eps=1e-6,
|
|
patch_size=14,
|
|
image_size=518,
|
|
# No mask_token in DA3 weights; omit param to avoid load warnings.
|
|
use_mask_token=False,
|
|
alt_start=alt_start,
|
|
qknorm_start=qknorm_start,
|
|
rope_start=rope_start,
|
|
cat_token=cat_token,
|
|
rope_freq=100.0,
|
|
))
|
|
return cfg
|
|
|
|
|
|
class DepthAnything3Net(nn.Module):
|
|
|
|
PATCH_SIZE = 14
|
|
|
|
def __init__(
|
|
self,
|
|
# --- Backbone ---
|
|
backbone_name: str = "vitl",
|
|
out_layers: Sequence[int] = (4, 11, 17, 23),
|
|
alt_start: int = -1,
|
|
qknorm_start: int = -1,
|
|
rope_start: int = -1,
|
|
cat_token: bool = False,
|
|
# --- Head ---
|
|
head_type: str = "dpt", # dpt or dualdpt
|
|
head_dim_in: int = 1024,
|
|
head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf
|
|
head_features: int = 256,
|
|
head_out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
|
head_use_sky_head: bool = True, # ignored by DualDPT
|
|
head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT
|
|
# --- Camera (multi-view) ---
|
|
has_cam_enc: bool = False,
|
|
has_cam_dec: bool = False,
|
|
cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim)
|
|
cam_dec_dim_in: Optional[int] = None, # CameraDec dim_in (defaults to 2*embed_dim with cat_token)
|
|
# ComfyUI plumbing
|
|
device=None, dtype=None, operations=None,
|
|
**_ignored,
|
|
):
|
|
super().__init__()
|
|
head_cls = _HEAD_REGISTRY[head_type.lower()]
|
|
self.head_type = head_type.lower()
|
|
self.has_sky = (self.head_type == "dpt") and head_use_sky_head
|
|
self.has_conf = head_output_dim > 1
|
|
self.out_layers = list(out_layers)
|
|
|
|
backbone_cfg = _build_backbone_config(
|
|
backbone_name,
|
|
alt_start=alt_start,
|
|
qknorm_start=qknorm_start,
|
|
rope_start=rope_start,
|
|
cat_token=cat_token,
|
|
)
|
|
self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations)
|
|
|
|
head_kwargs = dict(
|
|
dim_in=head_dim_in,
|
|
patch_size=self.PATCH_SIZE,
|
|
output_dim=head_output_dim,
|
|
features=head_features,
|
|
out_channels=tuple(head_out_channels),
|
|
device=device, dtype=dtype, operations=operations,
|
|
)
|
|
if self.head_type == "dpt":
|
|
head_kwargs.update(
|
|
use_sky_head=head_use_sky_head,
|
|
pos_embed=(False if head_pos_embed is None else head_pos_embed),
|
|
)
|
|
else: # dualdpt
|
|
head_kwargs.update(
|
|
pos_embed=(True if head_pos_embed is None else head_pos_embed),
|
|
)
|
|
self.head = head_cls(**head_kwargs)
|
|
|
|
# Built only if checkpoint has weights; cam_enc output dim == embed_dim.
|
|
embed_dim = backbone_cfg["hidden_size"]
|
|
if has_cam_enc:
|
|
self.cam_enc = CameraEnc(
|
|
dim_out=cam_dim_out if cam_dim_out is not None else embed_dim,
|
|
num_heads=max(1, embed_dim // 64),
|
|
device=device, dtype=dtype, operations=operations,
|
|
)
|
|
else:
|
|
self.cam_enc = None
|
|
if has_cam_dec:
|
|
default_dim = embed_dim * (2 if cat_token else 1)
|
|
self.cam_dec = CameraDec(
|
|
dim_in=cam_dec_dim_in if cam_dec_dim_in is not None else default_dim,
|
|
device=device, dtype=dtype, operations=operations,
|
|
)
|
|
else:
|
|
self.cam_dec = None
|
|
|
|
self.dtype = dtype
|
|
|
|
def forward(
|
|
self,
|
|
image: torch.Tensor,
|
|
extrinsics: Optional[torch.Tensor] = None,
|
|
intrinsics: Optional[torch.Tensor] = None,
|
|
*,
|
|
use_ray_pose: bool = False,
|
|
ref_view_strategy: str = "saddle_balanced",
|
|
export_feat_layers: Optional[Sequence[int]] = None,
|
|
**_unused,
|
|
) -> Dict[str, torch.Tensor]:
|
|
"""Run depth and optionally pose prediction."""
|
|
if image.ndim == 4:
|
|
image = image.unsqueeze(1) # (B, 1, 3, H, W)
|
|
assert image.ndim == 5 and image.shape[2] == 3, \
|
|
f"image must be (B,3,H,W) or (B,S,3,H,W); got {tuple(image.shape)}"
|
|
|
|
B, S, _, H, W = image.shape
|
|
assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \
|
|
f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}"
|
|
|
|
# Camera-token preparation (multi-view path).
|
|
cam_token = None
|
|
if extrinsics is not None and intrinsics is not None and self.cam_enc is not None:
|
|
cam_token = self.cam_enc(extrinsics, intrinsics, (H, W))
|
|
|
|
# Toggle aux ray output on/off depending on what the caller asked for.
|
|
if isinstance(self.head, DualDPT):
|
|
self.head.enable_aux = bool(use_ray_pose)
|
|
|
|
feats, aux_feats = self.backbone.get_intermediate_layers_da3(
|
|
image, self.out_layers, cam_token=cam_token,
|
|
ref_view_strategy=ref_view_strategy,
|
|
export_feat_layers=export_feat_layers,
|
|
)
|
|
head_out = self.head(feats, H=H, W=W, patch_start_idx=0)
|
|
|
|
# Pose prediction.
|
|
out: Dict[str, torch.Tensor] = {}
|
|
if use_ray_pose and "ray" in head_out and "ray_conf" in head_out:
|
|
ray = head_out["ray"]
|
|
ray_conf = head_out["ray_conf"]
|
|
extr_c2w, focal, pp = get_extrinsic_from_camray(
|
|
ray, ray_conf, ray.shape[-3], ray.shape[-2],
|
|
)
|
|
# Match the upstream output: w2c, drop the homogeneous row.
|
|
extr_w2c = affine_inverse(extr_c2w)[:, :, :3, :]
|
|
# Build pixel-space intrinsics from the normalised focal/pp output.
|
|
intr = torch.eye(3, device=ray.device, dtype=ray.dtype)
|
|
intr = intr[None, None].expand(extr_c2w.shape[0], extr_c2w.shape[1], 3, 3).clone()
|
|
intr[:, :, 0, 0] = focal[:, :, 0] / 2 * W
|
|
intr[:, :, 1, 1] = focal[:, :, 1] / 2 * H
|
|
intr[:, :, 0, 2] = pp[:, :, 0] * W * 0.5
|
|
intr[:, :, 1, 2] = pp[:, :, 1] * H * 0.5
|
|
out["extrinsics"] = extr_w2c
|
|
out["intrinsics"] = intr
|
|
elif self.cam_dec is not None and S > 1:
|
|
# Decode the cam-token of the final out_layer into a pose encoding.
|
|
cam_feat = feats[-1][1] # (B, S, dim_in_to_cam_dec)
|
|
pose_enc = self.cam_dec(cam_feat)
|
|
c2w_3x4, intr = pose_encoding_to_extri_intri(pose_enc, (H, W))
|
|
# Match the upstream output convention: w2c (world->camera), 3x4.
|
|
c2w_4x4 = torch.cat([
|
|
c2w_3x4,
|
|
torch.tensor([0, 0, 0, 1], device=c2w_3x4.device, dtype=c2w_3x4.dtype)
|
|
.view(1, 1, 1, 4).expand(B, S, 1, 4),
|
|
], dim=-2)
|
|
out["extrinsics"] = affine_inverse(c2w_4x4)[:, :, :3, :]
|
|
out["intrinsics"] = intr
|
|
|
|
# Flatten the views axis for per-pixel outputs (depth/conf/sky) so the
|
|
# per-image consumer keeps its (B*S, H, W) interface.
|
|
for k, v in head_out.items():
|
|
if k in ("ray", "ray_conf"):
|
|
# Keep multi-view shape for downstream pose work.
|
|
out[k] = v
|
|
elif v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S:
|
|
out[k] = v.reshape(B * S, *v.shape[2:])
|
|
else:
|
|
out[k] = v
|
|
|
|
if export_feat_layers:
|
|
out["aux_features"] = self._reshape_aux_features(aux_feats, H, W)
|
|
return out
|
|
|
|
def _reshape_aux_features(self, aux_feats, H: int, W: int):
|
|
"""Reshape (B, S, N, C) aux features into (B, S, h_p, w_p, C)."""
|
|
ph, pw = H // self.PATCH_SIZE, W // self.PATCH_SIZE
|
|
out = []
|
|
for f in aux_feats:
|
|
B, S, N, C = f.shape
|
|
assert N == ph * pw, f"aux feature seq mismatch: {N} != {ph}*{pw}"
|
|
out.append(f.reshape(B, S, ph, pw, C))
|
|
return out
|