Add DA3GeometryToMesh and DA3GeometryToPointCloud

This commit is contained in:
Talmaj Marinc 2026-05-26 13:46:22 +02:00
parent 7cb2394630
commit 2ed1f36471

View File

@ -35,10 +35,12 @@ import comfy.sd
import folder_paths import folder_paths
from comfy.ldm.colormap import turbo as _turbo from comfy.ldm.colormap import turbo as _turbo
from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, Types, io
from comfy.ldm.moge.geometry import triangulate_grid_mesh
DA3ModelType = io.Custom("DA3_MODEL") DA3ModelType = io.Custom("DA3_MODEL")
DA3Geometry = io.Custom("DA3_GEOMETRY") DA3Geometry = io.Custom("DA3_GEOMETRY")
DA3PointCloud = io.Custom("DA3_POINT_CLOUD")
# DA3_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them): # DA3_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them):
# #
@ -52,6 +54,54 @@ DA3Geometry = io.Custom("DA3_GEOMETRY")
# Multi-view only — S = number of views; the leading 1 is the scene dimension from the model. # Multi-view only — S = number of views; the leading 1 is the scene dimension from the model.
# "extrinsics": torch.Tensor (1, S, 4, 4) -- world-to-camera matrices # "extrinsics": torch.Tensor (1, S, 4, 4) -- world-to-camera matrices
# "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics # "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics
#
# DA3_POINT_CLOUD is a dict:
# "points": torch.Tensor (N, 3) -- 3-D coords in glTF convention (Y-up, Z-back)
# "colors": torch.Tensor (N, 3) -- RGB in [0, 1], or None
# "confidence": torch.Tensor (N,) -- raw confidence per point, or None
def _da3_unproject(depth: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
"""Pixel-space K⁻¹ unprojection: (H,W) depth → (H,W,3) point map in OpenCV space."""
H, W = depth.shape
u = torch.arange(W, dtype=torch.float32, device=depth.device)
v = torch.arange(H, dtype=torch.float32, device=depth.device)
u, v = torch.meshgrid(u, v, indexing='xy') # both (H, W)
pix = torch.stack([u, v, torch.ones_like(u)], dim=-1) # (H, W, 3)
rays = torch.einsum('ij,hwj->hwi', torch.linalg.inv(K.to(depth.device)), pix)
return rays * depth.unsqueeze(-1) # (H, W, 3)
def _da3_default_K(H: int, W: int) -> torch.Tensor:
"""Fallback ~60° FOV pinhole K for mono-mode DA3 (no intrinsics in geometry)."""
fx = fy = float(W) * 0.7
return torch.tensor([[fx, 0.0, (W - 1) / 2.0],
[0.0, fy, (H - 1) / 2.0],
[0.0, 0.0, 1.0]], dtype=torch.float32)
def _da3_get_K(geometry: dict, b: int, H: int, W: int) -> torch.Tensor:
"""Return pixel-space K for batch element b, falling back to a default estimate."""
import logging
if "intrinsics" in geometry:
# shape (1, S, 3, 3) — leading scene dimension from the multiview head
return geometry["intrinsics"][0, b].float()
logging.getLogger("comfy").warning(
"DA3_GEOMETRY has no intrinsics (mono-mode model). "
"Using a ~60° FOV estimate; 3-D reconstruction may be inaccurate."
)
return _da3_default_K(H, W)
def _da3_build_mask(geometry: dict, b: int, H: int, W: int,
confidence_threshold: float, use_sky_mask: bool) -> torch.Tensor:
"""Build (H,W) bool keep-mask from sky probability and confidence."""
mask = torch.ones(H, W, dtype=torch.bool)
if use_sky_mask and "sky" in geometry:
mask = mask & (geometry["sky"][b] < 0.5)
if "confidence" in geometry:
mask = mask & (geometry["confidence"][b] >= confidence_threshold)
return mask
class LoadDepthAnything3Model(io.ComfyNode): class LoadDepthAnything3Model(io.ComfyNode):
@ -444,6 +494,178 @@ class DepthAnything3Render(io.ComfyNode):
return torch.stack(out, dim=0) return torch.stack(out, dim=0)
class DA3GeometryToMesh(io.ComfyNode):
"""Convert a DA3_GEOMETRY packet into a Types.MESH by unprojecting depth and triangulating."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DA3GeometryToMesh",
search_aliases=["da3", "depth anything", "mesh", "geometry", "3d", "triangulate"],
display_name="DA3 Geometry to Mesh",
category="image/geometry_estimation",
description="Convert a DA3_GEOMETRY depth map into a triangulated 3D mesh (Types.MESH).",
inputs=[
DA3Geometry.Input("da3_geometry"),
io.Int.Input("batch_index", default=0, min=0, max=4096,
tooltip="Which frame of a batched DA3_GEOMETRY to mesh. "
"Per-frame vertex counts differ so batches cannot be stacked."),
io.Int.Input("decimation", default=1, min=1, max=8,
tooltip="Vertex stride; 1 = full resolution, 2 = half, etc."),
io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01,
tooltip="Drop triangles whose 3×3 depth span exceeds this fraction. 0 = off."),
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
tooltip="Exclude pixels with raw confidence below this value. "
"Ignored when the geometry has no confidence map (Mono/Metric models)."),
io.Boolean.Input("use_sky_mask", default=True,
tooltip="Exclude sky-probability pixels (sky >= 0.5) from the mesh. "
"Ignored when the geometry has no sky map (Small/Base models)."),
io.Boolean.Input("texture", default=True,
tooltip="Carry the source image through as the baseColor texture."),
],
outputs=[io.Mesh.Output()],
)
@classmethod
def execute(cls, da3_geometry, batch_index, decimation, discontinuity_threshold,
confidence_threshold, use_sky_mask, texture) -> io.NodeOutput:
depth_all = da3_geometry["depth"] # (B, H, W)
B = depth_all.shape[0]
if batch_index >= B:
raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.")
depth = depth_all[batch_index] # (H, W)
H, W = depth.shape
K = _da3_get_K(da3_geometry, batch_index, H, W)
points = _da3_unproject(depth, K) # (H, W, 3) in OpenCV space
# Mask invalid pixels by setting them to inf so triangulate_grid_mesh skips them.
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
points = points.clone()
points[~mask] = float('inf')
verts, faces, uvs = triangulate_grid_mesh(
points,
decimation=decimation,
discontinuity_threshold=discontinuity_threshold,
depth=depth,
)
if verts.shape[0] == 0 or faces.shape[0] == 0:
raise ValueError(
"DA3GeometryToMesh produced an empty mesh. "
"Try raising discontinuity_threshold, lowering confidence_threshold, "
"or disabling use_sky_mask."
)
# OpenCV (X right, Y down, Z forward) → glTF (X right, Y up, Z back).
# Same transform as MoGePointMapToMesh perspective branch.
verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype)
faces = faces[:, [0, 2, 1]].contiguous()
tex = da3_geometry["image"][batch_index:batch_index + 1] if texture else None
mesh = Types.MESH(
vertices=verts.unsqueeze(0),
faces=faces.unsqueeze(0),
uvs=uvs.unsqueeze(0),
texture=tex,
)
return io.NodeOutput(mesh)
class DA3GeometryToPointCloud(io.ComfyNode):
"""Unproject a DA3_GEOMETRY depth map into a filtered DA3_POINT_CLOUD."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DA3GeometryToPointCloud",
search_aliases=["da3", "depth anything", "point cloud", "pointcloud", "3d", "geometry"],
display_name="DA3 Geometry to Point Cloud",
category="image/geometry_estimation",
description="Unproject a DA3_GEOMETRY depth map into a 3D point cloud (DA3_POINT_CLOUD).",
inputs=[
DA3Geometry.Input("da3_geometry"),
io.Int.Input("batch_index", default=0, min=0, max=4096,
tooltip="Which frame of a batched DA3_GEOMETRY to convert."),
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
tooltip="Exclude pixels with raw confidence below this value. "
"Ignored when the geometry has no confidence map."),
io.Boolean.Input("use_sky_mask", default=True,
tooltip="Exclude sky-probability pixels (sky >= 0.5). "
"Ignored when the geometry has no sky map."),
io.Int.Input("downsample", default=1, min=1, max=16,
tooltip="Take every Nth pixel (1 = full resolution). "
"Higher values give fewer points and faster processing."),
],
# TODO: add a proper PointCloud output type
outputs=[DA3PointCloud.Output(display_name="point_cloud")],
)
@classmethod
def execute(cls, da3_geometry, batch_index, confidence_threshold,
use_sky_mask, downsample) -> io.NodeOutput:
depth_all = da3_geometry["depth"] # (B, H, W)
B = depth_all.shape[0]
if batch_index >= B:
raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.")
depth = depth_all[batch_index] # (H, W)
H, W = depth.shape
K = _da3_get_K(da3_geometry, batch_index, H, W)
if downsample > 1:
depth = depth[::downsample, ::downsample].contiguous()
# Scale intrinsics to the downsampled grid.
K = K.clone()
K[0, :] /= downsample
K[1, :] /= downsample
H_ds, W_ds = depth.shape
points = _da3_unproject(depth, K) # (H_ds, W_ds, 3)
# Rebuild mask at downsampled resolution.
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
if downsample > 1:
mask = mask[::downsample, ::downsample]
mask = mask & torch.isfinite(depth)
# OpenCV → glTF: flip Y and Z.
points_gltf = points.clone()
points_gltf[..., 1] *= -1.0
points_gltf[..., 2] *= -1.0
pts_flat = points_gltf.reshape(-1, 3)[mask.reshape(-1)]
colors_flat = None
if "image" in da3_geometry:
img = da3_geometry["image"][batch_index] # (H, W, 3)
if downsample > 1:
img = img[::downsample, ::downsample]
colors_flat = img.reshape(-1, 3)[mask.reshape(-1)]
conf_flat = None
if "confidence" in da3_geometry:
conf = da3_geometry["confidence"][batch_index] # (H, W)
if downsample > 1:
conf = conf[::downsample, ::downsample]
conf_flat = conf.reshape(-1)[mask.reshape(-1)]
if pts_flat.shape[0] == 0:
raise ValueError(
"DA3GeometryToPointCloud produced zero points after filtering. "
"Try lowering confidence_threshold or disabling use_sky_mask."
)
return io.NodeOutput({
"points": pts_flat,
"colors": colors_flat,
"confidence": conf_flat,
})
class DepthAnything3Extension(ComfyExtension): class DepthAnything3Extension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -451,6 +673,8 @@ class DepthAnything3Extension(ComfyExtension):
LoadDepthAnything3Model, LoadDepthAnything3Model,
DepthAnything3Inference, DepthAnything3Inference,
DepthAnything3Render, DepthAnything3Render,
DA3GeometryToMesh,
# DA3GeometryToPointCloud, # Keep this commented out for now until we have a proper PointCloud output type
] ]