mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 16:29:25 +08:00
78 lines
3.1 KiB
Python
78 lines
3.1 KiB
Python
import torch
|
|
|
|
from comfy.ldm.modules import attention as _attention
|
|
|
|
|
|
def _var_attention_qkv(q, k, v, heads, skip_reshape):
|
|
if skip_reshape:
|
|
return q, k, v, q.shape[-1]
|
|
total_tokens, embed_dim = q.shape
|
|
head_dim = embed_dim // heads
|
|
return (
|
|
q.view(total_tokens, heads, head_dim),
|
|
k.view(k.shape[0], heads, head_dim),
|
|
v.view(v.shape[0], heads, head_dim),
|
|
head_dim,
|
|
)
|
|
|
|
|
|
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
|
|
if skip_output_reshape:
|
|
return out
|
|
return out.reshape(-1, heads * head_dim)
|
|
|
|
|
|
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
|
|
if cu_seqlens.dtype not in (torch.int32, torch.int64):
|
|
raise ValueError(f"{name} must use an integer dtype")
|
|
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
|
|
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
|
|
if cu_seqlens[0].item() != 0:
|
|
raise ValueError(f"{name} must start at 0")
|
|
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
|
|
raise ValueError(f"{name} must be strictly increasing")
|
|
if cu_seqlens[-1].item() != token_count:
|
|
raise ValueError(f"{name} does not match token count")
|
|
|
|
|
|
def _split_indices(cu_seqlens):
|
|
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
|
|
|
|
|
|
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
|
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
|
|
|
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
|
|
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
|
|
if cu_seqlens_k[-1].item() != v.shape[0]:
|
|
raise ValueError("cu_seqlens_k does not match v token count")
|
|
|
|
q_split_indices = _split_indices(cu_seqlens_q)
|
|
k_split_indices = _split_indices(cu_seqlens_k)
|
|
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
|
|
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
|
|
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
|
|
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
|
|
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
|
|
|
|
out = []
|
|
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
|
|
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
|
|
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
|
|
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
|
|
out_dtype = q_i.dtype
|
|
if _attention.optimized_attention is _attention.attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
|
|
q_i = q_i.to(torch.bfloat16)
|
|
k_i = k_i.to(torch.bfloat16)
|
|
v_i = v_i.to(torch.bfloat16)
|
|
out_i = _attention.optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
|
|
if out_i.dtype != out_dtype:
|
|
out_i = out_i.to(out_dtype)
|
|
out.append(out_i.squeeze(0).permute(1, 0, 2))
|
|
|
|
out = torch.cat(out, dim=0)
|
|
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
|
|
|
|
|
optimized_var_attention = var_attention_optimized_split
|