ComfyUI/comfy/ldm/trellis2/attention.py
Yousef R. Gamaleldin 6ea2e5b288 init
2026-01-30 23:34:48 +02:00

195 lines
8.5 KiB
Python

import torch
import math
from comfy.ldm.modules.attention import optimized_attention
from typing import Tuple, Union, List
from vae import VarLenTensor
def sparse_windowed_scaled_dot_product_self_attention(
qkv,
window_size: int,
shift_window: Tuple[int, int, int] = (0, 0, 0)
):
serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}'
serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
if serialization_spatial_cache is None:
fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window)
qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args))
else:
fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
if optimized_attention.__name__ == 'attention_xformers':
if 'xops' not in globals():
import xformers.ops as xops
q, k, v = qkv_feats.unbind(dim=1)
q = q.unsqueeze(0) # [1, M, H, C]
k = k.unsqueeze(0) # [1, M, H, C]
v = v.unsqueeze(0) # [1, M, H, C]
out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
elif optimized_attention.__name__ == 'attention_flash':
if 'flash_attn' not in globals():
import flash_attn
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C]
out = out[bwd_indices] # [T, H, C]
return qkv.replace(out)
def calc_window_partition(
tensor,
window_size: Union[int, Tuple[int, ...]],
shift_window: Union[int, Tuple[int, ...]] = 0,
) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
DIM = tensor.coords.shape[1] - 1
shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
shifted_coords = tensor.coords.clone().detach()
shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)]
NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
fwd_indices = torch.argsort(shifted_indices)
bwd_indices = torch.empty_like(fwd_indices)
bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
seq_lens = torch.bincount(shifted_indices)
mask = seq_lens != 0
seq_lens = seq_lens[mask]
if optimized_attention.__name__ == 'attention_xformers':
if 'xops' not in globals():
import xformers.ops as xops
attn_func_args = {
'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
}
elif optimized_attention.__name__ == 'attention_flash':
attn_func_args = {
'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(),
'max_seqlen': torch.max(seq_lens)
}
return fwd_indices, bwd_indices, seq_lens, attn_func_args
def sparse_scaled_dot_product_attention(*args, **kwargs):
arg_names_dict = {
1: ['qkv'],
2: ['q', 'kv'],
3: ['q', 'k', 'v']
}
num_all_args = len(args) + len(kwargs)
for key in arg_names_dict[num_all_args][len(args):]:
assert key in kwargs, f"Missing argument {key}"
if num_all_args == 1:
qkv = args[0] if len(args) > 0 else kwargs['qkv']
device = qkv.device
s = qkv
q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
kv_seqlen = q_seqlen
qkv = qkv.feats # [T, 3, H, C]
elif num_all_args == 2:
q = args[0] if len(args) > 0 else kwargs['q']
kv = args[1] if len(args) > 1 else kwargs['kv']
device = q.device
if isinstance(q, VarLenTensor):
s = q
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
q = q.feats # [T_Q, H, C]
else:
s = None
N, L, H, C = q.shape
q_seqlen = [L] * N
q = q.reshape(N * L, H, C) # [T_Q, H, C]
if isinstance(kv, VarLenTensor):
kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
kv = kv.feats # [T_KV, 2, H, C]
else:
N, L, _, H, C = kv.shape
kv_seqlen = [L] * N
kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
elif num_all_args == 3:
q = args[0] if len(args) > 0 else kwargs['q']
k = args[1] if len(args) > 1 else kwargs['k']
v = args[2] if len(args) > 2 else kwargs['v']
device = q.device
if isinstance(q, VarLenTensor):
s = q
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
q = q.feats # [T_Q, H, Ci]
else:
s = None
N, L, H, CI = q.shape
q_seqlen = [L] * N
q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
if isinstance(k, VarLenTensor):
kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
k = k.feats # [T_KV, H, Ci]
v = v.feats # [T_KV, H, Co]
else:
N, L, H, CI, CO = *k.shape, v.shape[-1]
kv_seqlen = [L] * N
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
if optimized_attention.__name__ == 'attention_xformers':
if 'xops' not in globals():
import xformers.ops as xops
if num_all_args == 1:
q, k, v = qkv.unbind(dim=1)
elif num_all_args == 2:
k, v = kv.unbind(dim=1)
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
out = xops.memory_efficient_attention(q, k, v, mask)[0]
elif optimized_attention.__name__ == 'attention_flash':
if 'flash_attn' not in globals():
import flash_attn
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
if num_all_args in [2, 3]:
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
if num_all_args == 1:
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
elif num_all_args == 2:
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
elif num_all_args == 3:
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
elif optimized_attention.__name__ == 'flash_attn_3': # TODO
if 'flash_attn_3' not in globals():
import flash_attn_interface as flash_attn_3
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
if num_all_args == 1:
q, k, v = qkv.unbind(dim=1)
cu_seqlens_kv = cu_seqlens_q.clone()
max_q_seqlen = max_kv_seqlen = max(q_seqlen)
elif num_all_args == 2:
k, v = kv.unbind(dim=1)
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
max_q_seqlen = max(q_seqlen)
max_kv_seqlen = max(kv_seqlen)
elif num_all_args == 3:
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
max_q_seqlen = max(q_seqlen)
max_kv_seqlen = max(kv_seqlen)
out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen)
if s is not None:
return s.replace(out)
else:
return out.reshape(N, L, H, -1)