mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
apply rope and optimized attention
This commit is contained in:
parent
3002708fe3
commit
cd0f7ba64e
@ -26,21 +26,21 @@ def scaled_dot_product_attention(*args, **kwargs):
|
||||
k = args[1] if len(args) > 1 else kwargs['k']
|
||||
v = args[2] if len(args) > 2 else kwargs['v']
|
||||
|
||||
# TODO verify
|
||||
heads = q or qkv
|
||||
heads = heads.shape[2]
|
||||
|
||||
if optimized_attention.__name__ == 'attention_xformers':
|
||||
if 'xops' not in globals():
|
||||
import xformers.ops as xops
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=2)
|
||||
out = xops.memory_efficient_attention(q, k, v)
|
||||
#out = xops.memory_efficient_attention(q, k, v)
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA:
|
||||
if 'flash_attn' not in globals():
|
||||
import flash_attn
|
||||
if num_all_args == 2:
|
||||
out = flash_attn.flash_attn_kvpacked_func(q, kv)
|
||||
elif num_all_args == 3:
|
||||
out = flash_attn.flash_attn_func(q, k, v)
|
||||
k, v = kv.unbind(dim=2)
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
elif optimized_attention.__name__ == 'attention_flash': # TODO
|
||||
if 'flash_attn_3' not in globals():
|
||||
import flash_attn_interface as flash_attn_3
|
||||
@ -59,15 +59,14 @@ def scaled_dot_product_attention(*args, **kwargs):
|
||||
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||
out = sdpa(q, k, v) # [N, H, L, C]
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
|
||||
elif optimized_attention.__name__ == 'attention_basic':
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=2)
|
||||
q = q.shape[2] # TODO
|
||||
out = optimized_attention(q, k, v)
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
|
||||
return out
|
||||
|
||||
@ -86,19 +85,21 @@ def sparse_windowed_scaled_dot_product_self_attention(
|
||||
fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache
|
||||
|
||||
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
|
||||
heads = qkv_feats.shape[2]
|
||||
|
||||
if optimized_attention.__name__ == 'attention_xformers':
|
||||
if 'xops' not in globals():
|
||||
import xformers.ops as xops
|
||||
q, k, v = qkv_feats.unbind(dim=1)
|
||||
q = q.unsqueeze(0) # [1, M, H, C]
|
||||
k = k.unsqueeze(0) # [1, M, H, C]
|
||||
v = v.unsqueeze(0) # [1, M, H, C]
|
||||
out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
|
||||
#out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
elif optimized_attention.__name__ == 'attention_flash':
|
||||
if 'flash_attn' not in globals():
|
||||
import flash_attn
|
||||
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C]
|
||||
else:
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
|
||||
out = out[bwd_indices] # [T, H, C]
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from comfy.ldm.trellis2.attention import (
|
||||
)
|
||||
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
||||
from comfy.nested_tensor import NestedTensor
|
||||
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
||||
|
||||
class SparseGELU(nn.GELU):
|
||||
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
||||
@ -52,7 +53,6 @@ class SparseMultiHeadRMSNorm(nn.Module):
|
||||
x = F.normalize(x, dim=-1) * self.gamma * self.scale
|
||||
return x.to(x_type)
|
||||
|
||||
# TODO: replace with apply_rope1
|
||||
class SparseRotaryPositionEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -61,7 +61,6 @@ class SparseRotaryPositionEmbedder(nn.Module):
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0)
|
||||
):
|
||||
super().__init__()
|
||||
assert head_dim % 2 == 0, "Head dim must be divisible by 2"
|
||||
self.head_dim = head_dim
|
||||
self.dim = dim
|
||||
self.rope_freq = rope_freq
|
||||
@ -69,46 +68,48 @@ class SparseRotaryPositionEmbedder(nn.Module):
|
||||
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
|
||||
self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
|
||||
|
||||
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
|
||||
self.freqs = self.freqs.to(indices.device)
|
||||
phases = torch.outer(indices, self.freqs)
|
||||
phases = torch.polar(torch.ones_like(phases), phases)
|
||||
return phases
|
||||
def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
phases_list = []
|
||||
for i in range(self.dim):
|
||||
phases_list.append(torch.outer(coords[..., i], self.freqs.to(coords.device)))
|
||||
|
||||
def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
|
||||
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
x_rotated = x_complex * phases.unsqueeze(-2)
|
||||
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
|
||||
return x_embed
|
||||
phases = torch.cat(phases_list, dim=-1)
|
||||
|
||||
if phases.shape[-1] < self.head_dim // 2:
|
||||
padn = self.head_dim // 2 - phases.shape[-1]
|
||||
phases = torch.cat([phases, torch.zeros(*phases.shape[:-1], padn, device=phases.device)], dim=-1)
|
||||
|
||||
cos = torch.cos(phases)
|
||||
sin = torch.sin(phases)
|
||||
|
||||
f_cis_0 = torch.stack([cos, sin], dim=-1)
|
||||
f_cis_1 = torch.stack([-sin, cos], dim=-1)
|
||||
freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1)
|
||||
|
||||
return freqs_cis
|
||||
|
||||
def forward(self, q, k=None):
|
||||
cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}'
|
||||
freqs_cis = q.get_spatial_cache(cache_name)
|
||||
|
||||
if freqs_cis is None:
|
||||
coords = q.coords[..., 1:].to(torch.float32)
|
||||
freqs_cis = self._get_freqs_cis(coords)
|
||||
q.register_spatial_cache(cache_name, freqs_cis)
|
||||
|
||||
if q.feats.ndim == 3:
|
||||
f_cis = freqs_cis.unsqueeze(1)
|
||||
else:
|
||||
f_cis = freqs_cis
|
||||
|
||||
def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
q (SparseTensor): [..., N, H, D] tensor of queries
|
||||
k (SparseTensor): [..., N, H, D] tensor of keys
|
||||
"""
|
||||
assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1"
|
||||
phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}'
|
||||
phases = q.get_spatial_cache(phases_cache_name)
|
||||
if phases is None:
|
||||
coords = q.coords[..., 1:]
|
||||
phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1)
|
||||
if phases.shape[-1] < self.head_dim // 2:
|
||||
padn = self.head_dim // 2 - phases.shape[-1]
|
||||
phases = torch.cat([phases, torch.polar(
|
||||
torch.ones(*phases.shape[:-1], padn, device=phases.device),
|
||||
torch.zeros(*phases.shape[:-1], padn, device=phases.device)
|
||||
)], dim=-1)
|
||||
q.register_spatial_cache(phases_cache_name, phases)
|
||||
q_embed = q.replace(self._rotary_embedding(q.feats, phases))
|
||||
if k is None:
|
||||
return q_embed
|
||||
k_embed = k.replace(self._rotary_embedding(k.feats, phases))
|
||||
return q_embed, k_embed
|
||||
return q.replace(apply_rope1(q.feats, f_cis))
|
||||
|
||||
q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis)
|
||||
return q.replace(q_feats), k.replace(k_feats)
|
||||
|
||||
class RotaryPositionEmbedder(SparseRotaryPositionEmbedder):
|
||||
def forward(self, indices: torch.Tensor) -> torch.Tensor:
|
||||
assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}"
|
||||
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
|
||||
if phases.shape[-1] < self.head_dim // 2:
|
||||
padn = self.head_dim // 2 - phases.shape[-1]
|
||||
@ -228,9 +229,6 @@ class SparseMultiHeadAttention(nn.Module):
|
||||
return h
|
||||
|
||||
class ModulatedSparseTransformerBlock(nn.Module):
|
||||
"""
|
||||
Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user