Refactor attention, optimize and heavily cleanup model code and cond paths

This commit is contained in:
kijai 2026-05-22 19:55:15 +03:00
parent b2abca0f33
commit 4585a731c1
7 changed files with 809 additions and 862 deletions

View File

@ -32,6 +32,11 @@ except ImportError as e:
raise e
exit(-1)
try:
from sageattention import sageattn_varlen # sageattention >= 2
except ImportError:
sageattn_varlen = None
SAGE_ATTENTION3_IS_AVAILABLE = False
try:
from sageattn3 import sageattn3_blackwell
@ -48,6 +53,24 @@ except ImportError:
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
try:
from torch.nn.attention.varlen import varlen_attn as _torch_varlen_attn
except ImportError:
_torch_varlen_attn = None
def _is_varlen(kwargs):
"""Varlen mode is opted into by passing cu_seqlens_q in kwargs."""
return kwargs.get("cu_seqlens_q") is not None
def _varlen_args(kwargs):
cu_seqlens_q = kwargs["cu_seqlens_q"]
cu_seqlens_kv = kwargs.get("cu_seqlens_kv", cu_seqlens_q)
max_seqlen_q = int(kwargs["max_seqlen_q"])
max_seqlen_kv = int(kwargs.get("max_seqlen_kv", max_seqlen_q))
return cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
REGISTERED_ATTENTION_FUNCTIONS = {}
def register_attention_function(name: str, func: Callable):
# avoid replacing existing functions
@ -144,6 +167,8 @@ def wrap_attn(func):
@wrap_attn
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if _is_varlen(kwargs):
return attention_pytorch(q, k, v, heads, mask=mask, **kwargs)
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@ -218,6 +243,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
@wrap_attn
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if _is_varlen(kwargs):
return attention_pytorch(query, key, value, heads, mask=mask, **kwargs)
attn_precision = get_attn_precision(attn_precision, query.dtype)
if skip_reshape:
@ -293,6 +320,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
@wrap_attn
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if _is_varlen(kwargs):
return attention_pytorch(q, k, v, heads, mask=mask, **kwargs)
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@ -424,6 +453,17 @@ except:
@wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if _is_varlen(kwargs):
# q, k, v expected as packed [T_total, H, C]. xformers wants per-item
# seqlen lists and a leading batch dim of 1.
cu_seqlens_q, cu_seqlens_kv, _max_q, _max_kv = _varlen_args(kwargs)
q_seqlen = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).tolist()
kv_seqlen = (cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]).tolist()
attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
return xformers.ops.memory_efficient_attention(
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attn_bias=attn_bias,
)[0]
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
@ -493,6 +533,22 @@ else:
@wrap_attn
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if _is_varlen(kwargs):
# q, k, v expected as packed [T_total, H, C]. cu_seqlens_q / cu_seqlens_kv
# describe per-item offsets. Native varlen kernel if available, else
# nested-tensor SDPA fallback. mask/attn_precision/etc are ignored here.
cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs)
if _torch_varlen_attn is not None:
return _torch_varlen_attn(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv)
q_nj = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
k_nj = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_kv.long())
v_nj = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_kv.long())
out = comfy.ops.scaled_dot_product_attention(
q_nj.transpose(1, 2), k_nj.transpose(1, 2), v_nj.transpose(1, 2),
attn_mask=None, dropout_p=0.0, is_causal=False,
)
return out.transpose(1, 2).values()
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@ -541,6 +597,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if _is_varlen(kwargs):
# q, k, v expected as packed [T_total, H, C].
if sageattn_varlen is None:
# sageattention v1 has no varlen kernel; fall back to attention_pytorch's varlen path.
return attention_pytorch(q, k, v, heads, mask=mask, **kwargs)
cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs)
return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv)
if kwargs.get("low_precision_attention", True) is False:
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs)
@ -698,6 +761,12 @@ except AttributeError as error:
@wrap_attn
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if _is_varlen(kwargs):
# q, k, v expected as packed [T_total, H, C].
from flash_attn import flash_attn_varlen_func
cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs)
return flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv)
if skip_reshape:
b, _, _, dim_head = q.shape
else:

View File

@ -1,125 +1,107 @@
import torch
import math
from typing import Tuple, Union
import torch
from comfy.ldm.modules.attention import optimized_attention
from typing import Tuple, Union, List
from comfy.ldm.trellis2.vae import VarLenTensor
import comfy.ops
try:
from torch.nn.attention.varlen import varlen_attn as _varlen_attn
except ImportError:
_varlen_attn = None
def var_attn_arg(kwargs):
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
max_seqlen_q = kwargs.get("max_seqlen_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_kv", cu_seqlens_q)
max_seqlen_k = kwargs.get("max_kv_seqlen", max_seqlen_q)
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
def _build_cu_seqlens(seqlen, device):
"""Cumulative-offset tensor from a list/tensor of per-item sequence lengths."""
if isinstance(seqlen, torch.Tensor):
return torch.cat(
[torch.zeros(1, dtype=torch.int32, device=seqlen.device), torch.cumsum(seqlen, dim=0).int()],
dim=0,
).to(device)
return torch.cat(
[torch.tensor([0]), torch.cumsum(torch.tensor(seqlen), dim=0)],
).int().to(device)
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
if not skip_reshape:
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
q = q.view(total_tokens, heads, head_dim)
k = k.view(k.shape[0], heads, head_dim)
v = v.view(v.shape[0], heads, head_dim)
if _varlen_attn is not None:
return _varlen_attn(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
int(max_seqlen_q), int(max_seqlen_k),
)
def _layout_to_feats(t):
"""Returns (sparse_template_or_None, seqlen, feats[T, H, C])."""
if isinstance(t, VarLenTensor):
seqlen = [t.layout[i].stop - t.layout[i].start for i in range(t.shape[0])]
return t, seqlen, t.feats
N, L = t.shape[:2]
return None, [L] * N, t.reshape(-1, *t.shape[-2:])
# Fallback: nested-tensor SDPA (PyTorch < the version that introduced varlen_attn)
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
return out.transpose(1, 2).values()
def scaled_dot_product_attention(*args, **kwargs):
num_all_args = len(args) + len(kwargs)
q = None
if num_all_args == 1:
qkv = args[0] if len(args) > 0 else kwargs.get('qkv')
elif num_all_args == 2:
q = args[0] if len(args) > 0 else kwargs.get('q')
kv = args[1] if len(args) > 1 else kwargs.get('kv')
elif num_all_args == 3:
q = args[0] if len(args) > 0 else kwargs.get('q')
k = args[1] if len(args) > 1 else kwargs.get('k')
v = args[2] if len(args) > 2 else kwargs.get('v')
if q is not None:
heads = q.shape[2]
else:
heads = qkv.shape[3]
if num_all_args == 1:
q, k, v = qkv.unbind(dim=2)
elif num_all_args == 2:
k, v = kv.unbind(dim=2)
def dense_attention(q, k, v, **kwargs):
"""q, k, v: [B, L, H, C]. Permutes for comfy's [B, H, L, C] convention."""
heads = q.shape[2]
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs)
return out.permute(0, 2, 1, 3)
out = out.permute(0, 2, 1, 3)
return out
def sparse_attention(q, k, v, **kwargs):
"""
Varlen attention for SparseTensor inputs. Each of q, k, v may be a VarLenTensor
(sparse) or dense [B, L, H, C]. Output type matches q. Backend dispatch lives
in comfy.ldm.modules.attention.optimized_attention; we just build cu_seqlens
from the layouts.
"""
s, q_seqlen, q_feats = _layout_to_feats(q)
_, kv_seqlen, k_feats = _layout_to_feats(k)
_, _, v_feats = _layout_to_feats(v)
heads = q_feats.shape[1]
def sparse_windowed_scaled_dot_product_self_attention(
qkv,
window_size: int,
shift_window: Tuple[int, int, int] = (0, 0, 0)
):
device = q_feats.device
cu_seqlens_q = _build_cu_seqlens(q_seqlen, device)
cu_seqlens_kv = _build_cu_seqlens(kv_seqlen, device)
serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}'
serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
if serialization_spatial_cache is None:
fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window)
qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args))
else:
fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache
out = optimized_attention(
q_feats, k_feats, v_feats, heads,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max(q_seqlen), max_seqlen_kv=max(kv_seqlen),
skip_reshape=True, skip_output_reshape=True,
**kwargs,
)
if s is not None:
return s.replace(out)
N, L = q.shape[:2]
return out.reshape(N, L, heads, -1)
def sparse_windowed_self_attention(qkv, window_size: int, shift_window: Tuple[int, int, int] = (0, 0, 0)):
"""Windowed sparse self-attention. Partitions voxels into windows via spatial
sort, then runs varlen attention with one sequence per non-empty window."""
cache_name = f'windowed_attention_{window_size}_{shift_window}'
cache = qkv.get_spatial_cache(cache_name)
if cache is None:
cache = calc_window_partition(qkv, window_size, shift_window)
qkv.register_spatial_cache(cache_name, cache)
fwd_indices, bwd_indices, seq_lens = cache
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
heads = qkv_feats.shape[2]
q, k, v = qkv_feats.unbind(dim=1) # each [M, H, C]
heads = q.shape[1]
device = q.device
if optimized_attention.__name__ == 'attention_xformers':
q, k, v = qkv_feats.unbind(dim=1)
q = q.unsqueeze(0) # [1, M, H, C]
k = k.unsqueeze(0) # [1, M, H, C]
v = v.unsqueeze(0) # [1, M, H, C]
#out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
elif optimized_attention.__name__ == 'attention_flash':
if 'flash_attn' not in globals():
import flash_attn
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C]
else:
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
out = out[bwd_indices] # [T, H, C]
cu_seqlens = _build_cu_seqlens(seq_lens, device)
max_seqlen = int(seq_lens.max())
out = optimized_attention(
q, k, v, heads,
cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens,
max_seqlen_q=max_seqlen, max_seqlen_kv=max_seqlen,
skip_reshape=True, skip_output_reshape=True,
)
out = out[bwd_indices]
return qkv.replace(out)
def calc_window_partition(
tensor,
window_size: Union[int, Tuple[int, ...]],
shift_window: Union[int, Tuple[int, ...]] = 0,
) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
):
"""Returns (fwd_indices, bwd_indices, seq_lens) for window partitioning."""
DIM = tensor.coords.shape[1] - 1
shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
@ -136,139 +118,6 @@ def calc_window_partition(
bwd_indices = torch.empty_like(fwd_indices)
bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
seq_lens = torch.bincount(shifted_indices)
mask = seq_lens != 0
seq_lens = seq_lens[mask]
seq_lens = seq_lens[seq_lens != 0]
if optimized_attention.__name__ == 'attention_xformers':
if 'xops' not in globals():
import xformers.ops as xops
attn_func_args = {
'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
}
elif optimized_attention.__name__ == 'attention_flash':
attn_func_args = {
'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(),
'max_seqlen': torch.max(seq_lens)
}
return fwd_indices, bwd_indices, seq_lens, attn_func_args
def sparse_scaled_dot_product_attention(*args, **kwargs):
q=None
arg_names_dict = {
1: ['qkv'],
2: ['q', 'kv'],
3: ['q', 'k', 'v']
}
num_all_args = len(args) + len(kwargs)
for key in arg_names_dict[num_all_args][len(args):]:
assert key in kwargs, f"Missing argument {key}"
if num_all_args == 1:
qkv = args[0] if len(args) > 0 else kwargs['qkv']
device = qkv.device
s = qkv
q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
kv_seqlen = q_seqlen
qkv = qkv.feats # [T, 3, H, C]
elif num_all_args == 2:
q = args[0] if len(args) > 0 else kwargs['q']
kv = args[1] if len(args) > 1 else kwargs['kv']
device = q.device
if isinstance(q, VarLenTensor):
s = q
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
q = q.feats # [T_Q, H, C]
else:
s = None
N, L, H, C = q.shape
q_seqlen = [L] * N
q = q.reshape(N * L, H, C) # [T_Q, H, C]
if isinstance(kv, VarLenTensor):
kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
kv = kv.feats # [T_KV, 2, H, C]
else:
N, L, _, H, C = kv.shape
kv_seqlen = [L] * N
kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
elif num_all_args == 3:
q = args[0] if len(args) > 0 else kwargs['q']
k = args[1] if len(args) > 1 else kwargs['k']
v = args[2] if len(args) > 2 else kwargs['v']
device = q.device
if isinstance(q, VarLenTensor):
s = q
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
q = q.feats # [T_Q, H, Ci]
else:
s = None
N, L, H, CI = q.shape
q_seqlen = [L] * N
q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
if isinstance(k, VarLenTensor):
kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
k = k.feats # [T_KV, H, Ci]
v = v.feats # [T_KV, H, Co]
else:
N, L, H, CI, CO = *k.shape, v.shape[-1]
kv_seqlen = [L] * N
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
# TODO: change
if q is not None:
heads = q
else:
heads = qkv
heads = heads.shape[2]
if optimized_attention.__name__ == 'attention_xformers':
if 'xops' not in globals():
import xformers.ops as xops
if num_all_args == 1:
q, k, v = qkv.unbind(dim=1)
elif num_all_args == 2:
k, v = kv.unbind(dim=1)
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
out = xops.memory_efficient_attention(q, k, v, mask)[0]
elif optimized_attention.__name__ == 'attention_flash':
if 'flash_attn' not in globals():
import flash_attn
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
if num_all_args in [2, 3]:
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
if num_all_args == 1:
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
elif num_all_args == 2:
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
elif num_all_args == 3:
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
elif optimized_attention.__name__ == "attention_pytorch":
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
if num_all_args in [2, 3]:
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
else:
cu_seqlens_kv = cu_seqlens_q
if num_all_args == 1:
q, k, v = qkv.unbind(dim=1)
elif num_all_args == 2:
k, v = kv.unbind(dim=1)
out = attention_pytorch(q, k, v, heads=heads,cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max(q_seqlen), max_kv_seqlen=max(kv_seqlen),
skip_reshape=True, skip_output_reshape=True)
if s is not None:
return s.replace(out)
else:
return out.reshape(N, L, H, -1)
return fwd_indices, bwd_indices, seq_lens

View File

@ -4,11 +4,12 @@ import torch.nn as nn
from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor
from typing import Optional, Tuple, Literal, Union, List
from comfy.ldm.trellis2.attention import (
sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention
sparse_windowed_self_attention, sparse_attention, dense_attention
)
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
from comfy.ldm.flux.math import apply_rope, apply_rope1
class SparseGELU(nn.GELU):
def forward(self, input: VarLenTensor) -> VarLenTensor:
return input.replace(super().forward(input.feats))
@ -31,22 +32,16 @@ class LayerNorm32(nn.LayerNorm):
x = x.to(dtype=torch.float32)
o = super().forward(x)
return o.to(dtype=x_dtype)
class SparseMultiHeadRMSNorm(nn.Module):
def __init__(self, dim: int, heads: int, device, dtype):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
x_type = x.dtype
x = x.float()
if isinstance(x, VarLenTensor):
x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale)
else:
x = F.normalize(x, dim=-1) * self.gamma * self.scale
return x.to(x_type)
return x.replace(F.rms_norm(x.feats, (x.feats.shape[-1],)) * self.gamma)
return F.rms_norm(x, (x.shape[-1],)) * self.gamma
class SparseRotaryPositionEmbedder(nn.Module):
def __init__(
@ -84,12 +79,6 @@ class SparseRotaryPositionEmbedder(nn.Module):
return freqs_cis
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
self.freqs = self.freqs.to(indices.device)
phases = torch.outer(indices, self.freqs)
phases = torch.polar(torch.ones_like(phases), phases)
return phases
def forward(self, q, k=None):
cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}'
freqs_cis = q.get_spatial_cache(cache_name)
@ -110,27 +99,10 @@ class SparseRotaryPositionEmbedder(nn.Module):
q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis)
return q.replace(q_feats), k.replace(k_feats)
@staticmethod
def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_rotated = x_complex * phases.unsqueeze(-2)
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
return x_embed
class RotaryPositionEmbedder(SparseRotaryPositionEmbedder):
def forward(self, indices: torch.Tensor) -> torch.Tensor:
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
if torch.is_complex(phases):
phases = phases.to(torch.complex64)
else:
phases = phases.to(torch.float32)
if phases.shape[-1] < self.head_dim // 2:
padn = self.head_dim // 2 - phases.shape[-1]
phases = torch.cat([phases, torch.polar(
torch.ones(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32),
torch.zeros(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32)
)], dim=-1)
return phases
def forward(self, coords: torch.Tensor) -> torch.Tensor:
return self._get_freqs_cis(coords) # [L, head_dim/2, 2, 2]
class SparseMultiHeadAttention(nn.Module):
def __init__(
@ -198,50 +170,53 @@ class SparseMultiHeadAttention(nn.Module):
x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats
def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor:
def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None,
transformer_options=None) -> SparseTensor:
if self._type == "self":
dtype = next(self.to_qkv.parameters()).dtype
x = x.to(dtype)
qkv = self._linear(self.to_qkv, x)
qkv = self._fused_pre(qkv, num_fused=3)
if self.qk_rms_norm or self.use_rope:
if self.attn_mode == "full":
q, k, v = qkv.unbind(dim=-3)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
if self.use_rope:
q, k = self.rope(q, k)
qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
if self.attn_mode == "full":
h = sparse_scaled_dot_product_attention(qkv)
elif self.attn_mode == "windowed":
h = sparse_windowed_scaled_dot_product_self_attention(
qkv, self.window_size, shift_window=self.shift_window
)
elif self.attn_mode == "double_windowed":
qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:])
qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2])
h0 = sparse_windowed_scaled_dot_product_self_attention(
qkv0, self.window_size, shift_window=(0, 0, 0)
)
h1 = sparse_windowed_scaled_dot_product_self_attention(
qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3)
)
h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1))
h = sparse_attention(q, k, v, transformer_options=transformer_options)
else:
# Windowed paths take packed qkv; preserve any per-head norm/rope.
if self.qk_rms_norm or self.use_rope:
q, k, v = qkv.unbind(dim=-3)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
if self.use_rope:
q, k = self.rope(q, k)
qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
if self.attn_mode == "windowed":
h = sparse_windowed_self_attention(
qkv, self.window_size, shift_window=self.shift_window
)
elif self.attn_mode == "double_windowed":
qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:])
qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2])
h0 = sparse_windowed_self_attention(
qkv0, self.window_size, shift_window=(0, 0, 0)
)
h1 = sparse_windowed_self_attention(
qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3)
)
h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1))
else:
q = self._linear(self.to_q, x)
q = self._reshape_chs(q, (self.num_heads, -1))
dtype = next(self.to_kv.parameters()).dtype
context = context.to(dtype)
kv = self._linear(self.to_kv, context)
kv = self._fused_pre(kv, num_fused=2)
k, v = kv.unbind(dim=-3)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k, v = kv.unbind(dim=-3)
k = self.k_rms_norm(k)
h = sparse_scaled_dot_product_attention(q, k, v)
else:
h = sparse_scaled_dot_product_attention(q, kv)
h = sparse_attention(q, k, v, transformer_options=transformer_options)
h = self._reshape_chs(h, (-1,))
h = self._linear(self.to_out, h)
return h
@ -265,9 +240,9 @@ class ProjectAttentionSparse(nn.Module):
self.proj_linear = operations.Linear(proj_in_channels, channels, bias=True,
device=device, dtype=dtype)
def forward(self, x: SparseTensor, context) -> SparseTensor:
def forward(self, x: SparseTensor, context, transformer_options=None) -> SparseTensor:
global_ctx, proj_in = _split_proj_context(context)
global_out = self.cross_attn_block(x, global_ctx)
global_out = self.cross_attn_block(x, global_ctx, transformer_options=transformer_options)
if isinstance(proj_in, tuple):
proj_in = torch.cat([proj_in[0], proj_in[1]], dim=-1)
proj_out = self.proj_linear(proj_in.to(self.proj_linear.weight.dtype))
@ -282,9 +257,9 @@ class ProjectAttentionDense(nn.Module):
self.proj_linear = operations.Linear(proj_in_channels, channels, bias=True,
device=device, dtype=dtype)
def forward(self, x: torch.Tensor, context) -> torch.Tensor:
def forward(self, x: torch.Tensor, context, transformer_options=None) -> torch.Tensor:
global_ctx, proj_in = _split_proj_context(context)
global_out = self.cross_attn_block(x, global_ctx)
global_out = self.cross_attn_block(x, global_ctx, transformer_options=transformer_options)
if isinstance(proj_in, tuple):
proj_in = torch.cat([proj_in[0], proj_in[1]], dim=-1)
proj_out = self.proj_linear(proj_in.to(self.proj_linear.weight.dtype))
@ -367,32 +342,36 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
else:
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
def _forward(self, x: SparseTensor, mod: torch.Tensor, context) -> SparseTensor:
def _forward(self, x: SparseTensor, mod: torch.Tensor, context, transformer_options=None) -> SparseTensor:
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = x.replace(self.norm1(x.feats))
h = h * (1 + scale_msa) + shift_msa
h = self.self_attn(h)
h = h * gate_msa
x = x + h
# Fuse the (mul + add) and (mul + residual) pairs into addcmul so the
# mod/shift broadcasts hit one kernel each instead of two.
b_map = x.batch_boardcast_map
h_feats = self.norm1(x.feats)
h_feats = torch.addcmul(shift_msa[b_map], h_feats, (1 + scale_msa)[b_map])
h = self.self_attn(x.replace(h_feats), transformer_options=transformer_options)
x = x.replace(torch.addcmul(x.feats, h.feats, gate_msa[b_map]))
h = x.replace(self.norm2(x.feats))
if self.image_attn_mode == "global":
global_ctx, _ = _split_proj_context(context)
h = self.cross_attn(h, global_ctx)
h = self.cross_attn(h, global_ctx, transformer_options=transformer_options)
else:
h = self.cross_attn(h, context)
x = x + h
h = x.replace(self.norm3(x.feats))
h = h * (1 + scale_mlp) + shift_mlp
h = self.mlp(h)
h = h * gate_mlp
h = self.cross_attn(h, context, transformer_options=transformer_options)
x = x + h
h_feats = self.norm3(x.feats)
h_feats = torch.addcmul(shift_mlp[b_map], h_feats, (1 + scale_mlp)[b_map])
h = self.mlp(x.replace(h_feats))
x = x.replace(torch.addcmul(x.feats, h.feats, gate_mlp[b_map]))
return x
def forward(self, x: SparseTensor, mod: torch.Tensor, context) -> SparseTensor:
return self._forward(x, mod, context)
def forward(self, x: SparseTensor, mod: torch.Tensor, context, transformer_options=None) -> SparseTensor:
return self._forward(x, mod, context, transformer_options=transformer_options)
class SLatFlowModel(nn.Module):
@ -480,24 +459,21 @@ class SLatFlowModel(nn.Module):
t: torch.Tensor,
cond: Union[torch.Tensor, List[torch.Tensor]],
concat_cond: Optional[SparseTensor] = None,
**kwargs
transformer_options=None,
**kwargs,
) -> SparseTensor:
if concat_cond is not None:
x = sparse_cat([x, concat_cond], dim=-1)
if isinstance(cond, list):
cond = VarLenTensor.from_tensor_list(cond)
dtype = next(self.input_layer.parameters()).dtype
x = x.to(dtype)
h = self.input_layer(x)
t = t.to(dtype)
t_embedder = self.t_embedder.to(dtype)
t_emb = t_embedder(t, out_dtype = t.dtype)
t_emb = self.t_embedder(t, out_dtype=t.dtype)
if self.share_mod:
t_emb = self.adaLN_modulation(t_emb)
for block in self.blocks:
h = block(h, t_emb, cond)
h = block(h, t_emb, cond, transformer_options=transformer_options)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.out_layer(h)
@ -566,40 +542,34 @@ class MultiHeadAttention(nn.Module):
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None,
phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor:
B, L, C = x.shape
if self._type == "self":
x = x.to(next(self.to_qkv.parameters()).dtype)
qkv = self.to_qkv(x)
qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
if self.attn_mode == "full":
if self.qk_rms_norm or self.use_rope:
q, k, v = qkv.unbind(dim=2)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
if self.use_rope:
assert phases is not None, "Phases must be provided for RoPE"
q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases)
k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases)
h = scaled_dot_product_attention(q, k, v)
else:
h = scaled_dot_product_attention(qkv)
q, k, v = qkv.unbind(dim=2)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
if self.use_rope:
assert phases is not None, "Phases must be provided for RoPE"
# phases is [L, head_dim/2, 2, 2]; broadcast to [1, L, 1, ...]
# to align with q/k of shape [B, L, H, head_dim].
f_cis = phases.unsqueeze(0).unsqueeze(2)
q, k = apply_rope(q, k, f_cis)
h = dense_attention(q, k, v, transformer_options=transformer_options)
else:
Lkv = context.shape[1]
q = self.to_q(x)
context = context.to(next(self.to_kv.parameters()).dtype)
kv = self.to_kv(context)
q = q.reshape(B, L, self.num_heads, -1)
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
k, v = kv.unbind(dim=2)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k, v = kv.unbind(dim=2)
k = self.k_rms_norm(k)
h = scaled_dot_product_attention(q, k, v)
else:
h = scaled_dot_product_attention(q, kv)
h = dense_attention(q, k, v, transformer_options=transformer_options)
h = h.reshape(B, L, -1)
h = self.to_out(h)
return h
@ -677,32 +647,39 @@ class ModulatedTransformerCrossBlock(nn.Module):
else:
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context,
phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor:
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = self.norm1(x)
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
h = self.self_attn(h, phases=phases)
h = h * gate_msa.unsqueeze(1)
x = x + h
shift_msa = shift_msa.unsqueeze(1)
scale_msa = scale_msa.unsqueeze(1)
gate_msa = gate_msa.unsqueeze(1)
shift_mlp = shift_mlp.unsqueeze(1)
scale_mlp = scale_mlp.unsqueeze(1)
gate_mlp = gate_mlp.unsqueeze(1)
h = torch.addcmul(shift_msa, self.norm1(x), 1 + scale_msa)
h = self.self_attn(h, phases=phases, transformer_options=transformer_options)
x = torch.addcmul(x, h, gate_msa)
h = self.norm2(x)
if self.image_attn_mode == "global":
global_ctx, _ = _split_proj_context(context)
h = self.cross_attn(h, global_ctx)
h = self.cross_attn(h, global_ctx, transformer_options=transformer_options)
else:
h = self.cross_attn(h, context)
h = self.cross_attn(h, context, transformer_options=transformer_options)
x = x + h
h = self.norm3(x)
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
h = torch.addcmul(shift_mlp, self.norm3(x), 1 + scale_mlp)
h = self.mlp(h)
h = h * gate_mlp.unsqueeze(1)
x = x + h
x = torch.addcmul(x, h, gate_mlp)
return x
def forward(self, x: torch.Tensor, mod: torch.Tensor, context, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
return self._forward(x, mod, context, phases)
def forward(self, x: torch.Tensor, mod: torch.Tensor, context,
phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor:
return self._forward(x, mod, context, phases, transformer_options=transformer_options)
class SparseStructureFlowModel(nn.Module):
@ -792,18 +769,18 @@ class SparseStructureFlowModel(nn.Module):
self.out_layer = operations.Linear(model_channels, out_channels, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor,
transformer_options=None) -> torch.Tensor:
x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3)
h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
h = h.to(next(self.input_layer.parameters()).dtype)
h = self.input_layer(h)
t_emb = self.t_embedder(t, out_dtype = t.dtype)
t_emb = self.t_embedder(t, out_dtype=t.dtype)
if self.share_mod:
t_emb = self.adaLN_modulation(t_emb)
for block in self.blocks:
h = block(h, t_emb, cond, self.rope_phases)
h = block(h, t_emb, cond, self.rope_phases, transformer_options=transformer_options)
h = F.layer_norm(h, h.shape[-1:])
h = self.out_layer(h)
@ -820,7 +797,7 @@ def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
# Pixal3D ProjGrid math — port of upstream's ProjGrid + project_points_to_image_batch.
# World frame uses world Y as depth (Blender convention), camera looks along -Z local;
# World frame uses world Y as depth, camera looks along -Z local;
# transform_matrix is camera-to-world (inverted internally). Intrinsics: fx = 16 / tan(fov/2)
# with sensor_width = 32mm.
@ -927,7 +904,7 @@ def _back_project_to_tokens(
if not mask.any():
continue
p = coords_world[mask].unsqueeze(0)
uv, depth, valid = _project_points_to_image(
uv, _, _ = _project_points_to_image(
p, transform_matrix[b:b+1], camera_angle_x[b:b+1], image_resolution)
uv_ndc = (uv + 0.5) / image_resolution * 2.0 - 1.0
# padding_mode='border' is load-bearing: masking out-of-frame voxels confuses
@ -937,7 +914,7 @@ def _back_project_to_tokens(
out[mask] = sampled
return out
else:
uv, depth, valid = _project_points_to_image(
uv, _, _ = _project_points_to_image(
coords_world, transform_matrix, camera_angle_x, image_resolution)
uv_ndc = (uv + 0.5) / image_resolution * 2.0 - 1.0
sampled = _sample_features(feature_map, uv_ndc)
@ -945,21 +922,6 @@ def _back_project_to_tokens(
return out
def _pack_per_voxel_scalar(proj_pack: Optional[dict], key: str, eval_batch: int, device) -> torch.Tensor:
if proj_pack is None or key not in proj_pack:
return torch.ones((eval_batch,), device=device, dtype=torch.float32)
t = proj_pack[key].to(device=device, dtype=torch.float32)
if t.ndim == 0:
return t.expand(eval_batch).clone()
return _expand_pack(t, eval_batch)
def _expand_pack(t: torch.Tensor, eval_batch: int) -> torch.Tensor:
if eval_batch == t.shape[0]:
return t
if eval_batch % t.shape[0] != 0:
raise ValueError(f"eval batch {eval_batch} is not a multiple of pack batch {t.shape[0]}")
return t.repeat((eval_batch // t.shape[0],) + (1,) * (t.ndim - 1))
def _select_stage_entry(proj_pack: dict, stage: Optional[str]):
@ -973,53 +935,96 @@ def _select_stage_entry(proj_pack: dict, stage: Optional[str]):
raise ValueError(f"proj_feat_pack has no usable feature_map (stage={stage!r})")
def _build_proj_cond(global_cond: torch.Tensor, image_attn_mode: str, proj_pack: Optional[dict],
coords_world: torch.Tensor, batch_ids: Optional[torch.Tensor] = None,
eval_batch: Optional[int] = None,
proj_in_channels: Optional[int] = None,
stage: Optional[str] = None,
cond_or_uncond: Optional[list] = None):
if image_attn_mode == "global":
return global_cond
if proj_pack is None:
raise ValueError(f"image_attn_mode={image_attn_mode!r} but proj_feat_pack is missing")
device = coords_world.device
def compute_stage_proj_feats(
proj_pack: dict,
stage: str,
coords: Optional[torch.Tensor] = None,
coord_resolution: Optional[int] = None,
dense_grid_resolution: Optional[int] = None,
batch_size: Optional[int] = None,
device=None,
) -> torch.Tensor:
"""Back-project a Pixal3D stage's feature maps onto its target voxel/grid coords.
For sparse (shape / texture) stages: pass ``coords`` (with ``coord_resolution``).
Returns ``[N_voxels, C]`` per-voxel features with channel count =
LR channels + optional HR channels.
For the dense SS stage: pass ``dense_grid_resolution`` (16) + ``batch_size``.
Returns ``[B, R^3, C]`` features for the dense grid.
"""
if device is None:
device = coords.device if coords is not None else proj_pack["mesh_scale"].device
mesh_scale = proj_pack["mesh_scale"].to(device)
T = proj_pack["transform_matrix"].to(device)
cam_angle = proj_pack["camera_angle_x"].to(device)
feat_map_lr, feat_map_hr, image_resolution = _select_stage_entry(proj_pack, stage)
feat_map_lr = feat_map_lr.to(device)
if feat_map_hr is not None:
feat_map_hr = feat_map_hr.to(device)
if eval_batch is not None:
T = _expand_pack(T, eval_batch)
cam_angle = _expand_pack(cam_angle, eval_batch) if cam_angle.ndim >= 1 else cam_angle
feat_map_lr = _expand_pack(feat_map_lr, eval_batch)
if feat_map_hr is not None:
feat_map_hr = _expand_pack(feat_map_hr, eval_batch)
# Channel-count check against the trained proj_linear input. If HR is present, the
# block expects (LR_channels + HR_channels) since we concat the sampled features.
expected_channels = feat_map_lr.shape[1] + (feat_map_hr.shape[1] if feat_map_hr is not None else 0)
if proj_in_channels is not None and expected_channels != proj_in_channels:
if coords is not None:
if coord_resolution is None:
raise ValueError("compute_stage_proj_feats: coord_resolution required when coords is given")
coords_world, batch_ids = _coords_to_proj_world(coords, coord_resolution, mesh_scale)
else:
if dense_grid_resolution is None or batch_size is None:
raise ValueError("compute_stage_proj_feats: dense_grid_resolution + batch_size required for dense path")
coords_world = _dense_grid_proj_world(dense_grid_resolution, mesh_scale, batch_size,
device=device, dtype=torch.float32)
batch_ids = None
proj_lr = _back_project_to_tokens(coords_world, feat_map_lr, T, cam_angle,
image_resolution=image_resolution, batch_ids=batch_ids)
if feat_map_hr is not None:
proj_hr = _back_project_to_tokens(coords_world, feat_map_hr, T, cam_angle,
image_resolution=image_resolution, batch_ids=batch_ids)
return torch.cat([proj_lr, proj_hr], dim=-1)
return proj_lr
def _shape_proj_cond(global_cond: torch.Tensor, image_attn_mode: str,
proj_feats: Optional[torch.Tensor],
batch_ids: Optional[torch.Tensor] = None,
eval_batch: Optional[int] = None,
logical_batch: Optional[int] = None,
proj_in_channels: Optional[int] = None,
stage: Optional[str] = None,
cond_or_uncond: Optional[list] = None,
has_hr: bool = False):
"""Take pre-computed per-token proj features (from compute_stage_proj_feats),
apply CFG-batch duplication + uncond-slot zeroing, and wrap into the
``{"global", "proj"}`` context dict consumed by ProjectAttention.
proj_feats shape:
sparse (shape/texture): [N_voxels, C] (batch_ids gives per-voxel batch)
dense (SS): [B, N, C]
"""
if image_attn_mode == "global":
return global_cond
if proj_feats is None:
raise ValueError(f"image_attn_mode={image_attn_mode!r} but trellis2_proj_feats is missing — "
f"the stage setup node (or Pixal3DConditioning for SS) should have computed it.")
if proj_in_channels is not None and proj_feats.shape[-1] != proj_in_channels:
hint = ""
if feat_map_hr is None and expected_channels < proj_in_channels:
if not has_hr and proj_feats.shape[-1] < proj_in_channels:
hint = (" — feature_map_hr is missing for this stage. Connect a NAFModel "
"input to Pixal3DConditioning; the shape/texture stages of this "
"checkpoint need a NAF-upsampled HR feature map.")
raise ValueError(
f"proj_feat_pack[{stage!r}] has LR={feat_map_lr.shape[1]} "
f"+ HR={feat_map_hr.shape[1] if feat_map_hr is not None else 0} "
f"= {expected_channels} channels, sub-model expects {proj_in_channels}.{hint}"
f"proj_feats for stage {stage!r} has {proj_feats.shape[-1]} channels, "
f"sub-model expects {proj_in_channels}.{hint}"
)
proj_feats_lr = _back_project_to_tokens(coords_world, feat_map_lr, T, cam_angle,
image_resolution=image_resolution,
batch_ids=batch_ids)
if feat_map_hr is not None:
proj_feats_hr = _back_project_to_tokens(coords_world, feat_map_hr, T, cam_angle,
image_resolution=image_resolution,
batch_ids=batch_ids)
proj_feats = torch.cat([proj_feats_lr, proj_feats_hr], dim=-1)
else:
proj_feats = proj_feats_lr
# CFG-duplicate proj_feats to match the model's eval batch.
if eval_batch is not None and logical_batch is not None and eval_batch > logical_batch:
repeats = eval_batch // logical_batch
if batch_ids is None:
proj_feats = proj_feats.repeat((repeats,) + (1,) * (proj_feats.ndim - 1))
else:
proj_feats = proj_feats.repeat((repeats, 1))
# Mirror upstream's neg_cond by zeroing proj for any uncond batch slot.
if cond_or_uncond is not None and eval_batch is not None:
uncond_slots = [i for i, v in enumerate(cond_or_uncond) if v == 1]
@ -1056,10 +1061,6 @@ class Trellis2(nn.Module):
super().__init__()
self.dtype = dtype
operations = operations or nn
# for some reason it passes num_heads = -1
if num_heads == -1:
num_heads = 12
args = {
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
@ -1082,101 +1083,79 @@ class Trellis2(nn.Module):
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **struct_proj_kwargs, **args)
else:
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **tex_proj_kwargs, **args)
self.guidance_interval = [0.6, 1.0]
self.guidance_interval_txt = [0.6, 0.9]
def forward(self, x, timestep, context, **kwargs):
transformer_options = kwargs.get("transformer_options", {})
cond_or_uncond = transformer_options.get("cond_or_uncond")
model_options = {}
if hasattr(self, "meta"):
model_options = self.meta
timestep = timestep.to(x.dtype)
embeds = kwargs.get("embeds")
if embeds is None:
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
is_1024 = True#self.img2shape.resolution == 1024
coords = model_options.get("coords", None)
coord_counts = model_options.get("coord_counts", None)
mode = model_options.get("generation_mode", "structure_generation")
proj_feat_pack = model_options.get("proj_feat_pack", None)
coord_resolution = model_options.get("coord_resolution", None)
# Per-stage cascade metadata
coords = kwargs.get("trellis2_coords")
coord_counts = kwargs.get("trellis2_coord_counts")
mode = kwargs.get("trellis2_generation_mode", "structure_generation")
proj_feat_pack = kwargs.get("proj_feat_pack")
# Pre-computed per-stage back-projected features
proj_feats = kwargs.get("trellis2_proj_feats")
is_512_run = False
is_first_shape_pass = False
if mode == "shape_generation_512":
is_512_run = True
is_first_shape_pass = True
mode = "shape_generation"
if coords is not None:
if x.ndim == 4:
x = x.squeeze(-1).transpose(1, 2)
not_struct_mode = True
x = x.squeeze(-1).transpose(1, 2)
is_sparse_mode = True
else:
mode = "structure_generation"
not_struct_mode = False
is_sparse_mode = False
if x.size(-1) == 16 and x.size(-2) == 16:
mode = "structure_generation"
not_struct_mode = False
is_sparse_mode = False
if not not_struct_mode:
if not is_sparse_mode:
bsz = x.size(0)
x = x[:, :8]
x = x.view(bsz, 8, 16, 16, 16)
if is_1024 and not_struct_mode and not is_512_run:
if is_sparse_mode and not is_first_shape_pass:
context = embeds
sigmas = transformer_options.get("sigmas")[0].item()
if sigmas < 1.00001:
timestep *= 1000.0
if is_sparse_mode:
t_eval = timestep
c_eval = context
if context.size(0) > 1:
cond = context.chunk(2)[1]
else:
cond = context
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
if not_struct_mode:
orig_bsz = x.shape[0]
rule = txt_rule if mode == "texture_generation" else shape_rule
# CFG Bypass Slicing
if rule and orig_bsz > 1:
half = orig_bsz // 2
x_eval = x[half:]
t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep
c_eval = cond
else:
x_eval = x
t_eval = timestep
c_eval = context
B, N, C = x_eval.shape
B, N, C = x.shape
# Vectorized SparseTensor Construction
if mode in ["shape_generation", "texture_generation"]:
if coord_counts is not None:
logical_batch = coord_counts.shape[0]
# Duplicate coords if CFG is active
# Duplicate sparse coords when the sampler asks for >1 cond
# (CFG or otherwise). Each duplicate is offset along col 0
# so SparseTensor sees a fresh logical batch.
if B > logical_batch:
c_pos = coords.clone()
c_pos[:, 0] += logical_batch
batched_coords = torch.cat([coords, c_pos], dim=0)
counts_eval = torch.cat([coord_counts, coord_counts], dim=0)
reps = B // logical_batch
c_copies = []
for i in range(reps):
c = coords.clone()
c[:, 0] += i * logical_batch
c_copies.append(c)
batched_coords = torch.cat(c_copies, dim=0)
counts_eval = coord_counts.repeat(reps)
else:
batched_coords = coords
counts_eval = coord_counts
# Create boolean mask [B, N] to drop the padded zeros instantly
# Boolean mask [B, N] to drop the padded zeros instantly
mask = torch.arange(N, device=x.device).unsqueeze(0) < counts_eval.unsqueeze(1)
feats_flat = x_eval[mask]
feats_flat = x[mask]
else:
feats_flat = x_eval.reshape(-1, C)
coords_list =[]
feats_flat = x.reshape(-1, C)
coords_list = []
for i in range(B):
c = coords.clone()
c[:, 0] = i
@ -1185,35 +1164,42 @@ class Trellis2(nn.Module):
mask = None
else:
batched_coords = coords
feats_flat = x_eval
feats_flat = x
mask = None
x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
x_st = SparseTensor(
feats=feats_flat,
coords=batched_coords.to(torch.int32),
shape=torch.Size([B] + list(feats_flat.shape[1:])),
)
if mode == "shape_generation":
shape_attn = self.image_attn_mode_shape
if shape_attn != "global":
if coord_resolution is None:
raise ValueError("Pixal3D shape_generation requires coord_resolution in model_options; "
"EmptyTrellis2ShapeLatent should set it from the input voxel.")
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", B, batched_coords.device)
xyz_world, batch_ids = _coords_to_proj_world(batched_coords, coord_resolution, mesh_scale)
sub_model = self.img2shape_512 if is_512_run else self.img2shape
stage_name = "shape_512" if is_512_run else "shape_1024"
c_eval = _build_proj_cond(c_eval, shape_attn, proj_feat_pack, xyz_world, batch_ids,
eval_batch=B,
sub_model = self.img2shape_512 if is_first_shape_pass else self.img2shape
stage_name = "shape_512" if is_first_shape_pass else "shape_1024"
# batched_coords carries CFG-doubled batch ids in col 0; per-voxel
# batch_ids drive uncond-slot masking inside _shape_proj_cond.
batch_ids = batched_coords[:, 0].long()
logical_batch = coord_counts.shape[0] if coord_counts is not None else B
has_hr = bool(proj_feat_pack and proj_feat_pack.get("stages", {})
.get(stage_name, {}).get("feature_map_hr") is not None)
c_eval = _shape_proj_cond(c_eval, shape_attn, proj_feats,
batch_ids=batch_ids,
eval_batch=B, logical_batch=logical_batch,
proj_in_channels=sub_model.proj_in_channels,
stage=stage_name,
cond_or_uncond=cond_or_uncond)
if is_512_run:
out = self.img2shape_512(x_st, t_eval, c_eval)
cond_or_uncond=cond_or_uncond,
has_hr=has_hr)
if is_first_shape_pass:
out = self.img2shape_512(x_st, t_eval, c_eval, transformer_options=transformer_options)
else:
out = self.img2shape(x_st, t_eval, c_eval)
out = self.img2shape(x_st, t_eval, c_eval, transformer_options=transformer_options)
elif mode == "texture_generation":
if self.shape2txt is None:
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
slat = model_options.get("shape_slat")
slat = kwargs.get("trellis2_shape_slat")
if slat is None:
raise ValueError("shape_slat can't be None")
@ -1227,49 +1213,40 @@ class Trellis2(nn.Module):
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats.to(x_st.feats.device)], dim=-1))
tex_attn = self.image_attn_mode_texture
if tex_attn != "global":
if coord_resolution is None:
raise ValueError("Pixal3D texture_generation requires coord_resolution in model_options; "
"EmptyTrellis2LatentTexture should set it from the input voxel.")
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", B, batched_coords.device)
xyz_world, batch_ids = _coords_to_proj_world(batched_coords, coord_resolution, mesh_scale)
c_eval = _build_proj_cond(c_eval, tex_attn, proj_feat_pack, xyz_world, batch_ids,
eval_batch=B,
batch_ids = batched_coords[:, 0].long()
logical_batch = coord_counts.shape[0] if coord_counts is not None else B
has_hr = bool(proj_feat_pack and proj_feat_pack.get("stages", {})
.get("tex_1024", {}).get("feature_map_hr") is not None)
c_eval = _shape_proj_cond(c_eval, tex_attn, proj_feats,
batch_ids=batch_ids,
eval_batch=B, logical_batch=logical_batch,
proj_in_channels=self.shape2txt.proj_in_channels,
stage="tex_1024",
cond_or_uncond=cond_or_uncond)
out = self.shape2txt(x_st, t_eval, c_eval)
cond_or_uncond=cond_or_uncond,
has_hr=has_hr)
out = self.shape2txt(x_st, t_eval, c_eval, transformer_options=transformer_options)
else: # structure
orig_bsz = x.shape[0]
struct_attn = self.image_attn_mode_structure
if shape_rule and orig_bsz > 1:
half = orig_bsz // 2
x_eval = x[half:]
t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep
struct_cond = cond
if struct_attn != "global":
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", half, x.device)
grid_xyz = _dense_grid_proj_world(16, mesh_scale, half, device=x.device)
struct_cond = _build_proj_cond(cond, struct_attn, proj_feat_pack, grid_xyz,
eval_batch=half,
proj_in_channels=self.structure_model.proj_in_channels,
stage="ss",
cond_or_uncond=cond_or_uncond)
out = self.structure_model(x_eval, t_eval, struct_cond)
out = out.repeat(2, 1, 1, 1, 1)
else:
struct_cond = context
if struct_attn != "global":
mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", orig_bsz, x.device)
grid_xyz = _dense_grid_proj_world(16, mesh_scale, orig_bsz, device=x.device)
struct_cond = _build_proj_cond(context, struct_attn, proj_feat_pack, grid_xyz,
eval_batch=orig_bsz,
proj_in_channels=self.structure_model.proj_in_channels,
stage="ss",
cond_or_uncond=cond_or_uncond)
out = self.structure_model(x, timestep, struct_cond)
has_hr_ss = bool(proj_feat_pack and proj_feat_pack.get("stages", {})
.get("ss", {}).get("feature_map_hr") is not None)
logical_batch_ss = (
proj_feat_pack["mesh_scale"].shape[0]
if (proj_feat_pack is not None and torch.is_tensor(proj_feat_pack.get("mesh_scale")))
else x.shape[0]
)
struct_cond = context
if struct_attn != "global":
struct_cond = _shape_proj_cond(context, struct_attn, proj_feats,
batch_ids=None,
eval_batch=x.shape[0], logical_batch=logical_batch_ss,
proj_in_channels=self.structure_model.proj_in_channels,
stage="ss",
cond_or_uncond=cond_or_uncond,
has_hr=has_hr_ss)
out = self.structure_model(x, timestep, struct_cond, transformer_options=transformer_options)
if not_struct_mode:
if is_sparse_mode:
if mask is not None:
# Instantly scatter the valid tokens back into a padded rectangular tensor
padded_out = torch.zeros((B, N, out.feats.shape[-1]), device=x.device, dtype=out.feats.dtype)
@ -1277,9 +1254,6 @@ class Trellis2(nn.Module):
out_tensor = padded_out.transpose(1, 2).unsqueeze(-1)
else:
out_tensor = out.feats.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
if rule and orig_bsz > 1:
out_tensor = out_tensor.repeat(2, 1, 1, 1)
return out_tensor
else:
out = torch.nn.functional.pad(out, (0, 0, 0, 0, 0, 0, 0, 24))

View File

@ -1646,6 +1646,17 @@ class Trellis2(BaseModel):
out = super().extra_conds(**kwargs)
embeds = kwargs.get("embeds")
out["embeds"] = comfy.conds.CONDRegular(embeds)
# CONDConstant: shared across pos/neg
for k in ("trellis2_coords", "trellis2_coord_counts",
"trellis2_generation_mode", "trellis2_shape_slat",
"trellis2_proj_feats"):
v = kwargs.get(k)
if v is not None:
out[k] = comfy.conds.CONDConstant(v)
# Pixal3D's per-stage feature maps + camera params travel as a dict
proj_feat_pack = kwargs.get("proj_feat_pack")
if proj_feat_pack is not None:
out["proj_feat_pack"] = comfy.conds.CONDConstant(proj_feat_pack)
return out
class WAN21_FlowRVS(WAN21):

View File

@ -1322,10 +1322,10 @@ class Trellis2(supported_models_base.BASE):
unet_config = {
"image_model": "trellis2"
}
unet_extra_config = {}
sampling_settings = {
"shift": 3.0,
"multiplier": 1.0
}
memory_usage_factor = 3.5

View File

@ -1,7 +1,9 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types, io
from comfy.ldm.trellis2.vae import SparseTensor
from comfy.ldm.trellis2.model import _build_proj_transform_matrix, _project_points_to_image
from comfy.ldm.trellis2.model import (
_build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats,
)
from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
import comfy.model_management
@ -14,48 +16,9 @@ import math
import torch
ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
Pixal3DProjPack = io.Custom("PIXAL3D_PROJ_PACK")
NAFModel = io.Custom("NAF_MODEL")
# Pixal3D trains in a 90°-X-rotated grid frame (F_p). We un-rotate decoder outputs for
# user-facing previews/meshes, then re-rotate before feeding coords back to the shape DiT.
def _pixal3d_unrotate_voxel_data(data: torch.Tensor) -> torch.Tensor:
if data.ndim == 4:
return data.flip(-1).permute(0, 1, 3, 2).contiguous()
if data.ndim == 5:
return data.flip(-1).permute(0, 1, 2, 4, 3).contiguous()
raise ValueError(f"unexpected voxel shape {tuple(data.shape)}")
def _pixal3d_rerotate_voxel_data(data: torch.Tensor) -> torch.Tensor:
if data.ndim == 4:
return data.permute(0, 1, 3, 2).flip(-1).contiguous()
if data.ndim == 5:
return data.permute(0, 1, 2, 4, 3).flip(-1).contiguous()
raise ValueError(f"unexpected voxel shape {tuple(data.shape)}")
def _pixal3d_unrotate_vertices(vertices: torch.Tensor) -> torch.Tensor:
if vertices.numel() == 0:
return vertices
x, y, z = vertices.unbind(-1)
return torch.stack([-x, y, -z], dim=-1).contiguous()
def _pixal3d_unrotate_sparse_coords(coords: torch.Tensor, resolution: int) -> torch.Tensor:
if coords.numel() == 0:
return coords
R1 = resolution - 1
if coords.shape[-1] == 4:
b, i, j, k = coords.unbind(-1)
return torch.stack([b, R1 - i, j, R1 - k], dim=-1).contiguous()
if coords.shape[-1] == 3:
i, j, k = coords.unbind(-1)
return torch.stack([R1 - i, j, R1 - k], dim=-1).contiguous()
raise ValueError(f"unexpected coord shape {tuple(coords.shape)}")
def prepare_trellis_vae_for_decode(vae, sample_shape):
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
if len(sample_shape) == 5:
@ -202,15 +165,21 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
@classmethod
def execute(cls, samples, vae):
resolution = int(vae.first_stage_model.resolution.item())
# Mesh grid_size must match the actual coord resolution the upstream
# stage was run at (1024 cascade -> 64, 1536 cascade -> 96). The VAE's
# built-in `.resolution` buffer defaults to 1024 and is otherwise stale;
# take coord_resolution from the latent dict if the stage node set it.
coord_resolution = samples.get("coord_resolution")
if coord_resolution is not None:
resolution = int(coord_resolution) * 16
else:
resolution = int(vae.first_stage_model.resolution.item())
sample_tensor = samples["samples"]
device = comfy.model_management.get_torch_device()
coords = samples["coords"]
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
trellis_vae = vae.first_stage_model
coord_counts = samples.get("coord_counts")
pixal3d_mode = samples.get("model_options", {}).get("proj_feat_pack") is not None
samples = samples["samples"]
if coord_counts is None:
@ -236,10 +205,6 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
coords_list = [stage_tensor.coords for stage_tensor in stage_tensors]
subs.append(SparseTensor.from_tensor_list(feats_list, coords_list))
if pixal3d_mode:
for m in mesh:
m.vertices = _pixal3d_unrotate_vertices(m.vertices)
face_list = [m.faces for m in mesh]
vert_list = [m.vertices for m in mesh]
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
@ -276,7 +241,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
trellis_vae = vae.first_stage_model
coord_counts = samples.get("coord_counts")
pixal3d_mode = samples.get("model_options", {}).get("proj_feat_pack") is not None
samples = samples["samples"]
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
@ -288,7 +252,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides)
color_feats = voxel.feats[:, :3]
voxel_coords = voxel.coords#[:, 1:]
voxel_coords = voxel.coords
if voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
spatial = voxel_coords[:, -3:] if voxel_coords.shape[-1] == 4 else voxel_coords
@ -297,9 +261,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
else:
tex_resolution = 1024
if pixal3d_mode:
voxel_coords = _pixal3d_unrotate_sparse_coords(voxel_coords, resolution=tex_resolution)
voxel = Types.VOXEL(voxel_coords, color_feats, tex_resolution)
return IO.NodeOutput(voxel)
@ -312,7 +273,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
IO.Combo.Input("resolution", options=["32", "64"], default="32")
IO.Combo.Input("resolution", options=["32", "64"], default="32"),
],
outputs=[
IO.Voxel.Output("voxel"),
@ -338,129 +299,132 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
ratio = current_res // resolution
decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
voxel_data = decoded.squeeze(1).float()
if samples.get("model_options", {}).get("proj_feat_pack") is not None:
voxel_data = _pixal3d_unrotate_voxel_data(voxel_data)
out = Types.VOXEL(voxel_data)
return IO.NodeOutput(out)
return IO.NodeOutput(Types.VOXEL(voxel_data))
class Trellis2UpsampleCascade(IO.ComfyNode):
class Trellis2UpsampleStage(IO.ComfyNode):
"""Cascade-upsamples a 512-resolution shape latent into high-resolution
sparse coords and sets up the second shape-stage sampling pass at the
target resolution, attaching per-stage metadata to the conditioning for
the model to consume via extra_conds. target_resolution is reduced in
128-step decrements until the unique upsampled coord count fits under
max_tokens (floor 1024)."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Trellis2UpsampleCascade",
node_id="Trellis2UpsampleStage",
category="latent/3d",
display_name="Trellis2 Upsample Cascade",
description="Upsamples low-resolution Trellis2 shape latents into higher resolution coordinates while respecting the maximum token budget.",
display_name="Trellis2 Upsample Stage",
inputs=[
IO.Latent.Input("shape_latent"),
IO.Conditioning.Input("positive"),
IO.Conditioning.Input("negative"),
IO.Latent.Input("shape_latent", tooltip="The 512-resolution shape latent output from the first shape-stage KSampler."),
IO.Vae.Input("vae"),
IO.Combo.Input("target_resolution", options=["1024", "1536"], default="1024", tooltip="Controls output detail level for upsampling."),
IO.Int.Input("max_tokens", default=49152, min=1024, max=100000,
tooltip=(
"Maximum number of output elements (coordinates) allowed after upsampling. "
"Used to limit memory usage and control mesh density."
))
)),
],
outputs=[
IO.Voxel.Output(
"high_res_voxel",
tooltip=(
"High-resolution sparse coordinates produced after cascade upsampling. "
"Represents the refined 3D structure at target resolution."
)
)
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
IO.Latent.Output(),
]
)
@classmethod
def execute(cls, shape_latent, vae, target_resolution, max_tokens):
shape_latent_512 = shape_latent
device = comfy.model_management.get_torch_device()
prepare_trellis_vae_for_decode(vae, shape_latent_512["samples"].shape)
@staticmethod
def _quantize_unique(hr_coords: torch.Tensor, lr_resolution: int, hr_resolution: int) -> torch.Tensor:
# Fold the two scalar divisions into one and chain the float math in-place
# to avoid 3 full M*3 fp32 transients per call.
scale = (hr_resolution // 16) / lr_resolution
spatial = hr_coords[:, 1:].float()
spatial.add_(0.5).mul_(scale)
quant = torch.cat([hr_coords[:, :1], spatial.int()], dim=1)
return quant.unique(dim=0)
coord_counts = shape_latent_512.get("coord_counts")
@classmethod
def execute(cls, positive, negative, shape_latent, vae, target_resolution, max_tokens):
device = comfy.model_management.get_torch_device()
prepare_trellis_vae_for_decode(vae, shape_latent["samples"].shape)
coord_counts = shape_latent.get("coord_counts")
decoder = vae.first_stage_model.shape_dec
lr_resolution = 512
target_resolution = int(target_resolution)
if coord_counts is None:
feats, coords_512 = flatten_batched_sparse_latent(
shape_latent_512["samples"],
shape_latent_512["coords"],
coord_counts,
)
feats = feats.to(device)
coords_512 = coords_512.to(device)
slat = shape_norm(feats, coords_512)
slat.feats = slat.feats.to(next(decoder.parameters()).dtype)
hr_coords = decoder.upsample(slat, upsample_times=4)
hr_resolution = target_resolution
while True:
quant_coords = torch.cat([
hr_coords[:, :1],
((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
], dim=1)
final_coords = quant_coords.unique(dim=0)
num_tokens = final_coords.shape[0]
if num_tokens < max_tokens or hr_resolution <= 1024:
break
hr_resolution -= 128
return IO.NodeOutput(final_coords,)
items = split_batched_sparse_latent(
shape_latent_512["samples"],
shape_latent_512["coords"],
coord_counts,
)
decoder_dtype = next(decoder.parameters()).dtype
sample_hr_coords = []
for feats_i, coords_i in items:
feats_i = feats_i.to(device)
coords_i = coords_i.to(device).clone()
coords_i[:, 0] = 0
slat_i = shape_norm(feats_i, coords_i)
slat_i.feats = slat_i.feats.to(decoder_dtype)
sample_hr_coords.append(decoder.upsample(slat_i, upsample_times=4))
# Decode each sample's HR coords, then search for the largest hr_resolution
# that fits under max_tokens across all samples.
if coord_counts is None:
feats, coords_512 = flatten_batched_sparse_latent(
shape_latent["samples"], shape_latent["coords"], coord_counts,
)
slat = shape_norm(feats.to(device), coords_512.to(device))
slat.feats = slat.feats.to(decoder_dtype)
sample_hr_coords = [decoder.upsample(slat, upsample_times=4)]
else:
items = split_batched_sparse_latent(
shape_latent["samples"], shape_latent["coords"], coord_counts,
)
sample_hr_coords = []
for feats_i, coords_i in items:
coords_i = coords_i.to(device).clone()
coords_i[:, 0] = 0
slat_i = shape_norm(feats_i.to(device), coords_i)
slat_i.feats = slat_i.feats.to(decoder_dtype)
sample_hr_coords.append(decoder.upsample(slat_i, upsample_times=4))
# Resolution search — cache the final iteration's quantized unique tensors
# so we don't recompute .unique() per sample after picking hr_resolution.
hr_resolution = target_resolution
quant_unique_list = []
while True:
quant_unique_list = []
exceeds_limit = False
for hr_coords_i in sample_hr_coords:
quant_coords_i = torch.cat([
hr_coords_i[:, :1],
((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
], dim=1)
if quant_coords_i.unique(dim=0).shape[0] >= max_tokens:
qu = cls._quantize_unique(hr_coords_i, lr_resolution, hr_resolution)
quant_unique_list.append(qu)
if qu.shape[0] >= max_tokens:
exceeds_limit = True
break
if not exceeds_limit or hr_resolution <= 1024:
if not exceeds_limit:
break
if hr_resolution <= 1024:
for k in range(len(quant_unique_list), len(sample_hr_coords)):
quant_unique_list.append(
cls._quantize_unique(sample_hr_coords[k], lr_resolution, hr_resolution)
)
break
hr_resolution -= 128
final_coords_list = []
output_coord_counts = []
for sample_offset, hr_coords_i in enumerate(sample_hr_coords):
quant_coords_i = torch.cat([
hr_coords_i[:, :1],
((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
], dim=1)
final_coords_i = quant_coords_i.unique(dim=0)
final_coords_i = final_coords_i.clone()
final_coords_i[:, 0] = sample_offset
final_coords_list.append(final_coords_i)
output_coord_counts.append(int(final_coords_i.shape[0]))
# Rewrite batch column to match per-sample offset and concat.
per_sample_counts = []
for sample_offset, qu in enumerate(quant_unique_list):
qu[:, 0] = sample_offset
per_sample_counts.append(int(qu.shape[0]))
coords = torch.cat(quant_unique_list, dim=0)
counts = torch.tensor(per_sample_counts, dtype=torch.int64)
coord_resolution = hr_resolution // 16
coords = torch.cat(final_coords_list, dim=0)
output = Types.VOXEL(coords)
output.coord_counts = torch.tensor(output_coord_counts, dtype=torch.int64)
output.resolutions = torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64)
output.upsampled = True
batch_size, _, max_tokens_out = infer_batched_coord_layout(coords)
latent = torch.zeros(batch_size, 32, max_tokens_out, 1)
return IO.NodeOutput(output,)
extras = {
"trellis2_generation_mode": "shape_generation",
"trellis2_coords": coords,
"trellis2_coord_counts": counts,
}
proj_pack = _proj_pack_from_conditioning(positive)
if proj_pack is not None:
extras["trellis2_proj_feats"] = compute_stage_proj_feats(
proj_pack, "shape_1024", coords=coords, coord_resolution=coord_resolution,
)
positive_out = _conditioning_set_extras(positive, extras)
negative_out = _conditioning_set_extras(negative, extras)
out_latent = {"samples": latent, "coords": coords, "coord_counts": counts,
"coord_resolution": coord_resolution, "type": "trellis2"}
return IO.NodeOutput(positive_out, negative_out, out_latent)
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
@ -639,130 +603,152 @@ class Trellis2Conditioning(IO.ComfyNode):
negative = [[neg_cond_batched, {"embeds": neg_embeds_batched}]]
return IO.NodeOutput(positive, negative)
class EmptyTrellis2ShapeLatent(IO.ComfyNode):
def _proj_pack_from_conditioning(conditioning):
"""Return the proj_feat_pack dict embedded in a Pixal3D conditioning (or None
for vanilla Trellis2 / no conditioning connected). Pixal3DConditioning ships
the pack in cond[0][1]["proj_feat_pack"]; Trellis2Conditioning doesn't set it."""
if not conditioning:
return None
entry = conditioning[0]
if not isinstance(entry, (list, tuple)) or len(entry) < 2 or not isinstance(entry[1], dict):
return None
return entry[1].get("proj_feat_pack")
def _conditioning_set_extras(conditioning, extras: dict):
"""Return a copy of `conditioning` with `extras` merged into each entry's
dict same shallow-copy pattern ControlNetApplyAdvanced uses. The dicts
are copied so we don't mutate upstream conditioning."""
out = []
for entry in conditioning:
if isinstance(entry, (list, tuple)) and len(entry) >= 2 and isinstance(entry[1], dict):
new_dict = entry[1].copy()
new_dict.update(extras)
out.append([entry[0], new_dict])
else:
out.append(entry)
return out
class Trellis2ShapeStage(IO.ComfyNode):
"""Sets up the first shape-stage sampling pass: extracts sparse coords from
the dense structure voxel produced by VaeDecodeStructureTrellis2, builds an
empty sparse latent, and attaches per-stage metadata to the conditioning so
the model reads it via extra_conds at sample time. For the second shape pass
(post-upsample), use Trellis2UpsampleStage instead it combines the cascade
and the second-pass stage setup."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyTrellis2ShapeLatent",
node_id="Trellis2ShapeStage",
category="latent/3d",
inputs=[
IO.Conditioning.Input("positive"),
IO.Conditioning.Input("negative"),
IO.Voxel.Input(
"voxel",
tooltip=(
"Shape structure input. Accepts either a voxel structure "
"or upsampled voxel coordinates from a previous cascade stage."
)
),
Pixal3DProjPack.Input(
"proj_feat_pack",
optional=True,
tooltip="Pixal3D pixel-aligned projection pack from Pixal3DConditioning. Leave empty for vanilla Trellis2.",
tooltip="Dense structure voxel from VaeDecodeStructureTrellis2.",
),
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
IO.Latent.Output(),
]
)
@classmethod
def execute(cls, voxel, proj_feat_pack=None):
is_512_pass = False
coord_resolution = None
upsampled = hasattr(voxel, "upsampled")
if upsampled:
if hasattr(voxel, "resolutions") and voxel.resolutions is not None:
coord_resolution = int(voxel.resolutions[0].item()) // 16
voxel = voxel.data
if not upsampled:
voxel_data = voxel.data
if proj_feat_pack is not None:
voxel_data = _pixal3d_rerotate_voxel_data(voxel_data)
decoded = voxel_data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
is_512_pass = True
coord_resolution = int(decoded.shape[-1])
def execute(cls, positive, negative, voxel):
decoded = voxel.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
coord_resolution = int(decoded.shape[-1])
# Dispatch based on the upstream voxel resolution, mirroring upstream's
# pipeline_type → ss_res table:
# coord_res == 32 → first cascade shape pass OR pure-512 pipeline
# (img2shape_512 + shape_512 proj stage, 512 DINO).
# coord_res > 32 → pure-1024 non-cascade pipeline
# (img2shape + shape_1024 proj stage, 1024 DINO).
if coord_resolution <= 32:
mode = "shape_generation_512"
stage = "shape_512"
else:
coords = voxel.int()
mode = "shape_generation"
stage = "shape_1024"
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
in_channels = 32
# image like format
latent = torch.zeros(batch_size, in_channels, max_tokens, 1)
latent = torch.zeros(batch_size, 32, max_tokens, 1)
if is_512_pass:
generation_mode = "shape_generation_512"
else:
generation_mode = "shape_generation"
model_options = {"generation_mode": generation_mode, "coords": coords, "coord_counts": counts}
if coord_resolution is not None:
model_options["coord_resolution"] = coord_resolution
if proj_feat_pack is not None:
model_options["proj_feat_pack"] = proj_feat_pack
return IO.NodeOutput({"samples": latent, "coords": coords, "coord_counts": counts, "type": "trellis2",
"model_options": model_options})
extras = {
"trellis2_generation_mode": mode,
"trellis2_coords": coords,
"trellis2_coord_counts": counts,
}
proj_pack = _proj_pack_from_conditioning(positive)
if proj_pack is not None:
extras["trellis2_proj_feats"] = compute_stage_proj_feats(
proj_pack, stage, coords=coords, coord_resolution=coord_resolution,
)
positive_out = _conditioning_set_extras(positive, extras)
negative_out = _conditioning_set_extras(negative, extras)
out_latent = {"samples": latent, "coords": coords, "coord_counts": counts,
"coord_resolution": coord_resolution, "type": "trellis2"}
return IO.NodeOutput(positive_out, negative_out, out_latent)
class EmptyTrellis2LatentTexture(IO.ComfyNode):
class Trellis2TextureStage(IO.ComfyNode):
"""Sets up the texture-stage sampling pass. Reads coords / coord_counts /
coord_resolution and the shape_slat (the per-voxel shape latent) from the
incoming shape_latent dict set there by Trellis2ShapeStage or
Trellis2UpsampleStage. Builds an empty sparse latent at the same coord
layout and attaches per-stage metadata to the conditioning."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyTrellis2LatentTexture",
node_id="Trellis2TextureStage",
category="latent/3d",
inputs=[
IO.Voxel.Input(
"voxel",
tooltip=(
"Shape structure input. Accepts either a voxel structure "
"or upsampled voxel coordinates from a previous cascade stage."
)
),
IO.Conditioning.Input("positive"),
IO.Conditioning.Input("negative"),
IO.Latent.Input("shape_latent"),
Pixal3DProjPack.Input(
"proj_feat_pack",
optional=True,
tooltip="Pixal3D pixel-aligned projection pack from Pixal3DConditioning. Leave empty for vanilla Trellis2.",
),
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
IO.Latent.Output(),
]
)
@classmethod
def execute(cls, voxel, shape_latent, proj_feat_pack=None):
def execute(cls, positive, negative, shape_latent):
channels = 32
coord_resolution = None
upsampled = hasattr(voxel, "upsampled")
if upsampled:
if hasattr(voxel, "resolutions") and voxel.resolutions is not None:
coord_resolution = int(voxel.resolutions[0].item()) // 16
voxel = voxel.data
if not upsampled:
voxel_data = voxel.data
if proj_feat_pack is not None:
voxel_data = _pixal3d_rerotate_voxel_data(voxel_data)
decoded = voxel_data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
coord_resolution = int(decoded.shape[-1])
else:
coords = voxel.int()
coords = shape_latent["coords"]
coord_resolution = shape_latent.get("coord_resolution")
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
shape_latent = shape_latent["samples"]
if shape_latent.ndim == 4:
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
shape_slat = shape_latent["samples"]
if shape_slat.ndim == 4:
shape_slat = shape_slat.squeeze(-1).transpose(1, 2).reshape(-1, channels)
latent = torch.zeros(batch_size, channels, max_tokens, 1)
model_options = {"generation_mode": "texture_generation", "coords": coords, "coord_counts": counts, "shape_slat": shape_latent}
extras = {
"trellis2_generation_mode": "texture_generation",
"trellis2_coords": coords,
"trellis2_coord_counts": counts,
"trellis2_shape_slat": shape_slat,
}
proj_pack = _proj_pack_from_conditioning(positive)
if proj_pack is not None and coord_resolution is not None:
extras["trellis2_proj_feats"] = compute_stage_proj_feats(
proj_pack, "tex_1024", coords=coords, coord_resolution=coord_resolution,
)
positive_out = _conditioning_set_extras(positive, extras)
negative_out = _conditioning_set_extras(negative, extras)
out_latent = {"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts}
if coord_resolution is not None:
model_options["coord_resolution"] = coord_resolution
if proj_feat_pack is not None:
model_options["proj_feat_pack"] = proj_feat_pack
return IO.NodeOutput({"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts,
"model_options": model_options})
out_latent["coord_resolution"] = coord_resolution
return IO.NodeOutput(positive_out, negative_out, out_latent)
class EmptyTrellis2LatentStructure(IO.ComfyNode):
@ -773,30 +759,17 @@ class EmptyTrellis2LatentStructure(IO.ComfyNode):
category="latent/3d",
inputs=[
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
Pixal3DProjPack.Input(
"proj_feat_pack",
optional=True,
tooltip="Pixal3D pixel-aligned projection pack. Leave empty for vanilla Trellis2.",
),
],
outputs=[
IO.Latent.Output(),
]
)
@classmethod
def execute(cls, batch_size, proj_feat_pack=None):
# Trellis2.forward slices x[:, :8] and pads out to 32; KSampler residual math
# needs the empty latent to match latent_format (32-channel).
def execute(cls, batch_size):
in_channels = 32
resolution = 16
latent = torch.zeros(batch_size, in_channels, resolution, resolution, resolution)
output = {
"samples": latent,
"type": "trellis2",
}
if proj_feat_pack is not None:
output["model_options"] = {"proj_feat_pack": proj_feat_pack}
return IO.NodeOutput(output)
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
def _dinov3_patches_to_2d(tokens, image_size, patch_size=16):
h_p = w_p = image_size // patch_size
@ -813,20 +786,11 @@ def _dinov3_patches_to_2d(tokens, image_size, patch_size=16):
return patches.transpose(1, 2).reshape(tokens.shape[0], -1, h_p, w_p).contiguous()
def _fov_from_moge_intrinsics(moge_intrinsics: torch.Tensor) -> float:
fx = moge_intrinsics[..., 0, 0].float()
fov = 2.0 * torch.atan(0.5 / fx.clamp(min=1e-4))
return float(fov.mean().item())
def _run_dinov3_with_patches(model, cropped_pil, image_size):
# Pixal3D's cross-attn was trained against CLS + registers only (~5 tokens), not the
# full patch grid. The patch grid goes to the proj branch via patches_2d.
def _run_dinov3_with_patches(model, composite, image_size):
model_internal = model.model
torch_device = comfy.model_management.get_torch_device()
resized = cropped_pil.resize((image_size, image_size), Image.Resampling.LANCZOS)
img_np = np.array(resized).astype(np.float32) / 255.0
img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device)
img_t = comfy.utils.common_upscale(composite, image_size, image_size, "lanczos", "disabled")
img_t = img_t.to(torch_device)
img_t = (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
model_internal.image_size = image_size
tokens = model_internal(img_t, skip_norm_elementwise=True)[0]
@ -838,48 +802,59 @@ def _run_dinov3_with_patches(model, cropped_pil, image_size):
def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
pil_img = Image.fromarray(img_np)
pil_mask = Image.fromarray(mask_np)
max_size = max(pil_img.size)
scale = min(1.0, max_image_size / max_size)
if scale < 1.0:
new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale)
pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS)
pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST)
scene_size = (pil_img.width, pil_img.height)
rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8)
rgba_np[:, :, :3] = np.array(pil_img)
rgba_np[:, :, 3] = np.array(pil_mask)
alpha = rgba_np[:, :, 3]
bbox_coords = np.argwhere(alpha > 0.8 * 255)
if len(bbox_coords) > 0:
y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1])
y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1])
img = item_image.permute(2, 0, 1).unsqueeze(0).cpu().float()
mask = item_mask.unsqueeze(0).unsqueeze(0).cpu().float()
# Upstream went float→PIL uint8 implicitly; match that to keep composite bit-exact.
img = (img.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0
mask = (mask.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0
H, W = img.shape[-2:]
if max(H, W) > max_image_size:
scale = max_image_size / max(H, W)
new_w, new_h = int(W * scale), int(H * scale)
img = comfy.utils.common_upscale(img, new_w, new_h, "lanczos", "disabled")
mask = comfy.utils.common_upscale(mask, new_w, new_h, "nearest-exact", "disabled")
H, W = new_h, new_w
scene_size = (W, H)
alpha_u8 = (mask[0, 0].clamp(0, 1) * 255.0).to(torch.uint8)
fg_pixels = (alpha_u8 > 204).nonzero()
if fg_pixels.numel() > 0:
y_min, x_min = fg_pixels.min(dim=0).values.tolist()
y_max, x_max = fg_pixels.max(dim=0).values.tolist()
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
# Upstream pads the bbox by 10% — encoders were trained with that breathing room.
size = max(y_max - y_min, x_max - x_min)
size = int(size * 1.1)
size = int(max(y_max - y_min, x_max - x_min) * 1.1)
half = size // 2
crop_x1 = int(center_x - half)
crop_y1 = int(center_y - half)
crop_x2 = crop_x1 + 2 * half
crop_y2 = crop_y1 + 2 * half
crop_bbox = (crop_x1, crop_y1, crop_x2, crop_y2)
rgba_pil = Image.fromarray(rgba_np)
cropped_rgba = rgba_pil.crop(crop_bbox)
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
else:
logging.warning("Mask for the image is empty. Pixal3D requires a clean foreground mask.")
cropped_np = rgba_np.astype(np.float32) / 255.0
crop_bbox = (0, 0, scene_size[0], scene_size[1])
fg = cropped_np[:, :, :3]
alpha_float = cropped_np[:, :, 3:4]
composite_np = fg * alpha_float
composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
return Image.fromarray(composite_uint8), crop_bbox, scene_size
crop_x1, crop_y1, crop_x2, crop_y2 = 0, 0, W, H
crop_bbox = (crop_x1, crop_y1, crop_x2, crop_y2)
# Zero-pad out-of-bounds slice (PIL.crop semantics).
pad_l = max(0, -crop_x1)
pad_t = max(0, -crop_y1)
pad_r = max(0, crop_x2 - W)
pad_b = max(0, crop_y2 - H)
if pad_l or pad_t or pad_r or pad_b:
img = torch.nn.functional.pad(img, (pad_l, pad_r, pad_t, pad_b), value=0.0)
mask = torch.nn.functional.pad(mask, (pad_l, pad_r, pad_t, pad_b), value=0.0)
crop_x1 += pad_l; crop_x2 += pad_l
crop_y1 += pad_t; crop_y2 += pad_t
cropped_img = img [..., crop_y1:crop_y2, crop_x1:crop_x2]
cropped_mask = mask[..., crop_y1:crop_y2, crop_x1:crop_x2]
composite = (cropped_img * cropped_mask).clamp(0, 1)
composite = (composite * 255.0).round().clamp(0, 255).to(torch.uint8).float() / 255.0
return composite, crop_bbox, scene_size
def _fov_from_moge_intrinsics(moge_intrinsics: torch.Tensor) -> float:
fx = moge_intrinsics[..., 0, 0].float()
fov = 2.0 * torch.atan(0.5 / fx.clamp(min=1e-4))
return float(fov.mean().item())
class Pixal3DConditioning(IO.ComfyNode):
@ -901,10 +876,6 @@ class Pixal3DConditioning(IO.ComfyNode):
"mesh_scale", default=1.0, min=0.1, max=4.0, step=0.01,
tooltip="Mesh scale; 1.0 means unit cube.",
),
IO.Float.Input(
"distance_override", default=0.0, min=0.0, max=10.0, step=0.001,
tooltip="Override camera distance directly. 0 = auto-derive from FOV.",
),
io.Custom("MOGE_GEOMETRY").Input(
"moge_geometry",
optional=True,
@ -920,13 +891,11 @@ class Pixal3DConditioning(IO.ComfyNode):
outputs=[
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
Pixal3DProjPack.Output(display_name="proj_feat_pack"),
],
)
@classmethod
def execute(cls, clip_vision_model, image, mask, camera_angle_x, mesh_scale,
distance_override=0.0,
moge_geometry=None, naf_model=None) -> IO.NodeOutput:
if image.ndim == 3:
image = image.unsqueeze(0)
@ -945,21 +914,21 @@ class Pixal3DConditioning(IO.ComfyNode):
cond_512_list, cond_1024_list = [], []
patches_512_list, patches_1024_list = [], []
cropped_pil_list = []
composite_list = []
crop_bbox_list, scene_size_list = [], []
torch_device = comfy.model_management.get_torch_device()
for b in range(batch_size):
item_image = image[b]
item_mask = mask[b] if mask.size(0) > 1 else mask[0]
cropped_pil, crop_bbox, scene_size = _crop_image_with_mask(
composite, crop_bbox, scene_size = _crop_image_with_mask(
item_image, item_mask, max_image_size=1024)
crop_bbox_list.append(crop_bbox)
scene_size_list.append(scene_size)
cropped_pil_list.append(cropped_pil)
composite_list.append(composite)
cond_512 = _run_dinov3_with_patches(clip_vision_model, cropped_pil, 512)
cond_1024 = _run_dinov3_with_patches(clip_vision_model, cropped_pil, 1024)
cond_512 = _run_dinov3_with_patches(clip_vision_model, composite, 512)
cond_1024 = _run_dinov3_with_patches(clip_vision_model, composite, 1024)
cond_512_list.append(cond_512["tokens"].to(device))
cond_1024_list.append(cond_1024["tokens"].to(device))
patches_512_list.append(cond_512["patches_2d"].to(device))
@ -971,42 +940,28 @@ class Pixal3DConditioning(IO.ComfyNode):
fm_512_dino = torch.cat(patches_512_list, dim=0)
fm_1024_dino = torch.cat(patches_1024_list, dim=0)
# Upstream samples the LR DINO grid AND the NAF HR grid separately at projected
# 3D points, then cats sampled features along channels. Back-projection (in model.py)
# mirrors that — here we just stash LR + optional HR per stage.
# The LR DINO grid AND the NAF HR grid are sampled separately
# NAF targets per stage: shape_512=512, shape_1024=512, tex_1024=1024.
def _naf_hr(lr_feat, image_pil_list, image_size, naf_target):
def _naf_hr(lr_feat, composites, image_size, naf_target):
if naf_model is None or naf_target is None:
return None
# Run NAF in the input feature dtype (typically fp16 since DINO/ClipVision
# loads that way). The previous .float() cast doubled NAF memory by forcing
# full fp32 — at tex_1024/target=1024 that's ~10 GB on its own. Model
# weights need to match input dtype since PyTorch conv ops error out on
# mixed fp16-input/fp32-weight.
target_dtype = lr_feat.dtype
if next(naf_model.parameters()).dtype != target_dtype:
naf_model.to(dtype=target_dtype)
imgs = torch.stack([
torch.from_numpy(
np.array(p.resize((image_size, image_size), Image.Resampling.LANCZOS))
.astype(np.float32) / 255.0
).permute(2, 0, 1)
for p in image_pil_list
imgs = torch.cat([
comfy.utils.common_upscale(c, image_size, image_size, "lanczos", "disabled")
for c in composites
], dim=0).to(torch_device).to(target_dtype)
hr = naf_model(imgs, lr_feat.to(torch_device).to(target_dtype), naf_target)
return hr.to(device)
hr_shape_512 = _naf_hr(fm_512_dino, cropped_pil_list, 512, (512, 512))
hr_shape_1024 = _naf_hr(fm_1024_dino, cropped_pil_list, 1024, (512, 512))
hr_tex_1024 = _naf_hr(fm_1024_dino, cropped_pil_list, 1024, (1024, 1024))
hr_shape_512 = _naf_hr(fm_512_dino, composite_list, 512, (512, 512))
hr_shape_1024 = _naf_hr(fm_1024_dino, composite_list, 1024, (512, 512))
hr_tex_1024 = _naf_hr(fm_1024_dino, composite_list, 1024, (1024, 1024))
# distance_from_fov: grid_point (-1, 0, 0) projects to pixel (0, image_resolution-1).
camera_angle_x = float(camera_angle_x)
if distance_override > 0:
distance = float(distance_override)
else:
distance = 0.5 / math.tan(camera_angle_x / 2.0) / float(mesh_scale)
distance = 0.5 / math.tan(camera_angle_x / 2.0) / float(mesh_scale)
cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=device, dtype=torch.float32)
dist_t = torch.tensor([distance] * batch_size, device=device, dtype=torch.float32)
scale_t = torch.tensor([float(mesh_scale)] * batch_size, device=device, dtype=torch.float32)
@ -1028,13 +983,33 @@ class Pixal3DConditioning(IO.ComfyNode):
"scene_sizes": scene_size_list,
}
# global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024
# (Trellis2.forward swaps context↔embeds for non-structure HR stages).
# global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024.
# proj_feat_pack rides in the conditioning dict (same place embeds, ControlNet
# hints etc. live); the sampler auto-promotes it to a model.forward kwarg via
# Trellis2.extra_conds. The same pack object is shared between pos/neg —
# CONDConstant.can_concat sees them equal and concats to a single dict, then
# Trellis2.forward zeros proj for the uncond slots via cond_or_uncond.
# Pre-compute the SS-stage proj features (dense 16³ grid) once here — the
# shape/texture stages do their own computes in their respective stage nodes.
# proj_pack lives on intermediate (CPU); force the compute onto cuda so
# the bilinear-sampling step doesn't run on CPU.
ss_proj_feats = compute_stage_proj_feats(
proj_pack, "ss", dense_grid_resolution=16, batch_size=batch_size,
device=torch_device,
)
neg_global = torch.zeros_like(global_512)
neg_embeds = torch.zeros_like(global_1024)
positive = [[global_512, {"embeds": global_1024}]]
negative = [[neg_global, {"embeds": neg_embeds}]]
return IO.NodeOutput(positive, negative, proj_pack)
base_extras = {
"embeds": global_1024, "proj_feat_pack": proj_pack,
"trellis2_proj_feats": ss_proj_feats,
}
neg_extras = {
"embeds": neg_embeds, "proj_feat_pack": proj_pack,
"trellis2_proj_feats": ss_proj_feats,
}
positive = [[global_512, base_extras]]
negative = [[neg_global, neg_extras]]
return IO.NodeOutput(positive, negative)
def _project_vertices_to_image_uv(vertices_world, transform_matrix, camera_angle_x, image_resolution):
@ -1069,7 +1044,7 @@ class Pixal3DAlignObject(IO.ComfyNode):
category="latent/3d",
inputs=[
IO.Mesh.Input("mesh"),
Pixal3DProjPack.Input("proj_feat_pack", tooltip="The proj pack produced by Pixal3DConditioning for this object."),
IO.Conditioning.Input("positive", tooltip="The positive conditioning from Pixal3DConditioning for this object — Pixal3DAlignObject reads transform_matrix / camera_angle_x / mesh_scale / crop_bboxes out of its proj_feat_pack."),
io.Custom("MOGE_GEOMETRY").Input("moge_geometry", tooltip="MoGe geometry computed on the original scene image."),
IO.Mask.Input(
"object_mask",
@ -1089,7 +1064,10 @@ class Pixal3DAlignObject(IO.ComfyNode):
)
@classmethod
def execute(cls, mesh, proj_feat_pack, moge_geometry, object_mask=None, batch_index=0) -> IO.NodeOutput:
def execute(cls, mesh, positive, moge_geometry, object_mask=None, batch_index=0) -> IO.NodeOutput:
proj_feat_pack = _proj_pack_from_conditioning(positive)
if proj_feat_pack is None:
raise ValueError("Pixal3DAlignObject: positive conditioning has no proj_feat_pack — connect a Pixal3DConditioning output.")
vertices = mesh.vertices
faces = mesh.faces
if vertices.ndim == 3:
@ -1117,10 +1095,11 @@ class Pixal3DAlignObject(IO.ComfyNode):
"image. Run MoGe on the same resized scene image Pixal3DConditioning used."
)
# Compose VaeDecodeShapeTrellis's R_y(180°) inverse with R_proj to map user mesh
# space to ProjGrid world: (X, Y, Z) -> (-X, Z, Y).
# Vertices come out of VaeDecodeShapeTrellis in the Pixal3D model frame
# (no un-rotation). Apply _PROJ_GRID_ROTATION = R_x(-90°) to map model
# frame → ProjGrid world: (X, Y, Z) -> (X, -Z, Y).
v = vertices_one.float()
verts_world = torch.stack([-v[..., 0], v[..., 2], v[..., 1]], dim=-1)
verts_world = torch.stack([v[..., 0], -v[..., 2], v[..., 1]], dim=-1)
verts_world = verts_world / float(mesh_scale.item())
uv_crop, _depth, valid = _project_vertices_to_image_uv(
verts_world, T[0], cam_angle[0], image_resolution)
@ -1200,6 +1179,74 @@ class LoadNAFModel(IO.ComfyNode):
return IO.NodeOutput(model)
class CFGGuidanceInterval(IO.ComfyNode):
"""Generic model patch: apply CFG only during [start_percent, end_percent] of
the sampling schedule. Outside that window, skip the uncond computation and
collapse to effective cfg=1 same idea as upstream Trellis2 / Pixal3D's
guidance_interval_mixin, but lives at the sampler level (via
sampler_calc_cond_batch_function) so it works for any model.
Percents use ComfyUI's standard convention: 0.0 = start of sampling
(max-noise step), 1.0 = end of sampling (clean step). Conversion to sigma
is done via model_sampling.percent_to_sigma so the window is portable
across schedules (flow / EDM / discrete) and shift settings.
Defaults are full-range (no bypass). For Trellis2's upstream behavior,
wire (start_percent=0.0, end_percent=0.667) on the SS / shape KSamplers;
texture defaults to cfg=1 so the node is moot there."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="CFGGuidanceInterval",
category="model_patches/sampling",
inputs=[
IO.Model.Input("model"),
IO.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001,
tooltip="Fraction of sampling at which CFG turns ON (0 = beginning)."),
IO.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001,
tooltip="Fraction of sampling at which CFG turns OFF (1 = end)."),
],
outputs=[IO.Model.Output()],
)
@classmethod
def execute(cls, model, start_percent, end_percent):
import comfy.samplers
model_sampling = model.get_model_object("model_sampling")
# percent_to_sigma is monotonically decreasing: percent=0 -> sigma_max,
# percent=1 -> sigma_min. So start_percent < end_percent in user space
# means sigma_start > sigma_end. "Inside the window" is sigma in
# [sigma_end, sigma_start].
sigma_start = float(model_sampling.percent_to_sigma(start_percent))
sigma_end = float(model_sampling.percent_to_sigma(end_percent))
def calc_cond_batch_with_interval(args):
sigma_val = args["sigma"][0].item()
conds = args["conds"]
input_x = args["input"]
timestep = args["sigma"]
model_ref = args["model"]
model_opts = args["model_options"]
# conds is typically [cond, uncond]; uncond may be None when ComfyUI's
# global cfg=1 optimization has already pruned it.
cond = conds[0]
uncond = conds[1] if len(conds) > 1 else None
inside = sigma_end <= sigma_val <= sigma_start
if uncond is None or inside:
return comfy.samplers.calc_cond_batch(model_ref, conds, input_x, timestep, model_opts)
# Outside the window: compute cond only, mirror it into the uncond slot
# so the downstream cfg_function collapses to `cond` (effective cfg=1).
out = comfy.samplers.calc_cond_batch(model_ref, [cond], input_x, timestep, model_opts)
return [out[0], out[0]]
m = model.clone()
m.model_options["sampler_calc_cond_batch_function"] = calc_cond_batch_with_interval
return IO.NodeOutput(m)
class Trellis2Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -1208,13 +1255,14 @@ class Trellis2Extension(ComfyExtension):
Pixal3DConditioning,
Pixal3DAlignObject,
LoadNAFModel,
EmptyTrellis2ShapeLatent,
Trellis2ShapeStage,
EmptyTrellis2LatentStructure,
EmptyTrellis2LatentTexture,
Trellis2TextureStage,
VaeDecodeTextureTrellis,
VaeDecodeShapeTrellis,
VaeDecodeStructureTrellis2,
Trellis2UpsampleCascade,
Trellis2UpsampleStage,
CFGGuidanceInterval,
]

View File

@ -1537,10 +1537,6 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
if "model_options" in latent:
inner = model.model.diffusion_model
inner.meta = latent["model_options"]
callback = latent_preview.prepare_callback(model, steps)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,