From cd0f7ba64e6d0a195fdc53132cc58b053569dacb Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 5 Feb 2026 02:34:08 +0200 Subject: [PATCH] apply rope and optimized attention --- comfy/ldm/trellis2/attention.py | 29 +++++++------ comfy/ldm/trellis2/model.py | 76 ++++++++++++++++----------------- 2 files changed, 52 insertions(+), 53 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 6c912c8d9..edc85ce83 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -26,21 +26,21 @@ def scaled_dot_product_attention(*args, **kwargs): k = args[1] if len(args) > 1 else kwargs['k'] v = args[2] if len(args) > 2 else kwargs['v'] + # TODO verify + heads = q or qkv + heads = heads.shape[2] + if optimized_attention.__name__ == 'attention_xformers': - if 'xops' not in globals(): - import xformers.ops as xops if num_all_args == 1: q, k, v = qkv.unbind(dim=2) elif num_all_args == 2: k, v = kv.unbind(dim=2) - out = xops.memory_efficient_attention(q, k, v) + #out = xops.memory_efficient_attention(q, k, v) + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA: - if 'flash_attn' not in globals(): - import flash_attn if num_all_args == 2: - out = flash_attn.flash_attn_kvpacked_func(q, kv) - elif num_all_args == 3: - out = flash_attn.flash_attn_func(q, k, v) + k, v = kv.unbind(dim=2) + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) elif optimized_attention.__name__ == 'attention_flash': # TODO if 'flash_attn_3' not in globals(): import flash_attn_interface as flash_attn_3 @@ -59,15 +59,14 @@ def scaled_dot_product_attention(*args, **kwargs): q = q.permute(0, 2, 1, 3) # [N, H, L, C] k = k.permute(0, 2, 1, 3) # [N, H, L, C] v = v.permute(0, 2, 1, 3) # [N, H, L, C] - out = sdpa(q, k, v) # [N, H, L, C] + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) out = out.permute(0, 2, 1, 3) # [N, L, H, C] elif optimized_attention.__name__ == 'attention_basic': if num_all_args == 1: q, k, v = qkv.unbind(dim=2) elif num_all_args == 2: k, v = kv.unbind(dim=2) - q = q.shape[2] # TODO - out = optimized_attention(q, k, v) + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) return out @@ -86,19 +85,21 @@ def sparse_windowed_scaled_dot_product_self_attention( fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + heads = qkv_feats.shape[2] if optimized_attention.__name__ == 'attention_xformers': - if 'xops' not in globals(): - import xformers.ops as xops q, k, v = qkv_feats.unbind(dim=1) q = q.unsqueeze(0) # [1, M, H, C] k = k.unsqueeze(0) # [1, M, H, C] v = v.unsqueeze(0) # [1, M, H, C] - out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C] + #out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C] + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) elif optimized_attention.__name__ == 'attention_flash': if 'flash_attn' not in globals(): import flash_attn out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C] + else: + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) out = out[bwd_indices] # [T, H, C] diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 1dbbc4955..484622d76 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -8,6 +8,7 @@ from comfy.ldm.trellis2.attention import ( ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder from comfy.nested_tensor import NestedTensor +from comfy.ldm.flux.math import apply_rope, apply_rope1 class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: @@ -52,7 +53,6 @@ class SparseMultiHeadRMSNorm(nn.Module): x = F.normalize(x, dim=-1) * self.gamma * self.scale return x.to(x_type) -# TODO: replace with apply_rope1 class SparseRotaryPositionEmbedder(nn.Module): def __init__( self, @@ -61,7 +61,6 @@ class SparseRotaryPositionEmbedder(nn.Module): rope_freq: Tuple[float, float] = (1.0, 10000.0) ): super().__init__() - assert head_dim % 2 == 0, "Head dim must be divisible by 2" self.head_dim = head_dim self.dim = dim self.rope_freq = rope_freq @@ -69,46 +68,48 @@ class SparseRotaryPositionEmbedder(nn.Module): self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) - def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: - self.freqs = self.freqs.to(indices.device) - phases = torch.outer(indices, self.freqs) - phases = torch.polar(torch.ones_like(phases), phases) - return phases + def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor: + phases_list = [] + for i in range(self.dim): + phases_list.append(torch.outer(coords[..., i], self.freqs.to(coords.device))) - def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: - x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - x_rotated = x_complex * phases.unsqueeze(-2) - x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) - return x_embed + phases = torch.cat(phases_list, dim=-1) + + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.zeros(*phases.shape[:-1], padn, device=phases.device)], dim=-1) + + cos = torch.cos(phases) + sin = torch.sin(phases) + + f_cis_0 = torch.stack([cos, sin], dim=-1) + f_cis_1 = torch.stack([-sin, cos], dim=-1) + freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1) + + return freqs_cis + + def forward(self, q, k=None): + cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}' + freqs_cis = q.get_spatial_cache(cache_name) + + if freqs_cis is None: + coords = q.coords[..., 1:].to(torch.float32) + freqs_cis = self._get_freqs_cis(coords) + q.register_spatial_cache(cache_name, freqs_cis) + + if q.feats.ndim == 3: + f_cis = freqs_cis.unsqueeze(1) + else: + f_cis = freqs_cis - def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - q (SparseTensor): [..., N, H, D] tensor of queries - k (SparseTensor): [..., N, H, D] tensor of keys - """ - assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1" - phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}' - phases = q.get_spatial_cache(phases_cache_name) - if phases is None: - coords = q.coords[..., 1:] - phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1) - if phases.shape[-1] < self.head_dim // 2: - padn = self.head_dim // 2 - phases.shape[-1] - phases = torch.cat([phases, torch.polar( - torch.ones(*phases.shape[:-1], padn, device=phases.device), - torch.zeros(*phases.shape[:-1], padn, device=phases.device) - )], dim=-1) - q.register_spatial_cache(phases_cache_name, phases) - q_embed = q.replace(self._rotary_embedding(q.feats, phases)) if k is None: - return q_embed - k_embed = k.replace(self._rotary_embedding(k.feats, phases)) - return q_embed, k_embed + return q.replace(apply_rope1(q.feats, f_cis)) + + q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis) + return q.replace(q_feats), k.replace(k_feats) class RotaryPositionEmbedder(SparseRotaryPositionEmbedder): def forward(self, indices: torch.Tensor) -> torch.Tensor: - assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}" phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) if phases.shape[-1] < self.head_dim // 2: padn = self.head_dim // 2 - phases.shape[-1] @@ -228,9 +229,6 @@ class SparseMultiHeadAttention(nn.Module): return h class ModulatedSparseTransformerBlock(nn.Module): - """ - Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. - """ def __init__( self, channels: int,