ComfyUI/comfy/ldm/depth_anything_3/dinov2.py
2026-05-13 10:59:29 +02:00

498 lines
21 KiB
Python

# DINOv2 backbone for Depth Anything 3 (monocular inference path).
#
# Why not reuse ``comfy/image_encoders/dino2.py``?
# The existing ``Dinov2Model`` is a vanilla HuggingFace-style DINOv2 with a
# different state-dict layout (separate Q/K/V, ``embeddings.*`` /
# ``encoder.layer.*`` keys, ``layer_scale1.lambda1``) and no support for the
# architectural extensions DA3 adds on top of DINOv2 (RoPE, QK-norm,
# alternating local/global attention, concatenated camera token). Loading
# raw DA3 HF safetensors into ``Dinov2Model`` would require splitting
# ``attn.qkv`` weights and a large rename map for every block, and we'd
# still need to write the DA3 extensions separately. Keeping the upstream
# ``pretrained.*`` key layout here means HF weights load directly with no
# conversion step.
#
# Ported from the upstream repo at:
# src/depth_anything_3/model/dinov2/{dinov2,vision_transformer}.py
# src/depth_anything_3/model/dinov2/layers/*
#
# DA3 extensions on top of vanilla DINOv2 (only used by Small/Base variants):
# - 2D Rotary Position Embedding starting at ``rope_start``
# - QK-norm starting at ``qknorm_start``
# - Alternating local/global attention blocks starting at ``alt_start``
# - Camera-conditioning token concatenated to features (``cat_token=True``),
# with a learned parameter ``camera_token`` injected at block
# ``alt_start`` when no external camera token is supplied.
#
# For the Mono/Metric variants the configuration disables all of the above
# (alt_start/qknorm_start/rope_start = -1, cat_token=False) so this module
# collapses to a vanilla DINOv2-ViT encoder.
from __future__ import annotations
import math
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
# -----------------------------------------------------------------------------
# 2D rotary position embedding
# -----------------------------------------------------------------------------
class PositionGetter:
def __init__(self):
self._cache: dict[tuple[int, int], torch.Tensor] = {}
def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor:
key = (height, width)
if key not in self._cache:
y = torch.arange(height, device=device)
x = torch.arange(width, device=device)
self._cache[key] = torch.cartesian_prod(y, x)
cached = self._cache[key]
return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
class RotaryPositionEmbedding2D(nn.Module):
def __init__(self, frequency: float = 100.0):
super().__init__()
self.base_frequency = frequency
self._freq_cache: dict = {}
def _components(self, dim: int, seq_len: int, device, dtype):
key = (dim, seq_len, device, dtype)
if key not in self._freq_cache:
exp = torch.arange(0, dim, 2, device=device).float() / dim
inv_freq = 1.0 / (self.base_frequency ** exp)
pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
ang = torch.einsum("i,j->ij", pos, inv_freq)
ang = ang.to(dtype)
ang = torch.cat((ang, ang), dim=-1)
self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype))
return self._freq_cache[key]
@staticmethod
def _rotate(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1]
x1, x2 = x[..., : d // 2], x[..., d // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_1d(self, tokens, positions, cos_c, sin_c):
cos = F.embedding(positions, cos_c)[:, None, :, :]
sin = F.embedding(positions, sin_c)[:, None, :, :]
return (tokens * cos) + (self._rotate(tokens) * sin)
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
feature_dim = tokens.size(-1) // 2
max_pos = int(positions.max()) + 1
cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype)
v, h = tokens.chunk(2, dim=-1)
v = self._apply_1d(v, positions[..., 0], cos_c, sin_c)
h = self._apply_1d(h, positions[..., 1], cos_c, sin_c)
return torch.cat((v, h), dim=-1)
# -----------------------------------------------------------------------------
# Patch embed / MLP / SwiGLU / LayerScale
# -----------------------------------------------------------------------------
class PatchEmbed(nn.Module):
def __init__(self, patch_size=14, in_chans=3, embed_dim=384,
device=None, dtype=None, operations=None):
super().__init__()
self.patch_size = (patch_size, patch_size)
self.proj = operations.Conv2d(
in_chans, embed_dim,
kernel_size=self.patch_size, stride=self.patch_size,
device=device, dtype=dtype,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, _, H, W = x.shape
ph, pw = self.patch_size
assert H % ph == 0 and W % pw == 0
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, bias=True,
act_layer=nn.GELU, device=None, dtype=None, operations=None):
super().__init__()
hidden_features = hidden_features or in_features
self.fc1 = operations.Linear(in_features, hidden_features, bias=bias, device=device, dtype=dtype)
self.act = act_layer()
self.fc2 = operations.Linear(hidden_features, in_features, bias=bias, device=device, dtype=dtype)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
class SwiGLUFFNFused(nn.Module):
"""SwiGLU FFN matching upstream xformers.ops.SwiGLU layout (used for vitg)."""
def __init__(self, in_features, hidden_features=None, bias=True,
device=None, dtype=None, operations=None):
super().__init__()
hidden_features = hidden_features or in_features
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
# NOTE: xformers SwiGLU stores w12 as a single fused Linear (in, 2*hidden);
# split-by-half at forward time. We don't currently need this for the
# Apache-2.0 variants but keep it for parity with the upstream key names.
self.w12 = operations.Linear(in_features, 2 * hidden_features, bias=bias, device=device, dtype=dtype)
self.w3 = operations.Linear(hidden_features, in_features, bias=bias, device=device, dtype=dtype)
def forward(self, x):
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
return self.w3(F.silu(x1) * x2)
class LayerScale(nn.Module):
def __init__(self, dim, init_values: float = 1e-5, device=None, dtype=None):
super().__init__()
self.gamma = nn.Parameter(init_values * torch.ones(dim, device=device, dtype=dtype))
def forward(self, x):
return x * comfy_cast(self.gamma, x)
def comfy_cast(p: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
"""Cast a parameter to match the reference tensor's device/dtype."""
if p.device != ref.device or p.dtype != ref.dtype:
return p.to(device=ref.device, dtype=ref.dtype)
return p
# -----------------------------------------------------------------------------
# Attention + Block
# -----------------------------------------------------------------------------
class Attention(nn.Module):
def __init__(self, dim, num_heads: int, qkv_bias: bool = True, proj_bias: bool = True,
qk_norm: bool = False, rope: Optional[RotaryPositionEmbedding2D] = None,
device=None, dtype=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
self.q_norm = operations.LayerNorm(self.head_dim, device=device, dtype=dtype) if qk_norm else nn.Identity()
self.k_norm = operations.LayerNorm(self.head_dim, device=device, dtype=dtype) if qk_norm else nn.Identity()
self.proj = operations.Linear(dim, dim, bias=proj_bias, device=device, dtype=dtype)
self.rope = rope
def forward(self, x: torch.Tensor, pos: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q, k = self.q_norm(q), self.k_norm(k)
if self.rope is not None and pos is not None:
q = self.rope(q, pos)
k = self.rope(k, pos)
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=(attn_mask[:, None].repeat(1, self.num_heads, 1, 1)
if attn_mask is not None else None),
)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio: float = 4.0,
qkv_bias: bool = True, proj_bias: bool = True, ffn_bias: bool = True,
init_values: Optional[float] = 1.0,
norm_layer=nn.LayerNorm,
ffn_layer=Mlp,
qk_norm: bool = False,
rope: Optional[RotaryPositionEmbedding2D] = None,
ln_eps: float = 1e-6,
device=None, dtype=None, operations=None):
super().__init__()
self.norm1 = operations.LayerNorm(dim, eps=ln_eps, device=device, dtype=dtype)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
qk_norm=qk_norm, rope=rope, device=device, dtype=dtype, operations=operations,
)
self.ls1 = (LayerScale(dim, init_values=init_values, device=device, dtype=dtype)
if init_values else nn.Identity())
self.norm2 = operations.LayerNorm(dim, eps=ln_eps, device=device, dtype=dtype)
mlp_hidden = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim, hidden_features=mlp_hidden,
bias=ffn_bias, device=device, dtype=dtype, operations=operations,
)
self.ls2 = (LayerScale(dim, init_values=init_values, device=device, dtype=dtype)
if init_values else nn.Identity())
def forward(self, x, pos=None, attn_mask=None):
x = x + self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
# -----------------------------------------------------------------------------
# DINOv2 vision transformer
# -----------------------------------------------------------------------------
_BACKBONE_PRESETS = {
"vits": dict(embed_dim=384, depth=12, num_heads=6, ffn_layer="mlp"),
"vitb": dict(embed_dim=768, depth=12, num_heads=12, ffn_layer="mlp"),
"vitl": dict(embed_dim=1024, depth=24, num_heads=16, ffn_layer="mlp"),
"vitg": dict(embed_dim=1536, depth=40, num_heads=24, ffn_layer="swiglufused"),
}
class DinoVisionTransformer(nn.Module):
PATCH_SIZE = 14
def __init__(self,
embed_dim: int,
depth: int,
num_heads: int,
ffn_layer: str = "mlp",
mlp_ratio: float = 4.0,
init_values: float = 1.0,
alt_start: int = -1,
qknorm_start: int = -1,
rope_start: int = -1,
rope_freq: float = 100.0,
cat_token: bool = True,
device=None, dtype=None, operations=None):
super().__init__()
norm_layer = nn.LayerNorm
self.embed_dim = embed_dim
self.num_heads = num_heads
self.alt_start = alt_start
self.qknorm_start = qknorm_start
self.rope_start = rope_start
self.cat_token = cat_token
self.patch_size = self.PATCH_SIZE
self.num_register_tokens = 0
self.patch_start_idx = 1
self.num_tokens = 1
self.patch_embed = PatchEmbed(
patch_size=self.PATCH_SIZE, in_chans=3, embed_dim=embed_dim,
device=device, dtype=dtype, operations=operations,
)
# Number of patch positions for the historical 518x518 reference grid.
ref_grid = 518 // self.PATCH_SIZE
num_patches = ref_grid * ref_grid
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, device=device, dtype=dtype))
if alt_start != -1:
self.camera_token = nn.Parameter(torch.zeros(1, 2, embed_dim, device=device, dtype=dtype))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim,
device=device, dtype=dtype))
if rope_start != -1 and rope_freq > 0:
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq)
self.position_getter = PositionGetter()
else:
self.rope = None
self.position_getter = None
if ffn_layer == "mlp":
ffn = Mlp
elif ffn_layer in ("swiglu", "swiglufused"):
ffn = SwiGLUFFNFused
else:
raise NotImplementedError(f"Unsupported ffn_layer: {ffn_layer}")
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
init_values=init_values,
norm_layer=norm_layer,
ffn_layer=ffn,
qk_norm=(qknorm_start != -1 and i >= qknorm_start),
rope=(self.rope if (rope_start != -1 and i >= rope_start) else None),
device=device, dtype=dtype, operations=operations,
)
for i in range(depth)
])
self.norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
pos_embed = comfy_cast(self.pos_embed, x).float()
if npatch == N and w == h:
return pos_embed
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
M = int(math.sqrt(N))
assert N == M * M
# Historical 0.1 offset preserves bicubic resample compatibility with
# the original DINOv2 release; see the upstream PR for context.
sx = float(w0 + 0.1) / M
sy = float(h0 + 0.1) / M
patch_pos_embed = F.interpolate(
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
scale_factor=(sx, sy), mode="bicubic", antialias=False,
)
assert (w0, h0) == patch_pos_embed.shape[-2:]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
def prepare_tokens(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, S, 3, H, W) -> tokens (B, S, 1+N, C)
B, S, _, H, W = x.shape
x = rearrange(x, "b s c h w -> (b s) c h w")
x = self.patch_embed(x)
cls_token = comfy_cast(self.cls_token, x).expand(B, S, -1).reshape(B * S, 1, self.embed_dim)
x = torch.cat((cls_token, x), dim=1)
x = x + self.interpolate_pos_encoding(x, W, H)
x = rearrange(x, "(b s) n c -> b s n c", b=B, s=S)
return x
def _prepare_rope(self, B, S, H, W, device):
if self.rope is None:
return None, None
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=device)
pos = rearrange(pos, "(b s) n c -> b s n c", b=B)
pos_nodiff = torch.zeros_like(pos)
if self.patch_start_idx > 0:
pos = pos + 1
pos_special = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype)
pos_special = rearrange(pos_special, "(b s) n c -> b s n c", b=B)
pos = torch.cat([pos_special, pos], dim=2)
pos_nodiff = pos_nodiff + 1
pos_nodiff = torch.cat([pos_special, pos_nodiff], dim=2)
return pos, pos_nodiff
def _attn(self, x, blk, attn_type, pos=None, attn_mask=None):
b, s, n = x.shape[:3]
if attn_type == "local":
x = rearrange(x, "b s n c -> (b s) n c")
if pos is not None:
pos = rearrange(pos, "b s n c -> (b s) n c")
else: # "global"
x = rearrange(x, "b s n c -> b (s n) c")
if pos is not None:
pos = rearrange(pos, "b s n c -> b (s n) c")
x = blk(x, pos=pos, attn_mask=attn_mask)
if attn_type == "local":
x = rearrange(x, "(b s) n c -> b s n c", b=b, s=s)
else:
x = rearrange(x, "b (s n) c -> b s n c", b=b, s=s)
return x
# ------------------------------------------------------------------
# Public forward
# ------------------------------------------------------------------
def get_intermediate_layers(self, x: torch.Tensor, out_layers: List[int],
cam_token: Optional[torch.Tensor] = None):
B, S, _, H, W = x.shape
x = self.prepare_tokens(x)
pos, pos_nodiff = self._prepare_rope(B, S, H, W, x.device)
outputs = []
local_x = x
for i, blk in enumerate(self.blocks):
if self.rope is not None and i >= self.rope_start:
g_pos, l_pos = pos_nodiff, pos
else:
g_pos, l_pos = None, None
if self.alt_start != -1 and i == self.alt_start:
# Inject camera token at the alt-start boundary.
if cam_token is not None:
inj = cam_token
else:
ct = comfy_cast(self.camera_token, x)
ref_token = ct[:, :1].expand(B, -1, -1)
src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1)
inj = torch.cat([ref_token, src_token], dim=1)
x = x.clone()
x[:, :, 0] = inj
if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1):
x = self._attn(x, blk, "global", pos=g_pos)
else:
x = self._attn(x, blk, "local", pos=l_pos)
local_x = x
if i in out_layers:
out_x = torch.cat([local_x, x], dim=-1) if self.cat_token else x
outputs.append(out_x)
# Apply final norm. Upstream norms only the "global" half when cat_token.
normed: List[torch.Tensor] = []
camera_tokens: List[torch.Tensor] = []
for out_x in outputs:
# Camera/cls token slot is index 0 *before* register-token stripping.
camera_tokens.append(out_x[:, :, 0])
if out_x.shape[-1] == self.embed_dim:
normed.append(self.norm(out_x))
elif out_x.shape[-1] == self.embed_dim * 2:
left = out_x[..., : self.embed_dim]
right = self.norm(out_x[..., self.embed_dim :])
normed.append(torch.cat([left, right], dim=-1))
else:
raise ValueError(f"Unexpected token width: {out_x.shape[-1]}")
# Drop cls/cam token + register tokens from patch sequence.
normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed]
# Match upstream signature consumed by DA3 heads:
# feats[i][0] = normed patch tokens, feats[i][1] = camera/cls token.
return list(zip(normed, camera_tokens))
class DinoV2(nn.Module):
"""Top-level DINOv2 wrapper matching upstream key layout (``self.pretrained``)."""
def __init__(self, name: str = "vits",
out_layers: Optional[List[int]] = None,
alt_start: int = -1,
qknorm_start: int = -1,
rope_start: int = -1,
cat_token: bool = True,
device=None, dtype=None, operations=None, **kwargs):
super().__init__()
if name not in _BACKBONE_PRESETS:
raise ValueError(f"Unknown DINOv2 backbone variant: {name!r}")
preset = _BACKBONE_PRESETS[name]
self.name = name
self.out_layers = list(out_layers) if out_layers is not None else [5, 7, 9, 11]
self.cat_token = cat_token
self.pretrained = DinoVisionTransformer(
embed_dim=preset["embed_dim"],
depth=preset["depth"],
num_heads=preset["num_heads"],
ffn_layer=preset["ffn_layer"],
alt_start=alt_start,
qknorm_start=qknorm_start,
rope_start=rope_start,
cat_token=cat_token,
device=device, dtype=dtype, operations=operations,
)
def forward(self, x, cam_token=None, **_unused):
return self.pretrained.get_intermediate_layers(x, self.out_layers, cam_token=cam_token)