mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-25 18:02:37 +08:00
parent
ef8f3cbcdc
commit
084e08c6e2
@ -54,7 +54,7 @@ class SplitMHA(nn.Module):
|
||||
if mask is not None and mask.ndim == 2:
|
||||
mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast
|
||||
dtype = q.dtype # manual_cast may produce mixed dtypes
|
||||
out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask)
|
||||
out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask, low_precision_attention=False)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ class SAMAttention(nn.Module):
|
||||
q = self.q_proj(q)
|
||||
k = self.k_proj(k)
|
||||
v = self.v_proj(v)
|
||||
return self.out_proj(optimized_attention(q, k, v, self.num_heads))
|
||||
return self.out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
|
||||
|
||||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
@ -179,7 +179,7 @@ class Attention(nn.Module):
|
||||
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0)
|
||||
if self.use_rope and freqs_cis is not None:
|
||||
q, k = apply_rope(q, k, freqs_cis)
|
||||
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
|
||||
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True, low_precision_attention=False))
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
@ -364,7 +364,7 @@ class SplitAttn(nn.Module):
|
||||
v = self.v_proj(v)
|
||||
if rope is not None:
|
||||
q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope)
|
||||
out = optimized_attention(q, k, v, self.num_heads)
|
||||
out = optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)
|
||||
return self.out_proj(out)
|
||||
|
||||
|
||||
@ -657,7 +657,7 @@ class DecoupledMemoryAttnLayer(nn.Module):
|
||||
v = self.self_attn_v_proj(normed)
|
||||
if rope is not None:
|
||||
q, k = apply_rope_memory(q, k, rope, self.num_heads, 0)
|
||||
x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads))
|
||||
x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
|
||||
|
||||
# Decoupled cross-attention: fuse image and memory projections
|
||||
normed = self.norm2(x)
|
||||
@ -668,7 +668,7 @@ class DecoupledMemoryAttnLayer(nn.Module):
|
||||
v = self.cross_attn_v_proj(memory)
|
||||
if rope is not None:
|
||||
q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope)
|
||||
x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads))
|
||||
x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
|
||||
|
||||
# FFN
|
||||
x = x + self.linear2(F.gelu(self.linear1(self.norm3(x))))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user