added var length attention and fixed the vae issue

This commit is contained in:
Yousef Rafat 2025-12-19 20:23:39 +02:00
parent 74621b9d86
commit 7e62f8cc9f
3 changed files with 63 additions and 70 deletions

View File

@ -32,7 +32,7 @@ except ImportError as e:
FLASH_ATTENTION_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
from flash_attn import flash_attn_func, flash_attn_varlen_func
FLASH_ATTENTION_IS_AVAILABLE = True
except ImportError:
if model_management.flash_attention_enabled():
@ -473,8 +473,29 @@ else:
SDP_BATCH_LIMIT = 2**31
@wrap_attn
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
if var_length:
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
if not skip_reshape:
# assumes 2D q, k,v [total_tokens, embed_dim]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
q = q.view(total_tokens, heads, head_dim)
k = k.view(k.shape[0], heads, head_dim)
v = v.view(v.shape[0], heads, head_dim)
b = q.size(0); dim_head = q.shape[-1]
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
mask = None
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
elif skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
@ -492,8 +513,10 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.ndim == 3:
mask = mask.unsqueeze(1)
if SDP_BATCH_LIMIT >= b:
if SDP_BATCH_LIMIT >= b or var_length:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if var_length:
return out.contiguous().transpose(1, 2).values()
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -583,7 +606,20 @@ except AttributeError as error:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
@wrap_attn
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
if var_length:
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
max_seqlen_q = kwargs.get("max_seqlen_q", None)
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
assert max_seqlen_q != None, "max_seqlen_q shouldn't be None when var_length is True"
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
return flash_attn_varlen_func(
q=q, k=k, v=v,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k,
dropout_p=0.0, softmax_scale=None, causal=False
)
if skip_reshape:
b, _, _, dim_head = q.shape
else:

View File

@ -15,6 +15,11 @@ from comfy.rmsnorm import RMSNorm
from torch.nn.modules.utils import _triple
from torch import nn
import math
import logging
try:
from flash_attn import flash_attn_varlen_func
except:
logging.warning("Best results will be achieved with flash attention enabled for SeedVR2")
class Cache:
def __init__(self, disable=False, prefix="", cache=None):
@ -735,70 +740,21 @@ class NaSwinAttention(NaMMAttention):
)
else:
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
# TODO: continue testing
v_lens = vid_len_win.cpu().tolist()
t_lens_batch = txt_len.cpu().tolist()
win_counts = window_count.cpu().tolist()
vq_l = torch.split(vid_q, v_lens)
vk_l = torch.split(vid_k, v_lens)
vv_l = torch.split(vid_v, v_lens)
tv_batch = torch.split(txt_v, t_lens_batch)
tv_l = []
for i, count in enumerate(win_counts):
tv_l.extend([tv_batch[i]] * count)
current_txt_len = txt_q.shape[0]
expected_batch_len = sum(t_lens_batch)
if current_txt_len != expected_batch_len:
t_lens_win = txt_len_win.cpu().tolist()
tq_l = torch.split(txt_q, t_lens_win)
tk_l = torch.split(txt_k, t_lens_win)
else:
tq_batch = torch.split(txt_q, t_lens_batch)
tk_batch = torch.split(txt_k, t_lens_batch)
tq_l = []
tk_l = []
for i, count in enumerate(win_counts):
tq_l.extend([tq_batch[i]] * count)
tk_l.extend([tk_batch[i]] * count)
q_list = [torch.cat([v, t], dim=0) for v, t in zip(vq_l, tq_l)]
k_list = [torch.cat([v, t], dim=0) for v, t in zip(vk_l, tk_l)]
v_list = [torch.cat([v, t], dim=0) for v, t in zip(vv_l, tv_l)]
q = rnn_utils.pad_sequence(q_list, batch_first=True)
k = rnn_utils.pad_sequence(k_list, batch_first=True)
v = rnn_utils.pad_sequence(v_list, batch_first=True)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
B, Heads, Max_L, _ = q.shape
combined_lens = [v.shape[0] + t.shape[0] for v, t in zip(vq_l, tq_l)]
attn_mask = torch.zeros((B, 1, 1, Max_L), device=q.device, dtype=q.dtype)
idx = torch.arange(Max_L, device=q.device).unsqueeze(0).expand(B, Max_L)
len_tensor = torch.tensor(combined_lens, device=q.device).unsqueeze(1)
padding_mask = idx >= len_tensor
attn_mask.masked_fill_(padding_mask.unsqueeze(1).unsqueeze(1), float('-inf'))
out = optimized_attention(q, k, v, heads=self.heads, mask=attn_mask, skip_reshape=True, skip_output_reshape=True)
out = out.transpose(1, 2)
out_flat_list = []
for i, length in enumerate(combined_lens):
out_flat_list.append(out[i, :length])
out = torch.cat(out_flat_list, dim=0)
out = optimized_attention(
q=concat_win(vid_q, txt_q).bfloat16(),
k=concat_win(vid_k, txt_k).bfloat16(),
v=concat_win(vid_v, txt_v).bfloat16(),
heads=self.heads, skip_reshape=True, var_length = True,
cu_seqlens_q=cache_win(
"vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
),
cu_seqlens_k=cache_win(
"vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
),
max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()),
max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()),
)
vid_out, txt_out = unconcat_win(out)
@ -807,7 +763,8 @@ class NaSwinAttention(NaMMAttention):
vid_out = window_reverse(vid_out)
device = comfy.model_management.get_torch_device()
vid_out, txt_out = vid_out.to(device), txt_out.to(device)
dtype = next(self.proj_out.parameters()).dtype
vid_out, txt_out = vid_out.to(device=device, dtype=dtype), txt_out.to(device=device, dtype=dtype)
self.proj_out = self.proj_out.to(device)
vid_out, txt_out = self.proj_out(vid_out, txt_out)

View File

@ -1565,7 +1565,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
latent = latent.unsqueeze(2)
target_device = comfy.model_management.get_torch_device()
self.to(target_device)
self.decoder.to(target_device)
x = super().decode(latent).squeeze(2)
input = rearrange(self.original_image_video[0], "c t h w -> t c h w")