From a64414a0fc934379b8a2c25e948aa5b9d8156f69 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 19 Mar 2026 14:11:02 +0200 Subject: [PATCH] simplified var attn --- comfy/ldm/modules/attention.py | 131 ++++++++------------------------- comfy/ldm/seedvr/model.py | 8 +- 2 files changed, 35 insertions(+), 104 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index dd8c6ba72..17b185b69 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -19,15 +19,9 @@ if model_management.xformers_enabled(): import xformers.ops SAGE_ATTENTION_IS_AVAILABLE = False -SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = False try: from sageattention import sageattn SAGE_ATTENTION_IS_AVAILABLE = True - try: - from sageattention import sageattn_varlen - SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = True - except: - pass except ImportError as e: if model_management.sage_attention_enabled(): if e.name == "sageattention": @@ -93,13 +87,6 @@ def default(val, d): return val return d -def var_attn_arg(kwargs): - cu_seqlens_q = kwargs.get("cu_seqlens_q", None) - cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) - max_seqlen_q = kwargs.get("max_seqlen_q", None) - max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) - assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True" - return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k # feedforward class GEGLU(nn.Module): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops): @@ -424,14 +411,13 @@ except: @wrap_attn def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): - var_length = kwargs.get("var_length", False) b = q.shape[0] dim_head = q.shape[-1] # check to make sure xformers isn't broken disabled_xformers = False if BROKEN_XFORMERS: - if b * heads > 65535 and not var_length: + if b * heads > 65535: disabled_xformers = True if not disabled_xformers: @@ -439,27 +425,9 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh disabled_xformers = True if disabled_xformers: - return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, var_length=var_length, **kwargs) + return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs) - if var_length: - if not skip_reshape: - total_tokens, hidden_dim = q.shape - dim_head = hidden_dim // heads - q = q.view(1, total_tokens, heads, dim_head) - k = k.view(1, total_tokens, heads, dim_head) - v = v.view(1, total_tokens, heads, dim_head) - else: - if q.ndim == 3: - q = q.unsqueeze(0) - if k.ndim == 3: - k = k.unsqueeze(0) - if v.ndim == 3: - v = v.unsqueeze(0) - dim_head = q.shape[-1] - - target_output_shape = (q.shape[1], -1) - b = 1 - elif skip_reshape: + if skip_reshape: # b h k d -> b k h d q, k, v = map( lambda t: t.permute(0, 2, 1, 3), @@ -473,11 +441,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh (q, k, v), ) - if var_length: - cu_seqlens_q, _, _, _ = var_attn_arg(kwargs) - seq_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - mask = xformers.ops.BlockDiagonalMask.from_seqlens(seq_lens_q=seq_lens, seq_lens_k=seq_lens) - elif mask is not None: + if mask is not None: # add a singleton batch dimension if mask.ndim == 2: mask = mask.unsqueeze(0) @@ -499,8 +463,6 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) - if var_length: - return out.reshape(*target_output_shape) if skip_output_reshape: out = out.permute(0, 2, 1, 3) else: @@ -518,28 +480,7 @@ else: @wrap_attn def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): - var_length = kwargs.get("var_length", False) - if var_length: - cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs) - if not skip_reshape: - # assumes 2D q, k,v [total_tokens, embed_dim] - total_tokens, embed_dim = q.shape - head_dim = embed_dim // heads - q = q.view(total_tokens, heads, head_dim) - k = k.view(k.shape[0], heads, head_dim) - v = v.view(v.shape[0], heads, head_dim) - - b = q.size(0) - dim_head = q.shape[-1] - q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long()) - k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long()) - v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long()) - - mask = None - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - elif skip_reshape: + if skip_reshape: b, _, _, dim_head = q.shape else: b, _, dim_head = q.shape @@ -557,10 +498,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.ndim == 3: mask = mask.unsqueeze(1) - if SDP_BATCH_LIMIT >= b or var_length: + if SDP_BATCH_LIMIT >= b: out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) - if var_length: - return out.contiguous().transpose(1, 2).values() if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -584,19 +523,11 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha @wrap_attn def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): - var_length = kwargs.get("var_length", False) + if kwargs.get("low_precision_attention", True) is False: + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs) + exception_fallback = False - if var_length: - if not skip_reshape: - total_tokens, hidden_dim = q.shape - dim_head = hidden_dim // heads - q, k, v = [t.view(total_tokens, heads, dim_head) for t in (q, k, v)] - b, _, dim_head = q.shape - # skips batched code - mask = None - tensor_layout = "VAR" - target_output_shape = (q.shape[0], -1) - elif skip_reshape: + if skip_reshape: b, _, _, dim_head = q.shape tensor_layout = "HND" else: @@ -617,14 +548,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= mask = mask.unsqueeze(1) try: - if var_length and not SAGE_ATTENTION_VAR_LENGTH_AVAILABLE: - raise ValueError("Sage Attention two is required to run variable length attention.") - elif var_length: - cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs) - sm_scale = 1.0 / (q.shape[-1] ** 0.5) - out = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, is_causal=False, sm_scale=sm_scale) - else: - out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) + out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) except Exception as e: logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) exception_fallback = True @@ -634,7 +558,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= lambda t: t.transpose(1, 2), (q, k, v), ) - return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, var_length=var_length, **kwargs) + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs) if tensor_layout == "HND": if not skip_output_reshape: @@ -645,8 +569,6 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= if skip_output_reshape: out = out.transpose(1, 2) else: - if var_length: - return out.view(*target_output_shape) out = out.reshape(b, -1, heads * dim_head) return out @@ -758,15 +680,6 @@ except AttributeError as error: @wrap_attn def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): - var_length = kwargs.get("var_length", False) - if var_length: - cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs) - return flash_attn_varlen_func( - q=q, k=k, v=v, - cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, - dropout_p=0.0, softmax_scale=None, causal=False - ) if skip_reshape: b, _, _, dim_head = q.shape else: @@ -804,7 +717,27 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return out +def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False): + if not skip_reshape: + # assumes 2D q, k,v [total_tokens, embed_dim] + total_tokens, embed_dim = q.shape + head_dim = embed_dim // heads + q = q.view(total_tokens, heads, head_dim) + k = k.view(k.shape[0], heads, head_dim) + v = v.view(v.shape[0], heads, head_dim) + q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long()) + k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long()) + v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long()) + + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) + + return out.transpose(1, 2).values() + +optimized_var_attention = var_attention_pytorch optimized_attention = attention_basic if model_management.sage_attention_enabled(): diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 9c9edadce..01664052f 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -8,7 +8,7 @@ from math import ceil, pi import torch from itertools import chain from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding -from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.attention import optimized_var_attention from comfy.rmsnorm import RMSNorm from torch.nn.modules.utils import _triple from torch import nn @@ -767,19 +767,17 @@ class NaSwinAttention(NaMMAttention): else: vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) - out = optimized_attention( + out = optimized_var_attention( q=concat_win(vid_q, txt_q), k=concat_win(vid_k, txt_k), v=concat_win(vid_v, txt_v), - heads=self.heads, skip_reshape=True, var_length = True, + heads=self.heads, skip_reshape=True, cu_seqlens_q=cache_win( "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() ), cu_seqlens_k=cache_win( "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() ), - max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()), - max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()), ) vid_out, txt_out = unconcat_win(out)