fixed attn (couldn't use apply_rope for dino3)

This commit is contained in:
Yousef Rafat 2026-02-12 23:35:57 +02:00
parent 8e90bdc1cc
commit 0e239dc39b

View File

@ -3,7 +3,6 @@ import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.ldm.flux.math import apply_rope
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
class DINOv3ViTMLP(nn.Module):
@ -18,6 +17,26 @@ class DINOv3ViTMLP(nn.Module):
def forward(self, x):
return self.down_proj(self.act_fn(self.up_proj(x)))
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, **kwargs):
num_tokens = q.shape[-2]
num_patches = sin.shape[-2]
num_prefix_tokens = num_tokens - num_patches
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
return q, k
class DINOv3ViTAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, device, dtype, operations):
super().__init__()
@ -54,28 +73,7 @@ class DINOv3ViTAttention(nn.Module):
if position_embeddings is not None:
cos, sin = position_embeddings
num_tokens = query_states.shape[-2]
num_patches = cos.shape[-2]
num_prefix_tokens = num_tokens - num_patches
q_prefix, q_patches = query_states.split((num_prefix_tokens, num_patches), dim=-2)
k_prefix, k_patches = key_states.split((num_prefix_tokens, num_patches), dim=-2)
cos = cos[..., :self.head_dim // 2]
sin = sin[..., :self.head_dim // 2]
f_cis_0 = torch.stack([cos, sin], dim=-1)
f_cis_1 = torch.stack([-sin, cos], dim=-1)
freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1)
while freqs_cis.ndim < q_patches.ndim + 1:
freqs_cis = freqs_cis.unsqueeze(0)
q_patches, k_patches = apply_rope(q_patches, k_patches, freqs_cis)
query_states = torch.cat((q_prefix, q_patches), dim=-2)
key_states = torch.cat((k_prefix, k_patches), dim=-2)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
attn = optimized_attention_for_device(query_states.device, mask=False)
@ -83,6 +81,7 @@ class DINOv3ViTAttention(nn.Module):
query_states, key_states, value_states, self.num_heads, attention_mask, skip_reshape=True, skip_output_reshape=True
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
attn_output = self.o_proj(attn_output)