diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index ef04556da..9cb231e28 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn from comfy.ldm.modules.attention import optimized_attention_for_device -from comfy.ldm.flux.math import apply_rope from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale class DINOv3ViTMLP(nn.Module): @@ -18,6 +17,26 @@ class DINOv3ViTMLP(nn.Module): def forward(self, x): return self.down_proj(self.act_fn(self.up_proj(x))) +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) +def apply_rotary_pos_emb(q, k, cos, sin, **kwargs): + num_tokens = q.shape[-2] + num_patches = sin.shape[-2] + num_prefix_tokens = num_tokens - num_patches + + q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2) + k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2) + + q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin) + k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin) + + q = torch.cat((q_prefix_tokens, q_patches), dim=-2) + k = torch.cat((k_prefix_tokens, k_patches), dim=-2) + + return q, k + class DINOv3ViTAttention(nn.Module): def __init__(self, hidden_size, num_attention_heads, device, dtype, operations): super().__init__() @@ -54,28 +73,7 @@ class DINOv3ViTAttention(nn.Module): if position_embeddings is not None: cos, sin = position_embeddings - - num_tokens = query_states.shape[-2] - num_patches = cos.shape[-2] - num_prefix_tokens = num_tokens - num_patches - - q_prefix, q_patches = query_states.split((num_prefix_tokens, num_patches), dim=-2) - k_prefix, k_patches = key_states.split((num_prefix_tokens, num_patches), dim=-2) - - cos = cos[..., :self.head_dim // 2] - sin = sin[..., :self.head_dim // 2] - - f_cis_0 = torch.stack([cos, sin], dim=-1) - f_cis_1 = torch.stack([-sin, cos], dim=-1) - freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1) - - while freqs_cis.ndim < q_patches.ndim + 1: - freqs_cis = freqs_cis.unsqueeze(0) - - q_patches, k_patches = apply_rope(q_patches, k_patches, freqs_cis) - - query_states = torch.cat((q_prefix, q_patches), dim=-2) - key_states = torch.cat((k_prefix, k_patches), dim=-2) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) attn = optimized_attention_for_device(query_states.device, mask=False) @@ -83,6 +81,7 @@ class DINOv3ViTAttention(nn.Module): query_states, key_states, value_states, self.num_heads, attention_mask, skip_reshape=True, skip_output_reshape=True ) + attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() attn_output = self.o_proj(attn_output)