simplified var attn

This commit is contained in:
Yousef Rafat 2026-03-19 14:11:02 +02:00
parent bc4fd2cd11
commit a64414a0fc
2 changed files with 35 additions and 104 deletions

View File

@ -19,15 +19,9 @@ if model_management.xformers_enabled():
import xformers.ops
SAGE_ATTENTION_IS_AVAILABLE = False
SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTENTION_IS_AVAILABLE = True
try:
from sageattention import sageattn_varlen
SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = True
except:
pass
except ImportError as e:
if model_management.sage_attention_enabled():
if e.name == "sageattention":
@ -93,13 +87,6 @@ def default(val, d):
return val
return d
def var_attn_arg(kwargs):
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
max_seqlen_q = kwargs.get("max_seqlen_q", None)
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
@ -424,14 +411,13 @@ except:
@wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
var_length = kwargs.get("var_length", False)
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
disabled_xformers = False
if BROKEN_XFORMERS:
if b * heads > 65535 and not var_length:
if b * heads > 65535:
disabled_xformers = True
if not disabled_xformers:
@ -439,27 +425,9 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, var_length=var_length, **kwargs)
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
if var_length:
if not skip_reshape:
total_tokens, hidden_dim = q.shape
dim_head = hidden_dim // heads
q = q.view(1, total_tokens, heads, dim_head)
k = k.view(1, total_tokens, heads, dim_head)
v = v.view(1, total_tokens, heads, dim_head)
else:
if q.ndim == 3:
q = q.unsqueeze(0)
if k.ndim == 3:
k = k.unsqueeze(0)
if v.ndim == 3:
v = v.unsqueeze(0)
dim_head = q.shape[-1]
target_output_shape = (q.shape[1], -1)
b = 1
elif skip_reshape:
if skip_reshape:
# b h k d -> b k h d
q, k, v = map(
lambda t: t.permute(0, 2, 1, 3),
@ -473,11 +441,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
(q, k, v),
)
if var_length:
cu_seqlens_q, _, _, _ = var_attn_arg(kwargs)
seq_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
mask = xformers.ops.BlockDiagonalMask.from_seqlens(seq_lens_q=seq_lens, seq_lens_k=seq_lens)
elif mask is not None:
if mask is not None:
# add a singleton batch dimension
if mask.ndim == 2:
mask = mask.unsqueeze(0)
@ -499,8 +463,6 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if var_length:
return out.reshape(*target_output_shape)
if skip_output_reshape:
out = out.permute(0, 2, 1, 3)
else:
@ -518,28 +480,7 @@ else:
@wrap_attn
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
var_length = kwargs.get("var_length", False)
if var_length:
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
if not skip_reshape:
# assumes 2D q, k,v [total_tokens, embed_dim]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
q = q.view(total_tokens, heads, head_dim)
k = k.view(k.shape[0], heads, head_dim)
v = v.view(v.shape[0], heads, head_dim)
b = q.size(0)
dim_head = q.shape[-1]
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
mask = None
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
elif skip_reshape:
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
@ -557,10 +498,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.ndim == 3:
mask = mask.unsqueeze(1)
if SDP_BATCH_LIMIT >= b or var_length:
if SDP_BATCH_LIMIT >= b:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if var_length:
return out.contiguous().transpose(1, 2).values()
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -584,19 +523,11 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
var_length = kwargs.get("var_length", False)
if kwargs.get("low_precision_attention", True) is False:
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
exception_fallback = False
if var_length:
if not skip_reshape:
total_tokens, hidden_dim = q.shape
dim_head = hidden_dim // heads
q, k, v = [t.view(total_tokens, heads, dim_head) for t in (q, k, v)]
b, _, dim_head = q.shape
# skips batched code
mask = None
tensor_layout = "VAR"
target_output_shape = (q.shape[0], -1)
elif skip_reshape:
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
else:
@ -617,14 +548,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
mask = mask.unsqueeze(1)
try:
if var_length and not SAGE_ATTENTION_VAR_LENGTH_AVAILABLE:
raise ValueError("Sage Attention two is required to run variable length attention.")
elif var_length:
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
sm_scale = 1.0 / (q.shape[-1] ** 0.5)
out = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, is_causal=False, sm_scale=sm_scale)
else:
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
exception_fallback = True
@ -634,7 +558,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.transpose(1, 2),
(q, k, v),
)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, var_length=var_length, **kwargs)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
if tensor_layout == "HND":
if not skip_output_reshape:
@ -645,8 +569,6 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
if skip_output_reshape:
out = out.transpose(1, 2)
else:
if var_length:
return out.view(*target_output_shape)
out = out.reshape(b, -1, heads * dim_head)
return out
@ -758,15 +680,6 @@ except AttributeError as error:
@wrap_attn
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
var_length = kwargs.get("var_length", False)
if var_length:
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
return flash_attn_varlen_func(
q=q, k=k, v=v,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k,
dropout_p=0.0, softmax_scale=None, causal=False
)
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@ -804,7 +717,27 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
)
return out
def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False):
if not skip_reshape:
# assumes 2D q, k,v [total_tokens, embed_dim]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
q = q.view(total_tokens, heads, head_dim)
k = k.view(k.shape[0], heads, head_dim)
v = v.view(v.shape[0], heads, head_dim)
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
return out.transpose(1, 2).values()
optimized_var_attention = var_attention_pytorch
optimized_attention = attention_basic
if model_management.sage_attention_enabled():

View File

@ -8,7 +8,7 @@ from math import ceil, pi
import torch
from itertools import chain
from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.attention import optimized_var_attention
from comfy.rmsnorm import RMSNorm
from torch.nn.modules.utils import _triple
from torch import nn
@ -767,19 +767,17 @@ class NaSwinAttention(NaMMAttention):
else:
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
out = optimized_attention(
out = optimized_var_attention(
q=concat_win(vid_q, txt_q),
k=concat_win(vid_k, txt_k),
v=concat_win(vid_v, txt_v),
heads=self.heads, skip_reshape=True, var_length = True,
heads=self.heads, skip_reshape=True,
cu_seqlens_q=cache_win(
"vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
),
cu_seqlens_k=cache_win(
"vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
),
max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()),
max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()),
)
vid_out, txt_out = unconcat_win(out)