mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 17:37:39 +08:00
Add DA3GeometryToMesh and DA3GeometryToPointCloud
This commit is contained in:
parent
7cb2394630
commit
2ed1f36471
@ -35,10 +35,12 @@ import comfy.sd
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy.ldm.colormap import turbo as _turbo
|
from comfy.ldm.colormap import turbo as _turbo
|
||||||
from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess
|
from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, Types, io
|
||||||
|
from comfy.ldm.moge.geometry import triangulate_grid_mesh
|
||||||
|
|
||||||
DA3ModelType = io.Custom("DA3_MODEL")
|
DA3ModelType = io.Custom("DA3_MODEL")
|
||||||
DA3Geometry = io.Custom("DA3_GEOMETRY")
|
DA3Geometry = io.Custom("DA3_GEOMETRY")
|
||||||
|
DA3PointCloud = io.Custom("DA3_POINT_CLOUD")
|
||||||
|
|
||||||
# DA3_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them):
|
# DA3_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them):
|
||||||
#
|
#
|
||||||
@ -52,6 +54,54 @@ DA3Geometry = io.Custom("DA3_GEOMETRY")
|
|||||||
# Multi-view only — S = number of views; the leading 1 is the scene dimension from the model.
|
# Multi-view only — S = number of views; the leading 1 is the scene dimension from the model.
|
||||||
# "extrinsics": torch.Tensor (1, S, 4, 4) -- world-to-camera matrices
|
# "extrinsics": torch.Tensor (1, S, 4, 4) -- world-to-camera matrices
|
||||||
# "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics
|
# "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics
|
||||||
|
#
|
||||||
|
# DA3_POINT_CLOUD is a dict:
|
||||||
|
# "points": torch.Tensor (N, 3) -- 3-D coords in glTF convention (Y-up, Z-back)
|
||||||
|
# "colors": torch.Tensor (N, 3) -- RGB in [0, 1], or None
|
||||||
|
# "confidence": torch.Tensor (N,) -- raw confidence per point, or None
|
||||||
|
|
||||||
|
|
||||||
|
def _da3_unproject(depth: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Pixel-space K⁻¹ unprojection: (H,W) depth → (H,W,3) point map in OpenCV space."""
|
||||||
|
H, W = depth.shape
|
||||||
|
u = torch.arange(W, dtype=torch.float32, device=depth.device)
|
||||||
|
v = torch.arange(H, dtype=torch.float32, device=depth.device)
|
||||||
|
u, v = torch.meshgrid(u, v, indexing='xy') # both (H, W)
|
||||||
|
pix = torch.stack([u, v, torch.ones_like(u)], dim=-1) # (H, W, 3)
|
||||||
|
rays = torch.einsum('ij,hwj->hwi', torch.linalg.inv(K.to(depth.device)), pix)
|
||||||
|
return rays * depth.unsqueeze(-1) # (H, W, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def _da3_default_K(H: int, W: int) -> torch.Tensor:
|
||||||
|
"""Fallback ~60° FOV pinhole K for mono-mode DA3 (no intrinsics in geometry)."""
|
||||||
|
fx = fy = float(W) * 0.7
|
||||||
|
return torch.tensor([[fx, 0.0, (W - 1) / 2.0],
|
||||||
|
[0.0, fy, (H - 1) / 2.0],
|
||||||
|
[0.0, 0.0, 1.0]], dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def _da3_get_K(geometry: dict, b: int, H: int, W: int) -> torch.Tensor:
|
||||||
|
"""Return pixel-space K for batch element b, falling back to a default estimate."""
|
||||||
|
import logging
|
||||||
|
if "intrinsics" in geometry:
|
||||||
|
# shape (1, S, 3, 3) — leading scene dimension from the multiview head
|
||||||
|
return geometry["intrinsics"][0, b].float()
|
||||||
|
logging.getLogger("comfy").warning(
|
||||||
|
"DA3_GEOMETRY has no intrinsics (mono-mode model). "
|
||||||
|
"Using a ~60° FOV estimate; 3-D reconstruction may be inaccurate."
|
||||||
|
)
|
||||||
|
return _da3_default_K(H, W)
|
||||||
|
|
||||||
|
|
||||||
|
def _da3_build_mask(geometry: dict, b: int, H: int, W: int,
|
||||||
|
confidence_threshold: float, use_sky_mask: bool) -> torch.Tensor:
|
||||||
|
"""Build (H,W) bool keep-mask from sky probability and confidence."""
|
||||||
|
mask = torch.ones(H, W, dtype=torch.bool)
|
||||||
|
if use_sky_mask and "sky" in geometry:
|
||||||
|
mask = mask & (geometry["sky"][b] < 0.5)
|
||||||
|
if "confidence" in geometry:
|
||||||
|
mask = mask & (geometry["confidence"][b] >= confidence_threshold)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
class LoadDepthAnything3Model(io.ComfyNode):
|
class LoadDepthAnything3Model(io.ComfyNode):
|
||||||
@ -444,6 +494,178 @@ class DepthAnything3Render(io.ComfyNode):
|
|||||||
return torch.stack(out, dim=0)
|
return torch.stack(out, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class DA3GeometryToMesh(io.ComfyNode):
|
||||||
|
"""Convert a DA3_GEOMETRY packet into a Types.MESH by unprojecting depth and triangulating."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="DA3GeometryToMesh",
|
||||||
|
search_aliases=["da3", "depth anything", "mesh", "geometry", "3d", "triangulate"],
|
||||||
|
display_name="DA3 Geometry to Mesh",
|
||||||
|
category="image/geometry_estimation",
|
||||||
|
description="Convert a DA3_GEOMETRY depth map into a triangulated 3D mesh (Types.MESH).",
|
||||||
|
inputs=[
|
||||||
|
DA3Geometry.Input("da3_geometry"),
|
||||||
|
io.Int.Input("batch_index", default=0, min=0, max=4096,
|
||||||
|
tooltip="Which frame of a batched DA3_GEOMETRY to mesh. "
|
||||||
|
"Per-frame vertex counts differ so batches cannot be stacked."),
|
||||||
|
io.Int.Input("decimation", default=1, min=1, max=8,
|
||||||
|
tooltip="Vertex stride; 1 = full resolution, 2 = half, etc."),
|
||||||
|
io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01,
|
||||||
|
tooltip="Drop triangles whose 3×3 depth span exceeds this fraction. 0 = off."),
|
||||||
|
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
|
||||||
|
tooltip="Exclude pixels with raw confidence below this value. "
|
||||||
|
"Ignored when the geometry has no confidence map (Mono/Metric models)."),
|
||||||
|
io.Boolean.Input("use_sky_mask", default=True,
|
||||||
|
tooltip="Exclude sky-probability pixels (sky >= 0.5) from the mesh. "
|
||||||
|
"Ignored when the geometry has no sky map (Small/Base models)."),
|
||||||
|
io.Boolean.Input("texture", default=True,
|
||||||
|
tooltip="Carry the source image through as the baseColor texture."),
|
||||||
|
],
|
||||||
|
outputs=[io.Mesh.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, da3_geometry, batch_index, decimation, discontinuity_threshold,
|
||||||
|
confidence_threshold, use_sky_mask, texture) -> io.NodeOutput:
|
||||||
|
depth_all = da3_geometry["depth"] # (B, H, W)
|
||||||
|
B = depth_all.shape[0]
|
||||||
|
if batch_index >= B:
|
||||||
|
raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.")
|
||||||
|
|
||||||
|
depth = depth_all[batch_index] # (H, W)
|
||||||
|
H, W = depth.shape
|
||||||
|
|
||||||
|
K = _da3_get_K(da3_geometry, batch_index, H, W)
|
||||||
|
points = _da3_unproject(depth, K) # (H, W, 3) in OpenCV space
|
||||||
|
|
||||||
|
# Mask invalid pixels by setting them to inf so triangulate_grid_mesh skips them.
|
||||||
|
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
|
||||||
|
points = points.clone()
|
||||||
|
points[~mask] = float('inf')
|
||||||
|
|
||||||
|
verts, faces, uvs = triangulate_grid_mesh(
|
||||||
|
points,
|
||||||
|
decimation=decimation,
|
||||||
|
discontinuity_threshold=discontinuity_threshold,
|
||||||
|
depth=depth,
|
||||||
|
)
|
||||||
|
if verts.shape[0] == 0 or faces.shape[0] == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"DA3GeometryToMesh produced an empty mesh. "
|
||||||
|
"Try raising discontinuity_threshold, lowering confidence_threshold, "
|
||||||
|
"or disabling use_sky_mask."
|
||||||
|
)
|
||||||
|
|
||||||
|
# OpenCV (X right, Y down, Z forward) → glTF (X right, Y up, Z back).
|
||||||
|
# Same transform as MoGePointMapToMesh perspective branch.
|
||||||
|
verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype)
|
||||||
|
faces = faces[:, [0, 2, 1]].contiguous()
|
||||||
|
|
||||||
|
tex = da3_geometry["image"][batch_index:batch_index + 1] if texture else None
|
||||||
|
mesh = Types.MESH(
|
||||||
|
vertices=verts.unsqueeze(0),
|
||||||
|
faces=faces.unsqueeze(0),
|
||||||
|
uvs=uvs.unsqueeze(0),
|
||||||
|
texture=tex,
|
||||||
|
)
|
||||||
|
return io.NodeOutput(mesh)
|
||||||
|
|
||||||
|
|
||||||
|
class DA3GeometryToPointCloud(io.ComfyNode):
|
||||||
|
"""Unproject a DA3_GEOMETRY depth map into a filtered DA3_POINT_CLOUD."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="DA3GeometryToPointCloud",
|
||||||
|
search_aliases=["da3", "depth anything", "point cloud", "pointcloud", "3d", "geometry"],
|
||||||
|
display_name="DA3 Geometry to Point Cloud",
|
||||||
|
category="image/geometry_estimation",
|
||||||
|
description="Unproject a DA3_GEOMETRY depth map into a 3D point cloud (DA3_POINT_CLOUD).",
|
||||||
|
inputs=[
|
||||||
|
DA3Geometry.Input("da3_geometry"),
|
||||||
|
io.Int.Input("batch_index", default=0, min=0, max=4096,
|
||||||
|
tooltip="Which frame of a batched DA3_GEOMETRY to convert."),
|
||||||
|
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
|
||||||
|
tooltip="Exclude pixels with raw confidence below this value. "
|
||||||
|
"Ignored when the geometry has no confidence map."),
|
||||||
|
io.Boolean.Input("use_sky_mask", default=True,
|
||||||
|
tooltip="Exclude sky-probability pixels (sky >= 0.5). "
|
||||||
|
"Ignored when the geometry has no sky map."),
|
||||||
|
io.Int.Input("downsample", default=1, min=1, max=16,
|
||||||
|
tooltip="Take every Nth pixel (1 = full resolution). "
|
||||||
|
"Higher values give fewer points and faster processing."),
|
||||||
|
],
|
||||||
|
# TODO: add a proper PointCloud output type
|
||||||
|
outputs=[DA3PointCloud.Output(display_name="point_cloud")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, da3_geometry, batch_index, confidence_threshold,
|
||||||
|
use_sky_mask, downsample) -> io.NodeOutput:
|
||||||
|
depth_all = da3_geometry["depth"] # (B, H, W)
|
||||||
|
B = depth_all.shape[0]
|
||||||
|
if batch_index >= B:
|
||||||
|
raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.")
|
||||||
|
|
||||||
|
depth = depth_all[batch_index] # (H, W)
|
||||||
|
H, W = depth.shape
|
||||||
|
|
||||||
|
K = _da3_get_K(da3_geometry, batch_index, H, W)
|
||||||
|
|
||||||
|
if downsample > 1:
|
||||||
|
depth = depth[::downsample, ::downsample].contiguous()
|
||||||
|
# Scale intrinsics to the downsampled grid.
|
||||||
|
K = K.clone()
|
||||||
|
K[0, :] /= downsample
|
||||||
|
K[1, :] /= downsample
|
||||||
|
|
||||||
|
H_ds, W_ds = depth.shape
|
||||||
|
points = _da3_unproject(depth, K) # (H_ds, W_ds, 3)
|
||||||
|
|
||||||
|
# Rebuild mask at downsampled resolution.
|
||||||
|
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
|
||||||
|
if downsample > 1:
|
||||||
|
mask = mask[::downsample, ::downsample]
|
||||||
|
|
||||||
|
mask = mask & torch.isfinite(depth)
|
||||||
|
|
||||||
|
# OpenCV → glTF: flip Y and Z.
|
||||||
|
points_gltf = points.clone()
|
||||||
|
points_gltf[..., 1] *= -1.0
|
||||||
|
points_gltf[..., 2] *= -1.0
|
||||||
|
|
||||||
|
pts_flat = points_gltf.reshape(-1, 3)[mask.reshape(-1)]
|
||||||
|
|
||||||
|
colors_flat = None
|
||||||
|
if "image" in da3_geometry:
|
||||||
|
img = da3_geometry["image"][batch_index] # (H, W, 3)
|
||||||
|
if downsample > 1:
|
||||||
|
img = img[::downsample, ::downsample]
|
||||||
|
colors_flat = img.reshape(-1, 3)[mask.reshape(-1)]
|
||||||
|
|
||||||
|
conf_flat = None
|
||||||
|
if "confidence" in da3_geometry:
|
||||||
|
conf = da3_geometry["confidence"][batch_index] # (H, W)
|
||||||
|
if downsample > 1:
|
||||||
|
conf = conf[::downsample, ::downsample]
|
||||||
|
conf_flat = conf.reshape(-1)[mask.reshape(-1)]
|
||||||
|
|
||||||
|
if pts_flat.shape[0] == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"DA3GeometryToPointCloud produced zero points after filtering. "
|
||||||
|
"Try lowering confidence_threshold or disabling use_sky_mask."
|
||||||
|
)
|
||||||
|
|
||||||
|
return io.NodeOutput({
|
||||||
|
"points": pts_flat,
|
||||||
|
"colors": colors_flat,
|
||||||
|
"confidence": conf_flat,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
class DepthAnything3Extension(ComfyExtension):
|
class DepthAnything3Extension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -451,6 +673,8 @@ class DepthAnything3Extension(ComfyExtension):
|
|||||||
LoadDepthAnything3Model,
|
LoadDepthAnything3Model,
|
||||||
DepthAnything3Inference,
|
DepthAnything3Inference,
|
||||||
DepthAnything3Render,
|
DepthAnything3Render,
|
||||||
|
DA3GeometryToMesh,
|
||||||
|
# DA3GeometryToPointCloud, # Keep this commented out for now until we have a proper PointCloud output type
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user