mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-21 14:37:30 +08:00
Normalize confidence output.
This commit is contained in:
parent
dfe4124f77
commit
5b1490315b
@ -72,6 +72,26 @@ class LoadDepthAnything3(io.ComfyNode):
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor:
|
||||
"""Map raw confidence (expp1 activaton, range [1, ∞)) to [0, 1] per image.
|
||||
|
||||
The model uses ``exp(x) + 1`` so every pixel is guaranteed to be ≥ 1.
|
||||
Min-max normalization per image preserves the spatial pattern (high
|
||||
confidence = brighter) while producing a valid mask in [0, 1].
|
||||
"""
|
||||
B = conf.shape[0]
|
||||
out = []
|
||||
for i in range(B):
|
||||
c = conf[i]
|
||||
c_min = c.min()
|
||||
c_max = c.max()
|
||||
if c_max > c_min:
|
||||
out.append((c - c_min) / (c_max - c_min))
|
||||
else:
|
||||
out.append(torch.ones_like(c))
|
||||
return torch.stack(out, dim=0)
|
||||
|
||||
|
||||
def _run_da3(model_patcher, image: torch.Tensor, process_res: int,
|
||||
method: str = "upper_bound_resize"):
|
||||
"""Run DA3 on ``(B,H,W,3)`` IMAGE; returns depth/conf/sky at original resolution (or None)."""
|
||||
@ -299,7 +319,8 @@ class DepthAnything3(io.ComfyNode):
|
||||
|
||||
out_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous()
|
||||
sky_mask = sky if sky is not None else torch.zeros_like(depth)
|
||||
conf_mask = confidence if confidence is not None else torch.zeros_like(depth)
|
||||
conf_mask = (_normalize_confidence(confidence)
|
||||
if confidence is not None else torch.zeros_like(depth))
|
||||
camera = {"samples": torch.zeros(1, 1, 1, 1, 1), "type": "mono"}
|
||||
return io.NodeOutput(
|
||||
out_image,
|
||||
@ -336,13 +357,15 @@ class DepthAnything3(io.ComfyNode):
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
|
||||
conf = torch.zeros_like(depth)
|
||||
conf_raw = torch.zeros_like(depth)
|
||||
if "depth_conf" in out:
|
||||
conf = torch.nn.functional.interpolate(
|
||||
conf_raw = torch.nn.functional.interpolate(
|
||||
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
|
||||
conf_mask = _normalize_confidence(conf_raw) if conf_raw.any() else conf_raw
|
||||
|
||||
sky = torch.zeros_like(depth)
|
||||
if "sky" in out:
|
||||
sky = torch.nn.functional.interpolate(
|
||||
@ -383,12 +406,12 @@ class DepthAnything3(io.ComfyNode):
|
||||
"extrinsics": extrinsics.contiguous(),
|
||||
"intrinsics": intrinsics.contiguous(),
|
||||
"depth_raw": depth.contiguous(),
|
||||
"confidence": conf.contiguous(),
|
||||
"confidence": conf_raw.contiguous(),
|
||||
}
|
||||
return io.NodeOutput(
|
||||
depth_image,
|
||||
sky.contiguous(),
|
||||
conf.contiguous(),
|
||||
conf_mask.contiguous(),
|
||||
camera_latent,
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user