mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-20 22:17:31 +08:00
Fix no clamping when normalization is raw and refactor reusable code to _apply_sky_clip and _depth_to_image
This commit is contained in:
parent
893ab85a8f
commit
c0253de43c
@ -296,29 +296,48 @@ class DepthAnything3(io.ComfyNode):
|
||||
ref_view_strategy, pose_method,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _execute_mono(cls, model, image, process_res, resize_method,
|
||||
normalization, apply_sky_clip) -> io.NodeOutput:
|
||||
depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method)
|
||||
@staticmethod
|
||||
def _apply_sky_clip(depth: torch.Tensor, sky: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack([
|
||||
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
|
||||
for i in range(depth.shape[0])
|
||||
], dim=0)
|
||||
|
||||
if apply_sky_clip and sky is not None:
|
||||
depth = torch.stack([
|
||||
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
|
||||
for i in range(depth.shape[0])
|
||||
], dim=0)
|
||||
@staticmethod
|
||||
def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None,
|
||||
normalization: str) -> torch.Tensor:
|
||||
"""Normalise depth and pack as an (N,H,W,3) image tensor.
|
||||
|
||||
Preserves metric units when normalization is 'raw' (no clamping).
|
||||
"""
|
||||
N = depth.shape[0]
|
||||
if normalization == "v2_style":
|
||||
norm = torch.stack([
|
||||
da3_preprocess.normalize_depth_v2_style(
|
||||
depth[i], sky[i] if sky is not None else None)
|
||||
for i in range(depth.shape[0])
|
||||
depth[i], sky_for_norm[i] if sky_for_norm is not None else None)
|
||||
for i in range(N)
|
||||
], dim=0)
|
||||
elif normalization == "min_max":
|
||||
norm = da3_preprocess.normalize_depth_min_max(depth)
|
||||
else:
|
||||
norm = depth
|
||||
|
||||
out_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous()
|
||||
# Preserve metric units when normalization is raw.
|
||||
out = norm.unsqueeze(-1).repeat(1, 1, 1, 3)
|
||||
if normalization != "raw":
|
||||
out = out.clamp(0.0, 1.0)
|
||||
return out.contiguous()
|
||||
|
||||
@classmethod
|
||||
def _execute_mono(cls, model, image, process_res, resize_method,
|
||||
normalization, apply_sky_clip) -> io.NodeOutput:
|
||||
depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method)
|
||||
|
||||
if apply_sky_clip and sky is not None:
|
||||
depth = cls._apply_sky_clip(depth, sky)
|
||||
|
||||
out_image = cls._depth_to_image(depth, sky, normalization)
|
||||
|
||||
sky_mask = sky if sky is not None else torch.zeros_like(depth)
|
||||
conf_mask = (_normalize_confidence(confidence)
|
||||
if confidence is not None else torch.zeros_like(depth))
|
||||
@ -367,18 +386,15 @@ class DepthAnything3(io.ComfyNode):
|
||||
|
||||
conf_mask = _normalize_confidence(conf_raw) if conf_raw.any() else conf_raw
|
||||
|
||||
sky = torch.zeros_like(depth)
|
||||
sky = None
|
||||
if "sky" in out:
|
||||
sky = torch.nn.functional.interpolate(
|
||||
out["sky"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
|
||||
if apply_sky_clip and "sky" in out:
|
||||
depth = torch.stack([
|
||||
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
|
||||
for i in range(S)
|
||||
], dim=0)
|
||||
if apply_sky_clip and sky is not None:
|
||||
depth = cls._apply_sky_clip(depth, sky)
|
||||
|
||||
if "extrinsics" in out and "intrinsics" in out:
|
||||
extrinsics = out["extrinsics"].float().cpu()
|
||||
@ -388,19 +404,9 @@ class DepthAnything3(io.ComfyNode):
|
||||
intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone()
|
||||
|
||||
sky_for_norm = sky if diffusion.has_sky else None
|
||||
if normalization == "v2_style":
|
||||
norm = torch.stack([
|
||||
da3_preprocess.normalize_depth_v2_style(
|
||||
depth[i], sky_for_norm[i] if sky_for_norm is not None else None)
|
||||
for i in range(S)
|
||||
], dim=0)
|
||||
elif normalization == "min_max":
|
||||
norm = da3_preprocess.normalize_depth_min_max(depth)
|
||||
else:
|
||||
norm = depth
|
||||
|
||||
depth_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous()
|
||||
depth_image = cls._depth_to_image(depth, sky_for_norm, normalization)
|
||||
|
||||
sky_mask = sky if sky is not None else torch.zeros_like(depth)
|
||||
camera_latent = {
|
||||
"samples": depth.unsqueeze(0).unsqueeze(2).contiguous(), # (1, S, 1, H, W)
|
||||
"type": "da3_multiview",
|
||||
@ -411,7 +417,7 @@ class DepthAnything3(io.ComfyNode):
|
||||
}
|
||||
return io.NodeOutput(
|
||||
depth_image,
|
||||
sky.contiguous(),
|
||||
sky_mask.contiguous(),
|
||||
conf_mask.contiguous(),
|
||||
camera_latent,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user