diff --git a/comfy_extras/nodes_depth_anything_3.py b/comfy_extras/nodes_depth_anything_3.py index a80ad4afd..ee54c3171 100644 --- a/comfy_extras/nodes_depth_anything_3.py +++ b/comfy_extras/nodes_depth_anything_3.py @@ -35,10 +35,12 @@ import comfy.sd import folder_paths from comfy.ldm.colormap import turbo as _turbo from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess -from comfy_api.latest import ComfyExtension, io +from comfy_api.latest import ComfyExtension, Types, io +from comfy.ldm.moge.geometry import triangulate_grid_mesh DA3ModelType = io.Custom("DA3_MODEL") DA3Geometry = io.Custom("DA3_GEOMETRY") +DA3PointCloud = io.Custom("DA3_POINT_CLOUD") # DA3_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them): # @@ -52,6 +54,54 @@ DA3Geometry = io.Custom("DA3_GEOMETRY") # Multi-view only — S = number of views; the leading 1 is the scene dimension from the model. # "extrinsics": torch.Tensor (1, S, 4, 4) -- world-to-camera matrices # "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics +# +# DA3_POINT_CLOUD is a dict: +# "points": torch.Tensor (N, 3) -- 3-D coords in glTF convention (Y-up, Z-back) +# "colors": torch.Tensor (N, 3) -- RGB in [0, 1], or None +# "confidence": torch.Tensor (N,) -- raw confidence per point, or None + + +def _da3_unproject(depth: torch.Tensor, K: torch.Tensor) -> torch.Tensor: + """Pixel-space K⁻¹ unprojection: (H,W) depth → (H,W,3) point map in OpenCV space.""" + H, W = depth.shape + u = torch.arange(W, dtype=torch.float32, device=depth.device) + v = torch.arange(H, dtype=torch.float32, device=depth.device) + u, v = torch.meshgrid(u, v, indexing='xy') # both (H, W) + pix = torch.stack([u, v, torch.ones_like(u)], dim=-1) # (H, W, 3) + rays = torch.einsum('ij,hwj->hwi', torch.linalg.inv(K.to(depth.device)), pix) + return rays * depth.unsqueeze(-1) # (H, W, 3) + + +def _da3_default_K(H: int, W: int) -> torch.Tensor: + """Fallback ~60° FOV pinhole K for mono-mode DA3 (no intrinsics in geometry).""" + fx = fy = float(W) * 0.7 + return torch.tensor([[fx, 0.0, (W - 1) / 2.0], + [0.0, fy, (H - 1) / 2.0], + [0.0, 0.0, 1.0]], dtype=torch.float32) + + +def _da3_get_K(geometry: dict, b: int, H: int, W: int) -> torch.Tensor: + """Return pixel-space K for batch element b, falling back to a default estimate.""" + import logging + if "intrinsics" in geometry: + # shape (1, S, 3, 3) — leading scene dimension from the multiview head + return geometry["intrinsics"][0, b].float() + logging.getLogger("comfy").warning( + "DA3_GEOMETRY has no intrinsics (mono-mode model). " + "Using a ~60° FOV estimate; 3-D reconstruction may be inaccurate." + ) + return _da3_default_K(H, W) + + +def _da3_build_mask(geometry: dict, b: int, H: int, W: int, + confidence_threshold: float, use_sky_mask: bool) -> torch.Tensor: + """Build (H,W) bool keep-mask from sky probability and confidence.""" + mask = torch.ones(H, W, dtype=torch.bool) + if use_sky_mask and "sky" in geometry: + mask = mask & (geometry["sky"][b] < 0.5) + if "confidence" in geometry: + mask = mask & (geometry["confidence"][b] >= confidence_threshold) + return mask class LoadDepthAnything3Model(io.ComfyNode): @@ -444,6 +494,178 @@ class DepthAnything3Render(io.ComfyNode): return torch.stack(out, dim=0) +class DA3GeometryToMesh(io.ComfyNode): + """Convert a DA3_GEOMETRY packet into a Types.MESH by unprojecting depth and triangulating.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DA3GeometryToMesh", + search_aliases=["da3", "depth anything", "mesh", "geometry", "3d", "triangulate"], + display_name="DA3 Geometry to Mesh", + category="image/geometry_estimation", + description="Convert a DA3_GEOMETRY depth map into a triangulated 3D mesh (Types.MESH).", + inputs=[ + DA3Geometry.Input("da3_geometry"), + io.Int.Input("batch_index", default=0, min=0, max=4096, + tooltip="Which frame of a batched DA3_GEOMETRY to mesh. " + "Per-frame vertex counts differ so batches cannot be stacked."), + io.Int.Input("decimation", default=1, min=1, max=8, + tooltip="Vertex stride; 1 = full resolution, 2 = half, etc."), + io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01, + tooltip="Drop triangles whose 3×3 depth span exceeds this fraction. 0 = off."), + io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01, + tooltip="Exclude pixels with raw confidence below this value. " + "Ignored when the geometry has no confidence map (Mono/Metric models)."), + io.Boolean.Input("use_sky_mask", default=True, + tooltip="Exclude sky-probability pixels (sky >= 0.5) from the mesh. " + "Ignored when the geometry has no sky map (Small/Base models)."), + io.Boolean.Input("texture", default=True, + tooltip="Carry the source image through as the baseColor texture."), + ], + outputs=[io.Mesh.Output()], + ) + + @classmethod + def execute(cls, da3_geometry, batch_index, decimation, discontinuity_threshold, + confidence_threshold, use_sky_mask, texture) -> io.NodeOutput: + depth_all = da3_geometry["depth"] # (B, H, W) + B = depth_all.shape[0] + if batch_index >= B: + raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.") + + depth = depth_all[batch_index] # (H, W) + H, W = depth.shape + + K = _da3_get_K(da3_geometry, batch_index, H, W) + points = _da3_unproject(depth, K) # (H, W, 3) in OpenCV space + + # Mask invalid pixels by setting them to inf so triangulate_grid_mesh skips them. + mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask) + points = points.clone() + points[~mask] = float('inf') + + verts, faces, uvs = triangulate_grid_mesh( + points, + decimation=decimation, + discontinuity_threshold=discontinuity_threshold, + depth=depth, + ) + if verts.shape[0] == 0 or faces.shape[0] == 0: + raise ValueError( + "DA3GeometryToMesh produced an empty mesh. " + "Try raising discontinuity_threshold, lowering confidence_threshold, " + "or disabling use_sky_mask." + ) + + # OpenCV (X right, Y down, Z forward) → glTF (X right, Y up, Z back). + # Same transform as MoGePointMapToMesh perspective branch. + verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype) + faces = faces[:, [0, 2, 1]].contiguous() + + tex = da3_geometry["image"][batch_index:batch_index + 1] if texture else None + mesh = Types.MESH( + vertices=verts.unsqueeze(0), + faces=faces.unsqueeze(0), + uvs=uvs.unsqueeze(0), + texture=tex, + ) + return io.NodeOutput(mesh) + + +class DA3GeometryToPointCloud(io.ComfyNode): + """Unproject a DA3_GEOMETRY depth map into a filtered DA3_POINT_CLOUD.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DA3GeometryToPointCloud", + search_aliases=["da3", "depth anything", "point cloud", "pointcloud", "3d", "geometry"], + display_name="DA3 Geometry to Point Cloud", + category="image/geometry_estimation", + description="Unproject a DA3_GEOMETRY depth map into a 3D point cloud (DA3_POINT_CLOUD).", + inputs=[ + DA3Geometry.Input("da3_geometry"), + io.Int.Input("batch_index", default=0, min=0, max=4096, + tooltip="Which frame of a batched DA3_GEOMETRY to convert."), + io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01, + tooltip="Exclude pixels with raw confidence below this value. " + "Ignored when the geometry has no confidence map."), + io.Boolean.Input("use_sky_mask", default=True, + tooltip="Exclude sky-probability pixels (sky >= 0.5). " + "Ignored when the geometry has no sky map."), + io.Int.Input("downsample", default=1, min=1, max=16, + tooltip="Take every Nth pixel (1 = full resolution). " + "Higher values give fewer points and faster processing."), + ], + # TODO: add a proper PointCloud output type + outputs=[DA3PointCloud.Output(display_name="point_cloud")], + ) + + @classmethod + def execute(cls, da3_geometry, batch_index, confidence_threshold, + use_sky_mask, downsample) -> io.NodeOutput: + depth_all = da3_geometry["depth"] # (B, H, W) + B = depth_all.shape[0] + if batch_index >= B: + raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.") + + depth = depth_all[batch_index] # (H, W) + H, W = depth.shape + + K = _da3_get_K(da3_geometry, batch_index, H, W) + + if downsample > 1: + depth = depth[::downsample, ::downsample].contiguous() + # Scale intrinsics to the downsampled grid. + K = K.clone() + K[0, :] /= downsample + K[1, :] /= downsample + + H_ds, W_ds = depth.shape + points = _da3_unproject(depth, K) # (H_ds, W_ds, 3) + + # Rebuild mask at downsampled resolution. + mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask) + if downsample > 1: + mask = mask[::downsample, ::downsample] + + mask = mask & torch.isfinite(depth) + + # OpenCV → glTF: flip Y and Z. + points_gltf = points.clone() + points_gltf[..., 1] *= -1.0 + points_gltf[..., 2] *= -1.0 + + pts_flat = points_gltf.reshape(-1, 3)[mask.reshape(-1)] + + colors_flat = None + if "image" in da3_geometry: + img = da3_geometry["image"][batch_index] # (H, W, 3) + if downsample > 1: + img = img[::downsample, ::downsample] + colors_flat = img.reshape(-1, 3)[mask.reshape(-1)] + + conf_flat = None + if "confidence" in da3_geometry: + conf = da3_geometry["confidence"][batch_index] # (H, W) + if downsample > 1: + conf = conf[::downsample, ::downsample] + conf_flat = conf.reshape(-1)[mask.reshape(-1)] + + if pts_flat.shape[0] == 0: + raise ValueError( + "DA3GeometryToPointCloud produced zero points after filtering. " + "Try lowering confidence_threshold or disabling use_sky_mask." + ) + + return io.NodeOutput({ + "points": pts_flat, + "colors": colors_flat, + "confidence": conf_flat, + }) + + class DepthAnything3Extension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -451,6 +673,8 @@ class DepthAnything3Extension(ComfyExtension): LoadDepthAnything3Model, DepthAnything3Inference, DepthAnything3Render, + DA3GeometryToMesh, + # DA3GeometryToPointCloud, # Keep this commented out for now until we have a proper PointCloud output type ]