Use comfy attention

This commit is contained in:
kijai 2026-02-27 18:54:32 +02:00
parent 3929172408
commit 57ba8555fe

View File

@ -6,6 +6,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torchvision
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
COCO_CLASSES = [
'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat',
@ -202,11 +203,29 @@ class SCDown(nn.Module):
return self.cv2(self.cv1(x))
class SelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, device=None, dtype=None, operations=None):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.k_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.v_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.out_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
def forward(self, query, key, value, attn_mask=None):
optimized_attention = optimized_attention_for_device(query.device, False, small_input=True)
q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)
out = optimized_attention(q, k, v, heads=self.num_heads, mask=attn_mask)
return self.out_proj(out)
class _TransformerEncoderLayer(nn.Module):
"""Single AIFI encoder layer (pre- or post-norm, GELU by default)."""
def __init__(self, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, device=device, dtype=dtype)
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
@ -215,7 +234,7 @@ class _TransformerEncoderLayer(nn.Module):
def forward(self, src, src_mask=None, pos_embed=None):
q = k = src if pos_embed is None else src + pos_embed
src2, _ = self.self_attn(q, k, value=src, attn_mask=src_mask)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)
src = self.norm1(src + src2)
src2 = self.linear2(self.activation(self.linear1(src)))
return self.norm2(src + src2)
@ -307,7 +326,7 @@ class HybridEncoder(nn.Module):
for i, enc_idx in enumerate(self.use_encoder_idx):
h, w = proj[enc_idx].shape[2:]
src = proj[enc_idx].flatten(2).permute(0, 2, 1)
pe = getattr(self, f'pos_embed{enc_idx}').to(src.device)
pe = getattr(self, f'pos_embed{enc_idx}').to(device=src.device, dtype=src.dtype)
for layer in self.encoder[i].layers:
src = layer(src, pos_embed=pe)
proj[enc_idx] = src.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous()
@ -412,7 +431,7 @@ class MLP(nn.Module):
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, device=device, dtype=dtype)
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points, device=device, dtype=dtype, operations=operations)
self.gateway = Gate(d_model, device=device, dtype=dtype, operations=operations)
@ -423,7 +442,7 @@ class TransformerDecoderLayer(nn.Module):
def forward(self, target, ref_pts, value, spatial_shapes, attn_mask=None, query_pos=None):
q = k = target if query_pos is None else target + query_pos
t2, _ = self.self_attn(q, k, value=target, attn_mask=attn_mask)
t2 = self.self_attn(q, k, value=target, attn_mask=attn_mask)
target = self.norm1(target + t2)
t2 = self.cross_attn(
target if query_pos is None else target + query_pos,
@ -451,7 +470,7 @@ def weighting_function(reg_max, up, reg_scale):
def distance2bbox(points, distance, reg_scale):
"""Decode edge-distances → cxcywh boxes."""
rs = abs(reg_scale)
rs = abs(reg_scale).to(dtype=points.dtype)
x1 = points[..., 0] - (0.5 * rs + distance[..., 0]) * (points[..., 2] / rs)
y1 = points[..., 1] - (0.5 * rs + distance[..., 1]) * (points[..., 3] / rs)
x2 = points[..., 0] + (0.5 * rs + distance[..., 2]) * (points[..., 2] / rs)
@ -469,7 +488,7 @@ class Integral(nn.Module):
def forward(self, x, project):
shape = x.shape
x = F.softmax(x.reshape(-1, self.reg_max + 1), 1)
x = F.linear(x, project.to(x.device)).reshape(-1, 4)
x = F.linear(x, project.to(device=x.device, dtype=x.dtype)).reshape(-1, 4)
return x.reshape(list(shape[:-1]) + [-1])
@ -644,7 +663,7 @@ class DFINETransformer(nn.Module):
return torch.cat(flat, 1), shapes
def _decoder_input(self, memory: torch.Tensor, spatial_shapes):
anchors, valid_mask = self.anchors, self.valid_mask
anchors, valid_mask = self.anchors.to(memory.dtype), self.valid_mask
if memory.shape[0] > 1:
anchors = anchors.repeat(memory.shape[0], 1, 1)