diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a68cb8439..d4f038a63 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -32,6 +32,11 @@ except ImportError as e: raise e exit(-1) +try: + from sageattention import sageattn_varlen # sageattention >= 2 +except ImportError: + sageattn_varlen = None + SAGE_ATTENTION3_IS_AVAILABLE = False try: from sageattn3 import sageattn3_blackwell @@ -48,6 +53,24 @@ except ImportError: logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") exit(-1) +try: + from torch.nn.attention.varlen import varlen_attn as _torch_varlen_attn +except ImportError: + _torch_varlen_attn = None + + +def _is_varlen(kwargs): + """Varlen mode is opted into by passing cu_seqlens_q in kwargs.""" + return kwargs.get("cu_seqlens_q") is not None + + +def _varlen_args(kwargs): + cu_seqlens_q = kwargs["cu_seqlens_q"] + cu_seqlens_kv = kwargs.get("cu_seqlens_kv", cu_seqlens_q) + max_seqlen_q = int(kwargs["max_seqlen_q"]) + max_seqlen_kv = int(kwargs.get("max_seqlen_kv", max_seqlen_q)) + return cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv + REGISTERED_ATTENTION_FUNCTIONS = {} def register_attention_function(name: str, func: Callable): # avoid replacing existing functions @@ -144,6 +167,8 @@ def wrap_attn(func): @wrap_attn def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if _is_varlen(kwargs): + return attention_pytorch(q, k, v, heads, mask=mask, **kwargs) attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -218,6 +243,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape @wrap_attn def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if _is_varlen(kwargs): + return attention_pytorch(query, key, value, heads, mask=mask, **kwargs) attn_precision = get_attn_precision(attn_precision, query.dtype) if skip_reshape: @@ -293,6 +320,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, @wrap_attn def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if _is_varlen(kwargs): + return attention_pytorch(q, k, v, heads, mask=mask, **kwargs) attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -424,6 +453,17 @@ except: @wrap_attn def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if _is_varlen(kwargs): + # q, k, v expected as packed [T_total, H, C]. xformers wants per-item + # seqlen lists and a leading batch dim of 1. + cu_seqlens_q, cu_seqlens_kv, _max_q, _max_kv = _varlen_args(kwargs) + q_seqlen = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).tolist() + kv_seqlen = (cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]).tolist() + attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + return xformers.ops.memory_efficient_attention( + q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attn_bias=attn_bias, + )[0] + b = q.shape[0] dim_head = q.shape[-1] # check to make sure xformers isn't broken @@ -493,6 +533,22 @@ else: @wrap_attn def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if _is_varlen(kwargs): + # q, k, v expected as packed [T_total, H, C]. cu_seqlens_q / cu_seqlens_kv + # describe per-item offsets. Native varlen kernel if available, else + # nested-tensor SDPA fallback. mask/attn_precision/etc are ignored here. + cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs) + if _torch_varlen_attn is not None: + return _torch_varlen_attn(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv) + q_nj = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long()) + k_nj = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_kv.long()) + v_nj = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_kv.long()) + out = comfy.ops.scaled_dot_product_attention( + q_nj.transpose(1, 2), k_nj.transpose(1, 2), v_nj.transpose(1, 2), + attn_mask=None, dropout_p=0.0, is_causal=False, + ) + return out.transpose(1, 2).values() + if skip_reshape: b, _, _, dim_head = q.shape else: @@ -541,6 +597,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha @wrap_attn def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if _is_varlen(kwargs): + # q, k, v expected as packed [T_total, H, C]. + if sageattn_varlen is None: + # sageattention v1 has no varlen kernel; fall back to attention_pytorch's varlen path. + return attention_pytorch(q, k, v, heads, mask=mask, **kwargs) + cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs) + return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv) if kwargs.get("low_precision_attention", True) is False: return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs) @@ -698,6 +761,12 @@ except AttributeError as error: @wrap_attn def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if _is_varlen(kwargs): + # q, k, v expected as packed [T_total, H, C]. + from flash_attn import flash_attn_varlen_func + cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs) + return flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv) + if skip_reshape: b, _, _, dim_head = q.shape else: diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index a09a8fca8..07a629c01 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -1,125 +1,107 @@ -import torch import math +from typing import Tuple, Union + +import torch + from comfy.ldm.modules.attention import optimized_attention -from typing import Tuple, Union, List from comfy.ldm.trellis2.vae import VarLenTensor -import comfy.ops - -try: - from torch.nn.attention.varlen import varlen_attn as _varlen_attn -except ImportError: - _varlen_attn = None -def var_attn_arg(kwargs): - cu_seqlens_q = kwargs.get("cu_seqlens_q", None) - max_seqlen_q = kwargs.get("max_seqlen_q", None) - cu_seqlens_k = kwargs.get("cu_seqlens_kv", cu_seqlens_q) - max_seqlen_k = kwargs.get("max_kv_seqlen", max_seqlen_q) - assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True" - return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k +def _build_cu_seqlens(seqlen, device): + """Cumulative-offset tensor from a list/tensor of per-item sequence lengths.""" + if isinstance(seqlen, torch.Tensor): + return torch.cat( + [torch.zeros(1, dtype=torch.int32, device=seqlen.device), torch.cumsum(seqlen, dim=0).int()], + dim=0, + ).to(device) + return torch.cat( + [torch.tensor([0]), torch.cumsum(torch.tensor(seqlen), dim=0)], + ).int().to(device) -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): - cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs) - if not skip_reshape: - total_tokens, embed_dim = q.shape - head_dim = embed_dim // heads - q = q.view(total_tokens, heads, head_dim) - k = k.view(k.shape[0], heads, head_dim) - v = v.view(v.shape[0], heads, head_dim) - if _varlen_attn is not None: - return _varlen_attn( - q, k, v, - cu_seqlens_q, cu_seqlens_k, - int(max_seqlen_q), int(max_seqlen_k), - ) +def _layout_to_feats(t): + """Returns (sparse_template_or_None, seqlen, feats[T, H, C]).""" + if isinstance(t, VarLenTensor): + seqlen = [t.layout[i].stop - t.layout[i].start for i in range(t.shape[0])] + return t, seqlen, t.feats + N, L = t.shape[:2] + return None, [L] * N, t.reshape(-1, *t.shape[-2:]) - # Fallback: nested-tensor SDPA (PyTorch < the version that introduced varlen_attn) - q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long()) - k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long()) - v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long()) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - return out.transpose(1, 2).values() - -def scaled_dot_product_attention(*args, **kwargs): - num_all_args = len(args) + len(kwargs) - - q = None - if num_all_args == 1: - qkv = args[0] if len(args) > 0 else kwargs.get('qkv') - elif num_all_args == 2: - q = args[0] if len(args) > 0 else kwargs.get('q') - kv = args[1] if len(args) > 1 else kwargs.get('kv') - elif num_all_args == 3: - q = args[0] if len(args) > 0 else kwargs.get('q') - k = args[1] if len(args) > 1 else kwargs.get('k') - v = args[2] if len(args) > 2 else kwargs.get('v') - - if q is not None: - heads = q.shape[2] - else: - heads = qkv.shape[3] - - if num_all_args == 1: - q, k, v = qkv.unbind(dim=2) - elif num_all_args == 2: - k, v = kv.unbind(dim=2) +def dense_attention(q, k, v, **kwargs): + """q, k, v: [B, L, H, C]. Permutes for comfy's [B, H, L, C] convention.""" + heads = q.shape[2] q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs) + return out.permute(0, 2, 1, 3) - out = out.permute(0, 2, 1, 3) - return out +def sparse_attention(q, k, v, **kwargs): + """ + Varlen attention for SparseTensor inputs. Each of q, k, v may be a VarLenTensor + (sparse) or dense [B, L, H, C]. Output type matches q. Backend dispatch lives + in comfy.ldm.modules.attention.optimized_attention; we just build cu_seqlens + from the layouts. + """ + s, q_seqlen, q_feats = _layout_to_feats(q) + _, kv_seqlen, k_feats = _layout_to_feats(k) + _, _, v_feats = _layout_to_feats(v) + heads = q_feats.shape[1] -def sparse_windowed_scaled_dot_product_self_attention( - qkv, - window_size: int, - shift_window: Tuple[int, int, int] = (0, 0, 0) -): + device = q_feats.device + cu_seqlens_q = _build_cu_seqlens(q_seqlen, device) + cu_seqlens_kv = _build_cu_seqlens(kv_seqlen, device) - serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}' - serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) - if serialization_spatial_cache is None: - fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window) - qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args)) - else: - fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache + out = optimized_attention( + q_feats, k_feats, v_feats, heads, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max(q_seqlen), max_seqlen_kv=max(kv_seqlen), + skip_reshape=True, skip_output_reshape=True, + **kwargs, + ) + + if s is not None: + return s.replace(out) + N, L = q.shape[:2] + return out.reshape(N, L, heads, -1) + + +def sparse_windowed_self_attention(qkv, window_size: int, shift_window: Tuple[int, int, int] = (0, 0, 0)): + """Windowed sparse self-attention. Partitions voxels into windows via spatial + sort, then runs varlen attention with one sequence per non-empty window.""" + cache_name = f'windowed_attention_{window_size}_{shift_window}' + cache = qkv.get_spatial_cache(cache_name) + if cache is None: + cache = calc_window_partition(qkv, window_size, shift_window) + qkv.register_spatial_cache(cache_name, cache) + fwd_indices, bwd_indices, seq_lens = cache qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] - heads = qkv_feats.shape[2] + q, k, v = qkv_feats.unbind(dim=1) # each [M, H, C] + heads = q.shape[1] + device = q.device - if optimized_attention.__name__ == 'attention_xformers': - q, k, v = qkv_feats.unbind(dim=1) - q = q.unsqueeze(0) # [1, M, H, C] - k = k.unsqueeze(0) # [1, M, H, C] - v = v.unsqueeze(0) # [1, M, H, C] - #out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C] - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) - elif optimized_attention.__name__ == 'attention_flash': - if 'flash_attn' not in globals(): - import flash_attn - out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C] - else: - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) - - out = out[bwd_indices] # [T, H, C] + cu_seqlens = _build_cu_seqlens(seq_lens, device) + max_seqlen = int(seq_lens.max()) + out = optimized_attention( + q, k, v, heads, + cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens, + max_seqlen_q=max_seqlen, max_seqlen_kv=max_seqlen, + skip_reshape=True, skip_output_reshape=True, + ) + out = out[bwd_indices] return qkv.replace(out) + def calc_window_partition( tensor, window_size: Union[int, Tuple[int, ...]], shift_window: Union[int, Tuple[int, ...]] = 0, -) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: - +): + """Returns (fwd_indices, bwd_indices, seq_lens) for window partitioning.""" DIM = tensor.coords.shape[1] - 1 shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size @@ -136,139 +118,6 @@ def calc_window_partition( bwd_indices = torch.empty_like(fwd_indices) bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) seq_lens = torch.bincount(shifted_indices) - mask = seq_lens != 0 - seq_lens = seq_lens[mask] + seq_lens = seq_lens[seq_lens != 0] - if optimized_attention.__name__ == 'attention_xformers': - if 'xops' not in globals(): - import xformers.ops as xops - attn_func_args = { - 'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) - } - elif optimized_attention.__name__ == 'attention_flash': - attn_func_args = { - 'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(), - 'max_seqlen': torch.max(seq_lens) - } - - return fwd_indices, bwd_indices, seq_lens, attn_func_args - - -def sparse_scaled_dot_product_attention(*args, **kwargs): - q=None - arg_names_dict = { - 1: ['qkv'], - 2: ['q', 'kv'], - 3: ['q', 'k', 'v'] - } - num_all_args = len(args) + len(kwargs) - for key in arg_names_dict[num_all_args][len(args):]: - assert key in kwargs, f"Missing argument {key}" - - if num_all_args == 1: - qkv = args[0] if len(args) > 0 else kwargs['qkv'] - device = qkv.device - - s = qkv - q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] - kv_seqlen = q_seqlen - qkv = qkv.feats # [T, 3, H, C] - - elif num_all_args == 2: - q = args[0] if len(args) > 0 else kwargs['q'] - kv = args[1] if len(args) > 1 else kwargs['kv'] - device = q.device - - if isinstance(q, VarLenTensor): - s = q - q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] - q = q.feats # [T_Q, H, C] - else: - s = None - N, L, H, C = q.shape - q_seqlen = [L] * N - q = q.reshape(N * L, H, C) # [T_Q, H, C] - - if isinstance(kv, VarLenTensor): - kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] - kv = kv.feats # [T_KV, 2, H, C] - else: - N, L, _, H, C = kv.shape - kv_seqlen = [L] * N - kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] - - elif num_all_args == 3: - q = args[0] if len(args) > 0 else kwargs['q'] - k = args[1] if len(args) > 1 else kwargs['k'] - v = args[2] if len(args) > 2 else kwargs['v'] - device = q.device - - if isinstance(q, VarLenTensor): - s = q - q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] - q = q.feats # [T_Q, H, Ci] - else: - s = None - N, L, H, CI = q.shape - q_seqlen = [L] * N - q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] - - if isinstance(k, VarLenTensor): - kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] - k = k.feats # [T_KV, H, Ci] - v = v.feats # [T_KV, H, Co] - else: - N, L, H, CI, CO = *k.shape, v.shape[-1] - kv_seqlen = [L] * N - k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] - v = v.reshape(N * L, H, CO) # [T_KV, H, Co] - - # TODO: change - if q is not None: - heads = q - else: - heads = qkv - heads = heads.shape[2] - if optimized_attention.__name__ == 'attention_xformers': - if 'xops' not in globals(): - import xformers.ops as xops - if num_all_args == 1: - q, k, v = qkv.unbind(dim=1) - elif num_all_args == 2: - k, v = kv.unbind(dim=1) - q = q.unsqueeze(0) - k = k.unsqueeze(0) - v = v.unsqueeze(0) - mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) - out = xops.memory_efficient_attention(q, k, v, mask)[0] - elif optimized_attention.__name__ == 'attention_flash': - if 'flash_attn' not in globals(): - import flash_attn - cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) - if num_all_args in [2, 3]: - cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) - if num_all_args == 1: - out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) - elif num_all_args == 2: - out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) - elif num_all_args == 3: - out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) - - elif optimized_attention.__name__ == "attention_pytorch": - cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) - if num_all_args in [2, 3]: - cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) - else: - cu_seqlens_kv = cu_seqlens_q - if num_all_args == 1: - q, k, v = qkv.unbind(dim=1) - elif num_all_args == 2: - k, v = kv.unbind(dim=1) - out = attention_pytorch(q, k, v, heads=heads,cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max(q_seqlen), max_kv_seqlen=max(kv_seqlen), - skip_reshape=True, skip_output_reshape=True) - - if s is not None: - return s.replace(out) - else: - return out.reshape(N, L, H, -1) + return fwd_indices, bwd_indices, seq_lens diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 0b4181ad2..7e263169a 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -4,11 +4,12 @@ import torch.nn as nn from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor from typing import Optional, Tuple, Literal, Union, List from comfy.ldm.trellis2.attention import ( - sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention + sparse_windowed_self_attention, sparse_attention, dense_attention ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder from comfy.ldm.flux.math import apply_rope, apply_rope1 + class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: return input.replace(super().forward(input.feats)) @@ -31,22 +32,16 @@ class LayerNorm32(nn.LayerNorm): x = x.to(dtype=torch.float32) o = super().forward(x) return o.to(dtype=x_dtype) - - + class SparseMultiHeadRMSNorm(nn.Module): def __init__(self, dim: int, heads: int, device, dtype): super().__init__() - self.scale = dim ** 0.5 self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype)) def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: - x_type = x.dtype - x = x.float() if isinstance(x, VarLenTensor): - x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale) - else: - x = F.normalize(x, dim=-1) * self.gamma * self.scale - return x.to(x_type) + return x.replace(F.rms_norm(x.feats, (x.feats.shape[-1],)) * self.gamma) + return F.rms_norm(x, (x.shape[-1],)) * self.gamma class SparseRotaryPositionEmbedder(nn.Module): def __init__( @@ -84,12 +79,6 @@ class SparseRotaryPositionEmbedder(nn.Module): return freqs_cis - def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: - self.freqs = self.freqs.to(indices.device) - phases = torch.outer(indices, self.freqs) - phases = torch.polar(torch.ones_like(phases), phases) - return phases - def forward(self, q, k=None): cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}' freqs_cis = q.get_spatial_cache(cache_name) @@ -110,27 +99,10 @@ class SparseRotaryPositionEmbedder(nn.Module): q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis) return q.replace(q_feats), k.replace(k_feats) - @staticmethod - def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: - x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - x_rotated = x_complex * phases.unsqueeze(-2) - x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) - return x_embed class RotaryPositionEmbedder(SparseRotaryPositionEmbedder): - def forward(self, indices: torch.Tensor) -> torch.Tensor: - phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) - if torch.is_complex(phases): - phases = phases.to(torch.complex64) - else: - phases = phases.to(torch.float32) - if phases.shape[-1] < self.head_dim // 2: - padn = self.head_dim // 2 - phases.shape[-1] - phases = torch.cat([phases, torch.polar( - torch.ones(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32), - torch.zeros(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32) - )], dim=-1) - return phases + def forward(self, coords: torch.Tensor) -> torch.Tensor: + return self._get_freqs_cis(coords) # [L, head_dim/2, 2, 2] class SparseMultiHeadAttention(nn.Module): def __init__( @@ -198,50 +170,53 @@ class SparseMultiHeadAttention(nn.Module): x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats - def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor: + def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None, + transformer_options=None) -> SparseTensor: if self._type == "self": - dtype = next(self.to_qkv.parameters()).dtype - x = x.to(dtype) qkv = self._linear(self.to_qkv, x) qkv = self._fused_pre(qkv, num_fused=3) - if self.qk_rms_norm or self.use_rope: + if self.attn_mode == "full": q, k, v = qkv.unbind(dim=-3) if self.qk_rms_norm: q = self.q_rms_norm(q) k = self.k_rms_norm(k) if self.use_rope: q, k = self.rope(q, k) - qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) - if self.attn_mode == "full": - h = sparse_scaled_dot_product_attention(qkv) - elif self.attn_mode == "windowed": - h = sparse_windowed_scaled_dot_product_self_attention( - qkv, self.window_size, shift_window=self.shift_window - ) - elif self.attn_mode == "double_windowed": - qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:]) - qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2]) - h0 = sparse_windowed_scaled_dot_product_self_attention( - qkv0, self.window_size, shift_window=(0, 0, 0) - ) - h1 = sparse_windowed_scaled_dot_product_self_attention( - qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3) - ) - h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1)) + h = sparse_attention(q, k, v, transformer_options=transformer_options) + else: + # Windowed paths take packed qkv; preserve any per-head norm/rope. + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=-3) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + q, k = self.rope(q, k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "windowed": + h = sparse_windowed_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + elif self.attn_mode == "double_windowed": + qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:]) + qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2]) + h0 = sparse_windowed_self_attention( + qkv0, self.window_size, shift_window=(0, 0, 0) + ) + h1 = sparse_windowed_self_attention( + qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3) + ) + h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1)) else: q = self._linear(self.to_q, x) q = self._reshape_chs(q, (self.num_heads, -1)) - dtype = next(self.to_kv.parameters()).dtype - context = context.to(dtype) kv = self._linear(self.to_kv, context) kv = self._fused_pre(kv, num_fused=2) + k, v = kv.unbind(dim=-3) if self.qk_rms_norm: q = self.q_rms_norm(q) - k, v = kv.unbind(dim=-3) k = self.k_rms_norm(k) - h = sparse_scaled_dot_product_attention(q, k, v) - else: - h = sparse_scaled_dot_product_attention(q, kv) + h = sparse_attention(q, k, v, transformer_options=transformer_options) h = self._reshape_chs(h, (-1,)) h = self._linear(self.to_out, h) return h @@ -265,9 +240,9 @@ class ProjectAttentionSparse(nn.Module): self.proj_linear = operations.Linear(proj_in_channels, channels, bias=True, device=device, dtype=dtype) - def forward(self, x: SparseTensor, context) -> SparseTensor: + def forward(self, x: SparseTensor, context, transformer_options=None) -> SparseTensor: global_ctx, proj_in = _split_proj_context(context) - global_out = self.cross_attn_block(x, global_ctx) + global_out = self.cross_attn_block(x, global_ctx, transformer_options=transformer_options) if isinstance(proj_in, tuple): proj_in = torch.cat([proj_in[0], proj_in[1]], dim=-1) proj_out = self.proj_linear(proj_in.to(self.proj_linear.weight.dtype)) @@ -282,9 +257,9 @@ class ProjectAttentionDense(nn.Module): self.proj_linear = operations.Linear(proj_in_channels, channels, bias=True, device=device, dtype=dtype) - def forward(self, x: torch.Tensor, context) -> torch.Tensor: + def forward(self, x: torch.Tensor, context, transformer_options=None) -> torch.Tensor: global_ctx, proj_in = _split_proj_context(context) - global_out = self.cross_attn_block(x, global_ctx) + global_out = self.cross_attn_block(x, global_ctx, transformer_options=transformer_options) if isinstance(proj_in, tuple): proj_in = torch.cat([proj_in[0], proj_in[1]], dim=-1) proj_out = self.proj_linear(proj_in.to(self.proj_linear.weight.dtype)) @@ -367,32 +342,36 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): else: self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) - def _forward(self, x: SparseTensor, mod: torch.Tensor, context) -> SparseTensor: + def _forward(self, x: SparseTensor, mod: torch.Tensor, context, transformer_options=None) -> SparseTensor: if self.share_mod: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) else: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) - h = x.replace(self.norm1(x.feats)) - h = h * (1 + scale_msa) + shift_msa - h = self.self_attn(h) - h = h * gate_msa - x = x + h + # Fuse the (mul + add) and (mul + residual) pairs into addcmul so the + # mod/shift broadcasts hit one kernel each instead of two. + b_map = x.batch_boardcast_map + + h_feats = self.norm1(x.feats) + h_feats = torch.addcmul(shift_msa[b_map], h_feats, (1 + scale_msa)[b_map]) + h = self.self_attn(x.replace(h_feats), transformer_options=transformer_options) + x = x.replace(torch.addcmul(x.feats, h.feats, gate_msa[b_map])) + h = x.replace(self.norm2(x.feats)) if self.image_attn_mode == "global": global_ctx, _ = _split_proj_context(context) - h = self.cross_attn(h, global_ctx) + h = self.cross_attn(h, global_ctx, transformer_options=transformer_options) else: - h = self.cross_attn(h, context) - x = x + h - h = x.replace(self.norm3(x.feats)) - h = h * (1 + scale_mlp) + shift_mlp - h = self.mlp(h) - h = h * gate_mlp + h = self.cross_attn(h, context, transformer_options=transformer_options) x = x + h + + h_feats = self.norm3(x.feats) + h_feats = torch.addcmul(shift_mlp[b_map], h_feats, (1 + scale_mlp)[b_map]) + h = self.mlp(x.replace(h_feats)) + x = x.replace(torch.addcmul(x.feats, h.feats, gate_mlp[b_map])) return x - def forward(self, x: SparseTensor, mod: torch.Tensor, context) -> SparseTensor: - return self._forward(x, mod, context) + def forward(self, x: SparseTensor, mod: torch.Tensor, context, transformer_options=None) -> SparseTensor: + return self._forward(x, mod, context, transformer_options=transformer_options) class SLatFlowModel(nn.Module): @@ -480,24 +459,21 @@ class SLatFlowModel(nn.Module): t: torch.Tensor, cond: Union[torch.Tensor, List[torch.Tensor]], concat_cond: Optional[SparseTensor] = None, - **kwargs + transformer_options=None, + **kwargs, ) -> SparseTensor: if concat_cond is not None: x = sparse_cat([x, concat_cond], dim=-1) if isinstance(cond, list): cond = VarLenTensor.from_tensor_list(cond) - dtype = next(self.input_layer.parameters()).dtype - x = x.to(dtype) h = self.input_layer(x) - t = t.to(dtype) - t_embedder = self.t_embedder.to(dtype) - t_emb = t_embedder(t, out_dtype = t.dtype) + t_emb = self.t_embedder(t, out_dtype=t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) for block in self.blocks: - h = block(h, t_emb, cond) + h = block(h, t_emb, cond, transformer_options=transformer_options) h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) h = self.out_layer(h) @@ -566,40 +542,34 @@ class MultiHeadAttention(nn.Module): self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype) - def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, + phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor: B, L, C = x.shape if self._type == "self": - x = x.to(next(self.to_qkv.parameters()).dtype) qkv = self.to_qkv(x) qkv = qkv.reshape(B, L, 3, self.num_heads, -1) - - if self.attn_mode == "full": - if self.qk_rms_norm or self.use_rope: - q, k, v = qkv.unbind(dim=2) - if self.qk_rms_norm: - q = self.q_rms_norm(q) - k = self.k_rms_norm(k) - if self.use_rope: - assert phases is not None, "Phases must be provided for RoPE" - q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases) - k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases) - h = scaled_dot_product_attention(q, k, v) - else: - h = scaled_dot_product_attention(qkv) + q, k, v = qkv.unbind(dim=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + assert phases is not None, "Phases must be provided for RoPE" + # phases is [L, head_dim/2, 2, 2]; broadcast to [1, L, 1, ...] + # to align with q/k of shape [B, L, H, head_dim]. + f_cis = phases.unsqueeze(0).unsqueeze(2) + q, k = apply_rope(q, k, f_cis) + h = dense_attention(q, k, v, transformer_options=transformer_options) else: Lkv = context.shape[1] q = self.to_q(x) - context = context.to(next(self.to_kv.parameters()).dtype) kv = self.to_kv(context) q = q.reshape(B, L, self.num_heads, -1) kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + k, v = kv.unbind(dim=2) if self.qk_rms_norm: q = self.q_rms_norm(q) - k, v = kv.unbind(dim=2) k = self.k_rms_norm(k) - h = scaled_dot_product_attention(q, k, v) - else: - h = scaled_dot_product_attention(q, kv) + h = dense_attention(q, k, v, transformer_options=transformer_options) h = h.reshape(B, L, -1) h = self.to_out(h) return h @@ -677,32 +647,39 @@ class ModulatedTransformerCrossBlock(nn.Module): else: self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) - def _forward(self, x: torch.Tensor, mod: torch.Tensor, context, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context, + phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor: if self.share_mod: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) else: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) - h = self.norm1(x) - h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) - h = self.self_attn(h, phases=phases) - h = h * gate_msa.unsqueeze(1) - x = x + h + shift_msa = shift_msa.unsqueeze(1) + scale_msa = scale_msa.unsqueeze(1) + gate_msa = gate_msa.unsqueeze(1) + shift_mlp = shift_mlp.unsqueeze(1) + scale_mlp = scale_mlp.unsqueeze(1) + gate_mlp = gate_mlp.unsqueeze(1) + + h = torch.addcmul(shift_msa, self.norm1(x), 1 + scale_msa) + h = self.self_attn(h, phases=phases, transformer_options=transformer_options) + x = torch.addcmul(x, h, gate_msa) + h = self.norm2(x) if self.image_attn_mode == "global": global_ctx, _ = _split_proj_context(context) - h = self.cross_attn(h, global_ctx) + h = self.cross_attn(h, global_ctx, transformer_options=transformer_options) else: - h = self.cross_attn(h, context) + h = self.cross_attn(h, context, transformer_options=transformer_options) x = x + h - h = self.norm3(x) - h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + + h = torch.addcmul(shift_mlp, self.norm3(x), 1 + scale_mlp) h = self.mlp(h) - h = h * gate_mlp.unsqueeze(1) - x = x + h + x = torch.addcmul(x, h, gate_mlp) return x - def forward(self, x: torch.Tensor, mod: torch.Tensor, context, phases: Optional[torch.Tensor] = None) -> torch.Tensor: - return self._forward(x, mod, context, phases) + def forward(self, x: torch.Tensor, mod: torch.Tensor, context, + phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor: + return self._forward(x, mod, context, phases, transformer_options=transformer_options) class SparseStructureFlowModel(nn.Module): @@ -792,18 +769,18 @@ class SparseStructureFlowModel(nn.Module): self.out_layer = operations.Linear(model_channels, out_channels, device=device, dtype=dtype) - def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor, + transformer_options=None) -> torch.Tensor: x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3) h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() - h = h.to(next(self.input_layer.parameters()).dtype) h = self.input_layer(h) - t_emb = self.t_embedder(t, out_dtype = t.dtype) + t_emb = self.t_embedder(t, out_dtype=t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) for block in self.blocks: - h = block(h, t_emb, cond, self.rope_phases) + h = block(h, t_emb, cond, self.rope_phases, transformer_options=transformer_options) h = F.layer_norm(h, h.shape[-1:]) h = self.out_layer(h) @@ -820,7 +797,7 @@ def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0): # Pixal3D ProjGrid math — port of upstream's ProjGrid + project_points_to_image_batch. -# World frame uses world Y as depth (Blender convention), camera looks along -Z local; +# World frame uses world Y as depth, camera looks along -Z local; # transform_matrix is camera-to-world (inverted internally). Intrinsics: fx = 16 / tan(fov/2) # with sensor_width = 32mm. @@ -927,7 +904,7 @@ def _back_project_to_tokens( if not mask.any(): continue p = coords_world[mask].unsqueeze(0) - uv, depth, valid = _project_points_to_image( + uv, _, _ = _project_points_to_image( p, transform_matrix[b:b+1], camera_angle_x[b:b+1], image_resolution) uv_ndc = (uv + 0.5) / image_resolution * 2.0 - 1.0 # padding_mode='border' is load-bearing: masking out-of-frame voxels confuses @@ -937,7 +914,7 @@ def _back_project_to_tokens( out[mask] = sampled return out else: - uv, depth, valid = _project_points_to_image( + uv, _, _ = _project_points_to_image( coords_world, transform_matrix, camera_angle_x, image_resolution) uv_ndc = (uv + 0.5) / image_resolution * 2.0 - 1.0 sampled = _sample_features(feature_map, uv_ndc) @@ -945,21 +922,6 @@ def _back_project_to_tokens( return out -def _pack_per_voxel_scalar(proj_pack: Optional[dict], key: str, eval_batch: int, device) -> torch.Tensor: - if proj_pack is None or key not in proj_pack: - return torch.ones((eval_batch,), device=device, dtype=torch.float32) - t = proj_pack[key].to(device=device, dtype=torch.float32) - if t.ndim == 0: - return t.expand(eval_batch).clone() - return _expand_pack(t, eval_batch) - - -def _expand_pack(t: torch.Tensor, eval_batch: int) -> torch.Tensor: - if eval_batch == t.shape[0]: - return t - if eval_batch % t.shape[0] != 0: - raise ValueError(f"eval batch {eval_batch} is not a multiple of pack batch {t.shape[0]}") - return t.repeat((eval_batch // t.shape[0],) + (1,) * (t.ndim - 1)) def _select_stage_entry(proj_pack: dict, stage: Optional[str]): @@ -973,53 +935,96 @@ def _select_stage_entry(proj_pack: dict, stage: Optional[str]): raise ValueError(f"proj_feat_pack has no usable feature_map (stage={stage!r})") -def _build_proj_cond(global_cond: torch.Tensor, image_attn_mode: str, proj_pack: Optional[dict], - coords_world: torch.Tensor, batch_ids: Optional[torch.Tensor] = None, - eval_batch: Optional[int] = None, - proj_in_channels: Optional[int] = None, - stage: Optional[str] = None, - cond_or_uncond: Optional[list] = None): - if image_attn_mode == "global": - return global_cond - if proj_pack is None: - raise ValueError(f"image_attn_mode={image_attn_mode!r} but proj_feat_pack is missing") - device = coords_world.device +def compute_stage_proj_feats( + proj_pack: dict, + stage: str, + coords: Optional[torch.Tensor] = None, + coord_resolution: Optional[int] = None, + dense_grid_resolution: Optional[int] = None, + batch_size: Optional[int] = None, + device=None, +) -> torch.Tensor: + """Back-project a Pixal3D stage's feature maps onto its target voxel/grid coords. + + For sparse (shape / texture) stages: pass ``coords`` (with ``coord_resolution``). + Returns ``[N_voxels, C]`` per-voxel features with channel count = + LR channels + optional HR channels. + + For the dense SS stage: pass ``dense_grid_resolution`` (16) + ``batch_size``. + Returns ``[B, R^3, C]`` features for the dense grid. + + """ + if device is None: + device = coords.device if coords is not None else proj_pack["mesh_scale"].device + mesh_scale = proj_pack["mesh_scale"].to(device) T = proj_pack["transform_matrix"].to(device) cam_angle = proj_pack["camera_angle_x"].to(device) feat_map_lr, feat_map_hr, image_resolution = _select_stage_entry(proj_pack, stage) feat_map_lr = feat_map_lr.to(device) if feat_map_hr is not None: feat_map_hr = feat_map_hr.to(device) - if eval_batch is not None: - T = _expand_pack(T, eval_batch) - cam_angle = _expand_pack(cam_angle, eval_batch) if cam_angle.ndim >= 1 else cam_angle - feat_map_lr = _expand_pack(feat_map_lr, eval_batch) - if feat_map_hr is not None: - feat_map_hr = _expand_pack(feat_map_hr, eval_batch) - # Channel-count check against the trained proj_linear input. If HR is present, the - # block expects (LR_channels + HR_channels) since we concat the sampled features. - expected_channels = feat_map_lr.shape[1] + (feat_map_hr.shape[1] if feat_map_hr is not None else 0) - if proj_in_channels is not None and expected_channels != proj_in_channels: + + if coords is not None: + if coord_resolution is None: + raise ValueError("compute_stage_proj_feats: coord_resolution required when coords is given") + coords_world, batch_ids = _coords_to_proj_world(coords, coord_resolution, mesh_scale) + else: + if dense_grid_resolution is None or batch_size is None: + raise ValueError("compute_stage_proj_feats: dense_grid_resolution + batch_size required for dense path") + coords_world = _dense_grid_proj_world(dense_grid_resolution, mesh_scale, batch_size, + device=device, dtype=torch.float32) + batch_ids = None + + proj_lr = _back_project_to_tokens(coords_world, feat_map_lr, T, cam_angle, + image_resolution=image_resolution, batch_ids=batch_ids) + if feat_map_hr is not None: + proj_hr = _back_project_to_tokens(coords_world, feat_map_hr, T, cam_angle, + image_resolution=image_resolution, batch_ids=batch_ids) + return torch.cat([proj_lr, proj_hr], dim=-1) + return proj_lr + + +def _shape_proj_cond(global_cond: torch.Tensor, image_attn_mode: str, + proj_feats: Optional[torch.Tensor], + batch_ids: Optional[torch.Tensor] = None, + eval_batch: Optional[int] = None, + logical_batch: Optional[int] = None, + proj_in_channels: Optional[int] = None, + stage: Optional[str] = None, + cond_or_uncond: Optional[list] = None, + has_hr: bool = False): + """Take pre-computed per-token proj features (from compute_stage_proj_feats), + apply CFG-batch duplication + uncond-slot zeroing, and wrap into the + ``{"global", "proj"}`` context dict consumed by ProjectAttention. + + proj_feats shape: + sparse (shape/texture): [N_voxels, C] (batch_ids gives per-voxel batch) + dense (SS): [B, N, C] + """ + if image_attn_mode == "global": + return global_cond + if proj_feats is None: + raise ValueError(f"image_attn_mode={image_attn_mode!r} but trellis2_proj_feats is missing — " + f"the stage setup node (or Pixal3DConditioning for SS) should have computed it.") + if proj_in_channels is not None and proj_feats.shape[-1] != proj_in_channels: hint = "" - if feat_map_hr is None and expected_channels < proj_in_channels: + if not has_hr and proj_feats.shape[-1] < proj_in_channels: hint = (" — feature_map_hr is missing for this stage. Connect a NAFModel " "input to Pixal3DConditioning; the shape/texture stages of this " "checkpoint need a NAF-upsampled HR feature map.") raise ValueError( - f"proj_feat_pack[{stage!r}] has LR={feat_map_lr.shape[1]} " - f"+ HR={feat_map_hr.shape[1] if feat_map_hr is not None else 0} " - f"= {expected_channels} channels, sub-model expects {proj_in_channels}.{hint}" + f"proj_feats for stage {stage!r} has {proj_feats.shape[-1]} channels, " + f"sub-model expects {proj_in_channels}.{hint}" ) - proj_feats_lr = _back_project_to_tokens(coords_world, feat_map_lr, T, cam_angle, - image_resolution=image_resolution, - batch_ids=batch_ids) - if feat_map_hr is not None: - proj_feats_hr = _back_project_to_tokens(coords_world, feat_map_hr, T, cam_angle, - image_resolution=image_resolution, - batch_ids=batch_ids) - proj_feats = torch.cat([proj_feats_lr, proj_feats_hr], dim=-1) - else: - proj_feats = proj_feats_lr + + # CFG-duplicate proj_feats to match the model's eval batch. + if eval_batch is not None and logical_batch is not None and eval_batch > logical_batch: + repeats = eval_batch // logical_batch + if batch_ids is None: + proj_feats = proj_feats.repeat((repeats,) + (1,) * (proj_feats.ndim - 1)) + else: + proj_feats = proj_feats.repeat((repeats, 1)) + # Mirror upstream's neg_cond by zeroing proj for any uncond batch slot. if cond_or_uncond is not None and eval_batch is not None: uncond_slots = [i for i, v in enumerate(cond_or_uncond) if v == 1] @@ -1056,10 +1061,6 @@ class Trellis2(nn.Module): super().__init__() self.dtype = dtype - operations = operations or nn - # for some reason it passes num_heads = -1 - if num_heads == -1: - num_heads = 12 args = { "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, @@ -1082,101 +1083,79 @@ class Trellis2(nn.Module): self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **struct_proj_kwargs, **args) else: self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **tex_proj_kwargs, **args) - self.guidance_interval = [0.6, 1.0] - self.guidance_interval_txt = [0.6, 0.9] def forward(self, x, timestep, context, **kwargs): transformer_options = kwargs.get("transformer_options", {}) cond_or_uncond = transformer_options.get("cond_or_uncond") - model_options = {} - if hasattr(self, "meta"): - model_options = self.meta timestep = timestep.to(x.dtype) embeds = kwargs.get("embeds") if embeds is None: raise ValueError("Trellis2.forward requires 'embeds' in kwargs") - is_1024 = True#self.img2shape.resolution == 1024 - coords = model_options.get("coords", None) - coord_counts = model_options.get("coord_counts", None) - mode = model_options.get("generation_mode", "structure_generation") - proj_feat_pack = model_options.get("proj_feat_pack", None) - coord_resolution = model_options.get("coord_resolution", None) + # Per-stage cascade metadata + coords = kwargs.get("trellis2_coords") + coord_counts = kwargs.get("trellis2_coord_counts") + mode = kwargs.get("trellis2_generation_mode", "structure_generation") + proj_feat_pack = kwargs.get("proj_feat_pack") + # Pre-computed per-stage back-projected features + proj_feats = kwargs.get("trellis2_proj_feats") - is_512_run = False + is_first_shape_pass = False if mode == "shape_generation_512": - is_512_run = True + is_first_shape_pass = True mode = "shape_generation" if coords is not None: - if x.ndim == 4: - x = x.squeeze(-1).transpose(1, 2) - not_struct_mode = True + x = x.squeeze(-1).transpose(1, 2) + is_sparse_mode = True else: mode = "structure_generation" - not_struct_mode = False + is_sparse_mode = False if x.size(-1) == 16 and x.size(-2) == 16: mode = "structure_generation" - not_struct_mode = False + is_sparse_mode = False - if not not_struct_mode: + if not is_sparse_mode: bsz = x.size(0) x = x[:, :8] x = x.view(bsz, 8, 16, 16, 16) - if is_1024 and not_struct_mode and not is_512_run: + if is_sparse_mode and not is_first_shape_pass: context = embeds - sigmas = transformer_options.get("sigmas")[0].item() - if sigmas < 1.00001: - timestep *= 1000.0 + if is_sparse_mode: + t_eval = timestep + c_eval = context - if context.size(0) > 1: - cond = context.chunk(2)[1] - else: - cond = context - - shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] - txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] - - if not_struct_mode: - orig_bsz = x.shape[0] - rule = txt_rule if mode == "texture_generation" else shape_rule - - # CFG Bypass Slicing - if rule and orig_bsz > 1: - half = orig_bsz // 2 - x_eval = x[half:] - t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep - c_eval = cond - else: - x_eval = x - t_eval = timestep - c_eval = context - - B, N, C = x_eval.shape + B, N, C = x.shape # Vectorized SparseTensor Construction if mode in ["shape_generation", "texture_generation"]: if coord_counts is not None: logical_batch = coord_counts.shape[0] - # Duplicate coords if CFG is active + # Duplicate sparse coords when the sampler asks for >1 cond + # (CFG or otherwise). Each duplicate is offset along col 0 + # so SparseTensor sees a fresh logical batch. if B > logical_batch: - c_pos = coords.clone() - c_pos[:, 0] += logical_batch - batched_coords = torch.cat([coords, c_pos], dim=0) - counts_eval = torch.cat([coord_counts, coord_counts], dim=0) + reps = B // logical_batch + c_copies = [] + for i in range(reps): + c = coords.clone() + c[:, 0] += i * logical_batch + c_copies.append(c) + batched_coords = torch.cat(c_copies, dim=0) + counts_eval = coord_counts.repeat(reps) else: batched_coords = coords counts_eval = coord_counts - # Create boolean mask [B, N] to drop the padded zeros instantly + # Boolean mask [B, N] to drop the padded zeros instantly mask = torch.arange(N, device=x.device).unsqueeze(0) < counts_eval.unsqueeze(1) - feats_flat = x_eval[mask] + feats_flat = x[mask] else: - feats_flat = x_eval.reshape(-1, C) - coords_list =[] + feats_flat = x.reshape(-1, C) + coords_list = [] for i in range(B): c = coords.clone() c[:, 0] = i @@ -1185,35 +1164,42 @@ class Trellis2(nn.Module): mask = None else: batched_coords = coords - feats_flat = x_eval + feats_flat = x mask = None - x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) + x_st = SparseTensor( + feats=feats_flat, + coords=batched_coords.to(torch.int32), + shape=torch.Size([B] + list(feats_flat.shape[1:])), + ) if mode == "shape_generation": shape_attn = self.image_attn_mode_shape if shape_attn != "global": - if coord_resolution is None: - raise ValueError("Pixal3D shape_generation requires coord_resolution in model_options; " - "EmptyTrellis2ShapeLatent should set it from the input voxel.") - mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", B, batched_coords.device) - xyz_world, batch_ids = _coords_to_proj_world(batched_coords, coord_resolution, mesh_scale) - sub_model = self.img2shape_512 if is_512_run else self.img2shape - stage_name = "shape_512" if is_512_run else "shape_1024" - c_eval = _build_proj_cond(c_eval, shape_attn, proj_feat_pack, xyz_world, batch_ids, - eval_batch=B, + sub_model = self.img2shape_512 if is_first_shape_pass else self.img2shape + stage_name = "shape_512" if is_first_shape_pass else "shape_1024" + # batched_coords carries CFG-doubled batch ids in col 0; per-voxel + # batch_ids drive uncond-slot masking inside _shape_proj_cond. + batch_ids = batched_coords[:, 0].long() + logical_batch = coord_counts.shape[0] if coord_counts is not None else B + has_hr = bool(proj_feat_pack and proj_feat_pack.get("stages", {}) + .get(stage_name, {}).get("feature_map_hr") is not None) + c_eval = _shape_proj_cond(c_eval, shape_attn, proj_feats, + batch_ids=batch_ids, + eval_batch=B, logical_batch=logical_batch, proj_in_channels=sub_model.proj_in_channels, stage=stage_name, - cond_or_uncond=cond_or_uncond) - if is_512_run: - out = self.img2shape_512(x_st, t_eval, c_eval) + cond_or_uncond=cond_or_uncond, + has_hr=has_hr) + if is_first_shape_pass: + out = self.img2shape_512(x_st, t_eval, c_eval, transformer_options=transformer_options) else: - out = self.img2shape(x_st, t_eval, c_eval) + out = self.img2shape(x_st, t_eval, c_eval, transformer_options=transformer_options) elif mode == "texture_generation": if self.shape2txt is None: raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") - slat = model_options.get("shape_slat") + slat = kwargs.get("trellis2_shape_slat") if slat is None: raise ValueError("shape_slat can't be None") @@ -1227,49 +1213,40 @@ class Trellis2(nn.Module): x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats.to(x_st.feats.device)], dim=-1)) tex_attn = self.image_attn_mode_texture if tex_attn != "global": - if coord_resolution is None: - raise ValueError("Pixal3D texture_generation requires coord_resolution in model_options; " - "EmptyTrellis2LatentTexture should set it from the input voxel.") - mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", B, batched_coords.device) - xyz_world, batch_ids = _coords_to_proj_world(batched_coords, coord_resolution, mesh_scale) - c_eval = _build_proj_cond(c_eval, tex_attn, proj_feat_pack, xyz_world, batch_ids, - eval_batch=B, + batch_ids = batched_coords[:, 0].long() + logical_batch = coord_counts.shape[0] if coord_counts is not None else B + has_hr = bool(proj_feat_pack and proj_feat_pack.get("stages", {}) + .get("tex_1024", {}).get("feature_map_hr") is not None) + c_eval = _shape_proj_cond(c_eval, tex_attn, proj_feats, + batch_ids=batch_ids, + eval_batch=B, logical_batch=logical_batch, proj_in_channels=self.shape2txt.proj_in_channels, stage="tex_1024", - cond_or_uncond=cond_or_uncond) - out = self.shape2txt(x_st, t_eval, c_eval) + cond_or_uncond=cond_or_uncond, + has_hr=has_hr) + out = self.shape2txt(x_st, t_eval, c_eval, transformer_options=transformer_options) else: # structure - orig_bsz = x.shape[0] struct_attn = self.image_attn_mode_structure - if shape_rule and orig_bsz > 1: - half = orig_bsz // 2 - x_eval = x[half:] - t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep - struct_cond = cond - if struct_attn != "global": - mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", half, x.device) - grid_xyz = _dense_grid_proj_world(16, mesh_scale, half, device=x.device) - struct_cond = _build_proj_cond(cond, struct_attn, proj_feat_pack, grid_xyz, - eval_batch=half, - proj_in_channels=self.structure_model.proj_in_channels, - stage="ss", - cond_or_uncond=cond_or_uncond) - out = self.structure_model(x_eval, t_eval, struct_cond) - out = out.repeat(2, 1, 1, 1, 1) - else: - struct_cond = context - if struct_attn != "global": - mesh_scale = _pack_per_voxel_scalar(proj_feat_pack, "mesh_scale", orig_bsz, x.device) - grid_xyz = _dense_grid_proj_world(16, mesh_scale, orig_bsz, device=x.device) - struct_cond = _build_proj_cond(context, struct_attn, proj_feat_pack, grid_xyz, - eval_batch=orig_bsz, - proj_in_channels=self.structure_model.proj_in_channels, - stage="ss", - cond_or_uncond=cond_or_uncond) - out = self.structure_model(x, timestep, struct_cond) + has_hr_ss = bool(proj_feat_pack and proj_feat_pack.get("stages", {}) + .get("ss", {}).get("feature_map_hr") is not None) + logical_batch_ss = ( + proj_feat_pack["mesh_scale"].shape[0] + if (proj_feat_pack is not None and torch.is_tensor(proj_feat_pack.get("mesh_scale"))) + else x.shape[0] + ) + struct_cond = context + if struct_attn != "global": + struct_cond = _shape_proj_cond(context, struct_attn, proj_feats, + batch_ids=None, + eval_batch=x.shape[0], logical_batch=logical_batch_ss, + proj_in_channels=self.structure_model.proj_in_channels, + stage="ss", + cond_or_uncond=cond_or_uncond, + has_hr=has_hr_ss) + out = self.structure_model(x, timestep, struct_cond, transformer_options=transformer_options) - if not_struct_mode: + if is_sparse_mode: if mask is not None: # Instantly scatter the valid tokens back into a padded rectangular tensor padded_out = torch.zeros((B, N, out.feats.shape[-1]), device=x.device, dtype=out.feats.dtype) @@ -1277,9 +1254,6 @@ class Trellis2(nn.Module): out_tensor = padded_out.transpose(1, 2).unsqueeze(-1) else: out_tensor = out.feats.view(B, N, -1).transpose(1, 2).unsqueeze(-1) - - if rule and orig_bsz > 1: - out_tensor = out_tensor.repeat(2, 1, 1, 1) return out_tensor else: out = torch.nn.functional.pad(out, (0, 0, 0, 0, 0, 0, 0, 24)) diff --git a/comfy/model_base.py b/comfy/model_base.py index 2a27e7cca..142ec530a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1646,6 +1646,17 @@ class Trellis2(BaseModel): out = super().extra_conds(**kwargs) embeds = kwargs.get("embeds") out["embeds"] = comfy.conds.CONDRegular(embeds) + # CONDConstant: shared across pos/neg + for k in ("trellis2_coords", "trellis2_coord_counts", + "trellis2_generation_mode", "trellis2_shape_slat", + "trellis2_proj_feats"): + v = kwargs.get(k) + if v is not None: + out[k] = comfy.conds.CONDConstant(v) + # Pixal3D's per-stage feature maps + camera params travel as a dict + proj_feat_pack = kwargs.get("proj_feat_pack") + if proj_feat_pack is not None: + out["proj_feat_pack"] = comfy.conds.CONDConstant(proj_feat_pack) return out class WAN21_FlowRVS(WAN21): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 86338f1d7..34b8b0d95 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1322,10 +1322,10 @@ class Trellis2(supported_models_base.BASE): unet_config = { "image_model": "trellis2" } + unet_extra_config = {} sampling_settings = { "shift": 3.0, - "multiplier": 1.0 } memory_usage_factor = 3.5 diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 3484d6d6d..7bc22074e 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,7 +1,9 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types, io from comfy.ldm.trellis2.vae import SparseTensor -from comfy.ldm.trellis2.model import _build_proj_transform_matrix, _project_points_to_image +from comfy.ldm.trellis2.model import ( + _build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats, +) from comfy.ldm.trellis2.naf.model import build_naf_from_state_dict from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch import comfy.model_management @@ -14,48 +16,9 @@ import math import torch ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES") -Pixal3DProjPack = io.Custom("PIXAL3D_PROJ_PACK") NAFModel = io.Custom("NAF_MODEL") -# Pixal3D trains in a 90°-X-rotated grid frame (F_p). We un-rotate decoder outputs for -# user-facing previews/meshes, then re-rotate before feeding coords back to the shape DiT. - -def _pixal3d_unrotate_voxel_data(data: torch.Tensor) -> torch.Tensor: - if data.ndim == 4: - return data.flip(-1).permute(0, 1, 3, 2).contiguous() - if data.ndim == 5: - return data.flip(-1).permute(0, 1, 2, 4, 3).contiguous() - raise ValueError(f"unexpected voxel shape {tuple(data.shape)}") - - -def _pixal3d_rerotate_voxel_data(data: torch.Tensor) -> torch.Tensor: - if data.ndim == 4: - return data.permute(0, 1, 3, 2).flip(-1).contiguous() - if data.ndim == 5: - return data.permute(0, 1, 2, 4, 3).flip(-1).contiguous() - raise ValueError(f"unexpected voxel shape {tuple(data.shape)}") - - -def _pixal3d_unrotate_vertices(vertices: torch.Tensor) -> torch.Tensor: - if vertices.numel() == 0: - return vertices - x, y, z = vertices.unbind(-1) - return torch.stack([-x, y, -z], dim=-1).contiguous() - - -def _pixal3d_unrotate_sparse_coords(coords: torch.Tensor, resolution: int) -> torch.Tensor: - if coords.numel() == 0: - return coords - R1 = resolution - 1 - if coords.shape[-1] == 4: - b, i, j, k = coords.unbind(-1) - return torch.stack([b, R1 - i, j, R1 - k], dim=-1).contiguous() - if coords.shape[-1] == 3: - i, j, k = coords.unbind(-1) - return torch.stack([R1 - i, j, R1 - k], dim=-1).contiguous() - raise ValueError(f"unexpected coord shape {tuple(coords.shape)}") - def prepare_trellis_vae_for_decode(vae, sample_shape): memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype) if len(sample_shape) == 5: @@ -202,15 +165,21 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, vae): - - resolution = int(vae.first_stage_model.resolution.item()) + # Mesh grid_size must match the actual coord resolution the upstream + # stage was run at (1024 cascade -> 64, 1536 cascade -> 96). The VAE's + # built-in `.resolution` buffer defaults to 1024 and is otherwise stale; + # take coord_resolution from the latent dict if the stage node set it. + coord_resolution = samples.get("coord_resolution") + if coord_resolution is not None: + resolution = int(coord_resolution) * 16 + else: + resolution = int(vae.first_stage_model.resolution.item()) sample_tensor = samples["samples"] device = comfy.model_management.get_torch_device() coords = samples["coords"] prepare_trellis_vae_for_decode(vae, sample_tensor.shape) trellis_vae = vae.first_stage_model coord_counts = samples.get("coord_counts") - pixal3d_mode = samples.get("model_options", {}).get("proj_feat_pack") is not None samples = samples["samples"] if coord_counts is None: @@ -236,10 +205,6 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): coords_list = [stage_tensor.coords for stage_tensor in stage_tensors] subs.append(SparseTensor.from_tensor_list(feats_list, coords_list)) - if pixal3d_mode: - for m in mesh: - m.vertices = _pixal3d_unrotate_vertices(m.vertices) - face_list = [m.faces for m in mesh] vert_list = [m.vertices for m in mesh] if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list): @@ -276,7 +241,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): prepare_trellis_vae_for_decode(vae, sample_tensor.shape) trellis_vae = vae.first_stage_model coord_counts = samples.get("coord_counts") - pixal3d_mode = samples.get("model_options", {}).get("proj_feat_pack") is not None samples = samples["samples"] samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts) @@ -288,7 +252,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides) color_feats = voxel.feats[:, :3] - voxel_coords = voxel.coords#[:, 1:] + voxel_coords = voxel.coords if voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3: spatial = voxel_coords[:, -3:] if voxel_coords.shape[-1] == 4 else voxel_coords @@ -297,9 +261,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): else: tex_resolution = 1024 - if pixal3d_mode: - voxel_coords = _pixal3d_unrotate_sparse_coords(voxel_coords, resolution=tex_resolution) - voxel = Types.VOXEL(voxel_coords, color_feats, tex_resolution) return IO.NodeOutput(voxel) @@ -312,7 +273,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): inputs=[ IO.Latent.Input("samples"), IO.Vae.Input("vae"), - IO.Combo.Input("resolution", options=["32", "64"], default="32") + IO.Combo.Input("resolution", options=["32", "64"], default="32"), ], outputs=[ IO.Voxel.Output("voxel"), @@ -338,129 +299,132 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): ratio = current_res // resolution decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 voxel_data = decoded.squeeze(1).float() - if samples.get("model_options", {}).get("proj_feat_pack") is not None: - voxel_data = _pixal3d_unrotate_voxel_data(voxel_data) - out = Types.VOXEL(voxel_data) - return IO.NodeOutput(out) + return IO.NodeOutput(Types.VOXEL(voxel_data)) -class Trellis2UpsampleCascade(IO.ComfyNode): +class Trellis2UpsampleStage(IO.ComfyNode): + """Cascade-upsamples a 512-resolution shape latent into high-resolution + sparse coords and sets up the second shape-stage sampling pass at the + target resolution, attaching per-stage metadata to the conditioning for + the model to consume via extra_conds. target_resolution is reduced in + 128-step decrements until the unique upsampled coord count fits under + max_tokens (floor 1024).""" @classmethod def define_schema(cls): return IO.Schema( - node_id="Trellis2UpsampleCascade", + node_id="Trellis2UpsampleStage", category="latent/3d", - display_name="Trellis2 Upsample Cascade", - description="Upsamples low-resolution Trellis2 shape latents into higher resolution coordinates while respecting the maximum token budget.", + display_name="Trellis2 Upsample Stage", inputs=[ - IO.Latent.Input("shape_latent"), + IO.Conditioning.Input("positive"), + IO.Conditioning.Input("negative"), + IO.Latent.Input("shape_latent", tooltip="The 512-resolution shape latent output from the first shape-stage KSampler."), IO.Vae.Input("vae"), IO.Combo.Input("target_resolution", options=["1024", "1536"], default="1024", tooltip="Controls output detail level for upsampling."), IO.Int.Input("max_tokens", default=49152, min=1024, max=100000, tooltip=( "Maximum number of output elements (coordinates) allowed after upsampling. " "Used to limit memory usage and control mesh density." - )) + )), ], outputs=[ - IO.Voxel.Output( - "high_res_voxel", - tooltip=( - "High-resolution sparse coordinates produced after cascade upsampling. " - "Represents the refined 3D structure at target resolution." - ) - ) + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + IO.Latent.Output(), ] ) - @classmethod - def execute(cls, shape_latent, vae, target_resolution, max_tokens): - shape_latent_512 = shape_latent - device = comfy.model_management.get_torch_device() - prepare_trellis_vae_for_decode(vae, shape_latent_512["samples"].shape) + @staticmethod + def _quantize_unique(hr_coords: torch.Tensor, lr_resolution: int, hr_resolution: int) -> torch.Tensor: + # Fold the two scalar divisions into one and chain the float math in-place + # to avoid 3 full M*3 fp32 transients per call. + scale = (hr_resolution // 16) / lr_resolution + spatial = hr_coords[:, 1:].float() + spatial.add_(0.5).mul_(scale) + quant = torch.cat([hr_coords[:, :1], spatial.int()], dim=1) + return quant.unique(dim=0) - coord_counts = shape_latent_512.get("coord_counts") + @classmethod + def execute(cls, positive, negative, shape_latent, vae, target_resolution, max_tokens): + device = comfy.model_management.get_torch_device() + prepare_trellis_vae_for_decode(vae, shape_latent["samples"].shape) + + coord_counts = shape_latent.get("coord_counts") decoder = vae.first_stage_model.shape_dec lr_resolution = 512 target_resolution = int(target_resolution) - - if coord_counts is None: - feats, coords_512 = flatten_batched_sparse_latent( - shape_latent_512["samples"], - shape_latent_512["coords"], - coord_counts, - ) - feats = feats.to(device) - coords_512 = coords_512.to(device) - slat = shape_norm(feats, coords_512) - slat.feats = slat.feats.to(next(decoder.parameters()).dtype) - hr_coords = decoder.upsample(slat, upsample_times=4) - - hr_resolution = target_resolution - while True: - quant_coords = torch.cat([ - hr_coords[:, :1], - ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), - ], dim=1) - final_coords = quant_coords.unique(dim=0) - num_tokens = final_coords.shape[0] - - if num_tokens < max_tokens or hr_resolution <= 1024: - break - hr_resolution -= 128 - - return IO.NodeOutput(final_coords,) - - items = split_batched_sparse_latent( - shape_latent_512["samples"], - shape_latent_512["coords"], - coord_counts, - ) decoder_dtype = next(decoder.parameters()).dtype - sample_hr_coords = [] - for feats_i, coords_i in items: - feats_i = feats_i.to(device) - coords_i = coords_i.to(device).clone() - coords_i[:, 0] = 0 - slat_i = shape_norm(feats_i, coords_i) - slat_i.feats = slat_i.feats.to(decoder_dtype) - sample_hr_coords.append(decoder.upsample(slat_i, upsample_times=4)) + # Decode each sample's HR coords, then search for the largest hr_resolution + # that fits under max_tokens across all samples. + if coord_counts is None: + feats, coords_512 = flatten_batched_sparse_latent( + shape_latent["samples"], shape_latent["coords"], coord_counts, + ) + slat = shape_norm(feats.to(device), coords_512.to(device)) + slat.feats = slat.feats.to(decoder_dtype) + sample_hr_coords = [decoder.upsample(slat, upsample_times=4)] + else: + items = split_batched_sparse_latent( + shape_latent["samples"], shape_latent["coords"], coord_counts, + ) + sample_hr_coords = [] + for feats_i, coords_i in items: + coords_i = coords_i.to(device).clone() + coords_i[:, 0] = 0 + slat_i = shape_norm(feats_i.to(device), coords_i) + slat_i.feats = slat_i.feats.to(decoder_dtype) + sample_hr_coords.append(decoder.upsample(slat_i, upsample_times=4)) + # Resolution search — cache the final iteration's quantized unique tensors + # so we don't recompute .unique() per sample after picking hr_resolution. hr_resolution = target_resolution + quant_unique_list = [] while True: + quant_unique_list = [] exceeds_limit = False for hr_coords_i in sample_hr_coords: - quant_coords_i = torch.cat([ - hr_coords_i[:, :1], - ((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), - ], dim=1) - if quant_coords_i.unique(dim=0).shape[0] >= max_tokens: + qu = cls._quantize_unique(hr_coords_i, lr_resolution, hr_resolution) + quant_unique_list.append(qu) + if qu.shape[0] >= max_tokens: exceeds_limit = True break - if not exceeds_limit or hr_resolution <= 1024: + if not exceeds_limit: + break + if hr_resolution <= 1024: + for k in range(len(quant_unique_list), len(sample_hr_coords)): + quant_unique_list.append( + cls._quantize_unique(sample_hr_coords[k], lr_resolution, hr_resolution) + ) break hr_resolution -= 128 - final_coords_list = [] - output_coord_counts = [] - for sample_offset, hr_coords_i in enumerate(sample_hr_coords): - quant_coords_i = torch.cat([ - hr_coords_i[:, :1], - ((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), - ], dim=1) - final_coords_i = quant_coords_i.unique(dim=0) - final_coords_i = final_coords_i.clone() - final_coords_i[:, 0] = sample_offset - final_coords_list.append(final_coords_i) - output_coord_counts.append(int(final_coords_i.shape[0])) + # Rewrite batch column to match per-sample offset and concat. + per_sample_counts = [] + for sample_offset, qu in enumerate(quant_unique_list): + qu[:, 0] = sample_offset + per_sample_counts.append(int(qu.shape[0])) + coords = torch.cat(quant_unique_list, dim=0) + counts = torch.tensor(per_sample_counts, dtype=torch.int64) + coord_resolution = hr_resolution // 16 - coords = torch.cat(final_coords_list, dim=0) - output = Types.VOXEL(coords) - output.coord_counts = torch.tensor(output_coord_counts, dtype=torch.int64) - output.resolutions = torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64) - output.upsampled = True + batch_size, _, max_tokens_out = infer_batched_coord_layout(coords) + latent = torch.zeros(batch_size, 32, max_tokens_out, 1) - return IO.NodeOutput(output,) + extras = { + "trellis2_generation_mode": "shape_generation", + "trellis2_coords": coords, + "trellis2_coord_counts": counts, + } + proj_pack = _proj_pack_from_conditioning(positive) + if proj_pack is not None: + extras["trellis2_proj_feats"] = compute_stage_proj_feats( + proj_pack, "shape_1024", coords=coords, coord_resolution=coord_resolution, + ) + positive_out = _conditioning_set_extras(positive, extras) + negative_out = _conditioning_set_extras(negative, extras) + out_latent = {"samples": latent, "coords": coords, "coord_counts": counts, + "coord_resolution": coord_resolution, "type": "trellis2"} + return IO.NodeOutput(positive_out, negative_out, out_latent) dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) @@ -639,130 +603,152 @@ class Trellis2Conditioning(IO.ComfyNode): negative = [[neg_cond_batched, {"embeds": neg_embeds_batched}]] return IO.NodeOutput(positive, negative) -class EmptyTrellis2ShapeLatent(IO.ComfyNode): +def _proj_pack_from_conditioning(conditioning): + """Return the proj_feat_pack dict embedded in a Pixal3D conditioning (or None + for vanilla Trellis2 / no conditioning connected). Pixal3DConditioning ships + the pack in cond[0][1]["proj_feat_pack"]; Trellis2Conditioning doesn't set it.""" + if not conditioning: + return None + entry = conditioning[0] + if not isinstance(entry, (list, tuple)) or len(entry) < 2 or not isinstance(entry[1], dict): + return None + return entry[1].get("proj_feat_pack") + + +def _conditioning_set_extras(conditioning, extras: dict): + """Return a copy of `conditioning` with `extras` merged into each entry's + dict — same shallow-copy pattern ControlNetApplyAdvanced uses. The dicts + are copied so we don't mutate upstream conditioning.""" + out = [] + for entry in conditioning: + if isinstance(entry, (list, tuple)) and len(entry) >= 2 and isinstance(entry[1], dict): + new_dict = entry[1].copy() + new_dict.update(extras) + out.append([entry[0], new_dict]) + else: + out.append(entry) + return out + + +class Trellis2ShapeStage(IO.ComfyNode): + """Sets up the first shape-stage sampling pass: extracts sparse coords from + the dense structure voxel produced by VaeDecodeStructureTrellis2, builds an + empty sparse latent, and attaches per-stage metadata to the conditioning so + the model reads it via extra_conds at sample time. For the second shape pass + (post-upsample), use Trellis2UpsampleStage instead — it combines the cascade + and the second-pass stage setup.""" @classmethod def define_schema(cls): return IO.Schema( - node_id="EmptyTrellis2ShapeLatent", + node_id="Trellis2ShapeStage", category="latent/3d", inputs=[ + IO.Conditioning.Input("positive"), + IO.Conditioning.Input("negative"), IO.Voxel.Input( "voxel", - tooltip=( - "Shape structure input. Accepts either a voxel structure " - "or upsampled voxel coordinates from a previous cascade stage." - ) - ), - Pixal3DProjPack.Input( - "proj_feat_pack", - optional=True, - tooltip="Pixal3D pixel-aligned projection pack from Pixal3DConditioning. Leave empty for vanilla Trellis2.", + tooltip="Dense structure voxel from VaeDecodeStructureTrellis2.", ), ], outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), IO.Latent.Output(), ] ) @classmethod - def execute(cls, voxel, proj_feat_pack=None): - is_512_pass = False - coord_resolution = None - upsampled = hasattr(voxel, "upsampled") - if upsampled: - if hasattr(voxel, "resolutions") and voxel.resolutions is not None: - coord_resolution = int(voxel.resolutions[0].item()) // 16 - voxel = voxel.data - - if not upsampled: - voxel_data = voxel.data - if proj_feat_pack is not None: - voxel_data = _pixal3d_rerotate_voxel_data(voxel_data) - decoded = voxel_data.unsqueeze(1) - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() - is_512_pass = True - coord_resolution = int(decoded.shape[-1]) + def execute(cls, positive, negative, voxel): + decoded = voxel.data.unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + coord_resolution = int(decoded.shape[-1]) + # Dispatch based on the upstream voxel resolution, mirroring upstream's + # pipeline_type → ss_res table: + # coord_res == 32 → first cascade shape pass OR pure-512 pipeline + # (img2shape_512 + shape_512 proj stage, 512 DINO). + # coord_res > 32 → pure-1024 non-cascade pipeline + # (img2shape + shape_1024 proj stage, 1024 DINO). + if coord_resolution <= 32: + mode = "shape_generation_512" + stage = "shape_512" else: - coords = voxel.int() + mode = "shape_generation" + stage = "shape_1024" batch_size, counts, max_tokens = infer_batched_coord_layout(coords) - in_channels = 32 - # image like format - latent = torch.zeros(batch_size, in_channels, max_tokens, 1) + latent = torch.zeros(batch_size, 32, max_tokens, 1) - if is_512_pass: - generation_mode = "shape_generation_512" - else: - generation_mode = "shape_generation" - model_options = {"generation_mode": generation_mode, "coords": coords, "coord_counts": counts} - if coord_resolution is not None: - model_options["coord_resolution"] = coord_resolution - if proj_feat_pack is not None: - model_options["proj_feat_pack"] = proj_feat_pack - return IO.NodeOutput({"samples": latent, "coords": coords, "coord_counts": counts, "type": "trellis2", - "model_options": model_options}) + extras = { + "trellis2_generation_mode": mode, + "trellis2_coords": coords, + "trellis2_coord_counts": counts, + } + proj_pack = _proj_pack_from_conditioning(positive) + if proj_pack is not None: + extras["trellis2_proj_feats"] = compute_stage_proj_feats( + proj_pack, stage, coords=coords, coord_resolution=coord_resolution, + ) + positive_out = _conditioning_set_extras(positive, extras) + negative_out = _conditioning_set_extras(negative, extras) + out_latent = {"samples": latent, "coords": coords, "coord_counts": counts, + "coord_resolution": coord_resolution, "type": "trellis2"} + return IO.NodeOutput(positive_out, negative_out, out_latent) -class EmptyTrellis2LatentTexture(IO.ComfyNode): +class Trellis2TextureStage(IO.ComfyNode): + """Sets up the texture-stage sampling pass. Reads coords / coord_counts / + coord_resolution and the shape_slat (the per-voxel shape latent) from the + incoming shape_latent dict — set there by Trellis2ShapeStage or + Trellis2UpsampleStage. Builds an empty sparse latent at the same coord + layout and attaches per-stage metadata to the conditioning.""" @classmethod def define_schema(cls): return IO.Schema( - node_id="EmptyTrellis2LatentTexture", + node_id="Trellis2TextureStage", category="latent/3d", inputs=[ - IO.Voxel.Input( - "voxel", - tooltip=( - "Shape structure input. Accepts either a voxel structure " - "or upsampled voxel coordinates from a previous cascade stage." - ) - ), + IO.Conditioning.Input("positive"), + IO.Conditioning.Input("negative"), IO.Latent.Input("shape_latent"), - Pixal3DProjPack.Input( - "proj_feat_pack", - optional=True, - tooltip="Pixal3D pixel-aligned projection pack from Pixal3DConditioning. Leave empty for vanilla Trellis2.", - ), ], outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), IO.Latent.Output(), ] ) @classmethod - def execute(cls, voxel, shape_latent, proj_feat_pack=None): + def execute(cls, positive, negative, shape_latent): channels = 32 - coord_resolution = None - upsampled = hasattr(voxel, "upsampled") - if upsampled: - if hasattr(voxel, "resolutions") and voxel.resolutions is not None: - coord_resolution = int(voxel.resolutions[0].item()) // 16 - voxel = voxel.data - - if not upsampled: - voxel_data = voxel.data - if proj_feat_pack is not None: - voxel_data = _pixal3d_rerotate_voxel_data(voxel_data) - decoded = voxel_data.unsqueeze(1) - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() - coord_resolution = int(decoded.shape[-1]) - else: - coords = voxel.int() + coords = shape_latent["coords"] + coord_resolution = shape_latent.get("coord_resolution") batch_size, counts, max_tokens = infer_batched_coord_layout(coords) - shape_latent = shape_latent["samples"] - if shape_latent.ndim == 4: - shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels) + shape_slat = shape_latent["samples"] + if shape_slat.ndim == 4: + shape_slat = shape_slat.squeeze(-1).transpose(1, 2).reshape(-1, channels) latent = torch.zeros(batch_size, channels, max_tokens, 1) - model_options = {"generation_mode": "texture_generation", "coords": coords, "coord_counts": counts, "shape_slat": shape_latent} + extras = { + "trellis2_generation_mode": "texture_generation", + "trellis2_coords": coords, + "trellis2_coord_counts": counts, + "trellis2_shape_slat": shape_slat, + } + proj_pack = _proj_pack_from_conditioning(positive) + if proj_pack is not None and coord_resolution is not None: + extras["trellis2_proj_feats"] = compute_stage_proj_feats( + proj_pack, "tex_1024", coords=coords, coord_resolution=coord_resolution, + ) + + positive_out = _conditioning_set_extras(positive, extras) + negative_out = _conditioning_set_extras(negative, extras) + out_latent = {"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts} if coord_resolution is not None: - model_options["coord_resolution"] = coord_resolution - if proj_feat_pack is not None: - model_options["proj_feat_pack"] = proj_feat_pack - return IO.NodeOutput({"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts, - "model_options": model_options}) + out_latent["coord_resolution"] = coord_resolution + return IO.NodeOutput(positive_out, negative_out, out_latent) class EmptyTrellis2LatentStructure(IO.ComfyNode): @@ -773,30 +759,17 @@ class EmptyTrellis2LatentStructure(IO.ComfyNode): category="latent/3d", inputs=[ IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), - Pixal3DProjPack.Input( - "proj_feat_pack", - optional=True, - tooltip="Pixal3D pixel-aligned projection pack. Leave empty for vanilla Trellis2.", - ), ], outputs=[ IO.Latent.Output(), ] ) @classmethod - def execute(cls, batch_size, proj_feat_pack=None): - # Trellis2.forward slices x[:, :8] and pads out to 32; KSampler residual math - # needs the empty latent to match latent_format (32-channel). + def execute(cls, batch_size): in_channels = 32 resolution = 16 latent = torch.zeros(batch_size, in_channels, resolution, resolution, resolution) - output = { - "samples": latent, - "type": "trellis2", - } - if proj_feat_pack is not None: - output["model_options"] = {"proj_feat_pack": proj_feat_pack} - return IO.NodeOutput(output) + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) def _dinov3_patches_to_2d(tokens, image_size, patch_size=16): h_p = w_p = image_size // patch_size @@ -813,20 +786,11 @@ def _dinov3_patches_to_2d(tokens, image_size, patch_size=16): return patches.transpose(1, 2).reshape(tokens.shape[0], -1, h_p, w_p).contiguous() -def _fov_from_moge_intrinsics(moge_intrinsics: torch.Tensor) -> float: - fx = moge_intrinsics[..., 0, 0].float() - fov = 2.0 * torch.atan(0.5 / fx.clamp(min=1e-4)) - return float(fov.mean().item()) - - -def _run_dinov3_with_patches(model, cropped_pil, image_size): - # Pixal3D's cross-attn was trained against CLS + registers only (~5 tokens), not the - # full patch grid. The patch grid goes to the proj branch via patches_2d. +def _run_dinov3_with_patches(model, composite, image_size): model_internal = model.model torch_device = comfy.model_management.get_torch_device() - resized = cropped_pil.resize((image_size, image_size), Image.Resampling.LANCZOS) - img_np = np.array(resized).astype(np.float32) / 255.0 - img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device) + img_t = comfy.utils.common_upscale(composite, image_size, image_size, "lanczos", "disabled") + img_t = img_t.to(torch_device) img_t = (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device) model_internal.image_size = image_size tokens = model_internal(img_t, skip_norm_elementwise=True)[0] @@ -838,48 +802,59 @@ def _run_dinov3_with_patches(model, cropped_pil, image_size): def _crop_image_with_mask(item_image, item_mask, max_image_size=1024): - img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - pil_img = Image.fromarray(img_np) - pil_mask = Image.fromarray(mask_np) - max_size = max(pil_img.size) - scale = min(1.0, max_image_size / max_size) - if scale < 1.0: - new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale) - pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) - pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) - scene_size = (pil_img.width, pil_img.height) - rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) - rgba_np[:, :, :3] = np.array(pil_img) - rgba_np[:, :, 3] = np.array(pil_mask) - alpha = rgba_np[:, :, 3] - bbox_coords = np.argwhere(alpha > 0.8 * 255) - if len(bbox_coords) > 0: - y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1]) - y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1]) + img = item_image.permute(2, 0, 1).unsqueeze(0).cpu().float() + mask = item_mask.unsqueeze(0).unsqueeze(0).cpu().float() + # Upstream went float→PIL uint8 implicitly; match that to keep composite bit-exact. + img = (img.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0 + mask = (mask.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0 + + H, W = img.shape[-2:] + if max(H, W) > max_image_size: + scale = max_image_size / max(H, W) + new_w, new_h = int(W * scale), int(H * scale) + img = comfy.utils.common_upscale(img, new_w, new_h, "lanczos", "disabled") + mask = comfy.utils.common_upscale(mask, new_w, new_h, "nearest-exact", "disabled") + H, W = new_h, new_w + scene_size = (W, H) + + alpha_u8 = (mask[0, 0].clamp(0, 1) * 255.0).to(torch.uint8) + fg_pixels = (alpha_u8 > 204).nonzero() + if fg_pixels.numel() > 0: + y_min, x_min = fg_pixels.min(dim=0).values.tolist() + y_max, x_max = fg_pixels.max(dim=0).values.tolist() center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0 - # Upstream pads the bbox by 10% — encoders were trained with that breathing room. - size = max(y_max - y_min, x_max - x_min) - size = int(size * 1.1) + size = int(max(y_max - y_min, x_max - x_min) * 1.1) half = size // 2 crop_x1 = int(center_x - half) crop_y1 = int(center_y - half) crop_x2 = crop_x1 + 2 * half crop_y2 = crop_y1 + 2 * half - crop_bbox = (crop_x1, crop_y1, crop_x2, crop_y2) - rgba_pil = Image.fromarray(rgba_np) - cropped_rgba = rgba_pil.crop(crop_bbox) - cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 else: logging.warning("Mask for the image is empty. Pixal3D requires a clean foreground mask.") - cropped_np = rgba_np.astype(np.float32) / 255.0 - crop_bbox = (0, 0, scene_size[0], scene_size[1]) - fg = cropped_np[:, :, :3] - alpha_float = cropped_np[:, :, 3:4] - composite_np = fg * alpha_float - composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) - return Image.fromarray(composite_uint8), crop_bbox, scene_size + crop_x1, crop_y1, crop_x2, crop_y2 = 0, 0, W, H + crop_bbox = (crop_x1, crop_y1, crop_x2, crop_y2) + # Zero-pad out-of-bounds slice (PIL.crop semantics). + pad_l = max(0, -crop_x1) + pad_t = max(0, -crop_y1) + pad_r = max(0, crop_x2 - W) + pad_b = max(0, crop_y2 - H) + if pad_l or pad_t or pad_r or pad_b: + img = torch.nn.functional.pad(img, (pad_l, pad_r, pad_t, pad_b), value=0.0) + mask = torch.nn.functional.pad(mask, (pad_l, pad_r, pad_t, pad_b), value=0.0) + crop_x1 += pad_l; crop_x2 += pad_l + crop_y1 += pad_t; crop_y2 += pad_t + cropped_img = img [..., crop_y1:crop_y2, crop_x1:crop_x2] + cropped_mask = mask[..., crop_y1:crop_y2, crop_x1:crop_x2] + + composite = (cropped_img * cropped_mask).clamp(0, 1) + composite = (composite * 255.0).round().clamp(0, 255).to(torch.uint8).float() / 255.0 + return composite, crop_bbox, scene_size + +def _fov_from_moge_intrinsics(moge_intrinsics: torch.Tensor) -> float: + fx = moge_intrinsics[..., 0, 0].float() + fov = 2.0 * torch.atan(0.5 / fx.clamp(min=1e-4)) + return float(fov.mean().item()) class Pixal3DConditioning(IO.ComfyNode): @@ -901,10 +876,6 @@ class Pixal3DConditioning(IO.ComfyNode): "mesh_scale", default=1.0, min=0.1, max=4.0, step=0.01, tooltip="Mesh scale; 1.0 means unit cube.", ), - IO.Float.Input( - "distance_override", default=0.0, min=0.0, max=10.0, step=0.001, - tooltip="Override camera distance directly. 0 = auto-derive from FOV.", - ), io.Custom("MOGE_GEOMETRY").Input( "moge_geometry", optional=True, @@ -920,13 +891,11 @@ class Pixal3DConditioning(IO.ComfyNode): outputs=[ IO.Conditioning.Output(display_name="positive"), IO.Conditioning.Output(display_name="negative"), - Pixal3DProjPack.Output(display_name="proj_feat_pack"), ], ) @classmethod def execute(cls, clip_vision_model, image, mask, camera_angle_x, mesh_scale, - distance_override=0.0, moge_geometry=None, naf_model=None) -> IO.NodeOutput: if image.ndim == 3: image = image.unsqueeze(0) @@ -945,21 +914,21 @@ class Pixal3DConditioning(IO.ComfyNode): cond_512_list, cond_1024_list = [], [] patches_512_list, patches_1024_list = [], [] - cropped_pil_list = [] + composite_list = [] crop_bbox_list, scene_size_list = [], [] torch_device = comfy.model_management.get_torch_device() for b in range(batch_size): item_image = image[b] item_mask = mask[b] if mask.size(0) > 1 else mask[0] - cropped_pil, crop_bbox, scene_size = _crop_image_with_mask( + composite, crop_bbox, scene_size = _crop_image_with_mask( item_image, item_mask, max_image_size=1024) crop_bbox_list.append(crop_bbox) scene_size_list.append(scene_size) - cropped_pil_list.append(cropped_pil) + composite_list.append(composite) - cond_512 = _run_dinov3_with_patches(clip_vision_model, cropped_pil, 512) - cond_1024 = _run_dinov3_with_patches(clip_vision_model, cropped_pil, 1024) + cond_512 = _run_dinov3_with_patches(clip_vision_model, composite, 512) + cond_1024 = _run_dinov3_with_patches(clip_vision_model, composite, 1024) cond_512_list.append(cond_512["tokens"].to(device)) cond_1024_list.append(cond_1024["tokens"].to(device)) patches_512_list.append(cond_512["patches_2d"].to(device)) @@ -971,42 +940,28 @@ class Pixal3DConditioning(IO.ComfyNode): fm_512_dino = torch.cat(patches_512_list, dim=0) fm_1024_dino = torch.cat(patches_1024_list, dim=0) - # Upstream samples the LR DINO grid AND the NAF HR grid separately at projected - # 3D points, then cats sampled features along channels. Back-projection (in model.py) - # mirrors that — here we just stash LR + optional HR per stage. + # The LR DINO grid AND the NAF HR grid are sampled separately # NAF targets per stage: shape_512=512, shape_1024=512, tex_1024=1024. - def _naf_hr(lr_feat, image_pil_list, image_size, naf_target): + def _naf_hr(lr_feat, composites, image_size, naf_target): if naf_model is None or naf_target is None: return None - # Run NAF in the input feature dtype (typically fp16 since DINO/ClipVision - # loads that way). The previous .float() cast doubled NAF memory by forcing - # full fp32 — at tex_1024/target=1024 that's ~10 GB on its own. Model - # weights need to match input dtype since PyTorch conv ops error out on - # mixed fp16-input/fp32-weight. target_dtype = lr_feat.dtype if next(naf_model.parameters()).dtype != target_dtype: naf_model.to(dtype=target_dtype) - imgs = torch.stack([ - torch.from_numpy( - np.array(p.resize((image_size, image_size), Image.Resampling.LANCZOS)) - .astype(np.float32) / 255.0 - ).permute(2, 0, 1) - for p in image_pil_list + imgs = torch.cat([ + comfy.utils.common_upscale(c, image_size, image_size, "lanczos", "disabled") + for c in composites ], dim=0).to(torch_device).to(target_dtype) - hr = naf_model(imgs, lr_feat.to(torch_device).to(target_dtype), naf_target) return hr.to(device) - hr_shape_512 = _naf_hr(fm_512_dino, cropped_pil_list, 512, (512, 512)) - hr_shape_1024 = _naf_hr(fm_1024_dino, cropped_pil_list, 1024, (512, 512)) - hr_tex_1024 = _naf_hr(fm_1024_dino, cropped_pil_list, 1024, (1024, 1024)) + hr_shape_512 = _naf_hr(fm_512_dino, composite_list, 512, (512, 512)) + hr_shape_1024 = _naf_hr(fm_1024_dino, composite_list, 1024, (512, 512)) + hr_tex_1024 = _naf_hr(fm_1024_dino, composite_list, 1024, (1024, 1024)) # distance_from_fov: grid_point (-1, 0, 0) projects to pixel (0, image_resolution-1). camera_angle_x = float(camera_angle_x) - if distance_override > 0: - distance = float(distance_override) - else: - distance = 0.5 / math.tan(camera_angle_x / 2.0) / float(mesh_scale) + distance = 0.5 / math.tan(camera_angle_x / 2.0) / float(mesh_scale) cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=device, dtype=torch.float32) dist_t = torch.tensor([distance] * batch_size, device=device, dtype=torch.float32) scale_t = torch.tensor([float(mesh_scale)] * batch_size, device=device, dtype=torch.float32) @@ -1028,13 +983,33 @@ class Pixal3DConditioning(IO.ComfyNode): "scene_sizes": scene_size_list, } - # global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024 - # (Trellis2.forward swaps context↔embeds for non-structure HR stages). + # global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024. + # proj_feat_pack rides in the conditioning dict (same place embeds, ControlNet + # hints etc. live); the sampler auto-promotes it to a model.forward kwarg via + # Trellis2.extra_conds. The same pack object is shared between pos/neg — + # CONDConstant.can_concat sees them equal and concats to a single dict, then + # Trellis2.forward zeros proj for the uncond slots via cond_or_uncond. + # Pre-compute the SS-stage proj features (dense 16³ grid) once here — the + # shape/texture stages do their own computes in their respective stage nodes. + # proj_pack lives on intermediate (CPU); force the compute onto cuda so + # the bilinear-sampling step doesn't run on CPU. + ss_proj_feats = compute_stage_proj_feats( + proj_pack, "ss", dense_grid_resolution=16, batch_size=batch_size, + device=torch_device, + ) neg_global = torch.zeros_like(global_512) neg_embeds = torch.zeros_like(global_1024) - positive = [[global_512, {"embeds": global_1024}]] - negative = [[neg_global, {"embeds": neg_embeds}]] - return IO.NodeOutput(positive, negative, proj_pack) + base_extras = { + "embeds": global_1024, "proj_feat_pack": proj_pack, + "trellis2_proj_feats": ss_proj_feats, + } + neg_extras = { + "embeds": neg_embeds, "proj_feat_pack": proj_pack, + "trellis2_proj_feats": ss_proj_feats, + } + positive = [[global_512, base_extras]] + negative = [[neg_global, neg_extras]] + return IO.NodeOutput(positive, negative) def _project_vertices_to_image_uv(vertices_world, transform_matrix, camera_angle_x, image_resolution): @@ -1069,7 +1044,7 @@ class Pixal3DAlignObject(IO.ComfyNode): category="latent/3d", inputs=[ IO.Mesh.Input("mesh"), - Pixal3DProjPack.Input("proj_feat_pack", tooltip="The proj pack produced by Pixal3DConditioning for this object."), + IO.Conditioning.Input("positive", tooltip="The positive conditioning from Pixal3DConditioning for this object — Pixal3DAlignObject reads transform_matrix / camera_angle_x / mesh_scale / crop_bboxes out of its proj_feat_pack."), io.Custom("MOGE_GEOMETRY").Input("moge_geometry", tooltip="MoGe geometry computed on the original scene image."), IO.Mask.Input( "object_mask", @@ -1089,7 +1064,10 @@ class Pixal3DAlignObject(IO.ComfyNode): ) @classmethod - def execute(cls, mesh, proj_feat_pack, moge_geometry, object_mask=None, batch_index=0) -> IO.NodeOutput: + def execute(cls, mesh, positive, moge_geometry, object_mask=None, batch_index=0) -> IO.NodeOutput: + proj_feat_pack = _proj_pack_from_conditioning(positive) + if proj_feat_pack is None: + raise ValueError("Pixal3DAlignObject: positive conditioning has no proj_feat_pack — connect a Pixal3DConditioning output.") vertices = mesh.vertices faces = mesh.faces if vertices.ndim == 3: @@ -1117,10 +1095,11 @@ class Pixal3DAlignObject(IO.ComfyNode): "image. Run MoGe on the same resized scene image Pixal3DConditioning used." ) - # Compose VaeDecodeShapeTrellis's R_y(180°) inverse with R_proj to map user mesh - # space to ProjGrid world: (X, Y, Z) -> (-X, Z, Y). + # Vertices come out of VaeDecodeShapeTrellis in the Pixal3D model frame + # (no un-rotation). Apply _PROJ_GRID_ROTATION = R_x(-90°) to map model + # frame → ProjGrid world: (X, Y, Z) -> (X, -Z, Y). v = vertices_one.float() - verts_world = torch.stack([-v[..., 0], v[..., 2], v[..., 1]], dim=-1) + verts_world = torch.stack([v[..., 0], -v[..., 2], v[..., 1]], dim=-1) verts_world = verts_world / float(mesh_scale.item()) uv_crop, _depth, valid = _project_vertices_to_image_uv( verts_world, T[0], cam_angle[0], image_resolution) @@ -1200,6 +1179,74 @@ class LoadNAFModel(IO.ComfyNode): return IO.NodeOutput(model) +class CFGGuidanceInterval(IO.ComfyNode): + """Generic model patch: apply CFG only during [start_percent, end_percent] of + the sampling schedule. Outside that window, skip the uncond computation and + collapse to effective cfg=1 — same idea as upstream Trellis2 / Pixal3D's + guidance_interval_mixin, but lives at the sampler level (via + sampler_calc_cond_batch_function) so it works for any model. + + Percents use ComfyUI's standard convention: 0.0 = start of sampling + (max-noise step), 1.0 = end of sampling (clean step). Conversion to sigma + is done via model_sampling.percent_to_sigma so the window is portable + across schedules (flow / EDM / discrete) and shift settings. + + Defaults are full-range (no bypass). For Trellis2's upstream behavior, + wire (start_percent=0.0, end_percent=0.667) on the SS / shape KSamplers; + texture defaults to cfg=1 so the node is moot there.""" + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="CFGGuidanceInterval", + category="model_patches/sampling", + inputs=[ + IO.Model.Input("model"), + IO.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, + tooltip="Fraction of sampling at which CFG turns ON (0 = beginning)."), + IO.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, + tooltip="Fraction of sampling at which CFG turns OFF (1 = end)."), + ], + outputs=[IO.Model.Output()], + ) + + @classmethod + def execute(cls, model, start_percent, end_percent): + import comfy.samplers + + model_sampling = model.get_model_object("model_sampling") + # percent_to_sigma is monotonically decreasing: percent=0 -> sigma_max, + # percent=1 -> sigma_min. So start_percent < end_percent in user space + # means sigma_start > sigma_end. "Inside the window" is sigma in + # [sigma_end, sigma_start]. + sigma_start = float(model_sampling.percent_to_sigma(start_percent)) + sigma_end = float(model_sampling.percent_to_sigma(end_percent)) + + def calc_cond_batch_with_interval(args): + sigma_val = args["sigma"][0].item() + conds = args["conds"] + input_x = args["input"] + timestep = args["sigma"] + model_ref = args["model"] + model_opts = args["model_options"] + + # conds is typically [cond, uncond]; uncond may be None when ComfyUI's + # global cfg=1 optimization has already pruned it. + cond = conds[0] + uncond = conds[1] if len(conds) > 1 else None + inside = sigma_end <= sigma_val <= sigma_start + + if uncond is None or inside: + return comfy.samplers.calc_cond_batch(model_ref, conds, input_x, timestep, model_opts) + # Outside the window: compute cond only, mirror it into the uncond slot + # so the downstream cfg_function collapses to `cond` (effective cfg=1). + out = comfy.samplers.calc_cond_batch(model_ref, [cond], input_x, timestep, model_opts) + return [out[0], out[0]] + + m = model.clone() + m.model_options["sampler_calc_cond_batch_function"] = calc_cond_batch_with_interval + return IO.NodeOutput(m) + + class Trellis2Extension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -1208,13 +1255,14 @@ class Trellis2Extension(ComfyExtension): Pixal3DConditioning, Pixal3DAlignObject, LoadNAFModel, - EmptyTrellis2ShapeLatent, + Trellis2ShapeStage, EmptyTrellis2LatentStructure, - EmptyTrellis2LatentTexture, + Trellis2TextureStage, VaeDecodeTextureTrellis, VaeDecodeShapeTrellis, VaeDecodeStructureTrellis2, - Trellis2UpsampleCascade, + Trellis2UpsampleStage, + CFGGuidanceInterval, ] diff --git a/nodes.py b/nodes.py index 65413206a..f3fcb6656 100644 --- a/nodes.py +++ b/nodes.py @@ -1537,10 +1537,6 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if "noise_mask" in latent: noise_mask = latent["noise_mask"] - if "model_options" in latent: - inner = model.model.diffusion_model - inner.meta = latent["model_options"] - callback = latent_preview.prepare_callback(model, steps) disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,