apply rope and optimized attention

This commit is contained in:
Yousef Rafat 2026-02-05 02:34:08 +02:00
parent 3002708fe3
commit cd0f7ba64e
2 changed files with 52 additions and 53 deletions

View File

@ -26,21 +26,21 @@ def scaled_dot_product_attention(*args, **kwargs):
k = args[1] if len(args) > 1 else kwargs['k']
v = args[2] if len(args) > 2 else kwargs['v']
# TODO verify
heads = q or qkv
heads = heads.shape[2]
if optimized_attention.__name__ == 'attention_xformers':
if 'xops' not in globals():
import xformers.ops as xops
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
elif num_all_args == 2:
k, v = kv.unbind(dim=2)
out = xops.memory_efficient_attention(q, k, v)
#out = xops.memory_efficient_attention(q, k, v)
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA:
if 'flash_attn' not in globals():
import flash_attn
if num_all_args == 2:
out = flash_attn.flash_attn_kvpacked_func(q, kv)
elif num_all_args == 3:
out = flash_attn.flash_attn_func(q, k, v)
k, v = kv.unbind(dim=2)
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
elif optimized_attention.__name__ == 'attention_flash': # TODO
if 'flash_attn_3' not in globals():
import flash_attn_interface as flash_attn_3
@ -59,15 +59,14 @@ def scaled_dot_product_attention(*args, **kwargs):
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
out = sdpa(q, k, v) # [N, H, L, C]
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
elif optimized_attention.__name__ == 'attention_basic':
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
elif num_all_args == 2:
k, v = kv.unbind(dim=2)
q = q.shape[2] # TODO
out = optimized_attention(q, k, v)
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
return out
@ -86,19 +85,21 @@ def sparse_windowed_scaled_dot_product_self_attention(
fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
heads = qkv_feats.shape[2]
if optimized_attention.__name__ == 'attention_xformers':
if 'xops' not in globals():
import xformers.ops as xops
q, k, v = qkv_feats.unbind(dim=1)
q = q.unsqueeze(0) # [1, M, H, C]
k = k.unsqueeze(0) # [1, M, H, C]
v = v.unsqueeze(0) # [1, M, H, C]
out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
#out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
elif optimized_attention.__name__ == 'attention_flash':
if 'flash_attn' not in globals():
import flash_attn
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C]
else:
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
out = out[bwd_indices] # [T, H, C]

View File

@ -8,6 +8,7 @@ from comfy.ldm.trellis2.attention import (
)
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
from comfy.nested_tensor import NestedTensor
from comfy.ldm.flux.math import apply_rope, apply_rope1
class SparseGELU(nn.GELU):
def forward(self, input: VarLenTensor) -> VarLenTensor:
@ -52,7 +53,6 @@ class SparseMultiHeadRMSNorm(nn.Module):
x = F.normalize(x, dim=-1) * self.gamma * self.scale
return x.to(x_type)
# TODO: replace with apply_rope1
class SparseRotaryPositionEmbedder(nn.Module):
def __init__(
self,
@ -61,7 +61,6 @@ class SparseRotaryPositionEmbedder(nn.Module):
rope_freq: Tuple[float, float] = (1.0, 10000.0)
):
super().__init__()
assert head_dim % 2 == 0, "Head dim must be divisible by 2"
self.head_dim = head_dim
self.dim = dim
self.rope_freq = rope_freq
@ -69,46 +68,48 @@ class SparseRotaryPositionEmbedder(nn.Module):
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
self.freqs = self.freqs.to(indices.device)
phases = torch.outer(indices, self.freqs)
phases = torch.polar(torch.ones_like(phases), phases)
return phases
def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor:
phases_list = []
for i in range(self.dim):
phases_list.append(torch.outer(coords[..., i], self.freqs.to(coords.device)))
def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_rotated = x_complex * phases.unsqueeze(-2)
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
return x_embed
phases = torch.cat(phases_list, dim=-1)
if phases.shape[-1] < self.head_dim // 2:
padn = self.head_dim // 2 - phases.shape[-1]
phases = torch.cat([phases, torch.zeros(*phases.shape[:-1], padn, device=phases.device)], dim=-1)
cos = torch.cos(phases)
sin = torch.sin(phases)
f_cis_0 = torch.stack([cos, sin], dim=-1)
f_cis_1 = torch.stack([-sin, cos], dim=-1)
freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1)
return freqs_cis
def forward(self, q, k=None):
cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}'
freqs_cis = q.get_spatial_cache(cache_name)
if freqs_cis is None:
coords = q.coords[..., 1:].to(torch.float32)
freqs_cis = self._get_freqs_cis(coords)
q.register_spatial_cache(cache_name, freqs_cis)
if q.feats.ndim == 3:
f_cis = freqs_cis.unsqueeze(1)
else:
f_cis = freqs_cis
def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
q (SparseTensor): [..., N, H, D] tensor of queries
k (SparseTensor): [..., N, H, D] tensor of keys
"""
assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1"
phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}'
phases = q.get_spatial_cache(phases_cache_name)
if phases is None:
coords = q.coords[..., 1:]
phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1)
if phases.shape[-1] < self.head_dim // 2:
padn = self.head_dim // 2 - phases.shape[-1]
phases = torch.cat([phases, torch.polar(
torch.ones(*phases.shape[:-1], padn, device=phases.device),
torch.zeros(*phases.shape[:-1], padn, device=phases.device)
)], dim=-1)
q.register_spatial_cache(phases_cache_name, phases)
q_embed = q.replace(self._rotary_embedding(q.feats, phases))
if k is None:
return q_embed
k_embed = k.replace(self._rotary_embedding(k.feats, phases))
return q_embed, k_embed
return q.replace(apply_rope1(q.feats, f_cis))
q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis)
return q.replace(q_feats), k.replace(k_feats)
class RotaryPositionEmbedder(SparseRotaryPositionEmbedder):
def forward(self, indices: torch.Tensor) -> torch.Tensor:
assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}"
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
if phases.shape[-1] < self.head_dim // 2:
padn = self.head_dim // 2 - phases.shape[-1]
@ -228,9 +229,6 @@ class SparseMultiHeadAttention(nn.Module):
return h
class ModulatedSparseTransformerBlock(nn.Module):
"""
Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
"""
def __init__(
self,
channels: int,