mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 04:30:51 +08:00
added var length attention and fixed the vae issue
This commit is contained in:
parent
74621b9d86
commit
7e62f8cc9f
@ -32,7 +32,7 @@ except ImportError as e:
|
||||
|
||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
FLASH_ATTENTION_IS_AVAILABLE = True
|
||||
except ImportError:
|
||||
if model_management.flash_attention_enabled():
|
||||
@ -473,8 +473,29 @@ else:
|
||||
SDP_BATCH_LIMIT = 2**31
|
||||
|
||||
@wrap_attn
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
if skip_reshape:
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||
if var_length:
|
||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
|
||||
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
if not skip_reshape:
|
||||
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||
total_tokens, embed_dim = q.shape
|
||||
head_dim = embed_dim // heads
|
||||
q = q.view(total_tokens, heads, head_dim)
|
||||
k = k.view(k.shape[0], heads, head_dim)
|
||||
v = v.view(v.shape[0], heads, head_dim)
|
||||
|
||||
b = q.size(0); dim_head = q.shape[-1]
|
||||
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||
|
||||
mask = None
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
elif skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
@ -492,8 +513,10 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
if SDP_BATCH_LIMIT >= b or var_length:
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if var_length:
|
||||
return out.contiguous().transpose(1, 2).values()
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
@ -583,7 +606,20 @@ except AttributeError as error:
|
||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||
|
||||
@wrap_attn
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||
if var_length:
|
||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
|
||||
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
|
||||
assert max_seqlen_q != None, "max_seqlen_q shouldn't be None when var_length is True"
|
||||
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
return flash_attn_varlen_func(
|
||||
q=q, k=k, v=v,
|
||||
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k,
|
||||
dropout_p=0.0, softmax_scale=None, causal=False
|
||||
)
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
|
||||
@ -15,6 +15,11 @@ from comfy.rmsnorm import RMSNorm
|
||||
from torch.nn.modules.utils import _triple
|
||||
from torch import nn
|
||||
import math
|
||||
import logging
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
except:
|
||||
logging.warning("Best results will be achieved with flash attention enabled for SeedVR2")
|
||||
|
||||
class Cache:
|
||||
def __init__(self, disable=False, prefix="", cache=None):
|
||||
@ -735,70 +740,21 @@ class NaSwinAttention(NaMMAttention):
|
||||
)
|
||||
else:
|
||||
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
||||
|
||||
# TODO: continue testing
|
||||
v_lens = vid_len_win.cpu().tolist()
|
||||
t_lens_batch = txt_len.cpu().tolist()
|
||||
win_counts = window_count.cpu().tolist()
|
||||
|
||||
vq_l = torch.split(vid_q, v_lens)
|
||||
vk_l = torch.split(vid_k, v_lens)
|
||||
vv_l = torch.split(vid_v, v_lens)
|
||||
|
||||
tv_batch = torch.split(txt_v, t_lens_batch)
|
||||
tv_l = []
|
||||
for i, count in enumerate(win_counts):
|
||||
tv_l.extend([tv_batch[i]] * count)
|
||||
|
||||
current_txt_len = txt_q.shape[0]
|
||||
expected_batch_len = sum(t_lens_batch)
|
||||
|
||||
if current_txt_len != expected_batch_len:
|
||||
t_lens_win = txt_len_win.cpu().tolist()
|
||||
|
||||
tq_l = torch.split(txt_q, t_lens_win)
|
||||
tk_l = torch.split(txt_k, t_lens_win)
|
||||
else:
|
||||
tq_batch = torch.split(txt_q, t_lens_batch)
|
||||
tk_batch = torch.split(txt_k, t_lens_batch)
|
||||
|
||||
tq_l = []
|
||||
tk_l = []
|
||||
for i, count in enumerate(win_counts):
|
||||
tq_l.extend([tq_batch[i]] * count)
|
||||
tk_l.extend([tk_batch[i]] * count)
|
||||
|
||||
q_list = [torch.cat([v, t], dim=0) for v, t in zip(vq_l, tq_l)]
|
||||
k_list = [torch.cat([v, t], dim=0) for v, t in zip(vk_l, tk_l)]
|
||||
v_list = [torch.cat([v, t], dim=0) for v, t in zip(vv_l, tv_l)]
|
||||
|
||||
q = rnn_utils.pad_sequence(q_list, batch_first=True)
|
||||
k = rnn_utils.pad_sequence(k_list, batch_first=True)
|
||||
v = rnn_utils.pad_sequence(v_list, batch_first=True)
|
||||
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
B, Heads, Max_L, _ = q.shape
|
||||
combined_lens = [v.shape[0] + t.shape[0] for v, t in zip(vq_l, tq_l)]
|
||||
|
||||
attn_mask = torch.zeros((B, 1, 1, Max_L), device=q.device, dtype=q.dtype)
|
||||
idx = torch.arange(Max_L, device=q.device).unsqueeze(0).expand(B, Max_L)
|
||||
len_tensor = torch.tensor(combined_lens, device=q.device).unsqueeze(1)
|
||||
|
||||
padding_mask = idx >= len_tensor
|
||||
attn_mask.masked_fill_(padding_mask.unsqueeze(1).unsqueeze(1), float('-inf'))
|
||||
|
||||
out = optimized_attention(q, k, v, heads=self.heads, mask=attn_mask, skip_reshape=True, skip_output_reshape=True)
|
||||
|
||||
out = out.transpose(1, 2)
|
||||
|
||||
out_flat_list = []
|
||||
for i, length in enumerate(combined_lens):
|
||||
out_flat_list.append(out[i, :length])
|
||||
|
||||
out = torch.cat(out_flat_list, dim=0)
|
||||
out = optimized_attention(
|
||||
q=concat_win(vid_q, txt_q).bfloat16(),
|
||||
k=concat_win(vid_k, txt_k).bfloat16(),
|
||||
v=concat_win(vid_v, txt_v).bfloat16(),
|
||||
heads=self.heads, skip_reshape=True, var_length = True,
|
||||
cu_seqlens_q=cache_win(
|
||||
"vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
|
||||
),
|
||||
cu_seqlens_k=cache_win(
|
||||
"vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
|
||||
),
|
||||
max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()),
|
||||
max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()),
|
||||
)
|
||||
|
||||
vid_out, txt_out = unconcat_win(out)
|
||||
|
||||
@ -807,7 +763,8 @@ class NaSwinAttention(NaMMAttention):
|
||||
vid_out = window_reverse(vid_out)
|
||||
|
||||
device = comfy.model_management.get_torch_device()
|
||||
vid_out, txt_out = vid_out.to(device), txt_out.to(device)
|
||||
dtype = next(self.proj_out.parameters()).dtype
|
||||
vid_out, txt_out = vid_out.to(device=device, dtype=dtype), txt_out.to(device=device, dtype=dtype)
|
||||
self.proj_out = self.proj_out.to(device)
|
||||
vid_out, txt_out = self.proj_out(vid_out, txt_out)
|
||||
|
||||
|
||||
@ -1565,7 +1565,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
latent = latent.unsqueeze(2)
|
||||
|
||||
target_device = comfy.model_management.get_torch_device()
|
||||
self.to(target_device)
|
||||
self.decoder.to(target_device)
|
||||
x = super().decode(latent).squeeze(2)
|
||||
|
||||
input = rearrange(self.original_image_video[0], "c t h w -> t c h w")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user