ComfyUI/comfy_extras/nodes_depth_anything_3.py
2026-05-26 13:47:35 +02:00

683 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""ComfyUI nodes for Depth Anything 3.
Adds these nodes:
* ``LoadDepthAnything3`` -- load a DA3 ``.safetensors`` file from the
``models/geometry_estimation/`` folder.
* ``DepthAnything3`` -- unified depth estimation node supporting both mono and
multi-view modes via a DynamicCombo selector. Returns a DA3_GEOMETRY dict of
raw tensors (depth, sky, confidence, camera). Feed into ``DepthAnything3Render``
to produce display images, or directly into ``MoGeRender`` for depth / mask views.
* ``DepthAnything3Render`` -- post-processes a DA3_GEOMETRY dict: applies optional
sky clipping, normalises depth and confidence, and returns display images.
Model capability matrix
-----------------------
Variant head_type has_sky has_conf cam_dec
DA3-Small dualdpt False True yes
DA3-Base dualdpt False True yes
DA3-Mono-Large dpt True False no
DA3-Metric-Large dpt True False no (raw output is metres)
The node raises a ``ValueError`` at execution time when the selected
parameters conflict with the loaded model's capabilities (e.g.
``apply_sky_clip=True`` on a model with no sky head).
"""
from __future__ import annotations
from typing_extensions import override
import torch
import comfy.model_management as mm
import comfy.sd
import folder_paths
from comfy.ldm.colormap import turbo as _turbo
from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess
from comfy_api.latest import ComfyExtension, Types, io
from comfy.ldm.moge.geometry import triangulate_grid_mesh
DA3ModelType = io.Custom("DA3_MODEL")
DA3Geometry = io.Custom("DA3_GEOMETRY")
DA3PointCloud = io.Custom("DA3_POINT_CLOUD")
# DA3_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them):
#
# Per-frame tensors — B = batch size in mono mode; B = S (number of views) in multi-view mode.
# "depth": torch.Tensor (B, H, W) -- raw model depth (always present; matches MoGe convention)
# "image": torch.Tensor (B, H, W, 3) -- source image in [0, 1], CPU (always present)
# "mode": str -- "mono" or "multiview" (always present)
# "sky": torch.Tensor (B, H, W) -- sky probability in [0, 1] (Mono/Metric variants only)
# "confidence": torch.Tensor (B, H, W) -- raw model confidence output (Small/Base variants only)
#
# Multi-view only — S = number of views; the leading 1 is the scene dimension from the model.
# "extrinsics": torch.Tensor (1, S, 4, 4) -- world-to-camera matrices
# "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics
#
# DA3_POINT_CLOUD is a dict:
# "points": torch.Tensor (N, 3) -- 3-D coords in glTF convention (Y-up, Z-back)
# "colors": torch.Tensor (N, 3) -- RGB in [0, 1], or None
# "confidence": torch.Tensor (N,) -- raw confidence per point, or None
def _da3_unproject(depth: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
"""Pixel-space K⁻¹ unprojection: (H,W) depth → (H,W,3) point map in OpenCV space."""
H, W = depth.shape
u = torch.arange(W, dtype=torch.float32, device=depth.device)
v = torch.arange(H, dtype=torch.float32, device=depth.device)
u, v = torch.meshgrid(u, v, indexing='xy') # both (H, W)
pix = torch.stack([u, v, torch.ones_like(u)], dim=-1) # (H, W, 3)
rays = torch.einsum('ij,hwj->hwi', torch.linalg.inv(K.to(depth.device)), pix)
return rays * depth.unsqueeze(-1) # (H, W, 3)
def _da3_default_K(H: int, W: int) -> torch.Tensor:
"""Fallback ~60° FOV pinhole K for mono-mode DA3 (no intrinsics in geometry)."""
fx = fy = float(W) * 0.7
return torch.tensor([[fx, 0.0, (W - 1) / 2.0],
[0.0, fy, (H - 1) / 2.0],
[0.0, 0.0, 1.0]], dtype=torch.float32)
def _da3_get_K(geometry: dict, b: int, H: int, W: int) -> torch.Tensor:
"""Return pixel-space K for batch element b, falling back to a default estimate."""
import logging
if "intrinsics" in geometry:
# shape (1, S, 3, 3) — leading scene dimension from the multiview head
return geometry["intrinsics"][0, b].float()
logging.getLogger("comfy").warning(
"DA3_GEOMETRY has no intrinsics (mono-mode model). "
"Using a ~60° FOV estimate; 3-D reconstruction may be inaccurate."
)
return _da3_default_K(H, W)
def _da3_build_mask(geometry: dict, b: int, H: int, W: int,
confidence_threshold: float, use_sky_mask: bool) -> torch.Tensor:
"""Build (H,W) bool keep-mask from sky probability and confidence."""
mask = torch.ones(H, W, dtype=torch.bool)
if use_sky_mask and "sky" in geometry:
mask = mask & (geometry["sky"][b] < 0.5)
if "confidence" in geometry:
mask = mask & (geometry["confidence"][b] >= confidence_threshold)
return mask
class LoadDepthAnything3Model(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadDepthAnything3Model",
display_name="Load Depth Anything 3",
category="loaders",
inputs=[
io.Combo.Input(
"model_name",
options=folder_paths.get_filename_list("geometry_estimation"),
),
io.Combo.Input(
"weight_dtype",
options=["default", "fp16", "bf16", "fp32"],
default="default",
),
],
outputs=[DA3ModelType.Output()],
)
@classmethod
def execute(cls, model_name, weight_dtype) -> io.NodeOutput:
model_options = {}
if weight_dtype == "fp16":
model_options["dtype"] = torch.float16
elif weight_dtype == "bf16":
model_options["dtype"] = torch.bfloat16
elif weight_dtype == "fp32":
model_options["dtype"] = torch.float32
path = folder_paths.get_full_path_or_raise("geometry_estimation", model_name)
model = comfy.sd.load_diffusion_model(path, model_options=model_options)
return io.NodeOutput(model)
def _run_da3(model_patcher, image: torch.Tensor, process_res: int,
method: str = "upper_bound_resize"):
"""Run DA3 on ``(B,H,W,3)`` IMAGE; returns depth/conf/sky at original resolution (or None)."""
assert image.ndim == 4 and image.shape[-1] == 3, \
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
B, H, W, _ = image.shape
mm.load_model_gpu(model_patcher)
diffusion = model_patcher.model.diffusion_model
device = mm.get_torch_device()
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
depths, confs, skies = [], [], []
for i in range(B):
single = image[i:i + 1].to(device)
x = da3_preprocess.preprocess_image(single, process_res=process_res, method=method)
x = x.to(dtype=dtype)
with torch.no_grad():
out = diffusion(x)
depth_lr = out["depth"]
depth_full = torch.nn.functional.interpolate(
depth_lr.unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
depths.append(depth_full)
if "depth_conf" in out:
conf_full = torch.nn.functional.interpolate(
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
confs.append(conf_full)
if "sky" in out:
sky_full = torch.nn.functional.interpolate(
out["sky"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
skies.append(sky_full)
depth = torch.cat(depths, dim=0)
confidence = torch.cat(confs, dim=0) if confs else None
sky = torch.cat(skies, dim=0) if skies else None
return depth, confidence, sky
class DepthAnything3Inference(io.ComfyNode):
"""Raw Depth Anything 3 inference node.
Outputs a DA3_GEOMETRY dict of raw tensors. All display normalization
(sky clipping, depth scaling, confidence normalisation) is handled by
the companion ``DepthAnything3Render`` node.
Mono mode: each batch element is processed independently.
Multi-view mode: all frames share a single forward pass with cross-view
attention; adds ``extrinsics`` and ``intrinsics`` to the geometry dict.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DepthAnything3Inference",
search_aliases=["depth", "geometry", "da3", "depth anything", "monocular", "pointmap", "sky", "3d", "metric depth", "disparity"],
display_name="Run Depth Anything 3",
category="image/geometry_estimation",
description="Run Depth Anything 3 on an image or image batch. In multi-view mode each frame is treated as a separate view of the same scene.",
inputs=[
DA3ModelType.Input("da3_model"),
io.Image.Input("image",
tooltip="Single image or image batch. "
"In multi-view mode each frame is treated as "
"a separate view of the same scene."),
io.Int.Input("process_res", default=504, min=140, max=2520, step=14,
tooltip="Resolution the model runs at (longest side, multiple of 14). "
"Lower = faster / less VRAM; higher = more detail. "
"Output is upsampled back to the original size."),
io.Combo.Input("resize_method",
options=["upper_bound_resize", "lower_bound_resize"],
default="upper_bound_resize",
tooltip="- upper_bound_resize: scale so the longest side = process_res (caps memory, default).\n"
"- lower_bound_resize: scale so the shortest side = process_res (preserves more detail on tall/wide images, uses more memory)."),
io.DynamicCombo.Input("mode",
tooltip="- mono: single image or independent batch — works with any model variant.\n"
"- multiview: all frames processed together for geometric consistency + camera pose — requires DA3-Small or DA3-Base (DA3-Mono-Large / DA3-Metric-Large do NOT support this mode).",
options=[
io.DynamicCombo.Option("mono", []),
io.DynamicCombo.Option("multiview", [
io.Combo.Input("ref_view_strategy",
options=["saddle_balanced", "saddle_sim_range",
"first", "middle"],
default="saddle_balanced",
tooltip="Which view to use as the geometric anchor (only applied when S >= 3 and no extrinsics are provided).\n"
"- saddle_balanced: picks the view whose CLS-token features are closest to the median across similarity, norm and variance — best general choice.\n"
"- saddle_sim_range: picks the view with the widest similarity spread to other views — favours the most distinct viewpoint.\n"
"- first / middle: deterministic positional fallbacks."),
io.Combo.Input("pose_method",
options=["cam_dec", "ray_pose"],
default="cam_dec",
tooltip="- cam_dec: small MLP on the final camera token — works on DA3-Small and DA3-Base.\n"
"- ray_pose: RANSAC over the DualDPT ray output — works on DA3-Small and DA3-Base.\n"
"Both methods require DA3-Small or DA3-Base; this setting is ignored on Mono/Metric-Large."),
]),
]),
],
outputs=[
DA3Geometry.Output("geometry",
tooltip="DA3_GEOMETRY dict of raw tensors.\n"
"- Always: 'depth' (B,H,W), 'image', 'mode'.\n"
"- Optional: 'sky' + 'mask' (Mono/Metric), 'confidence' raw (Small/Base), 'extrinsics' + 'intrinsics' (multi-view)."),
],
)
@classmethod
def execute(cls, da3_model, image, process_res, resize_method, mode) -> io.NodeOutput:
mode_val = mode["mode"] # "mono" or "multiview"
if mode_val == "mono":
return cls._execute_mono(da3_model, image, process_res, resize_method)
# Capability checks for multi-view mode.
diffusion = da3_model.model.diffusion_model
pose_method = mode["pose_method"]
ref_view_strategy = mode["ref_view_strategy"]
has_cam_dec = diffusion.cam_dec is not None
has_dualdpt = diffusion.head_type == "dualdpt"
if not has_cam_dec and not has_dualdpt:
raise ValueError(
"multiview mode requires DA3-Small or DA3-Base — the loaded model "
f"(head_type='{diffusion.head_type}') does not support cross-view "
"attention or camera pose estimation. Switch mode to 'mono', or "
"load DA3-Small / DA3-Base for multiview."
)
if pose_method == "cam_dec" and not has_cam_dec:
raise ValueError(
"pose_method='cam_dec' requires a camera decoder, but the loaded "
f"model (head_type='{diffusion.head_type}') does not have one. "
"Use pose_method='ray_pose' instead."
)
if pose_method == "ray_pose" and not has_dualdpt:
raise ValueError(
"pose_method='ray_pose' requires a DualDPT head, but the loaded "
f"model has a '{diffusion.head_type}' head. "
"Use pose_method='cam_dec' instead."
)
return cls._execute_multiview(
da3_model, image, process_res, resize_method,
ref_view_strategy, pose_method,
)
@classmethod
def _execute_mono(cls, model, image, process_res, resize_method) -> io.NodeOutput:
depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method)
geometry: dict = {
"depth": depth.contiguous(),
"image": image[..., :3].cpu(),
"mode": "mono",
}
if sky is not None:
geometry["sky"] = sky.contiguous()
if confidence is not None:
geometry["confidence"] = confidence.contiguous()
return io.NodeOutput(geometry)
@classmethod
def _execute_multiview(cls, model, image, process_res, resize_method,
ref_view_strategy, pose_method) -> io.NodeOutput:
assert image.ndim == 4 and image.shape[-1] == 3, \
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
S, H, W, _ = image.shape
mm.load_model_gpu(model)
diffusion = model.model.diffusion_model
device = mm.get_torch_device()
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
# All views in a single forward pass: (1, S, 3, H', W').
x = image.to(device)
x = da3_preprocess.preprocess_image(x, process_res=process_res, method=resize_method)
x = x.to(dtype=dtype).unsqueeze(0)
use_ray_pose = (pose_method == "ray_pose")
with torch.no_grad():
out = diffusion(x, use_ray_pose=use_ray_pose,
ref_view_strategy=ref_view_strategy)
depth = torch.nn.functional.interpolate(
out["depth"].float().unsqueeze(1), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
sky = None
if "sky" in out:
sky = torch.nn.functional.interpolate(
out["sky"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
if "extrinsics" in out and "intrinsics" in out:
extrinsics = out["extrinsics"].float().cpu()
intrinsics = out["intrinsics"].float().cpu()
else:
extrinsics = torch.eye(4)[None, None].expand(1, S, 4, 4).clone()
intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone()
geometry: dict = {
"depth": depth.contiguous(),
"image": image[..., :3].cpu(),
"mode": "multiview",
"extrinsics": extrinsics.contiguous(),
"intrinsics": intrinsics.contiguous(),
}
if sky is not None:
geometry["sky"] = sky.contiguous()
if "depth_conf" in out:
conf = torch.nn.functional.interpolate(
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
geometry["confidence"] = conf.contiguous()
return io.NodeOutput(geometry)
class DepthAnything3Render(io.ComfyNode):
"""Visualise a DA3_GEOMETRY packet as a single image.
Mirrors the MoGeRender interface: one ``output`` selector, one IMAGE out.
Use multiple nodes in parallel to get depth + sky + confidence simultaneously.
"""
_DEPTH_RENDER_INPUTS = [
io.Combo.Input("normalization",
options=["v2_style", "min_max", "raw"],
default="v2_style",
tooltip="- v2_style: mean/std normalisation for perceptually balanced results (default).\n"
"- min_max: stretches the full depth range to [0, 1] for maximum contrast.\n"
"- raw: no scaling — preserves metric units for DA3-Metric-Large."),
io.Boolean.Input("apply_sky_clip", default=False,
tooltip="Clip sky-region depth to the 99th percentile of foreground depth before "
"normalisation. Requires a 'sky' tensor in the geometry "
"(DA3-Mono-Large or DA3-Metric-Large); raises an error otherwise."),
]
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DepthAnything3Render",
display_name="Depth Anything 3 Render",
category="image/geometry_estimation",
description="Visualise a DA3_GEOMETRY packet. Drop multiple nodes to get different views simultaneously.",
inputs=[
DA3Geometry.Input("geometry"),
io.DynamicCombo.Input("output",
tooltip="- depth: normalised greyscale depth image.\n"
"- depth_colored: depth mapped through the Turbo colormap.\n"
"- sky_mask: sky probability in [0, 1] (Mono/Metric variants only).\n"
"- confidence: normalised depth confidence (Small/Base variants only).",
options=[
io.DynamicCombo.Option("depth", cls._DEPTH_RENDER_INPUTS),
io.DynamicCombo.Option("depth_colored", cls._DEPTH_RENDER_INPUTS),
io.DynamicCombo.Option("sky_mask", []),
io.DynamicCombo.Option("confidence", []),
]),
],
outputs=[io.Image.Output()],
)
@classmethod
def execute(cls, geometry, output) -> io.NodeOutput:
output_val = output["output"]
if output_val in ("depth", "depth_colored"):
normalization = output["normalization"]
apply_sky_clip = output["apply_sky_clip"]
if apply_sky_clip and "sky" not in geometry:
raise ValueError(
"apply_sky_clip=True requires a sky tensor in the geometry, but none is present. "
"Run with DA3-Mono-Large or DA3-Metric-Large, or set apply_sky_clip=False."
)
depth = geometry["depth"]
sky = geometry.get("sky")
if apply_sky_clip and sky is not None:
depth = torch.stack([
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
for i in range(depth.shape[0])
], dim=0)
grey = cls._depth_to_image(depth, sky, normalization) # (B,H,W,3) greyscale
result = _turbo(grey[..., 0]) if output_val == "depth_colored" else grey
elif output_val == "sky_mask":
if "sky" not in geometry:
raise ValueError("geometry has no sky output; run with DA3-Mono-Large or DA3-Metric-Large.")
sky = geometry["sky"]
result = sky.unsqueeze(-1).expand(*sky.shape, 3).contiguous()
elif output_val == "confidence":
if "confidence" not in geometry:
raise ValueError("geometry has no confidence output; run with DA3-Small or DA3-Base.")
result = cls._normalize_confidence(geometry["confidence"])
result = result.unsqueeze(-1).expand(*result.shape, 3).contiguous()
else:
raise ValueError(f"Unknown output mode: {output_val}")
return io.NodeOutput(result.float())
@staticmethod
def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None,
normalization: str) -> torch.Tensor:
"""Normalise depth and pack as an (B,H,W,3) image tensor."""
N = depth.shape[0]
if normalization == "v2_style":
norm = torch.stack([
da3_preprocess.normalize_depth_v2_style(
depth[i], sky_for_norm[i] if sky_for_norm is not None else None)
for i in range(N)
], dim=0)
elif normalization == "min_max":
norm = da3_preprocess.normalize_depth_min_max(depth)
else:
norm = depth
out = norm.unsqueeze(-1).repeat(1, 1, 1, 3)
if normalization != "raw":
out = out.clamp(0.0, 1.0)
return out.contiguous()
@staticmethod
def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor:
"""Map raw confidence (expp1 activaton, range [1, ∞)) to [0, 1] per image.
The model uses ``exp(x) + 1`` so every pixel is guaranteed to be ≥ 1.
Min-max normalization per image preserves the spatial pattern (high
confidence = brighter) while producing a valid mask in [0, 1].
"""
B = conf.shape[0]
out = []
for i in range(B):
c = conf[i]
c_min = c.min()
c_max = c.max()
if c_max > c_min:
out.append((c - c_min) / (c_max - c_min))
else:
out.append(torch.ones_like(c))
return torch.stack(out, dim=0)
class DA3GeometryToMesh(io.ComfyNode):
"""Convert a DA3_GEOMETRY packet into a Types.MESH by unprojecting depth and triangulating."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DA3GeometryToMesh",
search_aliases=["da3", "depth anything", "mesh", "geometry", "3d", "triangulate"],
display_name="DA3 Geometry to Mesh",
category="image/geometry_estimation",
description="Convert a DA3_GEOMETRY depth map into a triangulated 3D mesh (Types.MESH).",
inputs=[
DA3Geometry.Input("da3_geometry"),
io.Int.Input("batch_index", default=0, min=0, max=4096,
tooltip="Which frame of a batched DA3_GEOMETRY to mesh. "
"Per-frame vertex counts differ so batches cannot be stacked."),
io.Int.Input("decimation", default=1, min=1, max=8,
tooltip="Vertex stride; 1 = full resolution, 2 = half, etc."),
io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01,
tooltip="Drop triangles whose 3×3 depth span exceeds this fraction. 0 = off."),
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
tooltip="Exclude pixels with raw confidence below this value. "
"Ignored when the geometry has no confidence map (Mono/Metric models)."),
io.Boolean.Input("use_sky_mask", default=True,
tooltip="Exclude sky-probability pixels (sky >= 0.5) from the mesh. "
"Ignored when the geometry has no sky map (Small/Base models)."),
io.Boolean.Input("texture", default=True,
tooltip="Carry the source image through as the baseColor texture."),
],
outputs=[io.Mesh.Output()],
)
@classmethod
def execute(cls, da3_geometry, batch_index, decimation, discontinuity_threshold,
confidence_threshold, use_sky_mask, texture) -> io.NodeOutput:
depth_all = da3_geometry["depth"] # (B, H, W)
B = depth_all.shape[0]
if batch_index >= B:
raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.")
depth = depth_all[batch_index] # (H, W)
H, W = depth.shape
K = _da3_get_K(da3_geometry, batch_index, H, W)
points = _da3_unproject(depth, K) # (H, W, 3) in OpenCV space
# Mask invalid pixels by setting them to inf so triangulate_grid_mesh skips them.
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
points = points.clone()
points[~mask] = float('inf')
verts, faces, uvs = triangulate_grid_mesh(
points,
decimation=decimation,
discontinuity_threshold=discontinuity_threshold,
depth=depth,
)
if verts.shape[0] == 0 or faces.shape[0] == 0:
raise ValueError(
"DA3GeometryToMesh produced an empty mesh. "
"Try raising discontinuity_threshold, lowering confidence_threshold, "
"or disabling use_sky_mask."
)
# OpenCV (X right, Y down, Z forward) → glTF (X right, Y up, Z back).
# Same transform as MoGePointMapToMesh perspective branch.
verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype)
faces = faces[:, [0, 2, 1]].contiguous()
tex = da3_geometry["image"][batch_index:batch_index + 1] if texture else None
mesh = Types.MESH(
vertices=verts.unsqueeze(0),
faces=faces.unsqueeze(0),
uvs=uvs.unsqueeze(0),
texture=tex,
)
return io.NodeOutput(mesh)
class DA3GeometryToPointCloud(io.ComfyNode):
"""Unproject a DA3_GEOMETRY depth map into a filtered DA3_POINT_CLOUD."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DA3GeometryToPointCloud",
search_aliases=["da3", "depth anything", "point cloud", "pointcloud", "3d", "geometry"],
display_name="DA3 Geometry to Point Cloud",
category="image/geometry_estimation",
description="Unproject a DA3_GEOMETRY depth map into a 3D point cloud (DA3_POINT_CLOUD).",
inputs=[
DA3Geometry.Input("da3_geometry"),
io.Int.Input("batch_index", default=0, min=0, max=4096,
tooltip="Which frame of a batched DA3_GEOMETRY to convert."),
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
tooltip="Exclude pixels with raw confidence below this value. "
"Ignored when the geometry has no confidence map."),
io.Boolean.Input("use_sky_mask", default=True,
tooltip="Exclude sky-probability pixels (sky >= 0.5). "
"Ignored when the geometry has no sky map."),
io.Int.Input("downsample", default=1, min=1, max=16,
tooltip="Take every Nth pixel (1 = full resolution). "
"Higher values give fewer points and faster processing."),
],
# TODO: add a proper PointCloud output type
outputs=[DA3PointCloud.Output(display_name="point_cloud")],
)
@classmethod
def execute(cls, da3_geometry, batch_index, confidence_threshold,
use_sky_mask, downsample) -> io.NodeOutput:
depth_all = da3_geometry["depth"] # (B, H, W)
B = depth_all.shape[0]
if batch_index >= B:
raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.")
depth = depth_all[batch_index] # (H, W)
H, W = depth.shape
K = _da3_get_K(da3_geometry, batch_index, H, W)
if downsample > 1:
depth = depth[::downsample, ::downsample].contiguous()
# Scale intrinsics to the downsampled grid.
K = K.clone()
K[0, :] /= downsample
K[1, :] /= downsample
H_ds, W_ds = depth.shape
points = _da3_unproject(depth, K) # (H_ds, W_ds, 3)
# Rebuild mask at downsampled resolution.
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
if downsample > 1:
mask = mask[::downsample, ::downsample]
mask = mask & torch.isfinite(depth)
# OpenCV → glTF: flip Y and Z.
points_gltf = points.clone()
points_gltf[..., 1] *= -1.0
points_gltf[..., 2] *= -1.0
pts_flat = points_gltf.reshape(-1, 3)[mask.reshape(-1)]
colors_flat = None
if "image" in da3_geometry:
img = da3_geometry["image"][batch_index] # (H, W, 3)
if downsample > 1:
img = img[::downsample, ::downsample]
colors_flat = img.reshape(-1, 3)[mask.reshape(-1)]
conf_flat = None
if "confidence" in da3_geometry:
conf = da3_geometry["confidence"][batch_index] # (H, W)
if downsample > 1:
conf = conf[::downsample, ::downsample]
conf_flat = conf.reshape(-1)[mask.reshape(-1)]
if pts_flat.shape[0] == 0:
raise ValueError(
"DA3GeometryToPointCloud produced zero points after filtering. "
"Try lowering confidence_threshold or disabling use_sky_mask."
)
return io.NodeOutput({
"points": pts_flat,
"colors": colors_flat,
"confidence": conf_flat,
})
class DepthAnything3Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LoadDepthAnything3Model,
DepthAnything3Inference,
DepthAnything3Render,
DA3GeometryToMesh,
# DA3GeometryToPointCloud, # Keep this commented out for now until we have a proper PointCloud output type
]
async def comfy_entrypoint() -> DepthAnything3Extension:
return DepthAnything3Extension()