From 5b1490315b04cdf3e66bd2d25441f9e63940d4e7 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Tue, 19 May 2026 15:51:41 +0200 Subject: [PATCH] Normalize confidence output. --- comfy_extras/nodes_depth_anything_3.py | 33 ++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/comfy_extras/nodes_depth_anything_3.py b/comfy_extras/nodes_depth_anything_3.py index e49bb6f55..10f9faf61 100644 --- a/comfy_extras/nodes_depth_anything_3.py +++ b/comfy_extras/nodes_depth_anything_3.py @@ -72,6 +72,26 @@ class LoadDepthAnything3(io.ComfyNode): return io.NodeOutput(model) +def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor: + """Map raw confidence (expp1 activaton, range [1, ∞)) to [0, 1] per image. + + The model uses ``exp(x) + 1`` so every pixel is guaranteed to be ≥ 1. + Min-max normalization per image preserves the spatial pattern (high + confidence = brighter) while producing a valid mask in [0, 1]. + """ + B = conf.shape[0] + out = [] + for i in range(B): + c = conf[i] + c_min = c.min() + c_max = c.max() + if c_max > c_min: + out.append((c - c_min) / (c_max - c_min)) + else: + out.append(torch.ones_like(c)) + return torch.stack(out, dim=0) + + def _run_da3(model_patcher, image: torch.Tensor, process_res: int, method: str = "upper_bound_resize"): """Run DA3 on ``(B,H,W,3)`` IMAGE; returns depth/conf/sky at original resolution (or None).""" @@ -299,7 +319,8 @@ class DepthAnything3(io.ComfyNode): out_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous() sky_mask = sky if sky is not None else torch.zeros_like(depth) - conf_mask = confidence if confidence is not None else torch.zeros_like(depth) + conf_mask = (_normalize_confidence(confidence) + if confidence is not None else torch.zeros_like(depth)) camera = {"samples": torch.zeros(1, 1, 1, 1, 1), "type": "mono"} return io.NodeOutput( out_image, @@ -336,13 +357,15 @@ class DepthAnything3(io.ComfyNode): mode="bilinear", align_corners=False, ).squeeze(1).cpu() - conf = torch.zeros_like(depth) + conf_raw = torch.zeros_like(depth) if "depth_conf" in out: - conf = torch.nn.functional.interpolate( + conf_raw = torch.nn.functional.interpolate( out["depth_conf"].unsqueeze(1).float(), size=(H, W), mode="bilinear", align_corners=False, ).squeeze(1).cpu() + conf_mask = _normalize_confidence(conf_raw) if conf_raw.any() else conf_raw + sky = torch.zeros_like(depth) if "sky" in out: sky = torch.nn.functional.interpolate( @@ -383,12 +406,12 @@ class DepthAnything3(io.ComfyNode): "extrinsics": extrinsics.contiguous(), "intrinsics": intrinsics.contiguous(), "depth_raw": depth.contiguous(), - "confidence": conf.contiguous(), + "confidence": conf_raw.contiguous(), } return io.NodeOutput( depth_image, sky.contiguous(), - conf.contiguous(), + conf_mask.contiguous(), camera_latent, )