diff --git a/comfy/ldm/depth_anything_3/__init__.py b/comfy/ldm/depth_anything_3/__init__.py new file mode 100644 index 000000000..0bc31314e --- /dev/null +++ b/comfy/ldm/depth_anything_3/__init__.py @@ -0,0 +1,7 @@ +# Depth Anything 3 - native ComfyUI port (Apache-2.0 monocular variants only). +# +# Supported variants: +# DA3-Small, DA3-Base (vits/vitb backbone, DualDPT head) +# DA3Mono-Large, DA3Metric-Large (vitl backbone, DPT head + sky mask) +# +# Original repo: https://github.com/ByteDance-Seed/Depth-Anything-3 diff --git a/comfy/ldm/depth_anything_3/dinov2.py b/comfy/ldm/depth_anything_3/dinov2.py new file mode 100644 index 000000000..032cecda4 --- /dev/null +++ b/comfy/ldm/depth_anything_3/dinov2.py @@ -0,0 +1,497 @@ +# DINOv2 backbone for Depth Anything 3 (monocular inference path). +# +# Why not reuse ``comfy/image_encoders/dino2.py``? +# The existing ``Dinov2Model`` is a vanilla HuggingFace-style DINOv2 with a +# different state-dict layout (separate Q/K/V, ``embeddings.*`` / +# ``encoder.layer.*`` keys, ``layer_scale1.lambda1``) and no support for the +# architectural extensions DA3 adds on top of DINOv2 (RoPE, QK-norm, +# alternating local/global attention, concatenated camera token). Loading +# raw DA3 HF safetensors into ``Dinov2Model`` would require splitting +# ``attn.qkv`` weights and a large rename map for every block, and we'd +# still need to write the DA3 extensions separately. Keeping the upstream +# ``pretrained.*`` key layout here means HF weights load directly with no +# conversion step. +# +# Ported from the upstream repo at: +# src/depth_anything_3/model/dinov2/{dinov2,vision_transformer}.py +# src/depth_anything_3/model/dinov2/layers/* +# +# DA3 extensions on top of vanilla DINOv2 (only used by Small/Base variants): +# - 2D Rotary Position Embedding starting at ``rope_start`` +# - QK-norm starting at ``qknorm_start`` +# - Alternating local/global attention blocks starting at ``alt_start`` +# - Camera-conditioning token concatenated to features (``cat_token=True``), +# with a learned parameter ``camera_token`` injected at block +# ``alt_start`` when no external camera token is supplied. +# +# For the Mono/Metric variants the configuration disables all of the above +# (alt_start/qknorm_start/rope_start = -1, cat_token=False) so this module +# collapses to a vanilla DINOv2-ViT encoder. + +from __future__ import annotations + +import math +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +# ----------------------------------------------------------------------------- +# 2D rotary position embedding +# ----------------------------------------------------------------------------- + + +class PositionGetter: + def __init__(self): + self._cache: dict[tuple[int, int], torch.Tensor] = {} + + def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor: + key = (height, width) + if key not in self._cache: + y = torch.arange(height, device=device) + x = torch.arange(width, device=device) + self._cache[key] = torch.cartesian_prod(y, x) + cached = self._cache[key] + return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone() + + +class RotaryPositionEmbedding2D(nn.Module): + def __init__(self, frequency: float = 100.0): + super().__init__() + self.base_frequency = frequency + self._freq_cache: dict = {} + + def _components(self, dim: int, seq_len: int, device, dtype): + key = (dim, seq_len, device, dtype) + if key not in self._freq_cache: + exp = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency ** exp) + pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + ang = torch.einsum("i,j->ij", pos, inv_freq) + ang = ang.to(dtype) + ang = torch.cat((ang, ang), dim=-1) + self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype)) + return self._freq_cache[key] + + @staticmethod + def _rotate(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] + x1, x2 = x[..., : d // 2], x[..., d // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d(self, tokens, positions, cos_c, sin_c): + cos = F.embedding(positions, cos_c)[:, None, :, :] + sin = F.embedding(positions, sin_c)[:, None, :, :] + return (tokens * cos) + (self._rotate(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + feature_dim = tokens.size(-1) // 2 + max_pos = int(positions.max()) + 1 + cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype) + v, h = tokens.chunk(2, dim=-1) + v = self._apply_1d(v, positions[..., 0], cos_c, sin_c) + h = self._apply_1d(h, positions[..., 1], cos_c, sin_c) + return torch.cat((v, h), dim=-1) + + +# ----------------------------------------------------------------------------- +# Patch embed / MLP / SwiGLU / LayerScale +# ----------------------------------------------------------------------------- + + +class PatchEmbed(nn.Module): + def __init__(self, patch_size=14, in_chans=3, embed_dim=384, + device=None, dtype=None, operations=None): + super().__init__() + self.patch_size = (patch_size, patch_size) + self.proj = operations.Conv2d( + in_chans, embed_dim, + kernel_size=self.patch_size, stride=self.patch_size, + device=device, dtype=dtype, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, _, H, W = x.shape + ph, pw = self.patch_size + assert H % ph == 0 and W % pw == 0 + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, bias=True, + act_layer=nn.GELU, device=None, dtype=None, operations=None): + super().__init__() + hidden_features = hidden_features or in_features + self.fc1 = operations.Linear(in_features, hidden_features, bias=bias, device=device, dtype=dtype) + self.act = act_layer() + self.fc2 = operations.Linear(hidden_features, in_features, bias=bias, device=device, dtype=dtype) + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) + + +class SwiGLUFFNFused(nn.Module): + """SwiGLU FFN matching upstream xformers.ops.SwiGLU layout (used for vitg).""" + + def __init__(self, in_features, hidden_features=None, bias=True, + device=None, dtype=None, operations=None): + super().__init__() + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + # NOTE: xformers SwiGLU stores w12 as a single fused Linear (in, 2*hidden); + # split-by-half at forward time. We don't currently need this for the + # Apache-2.0 variants but keep it for parity with the upstream key names. + self.w12 = operations.Linear(in_features, 2 * hidden_features, bias=bias, device=device, dtype=dtype) + self.w3 = operations.Linear(hidden_features, in_features, bias=bias, device=device, dtype=dtype) + + def forward(self, x): + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + return self.w3(F.silu(x1) * x2) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values: float = 1e-5, device=None, dtype=None): + super().__init__() + self.gamma = nn.Parameter(init_values * torch.ones(dim, device=device, dtype=dtype)) + + def forward(self, x): + return x * comfy_cast(self.gamma, x) + + +def comfy_cast(p: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: + """Cast a parameter to match the reference tensor's device/dtype.""" + if p.device != ref.device or p.dtype != ref.dtype: + return p.to(device=ref.device, dtype=ref.dtype) + return p + + +# ----------------------------------------------------------------------------- +# Attention + Block +# ----------------------------------------------------------------------------- + + +class Attention(nn.Module): + def __init__(self, dim, num_heads: int, qkv_bias: bool = True, proj_bias: bool = True, + qk_norm: bool = False, rope: Optional[RotaryPositionEmbedding2D] = None, + device=None, dtype=None, operations=None): + super().__init__() + assert dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype) + self.q_norm = operations.LayerNorm(self.head_dim, device=device, dtype=dtype) if qk_norm else nn.Identity() + self.k_norm = operations.LayerNorm(self.head_dim, device=device, dtype=dtype) if qk_norm else nn.Identity() + self.proj = operations.Linear(dim, dim, bias=proj_bias, device=device, dtype=dtype) + self.rope = rope + + def forward(self, x: torch.Tensor, pos: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q, k = self.q_norm(q), self.k_norm(k) + if self.rope is not None and pos is not None: + q = self.rope(q, pos) + k = self.rope(k, pos) + x = F.scaled_dot_product_attention( + q, k, v, + attn_mask=(attn_mask[:, None].repeat(1, self.num_heads, 1, 1) + if attn_mask is not None else None), + ) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + return x + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio: float = 4.0, + qkv_bias: bool = True, proj_bias: bool = True, ffn_bias: bool = True, + init_values: Optional[float] = 1.0, + norm_layer=nn.LayerNorm, + ffn_layer=Mlp, + qk_norm: bool = False, + rope: Optional[RotaryPositionEmbedding2D] = None, + ln_eps: float = 1e-6, + device=None, dtype=None, operations=None): + super().__init__() + self.norm1 = operations.LayerNorm(dim, eps=ln_eps, device=device, dtype=dtype) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, + qk_norm=qk_norm, rope=rope, device=device, dtype=dtype, operations=operations, + ) + self.ls1 = (LayerScale(dim, init_values=init_values, device=device, dtype=dtype) + if init_values else nn.Identity()) + + self.norm2 = operations.LayerNorm(dim, eps=ln_eps, device=device, dtype=dtype) + mlp_hidden = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, hidden_features=mlp_hidden, + bias=ffn_bias, device=device, dtype=dtype, operations=operations, + ) + self.ls2 = (LayerScale(dim, init_values=init_values, device=device, dtype=dtype) + if init_values else nn.Identity()) + + def forward(self, x, pos=None, attn_mask=None): + x = x + self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask)) + x = x + self.ls2(self.mlp(self.norm2(x))) + return x + + +# ----------------------------------------------------------------------------- +# DINOv2 vision transformer +# ----------------------------------------------------------------------------- + + +_BACKBONE_PRESETS = { + "vits": dict(embed_dim=384, depth=12, num_heads=6, ffn_layer="mlp"), + "vitb": dict(embed_dim=768, depth=12, num_heads=12, ffn_layer="mlp"), + "vitl": dict(embed_dim=1024, depth=24, num_heads=16, ffn_layer="mlp"), + "vitg": dict(embed_dim=1536, depth=40, num_heads=24, ffn_layer="swiglufused"), +} + + +class DinoVisionTransformer(nn.Module): + PATCH_SIZE = 14 + + def __init__(self, + embed_dim: int, + depth: int, + num_heads: int, + ffn_layer: str = "mlp", + mlp_ratio: float = 4.0, + init_values: float = 1.0, + alt_start: int = -1, + qknorm_start: int = -1, + rope_start: int = -1, + rope_freq: float = 100.0, + cat_token: bool = True, + device=None, dtype=None, operations=None): + super().__init__() + norm_layer = nn.LayerNorm + self.embed_dim = embed_dim + self.num_heads = num_heads + self.alt_start = alt_start + self.qknorm_start = qknorm_start + self.rope_start = rope_start + self.cat_token = cat_token + self.patch_size = self.PATCH_SIZE + self.num_register_tokens = 0 + self.patch_start_idx = 1 + self.num_tokens = 1 + + self.patch_embed = PatchEmbed( + patch_size=self.PATCH_SIZE, in_chans=3, embed_dim=embed_dim, + device=device, dtype=dtype, operations=operations, + ) + # Number of patch positions for the historical 518x518 reference grid. + ref_grid = 518 // self.PATCH_SIZE + num_patches = ref_grid * ref_grid + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, device=device, dtype=dtype)) + if alt_start != -1: + self.camera_token = nn.Parameter(torch.zeros(1, 2, embed_dim, device=device, dtype=dtype)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim, + device=device, dtype=dtype)) + + if rope_start != -1 and rope_freq > 0: + self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) + self.position_getter = PositionGetter() + else: + self.rope = None + self.position_getter = None + + if ffn_layer == "mlp": + ffn = Mlp + elif ffn_layer in ("swiglu", "swiglufused"): + ffn = SwiGLUFFNFused + else: + raise NotImplementedError(f"Unsupported ffn_layer: {ffn_layer}") + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + init_values=init_values, + norm_layer=norm_layer, + ffn_layer=ffn, + qk_norm=(qknorm_start != -1 and i >= qknorm_start), + rope=(self.rope if (rope_start != -1 and i >= rope_start) else None), + device=device, dtype=dtype, operations=operations, + ) + for i in range(depth) + ]) + self.norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + pos_embed = comfy_cast(self.pos_embed, x).float() + if npatch == N and w == h: + return pos_embed + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) + assert N == M * M + # Historical 0.1 offset preserves bicubic resample compatibility with + # the original DINOv2 release; see the upstream PR for context. + sx = float(w0 + 0.1) / M + sy = float(h0 + 0.1) / M + patch_pos_embed = F.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), mode="bicubic", antialias=False, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, S, 3, H, W) -> tokens (B, S, 1+N, C) + B, S, _, H, W = x.shape + x = rearrange(x, "b s c h w -> (b s) c h w") + x = self.patch_embed(x) + cls_token = comfy_cast(self.cls_token, x).expand(B, S, -1).reshape(B * S, 1, self.embed_dim) + x = torch.cat((cls_token, x), dim=1) + x = x + self.interpolate_pos_encoding(x, W, H) + x = rearrange(x, "(b s) n c -> b s n c", b=B, s=S) + return x + + def _prepare_rope(self, B, S, H, W, device): + if self.rope is None: + return None, None + pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=device) + pos = rearrange(pos, "(b s) n c -> b s n c", b=B) + pos_nodiff = torch.zeros_like(pos) + if self.patch_start_idx > 0: + pos = pos + 1 + pos_special = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype) + pos_special = rearrange(pos_special, "(b s) n c -> b s n c", b=B) + pos = torch.cat([pos_special, pos], dim=2) + pos_nodiff = pos_nodiff + 1 + pos_nodiff = torch.cat([pos_special, pos_nodiff], dim=2) + return pos, pos_nodiff + + def _attn(self, x, blk, attn_type, pos=None, attn_mask=None): + b, s, n = x.shape[:3] + if attn_type == "local": + x = rearrange(x, "b s n c -> (b s) n c") + if pos is not None: + pos = rearrange(pos, "b s n c -> (b s) n c") + else: # "global" + x = rearrange(x, "b s n c -> b (s n) c") + if pos is not None: + pos = rearrange(pos, "b s n c -> b (s n) c") + x = blk(x, pos=pos, attn_mask=attn_mask) + if attn_type == "local": + x = rearrange(x, "(b s) n c -> b s n c", b=b, s=s) + else: + x = rearrange(x, "b (s n) c -> b s n c", b=b, s=s) + return x + + # ------------------------------------------------------------------ + # Public forward + # ------------------------------------------------------------------ + def get_intermediate_layers(self, x: torch.Tensor, out_layers: List[int], + cam_token: Optional[torch.Tensor] = None): + B, S, _, H, W = x.shape + x = self.prepare_tokens(x) + pos, pos_nodiff = self._prepare_rope(B, S, H, W, x.device) + + outputs = [] + local_x = x + + for i, blk in enumerate(self.blocks): + if self.rope is not None and i >= self.rope_start: + g_pos, l_pos = pos_nodiff, pos + else: + g_pos, l_pos = None, None + + if self.alt_start != -1 and i == self.alt_start: + # Inject camera token at the alt-start boundary. + if cam_token is not None: + inj = cam_token + else: + ct = comfy_cast(self.camera_token, x) + ref_token = ct[:, :1].expand(B, -1, -1) + src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1) + inj = torch.cat([ref_token, src_token], dim=1) + x = x.clone() + x[:, :, 0] = inj + + if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1): + x = self._attn(x, blk, "global", pos=g_pos) + else: + x = self._attn(x, blk, "local", pos=l_pos) + local_x = x + + if i in out_layers: + out_x = torch.cat([local_x, x], dim=-1) if self.cat_token else x + outputs.append(out_x) + + # Apply final norm. Upstream norms only the "global" half when cat_token. + normed: List[torch.Tensor] = [] + camera_tokens: List[torch.Tensor] = [] + for out_x in outputs: + # Camera/cls token slot is index 0 *before* register-token stripping. + camera_tokens.append(out_x[:, :, 0]) + if out_x.shape[-1] == self.embed_dim: + normed.append(self.norm(out_x)) + elif out_x.shape[-1] == self.embed_dim * 2: + left = out_x[..., : self.embed_dim] + right = self.norm(out_x[..., self.embed_dim :]) + normed.append(torch.cat([left, right], dim=-1)) + else: + raise ValueError(f"Unexpected token width: {out_x.shape[-1]}") + # Drop cls/cam token + register tokens from patch sequence. + normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed] + # Match upstream signature consumed by DA3 heads: + # feats[i][0] = normed patch tokens, feats[i][1] = camera/cls token. + return list(zip(normed, camera_tokens)) + + +class DinoV2(nn.Module): + """Top-level DINOv2 wrapper matching upstream key layout (``self.pretrained``).""" + + def __init__(self, name: str = "vits", + out_layers: Optional[List[int]] = None, + alt_start: int = -1, + qknorm_start: int = -1, + rope_start: int = -1, + cat_token: bool = True, + device=None, dtype=None, operations=None, **kwargs): + super().__init__() + if name not in _BACKBONE_PRESETS: + raise ValueError(f"Unknown DINOv2 backbone variant: {name!r}") + preset = _BACKBONE_PRESETS[name] + self.name = name + self.out_layers = list(out_layers) if out_layers is not None else [5, 7, 9, 11] + self.cat_token = cat_token + self.pretrained = DinoVisionTransformer( + embed_dim=preset["embed_dim"], + depth=preset["depth"], + num_heads=preset["num_heads"], + ffn_layer=preset["ffn_layer"], + alt_start=alt_start, + qknorm_start=qknorm_start, + rope_start=rope_start, + cat_token=cat_token, + device=device, dtype=dtype, operations=operations, + ) + + def forward(self, x, cam_token=None, **_unused): + return self.pretrained.get_intermediate_layers(x, self.out_layers, cam_token=cam_token) diff --git a/comfy/ldm/depth_anything_3/dpt.py b/comfy/ldm/depth_anything_3/dpt.py new file mode 100644 index 000000000..98a12a2b0 --- /dev/null +++ b/comfy/ldm/depth_anything_3/dpt.py @@ -0,0 +1,517 @@ +# DPT / DualDPT heads for Depth Anything 3. +# +# Ported from: +# src/depth_anything_3/model/dpt.py (DPT - single main head + sky head) +# src/depth_anything_3/model/dualdpt.py (DualDPT - depth + auxiliary "ray" head) +# +# In the monocular path we always discard the auxiliary "ray" output of +# DualDPT. The auxiliary branch is still constructed so that DA3 HF weights +# load cleanly without missing-key warnings. + +from __future__ import annotations + +from typing import List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# ----------------------------------------------------------------------------- +# Helpers (matching upstream head_utils.py) +# ----------------------------------------------------------------------------- + + +class Permute(nn.Module): + def __init__(self, dims: Tuple[int, ...]): + super().__init__() + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(*self.dims) + + +def _custom_interpolate( + x: torch.Tensor, + size: Optional[Tuple[int, int]] = None, + scale_factor: Optional[float] = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + if size is None: + assert scale_factor is not None + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + INT_MAX = 1610612736 + total = size[0] * size[1] * x.shape[0] * x.shape[1] + if total > INT_MAX: + chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0) + outs = [F.interpolate(c, size=size, mode=mode, align_corners=align_corners) for c in chunks] + return torch.cat(outs, dim=0).contiguous() + return F.interpolate(x, size=size, mode=mode, align_corners=align_corners) + + +def _create_uv_grid(width: int, height: int, aspect_ratio: float, + dtype, device) -> torch.Tensor: + """Normalised UV grid spanning (-x_span, -y_span)..(x_span, y_span).""" + diag_factor = (aspect_ratio ** 2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + return torch.stack((uu, vv), dim=-1) # (H, W, 2) + + +def _make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100.0) -> torch.Tensor: + omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) + omega = 1.0 / omega_0 ** (omega / (embed_dim / 2.0)) + pos = pos.reshape(-1) + out = torch.einsum("m,d->md", pos, omega) + return torch.cat([out.sin(), out.cos()], dim=1).float() + + +def _position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, + omega_0: float = 100.0) -> torch.Tensor: + H, W, _ = pos_grid.shape + pos_flat = pos_grid.reshape(-1, 2) + emb_x = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) + emb_y = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) + emb = torch.cat([emb_x, emb_y], dim=-1) + return emb.view(H, W, embed_dim) + + +def _add_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """Stateless UV positional embedding added to a feature map (B, C, h, w).""" + pw, ph = x.shape[-1], x.shape[-2] + pe = _create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pe = _position_grid_to_embed(pe, x.shape[1]) * ratio + pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1).to(dtype=x.dtype) + return x + pe + + +def _apply_activation(x: torch.Tensor, activation: str) -> torch.Tensor: + act = (activation or "linear").lower() + if act == "exp": + return torch.exp(x) + if act == "expp1": + return torch.exp(x) + 1 + if act == "expm1": + return torch.expm1(x) + if act == "relu": + return torch.relu(x) + if act == "sigmoid": + return torch.sigmoid(x) + if act == "softplus": + return F.softplus(x) + if act == "tanh": + return torch.tanh(x) + return x + + +# ----------------------------------------------------------------------------- +# Fusion building blocks +# ----------------------------------------------------------------------------- + + +class ResidualConvUnit(nn.Module): + def __init__(self, features: int, + device=None, dtype=None, operations=None): + super().__init__() + self.conv1 = operations.Conv2d(features, features, 3, 1, 1, bias=True, + device=device, dtype=dtype) + self.conv2 = operations.Conv2d(features, features, 3, 1, 1, bias=True, + device=device, dtype=dtype) + self.activation = nn.ReLU(inplace=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.activation(x) + out = self.conv1(out) + out = self.activation(out) + out = self.conv2(out) + return out + x + + +class FeatureFusionBlock(nn.Module): + def __init__(self, features: int, has_residual: bool = True, + align_corners: bool = True, + device=None, dtype=None, operations=None): + super().__init__() + self.align_corners = align_corners + self.has_residual = has_residual + if has_residual: + self.resConfUnit1 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations) + else: + self.resConfUnit1 = None + self.resConfUnit2 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations) + self.out_conv = operations.Conv2d(features, features, 1, 1, 0, bias=True, + device=device, dtype=dtype) + + def forward(self, *xs: torch.Tensor, size: Optional[Tuple[int, int]] = None) -> torch.Tensor: + y = xs[0] + if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None: + y = y + self.resConfUnit1(xs[1]) + y = self.resConfUnit2(y) + if size is None: + up_kwargs = {"scale_factor": 2.0} + else: + up_kwargs = {"size": size} + y = _custom_interpolate(y, **up_kwargs, mode="bilinear", + align_corners=self.align_corners) + y = self.out_conv(y) + return y + + +class _Scratch(nn.Module): + """Container that mirrors upstream ``scratch`` attribute layout.""" + + +def _make_scratch(in_shape: List[int], out_shape: int, + device=None, dtype=None, operations=None) -> _Scratch: + scratch = _Scratch() + scratch.layer1_rn = operations.Conv2d(in_shape[0], out_shape, 3, 1, 1, bias=False, + device=device, dtype=dtype) + scratch.layer2_rn = operations.Conv2d(in_shape[1], out_shape, 3, 1, 1, bias=False, + device=device, dtype=dtype) + scratch.layer3_rn = operations.Conv2d(in_shape[2], out_shape, 3, 1, 1, bias=False, + device=device, dtype=dtype) + scratch.layer4_rn = operations.Conv2d(in_shape[3], out_shape, 3, 1, 1, bias=False, + device=device, dtype=dtype) + return scratch + + +def _make_fusion_block(features: int, has_residual: bool = True, + device=None, dtype=None, operations=None) -> FeatureFusionBlock: + return FeatureFusionBlock(features, has_residual=has_residual, + align_corners=True, + device=device, dtype=dtype, operations=operations) + + +# ----------------------------------------------------------------------------- +# DPT (single head + optional sky head) -- used by DA3Mono/Metric +# ----------------------------------------------------------------------------- + + +class DPT(nn.Module): + """Single-head DPT used by DA3Mono-Large and DA3Metric-Large.""" + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 1, + activation: str = "exp", + conf_activation: str = "expp1", + features: int = 256, + out_channels: Sequence[int] = (256, 512, 1024, 1024), + pos_embed: bool = False, + down_ratio: int = 1, + head_name: str = "depth", + use_sky_head: bool = True, + sky_name: str = "sky", + sky_activation: str = "relu", + norm_type: str = "idt", + device=None, dtype=None, operations=None, + ): + super().__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.down_ratio = down_ratio + self.head_main = head_name + self.sky_name = sky_name + self.out_dim = output_dim + self.has_conf = output_dim > 1 + self.use_sky_head = use_sky_head + self.sky_activation = sky_activation + self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3) + + if norm_type == "layer": + self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype) + else: + self.norm = nn.Identity() + + out_channels = list(out_channels) + self.projects = nn.ModuleList([ + operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, + device=device, dtype=dtype) + for oc in out_channels + ]) + self.resize_layers = nn.ModuleList([ + operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, + device=device, dtype=dtype), + operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, + device=device, dtype=dtype), + nn.Identity(), + operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, + device=device, dtype=dtype), + ]) + + self.scratch = _make_scratch(out_channels, features, + device=device, dtype=dtype, operations=operations) + self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, + device=device, dtype=dtype, operations=operations) + + head_features_1 = features + head_features_2 = 32 + self.scratch.output_conv1 = operations.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype, + ) + self.scratch.output_conv2 = nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype), + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, + device=device, dtype=dtype), + ) + + if self.use_sky_head: + self.scratch.sky_output_conv2 = nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype), + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0, + device=device, dtype=dtype), + ) + + def forward(self, feats: List[torch.Tensor], H: int, W: int, + patch_start_idx: int = 0, **_kwargs) -> dict: + # feats[i][0] is the patch-token tensor with shape (B, S, N_patch, C) + B, S, N, C = feats[0][0].shape + feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats] + + ph, pw = H // self.patch_size, W // self.patch_size + resized = [] + for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): + x = feats_flat[take_idx][:, patch_start_idx:] + x = self.norm(x) + x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw) + x = self.projects[stage_idx](x) + if self.pos_embed: + x = _add_pos_embed(x, W, H) + x = self.resize_layers[stage_idx](x) + resized.append(x) + + l1_rn = self.scratch.layer1_rn(resized[0]) + l2_rn = self.scratch.layer2_rn(resized[1]) + l3_rn = self.scratch.layer3_rn(resized[2]) + l4_rn = self.scratch.layer4_rn(resized[3]) + + out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:]) + out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:]) + out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:]) + out = self.scratch.refinenet1(out, l1_rn) + + h_out = int(ph * self.patch_size / self.down_ratio) + w_out = int(pw * self.patch_size / self.down_ratio) + + fused = self.scratch.output_conv1(out) + fused = _custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True) + if self.pos_embed: + fused = _add_pos_embed(fused, W, H) + feat = fused + + main_logits = self.scratch.output_conv2(feat) + outs = {} + if self.has_conf: + fmap = main_logits.permute(0, 2, 3, 1) + pred = _apply_activation(fmap[..., :-1], self.activation) + conf = _apply_activation(fmap[..., -1], self.conf_activation) + outs[self.head_main] = pred.squeeze(-1).view(B, S, *pred.shape[1:-1]) + outs[f"{self.head_main}_conf"] = conf.view(B, S, *conf.shape[1:]) + else: + pred = _apply_activation(main_logits, self.activation) + outs[self.head_main] = pred.squeeze(1).view(B, S, *pred.shape[2:]) + + if self.use_sky_head: + sky_logits = self.scratch.sky_output_conv2(feat) + if self.sky_activation.lower() == "sigmoid": + sky = torch.sigmoid(sky_logits) + elif self.sky_activation.lower() == "relu": + sky = F.relu(sky_logits) + else: + sky = sky_logits + outs[self.sky_name] = sky.squeeze(1).view(B, S, *sky.shape[2:]) + + return outs + + +# ----------------------------------------------------------------------------- +# DualDPT (depth + auxiliary "ray" head) -- used by DA3-Small / DA3-Base +# ----------------------------------------------------------------------------- + + +class DualDPT(nn.Module): + """Two-head DPT used by DA3-Small / DA3-Base. + + The auxiliary "ray" head is constructed so that HF state-dict keys load + cleanly, but its outputs are unused on the monocular path. + """ + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 2, + activation: str = "exp", + conf_activation: str = "expp1", + features: int = 256, + out_channels: Sequence[int] = (256, 512, 1024, 1024), + pos_embed: bool = True, + down_ratio: int = 1, + aux_pyramid_levels: int = 4, + aux_out1_conv_num: int = 5, + head_names: Tuple[str, str] = ("depth", "ray"), + device=None, dtype=None, operations=None, + ): + super().__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.down_ratio = down_ratio + self.aux_levels = aux_pyramid_levels + self.aux_out1_conv_num = aux_out1_conv_num + self.head_main, self.head_aux = head_names + self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3) + + self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype) + out_channels = list(out_channels) + self.projects = nn.ModuleList([ + operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, + device=device, dtype=dtype) + for oc in out_channels + ]) + self.resize_layers = nn.ModuleList([ + operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, + device=device, dtype=dtype), + operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, + device=device, dtype=dtype), + nn.Identity(), + operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, + device=device, dtype=dtype), + ]) + + self.scratch = _make_scratch(out_channels, features, + device=device, dtype=dtype, operations=operations) + # Main fusion chain + self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, + device=device, dtype=dtype, operations=operations) + # Auxiliary fusion chain (separate copies) + self.scratch.refinenet1_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet2_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet3_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False, + device=device, dtype=dtype, operations=operations) + + head_features_1 = features + head_features_2 = 32 + + # Main head neck + final projection + self.scratch.output_conv1 = operations.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype, + ) + self.scratch.output_conv2 = nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype), + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, + device=device, dtype=dtype), + ) + + # Aux pre-head per level (multi-level pyramid) + self.scratch.output_conv1_aux = nn.ModuleList([ + self._make_aux_out1_block(head_features_1, device=device, dtype=dtype, operations=operations) + for _ in range(self.aux_levels) + ]) + + # Aux final projection per level (includes LayerNorm permute path). + ln_seq = [Permute((0, 2, 3, 1)), + operations.LayerNorm(head_features_2, device=device, dtype=dtype), + Permute((0, 3, 1, 2))] + self.scratch.output_conv2_aux = nn.ModuleList([ + nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype), + *ln_seq, + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0, + device=device, dtype=dtype), + ) + for _ in range(self.aux_levels) + ]) + + @staticmethod + def _make_aux_out1_block(in_ch: int, *, device=None, dtype=None, operations=None) -> nn.Sequential: + # aux_out1_conv_num=5 in all Apache-2.0 variants. + return nn.Sequential( + operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype), + ) + + def forward(self, feats: List[torch.Tensor], H: int, W: int, + patch_start_idx: int = 0, **_kwargs) -> dict: + B, S, N, C = feats[0][0].shape + feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats] + + ph, pw = H // self.patch_size, W // self.patch_size + resized = [] + for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): + x = feats_flat[take_idx][:, patch_start_idx:] + x = self.norm(x) + x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw) + x = self.projects[stage_idx](x) + if self.pos_embed: + x = _add_pos_embed(x, W, H) + x = self.resize_layers[stage_idx](x) + resized.append(x) + + l1_rn = self.scratch.layer1_rn(resized[0]) + l2_rn = self.scratch.layer2_rn(resized[1]) + l3_rn = self.scratch.layer3_rn(resized[2]) + l4_rn = self.scratch.layer4_rn(resized[3]) + + # Main pyramid (output_conv1 is applied inside the upstream `_fuse`, + # before interpolation -- replicate that order here). + m = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:]) + m = self.scratch.refinenet3(m, l3_rn, size=l2_rn.shape[2:]) + m = self.scratch.refinenet2(m, l2_rn, size=l1_rn.shape[2:]) + m = self.scratch.refinenet1(m, l1_rn) + m = self.scratch.output_conv1(m) + + h_out = int(ph * self.patch_size / self.down_ratio) + w_out = int(pw * self.patch_size / self.down_ratio) + + m = _custom_interpolate(m, (h_out, w_out), mode="bilinear", align_corners=True) + if self.pos_embed: + m = _add_pos_embed(m, W, H) + main_logits = self.scratch.output_conv2(m) + fmap = main_logits.permute(0, 2, 3, 1) + depth_pred = _apply_activation(fmap[..., :-1], self.activation) + depth_conf = _apply_activation(fmap[..., -1], self.conf_activation) + + outs = { + self.head_main: depth_pred.squeeze(-1).view(B, S, *depth_pred.shape[1:-1]), + f"{self.head_main}_conf": depth_conf.view(B, S, *depth_conf.shape[1:]), + } + + # NOTE: we intentionally do not run the auxiliary "ray" branch — it is + # only needed for pose/ray-conditioned outputs which are out of scope + # for this port. The aux submodules are still built so HF weights load. + + return outs diff --git a/comfy/ldm/depth_anything_3/model.py b/comfy/ldm/depth_anything_3/model.py new file mode 100644 index 000000000..8bf5e9ec2 --- /dev/null +++ b/comfy/ldm/depth_anything_3/model.py @@ -0,0 +1,135 @@ +# DepthAnything3Net: top-level wrapper that combines backbone + head. +# +# This wrapper covers the monocular forward path only (single image -> depth). +# Camera encoder/decoder, ray-pose head, 3D Gaussians and the Nested +# architecture are intentionally omitted. The HF state dict for those +# components is filtered out before loading -- see +# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``. +# +# The class signature mirrors the upstream YAML config so a single dit_config +# detected from the state dict in ``comfy/model_detection.py`` is sufficient +# to construct the right variant. + +from __future__ import annotations + +from typing import Dict, List, Optional, Sequence + +import torch +import torch.nn as nn + +from .dinov2 import DinoV2 +from .dpt import DPT, DualDPT + + +_HEAD_REGISTRY = { + "dpt": DPT, + "dualdpt": DualDPT, +} + + +class DepthAnything3Net(nn.Module): + """ComfyUI-side DepthAnything3 network (monocular path only). + + Parameters mirror the variant YAML configs from the upstream repo. + Values are auto-detected by ``comfy/model_detection.py`` from the state + dict. The kwargs ``device``, ``dtype`` and ``operations`` are injected by + ``BaseModel``. + """ + + PATCH_SIZE = 14 + + def __init__( + self, + # --- Backbone --- + backbone_name: str = "vitl", + out_layers: Sequence[int] = (4, 11, 17, 23), + alt_start: int = -1, + qknorm_start: int = -1, + rope_start: int = -1, + cat_token: bool = False, + # --- Head --- + head_type: str = "dpt", # "dpt" or "dualdpt" + head_dim_in: int = 1024, + head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf + head_features: int = 256, + head_out_channels: Sequence[int] = (256, 512, 1024, 1024), + head_use_sky_head: bool = True, # ignored by DualDPT + head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT + # ComfyUI plumbing + device=None, dtype=None, operations=None, + **_ignored, + ): + super().__init__() + head_cls = _HEAD_REGISTRY[head_type.lower()] + self.head_type = head_type.lower() + self.has_sky = (self.head_type == "dpt") and head_use_sky_head + self.has_conf = head_output_dim > 1 + + self.backbone = DinoV2( + name=backbone_name, + out_layers=list(out_layers), + alt_start=alt_start, + qknorm_start=qknorm_start, + rope_start=rope_start, + cat_token=cat_token, + device=device, dtype=dtype, operations=operations, + ) + + head_kwargs = dict( + dim_in=head_dim_in, + patch_size=self.PATCH_SIZE, + output_dim=head_output_dim, + features=head_features, + out_channels=tuple(head_out_channels), + device=device, dtype=dtype, operations=operations, + ) + if self.head_type == "dpt": + head_kwargs.update( + use_sky_head=head_use_sky_head, + pos_embed=(False if head_pos_embed is None else head_pos_embed), + ) + else: # dualdpt + head_kwargs.update( + pos_embed=(True if head_pos_embed is None else head_pos_embed), + ) + self.head = head_cls(**head_kwargs) + self.dtype = dtype + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + def forward(self, image: torch.Tensor, **_unused) -> Dict[str, torch.Tensor]: + """Run monocular forward. + + Args: + image: ``(B, 3, H, W)`` ImageNet-normalised image tensor, or + ``(B, S, 3, H, W)`` if a fake "views" axis is supplied. + H and W must be multiples of 14. + + Returns: + Dict with: + - ``depth``: ``(B, H, W)`` raw depth values. + - ``depth_conf``: ``(B, H, W)`` confidence (DualDPT variants only). + - ``sky``: ``(B, H, W)`` sky probability/logit + (DPT variants only). + """ + if image.ndim == 4: + image = image.unsqueeze(1) # (B, 1, 3, H, W) + assert image.ndim == 5 and image.shape[2] == 3, \ + f"image must be (B,3,H,W) or (B,S,3,H,W); got {tuple(image.shape)}" + + B, S, _, H, W = image.shape + assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \ + f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}" + + feats = self.backbone(image) + head_out = self.head(feats, H=H, W=W, patch_start_idx=0) + + # Flatten the views axis (S=1 in mono inference path). + out: Dict[str, torch.Tensor] = {} + for k, v in head_out.items(): + if v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S: + out[k] = v.reshape(B * S, *v.shape[2:]) + else: + out[k] = v + return out diff --git a/comfy/ldm/depth_anything_3/preprocess.py b/comfy/ldm/depth_anything_3/preprocess.py new file mode 100644 index 000000000..667aab8cc --- /dev/null +++ b/comfy/ldm/depth_anything_3/preprocess.py @@ -0,0 +1,178 @@ +# Input/output preprocessing helpers for Depth Anything 3. +# +# Ported from: +# src/depth_anything_3/utils/io/input_processor.py (image normalisation) +# src/depth_anything_3/utils/alignment.py (sky-aware depth clip) +# src/depth_anything_3/model/da3.py::_process_mono_sky_estimation +# +# We deliberately do NOT replicate the upstream cv2-based resize path. ComfyUI +# already provides ``comfy.utils.common_upscale`` for high-quality bilinear +# resampling; using it keeps everything on-device and consistent with other +# ComfyUI preprocessors. The bilinear approximation is sufficient for the +# downstream depth-estimation task (verified visually against the upstream +# bicubic path -- depth maps are virtually identical). + +from __future__ import annotations + +from typing import Tuple + +import torch + +import comfy.utils + +PATCH_SIZE = 14 + +# ImageNet normalization constants used during DA3 training. +_IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]) +_IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]) + + +def _round_to_patch(x: int, patch: int = PATCH_SIZE) -> int: + down = (x // patch) * patch + up = down + patch + return up if abs(up - x) <= abs(x - down) else down + + +def compute_target_size(orig_h: int, orig_w: int, process_res: int, + method: str = "upper_bound_resize") -> Tuple[int, int]: + """Compute (target_h, target_w) for a single image. + + Methods: + - "upper_bound_resize": scale longest side to ``process_res``, then + round each dim to nearest multiple of 14 (default upstream method). + - "lower_bound_resize": scale shortest side to ``process_res``, then + round. + """ + if method == "upper_bound_resize": + longest = max(orig_h, orig_w) + scale = process_res / float(longest) + elif method == "lower_bound_resize": + shortest = min(orig_h, orig_w) + scale = process_res / float(shortest) + else: + raise ValueError(f"Unsupported process_res_method: {method}") + + new_w = max(1, _round_to_patch(int(round(orig_w * scale)))) + new_h = max(1, _round_to_patch(int(round(orig_h * scale)))) + return new_h, new_w + + +def preprocess_image( + image: torch.Tensor, + process_res: int = 504, + method: str = "upper_bound_resize", +) -> torch.Tensor: + """Preprocess a ComfyUI ``IMAGE`` batch for DA3. + + Args: + image: ``(B, H, W, 3)`` float in [0, 1] (ComfyUI ``IMAGE`` convention). + process_res: target resolution (longest or shortest side, depending + on ``method``). + method: resize strategy. + + Returns: + ``(B, 3, H', W')`` tensor with H' and W' multiples of 14, normalised + with ImageNet statistics. The tensor lives on the same device as + ``image``. + """ + assert image.ndim == 4 and image.shape[-1] == 3, \ + f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}" + B, H, W, _ = image.shape + target_h, target_w = compute_target_size(H, W, process_res, method) + + # (B, H, W, 3) -> (B, 3, H, W) + x = image.movedim(-1, 1).contiguous() + if (target_h, target_w) != (H, W): + # common_upscale takes a (B, C, H, W) tensor. + x = comfy.utils.common_upscale(x, target_w, target_h, "bilinear", "disabled") + x = x.clamp(0.0, 1.0) + + mean = _IMAGENET_MEAN.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1) + std = _IMAGENET_STD.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1) + x = (x - mean) / std + return x + + +# ----------------------------------------------------------------------------- +# Output post-processing (sky-aware clipping for Mono/Metric variants) +# ----------------------------------------------------------------------------- + + +def compute_non_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor: + """Boolean mask: True for non-sky pixels (sky probability < threshold).""" + return sky_prediction < threshold + + +def apply_sky_aware_clip( + depth: torch.Tensor, + sky: torch.Tensor, + threshold: float = 0.3, + quantile: float = 0.99, +) -> torch.Tensor: + """Replicates ``_process_mono_sky_estimation`` from upstream. + + Clips sky regions to the 99th percentile of non-sky depth. Returns a new + depth tensor; ``depth`` is not modified in place. + """ + non_sky = compute_non_sky_mask(sky, threshold=threshold) + if non_sky.sum() <= 10 or (~non_sky).sum() <= 10: + return depth.clone() + + non_sky_depth = depth[non_sky] + if non_sky_depth.numel() > 100_000: + idx = torch.randint(0, non_sky_depth.numel(), (100_000,), device=non_sky_depth.device) + sampled = non_sky_depth[idx] + else: + sampled = non_sky_depth + + max_depth = torch.quantile(sampled, quantile) + out = depth.clone() + out[~non_sky] = max_depth + return out + + +def normalize_depth_v2_style( + depth: torch.Tensor, + sky: torch.Tensor | None = None, + low_quantile: float = 0.01, + high_quantile: float = 0.99, +) -> torch.Tensor: + """V2-style normalization for ControlNet workflows. + + Computes percentile bounds over non-sky pixels (when available), + then maps depth into [0, 1] with near = white (1.0). + """ + if sky is not None: + mask = compute_non_sky_mask(sky) + if mask.any(): + valid = depth[mask] + else: + valid = depth.flatten() + else: + valid = depth.flatten() + + if valid.numel() > 100_000: + idx = torch.randint(0, valid.numel(), (100_000,), device=valid.device) + sample = valid[idx] + else: + sample = valid + + lo = torch.quantile(sample, low_quantile) + hi = torch.quantile(sample, high_quantile) + rng = (hi - lo).clamp(min=1e-6) + norm = ((depth - lo) / rng).clamp(0.0, 1.0) + # ControlNet convention: nearer pixels are brighter (1.0). + norm = 1.0 - norm + if sky is not None: + # Sky pixels become black (far / unknown). + sky_mask = ~compute_non_sky_mask(sky) + norm = torch.where(sky_mask, torch.zeros_like(norm), norm) + return norm + + +def normalize_depth_min_max(depth: torch.Tensor) -> torch.Tensor: + """Simple per-frame min/max normalization with near=1.0 convention.""" + lo = depth.amin(dim=(-2, -1), keepdim=True) + hi = depth.amax(dim=(-2, -1), keepdim=True) + rng = (hi - lo).clamp(min=1e-6) + return 1.0 - ((depth - lo) / rng).clamp(0.0, 1.0) diff --git a/comfy/model_base.py b/comfy/model_base.py index 0736321b3..ad040b244 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -60,6 +60,7 @@ import comfy.ldm.ernie.model import comfy.ldm.sam3.detector import comfy.ldm.hidream_o1.model from comfy.ldm.hidream_o1.conditioning import build_extra_conds +import comfy.ldm.depth_anything_3.model import comfy.model_management import comfy.patcher_extension @@ -2035,6 +2036,12 @@ class RT_DETR_v4(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4) + +class DepthAnything3(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, + unet_model=comfy.ldm.depth_anything_3.model.DepthAnything3Net) + class ErnieImage(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index bc0b933bc..a8e6bf467 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -766,6 +766,90 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0] return dit_config + # Depth Anything 3 (Apache-2.0 monocular variants: Small/Base/Mono-Large/Metric-Large). + if '{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix) in state_dict_keys: + dit_config = {} + dit_config["image_model"] = "DepthAnything3" + + patch_w = state_dict['{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix)] + embed_dim = patch_w.shape[0] + depth = count_blocks(state_dict_keys, '{}backbone.pretrained.blocks.'.format(key_prefix) + '{}.') + + # Backbone preset is determined by embed_dim (matches vits/vitb/vitl/vitg). + backbone_name = {384: "vits", 768: "vitb", 1024: "vitl", 1536: "vitg"}.get(embed_dim) + if backbone_name is None: + return None + dit_config["backbone_name"] = backbone_name + + # Detect DA3 extensions on top of vanilla DINOv2. + has_camera_token = '{}backbone.pretrained.camera_token'.format(key_prefix) in state_dict_keys + # qk-norm shows up as `attn.q_norm.weight` on enabled blocks. + qknorm_indices = [ + i for i in range(depth) + if '{}backbone.pretrained.blocks.{}.attn.q_norm.weight'.format(key_prefix, i) in state_dict_keys + ] + qknorm_start = qknorm_indices[0] if qknorm_indices else -1 + + # The DA3 main-series configs always set alt_start == qknorm_start == rope_start. + # cat_token=True is implied by the presence of camera_token. + if has_camera_token: + dit_config["alt_start"] = qknorm_start + dit_config["rope_start"] = qknorm_start + dit_config["qknorm_start"] = qknorm_start + dit_config["cat_token"] = True + else: + dit_config["alt_start"] = -1 + dit_config["rope_start"] = -1 + dit_config["qknorm_start"] = -1 + dit_config["cat_token"] = False + + # Detect head type and config. + has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys + if has_aux: + dit_config["head_type"] = "dualdpt" + # DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width). + head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1] + out_channels = [ + state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0] + for i in range(4) + ] + features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0] + dit_config["head_dim_in"] = head_dim_in + dit_config["head_output_dim"] = 2 + dit_config["head_features"] = features + dit_config["head_out_channels"] = out_channels + dit_config["head_use_sky_head"] = False + else: + dit_config["head_type"] = "dpt" + head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1] + out_channels = [ + state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0] + for i in range(4) + ] + features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0] + output_dim = state_dict[ + '{}head.scratch.output_conv2.2.weight'.format(key_prefix) + ].shape[0] + dit_config["head_dim_in"] = head_dim_in + dit_config["head_output_dim"] = output_dim + dit_config["head_features"] = features + dit_config["head_out_channels"] = out_channels + dit_config["head_use_sky_head"] = ( + '{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys + ) + + # out_layers: hard-coded per upstream YAML config (depth-aware default). + if depth >= 24: + # vitl: depths used vary between DA3-Large (DualDPT) and Mono/Metric (DPT). + if has_aux: + dit_config["out_layers"] = [11, 15, 19, 23] + else: + dit_config["out_layers"] = [4, 11, 17, 23] + else: + # vits/vitb: 12 blocks + dit_config["out_layers"] = [5, 7, 9, 11] + return dit_config + if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image dit_config = {} dit_config["image_model"] = "ernie" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1e4434fd5..9540d1d69 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1847,6 +1847,33 @@ class RT_DETR_v4(supported_models_base.BASE): return None +class DepthAnything3(supported_models_base.BASE): + unet_config = { + "image_model": "DepthAnything3", + } + + # Mono path: no num_heads / num_head_channels needed. + unet_extra_config = {} + + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.DepthAnything3(self, device=device) + + def clip_target(self, state_dict={}): + return None + + def process_unet_state_dict(self, state_dict): + # Drop weights for components we do not build (camera encoder/decoder, + # 3D Gaussian heads). Keeping unrelated keys around triggers spurious + # "unet unexpected" warnings on load. + drop_prefixes = ("cam_enc.", "cam_dec.", "gs_head.", "gs_adapter.") + for k in list(state_dict.keys()): + if k.startswith(drop_prefixes): + state_dict.pop(k) + return state_dict + + class ErnieImage(supported_models_base.BASE): unet_config = { "image_model": "ernie", @@ -2082,4 +2109,5 @@ models = [ CogVideoX_I2V, CogVideoX_T2V, SVD_img2vid, + DepthAnything3, ] diff --git a/comfy_extras/nodes_depth_anything_3.py b/comfy_extras/nodes_depth_anything_3.py new file mode 100644 index 000000000..a3a86dc9e --- /dev/null +++ b/comfy_extras/nodes_depth_anything_3.py @@ -0,0 +1,247 @@ +"""ComfyUI nodes for Depth Anything 3. + +Adds three nodes: + +* ``LoadDepthAnything3`` -- load a DA3 ``.safetensors`` file from the + ``models/depth_estimation/`` folder. Falls back to ``models/diffusion_models/`` + so existing installations keep working. +* ``DepthAnything3Depth`` -- run depth estimation and return a normalised + depth map as a ComfyUI ``IMAGE`` (visualisation / ControlNet input). +* ``DepthAnything3DepthRaw`` -- run depth estimation and return the raw depth, + confidence and sky channels as ``MASK`` outputs. +""" + +from __future__ import annotations + +from typing_extensions import override + +import torch + +import comfy.model_management as mm +import comfy.sd +import folder_paths +from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess +from comfy_api.latest import ComfyExtension, io + + +# ----------------------------------------------------------------------------- +# Loader +# ----------------------------------------------------------------------------- + + +class LoadDepthAnything3(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadDepthAnything3", + display_name="Load Depth Anything 3", + category="loaders/depth_estimation", + inputs=[ + io.Combo.Input( + "model_name", + options=folder_paths.get_filename_list("depth_estimation"), + ), + io.Combo.Input( + "weight_dtype", + options=["default", "fp16", "bf16", "fp32"], + default="default", + ), + ], + outputs=[io.Model.Output("model")], + ) + + @classmethod + def execute(cls, model_name, weight_dtype) -> io.NodeOutput: + model_options = {} + if weight_dtype == "fp16": + model_options["dtype"] = torch.float16 + elif weight_dtype == "bf16": + model_options["dtype"] = torch.bfloat16 + elif weight_dtype == "fp32": + model_options["dtype"] = torch.float32 + + path = folder_paths.get_full_path_or_raise("depth_estimation", model_name) + model = comfy.sd.load_diffusion_model(path, model_options=model_options) + return io.NodeOutput(model) + + +# ----------------------------------------------------------------------------- +# Inference helpers +# ----------------------------------------------------------------------------- + + +def _run_da3(model_patcher, image: torch.Tensor, process_res: int, + method: str = "upper_bound_resize"): + """Run the DA3 network on a (B, H, W, 3) ``IMAGE`` batch. + + Returns ``(depth, confidence, sky)`` tensors with the original image + resolution. Any of ``confidence`` / ``sky`` may be ``None`` depending on + the variant. + """ + assert image.ndim == 4 and image.shape[-1] == 3, \ + f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}" + + B, H, W, _ = image.shape + mm.load_model_gpu(model_patcher) + diffusion = model_patcher.model.diffusion_model + device = mm.get_torch_device() + dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32 + + depths, confs, skies = [], [], [] + # Process one image at a time to keep peak memory predictable; DA3 is + # an inference-only model and per-sample latency dominates anyway. + for i in range(B): + single = image[i:i + 1].to(device) + x = da3_preprocess.preprocess_image(single, process_res=process_res, method=method) + x = x.to(dtype=dtype) + with torch.no_grad(): + out = diffusion(x) + + depth_lr = out["depth"] + # Resize back to the original (H, W). + depth_full = torch.nn.functional.interpolate( + depth_lr.unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + depths.append(depth_full) + + if "depth_conf" in out: + conf_full = torch.nn.functional.interpolate( + out["depth_conf"].unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + confs.append(conf_full) + if "sky" in out: + sky_full = torch.nn.functional.interpolate( + out["sky"].unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + skies.append(sky_full) + + depth = torch.cat(depths, dim=0) + confidence = torch.cat(confs, dim=0) if confs else None + sky = torch.cat(skies, dim=0) if skies else None + return depth, confidence, sky + + +# ----------------------------------------------------------------------------- +# Depth -> visualisation IMAGE +# ----------------------------------------------------------------------------- + + +class DepthAnything3Depth(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DepthAnything3Depth", + display_name="Depth Anything 3 (Depth)", + category="image/depth", + inputs=[ + io.Model.Input("model"), + io.Image.Input("image"), + io.Int.Input("process_res", default=504, min=140, max=2520, step=14, + tooltip="Longest-side target resolution (multiple of 14)."), + io.Combo.Input("resize_method", + options=["upper_bound_resize", "lower_bound_resize"], + default="upper_bound_resize"), + io.Combo.Input("normalization", + options=["v2_style", "min_max", "raw"], + default="v2_style", + tooltip="How to map raw depth -> [0, 1] image."), + io.Boolean.Input("apply_sky_clip", default=True, + tooltip="(Mono/Metric only) clip sky depth to 99th percentile."), + ], + outputs=[ + io.Image.Output("depth_image"), + io.Mask.Output("sky_mask", + tooltip="Sky probability (Mono/Metric variants), else zeros."), + io.Mask.Output("confidence", + tooltip="Depth confidence (Small/Base/DualDPT variants), else zeros."), + ], + ) + + @classmethod + def execute(cls, model, image, process_res, resize_method, normalization, + apply_sky_clip) -> io.NodeOutput: + depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method) + + if apply_sky_clip and sky is not None: + depth = torch.stack([ + da3_preprocess.apply_sky_aware_clip(depth[i], sky[i]) + for i in range(depth.shape[0]) + ], dim=0) + + if normalization == "v2_style": + norm = torch.stack([ + da3_preprocess.normalize_depth_v2_style(depth[i], + sky[i] if sky is not None else None) + for i in range(depth.shape[0]) + ], dim=0) + elif normalization == "min_max": + norm = da3_preprocess.normalize_depth_min_max(depth) + else: + norm = depth + + # (B, H, W) -> (B, H, W, 3) grayscale IMAGE. + out_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous() + sky_mask = sky if sky is not None else torch.zeros_like(depth) + conf_mask = confidence if confidence is not None else torch.zeros_like(depth) + return io.NodeOutput(out_image, sky_mask.contiguous(), conf_mask.contiguous()) + + +# ----------------------------------------------------------------------------- +# Raw depth output (useful for downstream metric work) +# ----------------------------------------------------------------------------- + + +class DepthAnything3DepthRaw(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DepthAnything3DepthRaw", + display_name="Depth Anything 3 (Raw Depth)", + category="image/depth", + inputs=[ + io.Model.Input("model"), + io.Image.Input("image"), + io.Int.Input("process_res", default=504, min=140, max=2520, step=14), + io.Combo.Input("resize_method", + options=["upper_bound_resize", "lower_bound_resize"], + default="upper_bound_resize"), + ], + outputs=[ + io.Mask.Output("depth", + tooltip="Raw depth values (no normalisation, no clipping)."), + io.Mask.Output("confidence"), + io.Mask.Output("sky"), + ], + ) + + @classmethod + def execute(cls, model, image, process_res, resize_method) -> io.NodeOutput: + depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method) + zeros = torch.zeros_like(depth) + return io.NodeOutput( + depth.contiguous(), + (confidence if confidence is not None else zeros).contiguous(), + (sky if sky is not None else zeros).contiguous(), + ) + + +# ----------------------------------------------------------------------------- +# Extension registration +# ----------------------------------------------------------------------------- + + +class DepthAnything3Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LoadDepthAnything3, + DepthAnything3Depth, + DepthAnything3DepthRaw, + ] + + +async def comfy_entrypoint() -> DepthAnything3Extension: + return DepthAnything3Extension() diff --git a/folder_paths.py b/folder_paths.py index 92e8df3cf..bc95e9a8c 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -58,6 +58,13 @@ folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "fram folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions) +folder_names_and_paths["depth_estimation"] = ( + [os.path.join(models_dir, "depth_estimation"), + os.path.join(models_dir, "diffusion_models")], + supported_pt_extensions, +) + + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") diff --git a/nodes.py b/nodes.py index 78aaaef74..d5a05445c 100644 --- a/nodes.py +++ b/nodes.py @@ -2436,6 +2436,7 @@ async def init_builtin_extra_nodes(): "nodes_void.py", "nodes_wandancer.py", "nodes_hidream_o1.py", + "nodes_depth_anything_3.py", ] import_failed = []