mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-16 03:57:27 +08:00
399 lines
16 KiB
Python
399 lines
16 KiB
Python
"""ComfyUI nodes for Depth Anything 3.
|
|
|
|
Adds these nodes:
|
|
|
|
* ``LoadDepthAnything3`` -- load a DA3 ``.safetensors`` file from the
|
|
``models/depth_estimation/`` folder. Falls back to ``models/diffusion_models/``
|
|
so existing installations keep working.
|
|
* ``DepthAnything3Depth`` -- run depth estimation and return a normalised
|
|
depth map as a ComfyUI ``IMAGE`` (visualisation / ControlNet input).
|
|
* ``DepthAnything3DepthRaw`` -- run depth estimation and return the raw depth,
|
|
confidence and sky channels as ``MASK`` outputs.
|
|
* ``DepthAnything3MultiView`` -- multi-view path: depth + per-view extrinsics
|
|
+ intrinsics. Pose is decoded either from the camera-decoder MLP (default)
|
|
or from the auxiliary ray output via RANSAC (DA3-Small/Base only).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing_extensions import override
|
|
|
|
import torch
|
|
|
|
import comfy.model_management as mm
|
|
import comfy.sd
|
|
import folder_paths
|
|
from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess
|
|
from comfy_api.latest import ComfyExtension, io
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Loader
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
class LoadDepthAnything3(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="LoadDepthAnything3",
|
|
display_name="Load Depth Anything 3",
|
|
category="loaders/depth_estimation",
|
|
inputs=[
|
|
io.Combo.Input(
|
|
"model_name",
|
|
options=folder_paths.get_filename_list("depth_estimation"),
|
|
),
|
|
io.Combo.Input(
|
|
"weight_dtype",
|
|
options=["default", "fp16", "bf16", "fp32"],
|
|
default="default",
|
|
),
|
|
],
|
|
outputs=[io.Model.Output("model")],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, model_name, weight_dtype) -> io.NodeOutput:
|
|
model_options = {}
|
|
if weight_dtype == "fp16":
|
|
model_options["dtype"] = torch.float16
|
|
elif weight_dtype == "bf16":
|
|
model_options["dtype"] = torch.bfloat16
|
|
elif weight_dtype == "fp32":
|
|
model_options["dtype"] = torch.float32
|
|
|
|
path = folder_paths.get_full_path_or_raise("depth_estimation", model_name)
|
|
model = comfy.sd.load_diffusion_model(path, model_options=model_options)
|
|
return io.NodeOutput(model)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Inference helpers
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
def _run_da3(model_patcher, image: torch.Tensor, process_res: int,
|
|
method: str = "upper_bound_resize"):
|
|
"""Run the DA3 network on a (B, H, W, 3) ``IMAGE`` batch.
|
|
|
|
Returns ``(depth, confidence, sky)`` tensors with the original image
|
|
resolution. Any of ``confidence`` / ``sky`` may be ``None`` depending on
|
|
the variant.
|
|
"""
|
|
assert image.ndim == 4 and image.shape[-1] == 3, \
|
|
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
|
|
|
B, H, W, _ = image.shape
|
|
mm.load_model_gpu(model_patcher)
|
|
diffusion = model_patcher.model.diffusion_model
|
|
device = mm.get_torch_device()
|
|
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
|
|
|
|
depths, confs, skies = [], [], []
|
|
# Process one image at a time to keep peak memory predictable; DA3 is
|
|
# an inference-only model and per-sample latency dominates anyway.
|
|
for i in range(B):
|
|
single = image[i:i + 1].to(device)
|
|
x = da3_preprocess.preprocess_image(single, process_res=process_res, method=method)
|
|
x = x.to(dtype=dtype)
|
|
with torch.no_grad():
|
|
out = diffusion(x)
|
|
|
|
depth_lr = out["depth"]
|
|
# Resize back to the original (H, W).
|
|
depth_full = torch.nn.functional.interpolate(
|
|
depth_lr.unsqueeze(1).float(), size=(H, W),
|
|
mode="bilinear", align_corners=False,
|
|
).squeeze(1).cpu()
|
|
depths.append(depth_full)
|
|
|
|
if "depth_conf" in out:
|
|
conf_full = torch.nn.functional.interpolate(
|
|
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
|
|
mode="bilinear", align_corners=False,
|
|
).squeeze(1).cpu()
|
|
confs.append(conf_full)
|
|
if "sky" in out:
|
|
sky_full = torch.nn.functional.interpolate(
|
|
out["sky"].unsqueeze(1).float(), size=(H, W),
|
|
mode="bilinear", align_corners=False,
|
|
).squeeze(1).cpu()
|
|
skies.append(sky_full)
|
|
|
|
depth = torch.cat(depths, dim=0)
|
|
confidence = torch.cat(confs, dim=0) if confs else None
|
|
sky = torch.cat(skies, dim=0) if skies else None
|
|
return depth, confidence, sky
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Depth -> visualisation IMAGE
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
class DepthAnything3Depth(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="DepthAnything3Depth",
|
|
display_name="Depth Anything 3 (Depth)",
|
|
category="image/depth",
|
|
inputs=[
|
|
io.Model.Input("model"),
|
|
io.Image.Input("image"),
|
|
io.Int.Input("process_res", default=504, min=140, max=2520, step=14,
|
|
tooltip="Longest-side target resolution (multiple of 14)."),
|
|
io.Combo.Input("resize_method",
|
|
options=["upper_bound_resize", "lower_bound_resize"],
|
|
default="upper_bound_resize"),
|
|
io.Combo.Input("normalization",
|
|
options=["v2_style", "min_max", "raw"],
|
|
default="v2_style",
|
|
tooltip="How to map raw depth -> [0, 1] image."),
|
|
io.Boolean.Input("apply_sky_clip", default=True,
|
|
tooltip="(Mono/Metric only) clip sky depth to 99th percentile."),
|
|
],
|
|
outputs=[
|
|
io.Image.Output("depth_image"),
|
|
io.Mask.Output("sky_mask",
|
|
tooltip="Sky probability (Mono/Metric variants), else zeros."),
|
|
io.Mask.Output("confidence",
|
|
tooltip="Depth confidence (Small/Base/DualDPT variants), else zeros."),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, model, image, process_res, resize_method, normalization,
|
|
apply_sky_clip) -> io.NodeOutput:
|
|
depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method)
|
|
|
|
if apply_sky_clip and sky is not None:
|
|
depth = torch.stack([
|
|
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
|
|
for i in range(depth.shape[0])
|
|
], dim=0)
|
|
|
|
if normalization == "v2_style":
|
|
norm = torch.stack([
|
|
da3_preprocess.normalize_depth_v2_style(depth[i],
|
|
sky[i] if sky is not None else None)
|
|
for i in range(depth.shape[0])
|
|
], dim=0)
|
|
elif normalization == "min_max":
|
|
norm = da3_preprocess.normalize_depth_min_max(depth)
|
|
else:
|
|
norm = depth
|
|
|
|
# (B, H, W) -> (B, H, W, 3) grayscale IMAGE.
|
|
out_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous()
|
|
sky_mask = sky if sky is not None else torch.zeros_like(depth)
|
|
conf_mask = confidence if confidence is not None else torch.zeros_like(depth)
|
|
return io.NodeOutput(out_image, sky_mask.contiguous(), conf_mask.contiguous())
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Raw depth output (useful for downstream metric work)
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
class DepthAnything3MultiView(io.ComfyNode):
|
|
"""Multi-view depth + pose estimation for DA3-Small / DA3-Base / DA3-Large.
|
|
|
|
Treats each batch element of the input ``IMAGE`` as a separate view of
|
|
the same scene. The selected reference view is auto-chosen by the
|
|
backbone via ``ref_view_strategy`` (when at least 3 views are
|
|
supplied), unless camera extrinsics are provided -- in which case the
|
|
geometry is pinned by the user and no reordering is done.
|
|
|
|
Output structure:
|
|
* ``depth_image`` -- per-view normalised depth as a stacked ``IMAGE``
|
|
batch (one frame per view, original input order).
|
|
* ``confidence`` / ``sky`` -- per-view masks (zero when the variant
|
|
does not produce them).
|
|
* ``camera`` -- ``LATENT`` dict with keys::
|
|
samples: (1, S, 1, h_p, w_p) -- raw depth packed as latent
|
|
type: "da3_multiview"
|
|
extrinsics: (1, S, 4, 4) world-to-camera matrices
|
|
intrinsics: (1, S, 3, 3) pixel-space intrinsics
|
|
depth_raw: (S, H, W) un-normalised depth
|
|
confidence: (S, H, W)
|
|
"""
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="DepthAnything3MultiView",
|
|
display_name="Depth Anything 3 (Multi-View)",
|
|
category="image/depth",
|
|
inputs=[
|
|
io.Model.Input("model"),
|
|
io.Image.Input("image",
|
|
tooltip="Image batch where each frame is a view of the same scene."),
|
|
io.Int.Input("process_res", default=504, min=140, max=2520, step=14,
|
|
tooltip="Longest-side target resolution (multiple of 14)."),
|
|
io.Combo.Input("resize_method",
|
|
options=["upper_bound_resize", "lower_bound_resize"],
|
|
default="upper_bound_resize"),
|
|
io.Combo.Input("ref_view_strategy",
|
|
options=["saddle_balanced", "saddle_sim_range", "first", "middle"],
|
|
default="saddle_balanced",
|
|
tooltip="Reference view selection (only applied when "
|
|
"S>=3 and no extrinsics are provided)."),
|
|
io.Combo.Input("pose_method",
|
|
options=["cam_dec", "ray_pose"],
|
|
default="cam_dec",
|
|
tooltip="cam_dec: small MLP on the final cam token (works for "
|
|
"all variants with cam_dec). ray_pose: RANSAC over the "
|
|
"DualDPT auxiliary ray output (DA3-Small/Base only)."),
|
|
io.Combo.Input("normalization",
|
|
options=["v2_style", "min_max", "raw"],
|
|
default="v2_style"),
|
|
],
|
|
outputs=[
|
|
io.Image.Output("depth_image"),
|
|
io.Mask.Output("confidence"),
|
|
io.Mask.Output("sky_mask"),
|
|
io.Latent.Output("camera",
|
|
tooltip="Per-view extrinsics + intrinsics + raw depth."),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, model, image, process_res, resize_method, ref_view_strategy,
|
|
pose_method, normalization) -> io.NodeOutput:
|
|
assert image.ndim == 4 and image.shape[-1] == 3, \
|
|
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
|
S, H, W, _ = image.shape
|
|
|
|
mm.load_model_gpu(model)
|
|
diffusion = model.model.diffusion_model
|
|
device = mm.get_torch_device()
|
|
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
|
|
|
|
# Stack all views as a single batch element with views axis = S.
|
|
x = image.to(device)
|
|
x = da3_preprocess.preprocess_image(x, process_res=process_res, method=resize_method)
|
|
x = x.to(dtype=dtype).unsqueeze(0) # (1, S, 3, H', W')
|
|
|
|
use_ray_pose = (pose_method == "ray_pose")
|
|
with torch.no_grad():
|
|
out = diffusion(x, use_ray_pose=use_ray_pose,
|
|
ref_view_strategy=ref_view_strategy)
|
|
|
|
# ``out["depth"]`` is (S, h_p, w_p); resize back to (S, H, W).
|
|
depth_lr = out["depth"].float()
|
|
depth = torch.nn.functional.interpolate(
|
|
depth_lr.unsqueeze(1), size=(H, W),
|
|
mode="bilinear", align_corners=False,
|
|
).squeeze(1).cpu()
|
|
|
|
if "depth_conf" in out:
|
|
conf = torch.nn.functional.interpolate(
|
|
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
|
|
mode="bilinear", align_corners=False,
|
|
).squeeze(1).cpu()
|
|
else:
|
|
conf = torch.zeros_like(depth)
|
|
|
|
if "sky" in out:
|
|
sky = torch.nn.functional.interpolate(
|
|
out["sky"].unsqueeze(1).float(), size=(H, W),
|
|
mode="bilinear", align_corners=False,
|
|
).squeeze(1).cpu()
|
|
else:
|
|
sky = torch.zeros_like(depth)
|
|
|
|
# Pose. Defaults to identity when neither cam_dec nor ray_pose is wired up.
|
|
if "extrinsics" in out and "intrinsics" in out:
|
|
extrinsics = out["extrinsics"].float().cpu()
|
|
intrinsics = out["intrinsics"].float().cpu()
|
|
else:
|
|
extrinsics = torch.eye(4)[None, None].expand(1, S, 4, 4).clone()
|
|
intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone()
|
|
|
|
# Normalised depth viz per view (same path as the mono node).
|
|
if normalization == "v2_style":
|
|
norm = torch.stack([
|
|
da3_preprocess.normalize_depth_v2_style(depth[i],
|
|
sky[i] if "sky" in out else None)
|
|
for i in range(S)
|
|
], dim=0)
|
|
elif normalization == "min_max":
|
|
norm = da3_preprocess.normalize_depth_min_max(depth)
|
|
else:
|
|
norm = depth
|
|
|
|
depth_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous()
|
|
|
|
camera_latent = {
|
|
# The Latent contract requires a ``samples`` field; pack the raw
|
|
# depth there so a downstream node still has a tensor to chain on.
|
|
"samples": depth.unsqueeze(0).unsqueeze(2).contiguous(), # (1, S, 1, H, W)
|
|
"type": "da3_multiview",
|
|
"extrinsics": extrinsics.contiguous(),
|
|
"intrinsics": intrinsics.contiguous(),
|
|
"depth_raw": depth.contiguous(),
|
|
"confidence": conf.contiguous(),
|
|
}
|
|
return io.NodeOutput(
|
|
depth_image,
|
|
conf.contiguous(),
|
|
sky.contiguous(),
|
|
camera_latent,
|
|
)
|
|
|
|
|
|
class DepthAnything3DepthRaw(io.ComfyNode):
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="DepthAnything3DepthRaw",
|
|
display_name="Depth Anything 3 (Raw Depth)",
|
|
category="image/depth",
|
|
inputs=[
|
|
io.Model.Input("model"),
|
|
io.Image.Input("image"),
|
|
io.Int.Input("process_res", default=504, min=140, max=2520, step=14),
|
|
io.Combo.Input("resize_method",
|
|
options=["upper_bound_resize", "lower_bound_resize"],
|
|
default="upper_bound_resize"),
|
|
],
|
|
outputs=[
|
|
io.Mask.Output("depth",
|
|
tooltip="Raw depth values (no normalisation, no clipping)."),
|
|
io.Mask.Output("confidence"),
|
|
io.Mask.Output("sky"),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, model, image, process_res, resize_method) -> io.NodeOutput:
|
|
depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method)
|
|
zeros = torch.zeros_like(depth)
|
|
return io.NodeOutput(
|
|
depth.contiguous(),
|
|
(confidence if confidence is not None else zeros).contiguous(),
|
|
(sky if sky is not None else zeros).contiguous(),
|
|
)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Extension registration
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
class DepthAnything3Extension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [
|
|
LoadDepthAnything3,
|
|
DepthAnything3Depth,
|
|
DepthAnything3DepthRaw,
|
|
DepthAnything3MultiView,
|
|
]
|
|
|
|
|
|
async def comfy_entrypoint() -> DepthAnything3Extension:
|
|
return DepthAnything3Extension()
|