mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 19:17:32 +08:00
Refactor custom da dinov2 to image_encoders/dino2
This commit is contained in:
parent
b296c6a1aa
commit
4ad749ab17
@ -1,4 +1,8 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.text_encoders.bert import BertAttention
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
@ -14,13 +18,42 @@ class Dino2AttentionOutput(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2AttentionBlock(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations,
|
||||
qk_norm=False):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.head_dim = embed_dim // heads
|
||||
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
|
||||
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
if qk_norm:
|
||||
self.q_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
|
||||
self.k_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.q_norm = None
|
||||
self.k_norm = None
|
||||
|
||||
def forward(self, x, mask, optimized_attention):
|
||||
return self.output(self.attention(x, mask, optimized_attention))
|
||||
def forward(self, x, mask, optimized_attention, pos=None, rope=None):
|
||||
# Fast path used by the existing CLIP-vision DINOv2 (no DA3 extensions).
|
||||
if self.q_norm is None and rope is None:
|
||||
return self.output(self.attention(x, mask, optimized_attention))
|
||||
|
||||
# DA3 path: do QKV manually so we can apply per-head QK-norm and 2D RoPE.
|
||||
attn = self.attention
|
||||
B, N, C = x.shape
|
||||
h = self.heads
|
||||
d = self.head_dim
|
||||
q = attn.query(x).view(B, N, h, d).transpose(1, 2)
|
||||
k = attn.key(x).view(B, N, h, d).transpose(1, 2)
|
||||
v = attn.value(x).view(B, N, h, d).transpose(1, 2)
|
||||
if self.q_norm is not None:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
if rope is not None and pos is not None:
|
||||
q = rope(q, pos)
|
||||
k = rope(k, pos)
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
out = out.transpose(1, 2).reshape(B, N, C)
|
||||
return self.output(out)
|
||||
|
||||
|
||||
class LayerScale(torch.nn.Module):
|
||||
@ -64,9 +97,11 @@ class SwiGLUFFN(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Block(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn,
|
||||
qk_norm=False):
|
||||
super().__init__()
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations,
|
||||
qk_norm=qk_norm)
|
||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||
if use_swiglu_ffn:
|
||||
@ -76,19 +111,93 @@ class Dino2Block(torch.nn.Module):
|
||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, optimized_attention):
|
||||
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
|
||||
def forward(self, x, optimized_attention, pos=None, rope=None, attn_mask=None):
|
||||
x = x + self.layer_scale1(self.attention(self.norm1(x), attn_mask, optimized_attention,
|
||||
pos=pos, rope=rope))
|
||||
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
|
||||
# -----------------------------------------------------------------------------
|
||||
# 2D Rotary position embedding (DA3 extension)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _PositionGetter:
|
||||
"""Cache (h, w) -> flat (y, x) position grid used to feed ``rope``."""
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict = {}
|
||||
|
||||
def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor:
|
||||
key = (height, width, device)
|
||||
if key not in self._cache:
|
||||
y = torch.arange(height, device=device)
|
||||
x = torch.arange(width, device=device)
|
||||
self._cache[key] = torch.cartesian_prod(y, x)
|
||||
cached = self._cache[key]
|
||||
return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
||||
|
||||
|
||||
class RotaryPositionEmbedding2D(torch.nn.Module):
|
||||
"""2D RoPE used by DA3-Small/Base. No learnable parameters."""
|
||||
|
||||
def __init__(self, frequency: float = 100.0):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
for _ in range(num_layers)])
|
||||
self.base_frequency = frequency
|
||||
self._freq_cache: dict = {}
|
||||
|
||||
def _components(self, dim: int, seq_len: int, device, dtype):
|
||||
key = (dim, seq_len, device, dtype)
|
||||
if key not in self._freq_cache:
|
||||
exp = torch.arange(0, dim, 2, device=device).float() / dim
|
||||
inv_freq = 1.0 / (self.base_frequency ** exp)
|
||||
pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
||||
ang = torch.einsum("i,j->ij", pos, inv_freq)
|
||||
ang = ang.to(dtype)
|
||||
ang = torch.cat((ang, ang), dim=-1)
|
||||
self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype))
|
||||
return self._freq_cache[key]
|
||||
|
||||
@staticmethod
|
||||
def _rotate(x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1]
|
||||
x1, x2 = x[..., : d // 2], x[..., d // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def _apply_1d(self, tokens, positions, cos_c, sin_c):
|
||||
cos = F.embedding(positions, cos_c)[:, None, :, :]
|
||||
sin = F.embedding(positions, sin_c)[:, None, :, :]
|
||||
return (tokens * cos) + (self._rotate(tokens) * sin)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
||||
feature_dim = tokens.size(-1) // 2
|
||||
max_pos = int(positions.max()) + 1
|
||||
cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype)
|
||||
v, h = tokens.chunk(2, dim=-1)
|
||||
v = self._apply_1d(v, positions[..., 0], cos_c, sin_c)
|
||||
h = self._apply_1d(h, positions[..., 1], cos_c, sin_c)
|
||||
return torch.cat((v, h), dim=-1)
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn,
|
||||
qknorm_start: int = -1, rope: "RotaryPositionEmbedding2D | None" = None,
|
||||
rope_start: int = -1):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([
|
||||
Dino2Block(
|
||||
dim, num_heads, layer_norm_eps, dtype, device, operations,
|
||||
use_swiglu_ffn=use_swiglu_ffn,
|
||||
qk_norm=(qknorm_start != -1 and i >= qknorm_start),
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
self.rope = rope
|
||||
self.rope_start = rope_start
|
||||
|
||||
def forward(self, x, intermediate_output=None):
|
||||
# Backward-compat path used by ``ClipVisionModel`` (no DA3 extensions).
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
if intermediate_output is not None:
|
||||
@ -121,25 +230,79 @@ class Dino2PatchEmbeddings(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Embeddings(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
def __init__(self, dim, dtype, device, operations,
|
||||
patch_size: int = 14, image_size: int = 518,
|
||||
use_mask_token: bool = True,
|
||||
num_camera_tokens: int = 0):
|
||||
super().__init__()
|
||||
patch_size = 14
|
||||
image_size = 518
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
|
||||
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
||||
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
||||
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
if use_mask_token:
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.mask_token = None
|
||||
if num_camera_tokens > 0:
|
||||
# DA3 stores (ref_token, src_token) pairs that get injected at the
|
||||
# alt-attn boundary; see ``Dinov2Model._inject_camera_token``.
|
||||
self.camera_token = torch.nn.Parameter(torch.empty(1, num_camera_tokens, dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.camera_token = None
|
||||
|
||||
def _interpolate_pos_encoding(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
|
||||
previous_dtype = x.dtype
|
||||
npatch = x.shape[1] - 1
|
||||
N = self.position_embeddings.shape[1] - 1
|
||||
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype).float()
|
||||
if npatch == N and w == h:
|
||||
return pos_embed
|
||||
class_pos_embed = pos_embed[:, 0]
|
||||
patch_pos_embed = pos_embed[:, 1:]
|
||||
dim = x.shape[-1]
|
||||
w0 = w // self.patch_size
|
||||
h0 = h // self.patch_size
|
||||
M = int(math.sqrt(N))
|
||||
assert N == M * M
|
||||
# Historical 0.1 offset preserves bicubic resample compatibility with
|
||||
# the original DINOv2 release; see the upstream PR for context.
|
||||
sx = float(w0 + 0.1) / M
|
||||
sy = float(h0 + 0.1) / M
|
||||
patch_pos_embed = F.interpolate(
|
||||
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
||||
scale_factor=(sx, sy), mode="bicubic", antialias=False,
|
||||
)
|
||||
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
_, _, H, W = pixel_values.shape
|
||||
x = self.patch_embeddings(pixel_values)
|
||||
# TODO: mask_token?
|
||||
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
|
||||
x = x + self._interpolate_pos_encoding(x, H, W)
|
||||
return x
|
||||
|
||||
|
||||
class Dinov2Model(torch.nn.Module):
|
||||
"""DINOv2 vision backbone.
|
||||
|
||||
Supports two operating modes:
|
||||
|
||||
* **CLIP-vision DINOv2** (default): vanilla DINOv2-ViT used for
|
||||
``ClipVisionModel`` and SigLIP-style image encoding.
|
||||
* **Depth Anything 3** extensions (opt-in via config keys): 2D RoPE,
|
||||
QK-norm, alternating local/global attention, camera-token injection,
|
||||
``cat_token`` output and multi-layer feature extraction. These are
|
||||
enabled when the corresponding fields (``alt_start``, ``qknorm_start``,
|
||||
``rope_start``, ``cat_token``) are set in ``config_dict``. When all of
|
||||
them are at their disabled defaults this module behaves identically to
|
||||
the historical ``Dinov2Model``.
|
||||
"""
|
||||
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
num_layers = config_dict["num_hidden_layers"]
|
||||
@ -147,14 +310,171 @@ class Dinov2Model(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
|
||||
patch_size = config_dict.get("patch_size", 14)
|
||||
image_size = config_dict.get("image_size", 518)
|
||||
use_mask_token = config_dict.get("use_mask_token", True)
|
||||
|
||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
# DA3 extensions (all default to disabled).
|
||||
self.alt_start = config_dict.get("alt_start", -1)
|
||||
self.qknorm_start = config_dict.get("qknorm_start", -1)
|
||||
self.rope_start = config_dict.get("rope_start", -1)
|
||||
self.cat_token = config_dict.get("cat_token", False)
|
||||
rope_freq = config_dict.get("rope_freq", 100.0)
|
||||
|
||||
self.embed_dim = dim
|
||||
self.patch_size = patch_size
|
||||
self.num_register_tokens = 0
|
||||
self.patch_start_idx = 1
|
||||
|
||||
if self.rope_start != -1 and rope_freq > 0:
|
||||
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq)
|
||||
self._position_getter = _PositionGetter()
|
||||
else:
|
||||
self.rope = None
|
||||
self._position_getter = None
|
||||
|
||||
# camera_token shape: (1, 2, dim) -> (ref_token, src_token).
|
||||
num_cam_tokens = 2 if self.alt_start != -1 else 0
|
||||
|
||||
self.embeddings = Dino2Embeddings(
|
||||
dim, dtype, device, operations,
|
||||
patch_size=patch_size, image_size=image_size,
|
||||
use_mask_token=use_mask_token, num_camera_tokens=num_cam_tokens,
|
||||
)
|
||||
self.encoder = Dino2Encoder(
|
||||
dim, heads, layer_norm_eps, num_layers, dtype, device, operations,
|
||||
use_swiglu_ffn=use_swiglu_ffn,
|
||||
qknorm_start=self.qknorm_start,
|
||||
rope=self.rope, rope_start=self.rope_start,
|
||||
)
|
||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# CLIP-vision-style forward (no DA3 extensions, no multi-layer output).
|
||||
# Kept for backward compatibility with ``ClipVisionModel.encode_image``.
|
||||
# ------------------------------------------------------------------
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
x = self.embeddings(pixel_values)
|
||||
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
||||
x = self.layernorm(x)
|
||||
pooled_output = x[:, 0, :]
|
||||
return x, i, pooled_output, None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Depth Anything 3 forward
|
||||
# ------------------------------------------------------------------
|
||||
def _prepare_rope_positions(self, B, S, H, W, device):
|
||||
if self.rope is None:
|
||||
return None, None
|
||||
ph, pw = H // self.patch_size, W // self.patch_size
|
||||
pos = self._position_getter(B * S, ph, pw, device=device)
|
||||
# Shift so the cls/cam token at position 0 is reserved for "no diff".
|
||||
pos = pos + 1
|
||||
cls_pos = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype)
|
||||
# Per-view local: real grid positions for patches, 0 for cls token.
|
||||
pos_local = torch.cat([cls_pos, pos], dim=1)
|
||||
# Global (across views): same grid positions; cls token still at 0,
|
||||
# but patches share the same positions in every view.
|
||||
pos_global = torch.cat([cls_pos, torch.zeros_like(pos) + 1], dim=1)
|
||||
return pos_local, pos_global
|
||||
|
||||
def _inject_camera_token(self, x: torch.Tensor, B: int, S: int,
|
||||
cam_token: "torch.Tensor | None") -> torch.Tensor:
|
||||
# x: (B, S, N, C). Replace token at index 0 with the camera token.
|
||||
if cam_token is not None:
|
||||
inj = cam_token
|
||||
else:
|
||||
ct = comfy.model_management.cast_to_device(self.embeddings.camera_token, x.device, x.dtype)
|
||||
ref_token = ct[:, :1].expand(B, -1, -1)
|
||||
src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1)
|
||||
inj = torch.cat([ref_token, src_token], dim=1)
|
||||
x = x.clone()
|
||||
x[:, :, 0] = inj
|
||||
return x
|
||||
|
||||
def get_intermediate_layers(self, pixel_values, out_layers, cam_token=None):
|
||||
"""Multi-layer DINOv2 feature extraction used by Depth Anything 3.
|
||||
|
||||
Args:
|
||||
pixel_values: ``(B, S, 3, H, W)`` views or ``(B, 3, H, W)``.
|
||||
out_layers: indices into ``self.encoder.layer``.
|
||||
cam_token: optional ``(B, S, dim)`` camera token to inject at
|
||||
``alt_start``. If ``None`` and the model has its own
|
||||
``camera_token`` parameter, that is used.
|
||||
|
||||
Returns:
|
||||
List of ``(patch_tokens, cls_or_cam_token)`` tuples, one per
|
||||
requested ``out_layers`` entry. ``patch_tokens`` has shape
|
||||
``(B, S, N_patch, C)`` (or ``(B, S, N_patch, 2*C)`` when the
|
||||
model was configured with ``cat_token=True``); the second item
|
||||
has shape ``(B, S, C)``.
|
||||
"""
|
||||
if pixel_values.ndim == 4:
|
||||
pixel_values = pixel_values.unsqueeze(1)
|
||||
assert pixel_values.ndim == 5 and pixel_values.shape[2] == 3, \
|
||||
f"expected (B,3,H,W) or (B,S,3,H,W); got {tuple(pixel_values.shape)}"
|
||||
B, S, _, H, W = pixel_values.shape
|
||||
|
||||
# Patch + cls + (interpolated) pos embed for each view.
|
||||
x = pixel_values.reshape(B * S, 3, H, W)
|
||||
x = self.embeddings(x) # (B*S, 1+N, C)
|
||||
x = x.reshape(B, S, x.shape[-2], x.shape[-1]) # (B, S, 1+N, C)
|
||||
|
||||
pos_local, pos_global = self._prepare_rope_positions(B, S, H, W, x.device)
|
||||
# ``optimized_attention`` is only used by blocks without QK-norm/RoPE
|
||||
# (vanilla DINOv2 path); enabling-aware blocks fall through to SDPA.
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
out_set = set(out_layers)
|
||||
outputs: list[torch.Tensor] = []
|
||||
local_x = x
|
||||
|
||||
for i, blk in enumerate(self.encoder.layer):
|
||||
apply_rope = self.rope is not None and i >= self.rope_start
|
||||
block_rope = self.rope if apply_rope else None
|
||||
l_pos = pos_local if apply_rope else None
|
||||
g_pos = pos_global if apply_rope else None
|
||||
|
||||
if self.alt_start != -1 and i == self.alt_start:
|
||||
x = self._inject_camera_token(x, B, S, cam_token)
|
||||
|
||||
if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1):
|
||||
# Global attention across views: flatten S into the seq dim.
|
||||
t = x.reshape(B, S * x.shape[-2], x.shape[-1])
|
||||
p = g_pos.reshape(B, S * g_pos.shape[-2], g_pos.shape[-1]) if g_pos is not None else None
|
||||
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
|
||||
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
|
||||
else:
|
||||
# Per-view local attention.
|
||||
t = x.reshape(B * S, x.shape[-2], x.shape[-1])
|
||||
p = l_pos.reshape(B * S, l_pos.shape[-2], l_pos.shape[-1]) if l_pos is not None else None
|
||||
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
|
||||
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
|
||||
local_x = x
|
||||
|
||||
if i in out_set:
|
||||
if self.cat_token:
|
||||
out_x = torch.cat([local_x, x], dim=-1)
|
||||
else:
|
||||
out_x = x
|
||||
outputs.append(out_x)
|
||||
|
||||
# Apply final norm. When ``cat_token`` is set, only the right half
|
||||
# ("global" features) is normalised; the left half is left as-is to
|
||||
# match the upstream DA3 head signature.
|
||||
normed: list[torch.Tensor] = []
|
||||
cls_tokens: list[torch.Tensor] = []
|
||||
for out_x in outputs:
|
||||
cls_tokens.append(out_x[:, :, 0])
|
||||
if out_x.shape[-1] == self.embed_dim:
|
||||
normed.append(self.layernorm(out_x))
|
||||
elif out_x.shape[-1] == self.embed_dim * 2:
|
||||
left = out_x[..., :self.embed_dim]
|
||||
right = self.layernorm(out_x[..., self.embed_dim:])
|
||||
normed.append(torch.cat([left, right], dim=-1))
|
||||
else:
|
||||
raise ValueError(f"Unexpected token width: {out_x.shape[-1]}")
|
||||
|
||||
# Drop cls/cam token from the patch sequence.
|
||||
normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed]
|
||||
return list(zip(normed, cls_tokens))
|
||||
|
||||
@ -1,497 +0,0 @@
|
||||
# DINOv2 backbone for Depth Anything 3 (monocular inference path).
|
||||
#
|
||||
# Why not reuse ``comfy/image_encoders/dino2.py``?
|
||||
# The existing ``Dinov2Model`` is a vanilla HuggingFace-style DINOv2 with a
|
||||
# different state-dict layout (separate Q/K/V, ``embeddings.*`` /
|
||||
# ``encoder.layer.*`` keys, ``layer_scale1.lambda1``) and no support for the
|
||||
# architectural extensions DA3 adds on top of DINOv2 (RoPE, QK-norm,
|
||||
# alternating local/global attention, concatenated camera token). Loading
|
||||
# raw DA3 HF safetensors into ``Dinov2Model`` would require splitting
|
||||
# ``attn.qkv`` weights and a large rename map for every block, and we'd
|
||||
# still need to write the DA3 extensions separately. Keeping the upstream
|
||||
# ``pretrained.*`` key layout here means HF weights load directly with no
|
||||
# conversion step.
|
||||
#
|
||||
# Ported from the upstream repo at:
|
||||
# src/depth_anything_3/model/dinov2/{dinov2,vision_transformer}.py
|
||||
# src/depth_anything_3/model/dinov2/layers/*
|
||||
#
|
||||
# DA3 extensions on top of vanilla DINOv2 (only used by Small/Base variants):
|
||||
# - 2D Rotary Position Embedding starting at ``rope_start``
|
||||
# - QK-norm starting at ``qknorm_start``
|
||||
# - Alternating local/global attention blocks starting at ``alt_start``
|
||||
# - Camera-conditioning token concatenated to features (``cat_token=True``),
|
||||
# with a learned parameter ``camera_token`` injected at block
|
||||
# ``alt_start`` when no external camera token is supplied.
|
||||
#
|
||||
# For the Mono/Metric variants the configuration disables all of the above
|
||||
# (alt_start/qknorm_start/rope_start = -1, cat_token=False) so this module
|
||||
# collapses to a vanilla DINOv2-ViT encoder.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# 2D rotary position embedding
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PositionGetter:
|
||||
def __init__(self):
|
||||
self._cache: dict[tuple[int, int], torch.Tensor] = {}
|
||||
|
||||
def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor:
|
||||
key = (height, width)
|
||||
if key not in self._cache:
|
||||
y = torch.arange(height, device=device)
|
||||
x = torch.arange(width, device=device)
|
||||
self._cache[key] = torch.cartesian_prod(y, x)
|
||||
cached = self._cache[key]
|
||||
return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
||||
|
||||
|
||||
class RotaryPositionEmbedding2D(nn.Module):
|
||||
def __init__(self, frequency: float = 100.0):
|
||||
super().__init__()
|
||||
self.base_frequency = frequency
|
||||
self._freq_cache: dict = {}
|
||||
|
||||
def _components(self, dim: int, seq_len: int, device, dtype):
|
||||
key = (dim, seq_len, device, dtype)
|
||||
if key not in self._freq_cache:
|
||||
exp = torch.arange(0, dim, 2, device=device).float() / dim
|
||||
inv_freq = 1.0 / (self.base_frequency ** exp)
|
||||
pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
||||
ang = torch.einsum("i,j->ij", pos, inv_freq)
|
||||
ang = ang.to(dtype)
|
||||
ang = torch.cat((ang, ang), dim=-1)
|
||||
self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype))
|
||||
return self._freq_cache[key]
|
||||
|
||||
@staticmethod
|
||||
def _rotate(x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1]
|
||||
x1, x2 = x[..., : d // 2], x[..., d // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def _apply_1d(self, tokens, positions, cos_c, sin_c):
|
||||
cos = F.embedding(positions, cos_c)[:, None, :, :]
|
||||
sin = F.embedding(positions, sin_c)[:, None, :, :]
|
||||
return (tokens * cos) + (self._rotate(tokens) * sin)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
||||
feature_dim = tokens.size(-1) // 2
|
||||
max_pos = int(positions.max()) + 1
|
||||
cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype)
|
||||
v, h = tokens.chunk(2, dim=-1)
|
||||
v = self._apply_1d(v, positions[..., 0], cos_c, sin_c)
|
||||
h = self._apply_1d(h, positions[..., 1], cos_c, sin_c)
|
||||
return torch.cat((v, h), dim=-1)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Patch embed / MLP / SwiGLU / LayerScale
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(self, patch_size=14, in_chans=3, embed_dim=384,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.patch_size = (patch_size, patch_size)
|
||||
self.proj = operations.Conv2d(
|
||||
in_chans, embed_dim,
|
||||
kernel_size=self.patch_size, stride=self.patch_size,
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, _, H, W = x.shape
|
||||
ph, pw = self.patch_size
|
||||
assert H % ph == 0 and W % pw == 0
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, bias=True,
|
||||
act_layer=nn.GELU, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = operations.Linear(in_features, hidden_features, bias=bias, device=device, dtype=dtype)
|
||||
self.act = act_layer()
|
||||
self.fc2 = operations.Linear(hidden_features, in_features, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(self.act(self.fc1(x)))
|
||||
|
||||
|
||||
class SwiGLUFFNFused(nn.Module):
|
||||
"""SwiGLU FFN matching upstream xformers.ops.SwiGLU layout (used for vitg)."""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, bias=True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
hidden_features = hidden_features or in_features
|
||||
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
# NOTE: xformers SwiGLU stores w12 as a single fused Linear (in, 2*hidden);
|
||||
# split-by-half at forward time. We don't currently need this for the
|
||||
# Apache-2.0 variants but keep it for parity with the upstream key names.
|
||||
self.w12 = operations.Linear(in_features, 2 * hidden_features, bias=bias, device=device, dtype=dtype)
|
||||
self.w3 = operations.Linear(hidden_features, in_features, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x12 = self.w12(x)
|
||||
x1, x2 = x12.chunk(2, dim=-1)
|
||||
return self.w3(F.silu(x1) * x2)
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(self, dim, init_values: float = 1e-5, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x):
|
||||
return x * comfy_cast(self.gamma, x)
|
||||
|
||||
|
||||
def comfy_cast(p: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
|
||||
"""Cast a parameter to match the reference tensor's device/dtype."""
|
||||
if p.device != ref.device or p.dtype != ref.dtype:
|
||||
return p.to(device=ref.device, dtype=ref.dtype)
|
||||
return p
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Attention + Block
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads: int, qkv_bias: bool = True, proj_bias: bool = True,
|
||||
qk_norm: bool = False, rope: Optional[RotaryPositionEmbedding2D] = None,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.q_norm = operations.LayerNorm(self.head_dim, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||
self.k_norm = operations.LayerNorm(self.head_dim, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||
self.proj = operations.Linear(dim, dim, bias=proj_bias, device=device, dtype=dtype)
|
||||
self.rope = rope
|
||||
|
||||
def forward(self, x: torch.Tensor, pos: Optional[torch.Tensor] = None,
|
||||
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
if self.rope is not None and pos is not None:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=(attn_mask[:, None].repeat(1, self.num_heads, 1, 1)
|
||||
if attn_mask is not None else None),
|
||||
)
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, num_heads, mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True, proj_bias: bool = True, ffn_bias: bool = True,
|
||||
init_values: Optional[float] = 1.0,
|
||||
norm_layer=nn.LayerNorm,
|
||||
ffn_layer=Mlp,
|
||||
qk_norm: bool = False,
|
||||
rope: Optional[RotaryPositionEmbedding2D] = None,
|
||||
ln_eps: float = 1e-6,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(dim, eps=ln_eps, device=device, dtype=dtype)
|
||||
self.attn = Attention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
|
||||
qk_norm=qk_norm, rope=rope, device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
self.ls1 = (LayerScale(dim, init_values=init_values, device=device, dtype=dtype)
|
||||
if init_values else nn.Identity())
|
||||
|
||||
self.norm2 = operations.LayerNorm(dim, eps=ln_eps, device=device, dtype=dtype)
|
||||
mlp_hidden = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim, hidden_features=mlp_hidden,
|
||||
bias=ffn_bias, device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
self.ls2 = (LayerScale(dim, init_values=init_values, device=device, dtype=dtype)
|
||||
if init_values else nn.Identity())
|
||||
|
||||
def forward(self, x, pos=None, attn_mask=None):
|
||||
x = x + self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask))
|
||||
x = x + self.ls2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DINOv2 vision transformer
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
_BACKBONE_PRESETS = {
|
||||
"vits": dict(embed_dim=384, depth=12, num_heads=6, ffn_layer="mlp"),
|
||||
"vitb": dict(embed_dim=768, depth=12, num_heads=12, ffn_layer="mlp"),
|
||||
"vitl": dict(embed_dim=1024, depth=24, num_heads=16, ffn_layer="mlp"),
|
||||
"vitg": dict(embed_dim=1536, depth=40, num_heads=24, ffn_layer="swiglufused"),
|
||||
}
|
||||
|
||||
|
||||
class DinoVisionTransformer(nn.Module):
|
||||
PATCH_SIZE = 14
|
||||
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
ffn_layer: str = "mlp",
|
||||
mlp_ratio: float = 4.0,
|
||||
init_values: float = 1.0,
|
||||
alt_start: int = -1,
|
||||
qknorm_start: int = -1,
|
||||
rope_start: int = -1,
|
||||
rope_freq: float = 100.0,
|
||||
cat_token: bool = True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
norm_layer = nn.LayerNorm
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.alt_start = alt_start
|
||||
self.qknorm_start = qknorm_start
|
||||
self.rope_start = rope_start
|
||||
self.cat_token = cat_token
|
||||
self.patch_size = self.PATCH_SIZE
|
||||
self.num_register_tokens = 0
|
||||
self.patch_start_idx = 1
|
||||
self.num_tokens = 1
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
patch_size=self.PATCH_SIZE, in_chans=3, embed_dim=embed_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
# Number of patch positions for the historical 518x518 reference grid.
|
||||
ref_grid = 518 // self.PATCH_SIZE
|
||||
num_patches = ref_grid * ref_grid
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, device=device, dtype=dtype))
|
||||
if alt_start != -1:
|
||||
self.camera_token = nn.Parameter(torch.zeros(1, 2, embed_dim, device=device, dtype=dtype))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim,
|
||||
device=device, dtype=dtype))
|
||||
|
||||
if rope_start != -1 and rope_freq > 0:
|
||||
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq)
|
||||
self.position_getter = PositionGetter()
|
||||
else:
|
||||
self.rope = None
|
||||
self.position_getter = None
|
||||
|
||||
if ffn_layer == "mlp":
|
||||
ffn = Mlp
|
||||
elif ffn_layer in ("swiglu", "swiglufused"):
|
||||
ffn = SwiGLUFFNFused
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported ffn_layer: {ffn_layer}")
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=True,
|
||||
proj_bias=True,
|
||||
ffn_bias=True,
|
||||
init_values=init_values,
|
||||
norm_layer=norm_layer,
|
||||
ffn_layer=ffn,
|
||||
qk_norm=(qknorm_start != -1 and i >= qknorm_start),
|
||||
rope=(self.rope if (rope_start != -1 and i >= rope_start) else None),
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
for i in range(depth)
|
||||
])
|
||||
self.norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
def interpolate_pos_encoding(self, x, w, h):
|
||||
previous_dtype = x.dtype
|
||||
npatch = x.shape[1] - 1
|
||||
N = self.pos_embed.shape[1] - 1
|
||||
pos_embed = comfy_cast(self.pos_embed, x).float()
|
||||
if npatch == N and w == h:
|
||||
return pos_embed
|
||||
class_pos_embed = pos_embed[:, 0]
|
||||
patch_pos_embed = pos_embed[:, 1:]
|
||||
dim = x.shape[-1]
|
||||
w0 = w // self.patch_size
|
||||
h0 = h // self.patch_size
|
||||
M = int(math.sqrt(N))
|
||||
assert N == M * M
|
||||
# Historical 0.1 offset preserves bicubic resample compatibility with
|
||||
# the original DINOv2 release; see the upstream PR for context.
|
||||
sx = float(w0 + 0.1) / M
|
||||
sy = float(h0 + 0.1) / M
|
||||
patch_pos_embed = F.interpolate(
|
||||
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
||||
scale_factor=(sx, sy), mode="bicubic", antialias=False,
|
||||
)
|
||||
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
||||
|
||||
def prepare_tokens(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (B, S, 3, H, W) -> tokens (B, S, 1+N, C)
|
||||
B, S, _, H, W = x.shape
|
||||
x = rearrange(x, "b s c h w -> (b s) c h w")
|
||||
x = self.patch_embed(x)
|
||||
cls_token = comfy_cast(self.cls_token, x).expand(B, S, -1).reshape(B * S, 1, self.embed_dim)
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
x = x + self.interpolate_pos_encoding(x, W, H)
|
||||
x = rearrange(x, "(b s) n c -> b s n c", b=B, s=S)
|
||||
return x
|
||||
|
||||
def _prepare_rope(self, B, S, H, W, device):
|
||||
if self.rope is None:
|
||||
return None, None
|
||||
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=device)
|
||||
pos = rearrange(pos, "(b s) n c -> b s n c", b=B)
|
||||
pos_nodiff = torch.zeros_like(pos)
|
||||
if self.patch_start_idx > 0:
|
||||
pos = pos + 1
|
||||
pos_special = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype)
|
||||
pos_special = rearrange(pos_special, "(b s) n c -> b s n c", b=B)
|
||||
pos = torch.cat([pos_special, pos], dim=2)
|
||||
pos_nodiff = pos_nodiff + 1
|
||||
pos_nodiff = torch.cat([pos_special, pos_nodiff], dim=2)
|
||||
return pos, pos_nodiff
|
||||
|
||||
def _attn(self, x, blk, attn_type, pos=None, attn_mask=None):
|
||||
b, s, n = x.shape[:3]
|
||||
if attn_type == "local":
|
||||
x = rearrange(x, "b s n c -> (b s) n c")
|
||||
if pos is not None:
|
||||
pos = rearrange(pos, "b s n c -> (b s) n c")
|
||||
else: # "global"
|
||||
x = rearrange(x, "b s n c -> b (s n) c")
|
||||
if pos is not None:
|
||||
pos = rearrange(pos, "b s n c -> b (s n) c")
|
||||
x = blk(x, pos=pos, attn_mask=attn_mask)
|
||||
if attn_type == "local":
|
||||
x = rearrange(x, "(b s) n c -> b s n c", b=b, s=s)
|
||||
else:
|
||||
x = rearrange(x, "b (s n) c -> b s n c", b=b, s=s)
|
||||
return x
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public forward
|
||||
# ------------------------------------------------------------------
|
||||
def get_intermediate_layers(self, x: torch.Tensor, out_layers: List[int],
|
||||
cam_token: Optional[torch.Tensor] = None):
|
||||
B, S, _, H, W = x.shape
|
||||
x = self.prepare_tokens(x)
|
||||
pos, pos_nodiff = self._prepare_rope(B, S, H, W, x.device)
|
||||
|
||||
outputs = []
|
||||
local_x = x
|
||||
|
||||
for i, blk in enumerate(self.blocks):
|
||||
if self.rope is not None and i >= self.rope_start:
|
||||
g_pos, l_pos = pos_nodiff, pos
|
||||
else:
|
||||
g_pos, l_pos = None, None
|
||||
|
||||
if self.alt_start != -1 and i == self.alt_start:
|
||||
# Inject camera token at the alt-start boundary.
|
||||
if cam_token is not None:
|
||||
inj = cam_token
|
||||
else:
|
||||
ct = comfy_cast(self.camera_token, x)
|
||||
ref_token = ct[:, :1].expand(B, -1, -1)
|
||||
src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1)
|
||||
inj = torch.cat([ref_token, src_token], dim=1)
|
||||
x = x.clone()
|
||||
x[:, :, 0] = inj
|
||||
|
||||
if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1):
|
||||
x = self._attn(x, blk, "global", pos=g_pos)
|
||||
else:
|
||||
x = self._attn(x, blk, "local", pos=l_pos)
|
||||
local_x = x
|
||||
|
||||
if i in out_layers:
|
||||
out_x = torch.cat([local_x, x], dim=-1) if self.cat_token else x
|
||||
outputs.append(out_x)
|
||||
|
||||
# Apply final norm. Upstream norms only the "global" half when cat_token.
|
||||
normed: List[torch.Tensor] = []
|
||||
camera_tokens: List[torch.Tensor] = []
|
||||
for out_x in outputs:
|
||||
# Camera/cls token slot is index 0 *before* register-token stripping.
|
||||
camera_tokens.append(out_x[:, :, 0])
|
||||
if out_x.shape[-1] == self.embed_dim:
|
||||
normed.append(self.norm(out_x))
|
||||
elif out_x.shape[-1] == self.embed_dim * 2:
|
||||
left = out_x[..., : self.embed_dim]
|
||||
right = self.norm(out_x[..., self.embed_dim :])
|
||||
normed.append(torch.cat([left, right], dim=-1))
|
||||
else:
|
||||
raise ValueError(f"Unexpected token width: {out_x.shape[-1]}")
|
||||
# Drop cls/cam token + register tokens from patch sequence.
|
||||
normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed]
|
||||
# Match upstream signature consumed by DA3 heads:
|
||||
# feats[i][0] = normed patch tokens, feats[i][1] = camera/cls token.
|
||||
return list(zip(normed, camera_tokens))
|
||||
|
||||
|
||||
class DinoV2(nn.Module):
|
||||
"""Top-level DINOv2 wrapper matching upstream key layout (``self.pretrained``)."""
|
||||
|
||||
def __init__(self, name: str = "vits",
|
||||
out_layers: Optional[List[int]] = None,
|
||||
alt_start: int = -1,
|
||||
qknorm_start: int = -1,
|
||||
rope_start: int = -1,
|
||||
cat_token: bool = True,
|
||||
device=None, dtype=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
if name not in _BACKBONE_PRESETS:
|
||||
raise ValueError(f"Unknown DINOv2 backbone variant: {name!r}")
|
||||
preset = _BACKBONE_PRESETS[name]
|
||||
self.name = name
|
||||
self.out_layers = list(out_layers) if out_layers is not None else [5, 7, 9, 11]
|
||||
self.cat_token = cat_token
|
||||
self.pretrained = DinoVisionTransformer(
|
||||
embed_dim=preset["embed_dim"],
|
||||
depth=preset["depth"],
|
||||
num_heads=preset["num_heads"],
|
||||
ffn_layer=preset["ffn_layer"],
|
||||
alt_start=alt_start,
|
||||
qknorm_start=qknorm_start,
|
||||
rope_start=rope_start,
|
||||
cat_token=cat_token,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, x, cam_token=None, **_unused):
|
||||
return self.pretrained.get_intermediate_layers(x, self.out_layers, cam_token=cam_token)
|
||||
@ -9,15 +9,25 @@
|
||||
# The class signature mirrors the upstream YAML config so a single dit_config
|
||||
# detected from the state dict in ``comfy/model_detection.py`` is sufficient
|
||||
# to construct the right variant.
|
||||
#
|
||||
# Backbone: ``comfy.image_encoders.dino2.Dinov2Model`` is shared with the
|
||||
# CLIP-vision DINOv2 path. DA3-specific extensions (RoPE, QK-norm,
|
||||
# alternating local/global attention, camera token, multi-layer feature
|
||||
# extraction, pos-embed interpolation) are opt-in via the config dict and are
|
||||
# all disabled for the Mono/Metric variants. The upstream DA3 weight layout
|
||||
# (``backbone.pretrained.*`` with fused QKV) is converted to the
|
||||
# ``Dinov2Model`` layout in
|
||||
# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .dinov2 import DinoV2
|
||||
from comfy.image_encoders.dino2 import Dinov2Model
|
||||
|
||||
from .dpt import DPT, DualDPT
|
||||
|
||||
|
||||
@ -27,6 +37,42 @@ _HEAD_REGISTRY = {
|
||||
}
|
||||
|
||||
|
||||
# Backbone presets (mirror the upstream DINOv2 ViT variants).
|
||||
_BACKBONE_PRESETS = {
|
||||
"vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False),
|
||||
"vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False),
|
||||
"vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False),
|
||||
"vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True),
|
||||
}
|
||||
|
||||
|
||||
def _build_backbone_config(
|
||||
backbone_name: str,
|
||||
*,
|
||||
alt_start: int,
|
||||
qknorm_start: int,
|
||||
rope_start: int,
|
||||
cat_token: bool,
|
||||
) -> dict:
|
||||
if backbone_name not in _BACKBONE_PRESETS:
|
||||
raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}")
|
||||
cfg = dict(_BACKBONE_PRESETS[backbone_name])
|
||||
cfg.update(dict(
|
||||
layer_norm_eps=1e-6,
|
||||
patch_size=14,
|
||||
image_size=518,
|
||||
# DA3 weights have no mask_token; skip registering it to avoid spurious
|
||||
# missing-key warnings on load.
|
||||
use_mask_token=False,
|
||||
alt_start=alt_start,
|
||||
qknorm_start=qknorm_start,
|
||||
rope_start=rope_start,
|
||||
cat_token=cat_token,
|
||||
rope_freq=100.0,
|
||||
))
|
||||
return cfg
|
||||
|
||||
|
||||
class DepthAnything3Net(nn.Module):
|
||||
"""ComfyUI-side DepthAnything3 network (monocular path only).
|
||||
|
||||
@ -64,16 +110,16 @@ class DepthAnything3Net(nn.Module):
|
||||
self.head_type = head_type.lower()
|
||||
self.has_sky = (self.head_type == "dpt") and head_use_sky_head
|
||||
self.has_conf = head_output_dim > 1
|
||||
self.out_layers = list(out_layers)
|
||||
|
||||
self.backbone = DinoV2(
|
||||
name=backbone_name,
|
||||
out_layers=list(out_layers),
|
||||
backbone_cfg = _build_backbone_config(
|
||||
backbone_name,
|
||||
alt_start=alt_start,
|
||||
qknorm_start=qknorm_start,
|
||||
rope_start=rope_start,
|
||||
cat_token=cat_token,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations)
|
||||
|
||||
head_kwargs = dict(
|
||||
dim_in=head_dim_in,
|
||||
@ -122,7 +168,7 @@ class DepthAnything3Net(nn.Module):
|
||||
assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \
|
||||
f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}"
|
||||
|
||||
feats = self.backbone(image)
|
||||
feats = self.backbone.get_intermediate_layers(image, self.out_layers)
|
||||
head_out = self.head(feats, H=H, W=W, patch_start_idx=0)
|
||||
|
||||
# Flatten the views axis (S=1 in mono inference path).
|
||||
|
||||
@ -1871,8 +1871,99 @@ class DepthAnything3(supported_models_base.BASE):
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(drop_prefixes):
|
||||
state_dict.pop(k)
|
||||
# Remap upstream DA3 backbone keys (``backbone.pretrained.*`` with
|
||||
# fused QKV) to the layout used by ``comfy.image_encoders.dino2.Dinov2Model``.
|
||||
return _da3_remap_backbone_keys(state_dict, prefix="backbone.")
|
||||
|
||||
|
||||
def _da3_remap_backbone_keys(state_dict, prefix="backbone."):
|
||||
"""Rewrite upstream DA3 DINOv2 keys to the shared ``Dinov2Model`` layout.
|
||||
|
||||
Upstream layout (under ``{prefix}pretrained.``):
|
||||
patch_embed.proj.{weight,bias}, pos_embed, cls_token, camera_token, norm.*,
|
||||
blocks.{i}.norm{1,2}.*, blocks.{i}.attn.qkv.{weight,bias},
|
||||
blocks.{i}.attn.q_norm.*, blocks.{i}.attn.k_norm.*,
|
||||
blocks.{i}.attn.proj.*, blocks.{i}.ls{1,2}.gamma,
|
||||
blocks.{i}.mlp.fc{1,2}.* (or w12/w3 for SwiGLU)
|
||||
|
||||
Target layout (Dinov2Model under ``{prefix}``):
|
||||
embeddings.patch_embeddings.projection.*,
|
||||
embeddings.position_embeddings, embeddings.cls_token, embeddings.camera_token,
|
||||
layernorm.*,
|
||||
encoder.layer.{i}.norm{1,2}.*,
|
||||
encoder.layer.{i}.attention.attention.{query,key,value}.*,
|
||||
encoder.layer.{i}.attention.q_norm.*, encoder.layer.{i}.attention.k_norm.*,
|
||||
encoder.layer.{i}.attention.output.dense.*,
|
||||
encoder.layer.{i}.layer_scale{1,2}.lambda1,
|
||||
encoder.layer.{i}.mlp.fc{1,2}.* (or weights_in/weights_out for SwiGLU)
|
||||
"""
|
||||
pre = prefix + "pretrained."
|
||||
src_keys = [k for k in state_dict.keys() if k.startswith(pre)]
|
||||
if not src_keys:
|
||||
return state_dict
|
||||
|
||||
static_renames = {
|
||||
pre + "patch_embed.proj.weight": prefix + "embeddings.patch_embeddings.projection.weight",
|
||||
pre + "patch_embed.proj.bias": prefix + "embeddings.patch_embeddings.projection.bias",
|
||||
pre + "pos_embed": prefix + "embeddings.position_embeddings",
|
||||
pre + "cls_token": prefix + "embeddings.cls_token",
|
||||
pre + "camera_token": prefix + "embeddings.camera_token",
|
||||
pre + "norm.weight": prefix + "layernorm.weight",
|
||||
pre + "norm.bias": prefix + "layernorm.bias",
|
||||
}
|
||||
for src, dst in static_renames.items():
|
||||
if src in state_dict:
|
||||
state_dict[dst] = state_dict.pop(src)
|
||||
|
||||
block_pre = pre + "blocks."
|
||||
block_keys = [k for k in state_dict.keys() if k.startswith(block_pre)]
|
||||
for k in block_keys:
|
||||
rest = k[len(block_pre):] # e.g. "5.attn.qkv.weight"
|
||||
idx_str, _, sub = rest.partition(".")
|
||||
target_block = "{}encoder.layer.{}.".format(prefix, idx_str)
|
||||
|
||||
# Fused QKV -> split query/key/value linears.
|
||||
if sub == "attn.qkv.weight":
|
||||
qkv = state_dict.pop(k)
|
||||
c = qkv.shape[0] // 3
|
||||
state_dict[target_block + "attention.attention.query.weight"] = qkv[:c].clone()
|
||||
state_dict[target_block + "attention.attention.key.weight"] = qkv[c:2 * c].clone()
|
||||
state_dict[target_block + "attention.attention.value.weight"] = qkv[2 * c:].clone()
|
||||
continue
|
||||
if sub == "attn.qkv.bias":
|
||||
qkv = state_dict.pop(k)
|
||||
c = qkv.shape[0] // 3
|
||||
state_dict[target_block + "attention.attention.query.bias"] = qkv[:c].clone()
|
||||
state_dict[target_block + "attention.attention.key.bias"] = qkv[c:2 * c].clone()
|
||||
state_dict[target_block + "attention.attention.value.bias"] = qkv[2 * c:].clone()
|
||||
continue
|
||||
|
||||
# Sub-key remap (suffix preserved).
|
||||
if sub.startswith("attn.proj."):
|
||||
tail = sub[len("attn.proj."):]
|
||||
new = "attention.output.dense." + tail
|
||||
elif sub.startswith("attn.q_norm."):
|
||||
new = "attention.q_norm." + sub[len("attn.q_norm."):]
|
||||
elif sub.startswith("attn.k_norm."):
|
||||
new = "attention.k_norm." + sub[len("attn.k_norm."):]
|
||||
elif sub == "ls1.gamma":
|
||||
new = "layer_scale1.lambda1"
|
||||
elif sub == "ls2.gamma":
|
||||
new = "layer_scale2.lambda1"
|
||||
elif sub.startswith("mlp.w12."):
|
||||
new = "mlp.weights_in." + sub[len("mlp.w12."):]
|
||||
elif sub.startswith("mlp.w3."):
|
||||
new = "mlp.weights_out." + sub[len("mlp.w3."):]
|
||||
elif sub.startswith(("norm1.", "norm2.", "mlp.fc1.", "mlp.fc2.")):
|
||||
new = sub
|
||||
else:
|
||||
# Unrecognised key -- leave as-is so load_state_dict can complain.
|
||||
continue
|
||||
|
||||
state_dict[target_block + new] = state_dict.pop(k)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
class ErnieImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user