From 7e62f8cc9fa8aa973072388fd27ef1a9ab3a4cc0 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 19 Dec 2025 20:23:39 +0200 Subject: [PATCH] added var length attention and fixed the vae issue --- comfy/ldm/modules/attention.py | 46 ++++++++++++++++-- comfy/ldm/seedvr/model.py | 85 +++++++++------------------------- comfy/ldm/seedvr/vae.py | 2 +- 3 files changed, 63 insertions(+), 70 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a8800ded0..332c65ffb 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -32,7 +32,7 @@ except ImportError as e: FLASH_ATTENTION_IS_AVAILABLE = False try: - from flash_attn import flash_attn_func + from flash_attn import flash_attn_func, flash_attn_varlen_func FLASH_ATTENTION_IS_AVAILABLE = True except ImportError: if model_management.flash_attention_enabled(): @@ -473,8 +473,29 @@ else: SDP_BATCH_LIMIT = 2**31 @wrap_attn -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): - if skip_reshape: +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): + if var_length: + cu_seqlens_q = kwargs.get("cu_seqlens_q", None) + cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) + assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True" + if not skip_reshape: + # assumes 2D q, k,v [total_tokens, embed_dim] + total_tokens, embed_dim = q.shape + head_dim = embed_dim // heads + q = q.view(total_tokens, heads, head_dim) + k = k.view(k.shape[0], heads, head_dim) + v = v.view(v.shape[0], heads, head_dim) + + b = q.size(0); dim_head = q.shape[-1] + q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long()) + k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long()) + v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long()) + + mask = None + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + elif skip_reshape: b, _, _, dim_head = q.shape else: b, _, dim_head = q.shape @@ -492,8 +513,10 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha if mask.ndim == 3: mask = mask.unsqueeze(1) - if SDP_BATCH_LIMIT >= b: + if SDP_BATCH_LIMIT >= b or var_length: out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + if var_length: + return out.contiguous().transpose(1, 2).values() if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) @@ -583,7 +606,20 @@ except AttributeError as error: assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}" @wrap_attn -def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): +def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs): + if var_length: + cu_seqlens_q = kwargs.get("cu_seqlens_q", None) + cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) + max_seqlen_q = kwargs.get("max_seqlen_q", None) + max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) + assert max_seqlen_q != None, "max_seqlen_q shouldn't be None when var_length is True" + assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True" + return flash_attn_varlen_func( + q=q, k=k, v=v, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, + dropout_p=0.0, softmax_scale=None, causal=False + ) if skip_reshape: b, _, _, dim_head = q.shape else: diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 119799592..0825a12ba 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -15,6 +15,11 @@ from comfy.rmsnorm import RMSNorm from torch.nn.modules.utils import _triple from torch import nn import math +import logging +try: + from flash_attn import flash_attn_varlen_func +except: + logging.warning("Best results will be achieved with flash attention enabled for SeedVR2") class Cache: def __init__(self, disable=False, prefix="", cache=None): @@ -735,70 +740,21 @@ class NaSwinAttention(NaMMAttention): ) else: vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) - - # TODO: continue testing - v_lens = vid_len_win.cpu().tolist() - t_lens_batch = txt_len.cpu().tolist() - win_counts = window_count.cpu().tolist() - vq_l = torch.split(vid_q, v_lens) - vk_l = torch.split(vid_k, v_lens) - vv_l = torch.split(vid_v, v_lens) - - tv_batch = torch.split(txt_v, t_lens_batch) - tv_l = [] - for i, count in enumerate(win_counts): - tv_l.extend([tv_batch[i]] * count) - - current_txt_len = txt_q.shape[0] - expected_batch_len = sum(t_lens_batch) - - if current_txt_len != expected_batch_len: - t_lens_win = txt_len_win.cpu().tolist() - - tq_l = torch.split(txt_q, t_lens_win) - tk_l = torch.split(txt_k, t_lens_win) - else: - tq_batch = torch.split(txt_q, t_lens_batch) - tk_batch = torch.split(txt_k, t_lens_batch) - - tq_l = [] - tk_l = [] - for i, count in enumerate(win_counts): - tq_l.extend([tq_batch[i]] * count) - tk_l.extend([tk_batch[i]] * count) - - q_list = [torch.cat([v, t], dim=0) for v, t in zip(vq_l, tq_l)] - k_list = [torch.cat([v, t], dim=0) for v, t in zip(vk_l, tk_l)] - v_list = [torch.cat([v, t], dim=0) for v, t in zip(vv_l, tv_l)] - - q = rnn_utils.pad_sequence(q_list, batch_first=True) - k = rnn_utils.pad_sequence(k_list, batch_first=True) - v = rnn_utils.pad_sequence(v_list, batch_first=True) - - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - B, Heads, Max_L, _ = q.shape - combined_lens = [v.shape[0] + t.shape[0] for v, t in zip(vq_l, tq_l)] - - attn_mask = torch.zeros((B, 1, 1, Max_L), device=q.device, dtype=q.dtype) - idx = torch.arange(Max_L, device=q.device).unsqueeze(0).expand(B, Max_L) - len_tensor = torch.tensor(combined_lens, device=q.device).unsqueeze(1) - - padding_mask = idx >= len_tensor - attn_mask.masked_fill_(padding_mask.unsqueeze(1).unsqueeze(1), float('-inf')) - - out = optimized_attention(q, k, v, heads=self.heads, mask=attn_mask, skip_reshape=True, skip_output_reshape=True) - - out = out.transpose(1, 2) - - out_flat_list = [] - for i, length in enumerate(combined_lens): - out_flat_list.append(out[i, :length]) - - out = torch.cat(out_flat_list, dim=0) + out = optimized_attention( + q=concat_win(vid_q, txt_q).bfloat16(), + k=concat_win(vid_k, txt_k).bfloat16(), + v=concat_win(vid_v, txt_v).bfloat16(), + heads=self.heads, skip_reshape=True, var_length = True, + cu_seqlens_q=cache_win( + "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + ), + cu_seqlens_k=cache_win( + "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + ), + max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()), + max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()), + ) vid_out, txt_out = unconcat_win(out) @@ -807,7 +763,8 @@ class NaSwinAttention(NaMMAttention): vid_out = window_reverse(vid_out) device = comfy.model_management.get_torch_device() - vid_out, txt_out = vid_out.to(device), txt_out.to(device) + dtype = next(self.proj_out.parameters()).dtype + vid_out, txt_out = vid_out.to(device=device, dtype=dtype), txt_out.to(device=device, dtype=dtype) self.proj_out = self.proj_out.to(device) vid_out, txt_out = self.proj_out(vid_out, txt_out) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 277f7a697..ac5e20b8d 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -1565,7 +1565,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): latent = latent.unsqueeze(2) target_device = comfy.model_management.get_torch_device() - self.to(target_device) + self.decoder.to(target_device) x = super().decode(latent).squeeze(2) input = rearrange(self.original_image_video[0], "c t h w -> t c h w")