mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 19:17:32 +08:00
185 lines
6.5 KiB
Python
185 lines
6.5 KiB
Python
# Input/output preprocessing helpers for Depth Anything 3.
|
|
#
|
|
# Ported from:
|
|
# src/depth_anything_3/utils/io/input_processor.py (image normalisation)
|
|
# src/depth_anything_3/utils/alignment.py (sky-aware depth clip)
|
|
# src/depth_anything_3/model/da3.py::_process_mono_sky_estimation
|
|
#
|
|
# Resize: ``comfy.utils.common_upscale`` with ``upscale_method="lanczos"``.
|
|
# Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale); a sweep
|
|
# across {bilinear, bicubic, area, lanczos, bislerp} on a 768->504 test image
|
|
# showed lanczos has the lowest max-abs-diff vs the upstream cv2 output
|
|
# (~0.13 vs 0.21-0.71 for the others), so we use it in both directions for
|
|
# simplicity. This keeps the path stateless, on-device, and free of any
|
|
# OpenCV dependency.
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
|
|
import comfy.utils
|
|
|
|
PATCH_SIZE = 14
|
|
|
|
# ImageNet normalization constants used during DA3 training.
|
|
_IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406])
|
|
_IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225])
|
|
|
|
|
|
def _round_to_patch(x: int, patch: int = PATCH_SIZE) -> int:
|
|
down = (x // patch) * patch
|
|
up = down + patch
|
|
return up if abs(up - x) <= abs(x - down) else down
|
|
|
|
|
|
def compute_target_size(orig_h: int, orig_w: int, process_res: int,
|
|
method: str = "upper_bound_resize") -> Tuple[int, int]:
|
|
"""Compute (target_h, target_w) for a single image.
|
|
|
|
Methods:
|
|
- "upper_bound_resize": scale longest side to ``process_res``, then
|
|
round each dim to nearest multiple of 14 (default upstream method).
|
|
- "lower_bound_resize": scale shortest side to ``process_res``, then
|
|
round.
|
|
"""
|
|
if method == "upper_bound_resize":
|
|
longest = max(orig_h, orig_w)
|
|
scale = process_res / float(longest)
|
|
elif method == "lower_bound_resize":
|
|
shortest = min(orig_h, orig_w)
|
|
scale = process_res / float(shortest)
|
|
else:
|
|
raise ValueError(f"Unsupported process_res_method: {method}")
|
|
|
|
new_w = max(1, _round_to_patch(int(round(orig_w * scale))))
|
|
new_h = max(1, _round_to_patch(int(round(orig_h * scale))))
|
|
return new_h, new_w
|
|
|
|
|
|
def preprocess_image(
|
|
image: torch.Tensor,
|
|
process_res: int = 504,
|
|
method: str = "upper_bound_resize",
|
|
) -> torch.Tensor:
|
|
"""Preprocess a ComfyUI ``IMAGE`` batch for DA3.
|
|
|
|
Args:
|
|
image: ``(B, H, W, 3)`` float in [0, 1] (ComfyUI ``IMAGE`` convention).
|
|
process_res: target resolution (longest or shortest side, depending
|
|
on ``method``).
|
|
method: resize strategy.
|
|
|
|
Returns:
|
|
``(B, 3, H', W')`` tensor with H' and W' multiples of 14, normalised
|
|
with ImageNet statistics. The tensor lives on the same device as
|
|
``image``.
|
|
"""
|
|
assert image.ndim == 4 and image.shape[-1] == 3, \
|
|
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
|
B, H, W, _ = image.shape
|
|
target_h, target_w = compute_target_size(H, W, process_res, method)
|
|
|
|
# (B, H, W, 3) -> (B, 3, H, W)
|
|
x = image.movedim(-1, 1).contiguous()
|
|
if (target_h, target_w) != (H, W):
|
|
# Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale).
|
|
# Lanczos in ``common_upscale`` is anti-aliased and produces the
|
|
# closest pixel-wise match in a sweep across {bilinear, bicubic,
|
|
# area, lanczos, bislerp}. Used in both directions for simplicity.
|
|
x = comfy.utils.common_upscale(
|
|
x.float(), target_w, target_h, "lanczos", "disabled",
|
|
)
|
|
x = x.clamp(0.0, 1.0)
|
|
|
|
mean = _IMAGENET_MEAN.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
|
std = _IMAGENET_STD.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
|
x = (x - mean) / std
|
|
return x
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Output post-processing (sky-aware clipping for Mono/Metric variants)
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
def compute_non_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor:
|
|
"""Boolean mask: True for non-sky pixels (sky probability < threshold)."""
|
|
return sky_prediction < threshold
|
|
|
|
|
|
def apply_sky_aware_clip(
|
|
depth: torch.Tensor,
|
|
sky: torch.Tensor,
|
|
threshold: float = 0.3,
|
|
quantile: float = 0.99,
|
|
) -> torch.Tensor:
|
|
"""Replicates ``_process_mono_sky_estimation`` from upstream.
|
|
|
|
Clips sky regions to the 99th percentile of non-sky depth. Returns a new
|
|
depth tensor; ``depth`` is not modified in place.
|
|
"""
|
|
non_sky = compute_non_sky_mask(sky, threshold=threshold)
|
|
if non_sky.sum() <= 10 or (~non_sky).sum() <= 10:
|
|
return depth.clone()
|
|
|
|
non_sky_depth = depth[non_sky]
|
|
if non_sky_depth.numel() > 100_000:
|
|
idx = torch.randint(0, non_sky_depth.numel(), (100_000,), device=non_sky_depth.device)
|
|
sampled = non_sky_depth[idx]
|
|
else:
|
|
sampled = non_sky_depth
|
|
|
|
max_depth = torch.quantile(sampled, quantile)
|
|
out = depth.clone()
|
|
out[~non_sky] = max_depth
|
|
return out
|
|
|
|
|
|
def normalize_depth_v2_style(
|
|
depth: torch.Tensor,
|
|
sky: torch.Tensor | None = None,
|
|
low_quantile: float = 0.01,
|
|
high_quantile: float = 0.99,
|
|
) -> torch.Tensor:
|
|
"""V2-style normalization for ControlNet workflows.
|
|
|
|
Computes percentile bounds over non-sky pixels (when available),
|
|
then maps depth into [0, 1] with near = white (1.0).
|
|
"""
|
|
if sky is not None:
|
|
mask = compute_non_sky_mask(sky)
|
|
if mask.any():
|
|
valid = depth[mask]
|
|
else:
|
|
valid = depth.flatten()
|
|
else:
|
|
valid = depth.flatten()
|
|
|
|
if valid.numel() > 100_000:
|
|
idx = torch.randint(0, valid.numel(), (100_000,), device=valid.device)
|
|
sample = valid[idx]
|
|
else:
|
|
sample = valid
|
|
|
|
lo = torch.quantile(sample, low_quantile)
|
|
hi = torch.quantile(sample, high_quantile)
|
|
rng = (hi - lo).clamp(min=1e-6)
|
|
norm = ((depth - lo) / rng).clamp(0.0, 1.0)
|
|
# ControlNet convention: nearer pixels are brighter (1.0).
|
|
norm = 1.0 - norm
|
|
if sky is not None:
|
|
# Sky pixels become black (far / unknown).
|
|
sky_mask = ~compute_non_sky_mask(sky)
|
|
norm = torch.where(sky_mask, torch.zeros_like(norm), norm)
|
|
return norm
|
|
|
|
|
|
def normalize_depth_min_max(depth: torch.Tensor) -> torch.Tensor:
|
|
"""Simple per-frame min/max normalization with near=1.0 convention."""
|
|
lo = depth.amin(dim=(-2, -1), keepdim=True)
|
|
hi = depth.amax(dim=(-2, -1), keepdim=True)
|
|
rng = (hi - lo).clamp(min=1e-6)
|
|
return 1.0 - ((depth - lo) / rng).clamp(0.0, 1.0)
|