This commit is contained in:
kijai 2026-06-27 00:05:45 +03:00
parent 7bd1fa6e78
commit 288f3cf134
6 changed files with 109 additions and 573 deletions

View File

@ -770,9 +770,10 @@ class Hunyuan3Dv2_1(LatentFormat):
latent_channels = 64 latent_channels = 64
latent_dimensions = 1 latent_dimensions = 1
class Trellis2(LatentFormat): # TODO class Trellis2(LatentFormat):
latent_channels = 32 latent_channels = 32
trellis3d_preview = True # routes the sampler preview to Trellis3DPreviewer
class Hunyuan3Dv2mini(LatentFormat): class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64 latent_channels = 64
latent_dimensions = 1 latent_dimensions = 1

View File

@ -32,11 +32,6 @@ except ImportError as e:
raise e raise e
exit(-1) exit(-1)
try:
from sageattention import sageattn_varlen # sageattention >= 2
except ImportError:
sageattn_varlen = None
SAGE_ATTENTION3_IS_AVAILABLE = False SAGE_ATTENTION3_IS_AVAILABLE = False
try: try:
from sageattn3 import sageattn3_blackwell from sageattn3 import sageattn3_blackwell
@ -53,24 +48,6 @@ except ImportError:
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1) exit(-1)
try:
from torch.nn.attention.varlen import varlen_attn as _torch_varlen_attn
except ImportError:
_torch_varlen_attn = None
def _is_varlen(kwargs):
"""Varlen mode is opted into by passing cu_seqlens_q in kwargs."""
return kwargs.get("cu_seqlens_q") is not None
def _varlen_args(kwargs):
cu_seqlens_q = kwargs["cu_seqlens_q"]
cu_seqlens_kv = kwargs.get("cu_seqlens_kv", cu_seqlens_q)
max_seqlen_q = int(kwargs["max_seqlen_q"])
max_seqlen_kv = int(kwargs.get("max_seqlen_kv", max_seqlen_q))
return cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
REGISTERED_ATTENTION_FUNCTIONS = {} REGISTERED_ATTENTION_FUNCTIONS = {}
def register_attention_function(name: str, func: Callable): def register_attention_function(name: str, func: Callable):
# avoid replacing existing functions # avoid replacing existing functions
@ -167,8 +144,6 @@ def wrap_attn(func):
@wrap_attn @wrap_attn
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if _is_varlen(kwargs):
return attention_pytorch(q, k, v, heads, mask=mask, **kwargs)
attn_precision = get_attn_precision(attn_precision, q.dtype) attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape: if skip_reshape:
@ -243,8 +218,6 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
@wrap_attn @wrap_attn
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if _is_varlen(kwargs):
return attention_pytorch(query, key, value, heads, mask=mask, **kwargs)
attn_precision = get_attn_precision(attn_precision, query.dtype) attn_precision = get_attn_precision(attn_precision, query.dtype)
if skip_reshape: if skip_reshape:
@ -320,8 +293,6 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
@wrap_attn @wrap_attn
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if _is_varlen(kwargs):
return attention_pytorch(q, k, v, heads, mask=mask, **kwargs)
attn_precision = get_attn_precision(attn_precision, q.dtype) attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape: if skip_reshape:
@ -453,17 +424,6 @@ except:
@wrap_attn @wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if _is_varlen(kwargs):
# q, k, v expected as packed [T_total, H, C]. xformers wants per-item
# seqlen lists and a leading batch dim of 1.
cu_seqlens_q, cu_seqlens_kv, _max_q, _max_kv = _varlen_args(kwargs)
q_seqlen = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).tolist()
kv_seqlen = (cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]).tolist()
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
return xformers.ops.memory_efficient_attention(
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attn_bias=attn_bias,
)[0]
b = q.shape[0] b = q.shape[0]
dim_head = q.shape[-1] dim_head = q.shape[-1]
# check to make sure xformers isn't broken # check to make sure xformers isn't broken
@ -533,22 +493,6 @@ else:
@wrap_attn @wrap_attn
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if _is_varlen(kwargs):
# q, k, v expected as packed [T_total, H, C]. cu_seqlens_q / cu_seqlens_kv
# describe per-item offsets. Native varlen kernel if available, else
# nested-tensor SDPA fallback. mask/attn_precision/etc are ignored here.
cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs)
if _torch_varlen_attn is not None:
return _torch_varlen_attn(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv)
q_nj = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
k_nj = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_kv.long())
v_nj = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_kv.long())
out = comfy.ops.scaled_dot_product_attention(
q_nj.transpose(1, 2), k_nj.transpose(1, 2), v_nj.transpose(1, 2),
attn_mask=None, dropout_p=0.0, is_causal=False,
)
return out.transpose(1, 2).values()
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
else: else:
@ -597,13 +541,6 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn @wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if _is_varlen(kwargs):
# q, k, v expected as packed [T_total, H, C].
if sageattn_varlen is None:
# sageattention v1 has no varlen kernel; fall back to attention_pytorch's varlen path.
return attention_pytorch(q, k, v, heads, mask=mask, **kwargs)
cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs)
return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv)
if kwargs.get("low_precision_attention", True) is False: if kwargs.get("low_precision_attention", True) is False:
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs) return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
@ -761,12 +698,6 @@ except AttributeError as error:
@wrap_attn @wrap_attn
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if _is_varlen(kwargs):
# q, k, v expected as packed [T_total, H, C].
from flash_attn import flash_attn_varlen_func
cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs)
return flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv)
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
else: else:

View File

@ -1,33 +1,7 @@
import math
from typing import Tuple, Union
import torch
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.trellis2.vae import VarLenTensor from comfy.ldm.trellis2.vae import VarLenTensor
def _build_cu_seqlens(seqlen, device):
"""Cumulative-offset tensor from a list/tensor of per-item sequence lengths."""
if isinstance(seqlen, torch.Tensor):
return torch.cat(
[torch.zeros(1, dtype=torch.int32, device=seqlen.device), torch.cumsum(seqlen, dim=0).int()],
dim=0,
).to(device)
return torch.cat(
[torch.tensor([0]), torch.cumsum(torch.tensor(seqlen), dim=0)],
).int().to(device)
def _layout_to_feats(t):
"""Returns (sparse_template_or_None, seqlen, feats[T, H, C])."""
if isinstance(t, VarLenTensor):
seqlen = [t.layout[i].stop - t.layout[i].start for i in range(t.shape[0])]
return t, seqlen, t.feats
N, L = t.shape[:2]
return None, [L] * N, t.reshape(-1, *t.shape[-2:])
def dense_attention(q, k, v, **kwargs): def dense_attention(q, k, v, **kwargs):
"""q, k, v: [B, L, H, C]. Permutes for comfy's [B, H, L, C] convention.""" """q, k, v: [B, L, H, C]. Permutes for comfy's [B, H, L, C] convention."""
heads = q.shape[2] heads = q.shape[2]
@ -38,86 +12,36 @@ def dense_attention(q, k, v, **kwargs):
return out.permute(0, 2, 1, 3) return out.permute(0, 2, 1, 3)
def _to_rect(t):
"""Fold a VarLenTensor packed as [sum(L_i), H, C] into a dense [B, L, H, C].
The sparse generation stages run a single object per call (optionally
CFG-duplicated, which keeps every batch entry the same length), so the
packed layout is rectangular and attention is ordinary dense attention over
a batch dim no variable-length kernel needed. A dense [B, L, H, C] tensor
(e.g. cross-attention context) passes through unchanged.
"""
if not isinstance(t, VarLenTensor):
return t
B = t.shape[0]
seqlens = [t.layout[i].stop - t.layout[i].start for i in range(B)]
if len(set(seqlens)) != 1:
raise ValueError(
"trellis2 sparse attention expects equal sequence lengths per batch "
f"(single object, optionally CFG-duplicated); got {seqlens}. "
"Multi-object batching is not supported."
)
return t.feats.view(B, seqlens[0], *t.feats.shape[1:])
def sparse_attention(q, k, v, **kwargs): def sparse_attention(q, k, v, **kwargs):
"""Full attention over a SparseTensor's voxels.
Single object (optionally CFG-duplicated) => the packed layout is
rectangular, so we fold it into a batch dim and run ordinary dense
attention. Output type matches q.
""" """
Varlen attention for SparseTensor inputs. Each of q, k, v may be a VarLenTensor out = dense_attention(_to_rect(q), _to_rect(k), _to_rect(v), **kwargs) # [B, Lq, H, C]
(sparse) or dense [B, L, H, C]. Output type matches q. Backend dispatch lives if isinstance(q, VarLenTensor):
in comfy.ldm.modules.attention.optimized_attention; we just build cu_seqlens return q.replace(out.reshape(-1, *out.shape[2:]))
from the layouts. return out
"""
s, q_seqlen, q_feats = _layout_to_feats(q)
_, kv_seqlen, k_feats = _layout_to_feats(k)
_, _, v_feats = _layout_to_feats(v)
heads = q_feats.shape[1]
device = q_feats.device
cu_seqlens_q = _build_cu_seqlens(q_seqlen, device)
cu_seqlens_kv = _build_cu_seqlens(kv_seqlen, device)
out = optimized_attention(
q_feats, k_feats, v_feats, heads,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max(q_seqlen), max_seqlen_kv=max(kv_seqlen),
skip_reshape=True, skip_output_reshape=True,
**kwargs,
)
if s is not None:
return s.replace(out)
N, L = q.shape[:2]
return out.reshape(N, L, heads, -1)
def sparse_windowed_self_attention(qkv, window_size: int, shift_window: Tuple[int, int, int] = (0, 0, 0)):
"""Windowed sparse self-attention. Partitions voxels into windows via spatial
sort, then runs varlen attention with one sequence per non-empty window."""
cache_name = f'windowed_attention_{window_size}_{shift_window}'
cache = qkv.get_spatial_cache(cache_name)
if cache is None:
cache = calc_window_partition(qkv, window_size, shift_window)
qkv.register_spatial_cache(cache_name, cache)
fwd_indices, bwd_indices, seq_lens = cache
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
q, k, v = qkv_feats.unbind(dim=1) # each [M, H, C]
heads = q.shape[1]
device = q.device
cu_seqlens = _build_cu_seqlens(seq_lens, device)
max_seqlen = int(seq_lens.max())
out = optimized_attention(
q, k, v, heads,
cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens,
max_seqlen_q=max_seqlen, max_seqlen_kv=max_seqlen,
skip_reshape=True, skip_output_reshape=True,
)
out = out[bwd_indices]
return qkv.replace(out)
def calc_window_partition(
tensor,
window_size: Union[int, Tuple[int, ...]],
shift_window: Union[int, Tuple[int, ...]] = 0,
):
"""Returns (fwd_indices, bwd_indices, seq_lens) for window partitioning."""
DIM = tensor.coords.shape[1] - 1
shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
shifted_coords = tensor.coords.clone().detach()
shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)]
NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
fwd_indices = torch.argsort(shifted_indices)
bwd_indices = torch.empty_like(fwd_indices)
bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
seq_lens = torch.bincount(shifted_indices)
seq_lens = seq_lens[seq_lens != 0]
return fwd_indices, bwd_indices, seq_lens

View File

@ -3,8 +3,6 @@ import torch
import comfy.model_management import comfy.model_management
UINT32_SENTINEL = 0xFFFFFFFF
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device): def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
"""Kernel spatial offsets in the same order as the CUDA/Triton kernels.""" """Kernel spatial offsets in the same order as the CUDA/Triton kernels."""
@ -19,11 +17,9 @@ def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
class TorchHashMap: class TorchHashMap:
"""Sorted-array hashmap backed by torch.searchsorted.""" """Sorted-array hashmap backed by torch.searchsorted."""
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): def __init__(self, keys: torch.Tensor, values: torch.Tensor):
device = keys.device
self.sorted_keys, order = torch.sort(keys.to(torch.long)) self.sorted_keys, order = torch.sort(keys.to(torch.long))
self.sorted_vals = values.to(torch.long)[order] self.sorted_vals = values.to(torch.long)[order]
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
self._n = self.sorted_keys.numel() self._n = self.sorted_keys.numel()
# Chunk size for lookup_flat, caps each transient to ~CHUNK rows. # Chunk size for lookup_flat, caps each transient to ~CHUNK rows.
@ -54,65 +50,28 @@ def build_submanifold_neighbor_map(
Kw, Kh, Kd, Kw, Kh, Kd,
Dw, Dh, Dd, Dw, Dh, Dd,
): ):
# neighbor[i, v] = index of the voxel at voxel i's coord + kernel-offset v, or -1.
# Chunked over voxels so the [chunk, V, 3] candidate transient stays bounded.
device = coords.device device = coords.device
M = coords.shape[0] M = coords.shape[0]
V = Kw * Kh * Kd offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device).long() # [V, 3]
half_V = V // 2 + 1 V = offsets.shape[0]
INVALID = -1 center = torch.tensor([(Kw // 2) * Dw, (Kh // 2) * Dh, (Kd // 2) * Dd], device=device)
WHD, HD = W * H * D, H * D
# int32 neighbour map: 4 bytes/elem vs 8 bytes for int64 neighbor = torch.empty((M, V), dtype=torch.int32, device=device)
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.int32) # ~V*40 bytes/voxel of transient (int64 cand + flat + masks); cap at ~0.5 GB.
chunk = max(1, min(M, int(0.5 * (1024 ** 3) / (V * 40))))
b = coords[:, 0].long()
x = coords[:, 1].long()
y = coords[:, 2].long()
z = coords[:, 3].long()
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device)
ox = x - (Kw // 2) * Dw
oy = y - (Kh // 2) * Dh
oz = z - (Kd // 2) * Dd
for v in range(half_V):
if v == half_V - 1:
# Center voxel always maps to itself
neighbor[:, v] = torch.arange(M, device=device, dtype=torch.int32)
continue
dx, dy, dz = offsets[v]
kx = ox + dx
ky = oy + dy
kz = oz + dz
valid = (
(kx >= 0) & (kx < W) &
(ky >= 0) & (ky < H) &
(kz >= 0) & (kz < D)
)
flat = (
b[valid] * (W * H * D) +
kx[valid] * (H * D) +
ky[valid] * D +
kz[valid]
)
if flat.numel() > 0:
found = hashmap.lookup_flat(flat)
idx_in_M = torch.where(valid)[0]
neighbor[idx_in_M, v] = found.to(torch.int32)
# BUG FIX: old code used found != hashmap.default_value which
# compared int32 -1 against int64 4294967295 → always True.
# We now explicitly check for valid indices.
valid_found_mask = found >= 0
if valid_found_mask.any():
src_points = idx_in_M[valid_found_mask]
dst_points = found[valid_found_mask].long()
neighbor[dst_points, V - 1 - v] = src_points.to(torch.int32)
for s in range(0, M, chunk):
e = min(s + chunk, M)
b = coords[s:e, 0].long()
cand = coords[s:e, 1:4].long()[:, None, :] + offsets[None, :, :] - center # [c, V, 3]
x, y, z = cand[..., 0], cand[..., 1], cand[..., 2]
in_bounds = (x >= 0) & (x < W) & (y >= 0) & (y < H) & (z >= 0) & (z < D) # [c, V]
flat = b[:, None] * WHD + x * HD + y * D + z # [c, V]
flat = torch.where(in_bounds, flat, torch.full_like(flat, -1)) # OOB -> guaranteed miss
neighbor[s:e] = hashmap.lookup_flat(flat.reshape(-1)).view(e - s, V)
return neighbor return neighbor
def get_recommended_chunk_mem( def get_recommended_chunk_mem(
@ -155,7 +114,7 @@ def sparse_submanifold_conv3d(
coords[:, 2].long() * y_stride + coords[:, 2].long() * y_stride +
coords[:, 3].long() * z_stride) coords[:, 3].long() * z_stride)
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=device) vals = torch.arange(coords.shape[0], dtype=torch.int32, device=device)
hashmap = TorchHashMap(flat_keys, vals, UINT32_SENTINEL) hashmap = TorchHashMap(flat_keys, vals)
neighbor = build_submanifold_neighbor_map( neighbor = build_submanifold_neighbor_map(
hashmap, coords, W, H, D, Kw, Kh, Kd, hashmap, coords, W, H, D, Kw, Kh, Kd,
@ -165,13 +124,15 @@ def sparse_submanifold_conv3d(
neighbor = neighbor_cache neighbor = neighbor_cache
N_pts = feats.shape[0] N_pts = feats.shape[0]
sentinel = -1
weight_T = weight.view(Co, V * Ci).T weight_T = weight.view(Co, V * Ci).T
output = torch.empty(N_pts, Co, device=device, dtype=feats.dtype) output = torch.empty(N_pts, Co, device=device, dtype=feats.dtype)
# Chunk size from memory budget. The dominant peak is `gathered`, of shape (chunk, V, Ci) in feats.dtype. # Zero row at index N_pts; missing neighbors (-1) gather it -> no separate masking.
feats_padded = torch.cat([feats, feats.new_zeros(1, Ci)], dim=0)
# Chunk over voxels to bound the (chunk, V, Ci) gather.
max_chunk_mem_gb = get_recommended_chunk_mem(device) max_chunk_mem_gb = get_recommended_chunk_mem(device)
mem_per_row = V * Ci * feats.element_size() mem_per_row = V * Ci * feats.element_size()
max_chunk_mem = max_chunk_mem_gb * (1024 ** 3) max_chunk_mem = max_chunk_mem_gb * (1024 ** 3)
@ -182,18 +143,10 @@ def sparse_submanifold_conv3d(
end = min(start + chunk_size, N_pts) end = min(start + chunk_size, N_pts)
actual_chunk = end - start actual_chunk = end - start
chunk_neighbor = neighbor[start:end] chunk_idx = torch.where(neighbor[start:end] < 0, N_pts, neighbor[start:end]) # -1 -> zero row
chunk_valid = chunk_neighbor != sentinel gathered = feats_padded[chunk_idx] # (chunk, V, Ci)
# clamp(-1 -> 0) keeps invalid indices in-range so the gather is safe
chunk_idx = chunk_neighbor.clamp(min=0)
# (chunk, V, Ci) gather, then in-place zero of invalid neighbors.
gathered = feats[chunk_idx]
gathered.mul_(chunk_valid.unsqueeze(-1))
# GEMM (chunk, V*Ci) @ (V*Ci, Co) -> (chunk, Co), written to output[start:end].
gathered_flat = gathered.view(actual_chunk, V * Ci) gathered_flat = gathered.view(actual_chunk, V * Ci)
torch.matmul(gathered_flat, weight_T, out=output[start:end]) torch.matmul(gathered_flat, weight_T, out=output[start:end]) # (chunk, V*Ci) @ (V*Ci, Co)
if bias is not None: if bias is not None:
output += bias.unsqueeze(0).to(output.dtype) output += bias.unsqueeze(0).to(output.dtype)

View File

@ -3,12 +3,9 @@ import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor
from typing import Optional, Tuple, Literal, Union, List from typing import Optional, Tuple, Literal, Union, List
from comfy.ldm.trellis2.attention import ( from comfy.ldm.trellis2.attention import sparse_attention, dense_attention
sparse_windowed_self_attention, sparse_attention, dense_attention
)
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
from comfy.ldm.flux.math import apply_rope, apply_rope1 from comfy.ldm.flux.math import apply_rope, apply_rope1
from comfy.ldm.trellis2 import sampling_preview
class SparseGELU(nn.GELU): class SparseGELU(nn.GELU):
def forward(self, input: VarLenTensor) -> VarLenTensor: def forward(self, input: VarLenTensor) -> VarLenTensor:
@ -26,13 +23,6 @@ class SparseFeedForwardNet(nn.Module):
def forward(self, x: VarLenTensor) -> VarLenTensor: def forward(self, x: VarLenTensor) -> VarLenTensor:
return self.mlp(x) return self.mlp(x)
class LayerNorm32(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype
x = x.to(dtype=torch.float32)
o = super().forward(x)
return o.to(dtype=x_dtype)
class SparseMultiHeadRMSNorm(nn.Module): class SparseMultiHeadRMSNorm(nn.Module):
def __init__(self, dim: int, heads: int, device, dtype): def __init__(self, dim: int, heads: int, device, dtype):
super().__init__() super().__init__()
@ -44,13 +34,7 @@ class SparseMultiHeadRMSNorm(nn.Module):
return F.rms_norm(x, (x.shape[-1],)) * self.gamma return F.rms_norm(x, (x.shape[-1],)) * self.gamma
class SparseRotaryPositionEmbedder(nn.Module): class SparseRotaryPositionEmbedder(nn.Module):
def __init__( def __init__(self, head_dim: int, dim: int = 3, rope_freq: Tuple[float, float] = (1.0, 10000.0), device=None):
self,
head_dim: int,
dim: int = 3,
rope_freq: Tuple[float, float] = (1.0, 10000.0),
device=None
):
super().__init__() super().__init__()
self.head_dim = head_dim self.head_dim = head_dim
self.dim = dim self.dim = dim
@ -111,12 +95,7 @@ class SparseMultiHeadAttention(nn.Module):
num_heads: int, num_heads: int,
ctx_channels: Optional[int] = None, ctx_channels: Optional[int] = None,
type: Literal["self", "cross"] = "self", type: Literal["self", "cross"] = "self",
attn_mode: Literal["full", "windowed", "double_windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
qkv_bias: bool = True, qkv_bias: bool = True,
use_rope: bool = False,
rope_freq: Tuple[int, int] = (1.0, 10000.0),
qk_rms_norm: bool = False, qk_rms_norm: bool = False,
device=None, dtype=None, operations=None device=None, dtype=None, operations=None
): ):
@ -127,10 +106,6 @@ class SparseMultiHeadAttention(nn.Module):
self.ctx_channels = ctx_channels if ctx_channels is not None else channels self.ctx_channels = ctx_channels if ctx_channels is not None else channels
self.num_heads = num_heads self.num_heads = num_heads
self._type = type self._type = type
self.attn_mode = attn_mode
self.window_size = window_size
self.shift_window = shift_window
self.use_rope = use_rope
self.qk_rms_norm = qk_rms_norm self.qk_rms_norm = qk_rms_norm
if self._type == "self": if self._type == "self":
@ -145,8 +120,8 @@ class SparseMultiHeadAttention(nn.Module):
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype) self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
if use_rope: if self._type == "self":
self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq, device=device) self.rope = SparseRotaryPositionEmbedder(self.head_dim, device=device)
@staticmethod @staticmethod
def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
@ -170,43 +145,16 @@ class SparseMultiHeadAttention(nn.Module):
x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats
def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None, def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None, transformer_options=None) -> SparseTensor:
transformer_options=None) -> SparseTensor:
if self._type == "self": if self._type == "self":
qkv = self._linear(self.to_qkv, x) qkv = self._linear(self.to_qkv, x)
qkv = self._fused_pre(qkv, num_fused=3) qkv = self._fused_pre(qkv, num_fused=3)
if self.attn_mode == "full": q, k, v = qkv.unbind(dim=-3)
q, k, v = qkv.unbind(dim=-3) if self.qk_rms_norm:
if self.qk_rms_norm: q = self.q_rms_norm(q)
q = self.q_rms_norm(q) k = self.k_rms_norm(k)
k = self.k_rms_norm(k) q, k = self.rope(q, k)
if self.use_rope: h = sparse_attention(q, k, v, transformer_options=transformer_options)
q, k = self.rope(q, k)
h = sparse_attention(q, k, v, transformer_options=transformer_options)
else:
# Windowed paths take packed qkv; preserve any per-head norm/rope.
if self.qk_rms_norm or self.use_rope:
q, k, v = qkv.unbind(dim=-3)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
if self.use_rope:
q, k = self.rope(q, k)
qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
if self.attn_mode == "windowed":
h = sparse_windowed_self_attention(
qkv, self.window_size, shift_window=self.shift_window
)
elif self.attn_mode == "double_windowed":
qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:])
qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2])
h0 = sparse_windowed_self_attention(
qkv0, self.window_size, shift_window=(0, 0, 0)
)
h1 = sparse_windowed_self_attention(
qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3)
)
h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1))
else: else:
q = self._linear(self.to_q, x) q = self._linear(self.to_q, x)
q = self._reshape_chs(q, (self.num_heads, -1)) q = self._reshape_chs(q, (self.num_heads, -1))
@ -276,37 +224,25 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
ctx_channels: int, ctx_channels: int,
num_heads: int, num_heads: int,
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
attn_mode: Literal["full", "swin"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
rope_freq: Tuple[float, float] = (1.0, 10000.0),
qk_rms_norm: bool = False, qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False, qk_rms_norm_cross: bool = False,
qkv_bias: bool = True, qkv_bias: bool = True,
share_mod: bool = False, share_mod: bool = False,
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global", image_attn_mode: Literal["global", "proj"] = "global",
proj_in_channels: Optional[int] = None, proj_in_channels: Optional[int] = None,
device=None, dtype=None, operations=None device=None, dtype=None, operations=None
): ):
super().__init__() super().__init__()
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod self.share_mod = share_mod
self.image_attn_mode = image_attn_mode self.image_attn_mode = image_attn_mode
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device) self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.self_attn = SparseMultiHeadAttention( self.self_attn = SparseMultiHeadAttention(
channels, channels,
num_heads=num_heads, num_heads=num_heads,
type="self", type="self",
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm, qk_rms_norm=qk_rms_norm,
device=device, dtype=dtype, operations=operations device=device, dtype=dtype, operations=operations
) )
@ -315,7 +251,6 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
ctx_channels=ctx_channels, ctx_channels=ctx_channels,
num_heads=num_heads, num_heads=num_heads,
type="cross", type="cross",
attn_mode="full",
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross, qk_rms_norm=qk_rms_norm_cross,
device=device, dtype=dtype, operations=operations device=device, dtype=dtype, operations=operations
@ -347,8 +282,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
else: else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
# Fuse the (mul + add) and (mul + residual) pairs into addcmul so the # Fuse the (mul + add) and (mul + residual) pairs into addcmul
# mod/shift broadcasts hit one kernel each instead of two.
b_map = x.batch_boardcast_map b_map = x.batch_boardcast_map
h_feats = self.norm1(x.feats) h_feats = self.norm1(x.feats)
@ -386,18 +320,12 @@ class SLatFlowModel(nn.Module):
num_heads: Optional[int] = None, num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64, num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4, mlp_ratio: float = 4,
pe_mode: Literal["ape", "rope"] = "rope",
rope_freq: Tuple[float, float] = (1.0, 10000.0),
use_checkpoint: bool = False,
share_mod: bool = False, share_mod: bool = False,
initialization: str = 'vanilla',
qk_rms_norm: bool = False, qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False, qk_rms_norm_cross: bool = False,
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global", image_attn_mode: Literal["global", "proj"] = "global",
proj_in_channels: Optional[int] = None, proj_in_channels: Optional[int] = None,
dtype = None, dtype = None, device = None, operations = None,
device = None,
operations = None,
): ):
super().__init__() super().__init__()
self.resolution = resolution self.resolution = resolution
@ -408,10 +336,7 @@ class SLatFlowModel(nn.Module):
self.num_blocks = num_blocks self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.pe_mode = pe_mode
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod self.share_mod = share_mod
self.initialization = initialization
self.qk_rms_norm = qk_rms_norm self.qk_rms_norm = qk_rms_norm
self.qk_rms_norm_cross = qk_rms_norm_cross self.qk_rms_norm_cross = qk_rms_norm_cross
self.image_attn_mode = image_attn_mode self.image_attn_mode = image_attn_mode
@ -433,10 +358,6 @@ class SLatFlowModel(nn.Module):
cond_channels, cond_channels,
num_heads=self.num_heads, num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio, mlp_ratio=self.mlp_ratio,
attn_mode='full',
use_checkpoint=self.use_checkpoint,
use_rope=(pe_mode == "rope"),
rope_freq=rope_freq,
share_mod=self.share_mod, share_mod=self.share_mod,
qk_rms_norm=self.qk_rms_norm, qk_rms_norm=self.qk_rms_norm,
qk_rms_norm_cross=self.qk_rms_norm_cross, qk_rms_norm_cross=self.qk_rms_norm_cross,
@ -491,14 +412,22 @@ class FeedForwardNet(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x) return self.mlp(x)
# class MultiHeadRMSNorm(nn.Module):
# def __init__(self, dim: int, heads: int, device=None, dtype=None):
# super().__init__()
# self.scale = dim ** 0.5
# self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
# def forward(self, x: torch.Tensor) -> torch.Tensor:
# return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
class MultiHeadRMSNorm(nn.Module): class MultiHeadRMSNorm(nn.Module):
def __init__(self, dim: int, heads: int, device=None, dtype=None): def __init__(self, dim: int, heads: int, device=None, dtype=None):
super().__init__() super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype)) self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) return (F.rms_norm(x.float(), (x.shape[-1],)) * self.gamma).to(x.dtype)
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
@ -508,12 +437,7 @@ class MultiHeadAttention(nn.Module):
num_heads: int, num_heads: int,
ctx_channels: Optional[int]=None, ctx_channels: Optional[int]=None,
type: Literal["self", "cross"] = "self", type: Literal["self", "cross"] = "self",
attn_mode: Literal["full", "windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
qkv_bias: bool = True, qkv_bias: bool = True,
use_rope: bool = False,
rope_freq: Tuple[float, float] = (1.0, 10000.0),
qk_rms_norm: bool = False, qk_rms_norm: bool = False,
device=None, dtype=None, operations=None device=None, dtype=None, operations=None
): ):
@ -524,10 +448,6 @@ class MultiHeadAttention(nn.Module):
self.ctx_channels = ctx_channels if ctx_channels is not None else channels self.ctx_channels = ctx_channels if ctx_channels is not None else channels
self.num_heads = num_heads self.num_heads = num_heads
self._type = type self._type = type
self.attn_mode = attn_mode
self.window_size = window_size
self.shift_window = shift_window
self.use_rope = use_rope
self.qk_rms_norm = qk_rms_norm self.qk_rms_norm = qk_rms_norm
if self._type == "self": if self._type == "self":
@ -552,12 +472,11 @@ class MultiHeadAttention(nn.Module):
if self.qk_rms_norm: if self.qk_rms_norm:
q = self.q_rms_norm(q) q = self.q_rms_norm(q)
k = self.k_rms_norm(k) k = self.k_rms_norm(k)
if self.use_rope: assert phases is not None, "Phases must be provided for RoPE"
assert phases is not None, "Phases must be provided for RoPE" # phases is [L, head_dim/2, 2, 2]; broadcast to [1, L, 1, ...]
# phases is [L, head_dim/2, 2, 2]; broadcast to [1, L, 1, ...] # to align with q/k of shape [B, L, H, head_dim].
# to align with q/k of shape [B, L, H, head_dim]. f_cis = phases.unsqueeze(0).unsqueeze(2)
f_cis = phases.unsqueeze(0).unsqueeze(2) q, k = apply_rope(q, k, f_cis)
q, k = apply_rope(q, k, f_cis)
h = dense_attention(q, k, v, transformer_options=transformer_options) h = dense_attention(q, k, v, transformer_options=transformer_options)
else: else:
Lkv = context.shape[1] Lkv = context.shape[1]
@ -581,37 +500,25 @@ class ModulatedTransformerCrossBlock(nn.Module):
ctx_channels: int, ctx_channels: int,
num_heads: int, num_heads: int,
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
attn_mode: Literal["full", "windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
rope_freq: Tuple[int, int] = (1.0, 10000.0),
qk_rms_norm: bool = False, qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False, qk_rms_norm_cross: bool = False,
qkv_bias: bool = True, qkv_bias: bool = True,
share_mod: bool = False, share_mod: bool = False,
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global", image_attn_mode: Literal["global", "proj"] = "global",
proj_in_channels: Optional[int] = None, proj_in_channels: Optional[int] = None,
device=None, dtype=None, operations=None device=None, dtype=None, operations=None
): ):
super().__init__() super().__init__()
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod self.share_mod = share_mod
self.image_attn_mode = image_attn_mode self.image_attn_mode = image_attn_mode
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device) self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
self.self_attn = MultiHeadAttention( self.self_attn = MultiHeadAttention(
channels, channels,
num_heads=num_heads, num_heads=num_heads,
type="self", type="self",
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm, qk_rms_norm=qk_rms_norm,
device=device, dtype=dtype, operations=operations device=device, dtype=dtype, operations=operations
) )
@ -620,7 +527,6 @@ class ModulatedTransformerCrossBlock(nn.Module):
ctx_channels=ctx_channels, ctx_channels=ctx_channels,
num_heads=num_heads, num_heads=num_heads,
type="cross", type="cross",
attn_mode="full",
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross, qk_rms_norm=qk_rms_norm_cross,
device=device, dtype=dtype, operations=operations device=device, dtype=dtype, operations=operations
@ -640,25 +546,17 @@ class ModulatedTransformerCrossBlock(nn.Module):
device=device, dtype=dtype, operations=operations device=device, dtype=dtype, operations=operations
) )
if not share_mod: if not share_mod:
self.adaLN_modulation = nn.Sequential( self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
nn.SiLU(),
operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)
)
else: else:
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context, def _forward(self, x: torch.Tensor, mod: torch.Tensor, context,
phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor: phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor:
if self.share_mod: if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) mod = (self.modulation + mod).type(mod.dtype)
else: else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) mod = self.adaLN_modulation(mod)
shift_msa = shift_msa.unsqueeze(1) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(6, dim=-1)
scale_msa = scale_msa.unsqueeze(1)
gate_msa = gate_msa.unsqueeze(1)
shift_mlp = shift_mlp.unsqueeze(1)
scale_mlp = scale_mlp.unsqueeze(1)
gate_mlp = gate_mlp.unsqueeze(1)
h = torch.addcmul(shift_msa, self.norm1(x), 1 + scale_msa) h = torch.addcmul(shift_msa, self.norm1(x), 1 + scale_msa)
h = self.self_attn(h, phases=phases, transformer_options=transformer_options) h = self.self_attn(h, phases=phases, transformer_options=transformer_options)
@ -694,14 +592,10 @@ class SparseStructureFlowModel(nn.Module):
num_heads: Optional[int] = None, num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64, num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4, mlp_ratio: float = 4,
pe_mode: Literal["ape", "rope"] = "rope",
rope_freq: Tuple[float, float] = (1.0, 10000.0),
use_checkpoint: bool = False,
share_mod: bool = False, share_mod: bool = False,
initialization: str = 'vanilla',
qk_rms_norm: bool = False, qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False, qk_rms_norm_cross: bool = False,
image_attn_mode: Literal["global", "proj", "gated_proj"] = "global", image_attn_mode: Literal["global", "proj"] = "global",
proj_in_channels: Optional[int] = None, proj_in_channels: Optional[int] = None,
operations=None, operations=None,
device = None, device = None,
@ -718,10 +612,7 @@ class SparseStructureFlowModel(nn.Module):
self.num_blocks = num_blocks self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.pe_mode = pe_mode
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod self.share_mod = share_mod
self.initialization = initialization
self.qk_rms_norm = qk_rms_norm self.qk_rms_norm = qk_rms_norm
self.qk_rms_norm_cross = qk_rms_norm_cross self.qk_rms_norm_cross = qk_rms_norm_cross
self.image_attn_mode = image_attn_mode self.image_attn_mode = image_attn_mode
@ -742,9 +633,6 @@ class SparseStructureFlowModel(nn.Module):
rope_phases = pos_embedder(coords) rope_phases = pos_embedder(coords)
self.register_buffer("rope_phases", rope_phases, persistent=False) self.register_buffer("rope_phases", rope_phases, persistent=False)
if pe_mode != "rope":
self.rope_phases = None
self.input_layer = operations.Linear(in_channels, model_channels, device=device, dtype=dtype) self.input_layer = operations.Linear(in_channels, model_channels, device=device, dtype=dtype)
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
@ -753,10 +641,6 @@ class SparseStructureFlowModel(nn.Module):
cond_channels, cond_channels,
num_heads=self.num_heads, num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio, mlp_ratio=self.mlp_ratio,
attn_mode='full',
use_checkpoint=self.use_checkpoint,
use_rope=(pe_mode == "rope"),
rope_freq=rope_freq,
share_mod=share_mod, share_mod=share_mod,
qk_rms_norm=self.qk_rms_norm, qk_rms_norm=self.qk_rms_norm,
qk_rms_norm_cross=self.qk_rms_norm_cross, qk_rms_norm_cross=self.qk_rms_norm_cross,
@ -788,18 +672,10 @@ class SparseStructureFlowModel(nn.Module):
return h return h
def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
t_shifted = t_shifted / 1000.0
t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1))
t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear)
t_new *= 1000.0
return t_new
# Pixal3D ProjGrid math
# Pixal3D ProjGrid math — port of upstream's ProjGrid + project_points_to_image_batch. # World frame uses world Y as depth, camera looks along -Z local
# World frame uses world Y as depth, camera looks along -Z local; # transform_matrix is camera-to-world (inverted internally). Intrinsics: fx = 16 / tan(fov/2) with sensor_width = 32mm.
# transform_matrix is camera-to-world (inverted internally). Intrinsics: fx = 16 / tan(fov/2)
# with sensor_width = 32mm.
_PROJ_GRID_ROTATION = torch.tensor( _PROJ_GRID_ROTATION = torch.tensor(
[[1.0, 0.0, 0.0], [[1.0, 0.0, 0.0],
@ -815,7 +691,7 @@ _PROJ_FRONT_VIEW_TRANSFORM = torch.tensor(
) )
def _build_proj_transform_matrix(distance: torch.Tensor, batch_size: int, def build_proj_transform_matrix(distance: torch.Tensor, batch_size: int,
device, dtype=torch.float32) -> torch.Tensor: device, dtype=torch.float32) -> torch.Tensor:
T = _PROJ_FRONT_VIEW_TRANSFORM.to(device=device, dtype=dtype) T = _PROJ_FRONT_VIEW_TRANSFORM.to(device=device, dtype=dtype)
T = T.unsqueeze(0).expand(batch_size, -1, -1).clone() T = T.unsqueeze(0).expand(batch_size, -1, -1).clone()
@ -849,8 +725,7 @@ def _project_points_to_image(points_world: torch.Tensor, transform_matrix: torch
def _sample_features(feature_map: torch.Tensor, uv_ndc: torch.Tensor) -> torch.Tensor: def _sample_features(feature_map: torch.Tensor, uv_ndc: torch.Tensor) -> torch.Tensor:
B, C, _, _ = feature_map.shape B, C, _, _ = feature_map.shape
grid = uv_ndc.view(B, -1, 1, 2).to(feature_map.dtype) grid = uv_ndc.view(B, -1, 1, 2).to(feature_map.dtype)
feat = F.grid_sample(feature_map, grid, mode="bilinear", feat = F.grid_sample(feature_map, grid, mode="bilinear", padding_mode="border", align_corners=False)
padding_mode="border", align_corners=False)
return feat.squeeze(-1) return feat.squeeze(-1)
@ -922,8 +797,6 @@ def _back_project_to_tokens(
return out return out
def _select_stage_entry(proj_pack: dict, stage: Optional[str]): def _select_stage_entry(proj_pack: dict, stage: Optional[str]):
"""Returns (feature_map_lr, feature_map_hr_or_None, image_resolution).""" """Returns (feature_map_lr, feature_map_hr_or_None, image_resolution)."""
stages = proj_pack.get("stages") stages = proj_pack.get("stages")
@ -1025,7 +898,7 @@ def _shape_proj_cond(global_cond: torch.Tensor, image_attn_mode: str,
else: else:
proj_feats = proj_feats.repeat((repeats, 1)) proj_feats = proj_feats.repeat((repeats, 1))
# Mirror upstream's neg_cond by zeroing proj for any uncond batch slot. # zero proj for any uncond batch slot
if cond_or_uncond is not None and eval_batch is not None: if cond_or_uncond is not None and eval_batch is not None:
uncond_slots = [i for i, v in enumerate(cond_or_uncond) if v == 1] uncond_slots = [i for i, v in enumerate(cond_or_uncond) if v == 1]
if uncond_slots: if uncond_slots:
@ -1100,9 +973,6 @@ class Trellis2(nn.Module):
# Pre-computed per-stage back-projected features # Pre-computed per-stage back-projected features
proj_feats = kwargs.get("trellis2_proj_feats") proj_feats = kwargs.get("trellis2_proj_feats")
sampling_preview.set_context(mode=mode, coords=coords, coord_counts=coord_counts,
model_frame=kwargs.get("trellis2_model_frame"))
is_first_shape_pass = False is_first_shape_pass = False
if mode == "shape_generation_512": if mode == "shape_generation_512":
is_first_shape_pass = True is_first_shape_pass = True

View File

@ -75,153 +75,10 @@ class Latent2RGBPreviewer(LatentPreviewer):
return preview_to_image(latent_image) return preview_to_image(latent_image)
class Trellis3DPreviewer(LatentPreviewer):
"""Per-step preview for the Trellis2/Pixal3D cascade.
Structure stage: x0 is a dense [B, 32, 16, 16, 16] grid project the per-cell
activation norm orthographically to a 2D occupancy heatmap (no decode, no coords).
Texture stage: x0 is sparse [B, 32, N, 1] splat the first 3 latent channels as
pseudo-color onto the fixed voxel coords (read from the sampling side-channel).
Shape stage adds no visible motion (coords are fixed, only sub-voxel detail
evolves) and a full decode per step is too costly, so it's skipped.
Both stages render through one orthographic point splatter (static view).
"""
_SIZE = 128
_FILL = 0.9 # fraction of frame the texture splat spans (leaves a border)
_STRUCTURE_ZOOM = 0.66 # <1 pulls the SS camera back, leaving margin around the blob
def _splat(self, points, colors, rad):
# points: [K, 3] voxel-index coords. colors: [K, 3] in [0, 1].
# Center + isotropic-normalize, project orthographically front-on
# (x->horizontal, y->up, z->depth), then splat a square footprint per point
# with one global far->near sort (painter's).
S = self._SIZE
dev = points.device # keep every tensor here
p = points.float()
p = p - (p.amax(0) + p.amin(0)) * 0.5
p = p / p.abs().amax().clamp(min=1e-8)
x, y, z = p[:, 0], p[:, 1], p[:, 2]
depth = z # into-screen
m = self._FILL
u = ((x * m * 0.5 + 0.5) * (S - 1)).long().clamp(0, S - 1)
v = (((-y) * m * 0.5 + 0.5) * (S - 1)).long().clamp(0, S - 1) # image up = +y
cols = colors.to(dev)
us, vs, ds, cs = [], [], [], []
for dv in range(-rad, rad + 1):
for du in range(-rad, rad + 1):
us.append((u + du).clamp(0, S - 1))
vs.append((v + dv).clamp(0, S - 1))
ds.append(depth)
cs.append(cols)
order = torch.cat(ds).argsort()
img = torch.zeros(S, S, 3, device=dev)
img[torch.cat(vs)[order], torch.cat(us)[order]] = torch.cat(cs)[order]
return preview_to_image(img, do_scale=False)
@staticmethod
def _turbo(x):
# Anton Mikhailov polynomial approximation of the turbo colormap. x: any shape
# in [0, 1] -> (..., 3) RGB.
x = x.clamp(0.0, 1.0)
x2 = x * x; x3 = x2 * x; x4 = x2 * x2; x5 = x4 * x
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)
def _structure(self, x0):
# x0: [B, 32, D, H, W]; the model only consumes the first 8 channels.
# Dense orthographic max-projection -> filled occupancy heatmap (turbo-colored,
# intensity-weighted so empty space stays black).
act = x0[0, :min(8, x0.shape[1])].float().norm(dim=0) # [D, H, W]
proj = act.amax(dim=2) # project along one axis
proj = (proj - proj.amin()) / (proj.amax() - proj.amin() + 1e-8)
inner = max(1, int(round(self._SIZE * self._STRUCTURE_ZOOM)))
img = torch.nn.functional.interpolate(proj[None, None], size=(inner, inner), mode="nearest")
pad = self._SIZE - inner
pl, pt = pad // 2, pad // 2
gray = torch.nn.functional.pad(img, (pl, pad - pl, pt, pad - pt))[0, 0] # [S, S], zero margin
rgb = self._turbo(gray) * gray.unsqueeze(-1) # [S, S, 3], black where empty
return preview_to_image(rgb, do_scale=False)
@staticmethod
def _latent_color(latent):
# Prefer the calibrated latent->base_color map (fit from real decoded
# albedo by VaeDecodeTextureTrellis); fall back to PCA pseudo-color until
# a texture decode has trained it.
try:
from comfy.ldm.trellis2 import sampling_preview
factors = sampling_preview.get_tex_rgb()
except Exception:
factors = None
if factors is not None:
W, b = factors
rgb = latent @ W.to(latent) + b.to(latent)
return rgb.clamp(0, 1)
return Trellis3DPreviewer._pca_color(latent)
@staticmethod
def _pca_color(latent):
# latent: [n, C]. Map the 3 directions of maximum variance to RGB.
# Higher contrast and more coherent than picking 3 fixed channels.
X = latent - latent.mean(dim=0, keepdim=True)
cov = (X.transpose(0, 1) @ X) / max(X.shape[0] - 1, 1) # [C, C]
_, evecs = torch.linalg.eigh(cov) # ascending eigenvalues
pcs = evecs[:, -3:] # [C, 3] top-3 components
# Deterministic sign per component (largest-magnitude entry positive) to
# stop the preview's hues from flickering as the latent rotates each step.
sign = torch.sign(pcs[pcs.abs().argmax(dim=0), torch.arange(3, device=pcs.device)])
pcs = pcs * sign.clamp(min=-1.0)
proj = X @ pcs # [n, 3]
pmin = proj.amin(dim=0, keepdim=True)
pmax = proj.amax(dim=0, keepdim=True)
return ((proj - pmin) / (pmax - pmin + 1e-8)).clamp(0, 1)
def _texture(self, x0, coords, model_frame=None):
if coords.shape[-1] == 4:
b0 = coords[:, 0] == 0
spatial = coords[b0][:, 1:4].float()
else:
spatial = coords[:, :3].float()
n0 = spatial.shape[0]
if n0 == 0:
return None
if model_frame == "z_up":
spatial = torch.stack([spatial[:, 0], spatial[:, 2], -spatial[:, 1]], dim=-1)
latent = x0[0, :, :n0, 0].float().transpose(0, 1) # [n0, C]
colors = self._latent_color(latent) # [n0, 3]
res = float(spatial.abs().max().item()) + 1.0
rad = max(1, int(round(self._SIZE * self._FILL / max(res, 1) / 2)))
return self._splat(spatial, colors, rad)
def decode_latent_to_preview(self, x0):
try:
from comfy.ldm.trellis2 import sampling_preview
ctx = sampling_preview.get_context()
if x0.ndim == 5:
return self._structure(x0)
mode = ctx.get("mode")
coords = ctx.get("coords")
if mode == "texture_generation" and coords is not None:
return self._texture(x0, coords, model_frame=ctx.get("model_frame"))
except Exception as e:
logging.debug(f"Trellis3DPreviewer: skipping preview ({e})")
return None
def decode_latent_to_preview_image(self, preview_format, x0):
preview_image = self.decode_latent_to_preview(x0)
if preview_image is None:
return None
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
def get_previewer(device, latent_format): def get_previewer(device, latent_format):
previewer = None previewer = None
method = args.preview_method method = args.preview_method
if method != LatentPreviewMethod.NoPreviews: if method != LatentPreviewMethod.NoPreviews:
if getattr(latent_format, "trellis3d_preview", False):
return Trellis3DPreviewer()
# TODO previewer methods # TODO previewer methods
taesd_decoder_path = None taesd_decoder_path = None
if latent_format.taesd_decoder_name is not None: if latent_format.taesd_decoder_name is not None: