mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-11 08:47:29 +08:00
Clean-up AI generted code
This commit is contained in:
parent
7ae6c41fcf
commit
b95db88d38
@ -289,21 +289,6 @@ class Dino2Embeddings(torch.nn.Module):
|
||||
|
||||
|
||||
class Dinov2Model(torch.nn.Module):
|
||||
"""DINOv2 vision backbone.
|
||||
|
||||
Supports two operating modes:
|
||||
|
||||
* **CLIP-vision DINOv2** (default): vanilla DINOv2-ViT used for
|
||||
``ClipVisionModel`` and SigLIP-style image encoding.
|
||||
* **Depth Anything 3** extensions (opt-in via config keys): 2D RoPE,
|
||||
QK-norm, alternating local/global attention, camera-token injection,
|
||||
``cat_token`` output and multi-layer feature extraction. These are
|
||||
enabled when the corresponding fields (``alt_start``, ``qknorm_start``,
|
||||
``rope_start``, ``cat_token``) are set in ``config_dict``. When all of
|
||||
them are at their disabled defaults this module behaves identically to
|
||||
the historical ``Dinov2Model``.
|
||||
"""
|
||||
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
@ -363,21 +348,7 @@ class Dinov2Model(torch.nn.Module):
|
||||
return x, i, pooled_output, None
|
||||
|
||||
def get_intermediate_layers(self, pixel_values, indices, apply_norm=True):
|
||||
"""Single-view multi-layer feature extraction (MoGe / vanilla DINOv2).
|
||||
|
||||
For the multi-view Depth Anything 3 path (RoPE, alt-attention,
|
||||
camera-token injection, ref-view selection, cat_token), use
|
||||
:meth:`get_intermediate_layers_da3` instead.
|
||||
|
||||
Args:
|
||||
pixel_values: ``(B, 3, H, W)`` single-view input.
|
||||
indices: layer indices to extract; supports negative indexing.
|
||||
apply_norm: if True, apply the final layernorm to each output.
|
||||
|
||||
Returns:
|
||||
list of ``(patch_tokens, cls_token)`` tuples with shapes
|
||||
``(B, N_patch, C)`` and ``(B, C)`` (one entry per ``indices``).
|
||||
"""
|
||||
"""Single-view multi-layer feature extraction."""
|
||||
x = self.embeddings(pixel_values)
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
n_layers = len(self.encoder.layer)
|
||||
@ -413,8 +384,7 @@ class Dinov2Model(torch.nn.Module):
|
||||
pos_global = torch.cat([cls_pos, torch.zeros_like(pos) + 1], dim=1)
|
||||
return pos_local, pos_global
|
||||
|
||||
def _inject_camera_token(self, x: torch.Tensor, B: int, S: int,
|
||||
cam_token: "torch.Tensor | None") -> torch.Tensor:
|
||||
def _inject_camera_token(self, x: torch.Tensor, B: int, S: int, cam_token: "torch.Tensor | None") -> torch.Tensor:
|
||||
# x: (B, S, N, C). Replace token at index 0 with the camera token.
|
||||
if cam_token is not None:
|
||||
inj = cam_token
|
||||
@ -427,40 +397,8 @@ class Dinov2Model(torch.nn.Module):
|
||||
x[:, :, 0] = inj
|
||||
return x
|
||||
|
||||
def get_intermediate_layers_da3(self, pixel_values, out_layers, cam_token=None,
|
||||
ref_view_strategy="saddle_balanced",
|
||||
export_feat_layers=None):
|
||||
"""Multi-view multi-layer feature extraction used by Depth Anything 3.
|
||||
|
||||
Adds RoPE positions, alternating local/global attention across views,
|
||||
camera-token injection, reference-view selection/reordering,
|
||||
``cat_token`` output and optional auxiliary feature exports on top of
|
||||
the vanilla DINOv2 path. For the single-view MoGe / CLIP-vision use
|
||||
case, see :meth:`get_intermediate_layers`.
|
||||
|
||||
Args:
|
||||
pixel_values: ``(B, S, 3, H, W)`` views or ``(B, 3, H, W)``.
|
||||
out_layers: indices into ``self.encoder.layer``.
|
||||
cam_token: optional ``(B, S, dim)`` camera token to inject at
|
||||
``alt_start``. If ``None`` and the model has its own
|
||||
``camera_token`` parameter, that is used.
|
||||
ref_view_strategy: when ``S >= 3`` and ``cam_token is None``,
|
||||
pick a reference view via this strategy and move it to
|
||||
position 0 right before the first alt-attention block.
|
||||
The original view order is restored on the way out.
|
||||
export_feat_layers: optional iterable of layer indices whose
|
||||
local attention outputs to also return as auxiliary
|
||||
features (``(B, S, N_patch, C)`` after final norm). Used
|
||||
by the multi-view path to expose intermediate features
|
||||
to the nested-architecture wrapper.
|
||||
|
||||
Returns:
|
||||
``(layer_outputs, aux_outputs)`` where ``layer_outputs`` is a
|
||||
list of ``(patch_tokens, cls_or_cam_token)`` tuples (one per
|
||||
``out_layers`` entry) and ``aux_outputs`` is a list of
|
||||
``(B, S, N_patch, C)`` features for ``export_feat_layers``
|
||||
(empty list when not requested).
|
||||
"""
|
||||
def get_intermediate_layers_da3(self, pixel_values, out_layers, cam_token=None, ref_view_strategy="saddle_balanced", export_feat_layers=None):
|
||||
"""Multi-view multi-layer feature extraction used by Depth Anything 3."""
|
||||
if pixel_values.ndim == 4:
|
||||
pixel_values = pixel_values.unsqueeze(1)
|
||||
assert pixel_values.ndim == 5 and pixel_values.shape[2] == 3, \
|
||||
@ -473,7 +411,7 @@ class Dinov2Model(torch.nn.Module):
|
||||
x = x.reshape(B, S, x.shape[-2], x.shape[-1]) # (B, S, 1+N, C)
|
||||
|
||||
pos_local, pos_global = self._prepare_rope_positions(B, S, H, W, x.device)
|
||||
# ``optimized_attention`` is only used by blocks without QK-norm/RoPE
|
||||
# optimized_attention is only used by blocks without QK-norm/RoPE
|
||||
# (vanilla DINOv2 path); enabling-aware blocks fall through to SDPA.
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
@ -492,10 +430,9 @@ class Dinov2Model(torch.nn.Module):
|
||||
g_pos = pos_global if apply_rope else None
|
||||
|
||||
# Reference-view selection threshold: matches the upstream constant
|
||||
# ``THRESH_FOR_REF_SELECTION = 3``. Skipped when a user-supplied
|
||||
# THRESH_FOR_REF_SELECTION = 3. Skipped when a user-supplied
|
||||
# cam_token is provided (camera info already pins the geometry).
|
||||
if (self.alt_start != -1 and i == self.alt_start - 1
|
||||
and S >= THRESH_FOR_REF_SELECTION and cam_token is None):
|
||||
if (self.alt_start != -1 and i == self.alt_start - 1 and S >= THRESH_FOR_REF_SELECTION and cam_token is None):
|
||||
b_idx = select_reference_view(x, strategy=ref_view_strategy)
|
||||
x = reorder_by_reference(x, b_idx)
|
||||
local_x = reorder_by_reference(local_x, b_idx)
|
||||
@ -534,7 +471,7 @@ class Dinov2Model(torch.nn.Module):
|
||||
aux = restore_original_order(aux, b_idx)
|
||||
aux_outputs.append(aux)
|
||||
|
||||
# Apply final norm. When ``cat_token`` is set, only the right half
|
||||
# Apply final norm. When cat_token is set, only the right half
|
||||
# ("global" features) is normalised; the left half is left as-is to
|
||||
# match the upstream DA3 head signature.
|
||||
normed: list[torch.Tensor] = []
|
||||
|
||||
@ -1,7 +0,0 @@
|
||||
# Depth Anything 3 - native ComfyUI port (Apache-2.0).
|
||||
#
|
||||
# Supported variants:
|
||||
# DA3-Small, DA3-Base (vits/vitb backbone, DualDPT head)
|
||||
# DA3Mono-Large, DA3Metric-Large (vitl backbone, DPT head + sky mask)
|
||||
#
|
||||
# Original repo: https://github.com/ByteDance-Seed/Depth-Anything-3
|
||||
@ -1,16 +1,4 @@
|
||||
"""Camera-token encoder and decoder for Depth Anything 3.
|
||||
|
||||
* :class:`CameraEnc` takes per-view extrinsics + intrinsics and produces a
|
||||
per-view camera token that gets injected at the alt-attention boundary
|
||||
in the DINOv2 backbone (block ``alt_start``).
|
||||
* :class:`CameraDec` takes the final-layer camera token output by the
|
||||
backbone and predicts a 9-D pose encoding (translation, quaternion,
|
||||
field-of-view).
|
||||
|
||||
The module/parameter names match the upstream ``cam_enc.py``/``cam_dec.py``
|
||||
so HF safetensors load directly with no key remapping (the upstream uses
|
||||
fused QKV linears, which we replicate here).
|
||||
"""
|
||||
"""Camera-token encoder and decoder for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -22,30 +10,27 @@ from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from .transform import affine_inverse, extri_intri_to_pose_encoding
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Building blocks (mirror ``depth_anything_3.model.utils.{attention,block}``)
|
||||
# -----------------------------------------------------------------------------
|
||||
# -----------------------------------------------------------------------
|
||||
# Building blocks (mirror depth_anything_3.model.utils.{attention,block})
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
|
||||
class _Mlp(nn.Module):
|
||||
"""Standard 2-layer MLP with GELU. Matches upstream ``utils.attention.Mlp``."""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None,
|
||||
*, device=None, dtype=None, operations=None):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, *, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = operations.Linear(in_features, hidden_features, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
self.fc2 = operations.Linear(hidden_features, out_features, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
self.fc1 = operations.Linear(in_features, hidden_features, bias=True, device=device, dtype=dtype)
|
||||
self.fc2 = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(F.gelu(self.fc1(x)))
|
||||
|
||||
|
||||
class _LayerScale(nn.Module):
|
||||
"""Per-channel learnable scaling. Matches upstream ``LayerScale``."""
|
||||
"""Per-channel learnable scaling. Matches upstream LayerScale."""
|
||||
|
||||
def __init__(self, dim, *, device=None, dtype=None):
|
||||
super().__init__()
|
||||
@ -56,22 +41,16 @@ class _LayerScale(nn.Module):
|
||||
|
||||
|
||||
class _Attention(nn.Module):
|
||||
"""Self-attention with fused QKV projection.
|
||||
""" Self-attention with fused QKV projection. Mirrors upstream utils.attention.Attention;
|
||||
Layout matches the HF safetensors (attn.qkv.{weight,bias} and attn.proj.{weight,bias})."""
|
||||
|
||||
Mirrors upstream ``utils.attention.Attention``; layout matches the
|
||||
HF safetensors (``attn.qkv.{weight,bias}`` and ``attn.proj.{weight,bias}``).
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_heads,
|
||||
*, device=None, dtype=None, operations=None):
|
||||
def __init__(self, dim, num_heads, *, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=True, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
@ -83,21 +62,15 @@ class _Attention(nn.Module):
|
||||
|
||||
|
||||
class _Block(nn.Module):
|
||||
"""Pre-norm transformer block with LayerScale.
|
||||
"""Pre-norm transformer block with LayerScale. Used by :class:CameraEnc. Layout follows upstream utils.block.Block."""
|
||||
|
||||
Used by :class:`CameraEnc`. Layout follows upstream ``utils.block.Block``.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4, init_values=0.01,
|
||||
*, device=None, dtype=None, operations=None):
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4, init_values=0.01, *, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
self.attn = _Attention(dim, num_heads,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.attn = _Attention(dim, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.ls1 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
|
||||
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
self.mlp = _Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio),
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.mlp = _Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
|
||||
self.ls2 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@ -1,12 +1,4 @@
|
||||
# DPT / DualDPT heads for Depth Anything 3.
|
||||
#
|
||||
# Ported from:
|
||||
# src/depth_anything_3/model/dpt.py (DPT - single main head + sky head)
|
||||
# src/depth_anything_3/model/dualdpt.py (DualDPT - depth + auxiliary "ray" head)
|
||||
#
|
||||
# In the monocular path we always discard the auxiliary "ray" output of
|
||||
# DualDPT. The auxiliary branch is still constructed so that DA3 HF weights
|
||||
# load cleanly without missing-key warnings.
|
||||
"""DPT / DualDPT heads for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -17,11 +9,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helpers (matching upstream head_utils.py)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Permute(nn.Module):
|
||||
def __init__(self, dims: Tuple[int, ...]):
|
||||
super().__init__()
|
||||
@ -50,8 +37,7 @@ def _custom_interpolate(
|
||||
return F.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
||||
|
||||
|
||||
def _create_uv_grid(width: int, height: int, aspect_ratio: float,
|
||||
dtype, device) -> torch.Tensor:
|
||||
def _create_uv_grid(width: int, height: int, aspect_ratio: float, dtype, device) -> torch.Tensor:
|
||||
"""Normalised UV grid spanning (-x_span, -y_span)..(x_span, y_span)."""
|
||||
diag_factor = (aspect_ratio ** 2 + 1.0) ** 0.5
|
||||
span_x = aspect_ratio / diag_factor
|
||||
@ -74,8 +60,7 @@ def _make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 1
|
||||
return torch.cat([out.sin(), out.cos()], dim=1).float()
|
||||
|
||||
|
||||
def _position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int,
|
||||
omega_0: float = 100.0) -> torch.Tensor:
|
||||
def _position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100.0) -> torch.Tensor:
|
||||
H, W, _ = pos_grid.shape
|
||||
pos_flat = pos_grid.reshape(-1, 2)
|
||||
emb_x = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0)
|
||||
@ -118,13 +103,10 @@ def _apply_activation(x: torch.Tensor, activation: str) -> torch.Tensor:
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
def __init__(self, features: int,
|
||||
device=None, dtype=None, operations=None):
|
||||
def __init__(self, features: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv1 = operations.Conv2d(features, features, 3, 1, 1, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
self.conv2 = operations.Conv2d(features, features, 3, 1, 1, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
self.conv1 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype)
|
||||
self.conv2 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype)
|
||||
self.activation = nn.ReLU(inplace=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -136,9 +118,7 @@ class ResidualConvUnit(nn.Module):
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
def __init__(self, features: int, has_residual: bool = True,
|
||||
align_corners: bool = True,
|
||||
device=None, dtype=None, operations=None):
|
||||
def __init__(self, features: int, has_residual: bool = True, align_corners: bool = True, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.align_corners = align_corners
|
||||
self.has_residual = has_residual
|
||||
@ -147,8 +127,7 @@ class FeatureFusionBlock(nn.Module):
|
||||
else:
|
||||
self.resConfUnit1 = None
|
||||
self.resConfUnit2 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
|
||||
self.out_conv = operations.Conv2d(features, features, 1, 1, 0, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
self.out_conv = operations.Conv2d(features, features, 1, 1, 0, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, *xs: torch.Tensor, size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
|
||||
y = xs[0]
|
||||
@ -159,8 +138,7 @@ class FeatureFusionBlock(nn.Module):
|
||||
up_kwargs = {"scale_factor": 2.0}
|
||||
else:
|
||||
up_kwargs = {"size": size}
|
||||
y = _custom_interpolate(y, **up_kwargs, mode="bilinear",
|
||||
align_corners=self.align_corners)
|
||||
y = _custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners)
|
||||
y = self.out_conv(y)
|
||||
return y
|
||||
|
||||
@ -169,25 +147,17 @@ class _Scratch(nn.Module):
|
||||
"""Container that mirrors upstream ``scratch`` attribute layout."""
|
||||
|
||||
|
||||
def _make_scratch(in_shape: List[int], out_shape: int,
|
||||
device=None, dtype=None, operations=None) -> _Scratch:
|
||||
def _make_scratch(in_shape: List[int], out_shape: int, device=None, dtype=None, operations=None) -> _Scratch:
|
||||
scratch = _Scratch()
|
||||
scratch.layer1_rn = operations.Conv2d(in_shape[0], out_shape, 3, 1, 1, bias=False,
|
||||
device=device, dtype=dtype)
|
||||
scratch.layer2_rn = operations.Conv2d(in_shape[1], out_shape, 3, 1, 1, bias=False,
|
||||
device=device, dtype=dtype)
|
||||
scratch.layer3_rn = operations.Conv2d(in_shape[2], out_shape, 3, 1, 1, bias=False,
|
||||
device=device, dtype=dtype)
|
||||
scratch.layer4_rn = operations.Conv2d(in_shape[3], out_shape, 3, 1, 1, bias=False,
|
||||
device=device, dtype=dtype)
|
||||
scratch.layer1_rn = operations.Conv2d(in_shape[0], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
scratch.layer2_rn = operations.Conv2d(in_shape[1], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
scratch.layer3_rn = operations.Conv2d(in_shape[2], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
scratch.layer4_rn = operations.Conv2d(in_shape[3], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
return scratch
|
||||
|
||||
|
||||
def _make_fusion_block(features: int, has_residual: bool = True,
|
||||
device=None, dtype=None, operations=None) -> FeatureFusionBlock:
|
||||
return FeatureFusionBlock(features, has_residual=has_residual,
|
||||
align_corners=True,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
def _make_fusion_block(features: int, has_residual: bool = True, device=None, dtype=None, operations=None) -> FeatureFusionBlock:
|
||||
return FeatureFusionBlock(features, has_residual=has_residual, align_corners=True, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@ -237,27 +207,21 @@ class DPT(nn.Module):
|
||||
|
||||
out_channels = list(out_channels)
|
||||
self.projects = nn.ModuleList([
|
||||
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype)
|
||||
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype)
|
||||
for oc in out_channels
|
||||
])
|
||||
self.resize_layers = nn.ModuleList([
|
||||
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype),
|
||||
nn.Identity(),
|
||||
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype),
|
||||
])
|
||||
|
||||
self.scratch = _make_scratch(out_channels, features,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
@ -266,24 +230,19 @@ class DPT(nn.Module):
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
if self.use_sky_head:
|
||||
self.scratch.sky_output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
operations.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, feats: List[torch.Tensor], H: int, W: int,
|
||||
patch_start_idx: int = 0, **_kwargs) -> dict:
|
||||
def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict:
|
||||
# feats[i][0] is the patch-token tensor with shape (B, S, N_patch, C)
|
||||
B, S, N, C = feats[0][0].shape
|
||||
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
|
||||
@ -350,14 +309,7 @@ class DPT(nn.Module):
|
||||
|
||||
|
||||
class DualDPT(nn.Module):
|
||||
"""Two-head DPT used by DA3-Small / DA3-Base.
|
||||
|
||||
The auxiliary "ray" head is constructed so that HF state-dict keys load
|
||||
cleanly. It is only executed when :attr:`enable_aux` is set on the
|
||||
instance (typically by ``DepthAnything3Net`` when running multi-view
|
||||
with ``use_ray_pose=True``); otherwise the monocular path skips it for
|
||||
speed and the auxiliary submodules sit idle.
|
||||
"""
|
||||
"""Two-head DPT used by DA3-Small / DA3-Base."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -386,40 +338,33 @@ class DualDPT(nn.Module):
|
||||
self.head_main, self.head_aux = head_names
|
||||
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
|
||||
# Toggle the auxiliary ray branch at runtime. Default off (mono path).
|
||||
# ``DepthAnything3Net`` flips this on when running multi-view + ray-pose.
|
||||
# DepthAnything3Net flips this on when running multi-view + ray-pose.
|
||||
self.enable_aux: bool = False
|
||||
|
||||
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
|
||||
out_channels = list(out_channels)
|
||||
self.projects = nn.ModuleList([
|
||||
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype)
|
||||
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype)
|
||||
for oc in out_channels
|
||||
])
|
||||
self.resize_layers = nn.ModuleList([
|
||||
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype),
|
||||
nn.Identity(),
|
||||
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype),
|
||||
])
|
||||
|
||||
self.scratch = _make_scratch(out_channels, features,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations)
|
||||
# Main fusion chain
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
|
||||
# Auxiliary fusion chain (separate copies)
|
||||
self.scratch.refinenet1_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
@ -430,11 +375,9 @@ class DualDPT(nn.Module):
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
# Aux pre-head per level (multi-level pyramid)
|
||||
@ -449,12 +392,10 @@ class DualDPT(nn.Module):
|
||||
Permute((0, 3, 1, 2))]
|
||||
self.scratch.output_conv2_aux = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
*ln_seq,
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
operations.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
for _ in range(self.aux_levels)
|
||||
])
|
||||
@ -470,8 +411,7 @@ class DualDPT(nn.Module):
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, feats: List[torch.Tensor], H: int, W: int,
|
||||
patch_start_idx: int = 0, **_kwargs) -> dict:
|
||||
def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict:
|
||||
B, S, N, C = feats[0][0].shape
|
||||
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
|
||||
|
||||
|
||||
@ -1,27 +1,3 @@
|
||||
# DepthAnything3Net: top-level wrapper that combines backbone + head.
|
||||
#
|
||||
# Supports both the monocular and the multi-view + camera path:
|
||||
#
|
||||
# * Monocular: ``S = 1``, no camera encoder/decoder. Mirrors the original
|
||||
# port that only handled ``DA3-MONO/METRIC-LARGE`` and the auxiliary-disabled
|
||||
# ``DA3-SMALL/BASE`` configs.
|
||||
# * Multi-view + camera: ``S > 1``. ``cam_enc`` (optional) maps user-supplied
|
||||
# extrinsics + intrinsics into a per-view camera token; ``cam_dec`` decodes
|
||||
# the final layer's camera token into a 9-D pose encoding. When the
|
||||
# auxiliary "ray" head of ``DualDPT`` is enabled the predicted ray map can
|
||||
# alternatively be used to estimate pose via RANSAC (``use_ray_pose=True``).
|
||||
# The 3D-Gaussian head and the nested-architecture wrapper are intentionally
|
||||
# left out of scope here; their state-dict keys (``gs_head.*``,
|
||||
# ``gs_adapter.*``) are dropped when repackaging the checkpoint with
|
||||
# ``scripts/convert_da3.py``, which also remaps the backbone into the native
|
||||
# ``Dinov2Model`` layout that this module loads directly.
|
||||
#
|
||||
# The backbone is shared with the CLIP-vision DINOv2 path
|
||||
# (``comfy.image_encoders.dino2.Dinov2Model``); the DA3-specific extensions
|
||||
# (RoPE, QK-norm, alternating local/global attention, camera token, multi-
|
||||
# layer feature extraction, reference-view reordering) are opt-in via the
|
||||
# config dict and are all disabled for the Mono/Metric variants.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional, Sequence
|
||||
@ -79,13 +55,6 @@ def _build_backbone_config(
|
||||
|
||||
|
||||
class DepthAnything3Net(nn.Module):
|
||||
"""ComfyUI-side DepthAnything3 network.
|
||||
|
||||
Parameters mirror the variant YAML configs from the upstream repo and
|
||||
are auto-detected from the state dict by ``comfy/model_detection.py``.
|
||||
The kwargs ``device``, ``dtype`` and ``operations`` are injected by
|
||||
``BaseModel``.
|
||||
"""
|
||||
|
||||
PATCH_SIZE = 14
|
||||
|
||||
@ -99,17 +68,17 @@ class DepthAnything3Net(nn.Module):
|
||||
rope_start: int = -1,
|
||||
cat_token: bool = False,
|
||||
# --- Head ---
|
||||
head_type: str = "dpt", # "dpt" or "dualdpt"
|
||||
head_type: str = "dpt", # dpt or dualdpt
|
||||
head_dim_in: int = 1024,
|
||||
head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf
|
||||
head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf
|
||||
head_features: int = 256,
|
||||
head_out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
head_use_sky_head: bool = True, # ignored by DualDPT
|
||||
head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT
|
||||
head_use_sky_head: bool = True, # ignored by DualDPT
|
||||
head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT
|
||||
# --- Camera (multi-view) ---
|
||||
has_cam_enc: bool = False,
|
||||
has_cam_dec: bool = False,
|
||||
cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim)
|
||||
cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim)
|
||||
cam_dec_dim_in: Optional[int] = None, # CameraDec dim_in (defaults to 2*embed_dim with cat_token)
|
||||
# ComfyUI plumbing
|
||||
device=None, dtype=None, operations=None,
|
||||
@ -182,42 +151,7 @@ class DepthAnything3Net(nn.Module):
|
||||
export_feat_layers: Optional[Sequence[int]] = None,
|
||||
**_unused,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Run depth (and optionally pose) prediction.
|
||||
|
||||
Args:
|
||||
image: ``(B, 3, H, W)`` ImageNet-normalised image tensor, or
|
||||
``(B, S, 3, H, W)`` for multi-view inputs. ``H`` and ``W``
|
||||
must be multiples of 14.
|
||||
extrinsics: optional ``(B, S, 4, 4)`` world-to-camera extrinsics.
|
||||
When provided together with ``intrinsics``, ``CameraEnc``
|
||||
converts them into per-view camera tokens that the backbone
|
||||
injects at block ``alt_start``.
|
||||
intrinsics: optional ``(B, S, 3, 3)`` pixel-space intrinsics.
|
||||
use_ray_pose: if True, predict pose from the auxiliary "ray" head
|
||||
(RANSAC over per-pixel rays). Only available on DualDPT
|
||||
variants. If False (default) and ``cam_dec`` is present,
|
||||
the final-layer cam token is decoded into pose instead.
|
||||
ref_view_strategy: reference-view selection strategy used when
|
||||
``S >= 3`` and no extrinsics are supplied. See
|
||||
:mod:`comfy.ldm.depth_anything_3.reference_view_selector`.
|
||||
export_feat_layers: optional list of backbone layer indices whose
|
||||
local features to also return as auxiliary outputs (used by
|
||||
downstream nested-architecture wrappers; empty by default).
|
||||
|
||||
Returns:
|
||||
Dict with a subset of:
|
||||
- ``depth`` ``(B*S, H, W)`` raw depth values.
|
||||
- ``depth_conf`` ``(B*S, H, W)`` confidence (DualDPT only).
|
||||
- ``sky`` ``(B*S, H, W)`` sky probability (DPT + sky head).
|
||||
- ``ray`` ``(B, S, h, w, 6)`` per-pixel cam ray (DualDPT,
|
||||
multi-view, ``use_ray_pose=True`` only).
|
||||
- ``ray_conf`` ``(B, S, h, w)`` ray confidence.
|
||||
- ``extrinsics`` ``(B, S, 4, 4)`` world-to-cam, when pose
|
||||
prediction is active.
|
||||
- ``intrinsics`` ``(B, S, 3, 3)`` pixel-space intrinsics.
|
||||
- ``aux_features`` list of ``(B, S, h_p, w_p, C)`` features
|
||||
when ``export_feat_layers`` is non-empty.
|
||||
"""
|
||||
"""Run depth and optionally pose prediction."""
|
||||
if image.ndim == 4:
|
||||
image = image.unsqueeze(1) # (B, 1, 3, H, W)
|
||||
assert image.ndim == 5 and image.shape[2] == 3, \
|
||||
@ -292,7 +226,7 @@ class DepthAnything3Net(nn.Module):
|
||||
return out
|
||||
|
||||
def _reshape_aux_features(self, aux_feats, H: int, W: int):
|
||||
"""Reshape ``(B, S, N, C)`` aux features into ``(B, S, h_p, w_p, C)``."""
|
||||
"""Reshape (B, S, N, C) aux features into (B, S, h_p, w_p, C)."""
|
||||
ph, pw = H // self.PATCH_SIZE, W // self.PATCH_SIZE
|
||||
out = []
|
||||
for f in aux_feats:
|
||||
|
||||
@ -1,17 +1,4 @@
|
||||
# Input/output preprocessing helpers for Depth Anything 3.
|
||||
#
|
||||
# Ported from:
|
||||
# src/depth_anything_3/utils/io/input_processor.py (image normalisation)
|
||||
# src/depth_anything_3/utils/alignment.py (sky-aware depth clip)
|
||||
# src/depth_anything_3/model/da3.py::_process_mono_sky_estimation
|
||||
#
|
||||
# Resize: ``comfy.utils.common_upscale`` with ``upscale_method="lanczos"``.
|
||||
# Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale); a sweep
|
||||
# across {bilinear, bicubic, area, lanczos, bislerp} on a 768->504 test image
|
||||
# showed lanczos has the lowest max-abs-diff vs the upstream cv2 output
|
||||
# (~0.13 vs 0.21-0.71 for the others), so we use it in both directions for
|
||||
# simplicity. This keeps the path stateless, on-device, and free of any
|
||||
# OpenCV dependency.
|
||||
"""Input/output preprocessing helpers for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -34,16 +21,11 @@ def _round_to_patch(x: int, patch: int = PATCH_SIZE) -> int:
|
||||
return up if abs(up - x) <= abs(x - down) else down
|
||||
|
||||
|
||||
def compute_target_size(orig_h: int, orig_w: int, process_res: int,
|
||||
method: str = "upper_bound_resize") -> Tuple[int, int]:
|
||||
def compute_target_size(orig_h: int, orig_w: int, process_res: int, method: str = "upper_bound_resize") -> Tuple[int, int]:
|
||||
"""Compute (target_h, target_w) for a single image.
|
||||
upper_bound_resize: scale longest side to process_res, then round each dim to nearest multiple of 14 (default upstream method).
|
||||
lower_bound_resize: scale shortest side to process_res, then round."""
|
||||
|
||||
Methods:
|
||||
- "upper_bound_resize": scale longest side to ``process_res``, then
|
||||
round each dim to nearest multiple of 14 (default upstream method).
|
||||
- "lower_bound_resize": scale shortest side to ``process_res``, then
|
||||
round.
|
||||
"""
|
||||
if method == "upper_bound_resize":
|
||||
longest = max(orig_h, orig_w)
|
||||
scale = process_res / float(longest)
|
||||
@ -58,26 +40,8 @@ def compute_target_size(orig_h: int, orig_w: int, process_res: int,
|
||||
return new_h, new_w
|
||||
|
||||
|
||||
def preprocess_image(
|
||||
image: torch.Tensor,
|
||||
process_res: int = 504,
|
||||
method: str = "upper_bound_resize",
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a ComfyUI ``IMAGE`` batch for DA3.
|
||||
|
||||
Args:
|
||||
image: ``(B, H, W, 3)`` float in [0, 1] (ComfyUI ``IMAGE`` convention).
|
||||
process_res: target resolution (longest or shortest side, depending
|
||||
on ``method``).
|
||||
method: resize strategy.
|
||||
|
||||
Returns:
|
||||
``(B, 3, H', W')`` tensor with H' and W' multiples of 14, normalised
|
||||
with ImageNet statistics. The tensor lives on the same device as
|
||||
``image``.
|
||||
"""
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, \
|
||||
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
def preprocess_image(image: torch.Tensor, process_res: int = 504, method: str = "upper_bound_resize") -> torch.Tensor:
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
B, H, W, _ = image.shape
|
||||
target_h, target_w = compute_target_size(H, W, process_res, method)
|
||||
|
||||
@ -88,9 +52,7 @@ def preprocess_image(
|
||||
# Lanczos in ``common_upscale`` is anti-aliased and produces the
|
||||
# closest pixel-wise match in a sweep across {bilinear, bicubic,
|
||||
# area, lanczos, bislerp}. Used in both directions for simplicity.
|
||||
x = comfy.utils.common_upscale(
|
||||
x.float(), target_w, target_h, "lanczos", "disabled",
|
||||
)
|
||||
x = comfy.utils.common_upscale(x.float(), target_w, target_h, "lanczos", "disabled",)
|
||||
x = x.clamp(0.0, 1.0)
|
||||
|
||||
mean = _IMAGENET_MEAN.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
||||
@ -109,17 +71,8 @@ def compute_non_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -
|
||||
return sky_prediction < threshold
|
||||
|
||||
|
||||
def apply_sky_aware_clip(
|
||||
depth: torch.Tensor,
|
||||
sky: torch.Tensor,
|
||||
threshold: float = 0.3,
|
||||
quantile: float = 0.99,
|
||||
) -> torch.Tensor:
|
||||
"""Replicates ``_process_mono_sky_estimation`` from upstream.
|
||||
|
||||
Clips sky regions to the 99th percentile of non-sky depth. Returns a new
|
||||
depth tensor; ``depth`` is not modified in place.
|
||||
"""
|
||||
def apply_sky_aware_clip(depth: torch.Tensor, sky: torch.Tensor, threshold: float = 0.3, quantile: float = 0.99) -> torch.Tensor:
|
||||
"""Clips sky regions to the 99th percentile of non-sky depth. Returns a new depth tensor."""
|
||||
non_sky = compute_non_sky_mask(sky, threshold=threshold)
|
||||
if non_sky.sum() <= 10 or (~non_sky).sum() <= 10:
|
||||
return depth.clone()
|
||||
@ -137,17 +90,8 @@ def apply_sky_aware_clip(
|
||||
return out
|
||||
|
||||
|
||||
def normalize_depth_v2_style(
|
||||
depth: torch.Tensor,
|
||||
sky: torch.Tensor | None = None,
|
||||
low_quantile: float = 0.01,
|
||||
high_quantile: float = 0.99,
|
||||
) -> torch.Tensor:
|
||||
"""V2-style normalization for ControlNet workflows.
|
||||
|
||||
Computes percentile bounds over non-sky pixels (when available),
|
||||
then maps depth into [0, 1] with near = white (1.0).
|
||||
"""
|
||||
def normalize_depth_v2_style(depth: torch.Tensor, sky: torch.Tensor | None = None, low_quantile: float = 0.01, high_quantile: float = 0.99) -> torch.Tensor:
|
||||
"""V2-style normalization computes percentile bounds over non-sky pixels (when available), then maps depth into [0, 1] with near = white (1.0)."""
|
||||
if sky is not None:
|
||||
mask = compute_non_sky_mask(sky)
|
||||
if mask.any():
|
||||
@ -167,10 +111,10 @@ def normalize_depth_v2_style(
|
||||
hi = torch.quantile(sample, high_quantile)
|
||||
rng = (hi - lo).clamp(min=1e-6)
|
||||
norm = ((depth - lo) / rng).clamp(0.0, 1.0)
|
||||
# ControlNet convention: nearer pixels are brighter (1.0).
|
||||
# Nearer pixels are brighter (1.0)
|
||||
norm = 1.0 - norm
|
||||
if sky is not None:
|
||||
# Sky pixels become black (far / unknown).
|
||||
# Sky pixels become black (far / unknown)
|
||||
sky_mask = ~compute_non_sky_mask(sky)
|
||||
norm = torch.where(sky_mask, torch.zeros_like(norm), norm)
|
||||
return norm
|
||||
|
||||
@ -1,21 +1,4 @@
|
||||
"""Ray-to-pose conversion for the multi-view path of Depth Anything 3.
|
||||
|
||||
Converts the auxiliary "ray" output of :class:`DualDPT` (per-pixel camera
|
||||
ray vectors, predicted on the per-view local feature map) into per-view
|
||||
extrinsics + intrinsics. Implementation is a 1:1 port of
|
||||
``depth_anything_3.utils.ray_utils`` upstream, using a weighted-RANSAC
|
||||
homography fit followed by a QL decomposition.
|
||||
|
||||
No learned parameters; pure tensor math. Output:
|
||||
|
||||
* ``R`` -- ``(B, S, 3, 3)`` rotation matrix
|
||||
* ``T`` -- ``(B, S, 3)`` camera-space translation
|
||||
* ``focal_lengths`` -- ``(B, S, 2)`` in normalised image space (image=2x2)
|
||||
* ``principal_points`` -- ``(B, S, 2)`` ditto
|
||||
|
||||
:func:`get_extrinsic_from_camray` wraps these into a 4x4 extrinsic matrix
|
||||
that the public node converts back into pixel-space intrinsics.
|
||||
"""
|
||||
"""Ray-to-pose conversion for the multi-view path of Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -28,13 +11,10 @@ import torch
|
||||
|
||||
|
||||
def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Decompose ``A = Q @ L`` with ``Q`` orthogonal and ``L`` lower-triangular.
|
||||
|
||||
"""Decompose A = Q @ L with Q orthogonal and L lower-triangular.
|
||||
Implemented in terms of QR by reversing the columns/rows; the standard
|
||||
trick from the upstream reference. Inputs ``A`` are ``(3, 3)``.
|
||||
"""
|
||||
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]],
|
||||
device=A.device, dtype=A.dtype)
|
||||
trick from the upstream reference. Inputs A are (3, 3)."""
|
||||
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device, dtype=A.dtype)
|
||||
A_tilde = A @ P
|
||||
# CUDA QR is not implemented for fp16/bf16; upcast just for this call.
|
||||
Q_tilde, R_tilde = torch.linalg.qr(A_tilde.float())
|
||||
@ -44,8 +24,8 @@ def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
L = P @ R_tilde @ P
|
||||
d = torch.diag(L)
|
||||
sign = torch.sign(d)
|
||||
Q = Q * sign[None, :] # scale columns of Q
|
||||
L = L * sign[:, None] # scale rows of L
|
||||
Q = Q * sign[None, :] # scale columns of Q
|
||||
L = L * sign[:, None] # scale rows of L
|
||||
return Q, L
|
||||
|
||||
|
||||
@ -58,12 +38,8 @@ def _homogenize_points(points: torch.Tensor) -> torch.Tensor:
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _find_homography_weighted_lsq(
|
||||
src_pts: torch.Tensor,
|
||||
dst_pts: torch.Tensor,
|
||||
confident_weight: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Solve a single ``H`` with weighted least-squares (DLT)."""
|
||||
def _find_homography_weighted_lsq(src_pts: torch.Tensor, dst_pts: torch.Tensor, confident_weight: torch.Tensor,) -> torch.Tensor:
|
||||
"""Solve a single H with weighted least-squares (DLT)."""
|
||||
N = src_pts.shape[0]
|
||||
if N < 4:
|
||||
raise ValueError("At least 4 points are required to compute a homography.")
|
||||
@ -83,12 +59,8 @@ def _find_homography_weighted_lsq(
|
||||
return H / H[-1, -1]
|
||||
|
||||
|
||||
def _find_homography_weighted_lsq_batched(
|
||||
src_pts_batch: torch.Tensor,
|
||||
dst_pts_batch: torch.Tensor,
|
||||
confident_weight_batch: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Batched DLT solver. Inputs ``(B, K, 2)`` / ``(B, K)``; output ``(B, 3, 3)``."""
|
||||
def _find_homography_weighted_lsq_batched(src_pts_batch: torch.Tensor, dst_pts_batch: torch.Tensor, confident_weight_batch: torch.Tensor) -> torch.Tensor:
|
||||
"""Batched DLT solver. Inputs (B, K, 2) / (B, K); output (B, 3, 3)."""
|
||||
B, K, _ = src_pts_batch.shape
|
||||
w = confident_weight_batch.sqrt().unsqueeze(2)
|
||||
x = src_pts_batch[:, :, 0:1]
|
||||
@ -117,10 +89,7 @@ def _ransac_find_homography_weighted_batched(
|
||||
max_inlier_num: int = 10000,
|
||||
rand_sample_iters_idx: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Batched weighted-RANSAC homography estimator.
|
||||
|
||||
Returns ``(B, 3, 3)`` homography matrices.
|
||||
"""
|
||||
"""Batched weighted-RANSAC homography estimator. Returns (B, 3, 3) homography matrices."""
|
||||
B, N, _ = src_pts.shape
|
||||
assert N >= 4
|
||||
device = src_pts.device
|
||||
@ -188,15 +157,8 @@ def _ransac_find_homography_weighted_batched(
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _unproject_identity(num_y: int, num_x: int, B: int, S: int,
|
||||
device, dtype) -> torch.Tensor:
|
||||
"""Camera-space unit rays for an identity intrinsic on a 2x2 image plane.
|
||||
|
||||
Replicates ``unproject_depth(..., ixt_normalized=True)`` upstream: pixel
|
||||
coords ``(x, y)`` in ``[dx, 2-dx] x [dy, 2-dy]`` get mapped to
|
||||
camera-space rays ``(x-1, y-1, 1)`` via the identity intrinsic
|
||||
``[[1,0,1],[0,1,1],[0,0,1]]``. Returns ``(B, S, num_y, num_x, 3)``.
|
||||
"""
|
||||
def _unproject_identity(num_y: int, num_x: int, B: int, S: int, device, dtype) -> torch.Tensor:
|
||||
"""Camera-space unit rays for an identity intrinsic on a 2x2 image plane."""
|
||||
dx = 1.0 / num_x
|
||||
dy = 1.0 / num_y
|
||||
# Centered camera-space coords directly (skip the K^-1 step since it's
|
||||
@ -210,7 +172,7 @@ def _unproject_identity(num_y: int, num_x: int, B: int, S: int,
|
||||
|
||||
|
||||
def _camray_to_caminfo(
|
||||
camray: torch.Tensor, # (B, S, h, w, 6)
|
||||
camray: torch.Tensor, # (B, S, h, w, 6)
|
||||
confidence: Optional[torch.Tensor] = None, # (B, S, h, w)
|
||||
reproj_threshold: float = 0.2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
@ -294,20 +256,12 @@ def _camray_to_caminfo(
|
||||
|
||||
|
||||
def get_extrinsic_from_camray(
|
||||
camray: torch.Tensor, # (B, S, h, w, 6)
|
||||
conf: torch.Tensor, # (B, S, h, w, 1) or (B, S, h, w)
|
||||
camray: torch.Tensor, # (B, S, h, w, 6)
|
||||
conf: torch.Tensor, # (B, S, h, w, 1) or (B, S, h, w)
|
||||
patch_size_y: int,
|
||||
patch_size_x: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Wrap a 4x4 extrinsic + per-view focal + principal-point output.
|
||||
|
||||
Returns:
|
||||
* extrinsic ``(B, S, 4, 4)`` camera-to-world (the inverse is
|
||||
what gets stored in ``output.extrinsics``
|
||||
by the caller).
|
||||
* focals ``(B, S, 2)`` in normalised image space.
|
||||
* pp ``(B, S, 2)`` in normalised image space.
|
||||
"""
|
||||
"""Wrap a 4x4 extrinsic + per-view focal + principal-point output."""
|
||||
if conf.ndim == 5 and conf.shape[-1] == 1:
|
||||
conf = conf.squeeze(-1)
|
||||
R, T, focal, pp = _camray_to_caminfo(camray, confidence=conf)
|
||||
|
||||
@ -1,15 +1,4 @@
|
||||
"""Reference-view selection for the multi-view path of Depth Anything 3.
|
||||
|
||||
Pure tensor math, no learned parameters. Exposed as three free functions:
|
||||
|
||||
* :func:`select_reference_view` -- pick a reference view per batch.
|
||||
* :func:`reorder_by_reference` -- move the reference view to position 0.
|
||||
* :func:`restore_original_order` -- inverse of :func:`reorder_by_reference`.
|
||||
|
||||
Mirrors ``depth_anything_3.model.reference_view_selector`` upstream.
|
||||
The default strategy (``"saddle_balanced"``) selects the view whose CLS
|
||||
token features are closest to the median across multiple metrics.
|
||||
"""
|
||||
"""Reference-view selection for the multi-view path of Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -26,22 +15,8 @@ RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_rang
|
||||
THRESH_FOR_REF_SELECTION: int = 3
|
||||
|
||||
|
||||
def select_reference_view(
|
||||
x: torch.Tensor,
|
||||
strategy: RefViewStrategy = "saddle_balanced",
|
||||
) -> torch.Tensor:
|
||||
"""Pick a reference view index per batch element.
|
||||
|
||||
Args:
|
||||
x: ``(B, S, N, C)`` token tensor. Index 0 along ``N`` is the
|
||||
cls/cam token used by the feature-based strategies.
|
||||
strategy: One of ``"first" | "middle" | "saddle_balanced" |
|
||||
"saddle_sim_range"``.
|
||||
|
||||
Returns:
|
||||
``(B,)`` long tensor with the chosen reference view index for
|
||||
each batch element.
|
||||
"""
|
||||
def select_reference_view(x: torch.Tensor, strategy: RefViewStrategy = "saddle_balanced") -> torch.Tensor:
|
||||
"""Pick a reference view index per batch element."""
|
||||
B, S, _, _ = x.shape
|
||||
if S <= 1:
|
||||
return torch.zeros(B, dtype=torch.long, device=x.device)
|
||||
@ -83,7 +58,7 @@ def select_reference_view(
|
||||
|
||||
|
||||
def reorder_by_reference(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
|
||||
"""Reorder ``x`` so the reference view is at position 0 in axis ``S``."""
|
||||
"""Reorder x so the reference view is at position 0 in axis S."""
|
||||
B, S = x.shape[0], x.shape[1]
|
||||
if S <= 1:
|
||||
return x
|
||||
@ -100,17 +75,13 @@ def reorder_by_reference(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
def restore_original_order(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
|
||||
"""Inverse of :func:`reorder_by_reference`."""
|
||||
"""Inverse of reorder_by_reference."""
|
||||
B, S = x.shape[0], x.shape[1]
|
||||
if S <= 1:
|
||||
return x
|
||||
target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
||||
b_idx_exp = b_idx.unsqueeze(1)
|
||||
restore = torch.where(target_positions < b_idx_exp,
|
||||
target_positions + 1,
|
||||
target_positions)
|
||||
restore = torch.scatter(
|
||||
restore, dim=1, index=b_idx_exp, src=torch.zeros_like(b_idx_exp),
|
||||
)
|
||||
restore = torch.where(target_positions < b_idx_exp, target_positions + 1, target_positions)
|
||||
restore = torch.scatter(restore, dim=1, index=b_idx_exp, src=torch.zeros_like(b_idx_exp))
|
||||
batch = torch.arange(B, device=x.device).unsqueeze(1)
|
||||
return x[batch, restore]
|
||||
|
||||
@ -1,11 +1,4 @@
|
||||
"""Geometry / camera transform helpers for Depth Anything 3.
|
||||
|
||||
Pure tensor math, no learned parameters. Mirrors the upstream upstream
|
||||
``depth_anything_3.model.utils.transform`` and the parts of
|
||||
``depth_anything_3.utils.geometry`` used at inference time on the
|
||||
multi-view + camera path. Kept self-contained so the DA3 module is fully
|
||||
ported and does not depend on the upstream repo at runtime.
|
||||
"""
|
||||
"""Geometry / camera transform helpers for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -21,10 +14,7 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
def as_homogeneous(ext: torch.Tensor) -> torch.Tensor:
|
||||
"""Promote ``(...,3,4)`` extrinsics to ``(...,4,4)`` homogeneous form.
|
||||
|
||||
A no-op when the input is already ``(...,4,4)``.
|
||||
"""
|
||||
"""Promote (...,3,4) extrinsics to (...,4,4) homogeneous form. No-op when the input is already ``(...,4,4)``."""
|
||||
if ext.shape[-2:] == (4, 4):
|
||||
return ext
|
||||
if ext.shape[-2:] == (3, 4):
|
||||
@ -48,7 +38,7 @@ def affine_inverse(A: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
||||
"""``sqrt(max(0, x))`` with a zero subgradient where ``x == 0``."""
|
||||
"""sqrt(max(0, x)) with a zero subgradient where x == 0."""
|
||||
ret = torch.zeros_like(x)
|
||||
positive_mask = x > 0
|
||||
if torch.is_grad_enabled():
|
||||
@ -64,7 +54,7 @@ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert quaternions (xyzw) to ``(...,3,3)`` rotation matrices."""
|
||||
"""Convert quaternions (xyzw) to (...,3,3) rotation matrices."""
|
||||
i, j, k, r = torch.unbind(quaternions, -1)
|
||||
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
||||
o = torch.stack(
|
||||
@ -85,7 +75,7 @@ def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert ``(...,3,3)`` rotation matrices to quaternions (xyzw)."""
|
||||
"""Convert (...,3,3) rotation matrices to quaternions (xyzw)."""
|
||||
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||||
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
||||
|
||||
@ -132,16 +122,11 @@ def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def extri_intri_to_pose_encoding(
|
||||
extrinsics: torch.Tensor,
|
||||
intrinsics: torch.Tensor,
|
||||
image_size_hw: Tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""Pack ``(extr, intr, image_size)`` into the 9-D pose-encoding vector.
|
||||
|
||||
``extrinsics`` are camera-to-world (c2w) ``(B,S,4,4)`` matrices,
|
||||
``intrinsics`` are pixel-space ``(B,S,3,3)`` matrices, ``image_size_hw``
|
||||
is a ``(H, W)`` pair. The encoding is ``[T(3), quat_xyzw(4), fov_h, fov_w]``.
|
||||
def extri_intri_to_pose_encoding(extrinsics: torch.Tensor, intrinsics: torch.Tensor, image_size_hw: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Pack (extr, intr, image_size) into the 9-D pose-encoding vector.
|
||||
extrinsics: camera-to-world (c2w) (B,S,4,4) matrices,
|
||||
intrinsics: pixel-space (B,S,3,3) matrices,
|
||||
image_size_hw: is a (H, W) pair.
|
||||
"""
|
||||
R = extrinsics[..., :3, :3]
|
||||
T = extrinsics[..., :3, 3]
|
||||
@ -152,15 +137,8 @@ def extri_intri_to_pose_encoding(
|
||||
return torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
|
||||
|
||||
|
||||
def pose_encoding_to_extri_intri(
|
||||
pose_encoding: torch.Tensor,
|
||||
image_size_hw: Tuple[int, int],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Inverse of :func:`extri_intri_to_pose_encoding`.
|
||||
|
||||
Returns a ``(B,S,3,4)`` c2w extrinsic matrix and a ``(B,S,3,3)``
|
||||
pixel-space intrinsic matrix.
|
||||
"""
|
||||
def pose_encoding_to_extri_intri(pose_encoding: torch.Tensor, image_size_hw: Tuple[int, int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Inverse of extri_intri_to_pose_encoding."""
|
||||
T = pose_encoding[..., :3]
|
||||
quat = pose_encoding[..., 3:7]
|
||||
fov_h = pose_encoding[..., 7]
|
||||
@ -173,8 +151,7 @@ def pose_encoding_to_extri_intri(
|
||||
H, W = image_size_hw
|
||||
fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6)
|
||||
fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6)
|
||||
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3),
|
||||
device=pose_encoding.device, dtype=pose_encoding.dtype)
|
||||
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device, dtype=pose_encoding.dtype)
|
||||
intrinsics[..., 0, 0] = fx
|
||||
intrinsics[..., 1, 1] = fy
|
||||
intrinsics[..., 0, 2] = W / 2
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
"""ComfyUI nodes for Depth Anything 3.
|
||||
Model capability matrix:
|
||||
|
||||
Model capability matrix
|
||||
-----------------------
|
||||
Variant head_type has_sky has_conf cam_dec
|
||||
DA3-Small dualdpt False True yes
|
||||
DA3-Base dualdpt False True yes
|
||||
DA3-Mono-Large dpt True False no
|
||||
DA3-Metric-Large dpt True False no (raw output is metres)
|
||||
Variant head_type has_sky has_conf cam_dec
|
||||
DA3-Small dualdpt False True yes
|
||||
DA3-Base dualdpt False True yes
|
||||
DA3-Mono-Large dpt True False no
|
||||
DA3-Metric-Large dpt True False no (raw output is metres)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -90,15 +89,7 @@ def _da3_get_extrinsic(geometry: dict, b: int) -> torch.Tensor | None:
|
||||
|
||||
|
||||
def _da3_apply_extrinsic(points_cam: torch.Tensor, E: torch.Tensor) -> torch.Tensor:
|
||||
"""Transform (H,W,3) OpenCV camera-space points to world space.
|
||||
|
||||
E is the world-to-camera SE(3) matrix (3×4 or 4×4). The camera-to-world
|
||||
inverse is computed analytically as [Rᵀ | −Rᵀt] rather than via
|
||||
torch.linalg.inv to avoid numerical failures on near-degenerate poses.
|
||||
|
||||
Returns the original camera-space points unchanged if E contains non-finite
|
||||
values (failed pose estimation), so the node can still produce a mesh.
|
||||
"""
|
||||
"""Transform (H,W,3) OpenCV camera-space points to world space."""
|
||||
E = E.to(points_cam.device).float()
|
||||
if not torch.isfinite(E).all():
|
||||
logging.getLogger("comfy").warning(
|
||||
@ -117,11 +108,7 @@ def _da3_apply_extrinsic(points_cam: torch.Tensor, E: torch.Tensor) -> torch.Ten
|
||||
|
||||
|
||||
def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor:
|
||||
"""Map raw confidence (exp(x)+1 activation, range [1, ∞)) to [0, 1] per image.
|
||||
|
||||
Min-max per image preserves the spatial pattern while producing a [0, 1]
|
||||
value suitable for both display and masking.
|
||||
"""
|
||||
"""Map raw confidence to [0, 1] per image."""
|
||||
B = conf.shape[0]
|
||||
out = []
|
||||
for i in range(B):
|
||||
@ -131,8 +118,7 @@ def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack(out, dim=0)
|
||||
|
||||
|
||||
def _da3_build_mask(geometry: dict, b: int, H: int, W: int,
|
||||
confidence_threshold: float, use_sky_mask: bool) -> torch.Tensor:
|
||||
def _da3_build_mask(geometry: dict, b: int, H: int, W: int, confidence_threshold: float, use_sky_mask: bool) -> torch.Tensor:
|
||||
"""Build (H,W) bool keep-mask from sky probability and confidence."""
|
||||
mask = torch.ones(H, W, dtype=torch.bool)
|
||||
if use_sky_mask and "sky" in geometry:
|
||||
@ -179,11 +165,9 @@ class LoadDA3Model(io.ComfyNode):
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
def _run_da3(model_patcher, image: torch.Tensor, process_res: int,
|
||||
method: str = "upper_bound_resize"):
|
||||
"""Run DA3 on ``(B,H,W,3)`` IMAGE; returns depth/conf/sky at original resolution (or None)."""
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, \
|
||||
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
def _run_da3(model_patcher, image: torch.Tensor, process_res: int, method: str = "upper_bound_resize"):
|
||||
"""Run DA3 on (B,H,W,3), returns depth/conf/sky at original resolution (or None)."""
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
|
||||
B, H, W, _ = image.shape
|
||||
mm.load_model_gpu(model_patcher)
|
||||
@ -236,56 +220,46 @@ class DA3Inference(io.ComfyNode):
|
||||
description="Run Depth Anything 3 on an image. In multi-view mode each image is treated as a separate view of the same scene.",
|
||||
inputs=[
|
||||
DA3ModelType.Input("da3_model"),
|
||||
io.Image.Input("image",
|
||||
tooltip="In multi-view mode each image is treated as "
|
||||
"a separate view of the same scene."),
|
||||
io.Int.Input("process_res", default=504, min=140, max=2520, step=14,
|
||||
tooltip="Resolution the model runs at (longest side, multiple of 14). "
|
||||
"Lower = faster / less VRAM; higher = more detail. "
|
||||
"Output is upsampled back to the original size."),
|
||||
io.Combo.Input("resize_method",
|
||||
options=["upper_bound_resize", "lower_bound_resize"],
|
||||
default="upper_bound_resize",
|
||||
tooltip="- upper_bound_resize: scale so the longest side = process_res (caps memory, default).\n"
|
||||
"- lower_bound_resize: scale so the shortest side = process_res (preserves more detail on tall/wide images, uses more memory)."),
|
||||
io.DynamicCombo.Input("mode",
|
||||
tooltip="- mono: single view image - works with any model variant.\n"
|
||||
"- multiview: all images processed together for geometric consistency + camera pose, for Small/Base models only.",
|
||||
options=[
|
||||
io.DynamicCombo.Option("mono", []),
|
||||
io.DynamicCombo.Option("multiview", [
|
||||
io.Combo.Input("ref_view_strategy",
|
||||
options=["saddle_balanced", "saddle_sim_range",
|
||||
"first", "middle"],
|
||||
default="saddle_balanced",
|
||||
tooltip="Which view acts as the geometric anchor (only when S >= 3 and no extrinsics provided).\n"
|
||||
"- saddle_balanced: the view most 'average' across all others - best general choice.\n"
|
||||
"- saddle_sim_range: the view most visually distinct from the others.\n"
|
||||
"- first / middle: fixed positional picks."),
|
||||
io.Combo.Input("pose_method",
|
||||
options=["cam_dec", "ray_pose"],
|
||||
default="cam_dec",
|
||||
tooltip="How the camera field-of-view is estimated (for Small/Base models only).\n"
|
||||
"- cam_dec: learned from image features.\n"
|
||||
"- ray_pose: derived geometrically from the model's 3-D ray output.\n"
|
||||
"Affects perspective correctness of the 3-D output. Try both if results look distorted."),
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input("resolution", default=504, min=140, max=2520, step=14,
|
||||
tooltip="Resolution the model runs at (longest side, multiple of 14).\n"
|
||||
"Lower = faster / less VRAM.\n"
|
||||
"Higher = more detail.\n"
|
||||
"Output is upsampled back to the original size."),
|
||||
io.Combo.Input("resize_method", options=["upper_bound_resize", "lower_bound_resize"], default="upper_bound_resize",
|
||||
tooltip="upper_bound_resize: scale so the longest side = resolution (caps memory, default).\n"
|
||||
"lower_bound_resize: scale so the shortest side = resolution (preserves more detail on tall/wide images, uses more memory)."),
|
||||
io.DynamicCombo.Input("mode", tooltip="mono: single view image (works with any model variant).\n"
|
||||
"multiview: all images processed together for geometric consistency + camera pose (for Small/Base models only).",
|
||||
options=[
|
||||
io.DynamicCombo.Option("mono", []),
|
||||
io.DynamicCombo.Option("multiview", [
|
||||
io.Combo.Input("ref_view_strategy", options=["saddle_balanced", "saddle_sim_range", "first", "middle"], default="saddle_balanced",
|
||||
tooltip="Which view acts as the geometric anchor.\n"
|
||||
"- saddle_balanced: the view most 'average' across all others (best general choice).\n"
|
||||
"- saddle_sim_range: the view most visually distinct from the others.\n"
|
||||
"- first / middle: fixed positional picks."),
|
||||
io.Combo.Input("pose_method", options=["cam_dec", "ray_pose"], default="cam_dec",
|
||||
tooltip="How the camera field-of-view is estimated (for Small/Base models only).\n"
|
||||
"- cam_dec: learned from image features.\n"
|
||||
"- ray_pose: derived geometrically from the model's 3D ray output.\n"
|
||||
"Affects perspective correctness of the 3D output. Try both if results look distorted."),
|
||||
]),
|
||||
]),
|
||||
],
|
||||
outputs=[
|
||||
DA3Geometry.Output("da3_geometry",
|
||||
tooltip="Dictionary of non-normalized tensors.\n"
|
||||
"- Always: 'depth', 'image', 'mode'.\n"
|
||||
"- Optional: 'sky' (Mono/Metric), 'confidence' (Small/Base), 'extrinsics' + 'intrinsics' (multi-view)."),
|
||||
DA3Geometry.Output("da3_geometry", tooltip="Dictionary of non-normalized tensors.\n"
|
||||
"Always has the keys: depth, image, mode.\n"
|
||||
"Optional keys: sky (for Mono/Metric), confidence (for Small/Base), extrinsics + intrinsics (for multi-view)."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, da3_model, image, process_res, resize_method, mode) -> io.NodeOutput:
|
||||
def execute(cls, da3_model, image, resolution, resize_method, mode) -> io.NodeOutput:
|
||||
mode_val = mode["mode"] # "mono" or "multiview"
|
||||
|
||||
if mode_val == "mono":
|
||||
return cls._execute_mono(da3_model, image, process_res, resize_method)
|
||||
return cls._execute_mono(da3_model, image, resolution, resize_method)
|
||||
|
||||
# Capability checks for multi-view mode.
|
||||
diffusion = da3_model.model.diffusion_model
|
||||
@ -317,13 +291,13 @@ class DA3Inference(io.ComfyNode):
|
||||
)
|
||||
|
||||
return cls._execute_multiview(
|
||||
da3_model, image, process_res, resize_method,
|
||||
da3_model, image, resolution, resize_method,
|
||||
ref_view_strategy, pose_method,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _execute_mono(cls, model, image, process_res, resize_method) -> io.NodeOutput:
|
||||
depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method)
|
||||
def _execute_mono(cls, model, image, resolution, resize_method) -> io.NodeOutput:
|
||||
depth, confidence, sky = _run_da3(model, image, resolution, method=resize_method)
|
||||
|
||||
geometry: dict = {
|
||||
"depth": depth.contiguous(),
|
||||
@ -337,8 +311,7 @@ class DA3Inference(io.ComfyNode):
|
||||
return io.NodeOutput(geometry)
|
||||
|
||||
@classmethod
|
||||
def _execute_multiview(cls, model, image, process_res, resize_method,
|
||||
ref_view_strategy, pose_method) -> io.NodeOutput:
|
||||
def _execute_multiview(cls, model, image, resolution, resize_method, ref_view_strategy, pose_method) -> io.NodeOutput:
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, \
|
||||
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
S, H, W, _ = image.shape
|
||||
@ -350,13 +323,12 @@ class DA3Inference(io.ComfyNode):
|
||||
|
||||
# All views in a single forward pass: (1, S, 3, H', W').
|
||||
x = image.to(device)
|
||||
x = da3_preprocess.preprocess_image(x, process_res=process_res, method=resize_method)
|
||||
x = da3_preprocess.preprocess_image(x, process_res=resolution, method=resize_method)
|
||||
x = x.to(dtype=dtype).unsqueeze(0)
|
||||
|
||||
use_ray_pose = (pose_method == "ray_pose")
|
||||
with torch.no_grad():
|
||||
out = diffusion(x, use_ray_pose=use_ray_pose,
|
||||
ref_view_strategy=ref_view_strategy)
|
||||
out = diffusion(x, use_ray_pose=use_ray_pose, ref_view_strategy=ref_view_strategy)
|
||||
|
||||
depth = torch.nn.functional.interpolate(
|
||||
out["depth"].float().unsqueeze(1), size=(H, W),
|
||||
@ -395,22 +367,19 @@ class DA3Inference(io.ComfyNode):
|
||||
return io.NodeOutput(geometry)
|
||||
|
||||
|
||||
|
||||
|
||||
class DA3Render(io.ComfyNode):
|
||||
"""Render a visualization from a DA3_GEOMETRY packet."""
|
||||
|
||||
_DEPTH_RENDER_INPUTS = [
|
||||
io.Combo.Input("normalization",
|
||||
options=["v2_style", "min_max", "raw"],
|
||||
default="v2_style",
|
||||
tooltip="- v2_style: mean/std normalisation for perceptually balanced results (default).\n"
|
||||
"- min_max: stretches the full depth range to [0, 1] for maximum contrast.\n"
|
||||
"- raw: no scaling - preserves metric units for Metric model."),
|
||||
options=["v2_style", "min_max", "raw"],
|
||||
default="v2_style",
|
||||
tooltip="- v2_style: mean/std normalisation for perceptually balanced results (default).\n"
|
||||
"- min_max: stretches the full depth range to [0, 1] for maximum contrast.\n"
|
||||
"- raw: no scaling,preserves metric units for Metric model."),
|
||||
io.Boolean.Input("apply_sky_clip", default=False,
|
||||
tooltip="Clip sky-region depth to the 99th percentile of foreground depth before "
|
||||
"normalisation. Requires a 'sky' tensor in the da3_geometry input"
|
||||
"provided by Mono/Metric models; raises an error otherwise."),
|
||||
tooltip="Clip sky-region depth to the 99th percentile of foreground depth before normalisation. "
|
||||
"Requires a sky key in the da3_geometry input (for Mono/Metric models only)."),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
@ -419,24 +388,22 @@ class DA3Render(io.ComfyNode):
|
||||
node_id="DA3Render",
|
||||
display_name="Render Depth Anything 3",
|
||||
category="image/geometry estimation",
|
||||
description="Render a depth map, confidence map, or sky mask from DA3 geometry data.",
|
||||
description="Render a depth map, confidence map, or sky mask from Depth Anything 3 geometry data.",
|
||||
inputs=[
|
||||
DA3Geometry.Input("da3_geometry"),
|
||||
io.DynamicCombo.Input("output",
|
||||
tooltip="- depth: normalised greyscale depth image.\n"
|
||||
"- depth_colored: depth mapped through the Turbo colormap.\n"
|
||||
"- sky_mask: sky probability in [0, 1] (for Mono/Metric models only).\n"
|
||||
"- confidence: normalised depth confidence (for Small/Base models only).",
|
||||
options=[
|
||||
tooltip="- depth: normalised greyscale depth image.\n"
|
||||
"- depth_colored: depth mapped through the Turbo colormap.\n"
|
||||
"- sky_mask: sky probability in [0, 1] (for Mono/Metric models only).\n"
|
||||
"- confidence: normalised depth confidence (for Small/Base models only).",
|
||||
options=[
|
||||
io.DynamicCombo.Option("depth", cls._DEPTH_RENDER_INPUTS),
|
||||
io.DynamicCombo.Option("depth_colored", cls._DEPTH_RENDER_INPUTS),
|
||||
io.DynamicCombo.Option("sky_mask", [
|
||||
io.Boolean.Input("colored", default=False,
|
||||
tooltip="Apply the Turbo colormap to the sky mask."),
|
||||
io.Boolean.Input("colored", default=False, tooltip="Apply the Turbo colormap to the sky mask."),
|
||||
]),
|
||||
io.DynamicCombo.Option("confidence", [
|
||||
io.Boolean.Input("colored", default=False,
|
||||
tooltip="Apply the Turbo colormap to the confidence map."),
|
||||
io.Boolean.Input("colored", default=False, tooltip="Apply the Turbo colormap to the confidence map."),
|
||||
]),
|
||||
]),
|
||||
],
|
||||
@ -489,9 +456,9 @@ class DA3Render(io.ComfyNode):
|
||||
return io.NodeOutput(result.float())
|
||||
|
||||
@staticmethod
|
||||
def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None,
|
||||
normalization: str) -> torch.Tensor:
|
||||
def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None, normalization: str) -> torch.Tensor:
|
||||
"""Normalise depth and pack as an (B,H,W,3) image tensor."""
|
||||
|
||||
N = depth.shape[0]
|
||||
if normalization == "v2_style":
|
||||
norm = torch.stack([
|
||||
@ -510,8 +477,6 @@ class DA3Render(io.ComfyNode):
|
||||
return out.contiguous()
|
||||
|
||||
|
||||
|
||||
|
||||
class DA3GeometryToMesh(io.ComfyNode):
|
||||
"""Convert a DA3_GEOMETRY packet into a Types.MESH by unprojecting depth and triangulating."""
|
||||
|
||||
@ -525,28 +490,20 @@ class DA3GeometryToMesh(io.ComfyNode):
|
||||
description="Convert a depth map into a triangulated 3D mesh.",
|
||||
inputs=[
|
||||
DA3Geometry.Input("da3_geometry"),
|
||||
io.Int.Input("batch_index", default=0, min=0, max=4096,
|
||||
tooltip="Which image of a batch to convert. "
|
||||
"Per-image vertex counts differ so batches cannot be stacked."),
|
||||
io.Int.Input("decimation", default=1, min=1, max=8,
|
||||
tooltip="Vertex stride; 1 = full resolution, 2 = half, etc."),
|
||||
io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Drop triangles whose 3×3 depth span exceeds this fraction. 0 = off."),
|
||||
io.Int.Input("batch_index", default=0, min=0, max=4096, tooltip="Which image of a batch to convert. Per-image vertex counts differ so batches cannot be stacked."),
|
||||
io.Int.Input("decimation", default=1, min=1, max=8, tooltip="Vertex stride. 1 = full resolution, 2 = half, etc."),
|
||||
io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01, tooltip="Drop triangles whose 3x3 depth span exceeds this fraction. 0 = off."),
|
||||
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all, 1 = keep only the single most confident pixel). "
|
||||
"Used when the geometry has confidence map (Small/Base models)."),
|
||||
io.Boolean.Input("use_sky_mask", default=True,
|
||||
tooltip="Exclude sky-probability pixels (sky >= 0.5) from the mesh. "
|
||||
"Used when the geometry has sky map (Mono/Metric models)."),
|
||||
io.Boolean.Input("texture", default=True,
|
||||
tooltip="Use the source image as a base color texture."),
|
||||
tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all, 1 = keep only the single most confident pixel). "
|
||||
"Used when the geometry has a confidence map (Small/Base models)."),
|
||||
io.Boolean.Input("use_sky_mask", default=True, tooltip="Exclude sky-probability pixels (sky >= 0.5) from the mesh. Used when the geometry has a sky map (Mono/Metric models)."),
|
||||
io.Boolean.Input("texture", default=True, tooltip="Use the source image as a base color texture."),
|
||||
],
|
||||
outputs=[io.Mesh.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, da3_geometry, batch_index, decimation, discontinuity_threshold,
|
||||
confidence_threshold, use_sky_mask, texture) -> io.NodeOutput:
|
||||
def execute(cls, da3_geometry, batch_index, decimation, discontinuity_threshold, confidence_threshold, use_sky_mask, texture) -> io.NodeOutput:
|
||||
depth_all = da3_geometry["depth"] # (B, H, W)
|
||||
B = depth_all.shape[0]
|
||||
if batch_index >= B:
|
||||
@ -627,25 +584,20 @@ class DA3GeometryToPointCloud(io.ComfyNode):
|
||||
description="Convert a depth map into a 3D point cloud.",
|
||||
inputs=[
|
||||
DA3Geometry.Input("da3_geometry"),
|
||||
io.Int.Input("batch_index", default=0, min=0, max=4096,
|
||||
tooltip="Which image of a batch to convert."),
|
||||
io.Int.Input("batch_index", default=0, min=0, max=4096, tooltip="Which image of a batch to convert."),
|
||||
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all). "
|
||||
"Used when the geometry has confidence map (Small/Base models)."),
|
||||
tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all). Used when the geometry has a confidence map (Small/Base models)."),
|
||||
io.Boolean.Input("use_sky_mask", default=True,
|
||||
tooltip="Exclude sky-probability pixels (sky >= 0.5). "
|
||||
"Used when the geometry has sky map (Mono/Metric models)."),
|
||||
tooltip="Exclude sky-probability pixels (sky >= 0.5). Used when the geometry has a sky map (Mono/Metric models)."),
|
||||
io.Int.Input("downsample", default=1, min=1, max=16,
|
||||
tooltip="Take every Nth pixel (1 = full resolution). "
|
||||
"Higher values give fewer points and faster processing."),
|
||||
tooltip="Take every Nth pixel (1 = full resolution). Higher values give fewer points and faster processing."),
|
||||
],
|
||||
# TODO: add a proper PointCloud output type
|
||||
outputs=[DA3PointCloud.Output(display_name="point_cloud")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, da3_geometry, batch_index, confidence_threshold,
|
||||
use_sky_mask, downsample) -> io.NodeOutput:
|
||||
def execute(cls, da3_geometry, batch_index, confidence_threshold, use_sky_mask, downsample) -> io.NodeOutput:
|
||||
depth_all = da3_geometry["depth"] # (B, H, W)
|
||||
B = depth_all.shape[0]
|
||||
if batch_index >= B:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user