From 83e2321d50963d734d254fc464f78684a88baee8 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Thu, 21 May 2026 20:27:44 +0200 Subject: [PATCH] Refactor DepthAnything3 into DepthAnything3Inference and DepthAnything3Render. --- comfy_extras/nodes_depth_anything_3.py | 308 +++++++++++++------------ 1 file changed, 156 insertions(+), 152 deletions(-) diff --git a/comfy_extras/nodes_depth_anything_3.py b/comfy_extras/nodes_depth_anything_3.py index 55a95a022..5b2929fdf 100644 --- a/comfy_extras/nodes_depth_anything_3.py +++ b/comfy_extras/nodes_depth_anything_3.py @@ -5,10 +5,11 @@ Adds these nodes: * ``LoadDepthAnything3`` -- load a DA3 ``.safetensors`` file from the ``models/geometry_estimation/`` folder. * ``DepthAnything3`` -- unified depth estimation node supporting both mono and - multi-view modes via a DynamicCombo selector. Returns a single DA3_GEOMETRY - dict containing raw depth, normalised depth image, source image, and - optionally sky/mask (Mono/Metric), confidence (Small/Base), and - extrinsics/intrinsics (multi-view). Compatible with MoGe Render. + multi-view modes via a DynamicCombo selector. Returns a DA3_GEOMETRY dict of + raw tensors (depth, sky, confidence, camera). Feed into ``DepthAnything3Render`` + to produce display images, or directly into ``MoGeRender`` for depth / mask views. +* ``DepthAnything3Render`` -- post-processes a DA3_GEOMETRY dict: applies optional + sky clipping, normalises depth and confidence, and returns display images. Model capability matrix ----------------------- @@ -41,24 +42,22 @@ DA3Geometry = io.Custom("DA3_GEOMETRY") # DA3_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them): # # Per-frame tensors — B = batch size in mono mode; B = S (number of views) in multi-view mode. -# "depth": torch.Tensor (B, H, W) -- raw depth (always present) -# "depth_image": torch.Tensor (B, H, W, 3) -- normalised depth for display (always present) +# "depth": torch.Tensor (B, H, W) -- raw model depth (always present; matches MoGe convention) # "image": torch.Tensor (B, H, W, 3) -- source image in [0, 1], CPU (always present) # "mode": str -- "mono" or "multiview" (always present) # "sky": torch.Tensor (B, H, W) -- sky probability in [0, 1] (Mono/Metric variants only) -# "mask": torch.Tensor (B, H, W) bool -- True = valid foreground / False = sky (present when sky head available) -# "confidence": torch.Tensor (B, H, W) -- normalised depth confidence in [0, 1] (Small/Base variants only) +# "confidence": torch.Tensor (B, H, W) -- raw model confidence output (Small/Base variants only) # # Multi-view only — S = number of views; the leading 1 is the scene dimension from the model. # "extrinsics": torch.Tensor (1, S, 4, 4) -- world-to-camera matrices # "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics -class LoadDepthAnything3(io.ComfyNode): +class LoadDepthAnything3Model(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="LoadDepthAnything3", + node_id="LoadDepthAnything3Model", display_name="Load Depth Anything 3", category="loaders", inputs=[ @@ -90,26 +89,6 @@ class LoadDepthAnything3(io.ComfyNode): return io.NodeOutput(model) -def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor: - """Map raw confidence (expp1 activaton, range [1, ∞)) to [0, 1] per image. - - The model uses ``exp(x) + 1`` so every pixel is guaranteed to be ≥ 1. - Min-max normalization per image preserves the spatial pattern (high - confidence = brighter) while producing a valid mask in [0, 1]. - """ - B = conf.shape[0] - out = [] - for i in range(B): - c = conf[i] - c_min = c.min() - c_max = c.max() - if c_max > c_min: - out.append((c - c_min) / (c_max - c_min)) - else: - out.append(torch.ones_like(c)) - return torch.stack(out, dim=0) - - def _run_da3(model_patcher, image: torch.Tensor, process_res: int, method: str = "upper_bound_resize"): """Run DA3 on ``(B,H,W,3)`` IMAGE; returns depth/conf/sky at original resolution (or None).""" @@ -156,29 +135,16 @@ def _run_da3(model_patcher, image: torch.Tensor, process_res: int, return depth, confidence, sky -class DepthAnything3(io.ComfyNode): - """Unified Depth Anything 3 node. +class DepthAnything3Inference(io.ComfyNode): + """Raw Depth Anything 3 inference node. - Returns a single DA3_GEOMETRY dict containing all useful outputs. - See the DA3_GEOMETRY comment block near the top of this module for the full key listing. + Outputs a DA3_GEOMETRY dict of raw tensors. All display normalization + (sky clipping, depth scaling, confidence normalisation) is handled by + the companion ``DepthAnything3Render`` node. - Mono mode - --------- - Runs the model on each batch element independently. - - Multi-view mode - --------------- - Treats every batch element as a separate view of the same scene. - Runs all views in a single forward pass so cross-view attention can - establish geometric consistency. Adds ``extrinsics`` and ``intrinsics`` - to the geometry dict. - - Capability errors - ----------------- - A ``ValueError`` is raised immediately when a parameter requires a - model feature that is absent in the loaded checkpoint (e.g. - ``apply_sky_clip=True`` on DA3-Small/Base which has no sky head, - or ``pose_method='cam_dec'`` on a monocular model). + Mono mode: each batch element is processed independently. + Multi-view mode: all frames share a single forward pass with cross-view + attention; adds ``extrinsics`` and ``intrinsics`` to the geometry dict. """ @classmethod @@ -206,18 +172,6 @@ class DepthAnything3(io.ComfyNode): "(caps memory, default). " "lower_bound_resize: scale so the shortest side = process_res " "(preserves more detail on tall/wide images, uses more memory)."), - io.Combo.Input("normalization", - options=["v2_style", "min_max", "raw"], - default="v2_style", - tooltip="How to map raw depth to [0, 1] for the output image. " - "'v2_style': normalizes using mean and std for perceptually balanced results (default). " - "'min_max': stretches the full depth range to [0, 1] for maximum contrast. " - "'raw': preserves absolute values — use this to keep metric units when running DA3-Metric-Large."), - io.Boolean.Input("apply_sky_clip", default=False, - tooltip="Clip sky-region depth to the 99th percentile before " - "normalisation. Requires a sky segmentation head " - "(DA3-Mono-Large or DA3-Metric-Large). " - "Raises an error on DA3-Small/Base."), io.DynamicCombo.Input("mode", tooltip="mono: single image or independent batch — " "use with any model. " @@ -254,36 +208,24 @@ class DepthAnything3(io.ComfyNode): ], outputs=[ DA3Geometry.Output("geometry", - tooltip="DA3_GEOMETRY dict. Always contains: " - "'depth' (raw), 'depth_image' (normalised), 'image' (source), 'mode'. " - "Optional: 'sky' + 'mask' (Mono/Metric variants), " - "'confidence' (Small/Base variants), " - "'extrinsics' + 'intrinsics' (multi-view only). " - "Compatible with MoGe Render for depth and mask visualisation."), + tooltip="DA3_GEOMETRY dict of raw tensors. " + "Always: 'depth' (B,H,W), 'image', 'mode'. " + "Optional: 'sky' + 'mask' (Mono/Metric), " + "'confidence' raw (Small/Base), " + "'extrinsics' + 'intrinsics' (multi-view). " + "Feed into DepthAnything3Render or MoGeRender."), ], ) @classmethod - def execute(cls, model, image, process_res, resize_method, normalization, - apply_sky_clip, mode) -> io.NodeOutput: - diffusion = model.model.diffusion_model + def execute(cls, model, image, process_res, resize_method, mode) -> io.NodeOutput: mode_val = mode["mode"] # "mono" or "multiview" - # Capability check for sky clip — fires in both modes. - if apply_sky_clip and not diffusion.has_sky: - raise ValueError( - "apply_sky_clip=True requires a sky segmentation head, but the loaded " - "model does not have one. Set apply_sky_clip=False, or load a model " - "that includes a sky head (e.g. DA3-Mono-Large or DA3-Metric-Large)." - ) - if mode_val == "mono": - return cls._execute_mono( - model, image, process_res, resize_method, - normalization, apply_sky_clip, - ) + return cls._execute_mono(model, image, process_res, resize_method) # Capability checks for multi-view pose. + diffusion = model.model.diffusion_model pose_method = mode["pose_method"] ref_view_strategy = mode["ref_view_strategy"] @@ -302,70 +244,26 @@ class DepthAnything3(io.ComfyNode): return cls._execute_multiview( model, image, process_res, resize_method, - normalization, apply_sky_clip, ref_view_strategy, pose_method, ) - @staticmethod - def _apply_sky_clip(depth: torch.Tensor, sky: torch.Tensor) -> torch.Tensor: - return torch.stack([ - da3_preprocess.apply_sky_aware_clip(depth[i], sky[i]) - for i in range(depth.shape[0]) - ], dim=0) - - @staticmethod - def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None, - normalization: str) -> torch.Tensor: - """Normalise depth and pack as an (N,H,W,3) image tensor. - - Preserves metric units when normalization is 'raw' (no clamping). - """ - N = depth.shape[0] - if normalization == "v2_style": - norm = torch.stack([ - da3_preprocess.normalize_depth_v2_style( - depth[i], sky_for_norm[i] if sky_for_norm is not None else None) - for i in range(N) - ], dim=0) - elif normalization == "min_max": - norm = da3_preprocess.normalize_depth_min_max(depth) - else: - norm = depth - - # Preserve metric units when normalization is raw. - out = norm.unsqueeze(-1).repeat(1, 1, 1, 3) - if normalization != "raw": - out = out.clamp(0.0, 1.0) - return out.contiguous() - @classmethod - def _execute_mono(cls, model, image, process_res, resize_method, - normalization, apply_sky_clip) -> io.NodeOutput: + def _execute_mono(cls, model, image, process_res, resize_method) -> io.NodeOutput: depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method) - if apply_sky_clip and sky is not None: - depth = cls._apply_sky_clip(depth, sky) - - depth_image = cls._depth_to_image(depth, sky, normalization) - geometry: dict = { "depth": depth.contiguous(), - "depth_image": depth_image, "image": image[..., :3].cpu(), "mode": "mono", } if sky is not None: geometry["sky"] = sky.contiguous() - # True = valid foreground, False = sky/invalid — matches MoGe mask semantics. - geometry["mask"] = (sky < 0.5).contiguous() if confidence is not None: geometry["confidence"] = confidence.contiguous() - geometry["confidence_image"] = _normalize_confidence(confidence).contiguous() return io.NodeOutput(geometry) @classmethod def _execute_multiview(cls, model, image, process_res, resize_method, - normalization, apply_sky_clip, ref_view_strategy, pose_method) -> io.NodeOutput: assert image.ndim == 4 and image.shape[-1] == 3, \ f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}" @@ -391,15 +289,6 @@ class DepthAnything3(io.ComfyNode): mode="bilinear", align_corners=False, ).squeeze(1).cpu() - conf_raw = torch.zeros_like(depth) - if "depth_conf" in out: - conf_raw = torch.nn.functional.interpolate( - out["depth_conf"].unsqueeze(1).float(), size=(H, W), - mode="bilinear", align_corners=False, - ).squeeze(1).cpu() - - conf_mask = _normalize_confidence(conf_raw) if conf_raw.any() else conf_raw - sky = None if "sky" in out: sky = torch.nn.functional.interpolate( @@ -407,9 +296,6 @@ class DepthAnything3(io.ComfyNode): mode="bilinear", align_corners=False, ).squeeze(1).cpu() - if apply_sky_clip and sky is not None: - depth = cls._apply_sky_clip(depth, sky) - if "extrinsics" in out and "intrinsics" in out: extrinsics = out["extrinsics"].float().cpu() intrinsics = out["intrinsics"].float().cpu() @@ -417,12 +303,8 @@ class DepthAnything3(io.ComfyNode): extrinsics = torch.eye(4)[None, None].expand(1, S, 4, 4).clone() intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone() - sky_for_norm = sky if diffusion.has_sky else None - depth_image = cls._depth_to_image(depth, sky_for_norm, normalization) - geometry: dict = { "depth": depth.contiguous(), - "depth_image": depth_image, "image": image[..., :3].cpu(), "mode": "multiview", "extrinsics": extrinsics.contiguous(), @@ -430,20 +312,142 @@ class DepthAnything3(io.ComfyNode): } if sky is not None: geometry["sky"] = sky.contiguous() - # True = valid foreground, False = sky/invalid — matches MoGe mask semantics. - geometry["mask"] = (sky < 0.5).contiguous() - if conf_raw.any(): - geometry["confidence"] = conf_mask.contiguous() - geometry["confidence_image"] = _normalize_confidence(conf_mask).contiguous() + if "depth_conf" in out: + conf = torch.nn.functional.interpolate( + out["depth_conf"].unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + geometry["confidence"] = conf.contiguous() return io.NodeOutput(geometry) +class DepthAnything3Render(io.ComfyNode): + """Visualise a DA3_GEOMETRY packet as a single image. + + Mirrors the MoGeRender interface: one ``output`` selector, one IMAGE out. + Use multiple nodes in parallel to get depth + sky + confidence simultaneously. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DepthAnything3Render", + display_name="Depth Anything 3 Render", + category="image/geometry_estimation", + description="Visualise a DA3_GEOMETRY packet. Drop multiple nodes to get different views simultaneously.", + inputs=[ + DA3Geometry.Input("geometry"), + io.DynamicCombo.Input("output", + tooltip="depth: normalised depth image. " + "sky_mask: sky probability in [0, 1] (Mono/Metric variants only). " + "confidence: normalised depth confidence (Small/Base variants only).", + options=[ + io.DynamicCombo.Option("depth", [ + io.Combo.Input("normalization", + options=["v2_style", "min_max", "raw"], + default="v2_style", + tooltip="'v2_style': mean/std normalisation for perceptually balanced results (default). " + "'min_max': stretches the full depth range to [0, 1] for maximum contrast. " + "'raw': no scaling — preserves metric units for DA3-Metric-Large."), + io.Boolean.Input("apply_sky_clip", default=False, + tooltip="Clip sky-region depth to the 99th percentile of foreground depth before " + "normalisation. Requires a 'sky' tensor in the geometry " + "(DA3-Mono-Large or DA3-Metric-Large); raises an error otherwise."), + ]), + io.DynamicCombo.Option("sky_mask", []), + io.DynamicCombo.Option("confidence", []), + ]), + ], + outputs=[io.Image.Output()], + ) + + @classmethod + def execute(cls, geometry, output) -> io.NodeOutput: + output_val = output["output"] + + if output_val == "depth": + normalization = output["normalization"] + apply_sky_clip = output["apply_sky_clip"] + if apply_sky_clip and "sky" not in geometry: + raise ValueError( + "apply_sky_clip=True requires a sky tensor in the geometry, but none is present. " + "Run with DA3-Mono-Large or DA3-Metric-Large, or set apply_sky_clip=False." + ) + depth = geometry["depth"] + sky = geometry.get("sky") + if apply_sky_clip and sky is not None: + depth = torch.stack([ + da3_preprocess.apply_sky_aware_clip(depth[i], sky[i]) + for i in range(depth.shape[0]) + ], dim=0) + result = cls._depth_to_image(depth, sky, normalization) + + elif output_val == "sky_mask": + if "sky" not in geometry: + raise ValueError("geometry has no sky output; run with DA3-Mono-Large or DA3-Metric-Large.") + sky = geometry["sky"] + result = sky.unsqueeze(-1).expand(*sky.shape, 3).contiguous() + + elif output_val == "confidence": + if "confidence" not in geometry: + raise ValueError("geometry has no confidence output; run with DA3-Small or DA3-Base.") + result = cls._normalize_confidence(geometry["confidence"]) + result = result.unsqueeze(-1).expand(*result.shape, 3).contiguous() + + else: + raise ValueError(f"Unknown output mode: {output_val}") + + return io.NodeOutput(result.float()) + + @staticmethod + def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None, + normalization: str) -> torch.Tensor: + """Normalise depth and pack as an (B,H,W,3) image tensor.""" + N = depth.shape[0] + if normalization == "v2_style": + norm = torch.stack([ + da3_preprocess.normalize_depth_v2_style( + depth[i], sky_for_norm[i] if sky_for_norm is not None else None) + for i in range(N) + ], dim=0) + elif normalization == "min_max": + norm = da3_preprocess.normalize_depth_min_max(depth) + else: + norm = depth + + out = norm.unsqueeze(-1).repeat(1, 1, 1, 3) + if normalization != "raw": + out = out.clamp(0.0, 1.0) + return out.contiguous() + + @staticmethod + def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor: + """Map raw confidence (expp1 activaton, range [1, ∞)) to [0, 1] per image. + + The model uses ``exp(x) + 1`` so every pixel is guaranteed to be ≥ 1. + Min-max normalization per image preserves the spatial pattern (high + confidence = brighter) while producing a valid mask in [0, 1]. + """ + B = conf.shape[0] + out = [] + for i in range(B): + c = conf[i] + c_min = c.min() + c_max = c.max() + if c_max > c_min: + out.append((c - c_min) / (c_max - c_min)) + else: + out.append(torch.ones_like(c)) + return torch.stack(out, dim=0) + + class DepthAnything3Extension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - LoadDepthAnything3, - DepthAnything3, + LoadDepthAnything3Model, + DepthAnything3Inference, + DepthAnything3Render, ]