Fix no clamping when normalization is raw and refactor reusable code to _apply_sky_clip and _depth_to_image

This commit is contained in:
Talmaj Marinc 2026-05-19 20:47:28 +02:00
parent 893ab85a8f
commit c0253de43c

View File

@ -296,29 +296,48 @@ class DepthAnything3(io.ComfyNode):
ref_view_strategy, pose_method,
)
@classmethod
def _execute_mono(cls, model, image, process_res, resize_method,
normalization, apply_sky_clip) -> io.NodeOutput:
depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method)
@staticmethod
def _apply_sky_clip(depth: torch.Tensor, sky: torch.Tensor) -> torch.Tensor:
return torch.stack([
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
for i in range(depth.shape[0])
], dim=0)
if apply_sky_clip and sky is not None:
depth = torch.stack([
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
for i in range(depth.shape[0])
], dim=0)
@staticmethod
def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None,
normalization: str) -> torch.Tensor:
"""Normalise depth and pack as an (N,H,W,3) image tensor.
Preserves metric units when normalization is 'raw' (no clamping).
"""
N = depth.shape[0]
if normalization == "v2_style":
norm = torch.stack([
da3_preprocess.normalize_depth_v2_style(
depth[i], sky[i] if sky is not None else None)
for i in range(depth.shape[0])
depth[i], sky_for_norm[i] if sky_for_norm is not None else None)
for i in range(N)
], dim=0)
elif normalization == "min_max":
norm = da3_preprocess.normalize_depth_min_max(depth)
else:
norm = depth
out_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous()
# Preserve metric units when normalization is raw.
out = norm.unsqueeze(-1).repeat(1, 1, 1, 3)
if normalization != "raw":
out = out.clamp(0.0, 1.0)
return out.contiguous()
@classmethod
def _execute_mono(cls, model, image, process_res, resize_method,
normalization, apply_sky_clip) -> io.NodeOutput:
depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method)
if apply_sky_clip and sky is not None:
depth = cls._apply_sky_clip(depth, sky)
out_image = cls._depth_to_image(depth, sky, normalization)
sky_mask = sky if sky is not None else torch.zeros_like(depth)
conf_mask = (_normalize_confidence(confidence)
if confidence is not None else torch.zeros_like(depth))
@ -367,18 +386,15 @@ class DepthAnything3(io.ComfyNode):
conf_mask = _normalize_confidence(conf_raw) if conf_raw.any() else conf_raw
sky = torch.zeros_like(depth)
sky = None
if "sky" in out:
sky = torch.nn.functional.interpolate(
out["sky"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
if apply_sky_clip and "sky" in out:
depth = torch.stack([
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
for i in range(S)
], dim=0)
if apply_sky_clip and sky is not None:
depth = cls._apply_sky_clip(depth, sky)
if "extrinsics" in out and "intrinsics" in out:
extrinsics = out["extrinsics"].float().cpu()
@ -388,19 +404,9 @@ class DepthAnything3(io.ComfyNode):
intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone()
sky_for_norm = sky if diffusion.has_sky else None
if normalization == "v2_style":
norm = torch.stack([
da3_preprocess.normalize_depth_v2_style(
depth[i], sky_for_norm[i] if sky_for_norm is not None else None)
for i in range(S)
], dim=0)
elif normalization == "min_max":
norm = da3_preprocess.normalize_depth_min_max(depth)
else:
norm = depth
depth_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous()
depth_image = cls._depth_to_image(depth, sky_for_norm, normalization)
sky_mask = sky if sky is not None else torch.zeros_like(depth)
camera_latent = {
"samples": depth.unsqueeze(0).unsqueeze(2).contiguous(), # (1, S, 1, H, W)
"type": "da3_multiview",
@ -411,7 +417,7 @@ class DepthAnything3(io.ComfyNode):
}
return io.NodeOutput(
depth_image,
sky.contiguous(),
sky_mask.contiguous(),
conf_mask.contiguous(),
camera_latent,
)