diff --git a/comfy/ldm/depth_anything_3/camera.py b/comfy/ldm/depth_anything_3/camera.py index 8a5c9361e..46eeba314 100644 --- a/comfy/ldm/depth_anything_3/camera.py +++ b/comfy/ldm/depth_anything_3/camera.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from comfy.ldm.modules.attention import optimized_attention_for_device from .transform import affine_inverse, extri_intri_to_pose_encoding @@ -74,11 +75,10 @@ class _Attention(nn.Module): def forward(self, x): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) - qkv = qkv.permute(2, 0, 3, 1, 4) # 3, B, h, N, d - q, k, v = qkv.unbind(0) - out = F.scaled_dot_product_attention(q, k, v) - out = out.transpose(1, 2).reshape(B, N, C) + qkv = self.qkv(x).reshape(B, N, 3, C) + q, k, v = qkv.unbind(2) # each (B, N, C) + attn_fn = optimized_attention_for_device(x.device, small_input=True) + out = attn_fn(q, k, v, heads=self.num_heads) return self.proj(out)