diff --git a/comfy/ldm/rf_detr/rfdetr_v4.py b/comfy/ldm/rf_detr/rfdetr_v4.py index 4793a20ac..f44a88279 100644 --- a/comfy/ldm/rf_detr/rfdetr_v4.py +++ b/comfy/ldm/rf_detr/rfdetr_v4.py @@ -6,6 +6,7 @@ import torch.nn as nn import torch.nn.functional as F import torchvision import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention_for_device COCO_CLASSES = [ 'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat', @@ -202,11 +203,29 @@ class SCDown(nn.Module): return self.cv2(self.cv1(x)) +class SelfAttention(nn.Module): + def __init__(self, embed_dim, num_heads, device=None, dtype=None, operations=None): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + self.q_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + self.k_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + self.v_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + self.out_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + + def forward(self, query, key, value, attn_mask=None): + optimized_attention = optimized_attention_for_device(query.device, False, small_input=True) + q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value) + out = optimized_attention(q, k, v, heads=self.num_heads, mask=attn_mask) + return self.out_proj(out) + + class _TransformerEncoderLayer(nn.Module): """Single AIFI encoder layer (pre- or post-norm, GELU by default).""" def __init__(self, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None): super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, device=device, dtype=dtype) + self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations) self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype) self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype) self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) @@ -215,7 +234,7 @@ class _TransformerEncoderLayer(nn.Module): def forward(self, src, src_mask=None, pos_embed=None): q = k = src if pos_embed is None else src + pos_embed - src2, _ = self.self_attn(q, k, value=src, attn_mask=src_mask) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask) src = self.norm1(src + src2) src2 = self.linear2(self.activation(self.linear1(src))) return self.norm2(src + src2) @@ -307,7 +326,7 @@ class HybridEncoder(nn.Module): for i, enc_idx in enumerate(self.use_encoder_idx): h, w = proj[enc_idx].shape[2:] src = proj[enc_idx].flatten(2).permute(0, 2, 1) - pe = getattr(self, f'pos_embed{enc_idx}').to(src.device) + pe = getattr(self, f'pos_embed{enc_idx}').to(device=src.device, dtype=src.dtype) for layer in self.encoder[i].layers: src = layer(src, pos_embed=pe) proj[enc_idx] = src.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous() @@ -412,7 +431,7 @@ class MLP(nn.Module): class TransformerDecoderLayer(nn.Module): def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4, device=None, dtype=None, operations=None): super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, device=device, dtype=dtype) + self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations) self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points, device=device, dtype=dtype, operations=operations) self.gateway = Gate(d_model, device=device, dtype=dtype, operations=operations) @@ -423,7 +442,7 @@ class TransformerDecoderLayer(nn.Module): def forward(self, target, ref_pts, value, spatial_shapes, attn_mask=None, query_pos=None): q = k = target if query_pos is None else target + query_pos - t2, _ = self.self_attn(q, k, value=target, attn_mask=attn_mask) + t2 = self.self_attn(q, k, value=target, attn_mask=attn_mask) target = self.norm1(target + t2) t2 = self.cross_attn( target if query_pos is None else target + query_pos, @@ -451,7 +470,7 @@ def weighting_function(reg_max, up, reg_scale): def distance2bbox(points, distance, reg_scale): """Decode edge-distances → cxcywh boxes.""" - rs = abs(reg_scale) + rs = abs(reg_scale).to(dtype=points.dtype) x1 = points[..., 0] - (0.5 * rs + distance[..., 0]) * (points[..., 2] / rs) y1 = points[..., 1] - (0.5 * rs + distance[..., 1]) * (points[..., 3] / rs) x2 = points[..., 0] + (0.5 * rs + distance[..., 2]) * (points[..., 2] / rs) @@ -469,7 +488,7 @@ class Integral(nn.Module): def forward(self, x, project): shape = x.shape x = F.softmax(x.reshape(-1, self.reg_max + 1), 1) - x = F.linear(x, project.to(x.device)).reshape(-1, 4) + x = F.linear(x, project.to(device=x.device, dtype=x.dtype)).reshape(-1, 4) return x.reshape(list(shape[:-1]) + [-1]) @@ -644,7 +663,7 @@ class DFINETransformer(nn.Module): return torch.cat(flat, 1), shapes def _decoder_input(self, memory: torch.Tensor, spatial_shapes): - anchors, valid_mask = self.anchors, self.valid_mask + anchors, valid_mask = self.anchors.to(memory.dtype), self.valid_mask if memory.shape[0] > 1: anchors = anchors.repeat(memory.shape[0], 1, 1)