diff --git a/comfy_extras/nodes_depth_anything_3.py b/comfy_extras/nodes_depth_anything_3.py index 9e0bee9aa..53767004c 100644 --- a/comfy_extras/nodes_depth_anything_3.py +++ b/comfy_extras/nodes_depth_anything_3.py @@ -5,13 +5,23 @@ Adds these nodes: * ``LoadDepthAnything3`` -- load a DA3 ``.safetensors`` file from the ``models/depth_estimation/`` folder. Falls back to ``models/diffusion_models/`` so existing installations keep working. -* ``DepthAnything3Depth`` -- run depth estimation and return a normalised - depth map as a ComfyUI ``IMAGE`` (visualisation / ControlNet input). -* ``DepthAnything3DepthRaw`` -- run depth estimation and return the raw depth, - confidence and sky channels as ``MASK`` outputs. -* ``DepthAnything3MultiView`` -- multi-view path: depth + per-view extrinsics - + intrinsics. Pose is decoded either from the camera-decoder MLP (default) - or from the auxiliary ray output via RANSAC (DA3-Small/Base only). +* ``DepthAnything3`` -- unified depth estimation node supporting both mono and + multi-view modes via a DynamicCombo selector. In mono mode, returns a + normalised depth image plus sky/confidence masks. In multi-view mode, + additionally returns per-view extrinsics, intrinsics and raw depth packed + as a LATENT. + +Model capability matrix +----------------------- + Variant head_type has_sky has_conf cam_dec + DA3-Small dualdpt False True yes + DA3-Base dualdpt False True yes + DA3-Mono-Large dpt True False no + DA3-Metric-Large dpt True False no (raw output is metres) + +The node raises a ``ValueError`` at execution time when the selected +parameters conflict with the loaded model's capabilities (e.g. +``apply_sky_clip=True`` on a model with no sky head). """ from __future__ import annotations @@ -26,12 +36,6 @@ import folder_paths from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess from comfy_api.latest import ComfyExtension, io - -# ----------------------------------------------------------------------------- -# Loader -# ----------------------------------------------------------------------------- - - class LoadDepthAnything3(io.ComfyNode): @classmethod def define_schema(cls): @@ -75,11 +79,11 @@ class LoadDepthAnything3(io.ComfyNode): def _run_da3(model_patcher, image: torch.Tensor, process_res: int, method: str = "upper_bound_resize"): - """Run the DA3 network on a (B, H, W, 3) ``IMAGE`` batch. + """Run the DA3 network on a (B, H, W, 3) IMAGE batch. - Returns ``(depth, confidence, sky)`` tensors with the original image - resolution. Any of ``confidence`` / ``sky`` may be ``None`` depending on - the variant. + Returns ``(depth, confidence, sky)`` tensors at the original image + resolution. ``confidence`` / ``sky`` are ``None`` when the variant does + not produce them. """ assert image.ndim == 4 and image.shape[-1] == 3, \ f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}" @@ -91,8 +95,7 @@ def _run_da3(model_patcher, image: torch.Tensor, process_res: int, dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32 depths, confs, skies = [], [], [] - # Process one image at a time to keep peak memory predictable; DA3 is - # an inference-only model and per-sample latency dominates anyway. + # Process one image at a time to keep peak memory predictable. for i in range(B): single = image[i:i + 1].to(device) x = da3_preprocess.preprocess_image(single, process_res=process_res, method=method) @@ -101,7 +104,6 @@ def _run_da3(model_patcher, image: torch.Tensor, process_res: int, out = diffusion(x) depth_lr = out["depth"] - # Resize back to the original (H, W). depth_full = torch.nn.functional.interpolate( depth_lr.unsqueeze(1).float(), size=(H, W), mode="bilinear", align_corners=False, @@ -127,21 +129,50 @@ def _run_da3(model_patcher, image: torch.Tensor, process_res: int, return depth, confidence, sky -# ----------------------------------------------------------------------------- -# Depth -> visualisation IMAGE -# ----------------------------------------------------------------------------- +class DepthAnything3(io.ComfyNode): + """Unified Depth Anything 3 node. + Mono mode + --------- + Runs the model on each batch element independently and returns a + normalised depth image together with sky and confidence masks. + + Multi-view mode + --------------- + Treats every batch element as a separate view of the same scene. + Runs all views in a single forward pass so cross-view attention can + establish geometric consistency. Additionally returns a ``LATENT`` + dict with per-view camera extrinsics, intrinsics and raw depth. + + Capability errors + ----------------- + A ``ValueError`` is raised immediately when a parameter requires a + model feature that is absent in the loaded checkpoint (e.g. + ``apply_sky_clip=True`` on DA3-Small/Base which has no sky head, + or ``pose_method='cam_dec'`` on a monocular model). + + Camera LATENT structure (multi-view only) + ----------------------------------------- + samples: (1, S, 1, H, W) -- raw depth packed as latent samples + type: "da3_multiview" + extrinsics: (1, S, 4, 4) -- world-to-camera matrices + intrinsics: (1, S, 3, 3) -- pixel-space intrinsics + depth_raw: (S, H, W) -- un-normalised depth + confidence: (S, H, W) -- per-pixel confidence (zeros if N/A) + """ -class DepthAnything3Depth(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="DepthAnything3Depth", - display_name="Depth Anything 3 (Depth)", + node_id="DepthAnything3", + display_name="Depth Anything 3", category="image/depth", inputs=[ io.Model.Input("model"), - io.Image.Input("image"), + io.Image.Input("image", + tooltip="Single image or image batch. " + "In multi-view mode each frame is treated as " + "a separate view of the same scene."), io.Int.Input("process_res", default=504, min=140, max=2520, step=14, tooltip="Longest-side target resolution (multiple of 14)."), io.Combo.Input("resize_method", @@ -150,22 +181,93 @@ class DepthAnything3Depth(io.ComfyNode): io.Combo.Input("normalization", options=["v2_style", "min_max", "raw"], default="v2_style", - tooltip="How to map raw depth -> [0, 1] image."), + tooltip="How to map raw depth to [0, 1] for the output image. " + "'raw' preserves absolute values — use this to keep " + "metric units when running DA3-Metric-Large."), io.Boolean.Input("apply_sky_clip", default=True, - tooltip="(Mono/Metric only) clip sky depth to 99th percentile."), + tooltip="Clip sky-region depth to the 99th percentile before " + "normalisation. Requires a sky segmentation head " + "(DA3-Mono-Large or DA3-Metric-Large). " + "Raises an error on DA3-Small/Base."), + io.DynamicCombo.Input("mode", options=[ + io.DynamicCombo.Option("mono", []), + io.DynamicCombo.Option("multiview", [ + io.Combo.Input("ref_view_strategy", + options=["saddle_balanced", "saddle_sim_range", + "first", "middle"], + default="saddle_balanced", + tooltip="Reference view selection strategy (applied when " + "S >= 3 and no extrinsics are provided)."), + io.Combo.Input("pose_method", + options=["cam_dec", "ray_pose"], + default="cam_dec", + tooltip="cam_dec: small MLP on the final camera token " + "(DA3-Small/Base). " + "ray_pose: RANSAC over the DualDPT ray output " + "(DA3-Small/Base only)."), + ]), + ]), ], outputs=[ io.Image.Output("depth_image"), io.Mask.Output("sky_mask", - tooltip="Sky probability (Mono/Metric variants), else zeros."), + tooltip="Sky probability mask (Mono/Metric variants). " + "Zeros for Small/Base."), io.Mask.Output("confidence", - tooltip="Depth confidence (Small/Base/DualDPT variants), else zeros."), + tooltip="Depth confidence (Small/Base variants). " + "Zeros for Mono/Metric."), + io.Latent.Output("camera", + tooltip="Multi-view: per-view extrinsics + intrinsics + raw depth. " + "In mono mode this is an empty placeholder."), ], ) @classmethod def execute(cls, model, image, process_res, resize_method, normalization, - apply_sky_clip) -> io.NodeOutput: + apply_sky_clip, mode) -> io.NodeOutput: + diffusion = model.model.diffusion_model + mode_val = mode["mode"] # "mono" or "multiview" + + # Capability check for sky clip — fires in both modes. + if apply_sky_clip and not diffusion.has_sky: + raise ValueError( + "apply_sky_clip=True requires a sky segmentation head, but the loaded " + "model does not have one. Set apply_sky_clip=False, or load a model " + "that includes a sky head (e.g. DA3-Mono-Large or DA3-Metric-Large)." + ) + + if mode_val == "mono": + return cls._execute_mono( + model, image, process_res, resize_method, + normalization, apply_sky_clip, + ) + + # Capability checks for multi-view pose. + pose_method = mode["pose_method"] + ref_view_strategy = mode["ref_view_strategy"] + + if pose_method == "cam_dec" and diffusion.cam_dec is None: + raise ValueError( + "pose_method='cam_dec' requires a camera decoder, but the loaded " + "model does not have one. Load a model with a camera decoder " + "(e.g. DA3-Small or DA3-Base), or set pose_method='ray_pose'." + ) + if pose_method == "ray_pose" and diffusion.head_type != "dualdpt": + raise ValueError( + "pose_method='ray_pose' requires a DualDPT head, but the loaded " + "model has a DPT head. Load a model with a DualDPT head " + "(e.g. DA3-Small or DA3-Base), or set pose_method='cam_dec'." + ) + + return cls._execute_multiview( + model, image, process_res, resize_method, + normalization, apply_sky_clip, + ref_view_strategy, pose_method, + ) + + @classmethod + def _execute_mono(cls, model, image, process_res, resize_method, + normalization, apply_sky_clip) -> io.NodeOutput: depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method) if apply_sky_clip and sky is not None: @@ -176,8 +278,8 @@ class DepthAnything3Depth(io.ComfyNode): if normalization == "v2_style": norm = torch.stack([ - da3_preprocess.normalize_depth_v2_style(depth[i], - sky[i] if sky is not None else None) + da3_preprocess.normalize_depth_v2_style( + depth[i], sky[i] if sky is not None else None) for i in range(depth.shape[0]) ], dim=0) elif normalization == "min_max": @@ -185,83 +287,21 @@ class DepthAnything3Depth(io.ComfyNode): else: norm = depth - # (B, H, W) -> (B, H, W, 3) grayscale IMAGE. out_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous() sky_mask = sky if sky is not None else torch.zeros_like(depth) conf_mask = confidence if confidence is not None else torch.zeros_like(depth) - return io.NodeOutput(out_image, sky_mask.contiguous(), conf_mask.contiguous()) - - -# ----------------------------------------------------------------------------- -# Raw depth output (useful for downstream metric work) -# ----------------------------------------------------------------------------- - - -class DepthAnything3MultiView(io.ComfyNode): - """Multi-view depth + pose estimation for DA3-Small / DA3-Base / DA3-Large. - - Treats each batch element of the input ``IMAGE`` as a separate view of - the same scene. The selected reference view is auto-chosen by the - backbone via ``ref_view_strategy`` (when at least 3 views are - supplied), unless camera extrinsics are provided -- in which case the - geometry is pinned by the user and no reordering is done. - - Output structure: - * ``depth_image`` -- per-view normalised depth as a stacked ``IMAGE`` - batch (one frame per view, original input order). - * ``confidence`` / ``sky`` -- per-view masks (zero when the variant - does not produce them). - * ``camera`` -- ``LATENT`` dict with keys:: - samples: (1, S, 1, h_p, w_p) -- raw depth packed as latent - type: "da3_multiview" - extrinsics: (1, S, 4, 4) world-to-camera matrices - intrinsics: (1, S, 3, 3) pixel-space intrinsics - depth_raw: (S, H, W) un-normalised depth - confidence: (S, H, W) - """ - - @classmethod - def define_schema(cls): - return io.Schema( - node_id="DepthAnything3MultiView", - display_name="Depth Anything 3 (Multi-View)", - category="image/depth", - inputs=[ - io.Model.Input("model"), - io.Image.Input("image", - tooltip="Image batch where each frame is a view of the same scene."), - io.Int.Input("process_res", default=504, min=140, max=2520, step=14, - tooltip="Longest-side target resolution (multiple of 14)."), - io.Combo.Input("resize_method", - options=["upper_bound_resize", "lower_bound_resize"], - default="upper_bound_resize"), - io.Combo.Input("ref_view_strategy", - options=["saddle_balanced", "saddle_sim_range", "first", "middle"], - default="saddle_balanced", - tooltip="Reference view selection (only applied when " - "S>=3 and no extrinsics are provided)."), - io.Combo.Input("pose_method", - options=["cam_dec", "ray_pose"], - default="cam_dec", - tooltip="cam_dec: small MLP on the final cam token (works for " - "all variants with cam_dec). ray_pose: RANSAC over the " - "DualDPT auxiliary ray output (DA3-Small/Base only)."), - io.Combo.Input("normalization", - options=["v2_style", "min_max", "raw"], - default="v2_style"), - ], - outputs=[ - io.Image.Output("depth_image"), - io.Mask.Output("confidence"), - io.Mask.Output("sky_mask"), - io.Latent.Output("camera", - tooltip="Per-view extrinsics + intrinsics + raw depth."), - ], + camera = {"samples": torch.zeros(1, 1, 1, 1, 1), "type": "mono"} + return io.NodeOutput( + out_image, + sky_mask.contiguous(), + conf_mask.contiguous(), + camera, ) @classmethod - def execute(cls, model, image, process_res, resize_method, ref_view_strategy, - pose_method, normalization) -> io.NodeOutput: + def _execute_multiview(cls, model, image, process_res, resize_method, + normalization, apply_sky_clip, + ref_view_strategy, pose_method) -> io.NodeOutput: assert image.ndim == 4 and image.shape[-1] == 3, \ f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}" S, H, W, _ = image.shape @@ -271,40 +311,41 @@ class DepthAnything3MultiView(io.ComfyNode): device = mm.get_torch_device() dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32 - # Stack all views as a single batch element with views axis = S. + # All views in a single forward pass: (1, S, 3, H', W'). x = image.to(device) x = da3_preprocess.preprocess_image(x, process_res=process_res, method=resize_method) - x = x.to(dtype=dtype).unsqueeze(0) # (1, S, 3, H', W') + x = x.to(dtype=dtype).unsqueeze(0) use_ray_pose = (pose_method == "ray_pose") with torch.no_grad(): out = diffusion(x, use_ray_pose=use_ray_pose, ref_view_strategy=ref_view_strategy) - # ``out["depth"]`` is (S, h_p, w_p); resize back to (S, H, W). - depth_lr = out["depth"].float() depth = torch.nn.functional.interpolate( - depth_lr.unsqueeze(1), size=(H, W), + out["depth"].float().unsqueeze(1), size=(H, W), mode="bilinear", align_corners=False, ).squeeze(1).cpu() + conf = torch.zeros_like(depth) if "depth_conf" in out: conf = torch.nn.functional.interpolate( out["depth_conf"].unsqueeze(1).float(), size=(H, W), mode="bilinear", align_corners=False, ).squeeze(1).cpu() - else: - conf = torch.zeros_like(depth) + sky = torch.zeros_like(depth) if "sky" in out: sky = torch.nn.functional.interpolate( out["sky"].unsqueeze(1).float(), size=(H, W), mode="bilinear", align_corners=False, ).squeeze(1).cpu() - else: - sky = torch.zeros_like(depth) - # Pose. Defaults to identity when neither cam_dec nor ray_pose is wired up. + if apply_sky_clip and "sky" in out: + depth = torch.stack([ + da3_preprocess.apply_sky_aware_clip(depth[i], sky[i]) + for i in range(S) + ], dim=0) + if "extrinsics" in out and "intrinsics" in out: extrinsics = out["extrinsics"].float().cpu() intrinsics = out["intrinsics"].float().cpu() @@ -312,11 +353,11 @@ class DepthAnything3MultiView(io.ComfyNode): extrinsics = torch.eye(4)[None, None].expand(1, S, 4, 4).clone() intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone() - # Normalised depth viz per view (same path as the mono node). + sky_for_norm = sky if diffusion.has_sky else None if normalization == "v2_style": norm = torch.stack([ - da3_preprocess.normalize_depth_v2_style(depth[i], - sky[i] if "sky" in out else None) + da3_preprocess.normalize_depth_v2_style( + depth[i], sky_for_norm[i] if sky_for_norm is not None else None) for i in range(S) ], dim=0) elif normalization == "min_max": @@ -327,8 +368,6 @@ class DepthAnything3MultiView(io.ComfyNode): depth_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous() camera_latent = { - # The Latent contract requires a ``samples`` field; pack the raw - # depth there so a downstream node still has a tensor to chain on. "samples": depth.unsqueeze(0).unsqueeze(2).contiguous(), # (1, S, 1, H, W) "type": "da3_multiview", "extrinsics": extrinsics.contiguous(), @@ -338,59 +377,18 @@ class DepthAnything3MultiView(io.ComfyNode): } return io.NodeOutput( depth_image, - conf.contiguous(), sky.contiguous(), + conf.contiguous(), camera_latent, ) -class DepthAnything3DepthRaw(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="DepthAnything3DepthRaw", - display_name="Depth Anything 3 (Raw Depth)", - category="image/depth", - inputs=[ - io.Model.Input("model"), - io.Image.Input("image"), - io.Int.Input("process_res", default=504, min=140, max=2520, step=14), - io.Combo.Input("resize_method", - options=["upper_bound_resize", "lower_bound_resize"], - default="upper_bound_resize"), - ], - outputs=[ - io.Mask.Output("depth", - tooltip="Raw depth values (no normalisation, no clipping)."), - io.Mask.Output("confidence"), - io.Mask.Output("sky"), - ], - ) - - @classmethod - def execute(cls, model, image, process_res, resize_method) -> io.NodeOutput: - depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method) - zeros = torch.zeros_like(depth) - return io.NodeOutput( - depth.contiguous(), - (confidence if confidence is not None else zeros).contiguous(), - (sky if sky is not None else zeros).contiguous(), - ) - - -# ----------------------------------------------------------------------------- -# Extension registration -# ----------------------------------------------------------------------------- - - class DepthAnything3Extension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ LoadDepthAnything3, - DepthAnything3Depth, - DepthAnything3DepthRaw, - DepthAnything3MultiView, + DepthAnything3, ]