import math import torch import torch.nn.functional as F from comfy.text_encoders.bert import BertAttention import comfy.model_management from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.depth_anything_3.reference_view_selector import ( select_reference_view, reorder_by_reference, restore_original_order, THRESH_FOR_REF_SELECTION, ) class Dino2AttentionOutput(torch.nn.Module): def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations): super().__init__() self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device) def forward(self, x): return self.dense(x) class Dino2AttentionBlock(torch.nn.Module): def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations, qk_norm=False): super().__init__() self.heads = heads self.head_dim = embed_dim // heads self.attention = BertAttention(embed_dim, heads, dtype, device, operations) self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations) if qk_norm: self.q_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device) self.k_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device) else: self.q_norm = None self.k_norm = None def forward(self, x, mask, optimized_attention, pos=None, rope=None): # Fast path used by the existing CLIP-vision DINOv2 (no DA3 extensions). if self.q_norm is None and rope is None: return self.output(self.attention(x, mask, optimized_attention)) # DA3 path: do QKV manually so we can apply per-head QK-norm and 2D RoPE. attn = self.attention B, N, C = x.shape h = self.heads d = self.head_dim q = attn.query(x).view(B, N, h, d).transpose(1, 2) k = attn.key(x).view(B, N, h, d).transpose(1, 2) v = attn.value(x).view(B, N, h, d).transpose(1, 2) if self.q_norm is not None: q = self.q_norm(q) k = self.k_norm(k) if rope is not None and pos is not None: q = rope(q, pos) k = rope(k, pos) out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) out = out.transpose(1, 2).reshape(B, N, C) return self.output(out) class LayerScale(torch.nn.Module): def __init__(self, dim, dtype, device, operations): super().__init__() self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) def forward(self, x): return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype) class Dinov2MLP(torch.nn.Module): def __init__(self, hidden_size: int, dtype, device, operations): super().__init__() mlp_ratio = 4 hidden_features = int(hidden_size * mlp_ratio) self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype) self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype) def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: hidden_state = self.fc1(hidden_state) hidden_state = torch.nn.functional.gelu(hidden_state) hidden_state = self.fc2(hidden_state) return hidden_state class SwiGLUFFN(torch.nn.Module): def __init__(self, dim, dtype, device, operations): super().__init__() in_features = out_features = dim hidden_features = int(dim * 4) hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype) self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype) def forward(self, x): x = self.weights_in(x) x1, x2 = x.chunk(2, dim=-1) x = torch.nn.functional.silu(x1) * x2 return self.weights_out(x) class Dino2Block(torch.nn.Module): def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn, qk_norm=False): super().__init__() self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations, qk_norm=qk_norm) self.layer_scale1 = LayerScale(dim, dtype, device, operations) self.layer_scale2 = LayerScale(dim, dtype, device, operations) if use_swiglu_ffn: self.mlp = SwiGLUFFN(dim, dtype, device, operations) else: self.mlp = Dinov2MLP(dim, dtype, device, operations) self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) def forward(self, x, optimized_attention, pos=None, rope=None, attn_mask=None): x = x + self.layer_scale1(self.attention(self.norm1(x), attn_mask, optimized_attention, pos=pos, rope=rope)) x = x + self.layer_scale2(self.mlp(self.norm2(x))) return x # ----------------------------------------------------------------------------- # 2D Rotary position embedding (DA3 extension) # ----------------------------------------------------------------------------- class _PositionGetter: """Cache (h, w) -> flat (y, x) position grid used to feed ``rope``.""" def __init__(self): self._cache: dict = {} def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor: key = (height, width, device) if key not in self._cache: y = torch.arange(height, device=device) x = torch.arange(width, device=device) self._cache[key] = torch.cartesian_prod(y, x) cached = self._cache[key] return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone() class RotaryPositionEmbedding2D(torch.nn.Module): """2D RoPE used by DA3-Small/Base. No learnable parameters.""" def __init__(self, frequency: float = 100.0): super().__init__() self.base_frequency = frequency self._freq_cache: dict = {} def _components(self, dim: int, seq_len: int, device, dtype): key = (dim, seq_len, device, dtype) if key not in self._freq_cache: exp = torch.arange(0, dim, 2, device=device).float() / dim inv_freq = 1.0 / (self.base_frequency ** exp) pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) ang = torch.einsum("i,j->ij", pos, inv_freq) ang = ang.to(dtype) ang = torch.cat((ang, ang), dim=-1) self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype)) return self._freq_cache[key] @staticmethod def _rotate(x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] x1, x2 = x[..., : d // 2], x[..., d // 2:] return torch.cat((-x2, x1), dim=-1) def _apply_1d(self, tokens, positions, cos_c, sin_c): cos = F.embedding(positions, cos_c)[:, None, :, :] sin = F.embedding(positions, sin_c)[:, None, :, :] return (tokens * cos) + (self._rotate(tokens) * sin) def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: feature_dim = tokens.size(-1) // 2 max_pos = int(positions.max()) + 1 cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype) v, h = tokens.chunk(2, dim=-1) v = self._apply_1d(v, positions[..., 0], cos_c, sin_c) h = self._apply_1d(h, positions[..., 1], cos_c, sin_c) return torch.cat((v, h), dim=-1) class Dino2Encoder(torch.nn.Module): def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn, qknorm_start: int = -1, rope: "RotaryPositionEmbedding2D | None" = None, rope_start: int = -1): super().__init__() self.layer = torch.nn.ModuleList([ Dino2Block( dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn=use_swiglu_ffn, qk_norm=(qknorm_start != -1 and i >= qknorm_start), ) for i in range(num_layers) ]) self.rope = rope self.rope_start = rope_start def forward(self, x, intermediate_output=None): # Backward-compat path used by ``ClipVisionModel`` (no DA3 extensions). optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) if intermediate_output is not None: if intermediate_output < 0: intermediate_output = len(self.layer) + intermediate_output intermediate = None for i, layer in enumerate(self.layer): x = layer(x, optimized_attention) if i == intermediate_output: intermediate = x.clone() return x, intermediate class Dino2PatchEmbeddings(torch.nn.Module): def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None): super().__init__() self.projection = operations.Conv2d( in_channels=num_channels, out_channels=dim, kernel_size=patch_size, stride=patch_size, bias=True, dtype=dtype, device=device ) def forward(self, pixel_values): return self.projection(pixel_values).flatten(2).transpose(1, 2) class Dino2Embeddings(torch.nn.Module): def __init__(self, dim, dtype, device, operations, patch_size: int = 14, image_size: int = 518, use_mask_token: bool = True, num_camera_tokens: int = 0): super().__init__() self.patch_size = patch_size self.image_size = image_size self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations) self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device)) self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) if use_mask_token: self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device)) else: self.mask_token = None if num_camera_tokens > 0: # DA3 stores (ref_token, src_token) pairs that get injected at the # alt-attn boundary; see ``Dinov2Model._inject_camera_token``. self.camera_token = torch.nn.Parameter(torch.empty(1, num_camera_tokens, dim, dtype=dtype, device=device)) else: self.camera_token = None def _interpolate_pos_encoding(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor: previous_dtype = x.dtype npatch = x.shape[1] - 1 N = self.position_embeddings.shape[1] - 1 pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype).float() if npatch == N and w == h: return pos_embed class_pos_embed = pos_embed[:, 0] patch_pos_embed = pos_embed[:, 1:] dim = x.shape[-1] ph = h // self.patch_size # patch grid height pw = w // self.patch_size # patch grid width M = int(math.sqrt(N)) assert N == M * M # Historical 0.1 offset preserves bicubic resample compatibility with # the original DINOv2 release; see the upstream PR for context. # ``scale_factor`` is interpreted as (height_scale, width_scale) by # ``F.interpolate`` so we must put the height scale FIRST. Earlier # revisions of this function had it swapped which only worked for # square inputs (e.g. CLIP-vision square crops); non-square inputs # like DA3-Small / DA3-Base multi-view paths exposed the bug. sh = float(ph + 0.1) / M sw = float(pw + 0.1) / M patch_pos_embed = F.interpolate( patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), scale_factor=(sh, sw), mode="bicubic", antialias=False, ) assert (ph, pw) == patch_pos_embed.shape[-2:] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) def forward(self, pixel_values): _, _, H, W = pixel_values.shape x = self.patch_embeddings(pixel_values) # TODO: mask_token? x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1) x = x + self._interpolate_pos_encoding(x, H, W) return x class Dinov2Model(torch.nn.Module): """DINOv2 vision backbone. Supports two operating modes: * **CLIP-vision DINOv2** (default): vanilla DINOv2-ViT used for ``ClipVisionModel`` and SigLIP-style image encoding. * **Depth Anything 3** extensions (opt-in via config keys): 2D RoPE, QK-norm, alternating local/global attention, camera-token injection, ``cat_token`` output and multi-layer feature extraction. These are enabled when the corresponding fields (``alt_start``, ``qknorm_start``, ``rope_start``, ``cat_token``) are set in ``config_dict``. When all of them are at their disabled defaults this module behaves identically to the historical ``Dinov2Model``. """ def __init__(self, config_dict, dtype, device, operations): super().__init__() num_layers = config_dict["num_hidden_layers"] dim = config_dict["hidden_size"] heads = config_dict["num_attention_heads"] layer_norm_eps = config_dict["layer_norm_eps"] use_swiglu_ffn = config_dict["use_swiglu_ffn"] patch_size = config_dict.get("patch_size", 14) image_size = config_dict.get("image_size", 518) use_mask_token = config_dict.get("use_mask_token", True) # DA3 extensions (all default to disabled). self.alt_start = config_dict.get("alt_start", -1) self.qknorm_start = config_dict.get("qknorm_start", -1) self.rope_start = config_dict.get("rope_start", -1) self.cat_token = config_dict.get("cat_token", False) rope_freq = config_dict.get("rope_freq", 100.0) self.embed_dim = dim self.patch_size = patch_size self.num_register_tokens = 0 self.patch_start_idx = 1 if self.rope_start != -1 and rope_freq > 0: self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) self._position_getter = _PositionGetter() else: self.rope = None self._position_getter = None # camera_token shape: (1, 2, dim) -> (ref_token, src_token). num_cam_tokens = 2 if self.alt_start != -1 else 0 self.embeddings = Dino2Embeddings( dim, dtype, device, operations, patch_size=patch_size, image_size=image_size, use_mask_token=use_mask_token, num_camera_tokens=num_cam_tokens, ) self.encoder = Dino2Encoder( dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn=use_swiglu_ffn, qknorm_start=self.qknorm_start, rope=self.rope, rope_start=self.rope_start, ) self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) # ------------------------------------------------------------------ # CLIP-vision-style forward (no DA3 extensions, no multi-layer output). # Kept for backward compatibility with ``ClipVisionModel.encode_image``. # ------------------------------------------------------------------ def forward(self, pixel_values, attention_mask=None, intermediate_output=None): x = self.embeddings(pixel_values) x, i = self.encoder(x, intermediate_output=intermediate_output) x = self.layernorm(x) pooled_output = x[:, 0, :] return x, i, pooled_output, None # ------------------------------------------------------------------ # Depth Anything 3 forward # ------------------------------------------------------------------ def _prepare_rope_positions(self, B, S, H, W, device): if self.rope is None: return None, None ph, pw = H // self.patch_size, W // self.patch_size pos = self._position_getter(B * S, ph, pw, device=device) # Shift so the cls/cam token at position 0 is reserved for "no diff". pos = pos + 1 cls_pos = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype) # Per-view local: real grid positions for patches, 0 for cls token. pos_local = torch.cat([cls_pos, pos], dim=1) # Global (across views): same grid positions; cls token still at 0, # but patches share the same positions in every view. pos_global = torch.cat([cls_pos, torch.zeros_like(pos) + 1], dim=1) return pos_local, pos_global def _inject_camera_token(self, x: torch.Tensor, B: int, S: int, cam_token: "torch.Tensor | None") -> torch.Tensor: # x: (B, S, N, C). Replace token at index 0 with the camera token. if cam_token is not None: inj = cam_token else: ct = comfy.model_management.cast_to_device(self.embeddings.camera_token, x.device, x.dtype) ref_token = ct[:, :1].expand(B, -1, -1) src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1) inj = torch.cat([ref_token, src_token], dim=1) x = x.clone() x[:, :, 0] = inj return x def get_intermediate_layers(self, pixel_values, out_layers, cam_token=None, ref_view_strategy="saddle_balanced", export_feat_layers=None): """Multi-layer DINOv2 feature extraction used by Depth Anything 3. Args: pixel_values: ``(B, S, 3, H, W)`` views or ``(B, 3, H, W)``. out_layers: indices into ``self.encoder.layer``. cam_token: optional ``(B, S, dim)`` camera token to inject at ``alt_start``. If ``None`` and the model has its own ``camera_token`` parameter, that is used. ref_view_strategy: when ``S >= 3`` and ``cam_token is None``, pick a reference view via this strategy and move it to position 0 right before the first alt-attention block. The original view order is restored on the way out. export_feat_layers: optional iterable of layer indices whose local attention outputs to also return as auxiliary features (``(B, S, N_patch, C)`` after final norm). Used by the multi-view path to expose intermediate features to the nested-architecture wrapper. Returns: ``(layer_outputs, aux_outputs)`` where ``layer_outputs`` is a list of ``(patch_tokens, cls_or_cam_token)`` tuples (one per ``out_layers`` entry) and ``aux_outputs`` is a list of ``(B, S, N_patch, C)`` features for ``export_feat_layers`` (empty list when not requested). """ if pixel_values.ndim == 4: pixel_values = pixel_values.unsqueeze(1) assert pixel_values.ndim == 5 and pixel_values.shape[2] == 3, \ f"expected (B,3,H,W) or (B,S,3,H,W); got {tuple(pixel_values.shape)}" B, S, _, H, W = pixel_values.shape # Patch + cls + (interpolated) pos embed for each view. x = pixel_values.reshape(B * S, 3, H, W) x = self.embeddings(x) # (B*S, 1+N, C) x = x.reshape(B, S, x.shape[-2], x.shape[-1]) # (B, S, 1+N, C) pos_local, pos_global = self._prepare_rope_positions(B, S, H, W, x.device) # ``optimized_attention`` is only used by blocks without QK-norm/RoPE # (vanilla DINOv2 path); enabling-aware blocks fall through to SDPA. optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) out_set = set(out_layers) export_set = set(export_feat_layers) if export_feat_layers else set() outputs: list[torch.Tensor] = [] aux_outputs: list[torch.Tensor] = [] local_x = x b_idx = None for i, blk in enumerate(self.encoder.layer): apply_rope = self.rope is not None and i >= self.rope_start block_rope = self.rope if apply_rope else None l_pos = pos_local if apply_rope else None g_pos = pos_global if apply_rope else None # Reference-view selection threshold: matches the upstream constant # ``THRESH_FOR_REF_SELECTION = 3``. Skipped when a user-supplied # cam_token is provided (camera info already pins the geometry). if (self.alt_start != -1 and i == self.alt_start - 1 and S >= THRESH_FOR_REF_SELECTION and cam_token is None): b_idx = select_reference_view(x, strategy=ref_view_strategy) x = reorder_by_reference(x, b_idx) local_x = reorder_by_reference(local_x, b_idx) if self.alt_start != -1 and i == self.alt_start: x = self._inject_camera_token(x, B, S, cam_token) if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1): # Global attention across views: flatten S into the seq dim. t = x.reshape(B, S * x.shape[-2], x.shape[-1]) p = g_pos.reshape(B, S * g_pos.shape[-2], g_pos.shape[-1]) if g_pos is not None else None t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope) x = t.reshape(B, S, x.shape[-2], x.shape[-1]) else: # Per-view local attention. t = x.reshape(B * S, x.shape[-2], x.shape[-1]) p = l_pos.reshape(B * S, l_pos.shape[-2], l_pos.shape[-1]) if l_pos is not None else None t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope) x = t.reshape(B, S, x.shape[-2], x.shape[-1]) local_x = x if i in out_set: if self.cat_token: out_x = torch.cat([local_x, x], dim=-1) else: out_x = x # Restore original view order on the way out so heads see views # in the user's expected order. if b_idx is not None and self.alt_start != -1: out_x = restore_original_order(out_x, b_idx) outputs.append(out_x) if i in export_set: aux = x if b_idx is not None and self.alt_start != -1: aux = restore_original_order(aux, b_idx) aux_outputs.append(aux) # Apply final norm. When ``cat_token`` is set, only the right half # ("global" features) is normalised; the left half is left as-is to # match the upstream DA3 head signature. normed: list[torch.Tensor] = [] cls_tokens: list[torch.Tensor] = [] for out_x in outputs: cls_tokens.append(out_x[:, :, 0]) if out_x.shape[-1] == self.embed_dim: normed.append(self.layernorm(out_x)) elif out_x.shape[-1] == self.embed_dim * 2: left = out_x[..., :self.embed_dim] right = self.layernorm(out_x[..., self.embed_dim:]) normed.append(torch.cat([left, right], dim=-1)) else: raise ValueError(f"Unexpected token width: {out_x.shape[-1]}") # Drop cls/cam token from the patch sequence. normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed] # Final layernorm + drop cls token from auxiliary features too. aux_normed = [self.layernorm(o)[..., 1 + self.num_register_tokens:, :] for o in aux_outputs] return list(zip(normed, cls_tokens)), aux_normed