mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-15 03:27:24 +08:00
Initial commit CORE-135
This commit is contained in:
parent
a5189fed51
commit
2a20db3eb3
7
comfy/ldm/depth_anything_3/__init__.py
Normal file
7
comfy/ldm/depth_anything_3/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Depth Anything 3 - native ComfyUI port (Apache-2.0 monocular variants only).
|
||||
#
|
||||
# Supported variants:
|
||||
# DA3-Small, DA3-Base (vits/vitb backbone, DualDPT head)
|
||||
# DA3Mono-Large, DA3Metric-Large (vitl backbone, DPT head + sky mask)
|
||||
#
|
||||
# Original repo: https://github.com/ByteDance-Seed/Depth-Anything-3
|
||||
497
comfy/ldm/depth_anything_3/dinov2.py
Normal file
497
comfy/ldm/depth_anything_3/dinov2.py
Normal file
@ -0,0 +1,497 @@
|
||||
# DINOv2 backbone for Depth Anything 3 (monocular inference path).
|
||||
#
|
||||
# Why not reuse ``comfy/image_encoders/dino2.py``?
|
||||
# The existing ``Dinov2Model`` is a vanilla HuggingFace-style DINOv2 with a
|
||||
# different state-dict layout (separate Q/K/V, ``embeddings.*`` /
|
||||
# ``encoder.layer.*`` keys, ``layer_scale1.lambda1``) and no support for the
|
||||
# architectural extensions DA3 adds on top of DINOv2 (RoPE, QK-norm,
|
||||
# alternating local/global attention, concatenated camera token). Loading
|
||||
# raw DA3 HF safetensors into ``Dinov2Model`` would require splitting
|
||||
# ``attn.qkv`` weights and a large rename map for every block, and we'd
|
||||
# still need to write the DA3 extensions separately. Keeping the upstream
|
||||
# ``pretrained.*`` key layout here means HF weights load directly with no
|
||||
# conversion step.
|
||||
#
|
||||
# Ported from the upstream repo at:
|
||||
# src/depth_anything_3/model/dinov2/{dinov2,vision_transformer}.py
|
||||
# src/depth_anything_3/model/dinov2/layers/*
|
||||
#
|
||||
# DA3 extensions on top of vanilla DINOv2 (only used by Small/Base variants):
|
||||
# - 2D Rotary Position Embedding starting at ``rope_start``
|
||||
# - QK-norm starting at ``qknorm_start``
|
||||
# - Alternating local/global attention blocks starting at ``alt_start``
|
||||
# - Camera-conditioning token concatenated to features (``cat_token=True``),
|
||||
# with a learned parameter ``camera_token`` injected at block
|
||||
# ``alt_start`` when no external camera token is supplied.
|
||||
#
|
||||
# For the Mono/Metric variants the configuration disables all of the above
|
||||
# (alt_start/qknorm_start/rope_start = -1, cat_token=False) so this module
|
||||
# collapses to a vanilla DINOv2-ViT encoder.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 2D rotary position embedding
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PositionGetter:
|
||||
def __init__(self):
|
||||
self._cache: dict[tuple[int, int], torch.Tensor] = {}
|
||||
|
||||
def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor:
|
||||
key = (height, width)
|
||||
if key not in self._cache:
|
||||
y = torch.arange(height, device=device)
|
||||
x = torch.arange(width, device=device)
|
||||
self._cache[key] = torch.cartesian_prod(y, x)
|
||||
cached = self._cache[key]
|
||||
return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
||||
|
||||
|
||||
class RotaryPositionEmbedding2D(nn.Module):
|
||||
def __init__(self, frequency: float = 100.0):
|
||||
super().__init__()
|
||||
self.base_frequency = frequency
|
||||
self._freq_cache: dict = {}
|
||||
|
||||
def _components(self, dim: int, seq_len: int, device, dtype):
|
||||
key = (dim, seq_len, device, dtype)
|
||||
if key not in self._freq_cache:
|
||||
exp = torch.arange(0, dim, 2, device=device).float() / dim
|
||||
inv_freq = 1.0 / (self.base_frequency ** exp)
|
||||
pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
||||
ang = torch.einsum("i,j->ij", pos, inv_freq)
|
||||
ang = ang.to(dtype)
|
||||
ang = torch.cat((ang, ang), dim=-1)
|
||||
self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype))
|
||||
return self._freq_cache[key]
|
||||
|
||||
@staticmethod
|
||||
def _rotate(x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1]
|
||||
x1, x2 = x[..., : d // 2], x[..., d // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def _apply_1d(self, tokens, positions, cos_c, sin_c):
|
||||
cos = F.embedding(positions, cos_c)[:, None, :, :]
|
||||
sin = F.embedding(positions, sin_c)[:, None, :, :]
|
||||
return (tokens * cos) + (self._rotate(tokens) * sin)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
||||
feature_dim = tokens.size(-1) // 2
|
||||
max_pos = int(positions.max()) + 1
|
||||
cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype)
|
||||
v, h = tokens.chunk(2, dim=-1)
|
||||
v = self._apply_1d(v, positions[..., 0], cos_c, sin_c)
|
||||
h = self._apply_1d(h, positions[..., 1], cos_c, sin_c)
|
||||
return torch.cat((v, h), dim=-1)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Patch embed / MLP / SwiGLU / LayerScale
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(self, patch_size=14, in_chans=3, embed_dim=384,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.patch_size = (patch_size, patch_size)
|
||||
self.proj = operations.Conv2d(
|
||||
in_chans, embed_dim,
|
||||
kernel_size=self.patch_size, stride=self.patch_size,
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, _, H, W = x.shape
|
||||
ph, pw = self.patch_size
|
||||
assert H % ph == 0 and W % pw == 0
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, bias=True,
|
||||
act_layer=nn.GELU, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = operations.Linear(in_features, hidden_features, bias=bias, device=device, dtype=dtype)
|
||||
self.act = act_layer()
|
||||
self.fc2 = operations.Linear(hidden_features, in_features, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
|
||||
|
||||
class SwiGLUFFNFused(nn.Module):
|
||||
"""SwiGLU FFN matching upstream xformers.ops.SwiGLU layout (used for vitg)."""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, bias=True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
hidden_features = hidden_features or in_features
|
||||
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
# NOTE: xformers SwiGLU stores w12 as a single fused Linear (in, 2*hidden);
|
||||
# split-by-half at forward time. We don't currently need this for the
|
||||
# Apache-2.0 variants but keep it for parity with the upstream key names.
|
||||
self.w12 = operations.Linear(in_features, 2 * hidden_features, bias=bias, device=device, dtype=dtype)
|
||||
self.w3 = operations.Linear(hidden_features, in_features, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x12 = self.w12(x)
|
||||
x1, x2 = x12.chunk(2, dim=-1)
|
||||
return self.w3(F.silu(x1) * x2)
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(self, dim, init_values: float = 1e-5, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x):
|
||||
return x * comfy_cast(self.gamma, x)
|
||||
|
||||
|
||||
def comfy_cast(p: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
|
||||
"""Cast a parameter to match the reference tensor's device/dtype."""
|
||||
if p.device != ref.device or p.dtype != ref.dtype:
|
||||
return p.to(device=ref.device, dtype=ref.dtype)
|
||||
return p
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Attention + Block
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads: int, qkv_bias: bool = True, proj_bias: bool = True,
|
||||
qk_norm: bool = False, rope: Optional[RotaryPositionEmbedding2D] = None,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.q_norm = operations.LayerNorm(self.head_dim, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||
self.k_norm = operations.LayerNorm(self.head_dim, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||
self.proj = operations.Linear(dim, dim, bias=proj_bias, device=device, dtype=dtype)
|
||||
self.rope = rope
|
||||
|
||||
def forward(self, x: torch.Tensor, pos: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
if self.rope is not None and pos is not None:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=(attn_mask[:, None].repeat(1, self.num_heads, 1, 1)
|
||||
if attn_mask is not None else None),
|
||||
)
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, num_heads, mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True, proj_bias: bool = True, ffn_bias: bool = True,
|
||||
init_values: Optional[float] = 1.0,
|
||||
norm_layer=nn.LayerNorm,
|
||||
ffn_layer=Mlp,
|
||||
qk_norm: bool = False,
|
||||
rope: Optional[RotaryPositionEmbedding2D] = None,
|
||||
ln_eps: float = 1e-6,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(dim, eps=ln_eps, device=device, dtype=dtype)
|
||||
self.attn = Attention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
|
||||
qk_norm=qk_norm, rope=rope, device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
self.ls1 = (LayerScale(dim, init_values=init_values, device=device, dtype=dtype)
|
||||
if init_values else nn.Identity())
|
||||
|
||||
self.norm2 = operations.LayerNorm(dim, eps=ln_eps, device=device, dtype=dtype)
|
||||
mlp_hidden = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim, hidden_features=mlp_hidden,
|
||||
bias=ffn_bias, device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
self.ls2 = (LayerScale(dim, init_values=init_values, device=device, dtype=dtype)
|
||||
if init_values else nn.Identity())
|
||||
|
||||
def forward(self, x, pos=None, attn_mask=None):
|
||||
x = x + self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask))
|
||||
x = x + self.ls2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DINOv2 vision transformer
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
_BACKBONE_PRESETS = {
|
||||
"vits": dict(embed_dim=384, depth=12, num_heads=6, ffn_layer="mlp"),
|
||||
"vitb": dict(embed_dim=768, depth=12, num_heads=12, ffn_layer="mlp"),
|
||||
"vitl": dict(embed_dim=1024, depth=24, num_heads=16, ffn_layer="mlp"),
|
||||
"vitg": dict(embed_dim=1536, depth=40, num_heads=24, ffn_layer="swiglufused"),
|
||||
}
|
||||
|
||||
|
||||
class DinoVisionTransformer(nn.Module):
|
||||
PATCH_SIZE = 14
|
||||
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
ffn_layer: str = "mlp",
|
||||
mlp_ratio: float = 4.0,
|
||||
init_values: float = 1.0,
|
||||
alt_start: int = -1,
|
||||
qknorm_start: int = -1,
|
||||
rope_start: int = -1,
|
||||
rope_freq: float = 100.0,
|
||||
cat_token: bool = True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
norm_layer = nn.LayerNorm
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.alt_start = alt_start
|
||||
self.qknorm_start = qknorm_start
|
||||
self.rope_start = rope_start
|
||||
self.cat_token = cat_token
|
||||
self.patch_size = self.PATCH_SIZE
|
||||
self.num_register_tokens = 0
|
||||
self.patch_start_idx = 1
|
||||
self.num_tokens = 1
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
patch_size=self.PATCH_SIZE, in_chans=3, embed_dim=embed_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
# Number of patch positions for the historical 518x518 reference grid.
|
||||
ref_grid = 518 // self.PATCH_SIZE
|
||||
num_patches = ref_grid * ref_grid
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, device=device, dtype=dtype))
|
||||
if alt_start != -1:
|
||||
self.camera_token = nn.Parameter(torch.zeros(1, 2, embed_dim, device=device, dtype=dtype))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim,
|
||||
device=device, dtype=dtype))
|
||||
|
||||
if rope_start != -1 and rope_freq > 0:
|
||||
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq)
|
||||
self.position_getter = PositionGetter()
|
||||
else:
|
||||
self.rope = None
|
||||
self.position_getter = None
|
||||
|
||||
if ffn_layer == "mlp":
|
||||
ffn = Mlp
|
||||
elif ffn_layer in ("swiglu", "swiglufused"):
|
||||
ffn = SwiGLUFFNFused
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported ffn_layer: {ffn_layer}")
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=True,
|
||||
proj_bias=True,
|
||||
ffn_bias=True,
|
||||
init_values=init_values,
|
||||
norm_layer=norm_layer,
|
||||
ffn_layer=ffn,
|
||||
qk_norm=(qknorm_start != -1 and i >= qknorm_start),
|
||||
rope=(self.rope if (rope_start != -1 and i >= rope_start) else None),
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
for i in range(depth)
|
||||
])
|
||||
self.norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
def interpolate_pos_encoding(self, x, w, h):
|
||||
previous_dtype = x.dtype
|
||||
npatch = x.shape[1] - 1
|
||||
N = self.pos_embed.shape[1] - 1
|
||||
pos_embed = comfy_cast(self.pos_embed, x).float()
|
||||
if npatch == N and w == h:
|
||||
return pos_embed
|
||||
class_pos_embed = pos_embed[:, 0]
|
||||
patch_pos_embed = pos_embed[:, 1:]
|
||||
dim = x.shape[-1]
|
||||
w0 = w // self.patch_size
|
||||
h0 = h // self.patch_size
|
||||
M = int(math.sqrt(N))
|
||||
assert N == M * M
|
||||
# Historical 0.1 offset preserves bicubic resample compatibility with
|
||||
# the original DINOv2 release; see the upstream PR for context.
|
||||
sx = float(w0 + 0.1) / M
|
||||
sy = float(h0 + 0.1) / M
|
||||
patch_pos_embed = F.interpolate(
|
||||
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
||||
scale_factor=(sx, sy), mode="bicubic", antialias=False,
|
||||
)
|
||||
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
||||
|
||||
def prepare_tokens(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (B, S, 3, H, W) -> tokens (B, S, 1+N, C)
|
||||
B, S, _, H, W = x.shape
|
||||
x = rearrange(x, "b s c h w -> (b s) c h w")
|
||||
x = self.patch_embed(x)
|
||||
cls_token = comfy_cast(self.cls_token, x).expand(B, S, -1).reshape(B * S, 1, self.embed_dim)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
x = x + self.interpolate_pos_encoding(x, W, H)
|
||||
x = rearrange(x, "(b s) n c -> b s n c", b=B, s=S)
|
||||
return x
|
||||
|
||||
def _prepare_rope(self, B, S, H, W, device):
|
||||
if self.rope is None:
|
||||
return None, None
|
||||
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=device)
|
||||
pos = rearrange(pos, "(b s) n c -> b s n c", b=B)
|
||||
pos_nodiff = torch.zeros_like(pos)
|
||||
if self.patch_start_idx > 0:
|
||||
pos = pos + 1
|
||||
pos_special = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype)
|
||||
pos_special = rearrange(pos_special, "(b s) n c -> b s n c", b=B)
|
||||
pos = torch.cat([pos_special, pos], dim=2)
|
||||
pos_nodiff = pos_nodiff + 1
|
||||
pos_nodiff = torch.cat([pos_special, pos_nodiff], dim=2)
|
||||
return pos, pos_nodiff
|
||||
|
||||
def _attn(self, x, blk, attn_type, pos=None, attn_mask=None):
|
||||
b, s, n = x.shape[:3]
|
||||
if attn_type == "local":
|
||||
x = rearrange(x, "b s n c -> (b s) n c")
|
||||
if pos is not None:
|
||||
pos = rearrange(pos, "b s n c -> (b s) n c")
|
||||
else: # "global"
|
||||
x = rearrange(x, "b s n c -> b (s n) c")
|
||||
if pos is not None:
|
||||
pos = rearrange(pos, "b s n c -> b (s n) c")
|
||||
x = blk(x, pos=pos, attn_mask=attn_mask)
|
||||
if attn_type == "local":
|
||||
x = rearrange(x, "(b s) n c -> b s n c", b=b, s=s)
|
||||
else:
|
||||
x = rearrange(x, "b (s n) c -> b s n c", b=b, s=s)
|
||||
return x
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public forward
|
||||
# ------------------------------------------------------------------
|
||||
def get_intermediate_layers(self, x: torch.Tensor, out_layers: List[int],
|
||||
cam_token: Optional[torch.Tensor] = None):
|
||||
B, S, _, H, W = x.shape
|
||||
x = self.prepare_tokens(x)
|
||||
pos, pos_nodiff = self._prepare_rope(B, S, H, W, x.device)
|
||||
|
||||
outputs = []
|
||||
local_x = x
|
||||
|
||||
for i, blk in enumerate(self.blocks):
|
||||
if self.rope is not None and i >= self.rope_start:
|
||||
g_pos, l_pos = pos_nodiff, pos
|
||||
else:
|
||||
g_pos, l_pos = None, None
|
||||
|
||||
if self.alt_start != -1 and i == self.alt_start:
|
||||
# Inject camera token at the alt-start boundary.
|
||||
if cam_token is not None:
|
||||
inj = cam_token
|
||||
else:
|
||||
ct = comfy_cast(self.camera_token, x)
|
||||
ref_token = ct[:, :1].expand(B, -1, -1)
|
||||
src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1)
|
||||
inj = torch.cat([ref_token, src_token], dim=1)
|
||||
x = x.clone()
|
||||
x[:, :, 0] = inj
|
||||
|
||||
if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1):
|
||||
x = self._attn(x, blk, "global", pos=g_pos)
|
||||
else:
|
||||
x = self._attn(x, blk, "local", pos=l_pos)
|
||||
local_x = x
|
||||
|
||||
if i in out_layers:
|
||||
out_x = torch.cat([local_x, x], dim=-1) if self.cat_token else x
|
||||
outputs.append(out_x)
|
||||
|
||||
# Apply final norm. Upstream norms only the "global" half when cat_token.
|
||||
normed: List[torch.Tensor] = []
|
||||
camera_tokens: List[torch.Tensor] = []
|
||||
for out_x in outputs:
|
||||
# Camera/cls token slot is index 0 *before* register-token stripping.
|
||||
camera_tokens.append(out_x[:, :, 0])
|
||||
if out_x.shape[-1] == self.embed_dim:
|
||||
normed.append(self.norm(out_x))
|
||||
elif out_x.shape[-1] == self.embed_dim * 2:
|
||||
left = out_x[..., : self.embed_dim]
|
||||
right = self.norm(out_x[..., self.embed_dim :])
|
||||
normed.append(torch.cat([left, right], dim=-1))
|
||||
else:
|
||||
raise ValueError(f"Unexpected token width: {out_x.shape[-1]}")
|
||||
# Drop cls/cam token + register tokens from patch sequence.
|
||||
normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed]
|
||||
# Match upstream signature consumed by DA3 heads:
|
||||
# feats[i][0] = normed patch tokens, feats[i][1] = camera/cls token.
|
||||
return list(zip(normed, camera_tokens))
|
||||
|
||||
|
||||
class DinoV2(nn.Module):
|
||||
"""Top-level DINOv2 wrapper matching upstream key layout (``self.pretrained``)."""
|
||||
|
||||
def __init__(self, name: str = "vits",
|
||||
out_layers: Optional[List[int]] = None,
|
||||
alt_start: int = -1,
|
||||
qknorm_start: int = -1,
|
||||
rope_start: int = -1,
|
||||
cat_token: bool = True,
|
||||
device=None, dtype=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
if name not in _BACKBONE_PRESETS:
|
||||
raise ValueError(f"Unknown DINOv2 backbone variant: {name!r}")
|
||||
preset = _BACKBONE_PRESETS[name]
|
||||
self.name = name
|
||||
self.out_layers = list(out_layers) if out_layers is not None else [5, 7, 9, 11]
|
||||
self.cat_token = cat_token
|
||||
self.pretrained = DinoVisionTransformer(
|
||||
embed_dim=preset["embed_dim"],
|
||||
depth=preset["depth"],
|
||||
num_heads=preset["num_heads"],
|
||||
ffn_layer=preset["ffn_layer"],
|
||||
alt_start=alt_start,
|
||||
qknorm_start=qknorm_start,
|
||||
rope_start=rope_start,
|
||||
cat_token=cat_token,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, x, cam_token=None, **_unused):
|
||||
return self.pretrained.get_intermediate_layers(x, self.out_layers, cam_token=cam_token)
|
||||
517
comfy/ldm/depth_anything_3/dpt.py
Normal file
517
comfy/ldm/depth_anything_3/dpt.py
Normal file
@ -0,0 +1,517 @@
|
||||
# DPT / DualDPT heads for Depth Anything 3.
|
||||
#
|
||||
# Ported from:
|
||||
# src/depth_anything_3/model/dpt.py (DPT - single main head + sky head)
|
||||
# src/depth_anything_3/model/dualdpt.py (DualDPT - depth + auxiliary "ray" head)
|
||||
#
|
||||
# In the monocular path we always discard the auxiliary "ray" output of
|
||||
# DualDPT. The auxiliary branch is still constructed so that DA3 HF weights
|
||||
# load cleanly without missing-key warnings.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helpers (matching upstream head_utils.py)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Permute(nn.Module):
|
||||
def __init__(self, dims: Tuple[int, ...]):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.permute(*self.dims)
|
||||
|
||||
|
||||
def _custom_interpolate(
|
||||
x: torch.Tensor,
|
||||
size: Optional[Tuple[int, int]] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
mode: str = "bilinear",
|
||||
align_corners: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if size is None:
|
||||
assert scale_factor is not None
|
||||
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
||||
INT_MAX = 1610612736
|
||||
total = size[0] * size[1] * x.shape[0] * x.shape[1]
|
||||
if total > INT_MAX:
|
||||
chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0)
|
||||
outs = [F.interpolate(c, size=size, mode=mode, align_corners=align_corners) for c in chunks]
|
||||
return torch.cat(outs, dim=0).contiguous()
|
||||
return F.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
||||
|
||||
|
||||
def _create_uv_grid(width: int, height: int, aspect_ratio: float,
|
||||
dtype, device) -> torch.Tensor:
|
||||
"""Normalised UV grid spanning (-x_span, -y_span)..(x_span, y_span)."""
|
||||
diag_factor = (aspect_ratio ** 2 + 1.0) ** 0.5
|
||||
span_x = aspect_ratio / diag_factor
|
||||
span_y = 1.0 / diag_factor
|
||||
left_x = -span_x * (width - 1) / width
|
||||
right_x = span_x * (width - 1) / width
|
||||
top_y = -span_y * (height - 1) / height
|
||||
bottom_y = span_y * (height - 1) / height
|
||||
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
||||
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
||||
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
||||
return torch.stack((uu, vv), dim=-1) # (H, W, 2)
|
||||
|
||||
|
||||
def _make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100.0) -> torch.Tensor:
|
||||
omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
|
||||
omega = 1.0 / omega_0 ** (omega / (embed_dim / 2.0))
|
||||
pos = pos.reshape(-1)
|
||||
out = torch.einsum("m,d->md", pos, omega)
|
||||
return torch.cat([out.sin(), out.cos()], dim=1).float()
|
||||
|
||||
|
||||
def _position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int,
|
||||
omega_0: float = 100.0) -> torch.Tensor:
|
||||
H, W, _ = pos_grid.shape
|
||||
pos_flat = pos_grid.reshape(-1, 2)
|
||||
emb_x = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0)
|
||||
emb_y = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0)
|
||||
emb = torch.cat([emb_x, emb_y], dim=-1)
|
||||
return emb.view(H, W, embed_dim)
|
||||
|
||||
|
||||
def _add_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
||||
"""Stateless UV positional embedding added to a feature map (B, C, h, w)."""
|
||||
pw, ph = x.shape[-1], x.shape[-2]
|
||||
pe = _create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
||||
pe = _position_grid_to_embed(pe, x.shape[1]) * ratio
|
||||
pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1).to(dtype=x.dtype)
|
||||
return x + pe
|
||||
|
||||
|
||||
def _apply_activation(x: torch.Tensor, activation: str) -> torch.Tensor:
|
||||
act = (activation or "linear").lower()
|
||||
if act == "exp":
|
||||
return torch.exp(x)
|
||||
if act == "expp1":
|
||||
return torch.exp(x) + 1
|
||||
if act == "expm1":
|
||||
return torch.expm1(x)
|
||||
if act == "relu":
|
||||
return torch.relu(x)
|
||||
if act == "sigmoid":
|
||||
return torch.sigmoid(x)
|
||||
if act == "softplus":
|
||||
return F.softplus(x)
|
||||
if act == "tanh":
|
||||
return torch.tanh(x)
|
||||
return x
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Fusion building blocks
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
def __init__(self, features: int,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv1 = operations.Conv2d(features, features, 3, 1, 1, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
self.conv2 = operations.Conv2d(features, features, 3, 1, 1, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
self.activation = nn.ReLU(inplace=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
def __init__(self, features: int, has_residual: bool = True,
|
||||
align_corners: bool = True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.align_corners = align_corners
|
||||
self.has_residual = has_residual
|
||||
if has_residual:
|
||||
self.resConfUnit1 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
|
||||
else:
|
||||
self.resConfUnit1 = None
|
||||
self.resConfUnit2 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
|
||||
self.out_conv = operations.Conv2d(features, features, 1, 1, 0, bias=True,
|
||||
device=device, dtype=dtype)
|
||||
|
||||
def forward(self, *xs: torch.Tensor, size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
|
||||
y = xs[0]
|
||||
if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None:
|
||||
y = y + self.resConfUnit1(xs[1])
|
||||
y = self.resConfUnit2(y)
|
||||
if size is None:
|
||||
up_kwargs = {"scale_factor": 2.0}
|
||||
else:
|
||||
up_kwargs = {"size": size}
|
||||
y = _custom_interpolate(y, **up_kwargs, mode="bilinear",
|
||||
align_corners=self.align_corners)
|
||||
y = self.out_conv(y)
|
||||
return y
|
||||
|
||||
|
||||
class _Scratch(nn.Module):
|
||||
"""Container that mirrors upstream ``scratch`` attribute layout."""
|
||||
|
||||
|
||||
def _make_scratch(in_shape: List[int], out_shape: int,
|
||||
device=None, dtype=None, operations=None) -> _Scratch:
|
||||
scratch = _Scratch()
|
||||
scratch.layer1_rn = operations.Conv2d(in_shape[0], out_shape, 3, 1, 1, bias=False,
|
||||
device=device, dtype=dtype)
|
||||
scratch.layer2_rn = operations.Conv2d(in_shape[1], out_shape, 3, 1, 1, bias=False,
|
||||
device=device, dtype=dtype)
|
||||
scratch.layer3_rn = operations.Conv2d(in_shape[2], out_shape, 3, 1, 1, bias=False,
|
||||
device=device, dtype=dtype)
|
||||
scratch.layer4_rn = operations.Conv2d(in_shape[3], out_shape, 3, 1, 1, bias=False,
|
||||
device=device, dtype=dtype)
|
||||
return scratch
|
||||
|
||||
|
||||
def _make_fusion_block(features: int, has_residual: bool = True,
|
||||
device=None, dtype=None, operations=None) -> FeatureFusionBlock:
|
||||
return FeatureFusionBlock(features, has_residual=has_residual,
|
||||
align_corners=True,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DPT (single head + optional sky head) -- used by DA3Mono/Metric
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DPT(nn.Module):
|
||||
"""Single-head DPT used by DA3Mono-Large and DA3Metric-Large."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
patch_size: int = 14,
|
||||
output_dim: int = 1,
|
||||
activation: str = "exp",
|
||||
conf_activation: str = "expp1",
|
||||
features: int = 256,
|
||||
out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
pos_embed: bool = False,
|
||||
down_ratio: int = 1,
|
||||
head_name: str = "depth",
|
||||
use_sky_head: bool = True,
|
||||
sky_name: str = "sky",
|
||||
sky_activation: str = "relu",
|
||||
norm_type: str = "idt",
|
||||
device=None, dtype=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.activation = activation
|
||||
self.conf_activation = conf_activation
|
||||
self.pos_embed = pos_embed
|
||||
self.down_ratio = down_ratio
|
||||
self.head_main = head_name
|
||||
self.sky_name = sky_name
|
||||
self.out_dim = output_dim
|
||||
self.has_conf = output_dim > 1
|
||||
self.use_sky_head = use_sky_head
|
||||
self.sky_activation = sky_activation
|
||||
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
|
||||
|
||||
if norm_type == "layer":
|
||||
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
|
||||
out_channels = list(out_channels)
|
||||
self.projects = nn.ModuleList([
|
||||
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype)
|
||||
for oc in out_channels
|
||||
])
|
||||
self.resize_layers = nn.ModuleList([
|
||||
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
nn.Identity(),
|
||||
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
])
|
||||
|
||||
self.scratch = _make_scratch(out_channels, features,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
self.scratch.output_conv1 = operations.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
if self.use_sky_head:
|
||||
self.scratch.sky_output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, feats: List[torch.Tensor], H: int, W: int,
|
||||
patch_start_idx: int = 0, **_kwargs) -> dict:
|
||||
# feats[i][0] is the patch-token tensor with shape (B, S, N_patch, C)
|
||||
B, S, N, C = feats[0][0].shape
|
||||
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
|
||||
|
||||
ph, pw = H // self.patch_size, W // self.patch_size
|
||||
resized = []
|
||||
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
|
||||
x = feats_flat[take_idx][:, patch_start_idx:]
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
|
||||
x = self.projects[stage_idx](x)
|
||||
if self.pos_embed:
|
||||
x = _add_pos_embed(x, W, H)
|
||||
x = self.resize_layers[stage_idx](x)
|
||||
resized.append(x)
|
||||
|
||||
l1_rn = self.scratch.layer1_rn(resized[0])
|
||||
l2_rn = self.scratch.layer2_rn(resized[1])
|
||||
l3_rn = self.scratch.layer3_rn(resized[2])
|
||||
l4_rn = self.scratch.layer4_rn(resized[3])
|
||||
|
||||
out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
|
||||
out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:])
|
||||
out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:])
|
||||
out = self.scratch.refinenet1(out, l1_rn)
|
||||
|
||||
h_out = int(ph * self.patch_size / self.down_ratio)
|
||||
w_out = int(pw * self.patch_size / self.down_ratio)
|
||||
|
||||
fused = self.scratch.output_conv1(out)
|
||||
fused = _custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True)
|
||||
if self.pos_embed:
|
||||
fused = _add_pos_embed(fused, W, H)
|
||||
feat = fused
|
||||
|
||||
main_logits = self.scratch.output_conv2(feat)
|
||||
outs = {}
|
||||
if self.has_conf:
|
||||
fmap = main_logits.permute(0, 2, 3, 1)
|
||||
pred = _apply_activation(fmap[..., :-1], self.activation)
|
||||
conf = _apply_activation(fmap[..., -1], self.conf_activation)
|
||||
outs[self.head_main] = pred.squeeze(-1).view(B, S, *pred.shape[1:-1])
|
||||
outs[f"{self.head_main}_conf"] = conf.view(B, S, *conf.shape[1:])
|
||||
else:
|
||||
pred = _apply_activation(main_logits, self.activation)
|
||||
outs[self.head_main] = pred.squeeze(1).view(B, S, *pred.shape[2:])
|
||||
|
||||
if self.use_sky_head:
|
||||
sky_logits = self.scratch.sky_output_conv2(feat)
|
||||
if self.sky_activation.lower() == "sigmoid":
|
||||
sky = torch.sigmoid(sky_logits)
|
||||
elif self.sky_activation.lower() == "relu":
|
||||
sky = F.relu(sky_logits)
|
||||
else:
|
||||
sky = sky_logits
|
||||
outs[self.sky_name] = sky.squeeze(1).view(B, S, *sky.shape[2:])
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DualDPT (depth + auxiliary "ray" head) -- used by DA3-Small / DA3-Base
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DualDPT(nn.Module):
|
||||
"""Two-head DPT used by DA3-Small / DA3-Base.
|
||||
|
||||
The auxiliary "ray" head is constructed so that HF state-dict keys load
|
||||
cleanly, but its outputs are unused on the monocular path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
patch_size: int = 14,
|
||||
output_dim: int = 2,
|
||||
activation: str = "exp",
|
||||
conf_activation: str = "expp1",
|
||||
features: int = 256,
|
||||
out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
pos_embed: bool = True,
|
||||
down_ratio: int = 1,
|
||||
aux_pyramid_levels: int = 4,
|
||||
aux_out1_conv_num: int = 5,
|
||||
head_names: Tuple[str, str] = ("depth", "ray"),
|
||||
device=None, dtype=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.activation = activation
|
||||
self.conf_activation = conf_activation
|
||||
self.pos_embed = pos_embed
|
||||
self.down_ratio = down_ratio
|
||||
self.aux_levels = aux_pyramid_levels
|
||||
self.aux_out1_conv_num = aux_out1_conv_num
|
||||
self.head_main, self.head_aux = head_names
|
||||
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
|
||||
|
||||
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
|
||||
out_channels = list(out_channels)
|
||||
self.projects = nn.ModuleList([
|
||||
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype)
|
||||
for oc in out_channels
|
||||
])
|
||||
self.resize_layers = nn.ModuleList([
|
||||
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
nn.Identity(),
|
||||
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
])
|
||||
|
||||
self.scratch = _make_scratch(out_channels, features,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
# Main fusion chain
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
# Auxiliary fusion chain (separate copies)
|
||||
self.scratch.refinenet1_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
|
||||
# Main head neck + final projection
|
||||
self.scratch.output_conv1 = operations.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
# Aux pre-head per level (multi-level pyramid)
|
||||
self.scratch.output_conv1_aux = nn.ModuleList([
|
||||
self._make_aux_out1_block(head_features_1, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(self.aux_levels)
|
||||
])
|
||||
|
||||
# Aux final projection per level (includes LayerNorm permute path).
|
||||
ln_seq = [Permute((0, 2, 3, 1)),
|
||||
operations.LayerNorm(head_features_2, device=device, dtype=dtype),
|
||||
Permute((0, 3, 1, 2))]
|
||||
self.scratch.output_conv2_aux = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype),
|
||||
*ln_seq,
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0,
|
||||
device=device, dtype=dtype),
|
||||
)
|
||||
for _ in range(self.aux_levels)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def _make_aux_out1_block(in_ch: int, *, device=None, dtype=None, operations=None) -> nn.Sequential:
|
||||
# aux_out1_conv_num=5 in all Apache-2.0 variants.
|
||||
return nn.Sequential(
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, feats: List[torch.Tensor], H: int, W: int,
|
||||
patch_start_idx: int = 0, **_kwargs) -> dict:
|
||||
B, S, N, C = feats[0][0].shape
|
||||
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
|
||||
|
||||
ph, pw = H // self.patch_size, W // self.patch_size
|
||||
resized = []
|
||||
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
|
||||
x = feats_flat[take_idx][:, patch_start_idx:]
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
|
||||
x = self.projects[stage_idx](x)
|
||||
if self.pos_embed:
|
||||
x = _add_pos_embed(x, W, H)
|
||||
x = self.resize_layers[stage_idx](x)
|
||||
resized.append(x)
|
||||
|
||||
l1_rn = self.scratch.layer1_rn(resized[0])
|
||||
l2_rn = self.scratch.layer2_rn(resized[1])
|
||||
l3_rn = self.scratch.layer3_rn(resized[2])
|
||||
l4_rn = self.scratch.layer4_rn(resized[3])
|
||||
|
||||
# Main pyramid (output_conv1 is applied inside the upstream `_fuse`,
|
||||
# before interpolation -- replicate that order here).
|
||||
m = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
|
||||
m = self.scratch.refinenet3(m, l3_rn, size=l2_rn.shape[2:])
|
||||
m = self.scratch.refinenet2(m, l2_rn, size=l1_rn.shape[2:])
|
||||
m = self.scratch.refinenet1(m, l1_rn)
|
||||
m = self.scratch.output_conv1(m)
|
||||
|
||||
h_out = int(ph * self.patch_size / self.down_ratio)
|
||||
w_out = int(pw * self.patch_size / self.down_ratio)
|
||||
|
||||
m = _custom_interpolate(m, (h_out, w_out), mode="bilinear", align_corners=True)
|
||||
if self.pos_embed:
|
||||
m = _add_pos_embed(m, W, H)
|
||||
main_logits = self.scratch.output_conv2(m)
|
||||
fmap = main_logits.permute(0, 2, 3, 1)
|
||||
depth_pred = _apply_activation(fmap[..., :-1], self.activation)
|
||||
depth_conf = _apply_activation(fmap[..., -1], self.conf_activation)
|
||||
|
||||
outs = {
|
||||
self.head_main: depth_pred.squeeze(-1).view(B, S, *depth_pred.shape[1:-1]),
|
||||
f"{self.head_main}_conf": depth_conf.view(B, S, *depth_conf.shape[1:]),
|
||||
}
|
||||
|
||||
# NOTE: we intentionally do not run the auxiliary "ray" branch — it is
|
||||
# only needed for pose/ray-conditioned outputs which are out of scope
|
||||
# for this port. The aux submodules are still built so HF weights load.
|
||||
|
||||
return outs
|
||||
135
comfy/ldm/depth_anything_3/model.py
Normal file
135
comfy/ldm/depth_anything_3/model.py
Normal file
@ -0,0 +1,135 @@
|
||||
# DepthAnything3Net: top-level wrapper that combines backbone + head.
|
||||
#
|
||||
# This wrapper covers the monocular forward path only (single image -> depth).
|
||||
# Camera encoder/decoder, ray-pose head, 3D Gaussians and the Nested
|
||||
# architecture are intentionally omitted. The HF state dict for those
|
||||
# components is filtered out before loading -- see
|
||||
# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``.
|
||||
#
|
||||
# The class signature mirrors the upstream YAML config so a single dit_config
|
||||
# detected from the state dict in ``comfy/model_detection.py`` is sufficient
|
||||
# to construct the right variant.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .dinov2 import DinoV2
|
||||
from .dpt import DPT, DualDPT
|
||||
|
||||
|
||||
_HEAD_REGISTRY = {
|
||||
"dpt": DPT,
|
||||
"dualdpt": DualDPT,
|
||||
}
|
||||
|
||||
|
||||
class DepthAnything3Net(nn.Module):
|
||||
"""ComfyUI-side DepthAnything3 network (monocular path only).
|
||||
|
||||
Parameters mirror the variant YAML configs from the upstream repo.
|
||||
Values are auto-detected by ``comfy/model_detection.py`` from the state
|
||||
dict. The kwargs ``device``, ``dtype`` and ``operations`` are injected by
|
||||
``BaseModel``.
|
||||
"""
|
||||
|
||||
PATCH_SIZE = 14
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# --- Backbone ---
|
||||
backbone_name: str = "vitl",
|
||||
out_layers: Sequence[int] = (4, 11, 17, 23),
|
||||
alt_start: int = -1,
|
||||
qknorm_start: int = -1,
|
||||
rope_start: int = -1,
|
||||
cat_token: bool = False,
|
||||
# --- Head ---
|
||||
head_type: str = "dpt", # "dpt" or "dualdpt"
|
||||
head_dim_in: int = 1024,
|
||||
head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf
|
||||
head_features: int = 256,
|
||||
head_out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
head_use_sky_head: bool = True, # ignored by DualDPT
|
||||
head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT
|
||||
# ComfyUI plumbing
|
||||
device=None, dtype=None, operations=None,
|
||||
**_ignored,
|
||||
):
|
||||
super().__init__()
|
||||
head_cls = _HEAD_REGISTRY[head_type.lower()]
|
||||
self.head_type = head_type.lower()
|
||||
self.has_sky = (self.head_type == "dpt") and head_use_sky_head
|
||||
self.has_conf = head_output_dim > 1
|
||||
|
||||
self.backbone = DinoV2(
|
||||
name=backbone_name,
|
||||
out_layers=list(out_layers),
|
||||
alt_start=alt_start,
|
||||
qknorm_start=qknorm_start,
|
||||
rope_start=rope_start,
|
||||
cat_token=cat_token,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
head_kwargs = dict(
|
||||
dim_in=head_dim_in,
|
||||
patch_size=self.PATCH_SIZE,
|
||||
output_dim=head_output_dim,
|
||||
features=head_features,
|
||||
out_channels=tuple(head_out_channels),
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
if self.head_type == "dpt":
|
||||
head_kwargs.update(
|
||||
use_sky_head=head_use_sky_head,
|
||||
pos_embed=(False if head_pos_embed is None else head_pos_embed),
|
||||
)
|
||||
else: # dualdpt
|
||||
head_kwargs.update(
|
||||
pos_embed=(True if head_pos_embed is None else head_pos_embed),
|
||||
)
|
||||
self.head = head_cls(**head_kwargs)
|
||||
self.dtype = dtype
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Forward
|
||||
# ------------------------------------------------------------------
|
||||
def forward(self, image: torch.Tensor, **_unused) -> Dict[str, torch.Tensor]:
|
||||
"""Run monocular forward.
|
||||
|
||||
Args:
|
||||
image: ``(B, 3, H, W)`` ImageNet-normalised image tensor, or
|
||||
``(B, S, 3, H, W)`` if a fake "views" axis is supplied.
|
||||
H and W must be multiples of 14.
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- ``depth``: ``(B, H, W)`` raw depth values.
|
||||
- ``depth_conf``: ``(B, H, W)`` confidence (DualDPT variants only).
|
||||
- ``sky``: ``(B, H, W)`` sky probability/logit
|
||||
(DPT variants only).
|
||||
"""
|
||||
if image.ndim == 4:
|
||||
image = image.unsqueeze(1) # (B, 1, 3, H, W)
|
||||
assert image.ndim == 5 and image.shape[2] == 3, \
|
||||
f"image must be (B,3,H,W) or (B,S,3,H,W); got {tuple(image.shape)}"
|
||||
|
||||
B, S, _, H, W = image.shape
|
||||
assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \
|
||||
f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}"
|
||||
|
||||
feats = self.backbone(image)
|
||||
head_out = self.head(feats, H=H, W=W, patch_start_idx=0)
|
||||
|
||||
# Flatten the views axis (S=1 in mono inference path).
|
||||
out: Dict[str, torch.Tensor] = {}
|
||||
for k, v in head_out.items():
|
||||
if v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S:
|
||||
out[k] = v.reshape(B * S, *v.shape[2:])
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
178
comfy/ldm/depth_anything_3/preprocess.py
Normal file
178
comfy/ldm/depth_anything_3/preprocess.py
Normal file
@ -0,0 +1,178 @@
|
||||
# Input/output preprocessing helpers for Depth Anything 3.
|
||||
#
|
||||
# Ported from:
|
||||
# src/depth_anything_3/utils/io/input_processor.py (image normalisation)
|
||||
# src/depth_anything_3/utils/alignment.py (sky-aware depth clip)
|
||||
# src/depth_anything_3/model/da3.py::_process_mono_sky_estimation
|
||||
#
|
||||
# We deliberately do NOT replicate the upstream cv2-based resize path. ComfyUI
|
||||
# already provides ``comfy.utils.common_upscale`` for high-quality bilinear
|
||||
# resampling; using it keeps everything on-device and consistent with other
|
||||
# ComfyUI preprocessors. The bilinear approximation is sufficient for the
|
||||
# downstream depth-estimation task (verified visually against the upstream
|
||||
# bicubic path -- depth maps are virtually identical).
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
|
||||
PATCH_SIZE = 14
|
||||
|
||||
# ImageNet normalization constants used during DA3 training.
|
||||
_IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406])
|
||||
_IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225])
|
||||
|
||||
|
||||
def _round_to_patch(x: int, patch: int = PATCH_SIZE) -> int:
|
||||
down = (x // patch) * patch
|
||||
up = down + patch
|
||||
return up if abs(up - x) <= abs(x - down) else down
|
||||
|
||||
|
||||
def compute_target_size(orig_h: int, orig_w: int, process_res: int,
|
||||
method: str = "upper_bound_resize") -> Tuple[int, int]:
|
||||
"""Compute (target_h, target_w) for a single image.
|
||||
|
||||
Methods:
|
||||
- "upper_bound_resize": scale longest side to ``process_res``, then
|
||||
round each dim to nearest multiple of 14 (default upstream method).
|
||||
- "lower_bound_resize": scale shortest side to ``process_res``, then
|
||||
round.
|
||||
"""
|
||||
if method == "upper_bound_resize":
|
||||
longest = max(orig_h, orig_w)
|
||||
scale = process_res / float(longest)
|
||||
elif method == "lower_bound_resize":
|
||||
shortest = min(orig_h, orig_w)
|
||||
scale = process_res / float(shortest)
|
||||
else:
|
||||
raise ValueError(f"Unsupported process_res_method: {method}")
|
||||
|
||||
new_w = max(1, _round_to_patch(int(round(orig_w * scale))))
|
||||
new_h = max(1, _round_to_patch(int(round(orig_h * scale))))
|
||||
return new_h, new_w
|
||||
|
||||
|
||||
def preprocess_image(
|
||||
image: torch.Tensor,
|
||||
process_res: int = 504,
|
||||
method: str = "upper_bound_resize",
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a ComfyUI ``IMAGE`` batch for DA3.
|
||||
|
||||
Args:
|
||||
image: ``(B, H, W, 3)`` float in [0, 1] (ComfyUI ``IMAGE`` convention).
|
||||
process_res: target resolution (longest or shortest side, depending
|
||||
on ``method``).
|
||||
method: resize strategy.
|
||||
|
||||
Returns:
|
||||
``(B, 3, H', W')`` tensor with H' and W' multiples of 14, normalised
|
||||
with ImageNet statistics. The tensor lives on the same device as
|
||||
``image``.
|
||||
"""
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, \
|
||||
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
B, H, W, _ = image.shape
|
||||
target_h, target_w = compute_target_size(H, W, process_res, method)
|
||||
|
||||
# (B, H, W, 3) -> (B, 3, H, W)
|
||||
x = image.movedim(-1, 1).contiguous()
|
||||
if (target_h, target_w) != (H, W):
|
||||
# common_upscale takes a (B, C, H, W) tensor.
|
||||
x = comfy.utils.common_upscale(x, target_w, target_h, "bilinear", "disabled")
|
||||
x = x.clamp(0.0, 1.0)
|
||||
|
||||
mean = _IMAGENET_MEAN.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
||||
std = _IMAGENET_STD.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
||||
x = (x - mean) / std
|
||||
return x
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Output post-processing (sky-aware clipping for Mono/Metric variants)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_non_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor:
|
||||
"""Boolean mask: True for non-sky pixels (sky probability < threshold)."""
|
||||
return sky_prediction < threshold
|
||||
|
||||
|
||||
def apply_sky_aware_clip(
|
||||
depth: torch.Tensor,
|
||||
sky: torch.Tensor,
|
||||
threshold: float = 0.3,
|
||||
quantile: float = 0.99,
|
||||
) -> torch.Tensor:
|
||||
"""Replicates ``_process_mono_sky_estimation`` from upstream.
|
||||
|
||||
Clips sky regions to the 99th percentile of non-sky depth. Returns a new
|
||||
depth tensor; ``depth`` is not modified in place.
|
||||
"""
|
||||
non_sky = compute_non_sky_mask(sky, threshold=threshold)
|
||||
if non_sky.sum() <= 10 or (~non_sky).sum() <= 10:
|
||||
return depth.clone()
|
||||
|
||||
non_sky_depth = depth[non_sky]
|
||||
if non_sky_depth.numel() > 100_000:
|
||||
idx = torch.randint(0, non_sky_depth.numel(), (100_000,), device=non_sky_depth.device)
|
||||
sampled = non_sky_depth[idx]
|
||||
else:
|
||||
sampled = non_sky_depth
|
||||
|
||||
max_depth = torch.quantile(sampled, quantile)
|
||||
out = depth.clone()
|
||||
out[~non_sky] = max_depth
|
||||
return out
|
||||
|
||||
|
||||
def normalize_depth_v2_style(
|
||||
depth: torch.Tensor,
|
||||
sky: torch.Tensor | None = None,
|
||||
low_quantile: float = 0.01,
|
||||
high_quantile: float = 0.99,
|
||||
) -> torch.Tensor:
|
||||
"""V2-style normalization for ControlNet workflows.
|
||||
|
||||
Computes percentile bounds over non-sky pixels (when available),
|
||||
then maps depth into [0, 1] with near = white (1.0).
|
||||
"""
|
||||
if sky is not None:
|
||||
mask = compute_non_sky_mask(sky)
|
||||
if mask.any():
|
||||
valid = depth[mask]
|
||||
else:
|
||||
valid = depth.flatten()
|
||||
else:
|
||||
valid = depth.flatten()
|
||||
|
||||
if valid.numel() > 100_000:
|
||||
idx = torch.randint(0, valid.numel(), (100_000,), device=valid.device)
|
||||
sample = valid[idx]
|
||||
else:
|
||||
sample = valid
|
||||
|
||||
lo = torch.quantile(sample, low_quantile)
|
||||
hi = torch.quantile(sample, high_quantile)
|
||||
rng = (hi - lo).clamp(min=1e-6)
|
||||
norm = ((depth - lo) / rng).clamp(0.0, 1.0)
|
||||
# ControlNet convention: nearer pixels are brighter (1.0).
|
||||
norm = 1.0 - norm
|
||||
if sky is not None:
|
||||
# Sky pixels become black (far / unknown).
|
||||
sky_mask = ~compute_non_sky_mask(sky)
|
||||
norm = torch.where(sky_mask, torch.zeros_like(norm), norm)
|
||||
return norm
|
||||
|
||||
|
||||
def normalize_depth_min_max(depth: torch.Tensor) -> torch.Tensor:
|
||||
"""Simple per-frame min/max normalization with near=1.0 convention."""
|
||||
lo = depth.amin(dim=(-2, -1), keepdim=True)
|
||||
hi = depth.amax(dim=(-2, -1), keepdim=True)
|
||||
rng = (hi - lo).clamp(min=1e-6)
|
||||
return 1.0 - ((depth - lo) / rng).clamp(0.0, 1.0)
|
||||
@ -60,6 +60,7 @@ import comfy.ldm.ernie.model
|
||||
import comfy.ldm.sam3.detector
|
||||
import comfy.ldm.hidream_o1.model
|
||||
from comfy.ldm.hidream_o1.conditioning import build_extra_conds
|
||||
import comfy.ldm.depth_anything_3.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -2035,6 +2036,12 @@ class RT_DETR_v4(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
||||
|
||||
|
||||
class DepthAnything3(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.depth_anything_3.model.DepthAnything3Net)
|
||||
|
||||
class ErnieImage(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel)
|
||||
|
||||
@ -766,6 +766,90 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
|
||||
return dit_config
|
||||
|
||||
# Depth Anything 3 (Apache-2.0 monocular variants: Small/Base/Mono-Large/Metric-Large).
|
||||
if '{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "DepthAnything3"
|
||||
|
||||
patch_w = state_dict['{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix)]
|
||||
embed_dim = patch_w.shape[0]
|
||||
depth = count_blocks(state_dict_keys, '{}backbone.pretrained.blocks.'.format(key_prefix) + '{}.')
|
||||
|
||||
# Backbone preset is determined by embed_dim (matches vits/vitb/vitl/vitg).
|
||||
backbone_name = {384: "vits", 768: "vitb", 1024: "vitl", 1536: "vitg"}.get(embed_dim)
|
||||
if backbone_name is None:
|
||||
return None
|
||||
dit_config["backbone_name"] = backbone_name
|
||||
|
||||
# Detect DA3 extensions on top of vanilla DINOv2.
|
||||
has_camera_token = '{}backbone.pretrained.camera_token'.format(key_prefix) in state_dict_keys
|
||||
# qk-norm shows up as `attn.q_norm.weight` on enabled blocks.
|
||||
qknorm_indices = [
|
||||
i for i in range(depth)
|
||||
if '{}backbone.pretrained.blocks.{}.attn.q_norm.weight'.format(key_prefix, i) in state_dict_keys
|
||||
]
|
||||
qknorm_start = qknorm_indices[0] if qknorm_indices else -1
|
||||
|
||||
# The DA3 main-series configs always set alt_start == qknorm_start == rope_start.
|
||||
# cat_token=True is implied by the presence of camera_token.
|
||||
if has_camera_token:
|
||||
dit_config["alt_start"] = qknorm_start
|
||||
dit_config["rope_start"] = qknorm_start
|
||||
dit_config["qknorm_start"] = qknorm_start
|
||||
dit_config["cat_token"] = True
|
||||
else:
|
||||
dit_config["alt_start"] = -1
|
||||
dit_config["rope_start"] = -1
|
||||
dit_config["qknorm_start"] = -1
|
||||
dit_config["cat_token"] = False
|
||||
|
||||
# Detect head type and config.
|
||||
has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys
|
||||
if has_aux:
|
||||
dit_config["head_type"] = "dualdpt"
|
||||
# DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width).
|
||||
head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
|
||||
out_channels = [
|
||||
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
|
||||
for i in range(4)
|
||||
]
|
||||
features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["head_dim_in"] = head_dim_in
|
||||
dit_config["head_output_dim"] = 2
|
||||
dit_config["head_features"] = features
|
||||
dit_config["head_out_channels"] = out_channels
|
||||
dit_config["head_use_sky_head"] = False
|
||||
else:
|
||||
dit_config["head_type"] = "dpt"
|
||||
head_dim_in = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
|
||||
out_channels = [
|
||||
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
|
||||
for i in range(4)
|
||||
]
|
||||
features = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
|
||||
output_dim = state_dict[
|
||||
'{}head.scratch.output_conv2.2.weight'.format(key_prefix)
|
||||
].shape[0]
|
||||
dit_config["head_dim_in"] = head_dim_in
|
||||
dit_config["head_output_dim"] = output_dim
|
||||
dit_config["head_features"] = features
|
||||
dit_config["head_out_channels"] = out_channels
|
||||
dit_config["head_use_sky_head"] = (
|
||||
'{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys
|
||||
)
|
||||
|
||||
# out_layers: hard-coded per upstream YAML config (depth-aware default).
|
||||
if depth >= 24:
|
||||
# vitl: depths used vary between DA3-Large (DualDPT) and Mono/Metric (DPT).
|
||||
if has_aux:
|
||||
dit_config["out_layers"] = [11, 15, 19, 23]
|
||||
else:
|
||||
dit_config["out_layers"] = [4, 11, 17, 23]
|
||||
else:
|
||||
# vits/vitb: 12 blocks
|
||||
dit_config["out_layers"] = [5, 7, 9, 11]
|
||||
return dit_config
|
||||
|
||||
if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ernie"
|
||||
|
||||
@ -1847,6 +1847,33 @@ class RT_DETR_v4(supported_models_base.BASE):
|
||||
return None
|
||||
|
||||
|
||||
class DepthAnything3(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "DepthAnything3",
|
||||
}
|
||||
|
||||
# Mono path: no num_heads / num_head_channels needed.
|
||||
unet_extra_config = {}
|
||||
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.DepthAnything3(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
# Drop weights for components we do not build (camera encoder/decoder,
|
||||
# 3D Gaussian heads). Keeping unrelated keys around triggers spurious
|
||||
# "unet unexpected" warnings on load.
|
||||
drop_prefixes = ("cam_enc.", "cam_dec.", "gs_head.", "gs_adapter.")
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(drop_prefixes):
|
||||
state_dict.pop(k)
|
||||
return state_dict
|
||||
|
||||
|
||||
class ErnieImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "ernie",
|
||||
@ -2082,4 +2109,5 @@ models = [
|
||||
CogVideoX_I2V,
|
||||
CogVideoX_T2V,
|
||||
SVD_img2vid,
|
||||
DepthAnything3,
|
||||
]
|
||||
|
||||
247
comfy_extras/nodes_depth_anything_3.py
Normal file
247
comfy_extras/nodes_depth_anything_3.py
Normal file
@ -0,0 +1,247 @@
|
||||
"""ComfyUI nodes for Depth Anything 3.
|
||||
|
||||
Adds three nodes:
|
||||
|
||||
* ``LoadDepthAnything3`` -- load a DA3 ``.safetensors`` file from the
|
||||
``models/depth_estimation/`` folder. Falls back to ``models/diffusion_models/``
|
||||
so existing installations keep working.
|
||||
* ``DepthAnything3Depth`` -- run depth estimation and return a normalised
|
||||
depth map as a ComfyUI ``IMAGE`` (visualisation / ControlNet input).
|
||||
* ``DepthAnything3DepthRaw`` -- run depth estimation and return the raw depth,
|
||||
confidence and sky channels as ``MASK`` outputs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management as mm
|
||||
import comfy.sd
|
||||
import folder_paths
|
||||
from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Loader
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class LoadDepthAnything3(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadDepthAnything3",
|
||||
display_name="Load Depth Anything 3",
|
||||
category="loaders/depth_estimation",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"model_name",
|
||||
options=folder_paths.get_filename_list("depth_estimation"),
|
||||
),
|
||||
io.Combo.Input(
|
||||
"weight_dtype",
|
||||
options=["default", "fp16", "bf16", "fp32"],
|
||||
default="default",
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output("model")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name, weight_dtype) -> io.NodeOutput:
|
||||
model_options = {}
|
||||
if weight_dtype == "fp16":
|
||||
model_options["dtype"] = torch.float16
|
||||
elif weight_dtype == "bf16":
|
||||
model_options["dtype"] = torch.bfloat16
|
||||
elif weight_dtype == "fp32":
|
||||
model_options["dtype"] = torch.float32
|
||||
|
||||
path = folder_paths.get_full_path_or_raise("depth_estimation", model_name)
|
||||
model = comfy.sd.load_diffusion_model(path, model_options=model_options)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Inference helpers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_da3(model_patcher, image: torch.Tensor, process_res: int,
|
||||
method: str = "upper_bound_resize"):
|
||||
"""Run the DA3 network on a (B, H, W, 3) ``IMAGE`` batch.
|
||||
|
||||
Returns ``(depth, confidence, sky)`` tensors with the original image
|
||||
resolution. Any of ``confidence`` / ``sky`` may be ``None`` depending on
|
||||
the variant.
|
||||
"""
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, \
|
||||
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
|
||||
B, H, W, _ = image.shape
|
||||
mm.load_model_gpu(model_patcher)
|
||||
diffusion = model_patcher.model.diffusion_model
|
||||
device = mm.get_torch_device()
|
||||
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
|
||||
|
||||
depths, confs, skies = [], [], []
|
||||
# Process one image at a time to keep peak memory predictable; DA3 is
|
||||
# an inference-only model and per-sample latency dominates anyway.
|
||||
for i in range(B):
|
||||
single = image[i:i + 1].to(device)
|
||||
x = da3_preprocess.preprocess_image(single, process_res=process_res, method=method)
|
||||
x = x.to(dtype=dtype)
|
||||
with torch.no_grad():
|
||||
out = diffusion(x)
|
||||
|
||||
depth_lr = out["depth"]
|
||||
# Resize back to the original (H, W).
|
||||
depth_full = torch.nn.functional.interpolate(
|
||||
depth_lr.unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
depths.append(depth_full)
|
||||
|
||||
if "depth_conf" in out:
|
||||
conf_full = torch.nn.functional.interpolate(
|
||||
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
confs.append(conf_full)
|
||||
if "sky" in out:
|
||||
sky_full = torch.nn.functional.interpolate(
|
||||
out["sky"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
skies.append(sky_full)
|
||||
|
||||
depth = torch.cat(depths, dim=0)
|
||||
confidence = torch.cat(confs, dim=0) if confs else None
|
||||
sky = torch.cat(skies, dim=0) if skies else None
|
||||
return depth, confidence, sky
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Depth -> visualisation IMAGE
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DepthAnything3Depth(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DepthAnything3Depth",
|
||||
display_name="Depth Anything 3 (Depth)",
|
||||
category="image/depth",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input("process_res", default=504, min=140, max=2520, step=14,
|
||||
tooltip="Longest-side target resolution (multiple of 14)."),
|
||||
io.Combo.Input("resize_method",
|
||||
options=["upper_bound_resize", "lower_bound_resize"],
|
||||
default="upper_bound_resize"),
|
||||
io.Combo.Input("normalization",
|
||||
options=["v2_style", "min_max", "raw"],
|
||||
default="v2_style",
|
||||
tooltip="How to map raw depth -> [0, 1] image."),
|
||||
io.Boolean.Input("apply_sky_clip", default=True,
|
||||
tooltip="(Mono/Metric only) clip sky depth to 99th percentile."),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("depth_image"),
|
||||
io.Mask.Output("sky_mask",
|
||||
tooltip="Sky probability (Mono/Metric variants), else zeros."),
|
||||
io.Mask.Output("confidence",
|
||||
tooltip="Depth confidence (Small/Base/DualDPT variants), else zeros."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, image, process_res, resize_method, normalization,
|
||||
apply_sky_clip) -> io.NodeOutput:
|
||||
depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method)
|
||||
|
||||
if apply_sky_clip and sky is not None:
|
||||
depth = torch.stack([
|
||||
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
|
||||
for i in range(depth.shape[0])
|
||||
], dim=0)
|
||||
|
||||
if normalization == "v2_style":
|
||||
norm = torch.stack([
|
||||
da3_preprocess.normalize_depth_v2_style(depth[i],
|
||||
sky[i] if sky is not None else None)
|
||||
for i in range(depth.shape[0])
|
||||
], dim=0)
|
||||
elif normalization == "min_max":
|
||||
norm = da3_preprocess.normalize_depth_min_max(depth)
|
||||
else:
|
||||
norm = depth
|
||||
|
||||
# (B, H, W) -> (B, H, W, 3) grayscale IMAGE.
|
||||
out_image = norm.unsqueeze(-1).repeat(1, 1, 1, 3).clamp(0.0, 1.0).contiguous()
|
||||
sky_mask = sky if sky is not None else torch.zeros_like(depth)
|
||||
conf_mask = confidence if confidence is not None else torch.zeros_like(depth)
|
||||
return io.NodeOutput(out_image, sky_mask.contiguous(), conf_mask.contiguous())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Raw depth output (useful for downstream metric work)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DepthAnything3DepthRaw(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DepthAnything3DepthRaw",
|
||||
display_name="Depth Anything 3 (Raw Depth)",
|
||||
category="image/depth",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input("process_res", default=504, min=140, max=2520, step=14),
|
||||
io.Combo.Input("resize_method",
|
||||
options=["upper_bound_resize", "lower_bound_resize"],
|
||||
default="upper_bound_resize"),
|
||||
],
|
||||
outputs=[
|
||||
io.Mask.Output("depth",
|
||||
tooltip="Raw depth values (no normalisation, no clipping)."),
|
||||
io.Mask.Output("confidence"),
|
||||
io.Mask.Output("sky"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, image, process_res, resize_method) -> io.NodeOutput:
|
||||
depth, confidence, sky = _run_da3(model, image, process_res, method=resize_method)
|
||||
zeros = torch.zeros_like(depth)
|
||||
return io.NodeOutput(
|
||||
depth.contiguous(),
|
||||
(confidence if confidence is not None else zeros).contiguous(),
|
||||
(sky if sky is not None else zeros).contiguous(),
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Extension registration
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DepthAnything3Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
LoadDepthAnything3,
|
||||
DepthAnything3Depth,
|
||||
DepthAnything3DepthRaw,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> DepthAnything3Extension:
|
||||
return DepthAnything3Extension()
|
||||
@ -58,6 +58,13 @@ folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "fram
|
||||
|
||||
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["depth_estimation"] = (
|
||||
[os.path.join(models_dir, "depth_estimation"),
|
||||
os.path.join(models_dir, "diffusion_models")],
|
||||
supported_pt_extensions,
|
||||
)
|
||||
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
input_directory = os.path.join(base_path, "input")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user