mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-24 01:12:37 +08:00
apply rope and optimized attention
This commit is contained in:
parent
3002708fe3
commit
cd0f7ba64e
@ -26,21 +26,21 @@ def scaled_dot_product_attention(*args, **kwargs):
|
|||||||
k = args[1] if len(args) > 1 else kwargs['k']
|
k = args[1] if len(args) > 1 else kwargs['k']
|
||||||
v = args[2] if len(args) > 2 else kwargs['v']
|
v = args[2] if len(args) > 2 else kwargs['v']
|
||||||
|
|
||||||
|
# TODO verify
|
||||||
|
heads = q or qkv
|
||||||
|
heads = heads.shape[2]
|
||||||
|
|
||||||
if optimized_attention.__name__ == 'attention_xformers':
|
if optimized_attention.__name__ == 'attention_xformers':
|
||||||
if 'xops' not in globals():
|
|
||||||
import xformers.ops as xops
|
|
||||||
if num_all_args == 1:
|
if num_all_args == 1:
|
||||||
q, k, v = qkv.unbind(dim=2)
|
q, k, v = qkv.unbind(dim=2)
|
||||||
elif num_all_args == 2:
|
elif num_all_args == 2:
|
||||||
k, v = kv.unbind(dim=2)
|
k, v = kv.unbind(dim=2)
|
||||||
out = xops.memory_efficient_attention(q, k, v)
|
#out = xops.memory_efficient_attention(q, k, v)
|
||||||
|
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||||
elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA:
|
elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA:
|
||||||
if 'flash_attn' not in globals():
|
|
||||||
import flash_attn
|
|
||||||
if num_all_args == 2:
|
if num_all_args == 2:
|
||||||
out = flash_attn.flash_attn_kvpacked_func(q, kv)
|
k, v = kv.unbind(dim=2)
|
||||||
elif num_all_args == 3:
|
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||||
out = flash_attn.flash_attn_func(q, k, v)
|
|
||||||
elif optimized_attention.__name__ == 'attention_flash': # TODO
|
elif optimized_attention.__name__ == 'attention_flash': # TODO
|
||||||
if 'flash_attn_3' not in globals():
|
if 'flash_attn_3' not in globals():
|
||||||
import flash_attn_interface as flash_attn_3
|
import flash_attn_interface as flash_attn_3
|
||||||
@ -59,15 +59,14 @@ def scaled_dot_product_attention(*args, **kwargs):
|
|||||||
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
|
q = q.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||||
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
|
k = k.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||||
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
|
v = v.permute(0, 2, 1, 3) # [N, H, L, C]
|
||||||
out = sdpa(q, k, v) # [N, H, L, C]
|
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||||
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
|
out = out.permute(0, 2, 1, 3) # [N, L, H, C]
|
||||||
elif optimized_attention.__name__ == 'attention_basic':
|
elif optimized_attention.__name__ == 'attention_basic':
|
||||||
if num_all_args == 1:
|
if num_all_args == 1:
|
||||||
q, k, v = qkv.unbind(dim=2)
|
q, k, v = qkv.unbind(dim=2)
|
||||||
elif num_all_args == 2:
|
elif num_all_args == 2:
|
||||||
k, v = kv.unbind(dim=2)
|
k, v = kv.unbind(dim=2)
|
||||||
q = q.shape[2] # TODO
|
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||||
out = optimized_attention(q, k, v)
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -86,19 +85,21 @@ def sparse_windowed_scaled_dot_product_self_attention(
|
|||||||
fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache
|
fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache
|
||||||
|
|
||||||
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
|
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
|
||||||
|
heads = qkv_feats.shape[2]
|
||||||
|
|
||||||
if optimized_attention.__name__ == 'attention_xformers':
|
if optimized_attention.__name__ == 'attention_xformers':
|
||||||
if 'xops' not in globals():
|
|
||||||
import xformers.ops as xops
|
|
||||||
q, k, v = qkv_feats.unbind(dim=1)
|
q, k, v = qkv_feats.unbind(dim=1)
|
||||||
q = q.unsqueeze(0) # [1, M, H, C]
|
q = q.unsqueeze(0) # [1, M, H, C]
|
||||||
k = k.unsqueeze(0) # [1, M, H, C]
|
k = k.unsqueeze(0) # [1, M, H, C]
|
||||||
v = v.unsqueeze(0) # [1, M, H, C]
|
v = v.unsqueeze(0) # [1, M, H, C]
|
||||||
out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
|
#out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
|
||||||
|
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||||
elif optimized_attention.__name__ == 'attention_flash':
|
elif optimized_attention.__name__ == 'attention_flash':
|
||||||
if 'flash_attn' not in globals():
|
if 'flash_attn' not in globals():
|
||||||
import flash_attn
|
import flash_attn
|
||||||
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C]
|
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C]
|
||||||
|
else:
|
||||||
|
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||||
|
|
||||||
out = out[bwd_indices] # [T, H, C]
|
out = out[bwd_indices] # [T, H, C]
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from comfy.ldm.trellis2.attention import (
|
|||||||
)
|
)
|
||||||
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
||||||
from comfy.nested_tensor import NestedTensor
|
from comfy.nested_tensor import NestedTensor
|
||||||
|
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
||||||
|
|
||||||
class SparseGELU(nn.GELU):
|
class SparseGELU(nn.GELU):
|
||||||
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
||||||
@ -52,7 +53,6 @@ class SparseMultiHeadRMSNorm(nn.Module):
|
|||||||
x = F.normalize(x, dim=-1) * self.gamma * self.scale
|
x = F.normalize(x, dim=-1) * self.gamma * self.scale
|
||||||
return x.to(x_type)
|
return x.to(x_type)
|
||||||
|
|
||||||
# TODO: replace with apply_rope1
|
|
||||||
class SparseRotaryPositionEmbedder(nn.Module):
|
class SparseRotaryPositionEmbedder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -61,7 +61,6 @@ class SparseRotaryPositionEmbedder(nn.Module):
|
|||||||
rope_freq: Tuple[float, float] = (1.0, 10000.0)
|
rope_freq: Tuple[float, float] = (1.0, 10000.0)
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert head_dim % 2 == 0, "Head dim must be divisible by 2"
|
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.rope_freq = rope_freq
|
self.rope_freq = rope_freq
|
||||||
@ -69,46 +68,48 @@ class SparseRotaryPositionEmbedder(nn.Module):
|
|||||||
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
|
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
|
||||||
self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
|
self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
|
||||||
|
|
||||||
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
|
def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor:
|
||||||
self.freqs = self.freqs.to(indices.device)
|
phases_list = []
|
||||||
phases = torch.outer(indices, self.freqs)
|
for i in range(self.dim):
|
||||||
phases = torch.polar(torch.ones_like(phases), phases)
|
phases_list.append(torch.outer(coords[..., i], self.freqs.to(coords.device)))
|
||||||
return phases
|
|
||||||
|
|
||||||
def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
|
phases = torch.cat(phases_list, dim=-1)
|
||||||
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
|
||||||
x_rotated = x_complex * phases.unsqueeze(-2)
|
if phases.shape[-1] < self.head_dim // 2:
|
||||||
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
|
padn = self.head_dim // 2 - phases.shape[-1]
|
||||||
return x_embed
|
phases = torch.cat([phases, torch.zeros(*phases.shape[:-1], padn, device=phases.device)], dim=-1)
|
||||||
|
|
||||||
|
cos = torch.cos(phases)
|
||||||
|
sin = torch.sin(phases)
|
||||||
|
|
||||||
|
f_cis_0 = torch.stack([cos, sin], dim=-1)
|
||||||
|
f_cis_1 = torch.stack([-sin, cos], dim=-1)
|
||||||
|
freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1)
|
||||||
|
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
def forward(self, q, k=None):
|
||||||
|
cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}'
|
||||||
|
freqs_cis = q.get_spatial_cache(cache_name)
|
||||||
|
|
||||||
|
if freqs_cis is None:
|
||||||
|
coords = q.coords[..., 1:].to(torch.float32)
|
||||||
|
freqs_cis = self._get_freqs_cis(coords)
|
||||||
|
q.register_spatial_cache(cache_name, freqs_cis)
|
||||||
|
|
||||||
|
if q.feats.ndim == 3:
|
||||||
|
f_cis = freqs_cis.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
f_cis = freqs_cis
|
||||||
|
|
||||||
def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
q (SparseTensor): [..., N, H, D] tensor of queries
|
|
||||||
k (SparseTensor): [..., N, H, D] tensor of keys
|
|
||||||
"""
|
|
||||||
assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1"
|
|
||||||
phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}'
|
|
||||||
phases = q.get_spatial_cache(phases_cache_name)
|
|
||||||
if phases is None:
|
|
||||||
coords = q.coords[..., 1:]
|
|
||||||
phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1)
|
|
||||||
if phases.shape[-1] < self.head_dim // 2:
|
|
||||||
padn = self.head_dim // 2 - phases.shape[-1]
|
|
||||||
phases = torch.cat([phases, torch.polar(
|
|
||||||
torch.ones(*phases.shape[:-1], padn, device=phases.device),
|
|
||||||
torch.zeros(*phases.shape[:-1], padn, device=phases.device)
|
|
||||||
)], dim=-1)
|
|
||||||
q.register_spatial_cache(phases_cache_name, phases)
|
|
||||||
q_embed = q.replace(self._rotary_embedding(q.feats, phases))
|
|
||||||
if k is None:
|
if k is None:
|
||||||
return q_embed
|
return q.replace(apply_rope1(q.feats, f_cis))
|
||||||
k_embed = k.replace(self._rotary_embedding(k.feats, phases))
|
|
||||||
return q_embed, k_embed
|
q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis)
|
||||||
|
return q.replace(q_feats), k.replace(k_feats)
|
||||||
|
|
||||||
class RotaryPositionEmbedder(SparseRotaryPositionEmbedder):
|
class RotaryPositionEmbedder(SparseRotaryPositionEmbedder):
|
||||||
def forward(self, indices: torch.Tensor) -> torch.Tensor:
|
def forward(self, indices: torch.Tensor) -> torch.Tensor:
|
||||||
assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}"
|
|
||||||
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
|
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
|
||||||
if phases.shape[-1] < self.head_dim // 2:
|
if phases.shape[-1] < self.head_dim // 2:
|
||||||
padn = self.head_dim // 2 - phases.shape[-1]
|
padn = self.head_dim // 2 - phases.shape[-1]
|
||||||
@ -228,9 +229,6 @@ class SparseMultiHeadAttention(nn.Module):
|
|||||||
return h
|
return h
|
||||||
|
|
||||||
class ModulatedSparseTransformerBlock(nn.Module):
|
class ModulatedSparseTransformerBlock(nn.Module):
|
||||||
"""
|
|
||||||
Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
|
|
||||||
"""
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
channels: int,
|
channels: int,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user