diff --git a/comfy/ldm/sam3/detector.py b/comfy/ldm/sam3/detector.py index 6ae919a79..12d3a01ab 100644 --- a/comfy/ldm/sam3/detector.py +++ b/comfy/ldm/sam3/detector.py @@ -54,7 +54,7 @@ class SplitMHA(nn.Module): if mask is not None and mask.ndim == 2: mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast dtype = q.dtype # manual_cast may produce mixed dtypes - out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask) + out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask, low_precision_attention=False) return self.out_proj(out) diff --git a/comfy/ldm/sam3/sam.py b/comfy/ldm/sam3/sam.py index 272781d45..75cb457cf 100644 --- a/comfy/ldm/sam3/sam.py +++ b/comfy/ldm/sam3/sam.py @@ -40,7 +40,7 @@ class SAMAttention(nn.Module): q = self.q_proj(q) k = self.k_proj(k) v = self.v_proj(v) - return self.out_proj(optimized_attention(q, k, v, self.num_heads)) + return self.out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)) class TwoWayAttentionBlock(nn.Module): @@ -179,7 +179,7 @@ class Attention(nn.Module): q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0) if self.use_rope and freqs_cis is not None: q, k = apply_rope(q, k, freqs_cis) - return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True)) + return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True, low_precision_attention=False)) class Block(nn.Module): diff --git a/comfy/ldm/sam3/tracker.py b/comfy/ldm/sam3/tracker.py index 6ff6369d1..8f7481003 100644 --- a/comfy/ldm/sam3/tracker.py +++ b/comfy/ldm/sam3/tracker.py @@ -364,7 +364,7 @@ class SplitAttn(nn.Module): v = self.v_proj(v) if rope is not None: q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) - out = optimized_attention(q, k, v, self.num_heads) + out = optimized_attention(q, k, v, self.num_heads, low_precision_attention=False) return self.out_proj(out) @@ -657,7 +657,7 @@ class DecoupledMemoryAttnLayer(nn.Module): v = self.self_attn_v_proj(normed) if rope is not None: q, k = apply_rope_memory(q, k, rope, self.num_heads, 0) - x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads)) + x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)) # Decoupled cross-attention: fuse image and memory projections normed = self.norm2(x) @@ -668,7 +668,7 @@ class DecoupledMemoryAttnLayer(nn.Module): v = self.cross_attn_v_proj(memory) if rope is not None: q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) - x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads)) + x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)) # FFN x = x + self.linear2(F.gelu(self.linear1(self.norm3(x))))