Refactor custom da dinov2 to image_encoders/dino2

This commit is contained in:
Talmaj Marinc 2026-05-12 09:54:59 +02:00
parent b296c6a1aa
commit 4ad749ab17
4 changed files with 482 additions and 522 deletions

View File

@ -1,4 +1,8 @@
import math
import torch import torch
import torch.nn.functional as F
from comfy.text_encoders.bert import BertAttention from comfy.text_encoders.bert import BertAttention
import comfy.model_management import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
@ -14,13 +18,42 @@ class Dino2AttentionOutput(torch.nn.Module):
class Dino2AttentionBlock(torch.nn.Module): class Dino2AttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations): def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations,
qk_norm=False):
super().__init__() super().__init__()
self.heads = heads
self.head_dim = embed_dim // heads
self.attention = BertAttention(embed_dim, heads, dtype, device, operations) self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations) self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
if qk_norm:
self.q_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
self.k_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
else:
self.q_norm = None
self.k_norm = None
def forward(self, x, mask, optimized_attention): def forward(self, x, mask, optimized_attention, pos=None, rope=None):
return self.output(self.attention(x, mask, optimized_attention)) # Fast path used by the existing CLIP-vision DINOv2 (no DA3 extensions).
if self.q_norm is None and rope is None:
return self.output(self.attention(x, mask, optimized_attention))
# DA3 path: do QKV manually so we can apply per-head QK-norm and 2D RoPE.
attn = self.attention
B, N, C = x.shape
h = self.heads
d = self.head_dim
q = attn.query(x).view(B, N, h, d).transpose(1, 2)
k = attn.key(x).view(B, N, h, d).transpose(1, 2)
v = attn.value(x).view(B, N, h, d).transpose(1, 2)
if self.q_norm is not None:
q = self.q_norm(q)
k = self.k_norm(k)
if rope is not None and pos is not None:
q = rope(q, pos)
k = rope(k, pos)
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
out = out.transpose(1, 2).reshape(B, N, C)
return self.output(out)
class LayerScale(torch.nn.Module): class LayerScale(torch.nn.Module):
@ -64,9 +97,11 @@ class SwiGLUFFN(torch.nn.Module):
class Dino2Block(torch.nn.Module): class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn): def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn,
qk_norm=False):
super().__init__() super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations) self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations,
qk_norm=qk_norm)
self.layer_scale1 = LayerScale(dim, dtype, device, operations) self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations) self.layer_scale2 = LayerScale(dim, dtype, device, operations)
if use_swiglu_ffn: if use_swiglu_ffn:
@ -76,19 +111,93 @@ class Dino2Block(torch.nn.Module):
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, x, optimized_attention): def forward(self, x, optimized_attention, pos=None, rope=None, attn_mask=None):
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention)) x = x + self.layer_scale1(self.attention(self.norm1(x), attn_mask, optimized_attention,
pos=pos, rope=rope))
x = x + self.layer_scale2(self.mlp(self.norm2(x))) x = x + self.layer_scale2(self.mlp(self.norm2(x)))
return x return x
class Dino2Encoder(torch.nn.Module): # -----------------------------------------------------------------------------
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn): # 2D Rotary position embedding (DA3 extension)
# -----------------------------------------------------------------------------
class _PositionGetter:
"""Cache (h, w) -> flat (y, x) position grid used to feed ``rope``."""
def __init__(self):
self._cache: dict = {}
def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor:
key = (height, width, device)
if key not in self._cache:
y = torch.arange(height, device=device)
x = torch.arange(width, device=device)
self._cache[key] = torch.cartesian_prod(y, x)
cached = self._cache[key]
return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
class RotaryPositionEmbedding2D(torch.nn.Module):
"""2D RoPE used by DA3-Small/Base. No learnable parameters."""
def __init__(self, frequency: float = 100.0):
super().__init__() super().__init__()
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) self.base_frequency = frequency
for _ in range(num_layers)]) self._freq_cache: dict = {}
def _components(self, dim: int, seq_len: int, device, dtype):
key = (dim, seq_len, device, dtype)
if key not in self._freq_cache:
exp = torch.arange(0, dim, 2, device=device).float() / dim
inv_freq = 1.0 / (self.base_frequency ** exp)
pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
ang = torch.einsum("i,j->ij", pos, inv_freq)
ang = ang.to(dtype)
ang = torch.cat((ang, ang), dim=-1)
self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype))
return self._freq_cache[key]
@staticmethod
def _rotate(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1]
x1, x2 = x[..., : d // 2], x[..., d // 2:]
return torch.cat((-x2, x1), dim=-1)
def _apply_1d(self, tokens, positions, cos_c, sin_c):
cos = F.embedding(positions, cos_c)[:, None, :, :]
sin = F.embedding(positions, sin_c)[:, None, :, :]
return (tokens * cos) + (self._rotate(tokens) * sin)
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
feature_dim = tokens.size(-1) // 2
max_pos = int(positions.max()) + 1
cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype)
v, h = tokens.chunk(2, dim=-1)
v = self._apply_1d(v, positions[..., 0], cos_c, sin_c)
h = self._apply_1d(h, positions[..., 1], cos_c, sin_c)
return torch.cat((v, h), dim=-1)
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn,
qknorm_start: int = -1, rope: "RotaryPositionEmbedding2D | None" = None,
rope_start: int = -1):
super().__init__()
self.layer = torch.nn.ModuleList([
Dino2Block(
dim, num_heads, layer_norm_eps, dtype, device, operations,
use_swiglu_ffn=use_swiglu_ffn,
qk_norm=(qknorm_start != -1 and i >= qknorm_start),
)
for i in range(num_layers)
])
self.rope = rope
self.rope_start = rope_start
def forward(self, x, intermediate_output=None): def forward(self, x, intermediate_output=None):
# Backward-compat path used by ``ClipVisionModel`` (no DA3 extensions).
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
if intermediate_output is not None: if intermediate_output is not None:
@ -121,25 +230,79 @@ class Dino2PatchEmbeddings(torch.nn.Module):
class Dino2Embeddings(torch.nn.Module): class Dino2Embeddings(torch.nn.Module):
def __init__(self, dim, dtype, device, operations): def __init__(self, dim, dtype, device, operations,
patch_size: int = 14, image_size: int = 518,
use_mask_token: bool = True,
num_camera_tokens: int = 0):
super().__init__() super().__init__()
patch_size = 14 self.patch_size = patch_size
image_size = 518 self.image_size = image_size
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations) self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device)) self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device)) if use_mask_token:
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
else:
self.mask_token = None
if num_camera_tokens > 0:
# DA3 stores (ref_token, src_token) pairs that get injected at the
# alt-attn boundary; see ``Dinov2Model._inject_camera_token``.
self.camera_token = torch.nn.Parameter(torch.empty(1, num_camera_tokens, dim, dtype=dtype, device=device))
else:
self.camera_token = None
def _interpolate_pos_encoding(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.position_embeddings.shape[1] - 1
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype).float()
if npatch == N and w == h:
return pos_embed
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
M = int(math.sqrt(N))
assert N == M * M
# Historical 0.1 offset preserves bicubic resample compatibility with
# the original DINOv2 release; see the upstream PR for context.
sx = float(w0 + 0.1) / M
sy = float(h0 + 0.1) / M
patch_pos_embed = F.interpolate(
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
scale_factor=(sx, sy), mode="bicubic", antialias=False,
)
assert (w0, h0) == patch_pos_embed.shape[-2:]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
def forward(self, pixel_values): def forward(self, pixel_values):
_, _, H, W = pixel_values.shape
x = self.patch_embeddings(pixel_values) x = self.patch_embeddings(pixel_values)
# TODO: mask_token? # TODO: mask_token?
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1) x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype) x = x + self._interpolate_pos_encoding(x, H, W)
return x return x
class Dinov2Model(torch.nn.Module): class Dinov2Model(torch.nn.Module):
"""DINOv2 vision backbone.
Supports two operating modes:
* **CLIP-vision DINOv2** (default): vanilla DINOv2-ViT used for
``ClipVisionModel`` and SigLIP-style image encoding.
* **Depth Anything 3** extensions (opt-in via config keys): 2D RoPE,
QK-norm, alternating local/global attention, camera-token injection,
``cat_token`` output and multi-layer feature extraction. These are
enabled when the corresponding fields (``alt_start``, ``qknorm_start``,
``rope_start``, ``cat_token``) are set in ``config_dict``. When all of
them are at their disabled defaults this module behaves identically to
the historical ``Dinov2Model``.
"""
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
num_layers = config_dict["num_hidden_layers"] num_layers = config_dict["num_hidden_layers"]
@ -147,14 +310,171 @@ class Dinov2Model(torch.nn.Module):
heads = config_dict["num_attention_heads"] heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"] layer_norm_eps = config_dict["layer_norm_eps"]
use_swiglu_ffn = config_dict["use_swiglu_ffn"] use_swiglu_ffn = config_dict["use_swiglu_ffn"]
patch_size = config_dict.get("patch_size", 14)
image_size = config_dict.get("image_size", 518)
use_mask_token = config_dict.get("use_mask_token", True)
self.embeddings = Dino2Embeddings(dim, dtype, device, operations) # DA3 extensions (all default to disabled).
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) self.alt_start = config_dict.get("alt_start", -1)
self.qknorm_start = config_dict.get("qknorm_start", -1)
self.rope_start = config_dict.get("rope_start", -1)
self.cat_token = config_dict.get("cat_token", False)
rope_freq = config_dict.get("rope_freq", 100.0)
self.embed_dim = dim
self.patch_size = patch_size
self.num_register_tokens = 0
self.patch_start_idx = 1
if self.rope_start != -1 and rope_freq > 0:
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq)
self._position_getter = _PositionGetter()
else:
self.rope = None
self._position_getter = None
# camera_token shape: (1, 2, dim) -> (ref_token, src_token).
num_cam_tokens = 2 if self.alt_start != -1 else 0
self.embeddings = Dino2Embeddings(
dim, dtype, device, operations,
patch_size=patch_size, image_size=image_size,
use_mask_token=use_mask_token, num_camera_tokens=num_cam_tokens,
)
self.encoder = Dino2Encoder(
dim, heads, layer_norm_eps, num_layers, dtype, device, operations,
use_swiglu_ffn=use_swiglu_ffn,
qknorm_start=self.qknorm_start,
rope=self.rope, rope_start=self.rope_start,
)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
# ------------------------------------------------------------------
# CLIP-vision-style forward (no DA3 extensions, no multi-layer output).
# Kept for backward compatibility with ``ClipVisionModel.encode_image``.
# ------------------------------------------------------------------
def forward(self, pixel_values, attention_mask=None, intermediate_output=None): def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values) x = self.embeddings(pixel_values)
x, i = self.encoder(x, intermediate_output=intermediate_output) x, i = self.encoder(x, intermediate_output=intermediate_output)
x = self.layernorm(x) x = self.layernorm(x)
pooled_output = x[:, 0, :] pooled_output = x[:, 0, :]
return x, i, pooled_output, None return x, i, pooled_output, None
# ------------------------------------------------------------------
# Depth Anything 3 forward
# ------------------------------------------------------------------
def _prepare_rope_positions(self, B, S, H, W, device):
if self.rope is None:
return None, None
ph, pw = H // self.patch_size, W // self.patch_size
pos = self._position_getter(B * S, ph, pw, device=device)
# Shift so the cls/cam token at position 0 is reserved for "no diff".
pos = pos + 1
cls_pos = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype)
# Per-view local: real grid positions for patches, 0 for cls token.
pos_local = torch.cat([cls_pos, pos], dim=1)
# Global (across views): same grid positions; cls token still at 0,
# but patches share the same positions in every view.
pos_global = torch.cat([cls_pos, torch.zeros_like(pos) + 1], dim=1)
return pos_local, pos_global
def _inject_camera_token(self, x: torch.Tensor, B: int, S: int,
cam_token: "torch.Tensor | None") -> torch.Tensor:
# x: (B, S, N, C). Replace token at index 0 with the camera token.
if cam_token is not None:
inj = cam_token
else:
ct = comfy.model_management.cast_to_device(self.embeddings.camera_token, x.device, x.dtype)
ref_token = ct[:, :1].expand(B, -1, -1)
src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1)
inj = torch.cat([ref_token, src_token], dim=1)
x = x.clone()
x[:, :, 0] = inj
return x
def get_intermediate_layers(self, pixel_values, out_layers, cam_token=None):
"""Multi-layer DINOv2 feature extraction used by Depth Anything 3.
Args:
pixel_values: ``(B, S, 3, H, W)`` views or ``(B, 3, H, W)``.
out_layers: indices into ``self.encoder.layer``.
cam_token: optional ``(B, S, dim)`` camera token to inject at
``alt_start``. If ``None`` and the model has its own
``camera_token`` parameter, that is used.
Returns:
List of ``(patch_tokens, cls_or_cam_token)`` tuples, one per
requested ``out_layers`` entry. ``patch_tokens`` has shape
``(B, S, N_patch, C)`` (or ``(B, S, N_patch, 2*C)`` when the
model was configured with ``cat_token=True``); the second item
has shape ``(B, S, C)``.
"""
if pixel_values.ndim == 4:
pixel_values = pixel_values.unsqueeze(1)
assert pixel_values.ndim == 5 and pixel_values.shape[2] == 3, \
f"expected (B,3,H,W) or (B,S,3,H,W); got {tuple(pixel_values.shape)}"
B, S, _, H, W = pixel_values.shape
# Patch + cls + (interpolated) pos embed for each view.
x = pixel_values.reshape(B * S, 3, H, W)
x = self.embeddings(x) # (B*S, 1+N, C)
x = x.reshape(B, S, x.shape[-2], x.shape[-1]) # (B, S, 1+N, C)
pos_local, pos_global = self._prepare_rope_positions(B, S, H, W, x.device)
# ``optimized_attention`` is only used by blocks without QK-norm/RoPE
# (vanilla DINOv2 path); enabling-aware blocks fall through to SDPA.
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
out_set = set(out_layers)
outputs: list[torch.Tensor] = []
local_x = x
for i, blk in enumerate(self.encoder.layer):
apply_rope = self.rope is not None and i >= self.rope_start
block_rope = self.rope if apply_rope else None
l_pos = pos_local if apply_rope else None
g_pos = pos_global if apply_rope else None
if self.alt_start != -1 and i == self.alt_start:
x = self._inject_camera_token(x, B, S, cam_token)
if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1):
# Global attention across views: flatten S into the seq dim.
t = x.reshape(B, S * x.shape[-2], x.shape[-1])
p = g_pos.reshape(B, S * g_pos.shape[-2], g_pos.shape[-1]) if g_pos is not None else None
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
else:
# Per-view local attention.
t = x.reshape(B * S, x.shape[-2], x.shape[-1])
p = l_pos.reshape(B * S, l_pos.shape[-2], l_pos.shape[-1]) if l_pos is not None else None
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
local_x = x
if i in out_set:
if self.cat_token:
out_x = torch.cat([local_x, x], dim=-1)
else:
out_x = x
outputs.append(out_x)
# Apply final norm. When ``cat_token`` is set, only the right half
# ("global" features) is normalised; the left half is left as-is to
# match the upstream DA3 head signature.
normed: list[torch.Tensor] = []
cls_tokens: list[torch.Tensor] = []
for out_x in outputs:
cls_tokens.append(out_x[:, :, 0])
if out_x.shape[-1] == self.embed_dim:
normed.append(self.layernorm(out_x))
elif out_x.shape[-1] == self.embed_dim * 2:
left = out_x[..., :self.embed_dim]
right = self.layernorm(out_x[..., self.embed_dim:])
normed.append(torch.cat([left, right], dim=-1))
else:
raise ValueError(f"Unexpected token width: {out_x.shape[-1]}")
# Drop cls/cam token from the patch sequence.
normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed]
return list(zip(normed, cls_tokens))

View File

@ -1,497 +0,0 @@
# DINOv2 backbone for Depth Anything 3 (monocular inference path).
#
# Why not reuse ``comfy/image_encoders/dino2.py``?
# The existing ``Dinov2Model`` is a vanilla HuggingFace-style DINOv2 with a
# different state-dict layout (separate Q/K/V, ``embeddings.*`` /
# ``encoder.layer.*`` keys, ``layer_scale1.lambda1``) and no support for the
# architectural extensions DA3 adds on top of DINOv2 (RoPE, QK-norm,
# alternating local/global attention, concatenated camera token). Loading
# raw DA3 HF safetensors into ``Dinov2Model`` would require splitting
# ``attn.qkv`` weights and a large rename map for every block, and we'd
# still need to write the DA3 extensions separately. Keeping the upstream
# ``pretrained.*`` key layout here means HF weights load directly with no
# conversion step.
#
# Ported from the upstream repo at:
# src/depth_anything_3/model/dinov2/{dinov2,vision_transformer}.py
# src/depth_anything_3/model/dinov2/layers/*
#
# DA3 extensions on top of vanilla DINOv2 (only used by Small/Base variants):
# - 2D Rotary Position Embedding starting at ``rope_start``
# - QK-norm starting at ``qknorm_start``
# - Alternating local/global attention blocks starting at ``alt_start``
# - Camera-conditioning token concatenated to features (``cat_token=True``),
# with a learned parameter ``camera_token`` injected at block
# ``alt_start`` when no external camera token is supplied.
#
# For the Mono/Metric variants the configuration disables all of the above
# (alt_start/qknorm_start/rope_start = -1, cat_token=False) so this module
# collapses to a vanilla DINOv2-ViT encoder.
from __future__ import annotations
import math
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
# -----------------------------------------------------------------------------
# 2D rotary position embedding
# -----------------------------------------------------------------------------
class PositionGetter:
def __init__(self):
self._cache: dict[tuple[int, int], torch.Tensor] = {}
def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor:
key = (height, width)
if key not in self._cache:
y = torch.arange(height, device=device)
x = torch.arange(width, device=device)
self._cache[key] = torch.cartesian_prod(y, x)
cached = self._cache[key]
return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
class RotaryPositionEmbedding2D(nn.Module):
def __init__(self, frequency: float = 100.0):
super().__init__()
self.base_frequency = frequency
self._freq_cache: dict = {}
def _components(self, dim: int, seq_len: int, device, dtype):
key = (dim, seq_len, device, dtype)
if key not in self._freq_cache:
exp = torch.arange(0, dim, 2, device=device).float() / dim
inv_freq = 1.0 / (self.base_frequency ** exp)
pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
ang = torch.einsum("i,j->ij", pos, inv_freq)
ang = ang.to(dtype)
ang = torch.cat((ang, ang), dim=-1)
self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype))
return self._freq_cache[key]
@staticmethod
def _rotate(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1]
x1, x2 = x[..., : d // 2], x[..., d // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_1d(self, tokens, positions, cos_c, sin_c):
cos = F.embedding(positions, cos_c)[:, None, :, :]
sin = F.embedding(positions, sin_c)[:, None, :, :]
return (tokens * cos) + (self._rotate(tokens) * sin)
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
feature_dim = tokens.size(-1) // 2
max_pos = int(positions.max()) + 1
cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype)
v, h = tokens.chunk(2, dim=-1)
v = self._apply_1d(v, positions[..., 0], cos_c, sin_c)
h = self._apply_1d(h, positions[..., 1], cos_c, sin_c)
return torch.cat((v, h), dim=-1)
# -----------------------------------------------------------------------------
# Patch embed / MLP / SwiGLU / LayerScale
# -----------------------------------------------------------------------------
class PatchEmbed(nn.Module):
def __init__(self, patch_size=14, in_chans=3, embed_dim=384,
device=None, dtype=None, operations=None):
super().__init__()
self.patch_size = (patch_size, patch_size)
self.proj = operations.Conv2d(
in_chans, embed_dim,
kernel_size=self.patch_size, stride=self.patch_size,
device=device, dtype=dtype,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, _, H, W = x.shape
ph, pw = self.patch_size
assert H % ph == 0 and W % pw == 0
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, bias=True,
act_layer=nn.GELU, device=None, dtype=None, operations=None):
super().__init__()
hidden_features = hidden_features or in_features
self.fc1 = operations.Linear(in_features, hidden_features, bias=bias, device=device, dtype=dtype)
self.act = act_layer()
self.fc2 = operations.Linear(hidden_features, in_features, bias=bias, device=device, dtype=dtype)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
class SwiGLUFFNFused(nn.Module):
"""SwiGLU FFN matching upstream xformers.ops.SwiGLU layout (used for vitg)."""
def __init__(self, in_features, hidden_features=None, bias=True,
device=None, dtype=None, operations=None):
super().__init__()
hidden_features = hidden_features or in_features
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
# NOTE: xformers SwiGLU stores w12 as a single fused Linear (in, 2*hidden);
# split-by-half at forward time. We don't currently need this for the
# Apache-2.0 variants but keep it for parity with the upstream key names.
self.w12 = operations.Linear(in_features, 2 * hidden_features, bias=bias, device=device, dtype=dtype)
self.w3 = operations.Linear(hidden_features, in_features, bias=bias, device=device, dtype=dtype)
def forward(self, x):
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
return self.w3(F.silu(x1) * x2)
class LayerScale(nn.Module):
def __init__(self, dim, init_values: float = 1e-5, device=None, dtype=None):
super().__init__()
self.gamma = nn.Parameter(init_values * torch.ones(dim, device=device, dtype=dtype))
def forward(self, x):
return x * comfy_cast(self.gamma, x)
def comfy_cast(p: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
"""Cast a parameter to match the reference tensor's device/dtype."""
if p.device != ref.device or p.dtype != ref.dtype:
return p.to(device=ref.device, dtype=ref.dtype)
return p
# -----------------------------------------------------------------------------
# Attention + Block
# -----------------------------------------------------------------------------
class Attention(nn.Module):
def __init__(self, dim, num_heads: int, qkv_bias: bool = True, proj_bias: bool = True,
qk_norm: bool = False, rope: Optional[RotaryPositionEmbedding2D] = None,
device=None, dtype=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
self.q_norm = operations.LayerNorm(self.head_dim, device=device, dtype=dtype) if qk_norm else nn.Identity()
self.k_norm = operations.LayerNorm(self.head_dim, device=device, dtype=dtype) if qk_norm else nn.Identity()
self.proj = operations.Linear(dim, dim, bias=proj_bias, device=device, dtype=dtype)
self.rope = rope
def forward(self, x: torch.Tensor, pos: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q, k = self.q_norm(q), self.k_norm(k)
if self.rope is not None and pos is not None:
q = self.rope(q, pos)
k = self.rope(k, pos)
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=(attn_mask[:, None].repeat(1, self.num_heads, 1, 1)
if attn_mask is not None else None),
)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio: float = 4.0,
qkv_bias: bool = True, proj_bias: bool = True, ffn_bias: bool = True,
init_values: Optional[float] = 1.0,
norm_layer=nn.LayerNorm,
ffn_layer=Mlp,
qk_norm: bool = False,
rope: Optional[RotaryPositionEmbedding2D] = None,
ln_eps: float = 1e-6,
device=None, dtype=None, operations=None):
super().__init__()
self.norm1 = operations.LayerNorm(dim, eps=ln_eps, device=device, dtype=dtype)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
qk_norm=qk_norm, rope=rope, device=device, dtype=dtype, operations=operations,
)
self.ls1 = (LayerScale(dim, init_values=init_values, device=device, dtype=dtype)
if init_values else nn.Identity())
self.norm2 = operations.LayerNorm(dim, eps=ln_eps, device=device, dtype=dtype)
mlp_hidden = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim, hidden_features=mlp_hidden,
bias=ffn_bias, device=device, dtype=dtype, operations=operations,
)
self.ls2 = (LayerScale(dim, init_values=init_values, device=device, dtype=dtype)
if init_values else nn.Identity())
def forward(self, x, pos=None, attn_mask=None):
x = x + self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
# -----------------------------------------------------------------------------
# DINOv2 vision transformer
# -----------------------------------------------------------------------------
_BACKBONE_PRESETS = {
"vits": dict(embed_dim=384, depth=12, num_heads=6, ffn_layer="mlp"),
"vitb": dict(embed_dim=768, depth=12, num_heads=12, ffn_layer="mlp"),
"vitl": dict(embed_dim=1024, depth=24, num_heads=16, ffn_layer="mlp"),
"vitg": dict(embed_dim=1536, depth=40, num_heads=24, ffn_layer="swiglufused"),
}
class DinoVisionTransformer(nn.Module):
PATCH_SIZE = 14
def __init__(self,
embed_dim: int,
depth: int,
num_heads: int,
ffn_layer: str = "mlp",
mlp_ratio: float = 4.0,
init_values: float = 1.0,
alt_start: int = -1,
qknorm_start: int = -1,
rope_start: int = -1,
rope_freq: float = 100.0,
cat_token: bool = True,
device=None, dtype=None, operations=None):
super().__init__()
norm_layer = nn.LayerNorm
self.embed_dim = embed_dim
self.num_heads = num_heads
self.alt_start = alt_start
self.qknorm_start = qknorm_start
self.rope_start = rope_start
self.cat_token = cat_token
self.patch_size = self.PATCH_SIZE
self.num_register_tokens = 0
self.patch_start_idx = 1
self.num_tokens = 1
self.patch_embed = PatchEmbed(
patch_size=self.PATCH_SIZE, in_chans=3, embed_dim=embed_dim,
device=device, dtype=dtype, operations=operations,
)
# Number of patch positions for the historical 518x518 reference grid.
ref_grid = 518 // self.PATCH_SIZE
num_patches = ref_grid * ref_grid
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, device=device, dtype=dtype))
if alt_start != -1:
self.camera_token = nn.Parameter(torch.zeros(1, 2, embed_dim, device=device, dtype=dtype))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim,
device=device, dtype=dtype))
if rope_start != -1 and rope_freq > 0:
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq)
self.position_getter = PositionGetter()
else:
self.rope = None
self.position_getter = None
if ffn_layer == "mlp":
ffn = Mlp
elif ffn_layer in ("swiglu", "swiglufused"):
ffn = SwiGLUFFNFused
else:
raise NotImplementedError(f"Unsupported ffn_layer: {ffn_layer}")
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
init_values=init_values,
norm_layer=norm_layer,
ffn_layer=ffn,
qk_norm=(qknorm_start != -1 and i >= qknorm_start),
rope=(self.rope if (rope_start != -1 and i >= rope_start) else None),
device=device, dtype=dtype, operations=operations,
)
for i in range(depth)
])
self.norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
pos_embed = comfy_cast(self.pos_embed, x).float()
if npatch == N and w == h:
return pos_embed
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
M = int(math.sqrt(N))
assert N == M * M
# Historical 0.1 offset preserves bicubic resample compatibility with
# the original DINOv2 release; see the upstream PR for context.
sx = float(w0 + 0.1) / M
sy = float(h0 + 0.1) / M
patch_pos_embed = F.interpolate(
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
scale_factor=(sx, sy), mode="bicubic", antialias=False,
)
assert (w0, h0) == patch_pos_embed.shape[-2:]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
def prepare_tokens(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, S, 3, H, W) -> tokens (B, S, 1+N, C)
B, S, _, H, W = x.shape
x = rearrange(x, "b s c h w -> (b s) c h w")
x = self.patch_embed(x)
cls_token = comfy_cast(self.cls_token, x).expand(B, S, -1).reshape(B * S, 1, self.embed_dim)
x = torch.cat((cls_token, x), dim=1)
x = x + self.interpolate_pos_encoding(x, W, H)
x = rearrange(x, "(b s) n c -> b s n c", b=B, s=S)
return x
def _prepare_rope(self, B, S, H, W, device):
if self.rope is None:
return None, None
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=device)
pos = rearrange(pos, "(b s) n c -> b s n c", b=B)
pos_nodiff = torch.zeros_like(pos)
if self.patch_start_idx > 0:
pos = pos + 1
pos_special = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype)
pos_special = rearrange(pos_special, "(b s) n c -> b s n c", b=B)
pos = torch.cat([pos_special, pos], dim=2)
pos_nodiff = pos_nodiff + 1
pos_nodiff = torch.cat([pos_special, pos_nodiff], dim=2)
return pos, pos_nodiff
def _attn(self, x, blk, attn_type, pos=None, attn_mask=None):
b, s, n = x.shape[:3]
if attn_type == "local":
x = rearrange(x, "b s n c -> (b s) n c")
if pos is not None:
pos = rearrange(pos, "b s n c -> (b s) n c")
else: # "global"
x = rearrange(x, "b s n c -> b (s n) c")
if pos is not None:
pos = rearrange(pos, "b s n c -> b (s n) c")
x = blk(x, pos=pos, attn_mask=attn_mask)
if attn_type == "local":
x = rearrange(x, "(b s) n c -> b s n c", b=b, s=s)
else:
x = rearrange(x, "b (s n) c -> b s n c", b=b, s=s)
return x
# ------------------------------------------------------------------
# Public forward
# ------------------------------------------------------------------
def get_intermediate_layers(self, x: torch.Tensor, out_layers: List[int],
cam_token: Optional[torch.Tensor] = None):
B, S, _, H, W = x.shape
x = self.prepare_tokens(x)
pos, pos_nodiff = self._prepare_rope(B, S, H, W, x.device)
outputs = []
local_x = x
for i, blk in enumerate(self.blocks):
if self.rope is not None and i >= self.rope_start:
g_pos, l_pos = pos_nodiff, pos
else:
g_pos, l_pos = None, None
if self.alt_start != -1 and i == self.alt_start:
# Inject camera token at the alt-start boundary.
if cam_token is not None:
inj = cam_token
else:
ct = comfy_cast(self.camera_token, x)
ref_token = ct[:, :1].expand(B, -1, -1)
src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1)
inj = torch.cat([ref_token, src_token], dim=1)
x = x.clone()
x[:, :, 0] = inj
if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1):
x = self._attn(x, blk, "global", pos=g_pos)
else:
x = self._attn(x, blk, "local", pos=l_pos)
local_x = x
if i in out_layers:
out_x = torch.cat([local_x, x], dim=-1) if self.cat_token else x
outputs.append(out_x)
# Apply final norm. Upstream norms only the "global" half when cat_token.
normed: List[torch.Tensor] = []
camera_tokens: List[torch.Tensor] = []
for out_x in outputs:
# Camera/cls token slot is index 0 *before* register-token stripping.
camera_tokens.append(out_x[:, :, 0])
if out_x.shape[-1] == self.embed_dim:
normed.append(self.norm(out_x))
elif out_x.shape[-1] == self.embed_dim * 2:
left = out_x[..., : self.embed_dim]
right = self.norm(out_x[..., self.embed_dim :])
normed.append(torch.cat([left, right], dim=-1))
else:
raise ValueError(f"Unexpected token width: {out_x.shape[-1]}")
# Drop cls/cam token + register tokens from patch sequence.
normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed]
# Match upstream signature consumed by DA3 heads:
# feats[i][0] = normed patch tokens, feats[i][1] = camera/cls token.
return list(zip(normed, camera_tokens))
class DinoV2(nn.Module):
"""Top-level DINOv2 wrapper matching upstream key layout (``self.pretrained``)."""
def __init__(self, name: str = "vits",
out_layers: Optional[List[int]] = None,
alt_start: int = -1,
qknorm_start: int = -1,
rope_start: int = -1,
cat_token: bool = True,
device=None, dtype=None, operations=None, **kwargs):
super().__init__()
if name not in _BACKBONE_PRESETS:
raise ValueError(f"Unknown DINOv2 backbone variant: {name!r}")
preset = _BACKBONE_PRESETS[name]
self.name = name
self.out_layers = list(out_layers) if out_layers is not None else [5, 7, 9, 11]
self.cat_token = cat_token
self.pretrained = DinoVisionTransformer(
embed_dim=preset["embed_dim"],
depth=preset["depth"],
num_heads=preset["num_heads"],
ffn_layer=preset["ffn_layer"],
alt_start=alt_start,
qknorm_start=qknorm_start,
rope_start=rope_start,
cat_token=cat_token,
device=device, dtype=dtype, operations=operations,
)
def forward(self, x, cam_token=None, **_unused):
return self.pretrained.get_intermediate_layers(x, self.out_layers, cam_token=cam_token)

View File

@ -9,15 +9,25 @@
# The class signature mirrors the upstream YAML config so a single dit_config # The class signature mirrors the upstream YAML config so a single dit_config
# detected from the state dict in ``comfy/model_detection.py`` is sufficient # detected from the state dict in ``comfy/model_detection.py`` is sufficient
# to construct the right variant. # to construct the right variant.
#
# Backbone: ``comfy.image_encoders.dino2.Dinov2Model`` is shared with the
# CLIP-vision DINOv2 path. DA3-specific extensions (RoPE, QK-norm,
# alternating local/global attention, camera token, multi-layer feature
# extraction, pos-embed interpolation) are opt-in via the config dict and are
# all disabled for the Mono/Metric variants. The upstream DA3 weight layout
# (``backbone.pretrained.*`` with fused QKV) is converted to the
# ``Dinov2Model`` layout in
# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``.
from __future__ import annotations from __future__ import annotations
from typing import Dict, List, Optional, Sequence from typing import Dict, Optional, Sequence
import torch import torch
import torch.nn as nn import torch.nn as nn
from .dinov2 import DinoV2 from comfy.image_encoders.dino2 import Dinov2Model
from .dpt import DPT, DualDPT from .dpt import DPT, DualDPT
@ -27,6 +37,42 @@ _HEAD_REGISTRY = {
} }
# Backbone presets (mirror the upstream DINOv2 ViT variants).
_BACKBONE_PRESETS = {
"vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False),
"vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False),
"vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False),
"vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True),
}
def _build_backbone_config(
backbone_name: str,
*,
alt_start: int,
qknorm_start: int,
rope_start: int,
cat_token: bool,
) -> dict:
if backbone_name not in _BACKBONE_PRESETS:
raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}")
cfg = dict(_BACKBONE_PRESETS[backbone_name])
cfg.update(dict(
layer_norm_eps=1e-6,
patch_size=14,
image_size=518,
# DA3 weights have no mask_token; skip registering it to avoid spurious
# missing-key warnings on load.
use_mask_token=False,
alt_start=alt_start,
qknorm_start=qknorm_start,
rope_start=rope_start,
cat_token=cat_token,
rope_freq=100.0,
))
return cfg
class DepthAnything3Net(nn.Module): class DepthAnything3Net(nn.Module):
"""ComfyUI-side DepthAnything3 network (monocular path only). """ComfyUI-side DepthAnything3 network (monocular path only).
@ -64,16 +110,16 @@ class DepthAnything3Net(nn.Module):
self.head_type = head_type.lower() self.head_type = head_type.lower()
self.has_sky = (self.head_type == "dpt") and head_use_sky_head self.has_sky = (self.head_type == "dpt") and head_use_sky_head
self.has_conf = head_output_dim > 1 self.has_conf = head_output_dim > 1
self.out_layers = list(out_layers)
self.backbone = DinoV2( backbone_cfg = _build_backbone_config(
name=backbone_name, backbone_name,
out_layers=list(out_layers),
alt_start=alt_start, alt_start=alt_start,
qknorm_start=qknorm_start, qknorm_start=qknorm_start,
rope_start=rope_start, rope_start=rope_start,
cat_token=cat_token, cat_token=cat_token,
device=device, dtype=dtype, operations=operations,
) )
self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations)
head_kwargs = dict( head_kwargs = dict(
dim_in=head_dim_in, dim_in=head_dim_in,
@ -122,7 +168,7 @@ class DepthAnything3Net(nn.Module):
assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \ assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \
f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}" f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}"
feats = self.backbone(image) feats = self.backbone.get_intermediate_layers(image, self.out_layers)
head_out = self.head(feats, H=H, W=W, patch_start_idx=0) head_out = self.head(feats, H=H, W=W, patch_start_idx=0)
# Flatten the views axis (S=1 in mono inference path). # Flatten the views axis (S=1 in mono inference path).

View File

@ -1871,8 +1871,99 @@ class DepthAnything3(supported_models_base.BASE):
for k in list(state_dict.keys()): for k in list(state_dict.keys()):
if k.startswith(drop_prefixes): if k.startswith(drop_prefixes):
state_dict.pop(k) state_dict.pop(k)
# Remap upstream DA3 backbone keys (``backbone.pretrained.*`` with
# fused QKV) to the layout used by ``comfy.image_encoders.dino2.Dinov2Model``.
return _da3_remap_backbone_keys(state_dict, prefix="backbone.")
def _da3_remap_backbone_keys(state_dict, prefix="backbone."):
"""Rewrite upstream DA3 DINOv2 keys to the shared ``Dinov2Model`` layout.
Upstream layout (under ``{prefix}pretrained.``):
patch_embed.proj.{weight,bias}, pos_embed, cls_token, camera_token, norm.*,
blocks.{i}.norm{1,2}.*, blocks.{i}.attn.qkv.{weight,bias},
blocks.{i}.attn.q_norm.*, blocks.{i}.attn.k_norm.*,
blocks.{i}.attn.proj.*, blocks.{i}.ls{1,2}.gamma,
blocks.{i}.mlp.fc{1,2}.* (or w12/w3 for SwiGLU)
Target layout (Dinov2Model under ``{prefix}``):
embeddings.patch_embeddings.projection.*,
embeddings.position_embeddings, embeddings.cls_token, embeddings.camera_token,
layernorm.*,
encoder.layer.{i}.norm{1,2}.*,
encoder.layer.{i}.attention.attention.{query,key,value}.*,
encoder.layer.{i}.attention.q_norm.*, encoder.layer.{i}.attention.k_norm.*,
encoder.layer.{i}.attention.output.dense.*,
encoder.layer.{i}.layer_scale{1,2}.lambda1,
encoder.layer.{i}.mlp.fc{1,2}.* (or weights_in/weights_out for SwiGLU)
"""
pre = prefix + "pretrained."
src_keys = [k for k in state_dict.keys() if k.startswith(pre)]
if not src_keys:
return state_dict return state_dict
static_renames = {
pre + "patch_embed.proj.weight": prefix + "embeddings.patch_embeddings.projection.weight",
pre + "patch_embed.proj.bias": prefix + "embeddings.patch_embeddings.projection.bias",
pre + "pos_embed": prefix + "embeddings.position_embeddings",
pre + "cls_token": prefix + "embeddings.cls_token",
pre + "camera_token": prefix + "embeddings.camera_token",
pre + "norm.weight": prefix + "layernorm.weight",
pre + "norm.bias": prefix + "layernorm.bias",
}
for src, dst in static_renames.items():
if src in state_dict:
state_dict[dst] = state_dict.pop(src)
block_pre = pre + "blocks."
block_keys = [k for k in state_dict.keys() if k.startswith(block_pre)]
for k in block_keys:
rest = k[len(block_pre):] # e.g. "5.attn.qkv.weight"
idx_str, _, sub = rest.partition(".")
target_block = "{}encoder.layer.{}.".format(prefix, idx_str)
# Fused QKV -> split query/key/value linears.
if sub == "attn.qkv.weight":
qkv = state_dict.pop(k)
c = qkv.shape[0] // 3
state_dict[target_block + "attention.attention.query.weight"] = qkv[:c].clone()
state_dict[target_block + "attention.attention.key.weight"] = qkv[c:2 * c].clone()
state_dict[target_block + "attention.attention.value.weight"] = qkv[2 * c:].clone()
continue
if sub == "attn.qkv.bias":
qkv = state_dict.pop(k)
c = qkv.shape[0] // 3
state_dict[target_block + "attention.attention.query.bias"] = qkv[:c].clone()
state_dict[target_block + "attention.attention.key.bias"] = qkv[c:2 * c].clone()
state_dict[target_block + "attention.attention.value.bias"] = qkv[2 * c:].clone()
continue
# Sub-key remap (suffix preserved).
if sub.startswith("attn.proj."):
tail = sub[len("attn.proj."):]
new = "attention.output.dense." + tail
elif sub.startswith("attn.q_norm."):
new = "attention.q_norm." + sub[len("attn.q_norm."):]
elif sub.startswith("attn.k_norm."):
new = "attention.k_norm." + sub[len("attn.k_norm."):]
elif sub == "ls1.gamma":
new = "layer_scale1.lambda1"
elif sub == "ls2.gamma":
new = "layer_scale2.lambda1"
elif sub.startswith("mlp.w12."):
new = "mlp.weights_in." + sub[len("mlp.w12."):]
elif sub.startswith("mlp.w3."):
new = "mlp.weights_out." + sub[len("mlp.w3."):]
elif sub.startswith(("norm1.", "norm2.", "mlp.fc1.", "mlp.fc2.")):
new = sub
else:
# Unrecognised key -- leave as-is so load_state_dict can complain.
continue
state_dict[target_block + new] = state_dict.pop(k)
return state_dict
class ErnieImage(supported_models_base.BASE): class ErnieImage(supported_models_base.BASE):
unet_config = { unet_config = {