mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-29 02:17:52 +08:00
Fix confidence use in DA3ToMesh by using normalization, fix extrisinc usage.
This commit is contained in:
parent
2dd9d96d4a
commit
15c096aa16
@ -26,6 +26,7 @@ parameters conflict with the loaded model's capabilities (e.g.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
@ -52,7 +53,7 @@ DA3PointCloud = io.Custom("DA3_POINT_CLOUD")
|
||||
# "confidence": torch.Tensor (B, H, W) -- raw model confidence output (Small/Base variants only)
|
||||
#
|
||||
# Multi-view only — S = number of views; the leading 1 is the scene dimension from the model.
|
||||
# "extrinsics": torch.Tensor (1, S, 4, 4) -- world-to-camera matrices
|
||||
# "extrinsics": torch.Tensor (1, S, 3, 4) -- world-to-camera [R|t] matrices
|
||||
# "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics
|
||||
#
|
||||
# DA3_POINT_CLOUD is a dict:
|
||||
@ -82,7 +83,6 @@ def _da3_default_K(H: int, W: int) -> torch.Tensor:
|
||||
|
||||
def _da3_get_K(geometry: dict, b: int, H: int, W: int) -> torch.Tensor:
|
||||
"""Return pixel-space K for batch element b, falling back to a default estimate."""
|
||||
import logging
|
||||
if "intrinsics" in geometry:
|
||||
# shape (1, S, 3, 3) — leading scene dimension from the multiview head
|
||||
return geometry["intrinsics"][0, b].float()
|
||||
@ -93,14 +93,68 @@ def _da3_get_K(geometry: dict, b: int, H: int, W: int) -> torch.Tensor:
|
||||
return _da3_default_K(H, W)
|
||||
|
||||
|
||||
def _da3_get_extrinsic(geometry: dict, b: int) -> torch.Tensor | None:
|
||||
"""Return the world-to-camera extrinsic for batch element b, or None in mono mode.
|
||||
|
||||
The model outputs (1, S, 3, 4) [R|t] matrices; the fallback identity is (4, 4).
|
||||
_da3_apply_extrinsic handles both shapes via [:3, :3] / [:3, 3] slicing.
|
||||
"""
|
||||
if "extrinsics" not in geometry:
|
||||
return None
|
||||
return geometry["extrinsics"][0, b].float()
|
||||
|
||||
|
||||
def _da3_apply_extrinsic(points_cam: torch.Tensor, E: torch.Tensor) -> torch.Tensor:
|
||||
"""Transform (H,W,3) OpenCV camera-space points to world space.
|
||||
|
||||
E is the world-to-camera SE(3) matrix (3×4 or 4×4). The camera-to-world
|
||||
inverse is computed analytically as [Rᵀ | −Rᵀt] rather than via
|
||||
torch.linalg.inv to avoid numerical failures on near-degenerate poses.
|
||||
|
||||
Returns the original camera-space points unchanged if E contains non-finite
|
||||
values (failed pose estimation), so the node can still produce a mesh.
|
||||
"""
|
||||
E = E.to(points_cam.device).float()
|
||||
if not torch.isfinite(E).all():
|
||||
logging.getLogger("comfy").warning(
|
||||
"DA3 extrinsic matrix contains non-finite values (pose estimation may have failed). "
|
||||
"Falling back to camera-space coordinates."
|
||||
)
|
||||
return points_cam
|
||||
H, W, _ = points_cam.shape
|
||||
R = E[:3, :3] # (3, 3) rotation
|
||||
t = E[:3, 3] # (3,) translation
|
||||
R_inv = R.T # rotation inverse = transpose for orthogonal R
|
||||
t_inv = -(R_inv @ t) # (3,)
|
||||
pts = points_cam.reshape(-1, 3) # (N, 3)
|
||||
pts_world = pts @ R_inv.T + t_inv # (N, 3)
|
||||
return pts_world.reshape(H, W, 3)
|
||||
|
||||
|
||||
def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor:
|
||||
"""Map raw confidence (exp(x)+1 activation, range [1, ∞)) to [0, 1] per image.
|
||||
|
||||
Min-max per image preserves the spatial pattern while producing a [0, 1]
|
||||
value suitable for both display and masking.
|
||||
"""
|
||||
B = conf.shape[0]
|
||||
out = []
|
||||
for i in range(B):
|
||||
c = conf[i]
|
||||
c_min, c_max = c.min(), c.max()
|
||||
out.append((c - c_min) / (c_max - c_min) if c_max > c_min else torch.ones_like(c))
|
||||
return torch.stack(out, dim=0)
|
||||
|
||||
|
||||
def _da3_build_mask(geometry: dict, b: int, H: int, W: int,
|
||||
confidence_threshold: float, use_sky_mask: bool) -> torch.Tensor:
|
||||
"""Build (H,W) bool keep-mask from sky probability and confidence."""
|
||||
mask = torch.ones(H, W, dtype=torch.bool)
|
||||
if use_sky_mask and "sky" in geometry:
|
||||
mask = mask & (geometry["sky"][b] < 0.5)
|
||||
if "confidence" in geometry:
|
||||
mask = mask & (geometry["confidence"][b] >= confidence_threshold)
|
||||
if "confidence" in geometry and confidence_threshold > 0.0:
|
||||
conf_norm = _normalize_confidence(geometry["confidence"][b:b + 1])[0]
|
||||
mask = mask & (conf_norm >= confidence_threshold)
|
||||
return mask
|
||||
|
||||
|
||||
@ -444,7 +498,7 @@ class DepthAnything3Render(io.ComfyNode):
|
||||
elif output_val == "confidence":
|
||||
if "confidence" not in geometry:
|
||||
raise ValueError("geometry has no confidence output; run with DA3-Small or DA3-Base.")
|
||||
result = cls._normalize_confidence(geometry["confidence"])
|
||||
result = _normalize_confidence(geometry["confidence"])
|
||||
result = result.unsqueeze(-1).expand(*result.shape, 3).contiguous()
|
||||
|
||||
else:
|
||||
@ -473,25 +527,7 @@ class DepthAnything3Render(io.ComfyNode):
|
||||
out = out.clamp(0.0, 1.0)
|
||||
return out.contiguous()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor:
|
||||
"""Map raw confidence (expp1 activaton, range [1, ∞)) to [0, 1] per image.
|
||||
|
||||
The model uses ``exp(x) + 1`` so every pixel is guaranteed to be ≥ 1.
|
||||
Min-max normalization per image preserves the spatial pattern (high
|
||||
confidence = brighter) while producing a valid mask in [0, 1].
|
||||
"""
|
||||
B = conf.shape[0]
|
||||
out = []
|
||||
for i in range(B):
|
||||
c = conf[i]
|
||||
c_min = c.min()
|
||||
c_max = c.max()
|
||||
if c_max > c_min:
|
||||
out.append((c - c_min) / (c_max - c_min))
|
||||
else:
|
||||
out.append(torch.ones_like(c))
|
||||
return torch.stack(out, dim=0)
|
||||
|
||||
|
||||
class DA3GeometryToMesh(io.ComfyNode):
|
||||
@ -515,7 +551,7 @@ class DA3GeometryToMesh(io.ComfyNode):
|
||||
io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Drop triangles whose 3×3 depth span exceeds this fraction. 0 = off."),
|
||||
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Exclude pixels with raw confidence below this value. "
|
||||
tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all, 1 = keep only the single most confident pixel). "
|
||||
"Ignored when the geometry has no confidence map (Mono/Metric models)."),
|
||||
io.Boolean.Input("use_sky_mask", default=True,
|
||||
tooltip="Exclude sky-probability pixels (sky >= 0.5) from the mesh. "
|
||||
@ -537,11 +573,34 @@ class DA3GeometryToMesh(io.ComfyNode):
|
||||
depth = depth_all[batch_index] # (H, W)
|
||||
H, W = depth.shape
|
||||
|
||||
# NaN/inf depth would propagate silently through unproject and produce an
|
||||
# empty mesh; replace them with 0 here so those pixels are later excluded
|
||||
# by the isfinite check inside triangulate_grid_mesh.
|
||||
depth = depth.clone()
|
||||
n_bad = (~torch.isfinite(depth)).sum().item()
|
||||
if n_bad:
|
||||
logging.getLogger("comfy").warning(
|
||||
f"DA3GeometryToMesh: depth[{batch_index}] has {n_bad} non-finite pixels "
|
||||
f"({100*n_bad/(H*W):.1f}%) — zeroed before unproject."
|
||||
)
|
||||
depth[~torch.isfinite(depth)] = 0.0
|
||||
logging.getLogger("comfy").debug(
|
||||
f"DA3GeometryToMesh: depth[{batch_index}] range "
|
||||
f"[{depth.min():.4g}, {depth.max():.4g}], mean={depth.mean():.4g}"
|
||||
)
|
||||
|
||||
K = _da3_get_K(da3_geometry, batch_index, H, W)
|
||||
points = _da3_unproject(depth, K) # (H, W, 3) in OpenCV space
|
||||
points = _da3_unproject(depth, K) # (H, W, 3) in OpenCV camera space
|
||||
|
||||
# Apply world-to-camera inverse so multi-view frames share a common world frame.
|
||||
E = _da3_get_extrinsic(da3_geometry, batch_index)
|
||||
if E is not None:
|
||||
points = _da3_apply_extrinsic(points, E)
|
||||
|
||||
# Mask invalid pixels by setting them to inf so triangulate_grid_mesh skips them.
|
||||
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
|
||||
# Also exclude pixels where depth was invalid.
|
||||
mask = mask & (depth_all[batch_index] > 0) & torch.isfinite(depth_all[batch_index])
|
||||
points = points.clone()
|
||||
points[~mask] = float('inf')
|
||||
|
||||
@ -589,7 +648,7 @@ class DA3GeometryToPointCloud(io.ComfyNode):
|
||||
io.Int.Input("batch_index", default=0, min=0, max=4096,
|
||||
tooltip="Which frame of a batched DA3_GEOMETRY to convert."),
|
||||
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Exclude pixels with raw confidence below this value. "
|
||||
tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all). "
|
||||
"Ignored when the geometry has no confidence map."),
|
||||
io.Boolean.Input("use_sky_mask", default=True,
|
||||
tooltip="Exclude sky-probability pixels (sky >= 0.5). "
|
||||
@ -610,7 +669,8 @@ class DA3GeometryToPointCloud(io.ComfyNode):
|
||||
if batch_index >= B:
|
||||
raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.")
|
||||
|
||||
depth = depth_all[batch_index] # (H, W)
|
||||
depth = depth_all[batch_index].clone() # (H, W)
|
||||
depth[~torch.isfinite(depth)] = 0.0
|
||||
H, W = depth.shape
|
||||
|
||||
K = _da3_get_K(da3_geometry, batch_index, H, W)
|
||||
@ -623,7 +683,12 @@ class DA3GeometryToPointCloud(io.ComfyNode):
|
||||
K[1, :] /= downsample
|
||||
|
||||
H_ds, W_ds = depth.shape
|
||||
points = _da3_unproject(depth, K) # (H_ds, W_ds, 3)
|
||||
points = _da3_unproject(depth, K) # (H_ds, W_ds, 3) in OpenCV camera space
|
||||
|
||||
# Apply world-to-camera inverse so multi-view frames share a common world frame.
|
||||
E = _da3_get_extrinsic(da3_geometry, batch_index)
|
||||
if E is not None:
|
||||
points = _da3_apply_extrinsic(points, E)
|
||||
|
||||
# Rebuild mask at downsampled resolution.
|
||||
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user