diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py index 9b6dace9d..c8738dc8c 100644 --- a/comfy/image_encoders/dino2.py +++ b/comfy/image_encoders/dino2.py @@ -1,7 +1,15 @@ +import math + import torch +import torch.nn.functional as F + from comfy.text_encoders.bert import BertAttention import comfy.model_management from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.ldm.depth_anything_3.reference_view_selector import ( + select_reference_view, reorder_by_reference, restore_original_order, + THRESH_FOR_REF_SELECTION, +) class Dino2AttentionOutput(torch.nn.Module): @@ -14,13 +22,42 @@ class Dino2AttentionOutput(torch.nn.Module): class Dino2AttentionBlock(torch.nn.Module): - def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations): + def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations, + qk_norm=False): super().__init__() + self.heads = heads + self.head_dim = embed_dim // heads self.attention = BertAttention(embed_dim, heads, dtype, device, operations) self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations) + if qk_norm: + self.q_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device) + self.k_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device) + else: + self.q_norm = None + self.k_norm = None - def forward(self, x, mask, optimized_attention): - return self.output(self.attention(x, mask, optimized_attention)) + def forward(self, x, mask, optimized_attention, pos=None, rope=None): + # Fast path used by the existing CLIP-vision DINOv2 (no DA3 extensions). + if self.q_norm is None and rope is None: + return self.output(self.attention(x, mask, optimized_attention)) + + # DA3 path: do QKV manually so we can apply per-head QK-norm and 2D RoPE. + attn = self.attention + B, N, C = x.shape + h = self.heads + d = self.head_dim + q = attn.query(x).view(B, N, h, d).transpose(1, 2) + k = attn.key(x).view(B, N, h, d).transpose(1, 2) + v = attn.value(x).view(B, N, h, d).transpose(1, 2) + if self.q_norm is not None: + q = self.q_norm(q) + k = self.k_norm(k) + if rope is not None and pos is not None: + q = rope(q, pos) + k = rope(k, pos) + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) + out = out.transpose(1, 2).reshape(B, N, C) + return self.output(out) class LayerScale(torch.nn.Module): @@ -64,9 +101,11 @@ class SwiGLUFFN(torch.nn.Module): class Dino2Block(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn): + def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn, + qk_norm=False): super().__init__() - self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations) + self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations, + qk_norm=qk_norm) self.layer_scale1 = LayerScale(dim, dtype, device, operations) self.layer_scale2 = LayerScale(dim, dtype, device, operations) if use_swiglu_ffn: @@ -76,19 +115,93 @@ class Dino2Block(torch.nn.Module): self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) - def forward(self, x, optimized_attention): - x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention)) + def forward(self, x, optimized_attention, pos=None, rope=None, attn_mask=None): + x = x + self.layer_scale1(self.attention(self.norm1(x), attn_mask, optimized_attention, + pos=pos, rope=rope)) x = x + self.layer_scale2(self.mlp(self.norm2(x))) return x -class Dino2Encoder(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn): +# ----------------------------------------------------------------------------- +# 2D Rotary position embedding (DA3 extension) +# ----------------------------------------------------------------------------- + + +class _PositionGetter: + """Cache (h, w) -> flat (y, x) position grid used to feed ``rope``.""" + + def __init__(self): + self._cache: dict = {} + + def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor: + key = (height, width, device) + if key not in self._cache: + y = torch.arange(height, device=device) + x = torch.arange(width, device=device) + self._cache[key] = torch.cartesian_prod(y, x) + cached = self._cache[key] + return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone() + + +class RotaryPositionEmbedding2D(torch.nn.Module): + """2D RoPE used by DA3-Small/Base. No learnable parameters.""" + + def __init__(self, frequency: float = 100.0): super().__init__() - self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) - for _ in range(num_layers)]) + self.base_frequency = frequency + self._freq_cache: dict = {} + + def _components(self, dim: int, seq_len: int, device, dtype): + key = (dim, seq_len, device, dtype) + if key not in self._freq_cache: + exp = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency ** exp) + pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + ang = torch.einsum("i,j->ij", pos, inv_freq) + ang = ang.to(dtype) + ang = torch.cat((ang, ang), dim=-1) + self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype)) + return self._freq_cache[key] + + @staticmethod + def _rotate(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] + x1, x2 = x[..., : d // 2], x[..., d // 2:] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d(self, tokens, positions, cos_c, sin_c): + cos = F.embedding(positions, cos_c)[:, None, :, :] + sin = F.embedding(positions, sin_c)[:, None, :, :] + return (tokens * cos) + (self._rotate(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + feature_dim = tokens.size(-1) // 2 + max_pos = int(positions.max()) + 1 + cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype) + v, h = tokens.chunk(2, dim=-1) + v = self._apply_1d(v, positions[..., 0], cos_c, sin_c) + h = self._apply_1d(h, positions[..., 1], cos_c, sin_c) + return torch.cat((v, h), dim=-1) + + +class Dino2Encoder(torch.nn.Module): + def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn, + qknorm_start: int = -1, rope: "RotaryPositionEmbedding2D | None" = None, + rope_start: int = -1): + super().__init__() + self.layer = torch.nn.ModuleList([ + Dino2Block( + dim, num_heads, layer_norm_eps, dtype, device, operations, + use_swiglu_ffn=use_swiglu_ffn, + qk_norm=(qknorm_start != -1 and i >= qknorm_start), + ) + for i in range(num_layers) + ]) + self.rope = rope + self.rope_start = rope_start def forward(self, x, intermediate_output=None): + # Backward-compat path used by ``ClipVisionModel`` (no DA3 extensions). optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) if intermediate_output is not None: @@ -121,25 +234,84 @@ class Dino2PatchEmbeddings(torch.nn.Module): class Dino2Embeddings(torch.nn.Module): - def __init__(self, dim, dtype, device, operations): + def __init__(self, dim, dtype, device, operations, + patch_size: int = 14, image_size: int = 518, + use_mask_token: bool = True, + num_camera_tokens: int = 0): super().__init__() - patch_size = 14 - image_size = 518 + self.patch_size = patch_size + self.image_size = image_size self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations) self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device)) self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) - self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device)) + if use_mask_token: + self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device)) + else: + self.mask_token = None + if num_camera_tokens > 0: + # DA3 stores (ref_token, src_token) pairs that get injected at the + # alt-attn boundary; see ``Dinov2Model._inject_camera_token``. + self.camera_token = torch.nn.Parameter(torch.empty(1, num_camera_tokens, dim, dtype=dtype, device=device)) + else: + self.camera_token = None + + def _interpolate_pos_encoding(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor: + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.position_embeddings.shape[1] - 1 + pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype).float() + if npatch == N and w == h: + return pos_embed + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + ph = h // self.patch_size # patch grid height + pw = w // self.patch_size # patch grid width + M = int(math.sqrt(N)) + assert N == M * M + # Historical 0.1 offset preserves bicubic resample compatibility with + # the original DINOv2 release; see the upstream PR for context. + # ``scale_factor`` is interpreted as (height_scale, width_scale) by + # ``F.interpolate`` so we must put the height scale FIRST. Earlier + # revisions of this function had it swapped which only worked for + # square inputs (e.g. CLIP-vision square crops); non-square inputs + # like DA3-Small / DA3-Base multi-view paths exposed the bug. + sh = float(ph + 0.1) / M + sw = float(pw + 0.1) / M + patch_pos_embed = F.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + scale_factor=(sh, sw), mode="bicubic", antialias=False, + ) + assert (ph, pw) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) def forward(self, pixel_values): + _, _, H, W = pixel_values.shape x = self.patch_embeddings(pixel_values) # TODO: mask_token? x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1) - x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype) + x = x + self._interpolate_pos_encoding(x, H, W) return x class Dinov2Model(torch.nn.Module): + """DINOv2 vision backbone. + + Supports two operating modes: + + * **CLIP-vision DINOv2** (default): vanilla DINOv2-ViT used for + ``ClipVisionModel`` and SigLIP-style image encoding. + * **Depth Anything 3** extensions (opt-in via config keys): 2D RoPE, + QK-norm, alternating local/global attention, camera-token injection, + ``cat_token`` output and multi-layer feature extraction. These are + enabled when the corresponding fields (``alt_start``, ``qknorm_start``, + ``rope_start``, ``cat_token``) are set in ``config_dict``. When all of + them are at their disabled defaults this module behaves identically to + the historical ``Dinov2Model``. + """ + def __init__(self, config_dict, dtype, device, operations): super().__init__() num_layers = config_dict["num_hidden_layers"] @@ -147,14 +319,209 @@ class Dinov2Model(torch.nn.Module): heads = config_dict["num_attention_heads"] layer_norm_eps = config_dict["layer_norm_eps"] use_swiglu_ffn = config_dict["use_swiglu_ffn"] + patch_size = config_dict.get("patch_size", 14) + image_size = config_dict.get("image_size", 518) + use_mask_token = config_dict.get("use_mask_token", True) - self.embeddings = Dino2Embeddings(dim, dtype, device, operations) - self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) + # DA3 extensions (all default to disabled). + self.alt_start = config_dict.get("alt_start", -1) + self.qknorm_start = config_dict.get("qknorm_start", -1) + self.rope_start = config_dict.get("rope_start", -1) + self.cat_token = config_dict.get("cat_token", False) + rope_freq = config_dict.get("rope_freq", 100.0) + + self.embed_dim = dim + self.patch_size = patch_size + self.num_register_tokens = 0 + self.patch_start_idx = 1 + + if self.rope_start != -1 and rope_freq > 0: + self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) + self._position_getter = _PositionGetter() + else: + self.rope = None + self._position_getter = None + + # camera_token shape: (1, 2, dim) -> (ref_token, src_token). + num_cam_tokens = 2 if self.alt_start != -1 else 0 + + self.embeddings = Dino2Embeddings( + dim, dtype, device, operations, + patch_size=patch_size, image_size=image_size, + use_mask_token=use_mask_token, num_camera_tokens=num_cam_tokens, + ) + self.encoder = Dino2Encoder( + dim, heads, layer_norm_eps, num_layers, dtype, device, operations, + use_swiglu_ffn=use_swiglu_ffn, + qknorm_start=self.qknorm_start, + rope=self.rope, rope_start=self.rope_start, + ) self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) + # ------------------------------------------------------------------ + # CLIP-vision-style forward (no DA3 extensions, no multi-layer output). + # Kept for backward compatibility with ``ClipVisionModel.encode_image``. + # ------------------------------------------------------------------ def forward(self, pixel_values, attention_mask=None, intermediate_output=None): x = self.embeddings(pixel_values) x, i = self.encoder(x, intermediate_output=intermediate_output) x = self.layernorm(x) pooled_output = x[:, 0, :] return x, i, pooled_output, None + + # ------------------------------------------------------------------ + # Depth Anything 3 forward + # ------------------------------------------------------------------ + def _prepare_rope_positions(self, B, S, H, W, device): + if self.rope is None: + return None, None + ph, pw = H // self.patch_size, W // self.patch_size + pos = self._position_getter(B * S, ph, pw, device=device) + # Shift so the cls/cam token at position 0 is reserved for "no diff". + pos = pos + 1 + cls_pos = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype) + # Per-view local: real grid positions for patches, 0 for cls token. + pos_local = torch.cat([cls_pos, pos], dim=1) + # Global (across views): same grid positions; cls token still at 0, + # but patches share the same positions in every view. + pos_global = torch.cat([cls_pos, torch.zeros_like(pos) + 1], dim=1) + return pos_local, pos_global + + def _inject_camera_token(self, x: torch.Tensor, B: int, S: int, + cam_token: "torch.Tensor | None") -> torch.Tensor: + # x: (B, S, N, C). Replace token at index 0 with the camera token. + if cam_token is not None: + inj = cam_token + else: + ct = comfy.model_management.cast_to_device(self.embeddings.camera_token, x.device, x.dtype) + ref_token = ct[:, :1].expand(B, -1, -1) + src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1) + inj = torch.cat([ref_token, src_token], dim=1) + x = x.clone() + x[:, :, 0] = inj + return x + + def get_intermediate_layers(self, pixel_values, out_layers, cam_token=None, + ref_view_strategy="saddle_balanced", + export_feat_layers=None): + """Multi-layer DINOv2 feature extraction used by Depth Anything 3. + + Args: + pixel_values: ``(B, S, 3, H, W)`` views or ``(B, 3, H, W)``. + out_layers: indices into ``self.encoder.layer``. + cam_token: optional ``(B, S, dim)`` camera token to inject at + ``alt_start``. If ``None`` and the model has its own + ``camera_token`` parameter, that is used. + ref_view_strategy: when ``S >= 3`` and ``cam_token is None``, + pick a reference view via this strategy and move it to + position 0 right before the first alt-attention block. + The original view order is restored on the way out. + export_feat_layers: optional iterable of layer indices whose + local attention outputs to also return as auxiliary + features (``(B, S, N_patch, C)`` after final norm). Used + by the multi-view path to expose intermediate features + to the nested-architecture wrapper. + + Returns: + ``(layer_outputs, aux_outputs)`` where ``layer_outputs`` is a + list of ``(patch_tokens, cls_or_cam_token)`` tuples (one per + ``out_layers`` entry) and ``aux_outputs`` is a list of + ``(B, S, N_patch, C)`` features for ``export_feat_layers`` + (empty list when not requested). + """ + if pixel_values.ndim == 4: + pixel_values = pixel_values.unsqueeze(1) + assert pixel_values.ndim == 5 and pixel_values.shape[2] == 3, \ + f"expected (B,3,H,W) or (B,S,3,H,W); got {tuple(pixel_values.shape)}" + B, S, _, H, W = pixel_values.shape + + # Patch + cls + (interpolated) pos embed for each view. + x = pixel_values.reshape(B * S, 3, H, W) + x = self.embeddings(x) # (B*S, 1+N, C) + x = x.reshape(B, S, x.shape[-2], x.shape[-1]) # (B, S, 1+N, C) + + pos_local, pos_global = self._prepare_rope_positions(B, S, H, W, x.device) + # ``optimized_attention`` is only used by blocks without QK-norm/RoPE + # (vanilla DINOv2 path); enabling-aware blocks fall through to SDPA. + optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) + + out_set = set(out_layers) + export_set = set(export_feat_layers) if export_feat_layers else set() + outputs: list[torch.Tensor] = [] + aux_outputs: list[torch.Tensor] = [] + local_x = x + b_idx = None + + + for i, blk in enumerate(self.encoder.layer): + apply_rope = self.rope is not None and i >= self.rope_start + block_rope = self.rope if apply_rope else None + l_pos = pos_local if apply_rope else None + g_pos = pos_global if apply_rope else None + + # Reference-view selection threshold: matches the upstream constant + # ``THRESH_FOR_REF_SELECTION = 3``. Skipped when a user-supplied + # cam_token is provided (camera info already pins the geometry). + if (self.alt_start != -1 and i == self.alt_start - 1 + and S >= THRESH_FOR_REF_SELECTION and cam_token is None): + b_idx = select_reference_view(x, strategy=ref_view_strategy) + x = reorder_by_reference(x, b_idx) + local_x = reorder_by_reference(local_x, b_idx) + + if self.alt_start != -1 and i == self.alt_start: + x = self._inject_camera_token(x, B, S, cam_token) + + if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1): + # Global attention across views: flatten S into the seq dim. + t = x.reshape(B, S * x.shape[-2], x.shape[-1]) + p = g_pos.reshape(B, S * g_pos.shape[-2], g_pos.shape[-1]) if g_pos is not None else None + t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope) + x = t.reshape(B, S, x.shape[-2], x.shape[-1]) + else: + # Per-view local attention. + t = x.reshape(B * S, x.shape[-2], x.shape[-1]) + p = l_pos.reshape(B * S, l_pos.shape[-2], l_pos.shape[-1]) if l_pos is not None else None + t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope) + x = t.reshape(B, S, x.shape[-2], x.shape[-1]) + local_x = x + + if i in out_set: + if self.cat_token: + out_x = torch.cat([local_x, x], dim=-1) + else: + out_x = x + # Restore original view order on the way out so heads see views + # in the user's expected order. + if b_idx is not None and self.alt_start != -1: + out_x = restore_original_order(out_x, b_idx) + outputs.append(out_x) + + if i in export_set: + aux = x + if b_idx is not None and self.alt_start != -1: + aux = restore_original_order(aux, b_idx) + aux_outputs.append(aux) + + # Apply final norm. When ``cat_token`` is set, only the right half + # ("global" features) is normalised; the left half is left as-is to + # match the upstream DA3 head signature. + normed: list[torch.Tensor] = [] + cls_tokens: list[torch.Tensor] = [] + for out_x in outputs: + cls_tokens.append(out_x[:, :, 0]) + if out_x.shape[-1] == self.embed_dim: + normed.append(self.layernorm(out_x)) + elif out_x.shape[-1] == self.embed_dim * 2: + left = out_x[..., :self.embed_dim] + right = self.layernorm(out_x[..., self.embed_dim:]) + normed.append(torch.cat([left, right], dim=-1)) + else: + raise ValueError(f"Unexpected token width: {out_x.shape[-1]}") + + # Drop cls/cam token from the patch sequence. + normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed] + + # Final layernorm + drop cls token from auxiliary features too. + aux_normed = [self.layernorm(o)[..., 1 + self.num_register_tokens:, :] + for o in aux_outputs] + return list(zip(normed, cls_tokens)), aux_normed diff --git a/comfy/ldm/depth_anything_3/__init__.py b/comfy/ldm/depth_anything_3/__init__.py new file mode 100644 index 000000000..0bc31314e --- /dev/null +++ b/comfy/ldm/depth_anything_3/__init__.py @@ -0,0 +1,7 @@ +# Depth Anything 3 - native ComfyUI port (Apache-2.0 monocular variants only). +# +# Supported variants: +# DA3-Small, DA3-Base (vits/vitb backbone, DualDPT head) +# DA3Mono-Large, DA3Metric-Large (vitl backbone, DPT head + sky mask) +# +# Original repo: https://github.com/ByteDance-Seed/Depth-Anything-3 diff --git a/comfy/ldm/depth_anything_3/camera.py b/comfy/ldm/depth_anything_3/camera.py new file mode 100644 index 000000000..045d88c60 --- /dev/null +++ b/comfy/ldm/depth_anything_3/camera.py @@ -0,0 +1,214 @@ +"""Camera-token encoder and decoder for Depth Anything 3. + +* :class:`CameraEnc` takes per-view extrinsics + intrinsics and produces a + per-view camera token that gets injected at the alt-attention boundary + in the DINOv2 backbone (block ``alt_start``). +* :class:`CameraDec` takes the final-layer camera token output by the + backbone and predicts a 9-D pose encoding (translation, quaternion, + field-of-view). + +The module/parameter names match the upstream ``cam_enc.py``/``cam_dec.py`` +so HF safetensors load directly with no key remapping (the upstream uses +fused QKV linears, which we replicate here). +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .transform import affine_inverse, extri_intri_to_pose_encoding + + +# ----------------------------------------------------------------------------- +# Building blocks (mirror ``depth_anything_3.model.utils.{attention,block}``) +# ----------------------------------------------------------------------------- + + +class _Mlp(nn.Module): + """Standard 2-layer MLP with GELU. Matches upstream ``utils.attention.Mlp``.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, + *, device=None, dtype=None, operations=None): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = operations.Linear(in_features, hidden_features, bias=True, + device=device, dtype=dtype) + self.fc2 = operations.Linear(hidden_features, out_features, bias=True, + device=device, dtype=dtype) + + def forward(self, x): + return self.fc2(F.gelu(self.fc1(x))) + + +class _LayerScale(nn.Module): + """Per-channel learnable scaling. Matches upstream ``LayerScale``.""" + + def __init__(self, dim, *, device=None, dtype=None): + super().__init__() + self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + + def forward(self, x): + return x * self.gamma.to(dtype=x.dtype, device=x.device) + + +class _Attention(nn.Module): + """Self-attention with fused QKV projection. + + Mirrors upstream ``utils.attention.Attention``; layout matches the + HF safetensors (``attn.qkv.{weight,bias}`` and ``attn.proj.{weight,bias}``). + """ + + def __init__(self, dim, num_heads, + *, device=None, dtype=None, operations=None): + super().__init__() + assert dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = operations.Linear(dim, dim * 3, bias=True, + device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, bias=True, + device=device, dtype=dtype) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) + qkv = qkv.permute(2, 0, 3, 1, 4) # 3, B, h, N, d + q, k, v = qkv.unbind(0) + out = F.scaled_dot_product_attention(q, k, v) + out = out.transpose(1, 2).reshape(B, N, C) + return self.proj(out) + + +class _Block(nn.Module): + """Pre-norm transformer block with LayerScale. + + Used by :class:`CameraEnc`. Layout follows upstream ``utils.block.Block``. + """ + + def __init__(self, dim, num_heads, mlp_ratio=4, init_values=0.01, + *, device=None, dtype=None, operations=None): + super().__init__() + self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype) + self.attn = _Attention(dim, num_heads, + device=device, dtype=dtype, operations=operations) + self.ls1 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity() + self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype) + self.mlp = _Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), + device=device, dtype=dtype, operations=operations) + self.ls2 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity() + + def forward(self, x): + x = x + self.ls1(self.attn(self.norm1(x))) + x = x + self.ls2(self.mlp(self.norm2(x))) + return x + + +# ----------------------------------------------------------------------------- +# Camera encoder +# ----------------------------------------------------------------------------- + + +class CameraEnc(nn.Module): + """Encode per-view (extrinsics, intrinsics) into a camera token. + + Maps a 9-D pose-encoding vector through a small MLP up to the backbone's + ``embed_dim``, then runs ``trunk_depth`` transformer blocks. The output + has shape ``(B, S, embed_dim)`` and is injected at block ``alt_start`` + of the DINOv2 backbone in place of the cls token. + + Parameters mirror the upstream ``cam_enc.py`` so HF weights load directly. + """ + + def __init__( + self, + dim_out: int = 1024, + dim_in: int = 9, + trunk_depth: int = 4, + target_dim: int = 9, + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + *, + device=None, dtype=None, operations=None, + **_kwargs, + ): + super().__init__() + self.target_dim = target_dim + self.trunk_depth = trunk_depth + self.trunk = nn.Sequential(*[ + _Block(dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio, + init_values=init_values, + device=device, dtype=dtype, operations=operations) + for _ in range(trunk_depth) + ]) + self.token_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype) + self.trunk_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype) + self.pose_branch = _Mlp( + in_features=dim_in, + hidden_features=dim_out // 2, + out_features=dim_out, + device=device, dtype=dtype, operations=operations, + ) + + def forward(self, extrinsics: torch.Tensor, intrinsics: torch.Tensor, + image_size_hw) -> torch.Tensor: + """Encode camera parameters into ``(B, S, dim_out)`` tokens.""" + c2ws = affine_inverse(extrinsics) + pose_encoding = extri_intri_to_pose_encoding(c2ws, intrinsics, image_size_hw) + tokens = self.pose_branch(pose_encoding.to(self.pose_branch.fc1.weight.dtype)) + tokens = self.token_norm(tokens) + tokens = self.trunk(tokens) + tokens = self.trunk_norm(tokens) + return tokens + + +# ----------------------------------------------------------------------------- +# Camera decoder +# ----------------------------------------------------------------------------- + + +class CameraDec(nn.Module): + """Decode the final cam token into a 9-D pose encoding. + + Output layout: ``[T(3), quat_xyzw(4), fov_h, fov_w]``. The translation is + always predicted by the network; the quaternion and FoV can either be + predicted or supplied via ``camera_encoding`` (used at training time + when GT cameras are available -- not exercised at inference here). + + Parameters mirror the upstream ``cam_dec.py`` so HF weights load directly. + """ + + def __init__(self, dim_in: int = 1536, + *, device=None, dtype=None, operations=None, **_kwargs): + super().__init__() + d = dim_in + self.backbone = nn.Sequential( + operations.Linear(d, d, device=device, dtype=dtype), + nn.ReLU(), + operations.Linear(d, d, device=device, dtype=dtype), + nn.ReLU(), + ) + self.fc_t = operations.Linear(d, 3, device=device, dtype=dtype) + self.fc_qvec = operations.Linear(d, 4, device=device, dtype=dtype) + self.fc_fov = nn.Sequential( + operations.Linear(d, 2, device=device, dtype=dtype), + nn.ReLU(), + ) + + def forward(self, feat: torch.Tensor, + camera_encoding: "torch.Tensor | None" = None) -> torch.Tensor: + """Decode ``(B, N, dim_in)`` cam tokens into ``(B, N, 9)`` pose enc.""" + B, N = feat.shape[:2] + feat = feat.reshape(B * N, -1) + feat = self.backbone(feat) + out_t = self.fc_t(feat.float()).reshape(B, N, 3) + if camera_encoding is None: + out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4) + out_fov = self.fc_fov(feat.float()).reshape(B, N, 2) + else: + out_qvec = camera_encoding[..., 3:7] + out_fov = camera_encoding[..., -2:] + return torch.cat([out_t, out_qvec, out_fov], dim=-1) diff --git a/comfy/ldm/depth_anything_3/dpt.py b/comfy/ldm/depth_anything_3/dpt.py new file mode 100644 index 000000000..b186b7167 --- /dev/null +++ b/comfy/ldm/depth_anything_3/dpt.py @@ -0,0 +1,549 @@ +# DPT / DualDPT heads for Depth Anything 3. +# +# Ported from: +# src/depth_anything_3/model/dpt.py (DPT - single main head + sky head) +# src/depth_anything_3/model/dualdpt.py (DualDPT - depth + auxiliary "ray" head) +# +# In the monocular path we always discard the auxiliary "ray" output of +# DualDPT. The auxiliary branch is still constructed so that DA3 HF weights +# load cleanly without missing-key warnings. + +from __future__ import annotations + +from typing import List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# ----------------------------------------------------------------------------- +# Helpers (matching upstream head_utils.py) +# ----------------------------------------------------------------------------- + + +class Permute(nn.Module): + def __init__(self, dims: Tuple[int, ...]): + super().__init__() + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(*self.dims) + + +def _custom_interpolate( + x: torch.Tensor, + size: Optional[Tuple[int, int]] = None, + scale_factor: Optional[float] = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + if size is None: + assert scale_factor is not None + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + INT_MAX = 1610612736 + total = size[0] * size[1] * x.shape[0] * x.shape[1] + if total > INT_MAX: + chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0) + outs = [F.interpolate(c, size=size, mode=mode, align_corners=align_corners) for c in chunks] + return torch.cat(outs, dim=0).contiguous() + return F.interpolate(x, size=size, mode=mode, align_corners=align_corners) + + +def _create_uv_grid(width: int, height: int, aspect_ratio: float, + dtype, device) -> torch.Tensor: + """Normalised UV grid spanning (-x_span, -y_span)..(x_span, y_span).""" + diag_factor = (aspect_ratio ** 2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + return torch.stack((uu, vv), dim=-1) # (H, W, 2) + + +def _make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100.0) -> torch.Tensor: + omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) + omega = 1.0 / omega_0 ** (omega / (embed_dim / 2.0)) + pos = pos.reshape(-1) + out = torch.einsum("m,d->md", pos, omega) + return torch.cat([out.sin(), out.cos()], dim=1).float() + + +def _position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, + omega_0: float = 100.0) -> torch.Tensor: + H, W, _ = pos_grid.shape + pos_flat = pos_grid.reshape(-1, 2) + emb_x = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) + emb_y = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) + emb = torch.cat([emb_x, emb_y], dim=-1) + return emb.view(H, W, embed_dim) + + +def _add_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """Stateless UV positional embedding added to a feature map (B, C, h, w).""" + pw, ph = x.shape[-1], x.shape[-2] + pe = _create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pe = _position_grid_to_embed(pe, x.shape[1]) * ratio + pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1).to(dtype=x.dtype) + return x + pe + + +def _apply_activation(x: torch.Tensor, activation: str) -> torch.Tensor: + act = (activation or "linear").lower() + if act == "exp": + return torch.exp(x) + if act == "expp1": + return torch.exp(x) + 1 + if act == "expm1": + return torch.expm1(x) + if act == "relu": + return torch.relu(x) + if act == "sigmoid": + return torch.sigmoid(x) + if act == "softplus": + return F.softplus(x) + if act == "tanh": + return torch.tanh(x) + return x + + +# ----------------------------------------------------------------------------- +# Fusion building blocks +# ----------------------------------------------------------------------------- + + +class ResidualConvUnit(nn.Module): + def __init__(self, features: int, + device=None, dtype=None, operations=None): + super().__init__() + self.conv1 = operations.Conv2d(features, features, 3, 1, 1, bias=True, + device=device, dtype=dtype) + self.conv2 = operations.Conv2d(features, features, 3, 1, 1, bias=True, + device=device, dtype=dtype) + self.activation = nn.ReLU(inplace=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.activation(x) + out = self.conv1(out) + out = self.activation(out) + out = self.conv2(out) + return out + x + + +class FeatureFusionBlock(nn.Module): + def __init__(self, features: int, has_residual: bool = True, + align_corners: bool = True, + device=None, dtype=None, operations=None): + super().__init__() + self.align_corners = align_corners + self.has_residual = has_residual + if has_residual: + self.resConfUnit1 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations) + else: + self.resConfUnit1 = None + self.resConfUnit2 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations) + self.out_conv = operations.Conv2d(features, features, 1, 1, 0, bias=True, + device=device, dtype=dtype) + + def forward(self, *xs: torch.Tensor, size: Optional[Tuple[int, int]] = None) -> torch.Tensor: + y = xs[0] + if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None: + y = y + self.resConfUnit1(xs[1]) + y = self.resConfUnit2(y) + if size is None: + up_kwargs = {"scale_factor": 2.0} + else: + up_kwargs = {"size": size} + y = _custom_interpolate(y, **up_kwargs, mode="bilinear", + align_corners=self.align_corners) + y = self.out_conv(y) + return y + + +class _Scratch(nn.Module): + """Container that mirrors upstream ``scratch`` attribute layout.""" + + +def _make_scratch(in_shape: List[int], out_shape: int, + device=None, dtype=None, operations=None) -> _Scratch: + scratch = _Scratch() + scratch.layer1_rn = operations.Conv2d(in_shape[0], out_shape, 3, 1, 1, bias=False, + device=device, dtype=dtype) + scratch.layer2_rn = operations.Conv2d(in_shape[1], out_shape, 3, 1, 1, bias=False, + device=device, dtype=dtype) + scratch.layer3_rn = operations.Conv2d(in_shape[2], out_shape, 3, 1, 1, bias=False, + device=device, dtype=dtype) + scratch.layer4_rn = operations.Conv2d(in_shape[3], out_shape, 3, 1, 1, bias=False, + device=device, dtype=dtype) + return scratch + + +def _make_fusion_block(features: int, has_residual: bool = True, + device=None, dtype=None, operations=None) -> FeatureFusionBlock: + return FeatureFusionBlock(features, has_residual=has_residual, + align_corners=True, + device=device, dtype=dtype, operations=operations) + + +# ----------------------------------------------------------------------------- +# DPT (single head + optional sky head) -- used by DA3Mono/Metric +# ----------------------------------------------------------------------------- + + +class DPT(nn.Module): + """Single-head DPT used by DA3Mono-Large and DA3Metric-Large.""" + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 1, + activation: str = "exp", + conf_activation: str = "expp1", + features: int = 256, + out_channels: Sequence[int] = (256, 512, 1024, 1024), + pos_embed: bool = False, + down_ratio: int = 1, + head_name: str = "depth", + use_sky_head: bool = True, + sky_name: str = "sky", + sky_activation: str = "relu", + norm_type: str = "idt", + device=None, dtype=None, operations=None, + ): + super().__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.down_ratio = down_ratio + self.head_main = head_name + self.sky_name = sky_name + self.out_dim = output_dim + self.has_conf = output_dim > 1 + self.use_sky_head = use_sky_head + self.sky_activation = sky_activation + self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3) + + if norm_type == "layer": + self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype) + else: + self.norm = nn.Identity() + + out_channels = list(out_channels) + self.projects = nn.ModuleList([ + operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, + device=device, dtype=dtype) + for oc in out_channels + ]) + self.resize_layers = nn.ModuleList([ + operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, + device=device, dtype=dtype), + operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, + device=device, dtype=dtype), + nn.Identity(), + operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, + device=device, dtype=dtype), + ]) + + self.scratch = _make_scratch(out_channels, features, + device=device, dtype=dtype, operations=operations) + self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, + device=device, dtype=dtype, operations=operations) + + head_features_1 = features + head_features_2 = 32 + self.scratch.output_conv1 = operations.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype, + ) + self.scratch.output_conv2 = nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype), + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, + device=device, dtype=dtype), + ) + + if self.use_sky_head: + self.scratch.sky_output_conv2 = nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype), + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0, + device=device, dtype=dtype), + ) + + def forward(self, feats: List[torch.Tensor], H: int, W: int, + patch_start_idx: int = 0, **_kwargs) -> dict: + # feats[i][0] is the patch-token tensor with shape (B, S, N_patch, C) + B, S, N, C = feats[0][0].shape + feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats] + + ph, pw = H // self.patch_size, W // self.patch_size + resized = [] + for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): + x = feats_flat[take_idx][:, patch_start_idx:] + x = self.norm(x) + x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw) + x = self.projects[stage_idx](x) + if self.pos_embed: + x = _add_pos_embed(x, W, H) + x = self.resize_layers[stage_idx](x) + resized.append(x) + + l1_rn = self.scratch.layer1_rn(resized[0]) + l2_rn = self.scratch.layer2_rn(resized[1]) + l3_rn = self.scratch.layer3_rn(resized[2]) + l4_rn = self.scratch.layer4_rn(resized[3]) + + out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:]) + out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:]) + out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:]) + out = self.scratch.refinenet1(out, l1_rn) + + h_out = int(ph * self.patch_size / self.down_ratio) + w_out = int(pw * self.patch_size / self.down_ratio) + + fused = self.scratch.output_conv1(out) + fused = _custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True) + if self.pos_embed: + fused = _add_pos_embed(fused, W, H) + feat = fused + + main_logits = self.scratch.output_conv2(feat) + outs = {} + if self.has_conf: + fmap = main_logits.permute(0, 2, 3, 1) + pred = _apply_activation(fmap[..., :-1], self.activation) + conf = _apply_activation(fmap[..., -1], self.conf_activation) + outs[self.head_main] = pred.squeeze(-1).view(B, S, *pred.shape[1:-1]) + outs[f"{self.head_main}_conf"] = conf.view(B, S, *conf.shape[1:]) + else: + pred = _apply_activation(main_logits, self.activation) + outs[self.head_main] = pred.squeeze(1).view(B, S, *pred.shape[2:]) + + if self.use_sky_head: + sky_logits = self.scratch.sky_output_conv2(feat) + if self.sky_activation.lower() == "sigmoid": + sky = torch.sigmoid(sky_logits) + elif self.sky_activation.lower() == "relu": + sky = F.relu(sky_logits) + else: + sky = sky_logits + outs[self.sky_name] = sky.squeeze(1).view(B, S, *sky.shape[2:]) + + return outs + + +# ----------------------------------------------------------------------------- +# DualDPT (depth + auxiliary "ray" head) -- used by DA3-Small / DA3-Base +# ----------------------------------------------------------------------------- + + +class DualDPT(nn.Module): + """Two-head DPT used by DA3-Small / DA3-Base. + + The auxiliary "ray" head is constructed so that HF state-dict keys load + cleanly. It is only executed when :attr:`enable_aux` is set on the + instance (typically by ``DepthAnything3Net`` when running multi-view + with ``use_ray_pose=True``); otherwise the monocular path skips it for + speed and the auxiliary submodules sit idle. + """ + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 2, + activation: str = "exp", + conf_activation: str = "expp1", + features: int = 256, + out_channels: Sequence[int] = (256, 512, 1024, 1024), + pos_embed: bool = True, + down_ratio: int = 1, + aux_pyramid_levels: int = 4, + aux_out1_conv_num: int = 5, + head_names: Tuple[str, str] = ("depth", "ray"), + device=None, dtype=None, operations=None, + ): + super().__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.down_ratio = down_ratio + self.aux_levels = aux_pyramid_levels + self.aux_out1_conv_num = aux_out1_conv_num + self.head_main, self.head_aux = head_names + self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3) + # Toggle the auxiliary ray branch at runtime. Default off (mono path). + # ``DepthAnything3Net`` flips this on when running multi-view + ray-pose. + self.enable_aux: bool = False + + self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype) + out_channels = list(out_channels) + self.projects = nn.ModuleList([ + operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, + device=device, dtype=dtype) + for oc in out_channels + ]) + self.resize_layers = nn.ModuleList([ + operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, + device=device, dtype=dtype), + operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, + device=device, dtype=dtype), + nn.Identity(), + operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, + device=device, dtype=dtype), + ]) + + self.scratch = _make_scratch(out_channels, features, + device=device, dtype=dtype, operations=operations) + # Main fusion chain + self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, + device=device, dtype=dtype, operations=operations) + # Auxiliary fusion chain (separate copies) + self.scratch.refinenet1_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet2_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet3_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False, + device=device, dtype=dtype, operations=operations) + + head_features_1 = features + head_features_2 = 32 + + # Main head neck + final projection + self.scratch.output_conv1 = operations.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype, + ) + self.scratch.output_conv2 = nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype), + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, + device=device, dtype=dtype), + ) + + # Aux pre-head per level (multi-level pyramid) + self.scratch.output_conv1_aux = nn.ModuleList([ + self._make_aux_out1_block(head_features_1, device=device, dtype=dtype, operations=operations) + for _ in range(self.aux_levels) + ]) + + # Aux final projection per level (includes LayerNorm permute path). + ln_seq = [Permute((0, 2, 3, 1)), + operations.LayerNorm(head_features_2, device=device, dtype=dtype), + Permute((0, 3, 1, 2))] + self.scratch.output_conv2_aux = nn.ModuleList([ + nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype), + *ln_seq, + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0, + device=device, dtype=dtype), + ) + for _ in range(self.aux_levels) + ]) + + @staticmethod + def _make_aux_out1_block(in_ch: int, *, device=None, dtype=None, operations=None) -> nn.Sequential: + # aux_out1_conv_num=5 in all Apache-2.0 variants. + return nn.Sequential( + operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype), + ) + + def forward(self, feats: List[torch.Tensor], H: int, W: int, + patch_start_idx: int = 0, **_kwargs) -> dict: + B, S, N, C = feats[0][0].shape + feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats] + + ph, pw = H // self.patch_size, W // self.patch_size + resized = [] + for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): + x = feats_flat[take_idx][:, patch_start_idx:] + x = self.norm(x) + x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw) + x = self.projects[stage_idx](x) + if self.pos_embed: + x = _add_pos_embed(x, W, H) + x = self.resize_layers[stage_idx](x) + resized.append(x) + + l1_rn = self.scratch.layer1_rn(resized[0]) + l2_rn = self.scratch.layer2_rn(resized[1]) + l3_rn = self.scratch.layer3_rn(resized[2]) + l4_rn = self.scratch.layer4_rn(resized[3]) + + # Main pyramid (output_conv1 is applied inside the upstream `_fuse`, + # before interpolation -- replicate that order here). + m = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:]) + if self.enable_aux: + a4 = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:]) + aux_pyr = [a4] + m = self.scratch.refinenet3(m, l3_rn, size=l2_rn.shape[2:]) + if self.enable_aux: + aux_pyr.append(self.scratch.refinenet3_aux(aux_pyr[-1], l3_rn, size=l2_rn.shape[2:])) + m = self.scratch.refinenet2(m, l2_rn, size=l1_rn.shape[2:]) + if self.enable_aux: + aux_pyr.append(self.scratch.refinenet2_aux(aux_pyr[-1], l2_rn, size=l1_rn.shape[2:])) + m = self.scratch.refinenet1(m, l1_rn) + if self.enable_aux: + aux_pyr.append(self.scratch.refinenet1_aux(aux_pyr[-1], l1_rn)) + m = self.scratch.output_conv1(m) + + h_out = int(ph * self.patch_size / self.down_ratio) + w_out = int(pw * self.patch_size / self.down_ratio) + + m = _custom_interpolate(m, (h_out, w_out), mode="bilinear", align_corners=True) + if self.pos_embed: + m = _add_pos_embed(m, W, H) + main_logits = self.scratch.output_conv2(m) + fmap = main_logits.permute(0, 2, 3, 1) + depth_pred = _apply_activation(fmap[..., :-1], self.activation) + depth_conf = _apply_activation(fmap[..., -1], self.conf_activation) + + outs = { + self.head_main: depth_pred.squeeze(-1).view(B, S, *depth_pred.shape[1:-1]), + f"{self.head_main}_conf": depth_conf.view(B, S, *depth_conf.shape[1:]), + } + + if self.enable_aux: + # Auxiliary "ray" head (multi-level inside) -- only the last level + # is returned. Mirrors upstream ``DualDPT._fuse`` + ``_forward_impl``: + # each aux pyramid level goes through ``output_conv1_aux[i]`` + # (5-layer conv stack that ends at ``features // 2`` channels), + # then the last level optionally gets a pos-embed and finally + # ``output_conv2_aux[-1]``. + aux_processed = [ + self.scratch.output_conv1_aux[i](a) for i, a in enumerate(aux_pyr) + ] + last_aux = aux_processed[-1] + if self.pos_embed: + last_aux = _add_pos_embed(last_aux, W, H) + last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux) + fmap_last = last_aux_logits.permute(0, 2, 3, 1) + # Channels: [ray(6), ray_conf(1)]; ray uses 'linear' activation. + aux_pred = fmap_last[..., :-1] + aux_conf = _apply_activation(fmap_last[..., -1], self.conf_activation) + outs[self.head_aux] = aux_pred.view(B, S, *aux_pred.shape[1:]) + outs[f"{self.head_aux}_conf"] = aux_conf.view(B, S, *aux_conf.shape[1:]) + + return outs diff --git a/comfy/ldm/depth_anything_3/model.py b/comfy/ldm/depth_anything_3/model.py new file mode 100644 index 000000000..782517bac --- /dev/null +++ b/comfy/ldm/depth_anything_3/model.py @@ -0,0 +1,309 @@ +# DepthAnything3Net: top-level wrapper that combines backbone + head. +# +# Supports both the monocular and the multi-view + camera path: +# +# * Monocular: ``S = 1``, no camera encoder/decoder. Mirrors the original +# port that only handled ``DA3-MONO/METRIC-LARGE`` and the auxiliary-disabled +# ``DA3-SMALL/BASE`` configs. +# * Multi-view + camera: ``S > 1``. ``cam_enc`` (optional) maps user-supplied +# extrinsics + intrinsics into a per-view camera token; ``cam_dec`` decodes +# the final layer's camera token into a 9-D pose encoding. When the +# auxiliary "ray" head of ``DualDPT`` is enabled the predicted ray map can +# alternatively be used to estimate pose via RANSAC (``use_ray_pose=True``). +# The 3D-Gaussian head and the nested-architecture wrapper are intentionally +# left out of scope here; their state-dict keys are filtered in +# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``. +# +# The backbone is shared with the CLIP-vision DINOv2 path +# (``comfy.image_encoders.dino2.Dinov2Model``); the DA3-specific extensions +# (RoPE, QK-norm, alternating local/global attention, camera token, multi- +# layer feature extraction, reference-view reordering) are opt-in via the +# config dict and are all disabled for the Mono/Metric variants. + +from __future__ import annotations + +from typing import Dict, Optional, Sequence + +import torch +import torch.nn as nn + +from comfy.image_encoders.dino2 import Dinov2Model + +from .camera import CameraDec, CameraEnc +from .dpt import DPT, DualDPT +from .ray_pose import get_extrinsic_from_camray +from .transform import affine_inverse, pose_encoding_to_extri_intri + + +_HEAD_REGISTRY = { + "dpt": DPT, + "dualdpt": DualDPT, +} + + +# Backbone presets (mirror the upstream DINOv2 ViT variants). +_BACKBONE_PRESETS = { + "vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False), + "vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False), + "vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False), + "vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True), +} + + +def _build_backbone_config( + backbone_name: str, + *, + alt_start: int, + qknorm_start: int, + rope_start: int, + cat_token: bool, +) -> dict: + if backbone_name not in _BACKBONE_PRESETS: + raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}") + cfg = dict(_BACKBONE_PRESETS[backbone_name]) + cfg.update(dict( + layer_norm_eps=1e-6, + patch_size=14, + image_size=518, + # DA3 weights have no mask_token; skip registering it to avoid spurious + # missing-key warnings on load. + use_mask_token=False, + alt_start=alt_start, + qknorm_start=qknorm_start, + rope_start=rope_start, + cat_token=cat_token, + rope_freq=100.0, + )) + return cfg + + +class DepthAnything3Net(nn.Module): + """ComfyUI-side DepthAnything3 network. + + Parameters mirror the variant YAML configs from the upstream repo and + are auto-detected from the state dict by ``comfy/model_detection.py``. + The kwargs ``device``, ``dtype`` and ``operations`` are injected by + ``BaseModel``. + """ + + PATCH_SIZE = 14 + + def __init__( + self, + # --- Backbone --- + backbone_name: str = "vitl", + out_layers: Sequence[int] = (4, 11, 17, 23), + alt_start: int = -1, + qknorm_start: int = -1, + rope_start: int = -1, + cat_token: bool = False, + # --- Head --- + head_type: str = "dpt", # "dpt" or "dualdpt" + head_dim_in: int = 1024, + head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf + head_features: int = 256, + head_out_channels: Sequence[int] = (256, 512, 1024, 1024), + head_use_sky_head: bool = True, # ignored by DualDPT + head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT + # --- Camera (multi-view) --- + has_cam_enc: bool = False, + has_cam_dec: bool = False, + cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim) + cam_dec_dim_in: Optional[int] = None, # CameraDec dim_in (defaults to 2*embed_dim with cat_token) + # ComfyUI plumbing + device=None, dtype=None, operations=None, + **_ignored, + ): + super().__init__() + head_cls = _HEAD_REGISTRY[head_type.lower()] + self.head_type = head_type.lower() + self.has_sky = (self.head_type == "dpt") and head_use_sky_head + self.has_conf = head_output_dim > 1 + self.out_layers = list(out_layers) + + backbone_cfg = _build_backbone_config( + backbone_name, + alt_start=alt_start, + qknorm_start=qknorm_start, + rope_start=rope_start, + cat_token=cat_token, + ) + self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations) + + head_kwargs = dict( + dim_in=head_dim_in, + patch_size=self.PATCH_SIZE, + output_dim=head_output_dim, + features=head_features, + out_channels=tuple(head_out_channels), + device=device, dtype=dtype, operations=operations, + ) + if self.head_type == "dpt": + head_kwargs.update( + use_sky_head=head_use_sky_head, + pos_embed=(False if head_pos_embed is None else head_pos_embed), + ) + else: # dualdpt + head_kwargs.update( + pos_embed=(True if head_pos_embed is None else head_pos_embed), + ) + self.head = head_cls(**head_kwargs) + + # Camera encoder / decoder are only constructed when their weights are + # present in the checkpoint; the multi-view / pose forward path becomes + # available accordingly. ``cam_enc.dim_out`` matches the backbone's + # ``embed_dim`` so the cam token slots into block ``alt_start``. + embed_dim = backbone_cfg["hidden_size"] + if has_cam_enc: + self.cam_enc = CameraEnc( + dim_out=cam_dim_out if cam_dim_out is not None else embed_dim, + num_heads=max(1, embed_dim // 64), + device=device, dtype=dtype, operations=operations, + ) + else: + self.cam_enc = None + if has_cam_dec: + # Default cam_dec dim_in is 2*embed_dim when cat_token is on + # (the cls/cam token in the output is the cat'd version). + default_dim = embed_dim * (2 if cat_token else 1) + self.cam_dec = CameraDec( + dim_in=cam_dec_dim_in if cam_dec_dim_in is not None else default_dim, + device=device, dtype=dtype, operations=operations, + ) + else: + self.cam_dec = None + + self.dtype = dtype + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + def forward( + self, + image: torch.Tensor, + extrinsics: Optional[torch.Tensor] = None, + intrinsics: Optional[torch.Tensor] = None, + *, + use_ray_pose: bool = False, + ref_view_strategy: str = "saddle_balanced", + export_feat_layers: Optional[Sequence[int]] = None, + **_unused, + ) -> Dict[str, torch.Tensor]: + """Run depth (and optionally pose) prediction. + + Args: + image: ``(B, 3, H, W)`` ImageNet-normalised image tensor, or + ``(B, S, 3, H, W)`` for multi-view inputs. ``H`` and ``W`` + must be multiples of 14. + extrinsics: optional ``(B, S, 4, 4)`` world-to-camera extrinsics. + When provided together with ``intrinsics``, ``CameraEnc`` + converts them into per-view camera tokens that the backbone + injects at block ``alt_start``. + intrinsics: optional ``(B, S, 3, 3)`` pixel-space intrinsics. + use_ray_pose: if True, predict pose from the auxiliary "ray" head + (RANSAC over per-pixel rays). Only available on DualDPT + variants. If False (default) and ``cam_dec`` is present, + the final-layer cam token is decoded into pose instead. + ref_view_strategy: reference-view selection strategy used when + ``S >= 3`` and no extrinsics are supplied. See + :mod:`comfy.ldm.depth_anything_3.reference_view_selector`. + export_feat_layers: optional list of backbone layer indices whose + local features to also return as auxiliary outputs (used by + downstream nested-architecture wrappers; empty by default). + + Returns: + Dict with a subset of: + - ``depth`` ``(B*S, H, W)`` raw depth values. + - ``depth_conf`` ``(B*S, H, W)`` confidence (DualDPT only). + - ``sky`` ``(B*S, H, W)`` sky probability (DPT + sky head). + - ``ray`` ``(B, S, h, w, 6)`` per-pixel cam ray (DualDPT, + multi-view, ``use_ray_pose=True`` only). + - ``ray_conf`` ``(B, S, h, w)`` ray confidence. + - ``extrinsics`` ``(B, S, 4, 4)`` world-to-cam, when pose + prediction is active. + - ``intrinsics`` ``(B, S, 3, 3)`` pixel-space intrinsics. + - ``aux_features`` list of ``(B, S, h_p, w_p, C)`` features + when ``export_feat_layers`` is non-empty. + """ + if image.ndim == 4: + image = image.unsqueeze(1) # (B, 1, 3, H, W) + assert image.ndim == 5 and image.shape[2] == 3, \ + f"image must be (B,3,H,W) or (B,S,3,H,W); got {tuple(image.shape)}" + + B, S, _, H, W = image.shape + assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \ + f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}" + + # Camera-token preparation (multi-view path). + cam_token = None + if extrinsics is not None and intrinsics is not None and self.cam_enc is not None: + cam_token = self.cam_enc(extrinsics, intrinsics, (H, W)) + + # Toggle aux ray output on/off depending on what the caller asked for. + if isinstance(self.head, DualDPT): + self.head.enable_aux = bool(use_ray_pose) + + feats, aux_feats = self.backbone.get_intermediate_layers( + image, self.out_layers, cam_token=cam_token, + ref_view_strategy=ref_view_strategy, + export_feat_layers=export_feat_layers, + ) + head_out = self.head(feats, H=H, W=W, patch_start_idx=0) + + # Pose prediction. + out: Dict[str, torch.Tensor] = {} + if use_ray_pose and "ray" in head_out and "ray_conf" in head_out: + ray = head_out["ray"] + ray_conf = head_out["ray_conf"] + extr_c2w, focal, pp = get_extrinsic_from_camray( + ray, ray_conf, ray.shape[-3], ray.shape[-2], + ) + # Match the upstream output: w2c, drop the homogeneous row. + extr_w2c = affine_inverse(extr_c2w)[:, :, :3, :] + # Build pixel-space intrinsics from the normalised focal/pp output. + intr = torch.eye(3, device=ray.device, dtype=ray.dtype) + intr = intr[None, None].expand(extr_c2w.shape[0], extr_c2w.shape[1], 3, 3).clone() + intr[:, :, 0, 0] = focal[:, :, 0] / 2 * W + intr[:, :, 1, 1] = focal[:, :, 1] / 2 * H + intr[:, :, 0, 2] = pp[:, :, 0] * W * 0.5 + intr[:, :, 1, 2] = pp[:, :, 1] * H * 0.5 + out["extrinsics"] = extr_w2c + out["intrinsics"] = intr + elif self.cam_dec is not None and S > 1: + # Decode the cam-token of the final out_layer into a pose encoding. + cam_feat = feats[-1][1] # (B, S, dim_in_to_cam_dec) + pose_enc = self.cam_dec(cam_feat) + c2w_3x4, intr = pose_encoding_to_extri_intri(pose_enc, (H, W)) + # Match the upstream output convention: w2c (world->camera), 3x4. + c2w_4x4 = torch.cat([ + c2w_3x4, + torch.tensor([0, 0, 0, 1], device=c2w_3x4.device, dtype=c2w_3x4.dtype) + .view(1, 1, 1, 4).expand(B, S, 1, 4), + ], dim=-2) + out["extrinsics"] = affine_inverse(c2w_4x4)[:, :, :3, :] + out["intrinsics"] = intr + + # Flatten the views axis for per-pixel outputs (depth/conf/sky) so the + # per-image consumer keeps its (B*S, H, W) interface. + for k, v in head_out.items(): + if k in ("ray", "ray_conf"): + # Keep multi-view shape for downstream pose work. + out[k] = v + elif v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S: + out[k] = v.reshape(B * S, *v.shape[2:]) + else: + out[k] = v + + if export_feat_layers: + out["aux_features"] = self._reshape_aux_features(aux_feats, H, W) + return out + + def _reshape_aux_features(self, aux_feats, H: int, W: int): + """Reshape ``(B, S, N, C)`` aux features into ``(B, S, h_p, w_p, C)``.""" + ph, pw = H // self.PATCH_SIZE, W // self.PATCH_SIZE + out = [] + for f in aux_feats: + B, S, N, C = f.shape + assert N == ph * pw, f"aux feature seq mismatch: {N} != {ph}*{pw}" + out.append(f.reshape(B, S, ph, pw, C)) + return out diff --git a/comfy/ldm/depth_anything_3/preprocess.py b/comfy/ldm/depth_anything_3/preprocess.py new file mode 100644 index 000000000..b32ad0d10 --- /dev/null +++ b/comfy/ldm/depth_anything_3/preprocess.py @@ -0,0 +1,184 @@ +# Input/output preprocessing helpers for Depth Anything 3. +# +# Ported from: +# src/depth_anything_3/utils/io/input_processor.py (image normalisation) +# src/depth_anything_3/utils/alignment.py (sky-aware depth clip) +# src/depth_anything_3/model/da3.py::_process_mono_sky_estimation +# +# Resize: ``comfy.utils.common_upscale`` with ``upscale_method="lanczos"``. +# Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale); a sweep +# across {bilinear, bicubic, area, lanczos, bislerp} on a 768->504 test image +# showed lanczos has the lowest max-abs-diff vs the upstream cv2 output +# (~0.13 vs 0.21-0.71 for the others), so we use it in both directions for +# simplicity. This keeps the path stateless, on-device, and free of any +# OpenCV dependency. + +from __future__ import annotations + +from typing import Tuple + +import torch + +import comfy.utils + +PATCH_SIZE = 14 + +# ImageNet normalization constants used during DA3 training. +_IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]) +_IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]) + + +def _round_to_patch(x: int, patch: int = PATCH_SIZE) -> int: + down = (x // patch) * patch + up = down + patch + return up if abs(up - x) <= abs(x - down) else down + + +def compute_target_size(orig_h: int, orig_w: int, process_res: int, + method: str = "upper_bound_resize") -> Tuple[int, int]: + """Compute (target_h, target_w) for a single image. + + Methods: + - "upper_bound_resize": scale longest side to ``process_res``, then + round each dim to nearest multiple of 14 (default upstream method). + - "lower_bound_resize": scale shortest side to ``process_res``, then + round. + """ + if method == "upper_bound_resize": + longest = max(orig_h, orig_w) + scale = process_res / float(longest) + elif method == "lower_bound_resize": + shortest = min(orig_h, orig_w) + scale = process_res / float(shortest) + else: + raise ValueError(f"Unsupported process_res_method: {method}") + + new_w = max(1, _round_to_patch(int(round(orig_w * scale)))) + new_h = max(1, _round_to_patch(int(round(orig_h * scale)))) + return new_h, new_w + + +def preprocess_image( + image: torch.Tensor, + process_res: int = 504, + method: str = "upper_bound_resize", +) -> torch.Tensor: + """Preprocess a ComfyUI ``IMAGE`` batch for DA3. + + Args: + image: ``(B, H, W, 3)`` float in [0, 1] (ComfyUI ``IMAGE`` convention). + process_res: target resolution (longest or shortest side, depending + on ``method``). + method: resize strategy. + + Returns: + ``(B, 3, H', W')`` tensor with H' and W' multiples of 14, normalised + with ImageNet statistics. The tensor lives on the same device as + ``image``. + """ + assert image.ndim == 4 and image.shape[-1] == 3, \ + f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}" + B, H, W, _ = image.shape + target_h, target_w = compute_target_size(H, W, process_res, method) + + # (B, H, W, 3) -> (B, 3, H, W) + x = image.movedim(-1, 1).contiguous() + if (target_h, target_w) != (H, W): + # Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale). + # Lanczos in ``common_upscale`` is anti-aliased and produces the + # closest pixel-wise match in a sweep across {bilinear, bicubic, + # area, lanczos, bislerp}. Used in both directions for simplicity. + x = comfy.utils.common_upscale( + x.float(), target_w, target_h, "lanczos", "disabled", + ) + x = x.clamp(0.0, 1.0) + + mean = _IMAGENET_MEAN.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1) + std = _IMAGENET_STD.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1) + x = (x - mean) / std + return x + + +# ----------------------------------------------------------------------------- +# Output post-processing (sky-aware clipping for Mono/Metric variants) +# ----------------------------------------------------------------------------- + + +def compute_non_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor: + """Boolean mask: True for non-sky pixels (sky probability < threshold).""" + return sky_prediction < threshold + + +def apply_sky_aware_clip( + depth: torch.Tensor, + sky: torch.Tensor, + threshold: float = 0.3, + quantile: float = 0.99, +) -> torch.Tensor: + """Replicates ``_process_mono_sky_estimation`` from upstream. + + Clips sky regions to the 99th percentile of non-sky depth. Returns a new + depth tensor; ``depth`` is not modified in place. + """ + non_sky = compute_non_sky_mask(sky, threshold=threshold) + if non_sky.sum() <= 10 or (~non_sky).sum() <= 10: + return depth.clone() + + non_sky_depth = depth[non_sky] + if non_sky_depth.numel() > 100_000: + idx = torch.randint(0, non_sky_depth.numel(), (100_000,), device=non_sky_depth.device) + sampled = non_sky_depth[idx] + else: + sampled = non_sky_depth + + max_depth = torch.quantile(sampled, quantile) + out = depth.clone() + out[~non_sky] = max_depth + return out + + +def normalize_depth_v2_style( + depth: torch.Tensor, + sky: torch.Tensor | None = None, + low_quantile: float = 0.01, + high_quantile: float = 0.99, +) -> torch.Tensor: + """V2-style normalization for ControlNet workflows. + + Computes percentile bounds over non-sky pixels (when available), + then maps depth into [0, 1] with near = white (1.0). + """ + if sky is not None: + mask = compute_non_sky_mask(sky) + if mask.any(): + valid = depth[mask] + else: + valid = depth.flatten() + else: + valid = depth.flatten() + + if valid.numel() > 100_000: + idx = torch.randint(0, valid.numel(), (100_000,), device=valid.device) + sample = valid[idx] + else: + sample = valid + + lo = torch.quantile(sample, low_quantile) + hi = torch.quantile(sample, high_quantile) + rng = (hi - lo).clamp(min=1e-6) + norm = ((depth - lo) / rng).clamp(0.0, 1.0) + # ControlNet convention: nearer pixels are brighter (1.0). + norm = 1.0 - norm + if sky is not None: + # Sky pixels become black (far / unknown). + sky_mask = ~compute_non_sky_mask(sky) + norm = torch.where(sky_mask, torch.zeros_like(norm), norm) + return norm + + +def normalize_depth_min_max(depth: torch.Tensor) -> torch.Tensor: + """Simple per-frame min/max normalization with near=1.0 convention.""" + lo = depth.amin(dim=(-2, -1), keepdim=True) + hi = depth.amax(dim=(-2, -1), keepdim=True) + rng = (hi - lo).clamp(min=1e-6) + return 1.0 - ((depth - lo) / rng).clamp(0.0, 1.0) diff --git a/comfy/ldm/depth_anything_3/ray_pose.py b/comfy/ldm/depth_anything_3/ray_pose.py new file mode 100644 index 000000000..e72264dd5 --- /dev/null +++ b/comfy/ldm/depth_anything_3/ray_pose.py @@ -0,0 +1,320 @@ +"""Ray-to-pose conversion for the multi-view path of Depth Anything 3. + +Converts the auxiliary "ray" output of :class:`DualDPT` (per-pixel camera +ray vectors, predicted on the per-view local feature map) into per-view +extrinsics + intrinsics. Implementation is a 1:1 port of +``depth_anything_3.utils.ray_utils`` upstream, using a weighted-RANSAC +homography fit followed by a QL decomposition. + +No learned parameters; pure tensor math. Output: + +* ``R`` -- ``(B, S, 3, 3)`` rotation matrix +* ``T`` -- ``(B, S, 3)`` camera-space translation +* ``focal_lengths`` -- ``(B, S, 2)`` in normalised image space (image=2x2) +* ``principal_points`` -- ``(B, S, 2)`` ditto + +:func:`get_extrinsic_from_camray` wraps these into a 4x4 extrinsic matrix +that the public node converts back into pixel-space intrinsics. +""" + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch + + +# ----------------------------------------------------------------------------- +# Linear-algebra helpers +# ----------------------------------------------------------------------------- + + +def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Decompose ``A = Q @ L`` with ``Q`` orthogonal and ``L`` lower-triangular. + + Implemented in terms of QR by reversing the columns/rows; the standard + trick from the upstream reference. Inputs ``A`` are ``(3, 3)``. + """ + P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], + device=A.device, dtype=A.dtype) + A_tilde = A @ P + # CUDA QR is not implemented for fp16/bf16; upcast just for this call. + Q_tilde, R_tilde = torch.linalg.qr(A_tilde.float()) + Q_tilde = Q_tilde.to(A.dtype) + R_tilde = R_tilde.to(A.dtype) + Q = Q_tilde @ P + L = P @ R_tilde @ P + d = torch.diag(L) + sign = torch.sign(d) + Q = Q * sign[None, :] # scale columns of Q + L = L * sign[:, None] # scale rows of L + return Q, L + + +def _homogenize_points(points: torch.Tensor) -> torch.Tensor: + return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + + +# ----------------------------------------------------------------------------- +# Weighted-LSQ + RANSAC homography (batched) +# ----------------------------------------------------------------------------- + + +def _find_homography_weighted_lsq( + src_pts: torch.Tensor, + dst_pts: torch.Tensor, + confident_weight: torch.Tensor, +) -> torch.Tensor: + """Solve a single ``H`` with weighted least-squares (DLT).""" + N = src_pts.shape[0] + if N < 4: + raise ValueError("At least 4 points are required to compute a homography.") + w = confident_weight.sqrt().unsqueeze(1) # (N, 1) + x = src_pts[:, 0:1] + y = src_pts[:, 1:2] + u = dst_pts[:, 0:1] + v = dst_pts[:, 1:2] + zeros = torch.zeros_like(x) + A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=1) + A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=1) + A = torch.cat([A1, A2], dim=0) # (2N, 9) + # CUDA SVD is not implemented for fp16/bf16; upcast just for this call. + _, _, Vh = torch.linalg.svd(A.float()) + Vh = Vh.to(A.dtype) + H = Vh[-1].reshape(3, 3) + return H / H[-1, -1] + + +def _find_homography_weighted_lsq_batched( + src_pts_batch: torch.Tensor, + dst_pts_batch: torch.Tensor, + confident_weight_batch: torch.Tensor, +) -> torch.Tensor: + """Batched DLT solver. Inputs ``(B, K, 2)`` / ``(B, K)``; output ``(B, 3, 3)``.""" + B, K, _ = src_pts_batch.shape + w = confident_weight_batch.sqrt().unsqueeze(2) + x = src_pts_batch[:, :, 0:1] + y = src_pts_batch[:, :, 1:2] + u = dst_pts_batch[:, :, 0:1] + v = dst_pts_batch[:, :, 1:2] + zeros = torch.zeros_like(x) + A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=2) + A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=2) + A = torch.cat([A1, A2], dim=1) # (B, 2K, 9) + # CUDA SVD is not implemented for fp16/bf16; upcast just for this call. + _, _, Vh = torch.linalg.svd(A.float()) + Vh = Vh.to(A.dtype) + H = Vh[:, -1].reshape(B, 3, 3) + return H / H[:, 2:3, 2:3] + + +def _ransac_find_homography_weighted_batched( + src_pts: torch.Tensor, # (B, N, 2) + dst_pts: torch.Tensor, # (B, N, 2) + confident_weight: torch.Tensor, # (B, N) + n_sample: int, + n_iter: int = 100, + reproj_threshold: float = 3.0, + num_sample_for_ransac: int = 8, + max_inlier_num: int = 10000, + rand_sample_iters_idx: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Batched weighted-RANSAC homography estimator. + + Returns ``(B, 3, 3)`` homography matrices. + """ + B, N, _ = src_pts.shape + assert N >= 4 + device = src_pts.device + + sorted_idx = torch.argsort(confident_weight, descending=True, dim=1) + candidate_idx = sorted_idx[:, :n_sample] # (B, n_sample) + + if rand_sample_iters_idx is None: + rand_sample_iters_idx = torch.stack( + [torch.randperm(n_sample, device=device)[:num_sample_for_ransac] + for _ in range(n_iter)], + dim=0, + ) + + rand_idx = candidate_idx[:, rand_sample_iters_idx] # (B, n_iter, k) + b_idx = ( + torch.arange(B, device=device) + .view(B, 1, 1) + .expand(B, n_iter, num_sample_for_ransac) + ) + src_b = src_pts[b_idx, rand_idx] + dst_b = dst_pts[b_idx, rand_idx] + w_b = confident_weight[b_idx, rand_idx] + + cB, cN = src_b.shape[:2] + H_batch = _find_homography_weighted_lsq_batched( + src_b.flatten(0, 1), dst_b.flatten(0, 1), w_b.flatten(0, 1), + ).unflatten(0, (cB, cN)) # (B, n_iter, 3, 3) + + src_homo = torch.cat([src_pts, torch.ones(B, N, 1, device=device, dtype=src_pts.dtype)], dim=2) + proj = torch.bmm( + src_homo.unsqueeze(1).expand(B, n_iter, N, 3).reshape(-1, N, 3), + H_batch.reshape(-1, 3, 3).transpose(1, 2), + ) # (B*n_iter, N, 3) + proj_xy = (proj[:, :, :2] / proj[:, :, 2:3]).reshape(B, n_iter, N, 2) + err = ((proj_xy - dst_pts.unsqueeze(1)) ** 2).sum(-1).sqrt() # (B, n_iter, N) + inlier_mask = err < reproj_threshold + score = (inlier_mask * confident_weight.unsqueeze(1)).sum(dim=2) + best_idx = torch.argmax(score, dim=1) + best_inlier_mask = inlier_mask[torch.arange(B, device=device), best_idx] + + # Refit with the inlier set (per-batch, since the inlier counts vary). + H_inlier_list = [] + for b in range(B): + mask = best_inlier_mask[b] + in_src = src_pts[b][mask] + in_dst = dst_pts[b][mask] + in_w = confident_weight[b][mask] + if in_src.shape[0] < 4: + # Fall back to identity when RANSAC fails to find enough inliers. + H_inlier_list.append(torch.eye(3, device=device, dtype=src_pts.dtype)) + continue + sorted_w = torch.argsort(in_w, descending=True) + if len(sorted_w) > max_inlier_num: + keep = max(int(len(sorted_w) * 0.95), max_inlier_num) + sorted_w = sorted_w[:keep][torch.randperm(keep, device=device)[:max_inlier_num]] + H_inlier_list.append( + _find_homography_weighted_lsq(in_src[sorted_w], in_dst[sorted_w], in_w[sorted_w]) + ) + return torch.stack(H_inlier_list, dim=0) + + +# ----------------------------------------------------------------------------- +# Camera-ray utilities +# ----------------------------------------------------------------------------- + + +def _unproject_identity(num_y: int, num_x: int, B: int, S: int, + device, dtype) -> torch.Tensor: + """Camera-space unit rays for an identity intrinsic on a 2x2 image plane. + + Replicates ``unproject_depth(..., ixt_normalized=True)`` upstream: pixel + coords ``(x, y)`` in ``[dx, 2-dx] x [dy, 2-dy]`` get mapped to + camera-space rays ``(x-1, y-1, 1)`` via the identity intrinsic + ``[[1,0,1],[0,1,1],[0,0,1]]``. Returns ``(B, S, num_y, num_x, 3)``. + """ + dx = 1.0 / num_x + dy = 1.0 / num_y + # Centered camera-space coords directly (skip the K^-1 step since it's + # just a translation by -1 on x and y when K is identity-with-center=1). + y = torch.linspace(-(1 - dy), (1 - dy), num_y, device=device, dtype=dtype) + x = torch.linspace(-(1 - dx), (1 - dx), num_x, device=device, dtype=dtype) + yy, xx = torch.meshgrid(y, x, indexing="ij") + grid = torch.stack((xx, yy), dim=-1) # (h, w, 2) + grid = grid.unsqueeze(0).unsqueeze(0).expand(B, S, num_y, num_x, 2) + return torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1) + + +def _camray_to_caminfo( + camray: torch.Tensor, # (B, S, h, w, 6) + confidence: Optional[torch.Tensor] = None, # (B, S, h, w) + reproj_threshold: float = 0.2, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert per-pixel camera rays to per-view (R, T, focal, principal).""" + if confidence is None: + confidence = torch.ones_like(camray[..., 0]) + B, S, h, w, _ = camray.shape + device = camray.device + dtype = camray.dtype + + rays_target = camray[..., :3] # (B, S, h, w, 3) + rays_origin = _unproject_identity(h, w, B, S, device, dtype) + + # Flatten (B*S, h*w, *) for the RANSAC routine. + rays_target = rays_target.flatten(0, 1).flatten(1, 2) + rays_origin = rays_origin.flatten(0, 1).flatten(1, 2) + weights = confidence.flatten(0, 1).flatten(1, 2).clone() + + # Project to 2D in homogeneous form (the upstream calls this "perspective division"). + z_thresh = 1e-4 + mask = (rays_target[:, :, 2].abs() > z_thresh) & (rays_origin[:, :, 2].abs() > z_thresh) + weights = torch.where(mask, weights, torch.zeros_like(weights)) + src = rays_origin.clone() + dst = rays_target.clone() + src[..., 0] = torch.where(mask, src[..., 0] / src[..., 2], src[..., 0]) + src[..., 1] = torch.where(mask, src[..., 1] / src[..., 2], src[..., 1]) + dst[..., 0] = torch.where(mask, dst[..., 0] / dst[..., 2], dst[..., 0]) + dst[..., 1] = torch.where(mask, dst[..., 1] / dst[..., 2], dst[..., 1]) + src = src[..., :2] + dst = dst[..., :2] + + N = src.shape[1] + n_iter = 100 + sample_ratio = 0.3 + num_sample_for_ransac = 8 + n_sample = max(num_sample_for_ransac, int(N * sample_ratio)) + rand_idx = torch.stack( + [torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)], + dim=0, + ) + + # Chunk along the view axis to keep peak memory predictable. + chunk = 2 + A_list = [] + for i in range(0, src.shape[0], chunk): + A = _ransac_find_homography_weighted_batched( + src[i:i + chunk], dst[i:i + chunk], weights[i:i + chunk], + n_sample=n_sample, n_iter=n_iter, + num_sample_for_ransac=num_sample_for_ransac, + reproj_threshold=reproj_threshold, + rand_sample_iters_idx=rand_idx, + max_inlier_num=8000, + ) + # Flip sign on dets that come out < 0 (so that the QL produces a + # right-handed rotation). ``det`` lacks fp16/bf16 CUDA kernels, so + # do the comparison in fp32. + flip = torch.linalg.det(A.float()) < 0 + A = torch.where(flip[:, None, None], -A, A) + A_list.append(A) + A = torch.cat(A_list, dim=0) # (B*S, 3, 3) + + R_list, f_list, pp_list = [], [], [] + for i in range(A.shape[0]): + R, L = _ql_decomposition(A[i]) + L = L / L[2][2] + f_list.append(torch.stack((L[0][0], L[1][1]))) + pp_list.append(torch.stack((L[2][0], L[2][1]))) + R_list.append(R) + R = torch.stack(R_list).reshape(B, S, 3, 3) + focal = torch.stack(f_list).reshape(B, S, 2) + pp = torch.stack(pp_list).reshape(B, S, 2) + + # Translation: confidence-weighted average of camray direction(s). + cf = confidence.flatten(0, 1).flatten(1, 2) + T = (camray.flatten(0, 1).flatten(1, 2)[..., 3:] * cf.unsqueeze(-1)).sum(dim=1) + T = T / cf.sum(dim=-1, keepdim=True) + T = T.reshape(B, S, 3) + + # Match upstream output convention: focal -> 1/focal, pp + 1. + return R, T, 1.0 / focal, pp + 1.0 + + +def get_extrinsic_from_camray( + camray: torch.Tensor, # (B, S, h, w, 6) + conf: torch.Tensor, # (B, S, h, w, 1) or (B, S, h, w) + patch_size_y: int, + patch_size_x: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Wrap a 4x4 extrinsic + per-view focal + principal-point output. + + Returns: + * extrinsic ``(B, S, 4, 4)`` camera-to-world (the inverse is + what gets stored in ``output.extrinsics`` + by the caller). + * focals ``(B, S, 2)`` in normalised image space. + * pp ``(B, S, 2)`` in normalised image space. + """ + if conf.ndim == 5 and conf.shape[-1] == 1: + conf = conf.squeeze(-1) + R, T, focal, pp = _camray_to_caminfo(camray, confidence=conf) + extr = torch.cat([R, T.unsqueeze(-1)], dim=-1) # (B, S, 3, 4) + homo_row = torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device) + homo_row = homo_row.view(1, 1, 1, 4).expand(R.shape[0], R.shape[1], 1, 4) + extr = torch.cat([extr, homo_row], dim=-2) # (B, S, 4, 4) + return extr, focal, pp diff --git a/comfy/ldm/depth_anything_3/reference_view_selector.py b/comfy/ldm/depth_anything_3/reference_view_selector.py new file mode 100644 index 000000000..968cd9d14 --- /dev/null +++ b/comfy/ldm/depth_anything_3/reference_view_selector.py @@ -0,0 +1,116 @@ +"""Reference-view selection for the multi-view path of Depth Anything 3. + +Pure tensor math, no learned parameters. Exposed as three free functions: + +* :func:`select_reference_view` -- pick a reference view per batch. +* :func:`reorder_by_reference` -- move the reference view to position 0. +* :func:`restore_original_order` -- inverse of :func:`reorder_by_reference`. + +Mirrors ``depth_anything_3.model.reference_view_selector`` upstream. +The default strategy (``"saddle_balanced"``) selects the view whose CLS +token features are closest to the median across multiple metrics. +""" + +from __future__ import annotations + +from typing import Literal + +import torch + + +RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"] + + +# Per the upstream constants module: ``THRESH_FOR_REF_SELECTION = 3``. +# Reference selection only runs when there are at least this many views. +THRESH_FOR_REF_SELECTION: int = 3 + + +def select_reference_view( + x: torch.Tensor, + strategy: RefViewStrategy = "saddle_balanced", +) -> torch.Tensor: + """Pick a reference view index per batch element. + + Args: + x: ``(B, S, N, C)`` token tensor. Index 0 along ``N`` is the + cls/cam token used by the feature-based strategies. + strategy: One of ``"first" | "middle" | "saddle_balanced" | + "saddle_sim_range"``. + + Returns: + ``(B,)`` long tensor with the chosen reference view index for + each batch element. + """ + B, S, _, _ = x.shape + if S <= 1: + return torch.zeros(B, dtype=torch.long, device=x.device) + if strategy == "first": + return torch.zeros(B, dtype=torch.long, device=x.device) + if strategy == "middle": + return torch.full((B,), S // 2, dtype=torch.long, device=x.device) + + # Feature-based strategies: normalised cls/cam token per view. + img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True) # (B,S,C) + + if strategy == "saddle_balanced": + sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # (B,S,S) + sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0) + sim_score = sim_no_diag.sum(dim=-1) / (S - 1) # (B,S) + feat_norm = x[:, :, 0].norm(dim=-1) # (B,S) + feat_var = img_class_feat.var(dim=-1) # (B,S) + + def _normalize(metric): + mn = metric.min(dim=1, keepdim=True).values + mx = metric.max(dim=1, keepdim=True).values + return (metric - mn) / (mx - mn + 1e-8) + + sim_n, norm_n, var_n = _normalize(sim_score), _normalize(feat_norm), _normalize(feat_var) + balance = (sim_n - 0.5).abs() + (norm_n - 0.5).abs() + (var_n - 0.5).abs() + return balance.argmin(dim=1) + + if strategy == "saddle_sim_range": + sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) + sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0) + sim_max = sim_no_diag.max(dim=-1).values + sim_min = sim_no_diag.min(dim=-1).values + return (sim_max - sim_min).argmax(dim=1) + + raise ValueError( + f"Unknown reference view selection strategy: {strategy!r}. " + f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'" + ) + + +def reorder_by_reference(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor: + """Reorder ``x`` so the reference view is at position 0 in axis ``S``.""" + B, S = x.shape[0], x.shape[1] + if S <= 1: + return x + positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1) + b_idx_exp = b_idx.unsqueeze(1) + reorder = torch.where( + (positions > 0) & (positions <= b_idx_exp), + positions - 1, + positions, + ) + reorder[:, 0] = b_idx + batch = torch.arange(B, device=x.device).unsqueeze(1) + return x[batch, reorder] + + +def restore_original_order(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor: + """Inverse of :func:`reorder_by_reference`.""" + B, S = x.shape[0], x.shape[1] + if S <= 1: + return x + target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1) + b_idx_exp = b_idx.unsqueeze(1) + restore = torch.where(target_positions < b_idx_exp, + target_positions + 1, + target_positions) + restore = torch.scatter( + restore, dim=1, index=b_idx_exp, src=torch.zeros_like(b_idx_exp), + ) + batch = torch.arange(B, device=x.device).unsqueeze(1) + return x[batch, restore] diff --git a/comfy/ldm/depth_anything_3/transform.py b/comfy/ldm/depth_anything_3/transform.py new file mode 100644 index 000000000..f3ce8f6be --- /dev/null +++ b/comfy/ldm/depth_anything_3/transform.py @@ -0,0 +1,180 @@ +"""Geometry / camera transform helpers for Depth Anything 3. + +Pure tensor math, no learned parameters. Mirrors the upstream upstream +``depth_anything_3.model.utils.transform`` and the parts of +``depth_anything_3.utils.geometry`` used at inference time on the +multi-view + camera path. Kept self-contained so the DA3 module is fully +ported and does not depend on the upstream repo at runtime. +""" + +from __future__ import annotations + +from typing import Tuple + +import torch +import torch.nn.functional as F + + +# ----------------------------------------------------------------------------- +# Affine 4x4 helpers +# ----------------------------------------------------------------------------- + + +def as_homogeneous(ext: torch.Tensor) -> torch.Tensor: + """Promote ``(...,3,4)`` extrinsics to ``(...,4,4)`` homogeneous form. + + A no-op when the input is already ``(...,4,4)``. + """ + if ext.shape[-2:] == (4, 4): + return ext + if ext.shape[-2:] == (3, 4): + ones = torch.zeros_like(ext[..., :1, :4]) + ones[..., 0, 3] = 1.0 + return torch.cat([ext, ones], dim=-2) + raise ValueError(f"Invalid affine shape: {ext.shape}") + + +def affine_inverse(A: torch.Tensor) -> torch.Tensor: + """Inverse of an affine matrix ``[R|T; 0 0 0 1]``.""" + R = A[..., :3, :3] + T = A[..., :3, 3:] + P = A[..., 3:, :] + return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2) + + +# ----------------------------------------------------------------------------- +# Quaternion <-> rotation matrix (xyzw / scalar-last) +# ----------------------------------------------------------------------------- + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """``sqrt(max(0, x))`` with a zero subgradient where ``x == 0``.""" + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """Force the real part of a unit quaternion (xyzw) to be non-negative.""" + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """Convert quaternions (xyzw) to ``(...,3,3)`` rotation matrices.""" + i, j, k, r = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """Convert ``(...,3,3)`` rotation matrices to quaternions (xyzw).""" + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape( + batch_dim + (4,) + ) + # Reorder rijk -> xyzw (i.e. ijkr). + out = out[..., [1, 2, 3, 0]] + return standardize_quaternion(out) + + +# ----------------------------------------------------------------------------- +# Pose-encoding <-> extrinsics + intrinsics +# ----------------------------------------------------------------------------- + + +def extri_intri_to_pose_encoding( + extrinsics: torch.Tensor, + intrinsics: torch.Tensor, + image_size_hw: Tuple[int, int], +) -> torch.Tensor: + """Pack ``(extr, intr, image_size)`` into the 9-D pose-encoding vector. + + ``extrinsics`` are camera-to-world (c2w) ``(B,S,4,4)`` matrices, + ``intrinsics`` are pixel-space ``(B,S,3,3)`` matrices, ``image_size_hw`` + is a ``(H, W)`` pair. The encoding is ``[T(3), quat_xyzw(4), fov_h, fov_w]``. + """ + R = extrinsics[..., :3, :3] + T = extrinsics[..., :3, 3] + quat = mat_to_quat(R) + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + return torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() + + +def pose_encoding_to_extri_intri( + pose_encoding: torch.Tensor, + image_size_hw: Tuple[int, int], +) -> Tuple[torch.Tensor, torch.Tensor]: + """Inverse of :func:`extri_intri_to_pose_encoding`. + + Returns a ``(B,S,3,4)`` c2w extrinsic matrix and a ``(B,S,3,3)`` + pixel-space intrinsic matrix. + """ + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + fov_h = pose_encoding[..., 7] + fov_w = pose_encoding[..., 8] + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + H, W = image_size_hw + fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6) + fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6) + intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), + device=pose_encoding.device, dtype=pose_encoding.dtype) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = W / 2 + intrinsics[..., 1, 2] = H / 2 + intrinsics[..., 2, 2] = 1.0 + return extrinsics, intrinsics diff --git a/comfy/model_base.py b/comfy/model_base.py index 0736321b3..ad040b244 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -60,6 +60,7 @@ import comfy.ldm.ernie.model import comfy.ldm.sam3.detector import comfy.ldm.hidream_o1.model from comfy.ldm.hidream_o1.conditioning import build_extra_conds +import comfy.ldm.depth_anything_3.model import comfy.model_management import comfy.patcher_extension @@ -2035,6 +2036,12 @@ class RT_DETR_v4(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4) + +class DepthAnything3(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, + unet_model=comfy.ldm.depth_anything_3.model.DepthAnything3Net) + class ErnieImage(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index bc0b933bc..04a812267 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -766,6 +766,108 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0] return dit_config + # Depth Anything 3 (Apache-2.0 monocular variants: Small/Base/Mono-Large/Metric-Large). + if '{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix) in state_dict_keys: + dit_config = {} + dit_config["image_model"] = "DepthAnything3" + + patch_w = state_dict['{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix)] + embed_dim = patch_w.shape[0] + depth = count_blocks(state_dict_keys, '{}backbone.pretrained.blocks.'.format(key_prefix) + '{}.') + + # Backbone preset is determined by embed_dim (matches vits/vitb/vitl/vitg). + backbone_name = {384: "vits", 768: "vitb", 1024: "vitl", 1536: "vitg"}.get(embed_dim) + if backbone_name is None: + return None + dit_config["backbone_name"] = backbone_name + + # Detect DA3 extensions on top of vanilla DINOv2. + has_camera_token = '{}backbone.pretrained.camera_token'.format(key_prefix) in state_dict_keys + # qk-norm shows up as `attn.q_norm.weight` on enabled blocks. + qknorm_indices = [ + i for i in range(depth) + if '{}backbone.pretrained.blocks.{}.attn.q_norm.weight'.format(key_prefix, i) in state_dict_keys + ] + qknorm_start = qknorm_indices[0] if qknorm_indices else -1 + + # The DA3 main-series configs always set alt_start == qknorm_start == rope_start. + # cat_token=True is implied by the presence of camera_token. + if has_camera_token: + dit_config["alt_start"] = qknorm_start + dit_config["rope_start"] = qknorm_start + dit_config["qknorm_start"] = qknorm_start + dit_config["cat_token"] = True + else: + dit_config["alt_start"] = -1 + dit_config["rope_start"] = -1 + dit_config["qknorm_start"] = -1 + dit_config["cat_token"] = False + + # Detect head type and config. + has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys + if has_aux: + dit_config["head_type"] = "dualdpt" + # DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width). + head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1] + out_channels = [ + state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0] + for i in range(4) + ] + features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0] + dit_config["head_dim_in"] = head_dim_in + dit_config["head_output_dim"] = 2 + dit_config["head_features"] = features + dit_config["head_out_channels"] = out_channels + dit_config["head_use_sky_head"] = False + else: + dit_config["head_type"] = "dpt" + head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1] + out_channels = [ + state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0] + for i in range(4) + ] + features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0] + output_dim = state_dict[ + '{}head.scratch.output_conv2.2.weight'.format(key_prefix) + ].shape[0] + dit_config["head_dim_in"] = head_dim_in + dit_config["head_output_dim"] = output_dim + dit_config["head_features"] = features + dit_config["head_out_channels"] = out_channels + dit_config["head_use_sky_head"] = ( + '{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys + ) + + # out_layers: hard-coded per upstream YAML config (depth-aware default). + if depth >= 24: + # vitl: depths used vary between DA3-Large (DualDPT) and Mono/Metric (DPT). + if has_aux: + dit_config["out_layers"] = [11, 15, 19, 23] + else: + dit_config["out_layers"] = [4, 11, 17, 23] + else: + # vits/vitb: 12 blocks + dit_config["out_layers"] = [5, 7, 9, 11] + + # Camera encoder/decoder presence (multi-view + pose path). + has_cam_enc = '{}cam_enc.token_norm.weight'.format(key_prefix) in state_dict_keys + has_cam_dec = '{}cam_dec.fc_t.weight'.format(key_prefix) in state_dict_keys + dit_config["has_cam_enc"] = has_cam_enc + dit_config["has_cam_dec"] = has_cam_dec + if has_cam_enc: + cam_enc_w = state_dict.get( + '{}cam_enc.pose_branch.fc2.weight'.format(key_prefix) + ) + if cam_enc_w is not None: + dit_config["cam_dim_out"] = cam_enc_w.shape[0] + if has_cam_dec: + cam_dec_w = state_dict.get( + '{}cam_dec.fc_t.weight'.format(key_prefix) + ) + if cam_dec_w is not None: + dit_config["cam_dec_dim_in"] = cam_dec_w.shape[1] + return dit_config + if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image dit_config = {} dit_config["image_model"] = "ernie" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1e4434fd5..e5cc953d2 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1847,6 +1847,126 @@ class RT_DETR_v4(supported_models_base.BASE): return None +class DepthAnything3(supported_models_base.BASE): + unet_config = { + "image_model": "DepthAnything3", + } + + # Mono path: no num_heads / num_head_channels needed. + unet_extra_config = {} + + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.DepthAnything3(self, device=device) + + def clip_target(self, state_dict={}): + return None + + def process_unet_state_dict(self, state_dict): + # Drop weights for components we do not build (3D Gaussian heads). + # ``cam_enc.*`` / ``cam_dec.*`` are kept and consumed by the multi-view + # forward path -- their layouts in our ``camera.py`` mirror the + # upstream ``cam_enc.py`` / ``cam_dec.py`` so HF safetensors load + # directly without any key remap. + drop_prefixes = ("gs_head.", "gs_adapter.") + for k in list(state_dict.keys()): + if k.startswith(drop_prefixes): + state_dict.pop(k) + # Remap upstream DA3 backbone keys (``backbone.pretrained.*`` with + # fused QKV) to the layout used by ``comfy.image_encoders.dino2.Dinov2Model``. + return _da3_remap_backbone_keys(state_dict, prefix="backbone.") + + +def _da3_remap_backbone_keys(state_dict, prefix="backbone."): + """Rewrite upstream DA3 DINOv2 keys to the shared ``Dinov2Model`` layout. + + Upstream layout (under ``{prefix}pretrained.``): + patch_embed.proj.{weight,bias}, pos_embed, cls_token, camera_token, norm.*, + blocks.{i}.norm{1,2}.*, blocks.{i}.attn.qkv.{weight,bias}, + blocks.{i}.attn.q_norm.*, blocks.{i}.attn.k_norm.*, + blocks.{i}.attn.proj.*, blocks.{i}.ls{1,2}.gamma, + blocks.{i}.mlp.fc{1,2}.* (or w12/w3 for SwiGLU) + + Target layout (Dinov2Model under ``{prefix}``): + embeddings.patch_embeddings.projection.*, + embeddings.position_embeddings, embeddings.cls_token, embeddings.camera_token, + layernorm.*, + encoder.layer.{i}.norm{1,2}.*, + encoder.layer.{i}.attention.attention.{query,key,value}.*, + encoder.layer.{i}.attention.q_norm.*, encoder.layer.{i}.attention.k_norm.*, + encoder.layer.{i}.attention.output.dense.*, + encoder.layer.{i}.layer_scale{1,2}.lambda1, + encoder.layer.{i}.mlp.fc{1,2}.* (or weights_in/weights_out for SwiGLU) + """ + pre = prefix + "pretrained." + src_keys = [k for k in state_dict.keys() if k.startswith(pre)] + if not src_keys: + return state_dict + + static_renames = { + pre + "patch_embed.proj.weight": prefix + "embeddings.patch_embeddings.projection.weight", + pre + "patch_embed.proj.bias": prefix + "embeddings.patch_embeddings.projection.bias", + pre + "pos_embed": prefix + "embeddings.position_embeddings", + pre + "cls_token": prefix + "embeddings.cls_token", + pre + "camera_token": prefix + "embeddings.camera_token", + pre + "norm.weight": prefix + "layernorm.weight", + pre + "norm.bias": prefix + "layernorm.bias", + } + for src, dst in static_renames.items(): + if src in state_dict: + state_dict[dst] = state_dict.pop(src) + + block_pre = pre + "blocks." + block_keys = [k for k in state_dict.keys() if k.startswith(block_pre)] + for k in block_keys: + rest = k[len(block_pre):] # e.g. "5.attn.qkv.weight" + idx_str, _, sub = rest.partition(".") + target_block = "{}encoder.layer.{}.".format(prefix, idx_str) + + # Fused QKV -> split query/key/value linears. + if sub == "attn.qkv.weight": + qkv = state_dict.pop(k) + c = qkv.shape[0] // 3 + state_dict[target_block + "attention.attention.query.weight"] = qkv[:c].clone() + state_dict[target_block + "attention.attention.key.weight"] = qkv[c:2 * c].clone() + state_dict[target_block + "attention.attention.value.weight"] = qkv[2 * c:].clone() + continue + if sub == "attn.qkv.bias": + qkv = state_dict.pop(k) + c = qkv.shape[0] // 3 + state_dict[target_block + "attention.attention.query.bias"] = qkv[:c].clone() + state_dict[target_block + "attention.attention.key.bias"] = qkv[c:2 * c].clone() + state_dict[target_block + "attention.attention.value.bias"] = qkv[2 * c:].clone() + continue + + # Sub-key remap (suffix preserved). + if sub.startswith("attn.proj."): + tail = sub[len("attn.proj."):] + new = "attention.output.dense." + tail + elif sub.startswith("attn.q_norm."): + new = "attention.q_norm." + sub[len("attn.q_norm."):] + elif sub.startswith("attn.k_norm."): + new = "attention.k_norm." + sub[len("attn.k_norm."):] + elif sub == "ls1.gamma": + new = "layer_scale1.lambda1" + elif sub == "ls2.gamma": + new = "layer_scale2.lambda1" + elif sub.startswith("mlp.w12."): + new = "mlp.weights_in." + sub[len("mlp.w12."):] + elif sub.startswith("mlp.w3."): + new = "mlp.weights_out." + sub[len("mlp.w3."):] + elif sub.startswith(("norm1.", "norm2.", "mlp.fc1.", "mlp.fc2.")): + new = sub + else: + # Unrecognised key -- leave as-is so load_state_dict can complain. + continue + + state_dict[target_block + new] = state_dict.pop(k) + + return state_dict + + class ErnieImage(supported_models_base.BASE): unet_config = { "image_model": "ernie", @@ -2082,4 +2202,5 @@ models = [ CogVideoX_I2V, CogVideoX_T2V, SVD_img2vid, + DepthAnything3, ] diff --git a/comfy_extras/nodes_depth_anything_3.py b/comfy_extras/nodes_depth_anything_3.py new file mode 100644 index 000000000..9e0bee9aa --- /dev/null +++ b/comfy_extras/nodes_depth_anything_3.py @@ -0,0 +1,398 @@ +"""ComfyUI nodes for Depth Anything 3. + +Adds these nodes: + +* ``LoadDepthAnything3`` -- load a DA3 ``.safetensors`` file from the + ``models/depth_estimation/`` folder. Falls back to ``models/diffusion_models/`` + so existing installations keep working. +* ``DepthAnything3Depth`` -- run depth estimation and return a normalised + depth map as a ComfyUI ``IMAGE`` (visualisation / ControlNet input). +* ``DepthAnything3DepthRaw`` -- run depth estimation and return the raw depth, + confidence and sky channels as ``MASK`` outputs. +* ``DepthAnything3MultiView`` -- multi-view path: depth + per-view extrinsics + + intrinsics. Pose is decoded either from the camera-decoder MLP (default) + or from the auxiliary ray output via RANSAC (DA3-Small/Base only). +""" + +from __future__ import annotations + +from typing_extensions import override + +import torch + +import comfy.model_management as mm +import comfy.sd +import folder_paths +from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess +from comfy_api.latest import ComfyExtension, io + + +# ----------------------------------------------------------------------------- +# Loader +# ----------------------------------------------------------------------------- + + +class LoadDepthAnything3(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadDepthAnything3", + display_name="Load Depth Anything 3", + category="loaders/depth_estimation", + inputs=[ + io.Combo.Input( + "model_name", + options=folder_paths.get_filename_list("depth_estimation"), + ), + io.Combo.Input( + "weight_dtype", + options=["default", "fp16", "bf16", "fp32"], + default="default", + ), + ], + outputs=[io.Model.Output("model")], + ) + + @classmethod + def execute(cls, model_name, weight_dtype) -> io.NodeOutput: + model_options = {} + if weight_dtype == "fp16": + model_options["dtype"] = torch.float16 + elif weight_dtype == "bf16": + model_options["dtype"] = torch.bfloat16 + elif weight_dtype == "fp32": + model_options["dtype"] = torch.float32 + + path = folder_paths.get_full_path_or_raise("depth_estimation", model_name) + model = comfy.sd.load_diffusion_model(path, model_options=model_options) + return io.NodeOutput(model) + + +# ----------------------------------------------------------------------------- +# Inference helpers +# ----------------------------------------------------------------------------- + + +def _run_da3(model_patcher, image: torch.Tensor, process_res: int, + method: str = "upper_bound_resize"): + """Run the DA3 network on a (B, H, W, 3) ``IMAGE`` batch. + + Returns ``(depth, confidence, sky)`` tensors with the original image + resolution. Any of ``confidence`` / ``sky`` may be ``None`` depending on + the variant. + """ + assert image.ndim == 4 and image.shape[-1] == 3, \ + f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}" + + B, H, W, _ = image.shape + mm.load_model_gpu(model_patcher) + diffusion = model_patcher.model.diffusion_model + device = mm.get_torch_device() + dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32 + + depths, confs, skies = [], [], [] + # Process one image at a time to keep peak memory predictable; DA3 is + # an inference-only model and per-sample latency dominates anyway. + for i in range(B): + single = image[i:i + 1].to(device) + x = da3_preprocess.preprocess_image(single, process_res=process_res, method=method) + x = x.to(dtype=dtype) + with torch.no_grad(): + out = diffusion(x) + + depth_lr = out["depth"] + # Resize back to the original (H, W). + depth_full = torch.nn.functional.interpolate( + depth_lr.unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + depths.append(depth_full) + + if "depth_conf" in out: + conf_full = torch.nn.functional.interpolate( + out["depth_conf"].unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + confs.append(conf_full) + if "sky" in out: + sky_full = torch.nn.functional.interpolate( + out["sky"].unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + skies.append(sky_full) + + depth = torch.cat(depths, dim=0) + confidence = torch.cat(confs, dim=0) if confs else None + sky = torch.cat(skies, dim=0) if skies else None + return depth, confidence, sky + + +# ----------------------------------------------------------------------------- +# Depth -> visualisation IMAGE +# ----------------------------------------------------------------------------- + + +class DepthAnything3Depth(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DepthAnything3Depth", + display_name="Depth Anything 3 (Depth)", + category="image/depth", + inputs=[ + io.Model.Input("model"), + io.Image.Input("image"), + io.Int.Input("process_res", default=504, min=140, max=2520, step=14, + tooltip="Longest-side target resolution (multiple of 14)."), + io.Combo.Input("resize_method", + options=["upper_bound_resize", "lower_bound_resize"], + default="upper_bound_resize"), + io.Combo.Input("normalization", + options=["v2_style", "min_max", "raw"], + default="v2_style", + tooltip="How to map raw depth -> [0, 1] image."), + io.Boolean.Input("apply_sky_clip", default=True, + tooltip="(Mono/Metric only) clip sky depth to 99th percentile."), + ], + outputs=[ + io.Image.Output("depth_image"), + io.Mask.Output("sky_mask", + tooltip="Sky probability (Mono/Metric variants), else zeros."), + io.Mask.Output("confidence", + tooltip="Depth confidence (Small/Base/DualDPT variants), else zeros."), + ], + ) + + @classmethod + def execute(cls, model, image, process_res, resize_method, normalization, + apply_sky_clip) -> io.NodeOutput: + depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method) + + if apply_sky_clip and sky is not None: + depth = torch.stack([ + da3_preprocess.apply_sky_aware_clip(depth[i], sky[i]) + for i in range(depth.shape[0]) + ], dim=0) + + if normalization == "v2_style": + norm = torch.stack([ + da3_preprocess.normalize_depth_v2_style(depth[i], + sky[i] if sky is not None else None) + for i in range(depth.shape[0]) + ], dim=0) + elif normalization == "min_max": + norm = da3_preprocess.normalize_depth_min_max(depth) + else: + norm = depth + + # (B, H, W) -> (B, H, W, 3) grayscale IMAGE. + out_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous() + sky_mask = sky if sky is not None else torch.zeros_like(depth) + conf_mask = confidence if confidence is not None else torch.zeros_like(depth) + return io.NodeOutput(out_image, sky_mask.contiguous(), conf_mask.contiguous()) + + +# ----------------------------------------------------------------------------- +# Raw depth output (useful for downstream metric work) +# ----------------------------------------------------------------------------- + + +class DepthAnything3MultiView(io.ComfyNode): + """Multi-view depth + pose estimation for DA3-Small / DA3-Base / DA3-Large. + + Treats each batch element of the input ``IMAGE`` as a separate view of + the same scene. The selected reference view is auto-chosen by the + backbone via ``ref_view_strategy`` (when at least 3 views are + supplied), unless camera extrinsics are provided -- in which case the + geometry is pinned by the user and no reordering is done. + + Output structure: + * ``depth_image`` -- per-view normalised depth as a stacked ``IMAGE`` + batch (one frame per view, original input order). + * ``confidence`` / ``sky`` -- per-view masks (zero when the variant + does not produce them). + * ``camera`` -- ``LATENT`` dict with keys:: + samples: (1, S, 1, h_p, w_p) -- raw depth packed as latent + type: "da3_multiview" + extrinsics: (1, S, 4, 4) world-to-camera matrices + intrinsics: (1, S, 3, 3) pixel-space intrinsics + depth_raw: (S, H, W) un-normalised depth + confidence: (S, H, W) + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DepthAnything3MultiView", + display_name="Depth Anything 3 (Multi-View)", + category="image/depth", + inputs=[ + io.Model.Input("model"), + io.Image.Input("image", + tooltip="Image batch where each frame is a view of the same scene."), + io.Int.Input("process_res", default=504, min=140, max=2520, step=14, + tooltip="Longest-side target resolution (multiple of 14)."), + io.Combo.Input("resize_method", + options=["upper_bound_resize", "lower_bound_resize"], + default="upper_bound_resize"), + io.Combo.Input("ref_view_strategy", + options=["saddle_balanced", "saddle_sim_range", "first", "middle"], + default="saddle_balanced", + tooltip="Reference view selection (only applied when " + "S>=3 and no extrinsics are provided)."), + io.Combo.Input("pose_method", + options=["cam_dec", "ray_pose"], + default="cam_dec", + tooltip="cam_dec: small MLP on the final cam token (works for " + "all variants with cam_dec). ray_pose: RANSAC over the " + "DualDPT auxiliary ray output (DA3-Small/Base only)."), + io.Combo.Input("normalization", + options=["v2_style", "min_max", "raw"], + default="v2_style"), + ], + outputs=[ + io.Image.Output("depth_image"), + io.Mask.Output("confidence"), + io.Mask.Output("sky_mask"), + io.Latent.Output("camera", + tooltip="Per-view extrinsics + intrinsics + raw depth."), + ], + ) + + @classmethod + def execute(cls, model, image, process_res, resize_method, ref_view_strategy, + pose_method, normalization) -> io.NodeOutput: + assert image.ndim == 4 and image.shape[-1] == 3, \ + f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}" + S, H, W, _ = image.shape + + mm.load_model_gpu(model) + diffusion = model.model.diffusion_model + device = mm.get_torch_device() + dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32 + + # Stack all views as a single batch element with views axis = S. + x = image.to(device) + x = da3_preprocess.preprocess_image(x, process_res=process_res, method=resize_method) + x = x.to(dtype=dtype).unsqueeze(0) # (1, S, 3, H', W') + + use_ray_pose = (pose_method == "ray_pose") + with torch.no_grad(): + out = diffusion(x, use_ray_pose=use_ray_pose, + ref_view_strategy=ref_view_strategy) + + # ``out["depth"]`` is (S, h_p, w_p); resize back to (S, H, W). + depth_lr = out["depth"].float() + depth = torch.nn.functional.interpolate( + depth_lr.unsqueeze(1), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + + if "depth_conf" in out: + conf = torch.nn.functional.interpolate( + out["depth_conf"].unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + else: + conf = torch.zeros_like(depth) + + if "sky" in out: + sky = torch.nn.functional.interpolate( + out["sky"].unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + else: + sky = torch.zeros_like(depth) + + # Pose. Defaults to identity when neither cam_dec nor ray_pose is wired up. + if "extrinsics" in out and "intrinsics" in out: + extrinsics = out["extrinsics"].float().cpu() + intrinsics = out["intrinsics"].float().cpu() + else: + extrinsics = torch.eye(4)[None, None].expand(1, S, 4, 4).clone() + intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone() + + # Normalised depth viz per view (same path as the mono node). + if normalization == "v2_style": + norm = torch.stack([ + da3_preprocess.normalize_depth_v2_style(depth[i], + sky[i] if "sky" in out else None) + for i in range(S) + ], dim=0) + elif normalization == "min_max": + norm = da3_preprocess.normalize_depth_min_max(depth) + else: + norm = depth + + depth_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous() + + camera_latent = { + # The Latent contract requires a ``samples`` field; pack the raw + # depth there so a downstream node still has a tensor to chain on. + "samples": depth.unsqueeze(0).unsqueeze(2).contiguous(), # (1, S, 1, H, W) + "type": "da3_multiview", + "extrinsics": extrinsics.contiguous(), + "intrinsics": intrinsics.contiguous(), + "depth_raw": depth.contiguous(), + "confidence": conf.contiguous(), + } + return io.NodeOutput( + depth_image, + conf.contiguous(), + sky.contiguous(), + camera_latent, + ) + + +class DepthAnything3DepthRaw(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DepthAnything3DepthRaw", + display_name="Depth Anything 3 (Raw Depth)", + category="image/depth", + inputs=[ + io.Model.Input("model"), + io.Image.Input("image"), + io.Int.Input("process_res", default=504, min=140, max=2520, step=14), + io.Combo.Input("resize_method", + options=["upper_bound_resize", "lower_bound_resize"], + default="upper_bound_resize"), + ], + outputs=[ + io.Mask.Output("depth", + tooltip="Raw depth values (no normalisation, no clipping)."), + io.Mask.Output("confidence"), + io.Mask.Output("sky"), + ], + ) + + @classmethod + def execute(cls, model, image, process_res, resize_method) -> io.NodeOutput: + depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method) + zeros = torch.zeros_like(depth) + return io.NodeOutput( + depth.contiguous(), + (confidence if confidence is not None else zeros).contiguous(), + (sky if sky is not None else zeros).contiguous(), + ) + + +# ----------------------------------------------------------------------------- +# Extension registration +# ----------------------------------------------------------------------------- + + +class DepthAnything3Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LoadDepthAnything3, + DepthAnything3Depth, + DepthAnything3DepthRaw, + DepthAnything3MultiView, + ] + + +async def comfy_entrypoint() -> DepthAnything3Extension: + return DepthAnything3Extension() diff --git a/folder_paths.py b/folder_paths.py index 92e8df3cf..bc95e9a8c 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -58,6 +58,13 @@ folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "fram folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions) +folder_names_and_paths["depth_estimation"] = ( + [os.path.join(models_dir, "depth_estimation"), + os.path.join(models_dir, "diffusion_models")], + supported_pt_extensions, +) + + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") diff --git a/nodes.py b/nodes.py index 78aaaef74..d5a05445c 100644 --- a/nodes.py +++ b/nodes.py @@ -2436,6 +2436,7 @@ async def init_builtin_extra_nodes(): "nodes_void.py", "nodes_wandancer.py", "nodes_hidream_o1.py", + "nodes_depth_anything_3.py", ] import_failed = []