Create a unified DepthAnything3 node.

This commit is contained in:
Talmaj Marinc 2026-05-19 14:37:17 +02:00
parent ad94b046ee
commit 2686038f94

View File

@ -5,13 +5,23 @@ Adds these nodes:
* ``LoadDepthAnything3`` -- load a DA3 ``.safetensors`` file from the
``models/depth_estimation/`` folder. Falls back to ``models/diffusion_models/``
so existing installations keep working.
* ``DepthAnything3Depth`` -- run depth estimation and return a normalised
depth map as a ComfyUI ``IMAGE`` (visualisation / ControlNet input).
* ``DepthAnything3DepthRaw`` -- run depth estimation and return the raw depth,
confidence and sky channels as ``MASK`` outputs.
* ``DepthAnything3MultiView`` -- multi-view path: depth + per-view extrinsics
+ intrinsics. Pose is decoded either from the camera-decoder MLP (default)
or from the auxiliary ray output via RANSAC (DA3-Small/Base only).
* ``DepthAnything3`` -- unified depth estimation node supporting both mono and
multi-view modes via a DynamicCombo selector. In mono mode, returns a
normalised depth image plus sky/confidence masks. In multi-view mode,
additionally returns per-view extrinsics, intrinsics and raw depth packed
as a LATENT.
Model capability matrix
-----------------------
Variant head_type has_sky has_conf cam_dec
DA3-Small dualdpt False True yes
DA3-Base dualdpt False True yes
DA3-Mono-Large dpt True False no
DA3-Metric-Large dpt True False no (raw output is metres)
The node raises a ``ValueError`` at execution time when the selected
parameters conflict with the loaded model's capabilities (e.g.
``apply_sky_clip=True`` on a model with no sky head).
"""
from __future__ import annotations
@ -26,12 +36,6 @@ import folder_paths
from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess
from comfy_api.latest import ComfyExtension, io
# -----------------------------------------------------------------------------
# Loader
# -----------------------------------------------------------------------------
class LoadDepthAnything3(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -75,11 +79,11 @@ class LoadDepthAnything3(io.ComfyNode):
def _run_da3(model_patcher, image: torch.Tensor, process_res: int,
method: str = "upper_bound_resize"):
"""Run the DA3 network on a (B, H, W, 3) ``IMAGE`` batch.
"""Run the DA3 network on a (B, H, W, 3) IMAGE batch.
Returns ``(depth, confidence, sky)`` tensors with the original image
resolution. Any of ``confidence`` / ``sky`` may be ``None`` depending on
the variant.
Returns ``(depth, confidence, sky)`` tensors at the original image
resolution. ``confidence`` / ``sky`` are ``None`` when the variant does
not produce them.
"""
assert image.ndim == 4 and image.shape[-1] == 3, \
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
@ -91,8 +95,7 @@ def _run_da3(model_patcher, image: torch.Tensor, process_res: int,
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
depths, confs, skies = [], [], []
# Process one image at a time to keep peak memory predictable; DA3 is
# an inference-only model and per-sample latency dominates anyway.
# Process one image at a time to keep peak memory predictable.
for i in range(B):
single = image[i:i + 1].to(device)
x = da3_preprocess.preprocess_image(single, process_res=process_res, method=method)
@ -101,7 +104,6 @@ def _run_da3(model_patcher, image: torch.Tensor, process_res: int,
out = diffusion(x)
depth_lr = out["depth"]
# Resize back to the original (H, W).
depth_full = torch.nn.functional.interpolate(
depth_lr.unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
@ -127,21 +129,50 @@ def _run_da3(model_patcher, image: torch.Tensor, process_res: int,
return depth, confidence, sky
# -----------------------------------------------------------------------------
# Depth -> visualisation IMAGE
# -----------------------------------------------------------------------------
class DepthAnything3(io.ComfyNode):
"""Unified Depth Anything 3 node.
Mono mode
---------
Runs the model on each batch element independently and returns a
normalised depth image together with sky and confidence masks.
Multi-view mode
---------------
Treats every batch element as a separate view of the same scene.
Runs all views in a single forward pass so cross-view attention can
establish geometric consistency. Additionally returns a ``LATENT``
dict with per-view camera extrinsics, intrinsics and raw depth.
Capability errors
-----------------
A ``ValueError`` is raised immediately when a parameter requires a
model feature that is absent in the loaded checkpoint (e.g.
``apply_sky_clip=True`` on DA3-Small/Base which has no sky head,
or ``pose_method='cam_dec'`` on a monocular model).
Camera LATENT structure (multi-view only)
-----------------------------------------
samples: (1, S, 1, H, W) -- raw depth packed as latent samples
type: "da3_multiview"
extrinsics: (1, S, 4, 4) -- world-to-camera matrices
intrinsics: (1, S, 3, 3) -- pixel-space intrinsics
depth_raw: (S, H, W) -- un-normalised depth
confidence: (S, H, W) -- per-pixel confidence (zeros if N/A)
"""
class DepthAnything3Depth(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DepthAnything3Depth",
display_name="Depth Anything 3 (Depth)",
node_id="DepthAnything3",
display_name="Depth Anything 3",
category="image/depth",
inputs=[
io.Model.Input("model"),
io.Image.Input("image"),
io.Image.Input("image",
tooltip="Single image or image batch. "
"In multi-view mode each frame is treated as "
"a separate view of the same scene."),
io.Int.Input("process_res", default=504, min=140, max=2520, step=14,
tooltip="Longest-side target resolution (multiple of 14)."),
io.Combo.Input("resize_method",
@ -150,22 +181,93 @@ class DepthAnything3Depth(io.ComfyNode):
io.Combo.Input("normalization",
options=["v2_style", "min_max", "raw"],
default="v2_style",
tooltip="How to map raw depth -> [0, 1] image."),
tooltip="How to map raw depth to [0, 1] for the output image. "
"'raw' preserves absolute values — use this to keep "
"metric units when running DA3-Metric-Large."),
io.Boolean.Input("apply_sky_clip", default=True,
tooltip="(Mono/Metric only) clip sky depth to 99th percentile."),
tooltip="Clip sky-region depth to the 99th percentile before "
"normalisation. Requires a sky segmentation head "
"(DA3-Mono-Large or DA3-Metric-Large). "
"Raises an error on DA3-Small/Base."),
io.DynamicCombo.Input("mode", options=[
io.DynamicCombo.Option("mono", []),
io.DynamicCombo.Option("multiview", [
io.Combo.Input("ref_view_strategy",
options=["saddle_balanced", "saddle_sim_range",
"first", "middle"],
default="saddle_balanced",
tooltip="Reference view selection strategy (applied when "
"S >= 3 and no extrinsics are provided)."),
io.Combo.Input("pose_method",
options=["cam_dec", "ray_pose"],
default="cam_dec",
tooltip="cam_dec: small MLP on the final camera token "
"(DA3-Small/Base). "
"ray_pose: RANSAC over the DualDPT ray output "
"(DA3-Small/Base only)."),
]),
]),
],
outputs=[
io.Image.Output("depth_image"),
io.Mask.Output("sky_mask",
tooltip="Sky probability (Mono/Metric variants), else zeros."),
tooltip="Sky probability mask (Mono/Metric variants). "
"Zeros for Small/Base."),
io.Mask.Output("confidence",
tooltip="Depth confidence (Small/Base/DualDPT variants), else zeros."),
tooltip="Depth confidence (Small/Base variants). "
"Zeros for Mono/Metric."),
io.Latent.Output("camera",
tooltip="Multi-view: per-view extrinsics + intrinsics + raw depth. "
"In mono mode this is an empty placeholder."),
],
)
@classmethod
def execute(cls, model, image, process_res, resize_method, normalization,
apply_sky_clip) -> io.NodeOutput:
apply_sky_clip, mode) -> io.NodeOutput:
diffusion = model.model.diffusion_model
mode_val = mode["mode"] # "mono" or "multiview"
# Capability check for sky clip — fires in both modes.
if apply_sky_clip and not diffusion.has_sky:
raise ValueError(
"apply_sky_clip=True requires a sky segmentation head, but the loaded "
"model does not have one. Set apply_sky_clip=False, or load a model "
"that includes a sky head (e.g. DA3-Mono-Large or DA3-Metric-Large)."
)
if mode_val == "mono":
return cls._execute_mono(
model, image, process_res, resize_method,
normalization, apply_sky_clip,
)
# Capability checks for multi-view pose.
pose_method = mode["pose_method"]
ref_view_strategy = mode["ref_view_strategy"]
if pose_method == "cam_dec" and diffusion.cam_dec is None:
raise ValueError(
"pose_method='cam_dec' requires a camera decoder, but the loaded "
"model does not have one. Load a model with a camera decoder "
"(e.g. DA3-Small or DA3-Base), or set pose_method='ray_pose'."
)
if pose_method == "ray_pose" and diffusion.head_type != "dualdpt":
raise ValueError(
"pose_method='ray_pose' requires a DualDPT head, but the loaded "
"model has a DPT head. Load a model with a DualDPT head "
"(e.g. DA3-Small or DA3-Base), or set pose_method='cam_dec'."
)
return cls._execute_multiview(
model, image, process_res, resize_method,
normalization, apply_sky_clip,
ref_view_strategy, pose_method,
)
@classmethod
def _execute_mono(cls, model, image, process_res, resize_method,
normalization, apply_sky_clip) -> io.NodeOutput:
depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method)
if apply_sky_clip and sky is not None:
@ -176,8 +278,8 @@ class DepthAnything3Depth(io.ComfyNode):
if normalization == "v2_style":
norm = torch.stack([
da3_preprocess.normalize_depth_v2_style(depth[i],
sky[i] if sky is not None else None)
da3_preprocess.normalize_depth_v2_style(
depth[i], sky[i] if sky is not None else None)
for i in range(depth.shape[0])
], dim=0)
elif normalization == "min_max":
@ -185,83 +287,21 @@ class DepthAnything3Depth(io.ComfyNode):
else:
norm = depth
# (B, H, W) -> (B, H, W, 3) grayscale IMAGE.
out_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous()
sky_mask = sky if sky is not None else torch.zeros_like(depth)
conf_mask = confidence if confidence is not None else torch.zeros_like(depth)
return io.NodeOutput(out_image, sky_mask.contiguous(), conf_mask.contiguous())
# -----------------------------------------------------------------------------
# Raw depth output (useful for downstream metric work)
# -----------------------------------------------------------------------------
class DepthAnything3MultiView(io.ComfyNode):
"""Multi-view depth + pose estimation for DA3-Small / DA3-Base / DA3-Large.
Treats each batch element of the input ``IMAGE`` as a separate view of
the same scene. The selected reference view is auto-chosen by the
backbone via ``ref_view_strategy`` (when at least 3 views are
supplied), unless camera extrinsics are provided -- in which case the
geometry is pinned by the user and no reordering is done.
Output structure:
* ``depth_image`` -- per-view normalised depth as a stacked ``IMAGE``
batch (one frame per view, original input order).
* ``confidence`` / ``sky`` -- per-view masks (zero when the variant
does not produce them).
* ``camera`` -- ``LATENT`` dict with keys::
samples: (1, S, 1, h_p, w_p) -- raw depth packed as latent
type: "da3_multiview"
extrinsics: (1, S, 4, 4) world-to-camera matrices
intrinsics: (1, S, 3, 3) pixel-space intrinsics
depth_raw: (S, H, W) un-normalised depth
confidence: (S, H, W)
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DepthAnything3MultiView",
display_name="Depth Anything 3 (Multi-View)",
category="image/depth",
inputs=[
io.Model.Input("model"),
io.Image.Input("image",
tooltip="Image batch where each frame is a view of the same scene."),
io.Int.Input("process_res", default=504, min=140, max=2520, step=14,
tooltip="Longest-side target resolution (multiple of 14)."),
io.Combo.Input("resize_method",
options=["upper_bound_resize", "lower_bound_resize"],
default="upper_bound_resize"),
io.Combo.Input("ref_view_strategy",
options=["saddle_balanced", "saddle_sim_range", "first", "middle"],
default="saddle_balanced",
tooltip="Reference view selection (only applied when "
"S>=3 and no extrinsics are provided)."),
io.Combo.Input("pose_method",
options=["cam_dec", "ray_pose"],
default="cam_dec",
tooltip="cam_dec: small MLP on the final cam token (works for "
"all variants with cam_dec). ray_pose: RANSAC over the "
"DualDPT auxiliary ray output (DA3-Small/Base only)."),
io.Combo.Input("normalization",
options=["v2_style", "min_max", "raw"],
default="v2_style"),
],
outputs=[
io.Image.Output("depth_image"),
io.Mask.Output("confidence"),
io.Mask.Output("sky_mask"),
io.Latent.Output("camera",
tooltip="Per-view extrinsics + intrinsics + raw depth."),
],
camera = {"samples": torch.zeros(1, 1, 1, 1, 1), "type": "mono"}
return io.NodeOutput(
out_image,
sky_mask.contiguous(),
conf_mask.contiguous(),
camera,
)
@classmethod
def execute(cls, model, image, process_res, resize_method, ref_view_strategy,
pose_method, normalization) -> io.NodeOutput:
def _execute_multiview(cls, model, image, process_res, resize_method,
normalization, apply_sky_clip,
ref_view_strategy, pose_method) -> io.NodeOutput:
assert image.ndim == 4 and image.shape[-1] == 3, \
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
S, H, W, _ = image.shape
@ -271,40 +311,41 @@ class DepthAnything3MultiView(io.ComfyNode):
device = mm.get_torch_device()
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
# Stack all views as a single batch element with views axis = S.
# All views in a single forward pass: (1, S, 3, H', W').
x = image.to(device)
x = da3_preprocess.preprocess_image(x, process_res=process_res, method=resize_method)
x = x.to(dtype=dtype).unsqueeze(0) # (1, S, 3, H', W')
x = x.to(dtype=dtype).unsqueeze(0)
use_ray_pose = (pose_method == "ray_pose")
with torch.no_grad():
out = diffusion(x, use_ray_pose=use_ray_pose,
ref_view_strategy=ref_view_strategy)
# ``out["depth"]`` is (S, h_p, w_p); resize back to (S, H, W).
depth_lr = out["depth"].float()
depth = torch.nn.functional.interpolate(
depth_lr.unsqueeze(1), size=(H, W),
out["depth"].float().unsqueeze(1), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
conf = torch.zeros_like(depth)
if "depth_conf" in out:
conf = torch.nn.functional.interpolate(
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
else:
conf = torch.zeros_like(depth)
sky = torch.zeros_like(depth)
if "sky" in out:
sky = torch.nn.functional.interpolate(
out["sky"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
else:
sky = torch.zeros_like(depth)
# Pose. Defaults to identity when neither cam_dec nor ray_pose is wired up.
if apply_sky_clip and "sky" in out:
depth = torch.stack([
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
for i in range(S)
], dim=0)
if "extrinsics" in out and "intrinsics" in out:
extrinsics = out["extrinsics"].float().cpu()
intrinsics = out["intrinsics"].float().cpu()
@ -312,11 +353,11 @@ class DepthAnything3MultiView(io.ComfyNode):
extrinsics = torch.eye(4)[None, None].expand(1, S, 4, 4).clone()
intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone()
# Normalised depth viz per view (same path as the mono node).
sky_for_norm = sky if diffusion.has_sky else None
if normalization == "v2_style":
norm = torch.stack([
da3_preprocess.normalize_depth_v2_style(depth[i],
sky[i] if "sky" in out else None)
da3_preprocess.normalize_depth_v2_style(
depth[i], sky_for_norm[i] if sky_for_norm is not None else None)
for i in range(S)
], dim=0)
elif normalization == "min_max":
@ -327,8 +368,6 @@ class DepthAnything3MultiView(io.ComfyNode):
depth_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous()
camera_latent = {
# The Latent contract requires a ``samples`` field; pack the raw
# depth there so a downstream node still has a tensor to chain on.
"samples": depth.unsqueeze(0).unsqueeze(2).contiguous(), # (1, S, 1, H, W)
"type": "da3_multiview",
"extrinsics": extrinsics.contiguous(),
@ -338,59 +377,18 @@ class DepthAnything3MultiView(io.ComfyNode):
}
return io.NodeOutput(
depth_image,
conf.contiguous(),
sky.contiguous(),
conf.contiguous(),
camera_latent,
)
class DepthAnything3DepthRaw(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DepthAnything3DepthRaw",
display_name="Depth Anything 3 (Raw Depth)",
category="image/depth",
inputs=[
io.Model.Input("model"),
io.Image.Input("image"),
io.Int.Input("process_res", default=504, min=140, max=2520, step=14),
io.Combo.Input("resize_method",
options=["upper_bound_resize", "lower_bound_resize"],
default="upper_bound_resize"),
],
outputs=[
io.Mask.Output("depth",
tooltip="Raw depth values (no normalisation, no clipping)."),
io.Mask.Output("confidence"),
io.Mask.Output("sky"),
],
)
@classmethod
def execute(cls, model, image, process_res, resize_method) -> io.NodeOutput:
depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method)
zeros = torch.zeros_like(depth)
return io.NodeOutput(
depth.contiguous(),
(confidence if confidence is not None else zeros).contiguous(),
(sky if sky is not None else zeros).contiguous(),
)
# -----------------------------------------------------------------------------
# Extension registration
# -----------------------------------------------------------------------------
class DepthAnything3Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LoadDepthAnything3,
DepthAnything3Depth,
DepthAnything3DepthRaw,
DepthAnything3MultiView,
DepthAnything3,
]