diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py deleted file mode 100644 index 571101863..000000000 --- a/comfy/ldm/trellis2/attention.py +++ /dev/null @@ -1,47 +0,0 @@ -from comfy.ldm.modules.attention import optimized_attention -from comfy.ldm.trellis2.vae import VarLenTensor - - -def dense_attention(q, k, v, **kwargs): - """q, k, v: [B, L, H, C]. Permutes for comfy's [B, H, L, C] convention.""" - heads = q.shape[2] - q = q.permute(0, 2, 1, 3) - k = k.permute(0, 2, 1, 3) - v = v.permute(0, 2, 1, 3) - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs) - return out.permute(0, 2, 1, 3) - - -def _to_rect(t): - """Fold a VarLenTensor packed as [sum(L_i), H, C] into a dense [B, L, H, C]. - - The sparse generation stages run a single object per call (optionally - CFG-duplicated, which keeps every batch entry the same length), so the - packed layout is rectangular and attention is ordinary dense attention over - a batch dim — no variable-length kernel needed. A dense [B, L, H, C] tensor - (e.g. cross-attention context) passes through unchanged. - """ - if not isinstance(t, VarLenTensor): - return t - B = t.shape[0] - seqlens = [t.layout[i].stop - t.layout[i].start for i in range(B)] - if len(set(seqlens)) != 1: - raise ValueError( - "trellis2 sparse attention expects equal sequence lengths per batch " - f"(single object, optionally CFG-duplicated); got {seqlens}. " - "Multi-object batching is not supported." - ) - return t.feats.view(B, seqlens[0], *t.feats.shape[1:]) - - -def sparse_attention(q, k, v, **kwargs): - """Full attention over a SparseTensor's voxels. - - Single object (optionally CFG-duplicated) => the packed layout is - rectangular, so we fold it into a batch dim and run ordinary dense - attention. Output type matches q. - """ - out = dense_attention(_to_rect(q), _to_rect(k), _to_rect(v), **kwargs) # [B, Lq, H, C] - if isinstance(q, VarLenTensor): - return q.replace(out.reshape(-1, *out.shape[2:])) - return out diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 3f3751052..2beb389e7 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -3,10 +3,42 @@ import torch.nn.functional as F import torch.nn as nn from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor from typing import Optional, Tuple, Literal, Union, List -from comfy.ldm.trellis2.attention import sparse_attention, dense_attention +from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder from comfy.ldm.flux.math import apply_rope, apply_rope1 + +def dense_attention(q, k, v, **kwargs): + heads = q.shape[2] + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs) + return out.permute(0, 2, 1, 3) + + +def _to_rect(t): + # Single object (optionally CFG-duplicated) => packed layout is rectangular, + # so we can fold it into a batch dim and use dense attention. + if not isinstance(t, VarLenTensor): + return t + B = t.shape[0] + seqlens = [t.layout[i].stop - t.layout[i].start for i in range(B)] + if len(set(seqlens)) != 1: + raise ValueError( + "trellis2 sparse attention expects equal sequence lengths per batch " + f"(single object, optionally CFG-duplicated); got {seqlens}." + ) + return t.feats.view(B, seqlens[0], *t.feats.shape[1:]) + + +def sparse_attention(q, k, v, **kwargs): + out = dense_attention(_to_rect(q), _to_rect(k), _to_rect(v), **kwargs) + if isinstance(q, VarLenTensor): + return q.replace(out.reshape(-1, *out.shape[2:])) + return out + + class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: return input.replace(super().forward(input.feats))