mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-08 13:20:50 +08:00
added var length attention and fixed the vae issue
This commit is contained in:
parent
74621b9d86
commit
7e62f8cc9f
@ -32,7 +32,7 @@ except ImportError as e:
|
|||||||
|
|
||||||
FLASH_ATTENTION_IS_AVAILABLE = False
|
FLASH_ATTENTION_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_func
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||||
FLASH_ATTENTION_IS_AVAILABLE = True
|
FLASH_ATTENTION_IS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
if model_management.flash_attention_enabled():
|
if model_management.flash_attention_enabled():
|
||||||
@ -473,8 +473,29 @@ else:
|
|||||||
SDP_BATCH_LIMIT = 2**31
|
SDP_BATCH_LIMIT = 2**31
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||||
if skip_reshape:
|
if var_length:
|
||||||
|
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||||
|
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
|
||||||
|
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||||
|
if not skip_reshape:
|
||||||
|
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||||
|
total_tokens, embed_dim = q.shape
|
||||||
|
head_dim = embed_dim // heads
|
||||||
|
q = q.view(total_tokens, heads, head_dim)
|
||||||
|
k = k.view(k.shape[0], heads, head_dim)
|
||||||
|
v = v.view(v.shape[0], heads, head_dim)
|
||||||
|
|
||||||
|
b = q.size(0); dim_head = q.shape[-1]
|
||||||
|
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||||
|
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||||
|
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
elif skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
else:
|
else:
|
||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
@ -492,8 +513,10 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
if mask.ndim == 3:
|
if mask.ndim == 3:
|
||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
if SDP_BATCH_LIMIT >= b:
|
if SDP_BATCH_LIMIT >= b or var_length:
|
||||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
|
if var_length:
|
||||||
|
return out.contiguous().transpose(1, 2).values()
|
||||||
if not skip_output_reshape:
|
if not skip_output_reshape:
|
||||||
out = (
|
out = (
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
@ -583,7 +606,20 @@ except AttributeError as error:
|
|||||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||||
|
|
||||||
@wrap_attn
|
@wrap_attn
|
||||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||||
|
if var_length:
|
||||||
|
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||||
|
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
|
||||||
|
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||||
|
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
|
||||||
|
assert max_seqlen_q != None, "max_seqlen_q shouldn't be None when var_length is True"
|
||||||
|
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q=q, k=k, v=v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k,
|
||||||
|
dropout_p=0.0, softmax_scale=None, causal=False
|
||||||
|
)
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -15,6 +15,11 @@ from comfy.rmsnorm import RMSNorm
|
|||||||
from torch.nn.modules.utils import _triple
|
from torch.nn.modules.utils import _triple
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import math
|
import math
|
||||||
|
import logging
|
||||||
|
try:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
except:
|
||||||
|
logging.warning("Best results will be achieved with flash attention enabled for SeedVR2")
|
||||||
|
|
||||||
class Cache:
|
class Cache:
|
||||||
def __init__(self, disable=False, prefix="", cache=None):
|
def __init__(self, disable=False, prefix="", cache=None):
|
||||||
@ -735,70 +740,21 @@ class NaSwinAttention(NaMMAttention):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
|
||||||
|
|
||||||
# TODO: continue testing
|
|
||||||
v_lens = vid_len_win.cpu().tolist()
|
|
||||||
t_lens_batch = txt_len.cpu().tolist()
|
|
||||||
win_counts = window_count.cpu().tolist()
|
|
||||||
|
|
||||||
vq_l = torch.split(vid_q, v_lens)
|
out = optimized_attention(
|
||||||
vk_l = torch.split(vid_k, v_lens)
|
q=concat_win(vid_q, txt_q).bfloat16(),
|
||||||
vv_l = torch.split(vid_v, v_lens)
|
k=concat_win(vid_k, txt_k).bfloat16(),
|
||||||
|
v=concat_win(vid_v, txt_v).bfloat16(),
|
||||||
tv_batch = torch.split(txt_v, t_lens_batch)
|
heads=self.heads, skip_reshape=True, var_length = True,
|
||||||
tv_l = []
|
cu_seqlens_q=cache_win(
|
||||||
for i, count in enumerate(win_counts):
|
"vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
|
||||||
tv_l.extend([tv_batch[i]] * count)
|
),
|
||||||
|
cu_seqlens_k=cache_win(
|
||||||
current_txt_len = txt_q.shape[0]
|
"vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
|
||||||
expected_batch_len = sum(t_lens_batch)
|
),
|
||||||
|
max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()),
|
||||||
if current_txt_len != expected_batch_len:
|
max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()),
|
||||||
t_lens_win = txt_len_win.cpu().tolist()
|
)
|
||||||
|
|
||||||
tq_l = torch.split(txt_q, t_lens_win)
|
|
||||||
tk_l = torch.split(txt_k, t_lens_win)
|
|
||||||
else:
|
|
||||||
tq_batch = torch.split(txt_q, t_lens_batch)
|
|
||||||
tk_batch = torch.split(txt_k, t_lens_batch)
|
|
||||||
|
|
||||||
tq_l = []
|
|
||||||
tk_l = []
|
|
||||||
for i, count in enumerate(win_counts):
|
|
||||||
tq_l.extend([tq_batch[i]] * count)
|
|
||||||
tk_l.extend([tk_batch[i]] * count)
|
|
||||||
|
|
||||||
q_list = [torch.cat([v, t], dim=0) for v, t in zip(vq_l, tq_l)]
|
|
||||||
k_list = [torch.cat([v, t], dim=0) for v, t in zip(vk_l, tk_l)]
|
|
||||||
v_list = [torch.cat([v, t], dim=0) for v, t in zip(vv_l, tv_l)]
|
|
||||||
|
|
||||||
q = rnn_utils.pad_sequence(q_list, batch_first=True)
|
|
||||||
k = rnn_utils.pad_sequence(k_list, batch_first=True)
|
|
||||||
v = rnn_utils.pad_sequence(v_list, batch_first=True)
|
|
||||||
|
|
||||||
q = q.transpose(1, 2)
|
|
||||||
k = k.transpose(1, 2)
|
|
||||||
v = v.transpose(1, 2)
|
|
||||||
|
|
||||||
B, Heads, Max_L, _ = q.shape
|
|
||||||
combined_lens = [v.shape[0] + t.shape[0] for v, t in zip(vq_l, tq_l)]
|
|
||||||
|
|
||||||
attn_mask = torch.zeros((B, 1, 1, Max_L), device=q.device, dtype=q.dtype)
|
|
||||||
idx = torch.arange(Max_L, device=q.device).unsqueeze(0).expand(B, Max_L)
|
|
||||||
len_tensor = torch.tensor(combined_lens, device=q.device).unsqueeze(1)
|
|
||||||
|
|
||||||
padding_mask = idx >= len_tensor
|
|
||||||
attn_mask.masked_fill_(padding_mask.unsqueeze(1).unsqueeze(1), float('-inf'))
|
|
||||||
|
|
||||||
out = optimized_attention(q, k, v, heads=self.heads, mask=attn_mask, skip_reshape=True, skip_output_reshape=True)
|
|
||||||
|
|
||||||
out = out.transpose(1, 2)
|
|
||||||
|
|
||||||
out_flat_list = []
|
|
||||||
for i, length in enumerate(combined_lens):
|
|
||||||
out_flat_list.append(out[i, :length])
|
|
||||||
|
|
||||||
out = torch.cat(out_flat_list, dim=0)
|
|
||||||
|
|
||||||
vid_out, txt_out = unconcat_win(out)
|
vid_out, txt_out = unconcat_win(out)
|
||||||
|
|
||||||
@ -807,7 +763,8 @@ class NaSwinAttention(NaMMAttention):
|
|||||||
vid_out = window_reverse(vid_out)
|
vid_out = window_reverse(vid_out)
|
||||||
|
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
vid_out, txt_out = vid_out.to(device), txt_out.to(device)
|
dtype = next(self.proj_out.parameters()).dtype
|
||||||
|
vid_out, txt_out = vid_out.to(device=device, dtype=dtype), txt_out.to(device=device, dtype=dtype)
|
||||||
self.proj_out = self.proj_out.to(device)
|
self.proj_out = self.proj_out.to(device)
|
||||||
vid_out, txt_out = self.proj_out(vid_out, txt_out)
|
vid_out, txt_out = self.proj_out(vid_out, txt_out)
|
||||||
|
|
||||||
|
|||||||
@ -1565,7 +1565,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
latent = latent.unsqueeze(2)
|
latent = latent.unsqueeze(2)
|
||||||
|
|
||||||
target_device = comfy.model_management.get_torch_device()
|
target_device = comfy.model_management.get_torch_device()
|
||||||
self.to(target_device)
|
self.decoder.to(target_device)
|
||||||
x = super().decode(latent).squeeze(2)
|
x = super().decode(latent).squeeze(2)
|
||||||
|
|
||||||
input = rearrange(self.original_image_video[0], "c t h w -> t c h w")
|
input = rearrange(self.original_image_video[0], "c t h w -> t c h w")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user