mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-04 13:40:52 +08:00
124 lines
4.8 KiB
Python
124 lines
4.8 KiB
Python
import math
|
|
from typing import Tuple, Union
|
|
|
|
import torch
|
|
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
from comfy.ldm.trellis2.vae import VarLenTensor
|
|
|
|
|
|
def _build_cu_seqlens(seqlen, device):
|
|
"""Cumulative-offset tensor from a list/tensor of per-item sequence lengths."""
|
|
if isinstance(seqlen, torch.Tensor):
|
|
return torch.cat(
|
|
[torch.zeros(1, dtype=torch.int32, device=seqlen.device), torch.cumsum(seqlen, dim=0).int()],
|
|
dim=0,
|
|
).to(device)
|
|
return torch.cat(
|
|
[torch.tensor([0]), torch.cumsum(torch.tensor(seqlen), dim=0)],
|
|
).int().to(device)
|
|
|
|
|
|
def _layout_to_feats(t):
|
|
"""Returns (sparse_template_or_None, seqlen, feats[T, H, C])."""
|
|
if isinstance(t, VarLenTensor):
|
|
seqlen = [t.layout[i].stop - t.layout[i].start for i in range(t.shape[0])]
|
|
return t, seqlen, t.feats
|
|
N, L = t.shape[:2]
|
|
return None, [L] * N, t.reshape(-1, *t.shape[-2:])
|
|
|
|
|
|
def dense_attention(q, k, v, **kwargs):
|
|
"""q, k, v: [B, L, H, C]. Permutes for comfy's [B, H, L, C] convention."""
|
|
heads = q.shape[2]
|
|
q = q.permute(0, 2, 1, 3)
|
|
k = k.permute(0, 2, 1, 3)
|
|
v = v.permute(0, 2, 1, 3)
|
|
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs)
|
|
return out.permute(0, 2, 1, 3)
|
|
|
|
|
|
def sparse_attention(q, k, v, **kwargs):
|
|
"""
|
|
Varlen attention for SparseTensor inputs. Each of q, k, v may be a VarLenTensor
|
|
(sparse) or dense [B, L, H, C]. Output type matches q. Backend dispatch lives
|
|
in comfy.ldm.modules.attention.optimized_attention; we just build cu_seqlens
|
|
from the layouts.
|
|
"""
|
|
s, q_seqlen, q_feats = _layout_to_feats(q)
|
|
_, kv_seqlen, k_feats = _layout_to_feats(k)
|
|
_, _, v_feats = _layout_to_feats(v)
|
|
heads = q_feats.shape[1]
|
|
|
|
device = q_feats.device
|
|
cu_seqlens_q = _build_cu_seqlens(q_seqlen, device)
|
|
cu_seqlens_kv = _build_cu_seqlens(kv_seqlen, device)
|
|
|
|
out = optimized_attention(
|
|
q_feats, k_feats, v_feats, heads,
|
|
cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv,
|
|
max_seqlen_q=max(q_seqlen), max_seqlen_kv=max(kv_seqlen),
|
|
skip_reshape=True, skip_output_reshape=True,
|
|
**kwargs,
|
|
)
|
|
|
|
if s is not None:
|
|
return s.replace(out)
|
|
N, L = q.shape[:2]
|
|
return out.reshape(N, L, heads, -1)
|
|
|
|
|
|
def sparse_windowed_self_attention(qkv, window_size: int, shift_window: Tuple[int, int, int] = (0, 0, 0)):
|
|
"""Windowed sparse self-attention. Partitions voxels into windows via spatial
|
|
sort, then runs varlen attention with one sequence per non-empty window."""
|
|
cache_name = f'windowed_attention_{window_size}_{shift_window}'
|
|
cache = qkv.get_spatial_cache(cache_name)
|
|
if cache is None:
|
|
cache = calc_window_partition(qkv, window_size, shift_window)
|
|
qkv.register_spatial_cache(cache_name, cache)
|
|
fwd_indices, bwd_indices, seq_lens = cache
|
|
|
|
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
|
|
q, k, v = qkv_feats.unbind(dim=1) # each [M, H, C]
|
|
heads = q.shape[1]
|
|
device = q.device
|
|
|
|
cu_seqlens = _build_cu_seqlens(seq_lens, device)
|
|
max_seqlen = int(seq_lens.max())
|
|
|
|
out = optimized_attention(
|
|
q, k, v, heads,
|
|
cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens,
|
|
max_seqlen_q=max_seqlen, max_seqlen_kv=max_seqlen,
|
|
skip_reshape=True, skip_output_reshape=True,
|
|
)
|
|
out = out[bwd_indices]
|
|
return qkv.replace(out)
|
|
|
|
|
|
def calc_window_partition(
|
|
tensor,
|
|
window_size: Union[int, Tuple[int, ...]],
|
|
shift_window: Union[int, Tuple[int, ...]] = 0,
|
|
):
|
|
"""Returns (fwd_indices, bwd_indices, seq_lens) for window partitioning."""
|
|
DIM = tensor.coords.shape[1] - 1
|
|
shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
|
|
window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
|
|
shifted_coords = tensor.coords.clone().detach()
|
|
shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
|
|
|
|
MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)]
|
|
NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
|
|
OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
|
|
|
|
shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
|
|
shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
|
|
fwd_indices = torch.argsort(shifted_indices)
|
|
bwd_indices = torch.empty_like(fwd_indices)
|
|
bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
|
|
seq_lens = torch.bincount(shifted_indices)
|
|
seq_lens = seq_lens[seq_lens != 0]
|
|
|
|
return fwd_indices, bwd_indices, seq_lens
|