mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-24 07:57:29 +08:00
Refactor DepthAnything3 into DepthAnything3Inference and DepthAnything3Render.
This commit is contained in:
parent
5d970fb4f1
commit
83e2321d50
@ -5,10 +5,11 @@ Adds these nodes:
|
||||
* ``LoadDepthAnything3`` -- load a DA3 ``.safetensors`` file from the
|
||||
``models/geometry_estimation/`` folder.
|
||||
* ``DepthAnything3`` -- unified depth estimation node supporting both mono and
|
||||
multi-view modes via a DynamicCombo selector. Returns a single DA3_GEOMETRY
|
||||
dict containing raw depth, normalised depth image, source image, and
|
||||
optionally sky/mask (Mono/Metric), confidence (Small/Base), and
|
||||
extrinsics/intrinsics (multi-view). Compatible with MoGe Render.
|
||||
multi-view modes via a DynamicCombo selector. Returns a DA3_GEOMETRY dict of
|
||||
raw tensors (depth, sky, confidence, camera). Feed into ``DepthAnything3Render``
|
||||
to produce display images, or directly into ``MoGeRender`` for depth / mask views.
|
||||
* ``DepthAnything3Render`` -- post-processes a DA3_GEOMETRY dict: applies optional
|
||||
sky clipping, normalises depth and confidence, and returns display images.
|
||||
|
||||
Model capability matrix
|
||||
-----------------------
|
||||
@ -41,24 +42,22 @@ DA3Geometry = io.Custom("DA3_GEOMETRY")
|
||||
# DA3_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them):
|
||||
#
|
||||
# Per-frame tensors — B = batch size in mono mode; B = S (number of views) in multi-view mode.
|
||||
# "depth": torch.Tensor (B, H, W) -- raw depth (always present)
|
||||
# "depth_image": torch.Tensor (B, H, W, 3) -- normalised depth for display (always present)
|
||||
# "depth": torch.Tensor (B, H, W) -- raw model depth (always present; matches MoGe convention)
|
||||
# "image": torch.Tensor (B, H, W, 3) -- source image in [0, 1], CPU (always present)
|
||||
# "mode": str -- "mono" or "multiview" (always present)
|
||||
# "sky": torch.Tensor (B, H, W) -- sky probability in [0, 1] (Mono/Metric variants only)
|
||||
# "mask": torch.Tensor (B, H, W) bool -- True = valid foreground / False = sky (present when sky head available)
|
||||
# "confidence": torch.Tensor (B, H, W) -- normalised depth confidence in [0, 1] (Small/Base variants only)
|
||||
# "confidence": torch.Tensor (B, H, W) -- raw model confidence output (Small/Base variants only)
|
||||
#
|
||||
# Multi-view only — S = number of views; the leading 1 is the scene dimension from the model.
|
||||
# "extrinsics": torch.Tensor (1, S, 4, 4) -- world-to-camera matrices
|
||||
# "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics
|
||||
|
||||
|
||||
class LoadDepthAnything3(io.ComfyNode):
|
||||
class LoadDepthAnything3Model(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadDepthAnything3",
|
||||
node_id="LoadDepthAnything3Model",
|
||||
display_name="Load Depth Anything 3",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
@ -90,26 +89,6 @@ class LoadDepthAnything3(io.ComfyNode):
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor:
|
||||
"""Map raw confidence (expp1 activaton, range [1, ∞)) to [0, 1] per image.
|
||||
|
||||
The model uses ``exp(x) + 1`` so every pixel is guaranteed to be ≥ 1.
|
||||
Min-max normalization per image preserves the spatial pattern (high
|
||||
confidence = brighter) while producing a valid mask in [0, 1].
|
||||
"""
|
||||
B = conf.shape[0]
|
||||
out = []
|
||||
for i in range(B):
|
||||
c = conf[i]
|
||||
c_min = c.min()
|
||||
c_max = c.max()
|
||||
if c_max > c_min:
|
||||
out.append((c - c_min) / (c_max - c_min))
|
||||
else:
|
||||
out.append(torch.ones_like(c))
|
||||
return torch.stack(out, dim=0)
|
||||
|
||||
|
||||
def _run_da3(model_patcher, image: torch.Tensor, process_res: int,
|
||||
method: str = "upper_bound_resize"):
|
||||
"""Run DA3 on ``(B,H,W,3)`` IMAGE; returns depth/conf/sky at original resolution (or None)."""
|
||||
@ -156,29 +135,16 @@ def _run_da3(model_patcher, image: torch.Tensor, process_res: int,
|
||||
return depth, confidence, sky
|
||||
|
||||
|
||||
class DepthAnything3(io.ComfyNode):
|
||||
"""Unified Depth Anything 3 node.
|
||||
class DepthAnything3Inference(io.ComfyNode):
|
||||
"""Raw Depth Anything 3 inference node.
|
||||
|
||||
Returns a single DA3_GEOMETRY dict containing all useful outputs.
|
||||
See the DA3_GEOMETRY comment block near the top of this module for the full key listing.
|
||||
Outputs a DA3_GEOMETRY dict of raw tensors. All display normalization
|
||||
(sky clipping, depth scaling, confidence normalisation) is handled by
|
||||
the companion ``DepthAnything3Render`` node.
|
||||
|
||||
Mono mode
|
||||
---------
|
||||
Runs the model on each batch element independently.
|
||||
|
||||
Multi-view mode
|
||||
---------------
|
||||
Treats every batch element as a separate view of the same scene.
|
||||
Runs all views in a single forward pass so cross-view attention can
|
||||
establish geometric consistency. Adds ``extrinsics`` and ``intrinsics``
|
||||
to the geometry dict.
|
||||
|
||||
Capability errors
|
||||
-----------------
|
||||
A ``ValueError`` is raised immediately when a parameter requires a
|
||||
model feature that is absent in the loaded checkpoint (e.g.
|
||||
``apply_sky_clip=True`` on DA3-Small/Base which has no sky head,
|
||||
or ``pose_method='cam_dec'`` on a monocular model).
|
||||
Mono mode: each batch element is processed independently.
|
||||
Multi-view mode: all frames share a single forward pass with cross-view
|
||||
attention; adds ``extrinsics`` and ``intrinsics`` to the geometry dict.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@ -206,18 +172,6 @@ class DepthAnything3(io.ComfyNode):
|
||||
"(caps memory, default). "
|
||||
"lower_bound_resize: scale so the shortest side = process_res "
|
||||
"(preserves more detail on tall/wide images, uses more memory)."),
|
||||
io.Combo.Input("normalization",
|
||||
options=["v2_style", "min_max", "raw"],
|
||||
default="v2_style",
|
||||
tooltip="How to map raw depth to [0, 1] for the output image. "
|
||||
"'v2_style': normalizes using mean and std for perceptually balanced results (default). "
|
||||
"'min_max': stretches the full depth range to [0, 1] for maximum contrast. "
|
||||
"'raw': preserves absolute values — use this to keep metric units when running DA3-Metric-Large."),
|
||||
io.Boolean.Input("apply_sky_clip", default=False,
|
||||
tooltip="Clip sky-region depth to the 99th percentile before "
|
||||
"normalisation. Requires a sky segmentation head "
|
||||
"(DA3-Mono-Large or DA3-Metric-Large). "
|
||||
"Raises an error on DA3-Small/Base."),
|
||||
io.DynamicCombo.Input("mode",
|
||||
tooltip="mono: single image or independent batch — "
|
||||
"use with any model. "
|
||||
@ -254,36 +208,24 @@ class DepthAnything3(io.ComfyNode):
|
||||
],
|
||||
outputs=[
|
||||
DA3Geometry.Output("geometry",
|
||||
tooltip="DA3_GEOMETRY dict. Always contains: "
|
||||
"'depth' (raw), 'depth_image' (normalised), 'image' (source), 'mode'. "
|
||||
"Optional: 'sky' + 'mask' (Mono/Metric variants), "
|
||||
"'confidence' (Small/Base variants), "
|
||||
"'extrinsics' + 'intrinsics' (multi-view only). "
|
||||
"Compatible with MoGe Render for depth and mask visualisation."),
|
||||
tooltip="DA3_GEOMETRY dict of raw tensors. "
|
||||
"Always: 'depth' (B,H,W), 'image', 'mode'. "
|
||||
"Optional: 'sky' + 'mask' (Mono/Metric), "
|
||||
"'confidence' raw (Small/Base), "
|
||||
"'extrinsics' + 'intrinsics' (multi-view). "
|
||||
"Feed into DepthAnything3Render or MoGeRender."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, image, process_res, resize_method, normalization,
|
||||
apply_sky_clip, mode) -> io.NodeOutput:
|
||||
diffusion = model.model.diffusion_model
|
||||
def execute(cls, model, image, process_res, resize_method, mode) -> io.NodeOutput:
|
||||
mode_val = mode["mode"] # "mono" or "multiview"
|
||||
|
||||
# Capability check for sky clip — fires in both modes.
|
||||
if apply_sky_clip and not diffusion.has_sky:
|
||||
raise ValueError(
|
||||
"apply_sky_clip=True requires a sky segmentation head, but the loaded "
|
||||
"model does not have one. Set apply_sky_clip=False, or load a model "
|
||||
"that includes a sky head (e.g. DA3-Mono-Large or DA3-Metric-Large)."
|
||||
)
|
||||
|
||||
if mode_val == "mono":
|
||||
return cls._execute_mono(
|
||||
model, image, process_res, resize_method,
|
||||
normalization, apply_sky_clip,
|
||||
)
|
||||
return cls._execute_mono(model, image, process_res, resize_method)
|
||||
|
||||
# Capability checks for multi-view pose.
|
||||
diffusion = model.model.diffusion_model
|
||||
pose_method = mode["pose_method"]
|
||||
ref_view_strategy = mode["ref_view_strategy"]
|
||||
|
||||
@ -302,70 +244,26 @@ class DepthAnything3(io.ComfyNode):
|
||||
|
||||
return cls._execute_multiview(
|
||||
model, image, process_res, resize_method,
|
||||
normalization, apply_sky_clip,
|
||||
ref_view_strategy, pose_method,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _apply_sky_clip(depth: torch.Tensor, sky: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack([
|
||||
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
|
||||
for i in range(depth.shape[0])
|
||||
], dim=0)
|
||||
|
||||
@staticmethod
|
||||
def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None,
|
||||
normalization: str) -> torch.Tensor:
|
||||
"""Normalise depth and pack as an (N,H,W,3) image tensor.
|
||||
|
||||
Preserves metric units when normalization is 'raw' (no clamping).
|
||||
"""
|
||||
N = depth.shape[0]
|
||||
if normalization == "v2_style":
|
||||
norm = torch.stack([
|
||||
da3_preprocess.normalize_depth_v2_style(
|
||||
depth[i], sky_for_norm[i] if sky_for_norm is not None else None)
|
||||
for i in range(N)
|
||||
], dim=0)
|
||||
elif normalization == "min_max":
|
||||
norm = da3_preprocess.normalize_depth_min_max(depth)
|
||||
else:
|
||||
norm = depth
|
||||
|
||||
# Preserve metric units when normalization is raw.
|
||||
out = norm.unsqueeze(-1).repeat(1, 1, 1, 3)
|
||||
if normalization != "raw":
|
||||
out = out.clamp(0.0, 1.0)
|
||||
return out.contiguous()
|
||||
|
||||
@classmethod
|
||||
def _execute_mono(cls, model, image, process_res, resize_method,
|
||||
normalization, apply_sky_clip) -> io.NodeOutput:
|
||||
def _execute_mono(cls, model, image, process_res, resize_method) -> io.NodeOutput:
|
||||
depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method)
|
||||
|
||||
if apply_sky_clip and sky is not None:
|
||||
depth = cls._apply_sky_clip(depth, sky)
|
||||
|
||||
depth_image = cls._depth_to_image(depth, sky, normalization)
|
||||
|
||||
geometry: dict = {
|
||||
"depth": depth.contiguous(),
|
||||
"depth_image": depth_image,
|
||||
"image": image[..., :3].cpu(),
|
||||
"mode": "mono",
|
||||
}
|
||||
if sky is not None:
|
||||
geometry["sky"] = sky.contiguous()
|
||||
# True = valid foreground, False = sky/invalid — matches MoGe mask semantics.
|
||||
geometry["mask"] = (sky < 0.5).contiguous()
|
||||
if confidence is not None:
|
||||
geometry["confidence"] = confidence.contiguous()
|
||||
geometry["confidence_image"] = _normalize_confidence(confidence).contiguous()
|
||||
return io.NodeOutput(geometry)
|
||||
|
||||
@classmethod
|
||||
def _execute_multiview(cls, model, image, process_res, resize_method,
|
||||
normalization, apply_sky_clip,
|
||||
ref_view_strategy, pose_method) -> io.NodeOutput:
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, \
|
||||
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
@ -391,15 +289,6 @@ class DepthAnything3(io.ComfyNode):
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
|
||||
conf_raw = torch.zeros_like(depth)
|
||||
if "depth_conf" in out:
|
||||
conf_raw = torch.nn.functional.interpolate(
|
||||
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
|
||||
conf_mask = _normalize_confidence(conf_raw) if conf_raw.any() else conf_raw
|
||||
|
||||
sky = None
|
||||
if "sky" in out:
|
||||
sky = torch.nn.functional.interpolate(
|
||||
@ -407,9 +296,6 @@ class DepthAnything3(io.ComfyNode):
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
|
||||
if apply_sky_clip and sky is not None:
|
||||
depth = cls._apply_sky_clip(depth, sky)
|
||||
|
||||
if "extrinsics" in out and "intrinsics" in out:
|
||||
extrinsics = out["extrinsics"].float().cpu()
|
||||
intrinsics = out["intrinsics"].float().cpu()
|
||||
@ -417,12 +303,8 @@ class DepthAnything3(io.ComfyNode):
|
||||
extrinsics = torch.eye(4)[None, None].expand(1, S, 4, 4).clone()
|
||||
intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone()
|
||||
|
||||
sky_for_norm = sky if diffusion.has_sky else None
|
||||
depth_image = cls._depth_to_image(depth, sky_for_norm, normalization)
|
||||
|
||||
geometry: dict = {
|
||||
"depth": depth.contiguous(),
|
||||
"depth_image": depth_image,
|
||||
"image": image[..., :3].cpu(),
|
||||
"mode": "multiview",
|
||||
"extrinsics": extrinsics.contiguous(),
|
||||
@ -430,20 +312,142 @@ class DepthAnything3(io.ComfyNode):
|
||||
}
|
||||
if sky is not None:
|
||||
geometry["sky"] = sky.contiguous()
|
||||
# True = valid foreground, False = sky/invalid — matches MoGe mask semantics.
|
||||
geometry["mask"] = (sky < 0.5).contiguous()
|
||||
if conf_raw.any():
|
||||
geometry["confidence"] = conf_mask.contiguous()
|
||||
geometry["confidence_image"] = _normalize_confidence(conf_mask).contiguous()
|
||||
if "depth_conf" in out:
|
||||
conf = torch.nn.functional.interpolate(
|
||||
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
geometry["confidence"] = conf.contiguous()
|
||||
return io.NodeOutput(geometry)
|
||||
|
||||
|
||||
class DepthAnything3Render(io.ComfyNode):
|
||||
"""Visualise a DA3_GEOMETRY packet as a single image.
|
||||
|
||||
Mirrors the MoGeRender interface: one ``output`` selector, one IMAGE out.
|
||||
Use multiple nodes in parallel to get depth + sky + confidence simultaneously.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DepthAnything3Render",
|
||||
display_name="Depth Anything 3 Render",
|
||||
category="image/geometry_estimation",
|
||||
description="Visualise a DA3_GEOMETRY packet. Drop multiple nodes to get different views simultaneously.",
|
||||
inputs=[
|
||||
DA3Geometry.Input("geometry"),
|
||||
io.DynamicCombo.Input("output",
|
||||
tooltip="depth: normalised depth image. "
|
||||
"sky_mask: sky probability in [0, 1] (Mono/Metric variants only). "
|
||||
"confidence: normalised depth confidence (Small/Base variants only).",
|
||||
options=[
|
||||
io.DynamicCombo.Option("depth", [
|
||||
io.Combo.Input("normalization",
|
||||
options=["v2_style", "min_max", "raw"],
|
||||
default="v2_style",
|
||||
tooltip="'v2_style': mean/std normalisation for perceptually balanced results (default). "
|
||||
"'min_max': stretches the full depth range to [0, 1] for maximum contrast. "
|
||||
"'raw': no scaling — preserves metric units for DA3-Metric-Large."),
|
||||
io.Boolean.Input("apply_sky_clip", default=False,
|
||||
tooltip="Clip sky-region depth to the 99th percentile of foreground depth before "
|
||||
"normalisation. Requires a 'sky' tensor in the geometry "
|
||||
"(DA3-Mono-Large or DA3-Metric-Large); raises an error otherwise."),
|
||||
]),
|
||||
io.DynamicCombo.Option("sky_mask", []),
|
||||
io.DynamicCombo.Option("confidence", []),
|
||||
]),
|
||||
],
|
||||
outputs=[io.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, geometry, output) -> io.NodeOutput:
|
||||
output_val = output["output"]
|
||||
|
||||
if output_val == "depth":
|
||||
normalization = output["normalization"]
|
||||
apply_sky_clip = output["apply_sky_clip"]
|
||||
if apply_sky_clip and "sky" not in geometry:
|
||||
raise ValueError(
|
||||
"apply_sky_clip=True requires a sky tensor in the geometry, but none is present. "
|
||||
"Run with DA3-Mono-Large or DA3-Metric-Large, or set apply_sky_clip=False."
|
||||
)
|
||||
depth = geometry["depth"]
|
||||
sky = geometry.get("sky")
|
||||
if apply_sky_clip and sky is not None:
|
||||
depth = torch.stack([
|
||||
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
|
||||
for i in range(depth.shape[0])
|
||||
], dim=0)
|
||||
result = cls._depth_to_image(depth, sky, normalization)
|
||||
|
||||
elif output_val == "sky_mask":
|
||||
if "sky" not in geometry:
|
||||
raise ValueError("geometry has no sky output; run with DA3-Mono-Large or DA3-Metric-Large.")
|
||||
sky = geometry["sky"]
|
||||
result = sky.unsqueeze(-1).expand(*sky.shape, 3).contiguous()
|
||||
|
||||
elif output_val == "confidence":
|
||||
if "confidence" not in geometry:
|
||||
raise ValueError("geometry has no confidence output; run with DA3-Small or DA3-Base.")
|
||||
result = cls._normalize_confidence(geometry["confidence"])
|
||||
result = result.unsqueeze(-1).expand(*result.shape, 3).contiguous()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown output mode: {output_val}")
|
||||
|
||||
return io.NodeOutput(result.float())
|
||||
|
||||
@staticmethod
|
||||
def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None,
|
||||
normalization: str) -> torch.Tensor:
|
||||
"""Normalise depth and pack as an (B,H,W,3) image tensor."""
|
||||
N = depth.shape[0]
|
||||
if normalization == "v2_style":
|
||||
norm = torch.stack([
|
||||
da3_preprocess.normalize_depth_v2_style(
|
||||
depth[i], sky_for_norm[i] if sky_for_norm is not None else None)
|
||||
for i in range(N)
|
||||
], dim=0)
|
||||
elif normalization == "min_max":
|
||||
norm = da3_preprocess.normalize_depth_min_max(depth)
|
||||
else:
|
||||
norm = depth
|
||||
|
||||
out = norm.unsqueeze(-1).repeat(1, 1, 1, 3)
|
||||
if normalization != "raw":
|
||||
out = out.clamp(0.0, 1.0)
|
||||
return out.contiguous()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor:
|
||||
"""Map raw confidence (expp1 activaton, range [1, ∞)) to [0, 1] per image.
|
||||
|
||||
The model uses ``exp(x) + 1`` so every pixel is guaranteed to be ≥ 1.
|
||||
Min-max normalization per image preserves the spatial pattern (high
|
||||
confidence = brighter) while producing a valid mask in [0, 1].
|
||||
"""
|
||||
B = conf.shape[0]
|
||||
out = []
|
||||
for i in range(B):
|
||||
c = conf[i]
|
||||
c_min = c.min()
|
||||
c_max = c.max()
|
||||
if c_max > c_min:
|
||||
out.append((c - c_min) / (c_max - c_min))
|
||||
else:
|
||||
out.append(torch.ones_like(c))
|
||||
return torch.stack(out, dim=0)
|
||||
|
||||
|
||||
class DepthAnything3Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
LoadDepthAnything3,
|
||||
DepthAnything3,
|
||||
LoadDepthAnything3Model,
|
||||
DepthAnything3Inference,
|
||||
DepthAnything3Render,
|
||||
]
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user