From 4ad749ab1715d1d45fcbc7e155023f222fef3433 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Tue, 12 May 2026 09:54:59 +0200 Subject: [PATCH] Refactor custom da dinov2 to image_encoders/dino2 --- comfy/image_encoders/dino2.py | 356 ++++++++++++++++++- comfy/ldm/depth_anything_3/dinov2.py | 497 --------------------------- comfy/ldm/depth_anything_3/model.py | 60 +++- comfy/supported_models.py | 91 +++++ 4 files changed, 482 insertions(+), 522 deletions(-) delete mode 100644 comfy/ldm/depth_anything_3/dinov2.py diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py index 9b6dace9d..730890a33 100644 --- a/comfy/image_encoders/dino2.py +++ b/comfy/image_encoders/dino2.py @@ -1,4 +1,8 @@ +import math + import torch +import torch.nn.functional as F + from comfy.text_encoders.bert import BertAttention import comfy.model_management from comfy.ldm.modules.attention import optimized_attention_for_device @@ -14,13 +18,42 @@ class Dino2AttentionOutput(torch.nn.Module): class Dino2AttentionBlock(torch.nn.Module): - def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations): + def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations, + qk_norm=False): super().__init__() + self.heads = heads + self.head_dim = embed_dim // heads self.attention = BertAttention(embed_dim, heads, dtype, device, operations) self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations) + if qk_norm: + self.q_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device) + self.k_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device) + else: + self.q_norm = None + self.k_norm = None - def forward(self, x, mask, optimized_attention): - return self.output(self.attention(x, mask, optimized_attention)) + def forward(self, x, mask, optimized_attention, pos=None, rope=None): + # Fast path used by the existing CLIP-vision DINOv2 (no DA3 extensions). + if self.q_norm is None and rope is None: + return self.output(self.attention(x, mask, optimized_attention)) + + # DA3 path: do QKV manually so we can apply per-head QK-norm and 2D RoPE. + attn = self.attention + B, N, C = x.shape + h = self.heads + d = self.head_dim + q = attn.query(x).view(B, N, h, d).transpose(1, 2) + k = attn.key(x).view(B, N, h, d).transpose(1, 2) + v = attn.value(x).view(B, N, h, d).transpose(1, 2) + if self.q_norm is not None: + q = self.q_norm(q) + k = self.k_norm(k) + if rope is not None and pos is not None: + q = rope(q, pos) + k = rope(k, pos) + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) + out = out.transpose(1, 2).reshape(B, N, C) + return self.output(out) class LayerScale(torch.nn.Module): @@ -64,9 +97,11 @@ class SwiGLUFFN(torch.nn.Module): class Dino2Block(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn): + def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn, + qk_norm=False): super().__init__() - self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations) + self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations, + qk_norm=qk_norm) self.layer_scale1 = LayerScale(dim, dtype, device, operations) self.layer_scale2 = LayerScale(dim, dtype, device, operations) if use_swiglu_ffn: @@ -76,19 +111,93 @@ class Dino2Block(torch.nn.Module): self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) - def forward(self, x, optimized_attention): - x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention)) + def forward(self, x, optimized_attention, pos=None, rope=None, attn_mask=None): + x = x + self.layer_scale1(self.attention(self.norm1(x), attn_mask, optimized_attention, + pos=pos, rope=rope)) x = x + self.layer_scale2(self.mlp(self.norm2(x))) return x -class Dino2Encoder(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn): +# ----------------------------------------------------------------------------- +# 2D Rotary position embedding (DA3 extension) +# ----------------------------------------------------------------------------- + + +class _PositionGetter: + """Cache (h, w) -> flat (y, x) position grid used to feed ``rope``.""" + + def __init__(self): + self._cache: dict = {} + + def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor: + key = (height, width, device) + if key not in self._cache: + y = torch.arange(height, device=device) + x = torch.arange(width, device=device) + self._cache[key] = torch.cartesian_prod(y, x) + cached = self._cache[key] + return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone() + + +class RotaryPositionEmbedding2D(torch.nn.Module): + """2D RoPE used by DA3-Small/Base. No learnable parameters.""" + + def __init__(self, frequency: float = 100.0): super().__init__() - self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) - for _ in range(num_layers)]) + self.base_frequency = frequency + self._freq_cache: dict = {} + + def _components(self, dim: int, seq_len: int, device, dtype): + key = (dim, seq_len, device, dtype) + if key not in self._freq_cache: + exp = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency ** exp) + pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + ang = torch.einsum("i,j->ij", pos, inv_freq) + ang = ang.to(dtype) + ang = torch.cat((ang, ang), dim=-1) + self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype)) + return self._freq_cache[key] + + @staticmethod + def _rotate(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] + x1, x2 = x[..., : d // 2], x[..., d // 2:] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d(self, tokens, positions, cos_c, sin_c): + cos = F.embedding(positions, cos_c)[:, None, :, :] + sin = F.embedding(positions, sin_c)[:, None, :, :] + return (tokens * cos) + (self._rotate(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + feature_dim = tokens.size(-1) // 2 + max_pos = int(positions.max()) + 1 + cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype) + v, h = tokens.chunk(2, dim=-1) + v = self._apply_1d(v, positions[..., 0], cos_c, sin_c) + h = self._apply_1d(h, positions[..., 1], cos_c, sin_c) + return torch.cat((v, h), dim=-1) + + +class Dino2Encoder(torch.nn.Module): + def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn, + qknorm_start: int = -1, rope: "RotaryPositionEmbedding2D | None" = None, + rope_start: int = -1): + super().__init__() + self.layer = torch.nn.ModuleList([ + Dino2Block( + dim, num_heads, layer_norm_eps, dtype, device, operations, + use_swiglu_ffn=use_swiglu_ffn, + qk_norm=(qknorm_start != -1 and i >= qknorm_start), + ) + for i in range(num_layers) + ]) + self.rope = rope + self.rope_start = rope_start def forward(self, x, intermediate_output=None): + # Backward-compat path used by ``ClipVisionModel`` (no DA3 extensions). optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) if intermediate_output is not None: @@ -121,25 +230,79 @@ class Dino2PatchEmbeddings(torch.nn.Module): class Dino2Embeddings(torch.nn.Module): - def __init__(self, dim, dtype, device, operations): + def __init__(self, dim, dtype, device, operations, + patch_size: int = 14, image_size: int = 518, + use_mask_token: bool = True, + num_camera_tokens: int = 0): super().__init__() - patch_size = 14 - image_size = 518 + self.patch_size = patch_size + self.image_size = image_size self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations) self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device)) self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) - self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device)) + if use_mask_token: + self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device)) + else: + self.mask_token = None + if num_camera_tokens > 0: + # DA3 stores (ref_token, src_token) pairs that get injected at the + # alt-attn boundary; see ``Dinov2Model._inject_camera_token``. + self.camera_token = torch.nn.Parameter(torch.empty(1, num_camera_tokens, dim, dtype=dtype, device=device)) + else: + self.camera_token = None + + def _interpolate_pos_encoding(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor: + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.position_embeddings.shape[1] - 1 + pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype).float() + if npatch == N and w == h: + return pos_embed + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) + assert N == M * M + # Historical 0.1 offset preserves bicubic resample compatibility with + # the original DINOv2 release; see the upstream PR for context. + sx = float(w0 + 0.1) / M + sy = float(h0 + 0.1) / M + patch_pos_embed = F.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), mode="bicubic", antialias=False, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) def forward(self, pixel_values): + _, _, H, W = pixel_values.shape x = self.patch_embeddings(pixel_values) # TODO: mask_token? x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1) - x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype) + x = x + self._interpolate_pos_encoding(x, H, W) return x class Dinov2Model(torch.nn.Module): + """DINOv2 vision backbone. + + Supports two operating modes: + + * **CLIP-vision DINOv2** (default): vanilla DINOv2-ViT used for + ``ClipVisionModel`` and SigLIP-style image encoding. + * **Depth Anything 3** extensions (opt-in via config keys): 2D RoPE, + QK-norm, alternating local/global attention, camera-token injection, + ``cat_token`` output and multi-layer feature extraction. These are + enabled when the corresponding fields (``alt_start``, ``qknorm_start``, + ``rope_start``, ``cat_token``) are set in ``config_dict``. When all of + them are at their disabled defaults this module behaves identically to + the historical ``Dinov2Model``. + """ + def __init__(self, config_dict, dtype, device, operations): super().__init__() num_layers = config_dict["num_hidden_layers"] @@ -147,14 +310,171 @@ class Dinov2Model(torch.nn.Module): heads = config_dict["num_attention_heads"] layer_norm_eps = config_dict["layer_norm_eps"] use_swiglu_ffn = config_dict["use_swiglu_ffn"] + patch_size = config_dict.get("patch_size", 14) + image_size = config_dict.get("image_size", 518) + use_mask_token = config_dict.get("use_mask_token", True) - self.embeddings = Dino2Embeddings(dim, dtype, device, operations) - self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) + # DA3 extensions (all default to disabled). + self.alt_start = config_dict.get("alt_start", -1) + self.qknorm_start = config_dict.get("qknorm_start", -1) + self.rope_start = config_dict.get("rope_start", -1) + self.cat_token = config_dict.get("cat_token", False) + rope_freq = config_dict.get("rope_freq", 100.0) + + self.embed_dim = dim + self.patch_size = patch_size + self.num_register_tokens = 0 + self.patch_start_idx = 1 + + if self.rope_start != -1 and rope_freq > 0: + self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) + self._position_getter = _PositionGetter() + else: + self.rope = None + self._position_getter = None + + # camera_token shape: (1, 2, dim) -> (ref_token, src_token). + num_cam_tokens = 2 if self.alt_start != -1 else 0 + + self.embeddings = Dino2Embeddings( + dim, dtype, device, operations, + patch_size=patch_size, image_size=image_size, + use_mask_token=use_mask_token, num_camera_tokens=num_cam_tokens, + ) + self.encoder = Dino2Encoder( + dim, heads, layer_norm_eps, num_layers, dtype, device, operations, + use_swiglu_ffn=use_swiglu_ffn, + qknorm_start=self.qknorm_start, + rope=self.rope, rope_start=self.rope_start, + ) self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) + # ------------------------------------------------------------------ + # CLIP-vision-style forward (no DA3 extensions, no multi-layer output). + # Kept for backward compatibility with ``ClipVisionModel.encode_image``. + # ------------------------------------------------------------------ def forward(self, pixel_values, attention_mask=None, intermediate_output=None): x = self.embeddings(pixel_values) x, i = self.encoder(x, intermediate_output=intermediate_output) x = self.layernorm(x) pooled_output = x[:, 0, :] return x, i, pooled_output, None + + # ------------------------------------------------------------------ + # Depth Anything 3 forward + # ------------------------------------------------------------------ + def _prepare_rope_positions(self, B, S, H, W, device): + if self.rope is None: + return None, None + ph, pw = H // self.patch_size, W // self.patch_size + pos = self._position_getter(B * S, ph, pw, device=device) + # Shift so the cls/cam token at position 0 is reserved for "no diff". + pos = pos + 1 + cls_pos = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype) + # Per-view local: real grid positions for patches, 0 for cls token. + pos_local = torch.cat([cls_pos, pos], dim=1) + # Global (across views): same grid positions; cls token still at 0, + # but patches share the same positions in every view. + pos_global = torch.cat([cls_pos, torch.zeros_like(pos) + 1], dim=1) + return pos_local, pos_global + + def _inject_camera_token(self, x: torch.Tensor, B: int, S: int, + cam_token: "torch.Tensor | None") -> torch.Tensor: + # x: (B, S, N, C). Replace token at index 0 with the camera token. + if cam_token is not None: + inj = cam_token + else: + ct = comfy.model_management.cast_to_device(self.embeddings.camera_token, x.device, x.dtype) + ref_token = ct[:, :1].expand(B, -1, -1) + src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1) + inj = torch.cat([ref_token, src_token], dim=1) + x = x.clone() + x[:, :, 0] = inj + return x + + def get_intermediate_layers(self, pixel_values, out_layers, cam_token=None): + """Multi-layer DINOv2 feature extraction used by Depth Anything 3. + + Args: + pixel_values: ``(B, S, 3, H, W)`` views or ``(B, 3, H, W)``. + out_layers: indices into ``self.encoder.layer``. + cam_token: optional ``(B, S, dim)`` camera token to inject at + ``alt_start``. If ``None`` and the model has its own + ``camera_token`` parameter, that is used. + + Returns: + List of ``(patch_tokens, cls_or_cam_token)`` tuples, one per + requested ``out_layers`` entry. ``patch_tokens`` has shape + ``(B, S, N_patch, C)`` (or ``(B, S, N_patch, 2*C)`` when the + model was configured with ``cat_token=True``); the second item + has shape ``(B, S, C)``. + """ + if pixel_values.ndim == 4: + pixel_values = pixel_values.unsqueeze(1) + assert pixel_values.ndim == 5 and pixel_values.shape[2] == 3, \ + f"expected (B,3,H,W) or (B,S,3,H,W); got {tuple(pixel_values.shape)}" + B, S, _, H, W = pixel_values.shape + + # Patch + cls + (interpolated) pos embed for each view. + x = pixel_values.reshape(B * S, 3, H, W) + x = self.embeddings(x) # (B*S, 1+N, C) + x = x.reshape(B, S, x.shape[-2], x.shape[-1]) # (B, S, 1+N, C) + + pos_local, pos_global = self._prepare_rope_positions(B, S, H, W, x.device) + # ``optimized_attention`` is only used by blocks without QK-norm/RoPE + # (vanilla DINOv2 path); enabling-aware blocks fall through to SDPA. + optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) + + out_set = set(out_layers) + outputs: list[torch.Tensor] = [] + local_x = x + + for i, blk in enumerate(self.encoder.layer): + apply_rope = self.rope is not None and i >= self.rope_start + block_rope = self.rope if apply_rope else None + l_pos = pos_local if apply_rope else None + g_pos = pos_global if apply_rope else None + + if self.alt_start != -1 and i == self.alt_start: + x = self._inject_camera_token(x, B, S, cam_token) + + if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1): + # Global attention across views: flatten S into the seq dim. + t = x.reshape(B, S * x.shape[-2], x.shape[-1]) + p = g_pos.reshape(B, S * g_pos.shape[-2], g_pos.shape[-1]) if g_pos is not None else None + t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope) + x = t.reshape(B, S, x.shape[-2], x.shape[-1]) + else: + # Per-view local attention. + t = x.reshape(B * S, x.shape[-2], x.shape[-1]) + p = l_pos.reshape(B * S, l_pos.shape[-2], l_pos.shape[-1]) if l_pos is not None else None + t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope) + x = t.reshape(B, S, x.shape[-2], x.shape[-1]) + local_x = x + + if i in out_set: + if self.cat_token: + out_x = torch.cat([local_x, x], dim=-1) + else: + out_x = x + outputs.append(out_x) + + # Apply final norm. When ``cat_token`` is set, only the right half + # ("global" features) is normalised; the left half is left as-is to + # match the upstream DA3 head signature. + normed: list[torch.Tensor] = [] + cls_tokens: list[torch.Tensor] = [] + for out_x in outputs: + cls_tokens.append(out_x[:, :, 0]) + if out_x.shape[-1] == self.embed_dim: + normed.append(self.layernorm(out_x)) + elif out_x.shape[-1] == self.embed_dim * 2: + left = out_x[..., :self.embed_dim] + right = self.layernorm(out_x[..., self.embed_dim:]) + normed.append(torch.cat([left, right], dim=-1)) + else: + raise ValueError(f"Unexpected token width: {out_x.shape[-1]}") + + # Drop cls/cam token from the patch sequence. + normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed] + return list(zip(normed, cls_tokens)) diff --git a/comfy/ldm/depth_anything_3/dinov2.py b/comfy/ldm/depth_anything_3/dinov2.py deleted file mode 100644 index 032cecda4..000000000 --- a/comfy/ldm/depth_anything_3/dinov2.py +++ /dev/null @@ -1,497 +0,0 @@ -# DINOv2 backbone for Depth Anything 3 (monocular inference path). -# -# Why not reuse ``comfy/image_encoders/dino2.py``? -# The existing ``Dinov2Model`` is a vanilla HuggingFace-style DINOv2 with a -# different state-dict layout (separate Q/K/V, ``embeddings.*`` / -# ``encoder.layer.*`` keys, ``layer_scale1.lambda1``) and no support for the -# architectural extensions DA3 adds on top of DINOv2 (RoPE, QK-norm, -# alternating local/global attention, concatenated camera token). Loading -# raw DA3 HF safetensors into ``Dinov2Model`` would require splitting -# ``attn.qkv`` weights and a large rename map for every block, and we'd -# still need to write the DA3 extensions separately. Keeping the upstream -# ``pretrained.*`` key layout here means HF weights load directly with no -# conversion step. -# -# Ported from the upstream repo at: -# src/depth_anything_3/model/dinov2/{dinov2,vision_transformer}.py -# src/depth_anything_3/model/dinov2/layers/* -# -# DA3 extensions on top of vanilla DINOv2 (only used by Small/Base variants): -# - 2D Rotary Position Embedding starting at ``rope_start`` -# - QK-norm starting at ``qknorm_start`` -# - Alternating local/global attention blocks starting at ``alt_start`` -# - Camera-conditioning token concatenated to features (``cat_token=True``), -# with a learned parameter ``camera_token`` injected at block -# ``alt_start`` when no external camera token is supplied. -# -# For the Mono/Metric variants the configuration disables all of the above -# (alt_start/qknorm_start/rope_start = -1, cat_token=False) so this module -# collapses to a vanilla DINOv2-ViT encoder. - -from __future__ import annotations - -import math -from typing import List, Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange - - -# ----------------------------------------------------------------------------- -# 2D rotary position embedding -# ----------------------------------------------------------------------------- - - -class PositionGetter: - def __init__(self): - self._cache: dict[tuple[int, int], torch.Tensor] = {} - - def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor: - key = (height, width) - if key not in self._cache: - y = torch.arange(height, device=device) - x = torch.arange(width, device=device) - self._cache[key] = torch.cartesian_prod(y, x) - cached = self._cache[key] - return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone() - - -class RotaryPositionEmbedding2D(nn.Module): - def __init__(self, frequency: float = 100.0): - super().__init__() - self.base_frequency = frequency - self._freq_cache: dict = {} - - def _components(self, dim: int, seq_len: int, device, dtype): - key = (dim, seq_len, device, dtype) - if key not in self._freq_cache: - exp = torch.arange(0, dim, 2, device=device).float() / dim - inv_freq = 1.0 / (self.base_frequency ** exp) - pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) - ang = torch.einsum("i,j->ij", pos, inv_freq) - ang = ang.to(dtype) - ang = torch.cat((ang, ang), dim=-1) - self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype)) - return self._freq_cache[key] - - @staticmethod - def _rotate(x: torch.Tensor) -> torch.Tensor: - d = x.shape[-1] - x1, x2 = x[..., : d // 2], x[..., d // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def _apply_1d(self, tokens, positions, cos_c, sin_c): - cos = F.embedding(positions, cos_c)[:, None, :, :] - sin = F.embedding(positions, sin_c)[:, None, :, :] - return (tokens * cos) + (self._rotate(tokens) * sin) - - def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: - feature_dim = tokens.size(-1) // 2 - max_pos = int(positions.max()) + 1 - cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype) - v, h = tokens.chunk(2, dim=-1) - v = self._apply_1d(v, positions[..., 0], cos_c, sin_c) - h = self._apply_1d(h, positions[..., 1], cos_c, sin_c) - return torch.cat((v, h), dim=-1) - - -# ----------------------------------------------------------------------------- -# Patch embed / MLP / SwiGLU / LayerScale -# ----------------------------------------------------------------------------- - - -class PatchEmbed(nn.Module): - def __init__(self, patch_size=14, in_chans=3, embed_dim=384, - device=None, dtype=None, operations=None): - super().__init__() - self.patch_size = (patch_size, patch_size) - self.proj = operations.Conv2d( - in_chans, embed_dim, - kernel_size=self.patch_size, stride=self.patch_size, - device=device, dtype=dtype, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - _, _, H, W = x.shape - ph, pw = self.patch_size - assert H % ph == 0 and W % pw == 0 - x = self.proj(x) - x = x.flatten(2).transpose(1, 2) - return x - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, bias=True, - act_layer=nn.GELU, device=None, dtype=None, operations=None): - super().__init__() - hidden_features = hidden_features or in_features - self.fc1 = operations.Linear(in_features, hidden_features, bias=bias, device=device, dtype=dtype) - self.act = act_layer() - self.fc2 = operations.Linear(hidden_features, in_features, bias=bias, device=device, dtype=dtype) - - def forward(self, x): - return self.fc2(self.act(self.fc1(x))) - - -class SwiGLUFFNFused(nn.Module): - """SwiGLU FFN matching upstream xformers.ops.SwiGLU layout (used for vitg).""" - - def __init__(self, in_features, hidden_features=None, bias=True, - device=None, dtype=None, operations=None): - super().__init__() - hidden_features = hidden_features or in_features - hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 - # NOTE: xformers SwiGLU stores w12 as a single fused Linear (in, 2*hidden); - # split-by-half at forward time. We don't currently need this for the - # Apache-2.0 variants but keep it for parity with the upstream key names. - self.w12 = operations.Linear(in_features, 2 * hidden_features, bias=bias, device=device, dtype=dtype) - self.w3 = operations.Linear(hidden_features, in_features, bias=bias, device=device, dtype=dtype) - - def forward(self, x): - x12 = self.w12(x) - x1, x2 = x12.chunk(2, dim=-1) - return self.w3(F.silu(x1) * x2) - - -class LayerScale(nn.Module): - def __init__(self, dim, init_values: float = 1e-5, device=None, dtype=None): - super().__init__() - self.gamma = nn.Parameter(init_values * torch.ones(dim, device=device, dtype=dtype)) - - def forward(self, x): - return x * comfy_cast(self.gamma, x) - - -def comfy_cast(p: torch.Tensor, ref: torch.Tensor) -> torch.Tensor: - """Cast a parameter to match the reference tensor's device/dtype.""" - if p.device != ref.device or p.dtype != ref.dtype: - return p.to(device=ref.device, dtype=ref.dtype) - return p - - -# ----------------------------------------------------------------------------- -# Attention + Block -# ----------------------------------------------------------------------------- - - -class Attention(nn.Module): - def __init__(self, dim, num_heads: int, qkv_bias: bool = True, proj_bias: bool = True, - qk_norm: bool = False, rope: Optional[RotaryPositionEmbedding2D] = None, - device=None, dtype=None, operations=None): - super().__init__() - assert dim % num_heads == 0 - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype) - self.q_norm = operations.LayerNorm(self.head_dim, device=device, dtype=dtype) if qk_norm else nn.Identity() - self.k_norm = operations.LayerNorm(self.head_dim, device=device, dtype=dtype) if qk_norm else nn.Identity() - self.proj = operations.Linear(dim, dim, bias=proj_bias, device=device, dtype=dtype) - self.rope = rope - - def forward(self, x: torch.Tensor, pos: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] - q, k = self.q_norm(q), self.k_norm(k) - if self.rope is not None and pos is not None: - q = self.rope(q, pos) - k = self.rope(k, pos) - x = F.scaled_dot_product_attention( - q, k, v, - attn_mask=(attn_mask[:, None].repeat(1, self.num_heads, 1, 1) - if attn_mask is not None else None), - ) - x = x.transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - return x - - -class Block(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio: float = 4.0, - qkv_bias: bool = True, proj_bias: bool = True, ffn_bias: bool = True, - init_values: Optional[float] = 1.0, - norm_layer=nn.LayerNorm, - ffn_layer=Mlp, - qk_norm: bool = False, - rope: Optional[RotaryPositionEmbedding2D] = None, - ln_eps: float = 1e-6, - device=None, dtype=None, operations=None): - super().__init__() - self.norm1 = operations.LayerNorm(dim, eps=ln_eps, device=device, dtype=dtype) - self.attn = Attention( - dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, - qk_norm=qk_norm, rope=rope, device=device, dtype=dtype, operations=operations, - ) - self.ls1 = (LayerScale(dim, init_values=init_values, device=device, dtype=dtype) - if init_values else nn.Identity()) - - self.norm2 = operations.LayerNorm(dim, eps=ln_eps, device=device, dtype=dtype) - mlp_hidden = int(dim * mlp_ratio) - self.mlp = ffn_layer( - in_features=dim, hidden_features=mlp_hidden, - bias=ffn_bias, device=device, dtype=dtype, operations=operations, - ) - self.ls2 = (LayerScale(dim, init_values=init_values, device=device, dtype=dtype) - if init_values else nn.Identity()) - - def forward(self, x, pos=None, attn_mask=None): - x = x + self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask)) - x = x + self.ls2(self.mlp(self.norm2(x))) - return x - - -# ----------------------------------------------------------------------------- -# DINOv2 vision transformer -# ----------------------------------------------------------------------------- - - -_BACKBONE_PRESETS = { - "vits": dict(embed_dim=384, depth=12, num_heads=6, ffn_layer="mlp"), - "vitb": dict(embed_dim=768, depth=12, num_heads=12, ffn_layer="mlp"), - "vitl": dict(embed_dim=1024, depth=24, num_heads=16, ffn_layer="mlp"), - "vitg": dict(embed_dim=1536, depth=40, num_heads=24, ffn_layer="swiglufused"), -} - - -class DinoVisionTransformer(nn.Module): - PATCH_SIZE = 14 - - def __init__(self, - embed_dim: int, - depth: int, - num_heads: int, - ffn_layer: str = "mlp", - mlp_ratio: float = 4.0, - init_values: float = 1.0, - alt_start: int = -1, - qknorm_start: int = -1, - rope_start: int = -1, - rope_freq: float = 100.0, - cat_token: bool = True, - device=None, dtype=None, operations=None): - super().__init__() - norm_layer = nn.LayerNorm - self.embed_dim = embed_dim - self.num_heads = num_heads - self.alt_start = alt_start - self.qknorm_start = qknorm_start - self.rope_start = rope_start - self.cat_token = cat_token - self.patch_size = self.PATCH_SIZE - self.num_register_tokens = 0 - self.patch_start_idx = 1 - self.num_tokens = 1 - - self.patch_embed = PatchEmbed( - patch_size=self.PATCH_SIZE, in_chans=3, embed_dim=embed_dim, - device=device, dtype=dtype, operations=operations, - ) - # Number of patch positions for the historical 518x518 reference grid. - ref_grid = 518 // self.PATCH_SIZE - num_patches = ref_grid * ref_grid - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, device=device, dtype=dtype)) - if alt_start != -1: - self.camera_token = nn.Parameter(torch.zeros(1, 2, embed_dim, device=device, dtype=dtype)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim, - device=device, dtype=dtype)) - - if rope_start != -1 and rope_freq > 0: - self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) - self.position_getter = PositionGetter() - else: - self.rope = None - self.position_getter = None - - if ffn_layer == "mlp": - ffn = Mlp - elif ffn_layer in ("swiglu", "swiglufused"): - ffn = SwiGLUFFNFused - else: - raise NotImplementedError(f"Unsupported ffn_layer: {ffn_layer}") - - self.blocks = nn.ModuleList([ - Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=True, - proj_bias=True, - ffn_bias=True, - init_values=init_values, - norm_layer=norm_layer, - ffn_layer=ffn, - qk_norm=(qknorm_start != -1 and i >= qknorm_start), - rope=(self.rope if (rope_start != -1 and i >= rope_start) else None), - device=device, dtype=dtype, operations=operations, - ) - for i in range(depth) - ]) - self.norm = operations.LayerNorm(embed_dim, device=device, dtype=dtype) - - # ------------------------------------------------------------------ - # Helpers - # ------------------------------------------------------------------ - def interpolate_pos_encoding(self, x, w, h): - previous_dtype = x.dtype - npatch = x.shape[1] - 1 - N = self.pos_embed.shape[1] - 1 - pos_embed = comfy_cast(self.pos_embed, x).float() - if npatch == N and w == h: - return pos_embed - class_pos_embed = pos_embed[:, 0] - patch_pos_embed = pos_embed[:, 1:] - dim = x.shape[-1] - w0 = w // self.patch_size - h0 = h // self.patch_size - M = int(math.sqrt(N)) - assert N == M * M - # Historical 0.1 offset preserves bicubic resample compatibility with - # the original DINOv2 release; see the upstream PR for context. - sx = float(w0 + 0.1) / M - sy = float(h0 + 0.1) / M - patch_pos_embed = F.interpolate( - patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), - scale_factor=(sx, sy), mode="bicubic", antialias=False, - ) - assert (w0, h0) == patch_pos_embed.shape[-2:] - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) - - def prepare_tokens(self, x: torch.Tensor) -> torch.Tensor: - # x: (B, S, 3, H, W) -> tokens (B, S, 1+N, C) - B, S, _, H, W = x.shape - x = rearrange(x, "b s c h w -> (b s) c h w") - x = self.patch_embed(x) - cls_token = comfy_cast(self.cls_token, x).expand(B, S, -1).reshape(B * S, 1, self.embed_dim) - x = torch.cat((cls_token, x), dim=1) - x = x + self.interpolate_pos_encoding(x, W, H) - x = rearrange(x, "(b s) n c -> b s n c", b=B, s=S) - return x - - def _prepare_rope(self, B, S, H, W, device): - if self.rope is None: - return None, None - pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=device) - pos = rearrange(pos, "(b s) n c -> b s n c", b=B) - pos_nodiff = torch.zeros_like(pos) - if self.patch_start_idx > 0: - pos = pos + 1 - pos_special = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype) - pos_special = rearrange(pos_special, "(b s) n c -> b s n c", b=B) - pos = torch.cat([pos_special, pos], dim=2) - pos_nodiff = pos_nodiff + 1 - pos_nodiff = torch.cat([pos_special, pos_nodiff], dim=2) - return pos, pos_nodiff - - def _attn(self, x, blk, attn_type, pos=None, attn_mask=None): - b, s, n = x.shape[:3] - if attn_type == "local": - x = rearrange(x, "b s n c -> (b s) n c") - if pos is not None: - pos = rearrange(pos, "b s n c -> (b s) n c") - else: # "global" - x = rearrange(x, "b s n c -> b (s n) c") - if pos is not None: - pos = rearrange(pos, "b s n c -> b (s n) c") - x = blk(x, pos=pos, attn_mask=attn_mask) - if attn_type == "local": - x = rearrange(x, "(b s) n c -> b s n c", b=b, s=s) - else: - x = rearrange(x, "b (s n) c -> b s n c", b=b, s=s) - return x - - # ------------------------------------------------------------------ - # Public forward - # ------------------------------------------------------------------ - def get_intermediate_layers(self, x: torch.Tensor, out_layers: List[int], - cam_token: Optional[torch.Tensor] = None): - B, S, _, H, W = x.shape - x = self.prepare_tokens(x) - pos, pos_nodiff = self._prepare_rope(B, S, H, W, x.device) - - outputs = [] - local_x = x - - for i, blk in enumerate(self.blocks): - if self.rope is not None and i >= self.rope_start: - g_pos, l_pos = pos_nodiff, pos - else: - g_pos, l_pos = None, None - - if self.alt_start != -1 and i == self.alt_start: - # Inject camera token at the alt-start boundary. - if cam_token is not None: - inj = cam_token - else: - ct = comfy_cast(self.camera_token, x) - ref_token = ct[:, :1].expand(B, -1, -1) - src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1) - inj = torch.cat([ref_token, src_token], dim=1) - x = x.clone() - x[:, :, 0] = inj - - if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1): - x = self._attn(x, blk, "global", pos=g_pos) - else: - x = self._attn(x, blk, "local", pos=l_pos) - local_x = x - - if i in out_layers: - out_x = torch.cat([local_x, x], dim=-1) if self.cat_token else x - outputs.append(out_x) - - # Apply final norm. Upstream norms only the "global" half when cat_token. - normed: List[torch.Tensor] = [] - camera_tokens: List[torch.Tensor] = [] - for out_x in outputs: - # Camera/cls token slot is index 0 *before* register-token stripping. - camera_tokens.append(out_x[:, :, 0]) - if out_x.shape[-1] == self.embed_dim: - normed.append(self.norm(out_x)) - elif out_x.shape[-1] == self.embed_dim * 2: - left = out_x[..., : self.embed_dim] - right = self.norm(out_x[..., self.embed_dim :]) - normed.append(torch.cat([left, right], dim=-1)) - else: - raise ValueError(f"Unexpected token width: {out_x.shape[-1]}") - # Drop cls/cam token + register tokens from patch sequence. - normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed] - # Match upstream signature consumed by DA3 heads: - # feats[i][0] = normed patch tokens, feats[i][1] = camera/cls token. - return list(zip(normed, camera_tokens)) - - -class DinoV2(nn.Module): - """Top-level DINOv2 wrapper matching upstream key layout (``self.pretrained``).""" - - def __init__(self, name: str = "vits", - out_layers: Optional[List[int]] = None, - alt_start: int = -1, - qknorm_start: int = -1, - rope_start: int = -1, - cat_token: bool = True, - device=None, dtype=None, operations=None, **kwargs): - super().__init__() - if name not in _BACKBONE_PRESETS: - raise ValueError(f"Unknown DINOv2 backbone variant: {name!r}") - preset = _BACKBONE_PRESETS[name] - self.name = name - self.out_layers = list(out_layers) if out_layers is not None else [5, 7, 9, 11] - self.cat_token = cat_token - self.pretrained = DinoVisionTransformer( - embed_dim=preset["embed_dim"], - depth=preset["depth"], - num_heads=preset["num_heads"], - ffn_layer=preset["ffn_layer"], - alt_start=alt_start, - qknorm_start=qknorm_start, - rope_start=rope_start, - cat_token=cat_token, - device=device, dtype=dtype, operations=operations, - ) - - def forward(self, x, cam_token=None, **_unused): - return self.pretrained.get_intermediate_layers(x, self.out_layers, cam_token=cam_token) diff --git a/comfy/ldm/depth_anything_3/model.py b/comfy/ldm/depth_anything_3/model.py index 8bf5e9ec2..30e6af24d 100644 --- a/comfy/ldm/depth_anything_3/model.py +++ b/comfy/ldm/depth_anything_3/model.py @@ -9,15 +9,25 @@ # The class signature mirrors the upstream YAML config so a single dit_config # detected from the state dict in ``comfy/model_detection.py`` is sufficient # to construct the right variant. +# +# Backbone: ``comfy.image_encoders.dino2.Dinov2Model`` is shared with the +# CLIP-vision DINOv2 path. DA3-specific extensions (RoPE, QK-norm, +# alternating local/global attention, camera token, multi-layer feature +# extraction, pos-embed interpolation) are opt-in via the config dict and are +# all disabled for the Mono/Metric variants. The upstream DA3 weight layout +# (``backbone.pretrained.*`` with fused QKV) is converted to the +# ``Dinov2Model`` layout in +# ``comfy.supported_models.DepthAnything3.process_unet_state_dict``. from __future__ import annotations -from typing import Dict, List, Optional, Sequence +from typing import Dict, Optional, Sequence import torch import torch.nn as nn -from .dinov2 import DinoV2 +from comfy.image_encoders.dino2 import Dinov2Model + from .dpt import DPT, DualDPT @@ -27,6 +37,42 @@ _HEAD_REGISTRY = { } +# Backbone presets (mirror the upstream DINOv2 ViT variants). +_BACKBONE_PRESETS = { + "vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False), + "vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False), + "vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False), + "vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True), +} + + +def _build_backbone_config( + backbone_name: str, + *, + alt_start: int, + qknorm_start: int, + rope_start: int, + cat_token: bool, +) -> dict: + if backbone_name not in _BACKBONE_PRESETS: + raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}") + cfg = dict(_BACKBONE_PRESETS[backbone_name]) + cfg.update(dict( + layer_norm_eps=1e-6, + patch_size=14, + image_size=518, + # DA3 weights have no mask_token; skip registering it to avoid spurious + # missing-key warnings on load. + use_mask_token=False, + alt_start=alt_start, + qknorm_start=qknorm_start, + rope_start=rope_start, + cat_token=cat_token, + rope_freq=100.0, + )) + return cfg + + class DepthAnything3Net(nn.Module): """ComfyUI-side DepthAnything3 network (monocular path only). @@ -64,16 +110,16 @@ class DepthAnything3Net(nn.Module): self.head_type = head_type.lower() self.has_sky = (self.head_type == "dpt") and head_use_sky_head self.has_conf = head_output_dim > 1 + self.out_layers = list(out_layers) - self.backbone = DinoV2( - name=backbone_name, - out_layers=list(out_layers), + backbone_cfg = _build_backbone_config( + backbone_name, alt_start=alt_start, qknorm_start=qknorm_start, rope_start=rope_start, cat_token=cat_token, - device=device, dtype=dtype, operations=operations, ) + self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations) head_kwargs = dict( dim_in=head_dim_in, @@ -122,7 +168,7 @@ class DepthAnything3Net(nn.Module): assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \ f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}" - feats = self.backbone(image) + feats = self.backbone.get_intermediate_layers(image, self.out_layers) head_out = self.head(feats, H=H, W=W, patch_start_idx=0) # Flatten the views axis (S=1 in mono inference path). diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 9540d1d69..2d4a1a34f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1871,8 +1871,99 @@ class DepthAnything3(supported_models_base.BASE): for k in list(state_dict.keys()): if k.startswith(drop_prefixes): state_dict.pop(k) + # Remap upstream DA3 backbone keys (``backbone.pretrained.*`` with + # fused QKV) to the layout used by ``comfy.image_encoders.dino2.Dinov2Model``. + return _da3_remap_backbone_keys(state_dict, prefix="backbone.") + + +def _da3_remap_backbone_keys(state_dict, prefix="backbone."): + """Rewrite upstream DA3 DINOv2 keys to the shared ``Dinov2Model`` layout. + + Upstream layout (under ``{prefix}pretrained.``): + patch_embed.proj.{weight,bias}, pos_embed, cls_token, camera_token, norm.*, + blocks.{i}.norm{1,2}.*, blocks.{i}.attn.qkv.{weight,bias}, + blocks.{i}.attn.q_norm.*, blocks.{i}.attn.k_norm.*, + blocks.{i}.attn.proj.*, blocks.{i}.ls{1,2}.gamma, + blocks.{i}.mlp.fc{1,2}.* (or w12/w3 for SwiGLU) + + Target layout (Dinov2Model under ``{prefix}``): + embeddings.patch_embeddings.projection.*, + embeddings.position_embeddings, embeddings.cls_token, embeddings.camera_token, + layernorm.*, + encoder.layer.{i}.norm{1,2}.*, + encoder.layer.{i}.attention.attention.{query,key,value}.*, + encoder.layer.{i}.attention.q_norm.*, encoder.layer.{i}.attention.k_norm.*, + encoder.layer.{i}.attention.output.dense.*, + encoder.layer.{i}.layer_scale{1,2}.lambda1, + encoder.layer.{i}.mlp.fc{1,2}.* (or weights_in/weights_out for SwiGLU) + """ + pre = prefix + "pretrained." + src_keys = [k for k in state_dict.keys() if k.startswith(pre)] + if not src_keys: return state_dict + static_renames = { + pre + "patch_embed.proj.weight": prefix + "embeddings.patch_embeddings.projection.weight", + pre + "patch_embed.proj.bias": prefix + "embeddings.patch_embeddings.projection.bias", + pre + "pos_embed": prefix + "embeddings.position_embeddings", + pre + "cls_token": prefix + "embeddings.cls_token", + pre + "camera_token": prefix + "embeddings.camera_token", + pre + "norm.weight": prefix + "layernorm.weight", + pre + "norm.bias": prefix + "layernorm.bias", + } + for src, dst in static_renames.items(): + if src in state_dict: + state_dict[dst] = state_dict.pop(src) + + block_pre = pre + "blocks." + block_keys = [k for k in state_dict.keys() if k.startswith(block_pre)] + for k in block_keys: + rest = k[len(block_pre):] # e.g. "5.attn.qkv.weight" + idx_str, _, sub = rest.partition(".") + target_block = "{}encoder.layer.{}.".format(prefix, idx_str) + + # Fused QKV -> split query/key/value linears. + if sub == "attn.qkv.weight": + qkv = state_dict.pop(k) + c = qkv.shape[0] // 3 + state_dict[target_block + "attention.attention.query.weight"] = qkv[:c].clone() + state_dict[target_block + "attention.attention.key.weight"] = qkv[c:2 * c].clone() + state_dict[target_block + "attention.attention.value.weight"] = qkv[2 * c:].clone() + continue + if sub == "attn.qkv.bias": + qkv = state_dict.pop(k) + c = qkv.shape[0] // 3 + state_dict[target_block + "attention.attention.query.bias"] = qkv[:c].clone() + state_dict[target_block + "attention.attention.key.bias"] = qkv[c:2 * c].clone() + state_dict[target_block + "attention.attention.value.bias"] = qkv[2 * c:].clone() + continue + + # Sub-key remap (suffix preserved). + if sub.startswith("attn.proj."): + tail = sub[len("attn.proj."):] + new = "attention.output.dense." + tail + elif sub.startswith("attn.q_norm."): + new = "attention.q_norm." + sub[len("attn.q_norm."):] + elif sub.startswith("attn.k_norm."): + new = "attention.k_norm." + sub[len("attn.k_norm."):] + elif sub == "ls1.gamma": + new = "layer_scale1.lambda1" + elif sub == "ls2.gamma": + new = "layer_scale2.lambda1" + elif sub.startswith("mlp.w12."): + new = "mlp.weights_in." + sub[len("mlp.w12."):] + elif sub.startswith("mlp.w3."): + new = "mlp.weights_out." + sub[len("mlp.w3."):] + elif sub.startswith(("norm1.", "norm2.", "mlp.fc1.", "mlp.fc2.")): + new = sub + else: + # Unrecognised key -- leave as-is so load_state_dict can complain. + continue + + state_dict[target_block + new] = state_dict.pop(k) + + return state_dict + class ErnieImage(supported_models_base.BASE): unet_config = {