mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
fixed attn (couldn't use apply_rope for dino3)
This commit is contained in:
parent
8e90bdc1cc
commit
0e239dc39b
@ -3,7 +3,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
|
||||
|
||||
class DINOv3ViTMLP(nn.Module):
|
||||
@ -18,6 +17,26 @@ class DINOv3ViTMLP(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.up_proj(x)))
|
||||
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, **kwargs):
|
||||
num_tokens = q.shape[-2]
|
||||
num_patches = sin.shape[-2]
|
||||
num_prefix_tokens = num_tokens - num_patches
|
||||
|
||||
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
|
||||
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
|
||||
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
|
||||
|
||||
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
|
||||
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
|
||||
|
||||
return q, k
|
||||
|
||||
class DINOv3ViTAttention(nn.Module):
|
||||
def __init__(self, hidden_size, num_attention_heads, device, dtype, operations):
|
||||
super().__init__()
|
||||
@ -54,28 +73,7 @@ class DINOv3ViTAttention(nn.Module):
|
||||
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
|
||||
num_tokens = query_states.shape[-2]
|
||||
num_patches = cos.shape[-2]
|
||||
num_prefix_tokens = num_tokens - num_patches
|
||||
|
||||
q_prefix, q_patches = query_states.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
k_prefix, k_patches = key_states.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
|
||||
cos = cos[..., :self.head_dim // 2]
|
||||
sin = sin[..., :self.head_dim // 2]
|
||||
|
||||
f_cis_0 = torch.stack([cos, sin], dim=-1)
|
||||
f_cis_1 = torch.stack([-sin, cos], dim=-1)
|
||||
freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1)
|
||||
|
||||
while freqs_cis.ndim < q_patches.ndim + 1:
|
||||
freqs_cis = freqs_cis.unsqueeze(0)
|
||||
|
||||
q_patches, k_patches = apply_rope(q_patches, k_patches, freqs_cis)
|
||||
|
||||
query_states = torch.cat((q_prefix, q_patches), dim=-2)
|
||||
key_states = torch.cat((k_prefix, k_patches), dim=-2)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
attn = optimized_attention_for_device(query_states.device, mask=False)
|
||||
|
||||
@ -83,6 +81,7 @@ class DINOv3ViTAttention(nn.Module):
|
||||
query_states, key_states, value_states, self.num_heads, attention_mask, skip_reshape=True, skip_output_reshape=True
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user