mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
52 lines
1.9 KiB
Python
52 lines
1.9 KiB
Python
import torch
|
|
|
|
from comfy.ldm.modules import attention as _attention
|
|
|
|
|
|
def _var_attention_qkv(q, k, v, heads, skip_reshape):
|
|
if skip_reshape:
|
|
return q, k, v, q.shape[-1]
|
|
total_tokens, embed_dim = q.shape
|
|
head_dim = embed_dim // heads
|
|
return (
|
|
q.view(total_tokens, heads, head_dim),
|
|
k.view(k.shape[0], heads, head_dim),
|
|
v.view(v.shape[0], heads, head_dim),
|
|
head_dim,
|
|
)
|
|
|
|
|
|
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
|
|
if skip_output_reshape:
|
|
return out
|
|
return out.reshape(-1, heads * head_dim)
|
|
|
|
|
|
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
|
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
|
|
|
q_split_indices = cu_seqlens_q[1:-1]
|
|
k_split_indices = cu_seqlens_k[1:-1]
|
|
if k.shape[0] != v.shape[0]:
|
|
raise ValueError("cu_seqlens_k does not match v token count")
|
|
|
|
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
|
|
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
|
|
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
|
|
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
|
|
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
|
|
|
|
out = []
|
|
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
|
|
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
|
|
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
|
|
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
|
|
out_i = _attention.optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
|
|
out.append(out_i.squeeze(0).permute(1, 0, 2))
|
|
|
|
out = torch.cat(out, dim=0)
|
|
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
|
|
|
|
|
optimized_var_attention = var_attention_optimized_split
|