mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-04 05:31:03 +08:00
Cleanup model code
This commit is contained in:
parent
3aae4bf741
commit
eec0692bcb
@ -1,47 +0,0 @@
|
|||||||
from comfy.ldm.modules.attention import optimized_attention
|
|
||||||
from comfy.ldm.trellis2.vae import VarLenTensor
|
|
||||||
|
|
||||||
|
|
||||||
def dense_attention(q, k, v, **kwargs):
|
|
||||||
"""q, k, v: [B, L, H, C]. Permutes for comfy's [B, H, L, C] convention."""
|
|
||||||
heads = q.shape[2]
|
|
||||||
q = q.permute(0, 2, 1, 3)
|
|
||||||
k = k.permute(0, 2, 1, 3)
|
|
||||||
v = v.permute(0, 2, 1, 3)
|
|
||||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs)
|
|
||||||
return out.permute(0, 2, 1, 3)
|
|
||||||
|
|
||||||
|
|
||||||
def _to_rect(t):
|
|
||||||
"""Fold a VarLenTensor packed as [sum(L_i), H, C] into a dense [B, L, H, C].
|
|
||||||
|
|
||||||
The sparse generation stages run a single object per call (optionally
|
|
||||||
CFG-duplicated, which keeps every batch entry the same length), so the
|
|
||||||
packed layout is rectangular and attention is ordinary dense attention over
|
|
||||||
a batch dim — no variable-length kernel needed. A dense [B, L, H, C] tensor
|
|
||||||
(e.g. cross-attention context) passes through unchanged.
|
|
||||||
"""
|
|
||||||
if not isinstance(t, VarLenTensor):
|
|
||||||
return t
|
|
||||||
B = t.shape[0]
|
|
||||||
seqlens = [t.layout[i].stop - t.layout[i].start for i in range(B)]
|
|
||||||
if len(set(seqlens)) != 1:
|
|
||||||
raise ValueError(
|
|
||||||
"trellis2 sparse attention expects equal sequence lengths per batch "
|
|
||||||
f"(single object, optionally CFG-duplicated); got {seqlens}. "
|
|
||||||
"Multi-object batching is not supported."
|
|
||||||
)
|
|
||||||
return t.feats.view(B, seqlens[0], *t.feats.shape[1:])
|
|
||||||
|
|
||||||
|
|
||||||
def sparse_attention(q, k, v, **kwargs):
|
|
||||||
"""Full attention over a SparseTensor's voxels.
|
|
||||||
|
|
||||||
Single object (optionally CFG-duplicated) => the packed layout is
|
|
||||||
rectangular, so we fold it into a batch dim and run ordinary dense
|
|
||||||
attention. Output type matches q.
|
|
||||||
"""
|
|
||||||
out = dense_attention(_to_rect(q), _to_rect(k), _to_rect(v), **kwargs) # [B, Lq, H, C]
|
|
||||||
if isinstance(q, VarLenTensor):
|
|
||||||
return q.replace(out.reshape(-1, *out.shape[2:]))
|
|
||||||
return out
|
|
||||||
@ -3,10 +3,42 @@ import torch.nn.functional as F
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor
|
from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor
|
||||||
from typing import Optional, Tuple, Literal, Union, List
|
from typing import Optional, Tuple, Literal, Union, List
|
||||||
from comfy.ldm.trellis2.attention import sparse_attention, dense_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
||||||
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
||||||
|
|
||||||
|
|
||||||
|
def dense_attention(q, k, v, **kwargs):
|
||||||
|
heads = q.shape[2]
|
||||||
|
q = q.permute(0, 2, 1, 3)
|
||||||
|
k = k.permute(0, 2, 1, 3)
|
||||||
|
v = v.permute(0, 2, 1, 3)
|
||||||
|
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs)
|
||||||
|
return out.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_rect(t):
|
||||||
|
# Single object (optionally CFG-duplicated) => packed layout is rectangular,
|
||||||
|
# so we can fold it into a batch dim and use dense attention.
|
||||||
|
if not isinstance(t, VarLenTensor):
|
||||||
|
return t
|
||||||
|
B = t.shape[0]
|
||||||
|
seqlens = [t.layout[i].stop - t.layout[i].start for i in range(B)]
|
||||||
|
if len(set(seqlens)) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"trellis2 sparse attention expects equal sequence lengths per batch "
|
||||||
|
f"(single object, optionally CFG-duplicated); got {seqlens}."
|
||||||
|
)
|
||||||
|
return t.feats.view(B, seqlens[0], *t.feats.shape[1:])
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_attention(q, k, v, **kwargs):
|
||||||
|
out = dense_attention(_to_rect(q), _to_rect(k), _to_rect(v), **kwargs)
|
||||||
|
if isinstance(q, VarLenTensor):
|
||||||
|
return q.replace(out.reshape(-1, *out.shape[2:]))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class SparseGELU(nn.GELU):
|
class SparseGELU(nn.GELU):
|
||||||
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
||||||
return input.replace(super().forward(input.feats))
|
return input.replace(super().forward(input.feats))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user