From 288f3cf134d51aefa09a09972048ec50149e54cb Mon Sep 17 00:00:00 2001 From: kijai Date: Sat, 27 Jun 2026 00:05:45 +0300 Subject: [PATCH] Cleanup --- comfy/latent_formats.py | 5 +- comfy/ldm/modules/attention.py | 69 ---------- comfy/ldm/trellis2/attention.py | 138 +++++-------------- comfy/ldm/trellis2/flexgemm.py | 101 ++++---------- comfy/ldm/trellis2/model.py | 226 +++++++------------------------- latent_preview.py | 143 -------------------- 6 files changed, 109 insertions(+), 573 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 66eb2e0d2..1ff9ada9d 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -770,9 +770,10 @@ class Hunyuan3Dv2_1(LatentFormat): latent_channels = 64 latent_dimensions = 1 -class Trellis2(LatentFormat): # TODO +class Trellis2(LatentFormat): latent_channels = 32 - trellis3d_preview = True # routes the sampler preview to Trellis3DPreviewer + + class Hunyuan3Dv2mini(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index b75a76f77..55360535a 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -32,11 +32,6 @@ except ImportError as e: raise e exit(-1) -try: - from sageattention import sageattn_varlen # sageattention >= 2 -except ImportError: - sageattn_varlen = None - SAGE_ATTENTION3_IS_AVAILABLE = False try: from sageattn3 import sageattn3_blackwell @@ -53,24 +48,6 @@ except ImportError: logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn") exit(-1) -try: - from torch.nn.attention.varlen import varlen_attn as _torch_varlen_attn -except ImportError: - _torch_varlen_attn = None - - -def _is_varlen(kwargs): - """Varlen mode is opted into by passing cu_seqlens_q in kwargs.""" - return kwargs.get("cu_seqlens_q") is not None - - -def _varlen_args(kwargs): - cu_seqlens_q = kwargs["cu_seqlens_q"] - cu_seqlens_kv = kwargs.get("cu_seqlens_kv", cu_seqlens_q) - max_seqlen_q = int(kwargs["max_seqlen_q"]) - max_seqlen_kv = int(kwargs.get("max_seqlen_kv", max_seqlen_q)) - return cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv - REGISTERED_ATTENTION_FUNCTIONS = {} def register_attention_function(name: str, func: Callable): # avoid replacing existing functions @@ -167,8 +144,6 @@ def wrap_attn(func): @wrap_attn def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): - if _is_varlen(kwargs): - return attention_pytorch(q, k, v, heads, mask=mask, **kwargs) attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -243,8 +218,6 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape @wrap_attn def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): - if _is_varlen(kwargs): - return attention_pytorch(query, key, value, heads, mask=mask, **kwargs) attn_precision = get_attn_precision(attn_precision, query.dtype) if skip_reshape: @@ -320,8 +293,6 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, @wrap_attn def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): - if _is_varlen(kwargs): - return attention_pytorch(q, k, v, heads, mask=mask, **kwargs) attn_precision = get_attn_precision(attn_precision, q.dtype) if skip_reshape: @@ -453,17 +424,6 @@ except: @wrap_attn def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): - if _is_varlen(kwargs): - # q, k, v expected as packed [T_total, H, C]. xformers wants per-item - # seqlen lists and a leading batch dim of 1. - cu_seqlens_q, cu_seqlens_kv, _max_q, _max_kv = _varlen_args(kwargs) - q_seqlen = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).tolist() - kv_seqlen = (cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]).tolist() - attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) - return xformers.ops.memory_efficient_attention( - q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attn_bias=attn_bias, - )[0] - b = q.shape[0] dim_head = q.shape[-1] # check to make sure xformers isn't broken @@ -533,22 +493,6 @@ else: @wrap_attn def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): - if _is_varlen(kwargs): - # q, k, v expected as packed [T_total, H, C]. cu_seqlens_q / cu_seqlens_kv - # describe per-item offsets. Native varlen kernel if available, else - # nested-tensor SDPA fallback. mask/attn_precision/etc are ignored here. - cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs) - if _torch_varlen_attn is not None: - return _torch_varlen_attn(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv) - q_nj = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long()) - k_nj = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_kv.long()) - v_nj = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_kv.long()) - out = comfy.ops.scaled_dot_product_attention( - q_nj.transpose(1, 2), k_nj.transpose(1, 2), v_nj.transpose(1, 2), - attn_mask=None, dropout_p=0.0, is_causal=False, - ) - return out.transpose(1, 2).values() - if skip_reshape: b, _, _, dim_head = q.shape else: @@ -597,13 +541,6 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha @wrap_attn def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): - if _is_varlen(kwargs): - # q, k, v expected as packed [T_total, H, C]. - if sageattn_varlen is None: - # sageattention v1 has no varlen kernel; fall back to attention_pytorch's varlen path. - return attention_pytorch(q, k, v, heads, mask=mask, **kwargs) - cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs) - return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv) if kwargs.get("low_precision_attention", True) is False: return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs) @@ -761,12 +698,6 @@ except AttributeError as error: @wrap_attn def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): - if _is_varlen(kwargs): - # q, k, v expected as packed [T_total, H, C]. - from flash_attn import flash_attn_varlen_func - cu_seqlens_q, cu_seqlens_kv, max_q, max_kv = _varlen_args(kwargs) - return flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q, max_kv) - if skip_reshape: b, _, _, dim_head = q.shape else: diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 07a629c01..571101863 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -1,33 +1,7 @@ -import math -from typing import Tuple, Union - -import torch - from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.trellis2.vae import VarLenTensor -def _build_cu_seqlens(seqlen, device): - """Cumulative-offset tensor from a list/tensor of per-item sequence lengths.""" - if isinstance(seqlen, torch.Tensor): - return torch.cat( - [torch.zeros(1, dtype=torch.int32, device=seqlen.device), torch.cumsum(seqlen, dim=0).int()], - dim=0, - ).to(device) - return torch.cat( - [torch.tensor([0]), torch.cumsum(torch.tensor(seqlen), dim=0)], - ).int().to(device) - - -def _layout_to_feats(t): - """Returns (sparse_template_or_None, seqlen, feats[T, H, C]).""" - if isinstance(t, VarLenTensor): - seqlen = [t.layout[i].stop - t.layout[i].start for i in range(t.shape[0])] - return t, seqlen, t.feats - N, L = t.shape[:2] - return None, [L] * N, t.reshape(-1, *t.shape[-2:]) - - def dense_attention(q, k, v, **kwargs): """q, k, v: [B, L, H, C]. Permutes for comfy's [B, H, L, C] convention.""" heads = q.shape[2] @@ -38,86 +12,36 @@ def dense_attention(q, k, v, **kwargs): return out.permute(0, 2, 1, 3) +def _to_rect(t): + """Fold a VarLenTensor packed as [sum(L_i), H, C] into a dense [B, L, H, C]. + + The sparse generation stages run a single object per call (optionally + CFG-duplicated, which keeps every batch entry the same length), so the + packed layout is rectangular and attention is ordinary dense attention over + a batch dim — no variable-length kernel needed. A dense [B, L, H, C] tensor + (e.g. cross-attention context) passes through unchanged. + """ + if not isinstance(t, VarLenTensor): + return t + B = t.shape[0] + seqlens = [t.layout[i].stop - t.layout[i].start for i in range(B)] + if len(set(seqlens)) != 1: + raise ValueError( + "trellis2 sparse attention expects equal sequence lengths per batch " + f"(single object, optionally CFG-duplicated); got {seqlens}. " + "Multi-object batching is not supported." + ) + return t.feats.view(B, seqlens[0], *t.feats.shape[1:]) + + def sparse_attention(q, k, v, **kwargs): + """Full attention over a SparseTensor's voxels. + + Single object (optionally CFG-duplicated) => the packed layout is + rectangular, so we fold it into a batch dim and run ordinary dense + attention. Output type matches q. """ - Varlen attention for SparseTensor inputs. Each of q, k, v may be a VarLenTensor - (sparse) or dense [B, L, H, C]. Output type matches q. Backend dispatch lives - in comfy.ldm.modules.attention.optimized_attention; we just build cu_seqlens - from the layouts. - """ - s, q_seqlen, q_feats = _layout_to_feats(q) - _, kv_seqlen, k_feats = _layout_to_feats(k) - _, _, v_feats = _layout_to_feats(v) - heads = q_feats.shape[1] - - device = q_feats.device - cu_seqlens_q = _build_cu_seqlens(q_seqlen, device) - cu_seqlens_kv = _build_cu_seqlens(kv_seqlen, device) - - out = optimized_attention( - q_feats, k_feats, v_feats, heads, - cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - max_seqlen_q=max(q_seqlen), max_seqlen_kv=max(kv_seqlen), - skip_reshape=True, skip_output_reshape=True, - **kwargs, - ) - - if s is not None: - return s.replace(out) - N, L = q.shape[:2] - return out.reshape(N, L, heads, -1) - - -def sparse_windowed_self_attention(qkv, window_size: int, shift_window: Tuple[int, int, int] = (0, 0, 0)): - """Windowed sparse self-attention. Partitions voxels into windows via spatial - sort, then runs varlen attention with one sequence per non-empty window.""" - cache_name = f'windowed_attention_{window_size}_{shift_window}' - cache = qkv.get_spatial_cache(cache_name) - if cache is None: - cache = calc_window_partition(qkv, window_size, shift_window) - qkv.register_spatial_cache(cache_name, cache) - fwd_indices, bwd_indices, seq_lens = cache - - qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] - q, k, v = qkv_feats.unbind(dim=1) # each [M, H, C] - heads = q.shape[1] - device = q.device - - cu_seqlens = _build_cu_seqlens(seq_lens, device) - max_seqlen = int(seq_lens.max()) - - out = optimized_attention( - q, k, v, heads, - cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens, - max_seqlen_q=max_seqlen, max_seqlen_kv=max_seqlen, - skip_reshape=True, skip_output_reshape=True, - ) - out = out[bwd_indices] - return qkv.replace(out) - - -def calc_window_partition( - tensor, - window_size: Union[int, Tuple[int, ...]], - shift_window: Union[int, Tuple[int, ...]] = 0, -): - """Returns (fwd_indices, bwd_indices, seq_lens) for window partitioning.""" - DIM = tensor.coords.shape[1] - 1 - shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window - window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size - shifted_coords = tensor.coords.clone().detach() - shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) - - MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)] - NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] - OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] - - shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) - shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) - fwd_indices = torch.argsort(shifted_indices) - bwd_indices = torch.empty_like(fwd_indices) - bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) - seq_lens = torch.bincount(shifted_indices) - seq_lens = seq_lens[seq_lens != 0] - - return fwd_indices, bwd_indices, seq_lens + out = dense_attention(_to_rect(q), _to_rect(k), _to_rect(v), **kwargs) # [B, Lq, H, C] + if isinstance(q, VarLenTensor): + return q.replace(out.reshape(-1, *out.shape[2:])) + return out diff --git a/comfy/ldm/trellis2/flexgemm.py b/comfy/ldm/trellis2/flexgemm.py index e33b50376..e48a3f3ce 100644 --- a/comfy/ldm/trellis2/flexgemm.py +++ b/comfy/ldm/trellis2/flexgemm.py @@ -3,8 +3,6 @@ import torch import comfy.model_management -UINT32_SENTINEL = 0xFFFFFFFF - def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device): """Kernel spatial offsets in the same order as the CUDA/Triton kernels.""" @@ -19,11 +17,9 @@ def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device): class TorchHashMap: """Sorted-array hashmap backed by torch.searchsorted.""" - def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): - device = keys.device + def __init__(self, keys: torch.Tensor, values: torch.Tensor): self.sorted_keys, order = torch.sort(keys.to(torch.long)) self.sorted_vals = values.to(torch.long)[order] - self.default_value = torch.tensor(default_value, dtype=torch.long, device=device) self._n = self.sorted_keys.numel() # Chunk size for lookup_flat, caps each transient to ~CHUNK rows. @@ -54,65 +50,28 @@ def build_submanifold_neighbor_map( Kw, Kh, Kd, Dw, Dh, Dd, ): + # neighbor[i, v] = index of the voxel at voxel i's coord + kernel-offset v, or -1. + # Chunked over voxels so the [chunk, V, 3] candidate transient stays bounded. device = coords.device M = coords.shape[0] - V = Kw * Kh * Kd - half_V = V // 2 + 1 - INVALID = -1 + offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device).long() # [V, 3] + V = offsets.shape[0] + center = torch.tensor([(Kw // 2) * Dw, (Kh // 2) * Dh, (Kd // 2) * Dd], device=device) + WHD, HD = W * H * D, H * D - # int32 neighbour map: 4 bytes/elem vs 8 bytes for int64 - neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.int32) - - b = coords[:, 0].long() - x = coords[:, 1].long() - y = coords[:, 2].long() - z = coords[:, 3].long() - - offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device) - - ox = x - (Kw // 2) * Dw - oy = y - (Kh // 2) * Dh - oz = z - (Kd // 2) * Dd - - for v in range(half_V): - if v == half_V - 1: - # Center voxel always maps to itself - neighbor[:, v] = torch.arange(M, device=device, dtype=torch.int32) - continue - - dx, dy, dz = offsets[v] - - kx = ox + dx - ky = oy + dy - kz = oz + dz - - valid = ( - (kx >= 0) & (kx < W) & - (ky >= 0) & (ky < H) & - (kz >= 0) & (kz < D) - ) - - flat = ( - b[valid] * (W * H * D) + - kx[valid] * (H * D) + - ky[valid] * D + - kz[valid] - ) - - if flat.numel() > 0: - found = hashmap.lookup_flat(flat) - idx_in_M = torch.where(valid)[0] - neighbor[idx_in_M, v] = found.to(torch.int32) - - # BUG FIX: old code used found != hashmap.default_value which - # compared int32 -1 against int64 4294967295 → always True. - # We now explicitly check for valid indices. - valid_found_mask = found >= 0 - if valid_found_mask.any(): - src_points = idx_in_M[valid_found_mask] - dst_points = found[valid_found_mask].long() - neighbor[dst_points, V - 1 - v] = src_points.to(torch.int32) + neighbor = torch.empty((M, V), dtype=torch.int32, device=device) + # ~V*40 bytes/voxel of transient (int64 cand + flat + masks); cap at ~0.5 GB. + chunk = max(1, min(M, int(0.5 * (1024 ** 3) / (V * 40)))) + for s in range(0, M, chunk): + e = min(s + chunk, M) + b = coords[s:e, 0].long() + cand = coords[s:e, 1:4].long()[:, None, :] + offsets[None, :, :] - center # [c, V, 3] + x, y, z = cand[..., 0], cand[..., 1], cand[..., 2] + in_bounds = (x >= 0) & (x < W) & (y >= 0) & (y < H) & (z >= 0) & (z < D) # [c, V] + flat = b[:, None] * WHD + x * HD + y * D + z # [c, V] + flat = torch.where(in_bounds, flat, torch.full_like(flat, -1)) # OOB -> guaranteed miss + neighbor[s:e] = hashmap.lookup_flat(flat.reshape(-1)).view(e - s, V) return neighbor def get_recommended_chunk_mem( @@ -155,7 +114,7 @@ def sparse_submanifold_conv3d( coords[:, 2].long() * y_stride + coords[:, 3].long() * z_stride) vals = torch.arange(coords.shape[0], dtype=torch.int32, device=device) - hashmap = TorchHashMap(flat_keys, vals, UINT32_SENTINEL) + hashmap = TorchHashMap(flat_keys, vals) neighbor = build_submanifold_neighbor_map( hashmap, coords, W, H, D, Kw, Kh, Kd, @@ -165,13 +124,15 @@ def sparse_submanifold_conv3d( neighbor = neighbor_cache N_pts = feats.shape[0] - sentinel = -1 weight_T = weight.view(Co, V * Ci).T output = torch.empty(N_pts, Co, device=device, dtype=feats.dtype) - # Chunk size from memory budget. The dominant peak is `gathered`, of shape (chunk, V, Ci) in feats.dtype. + # Zero row at index N_pts; missing neighbors (-1) gather it -> no separate masking. + feats_padded = torch.cat([feats, feats.new_zeros(1, Ci)], dim=0) + + # Chunk over voxels to bound the (chunk, V, Ci) gather. max_chunk_mem_gb = get_recommended_chunk_mem(device) mem_per_row = V * Ci * feats.element_size() max_chunk_mem = max_chunk_mem_gb * (1024 ** 3) @@ -182,18 +143,10 @@ def sparse_submanifold_conv3d( end = min(start + chunk_size, N_pts) actual_chunk = end - start - chunk_neighbor = neighbor[start:end] - chunk_valid = chunk_neighbor != sentinel - # clamp(-1 -> 0) keeps invalid indices in-range so the gather is safe - chunk_idx = chunk_neighbor.clamp(min=0) - - # (chunk, V, Ci) gather, then in-place zero of invalid neighbors. - gathered = feats[chunk_idx] - gathered.mul_(chunk_valid.unsqueeze(-1)) - - # GEMM (chunk, V*Ci) @ (V*Ci, Co) -> (chunk, Co), written to output[start:end]. + chunk_idx = torch.where(neighbor[start:end] < 0, N_pts, neighbor[start:end]) # -1 -> zero row + gathered = feats_padded[chunk_idx] # (chunk, V, Ci) gathered_flat = gathered.view(actual_chunk, V * Ci) - torch.matmul(gathered_flat, weight_T, out=output[start:end]) + torch.matmul(gathered_flat, weight_T, out=output[start:end]) # (chunk, V*Ci) @ (V*Ci, Co) if bias is not None: output += bias.unsqueeze(0).to(output.dtype) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 3ed6e114d..3f3751052 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -3,12 +3,9 @@ import torch.nn.functional as F import torch.nn as nn from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor from typing import Optional, Tuple, Literal, Union, List -from comfy.ldm.trellis2.attention import ( - sparse_windowed_self_attention, sparse_attention, dense_attention -) +from comfy.ldm.trellis2.attention import sparse_attention, dense_attention from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder from comfy.ldm.flux.math import apply_rope, apply_rope1 -from comfy.ldm.trellis2 import sampling_preview class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: @@ -26,13 +23,6 @@ class SparseFeedForwardNet(nn.Module): def forward(self, x: VarLenTensor) -> VarLenTensor: return self.mlp(x) -class LayerNorm32(nn.LayerNorm): - def forward(self, x: torch.Tensor) -> torch.Tensor: - x_dtype = x.dtype - x = x.to(dtype=torch.float32) - o = super().forward(x) - return o.to(dtype=x_dtype) - class SparseMultiHeadRMSNorm(nn.Module): def __init__(self, dim: int, heads: int, device, dtype): super().__init__() @@ -44,13 +34,7 @@ class SparseMultiHeadRMSNorm(nn.Module): return F.rms_norm(x, (x.shape[-1],)) * self.gamma class SparseRotaryPositionEmbedder(nn.Module): - def __init__( - self, - head_dim: int, - dim: int = 3, - rope_freq: Tuple[float, float] = (1.0, 10000.0), - device=None - ): + def __init__(self, head_dim: int, dim: int = 3, rope_freq: Tuple[float, float] = (1.0, 10000.0), device=None): super().__init__() self.head_dim = head_dim self.dim = dim @@ -111,12 +95,7 @@ class SparseMultiHeadAttention(nn.Module): num_heads: int, ctx_channels: Optional[int] = None, type: Literal["self", "cross"] = "self", - attn_mode: Literal["full", "windowed", "double_windowed"] = "full", - window_size: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, qkv_bias: bool = True, - use_rope: bool = False, - rope_freq: Tuple[int, int] = (1.0, 10000.0), qk_rms_norm: bool = False, device=None, dtype=None, operations=None ): @@ -127,10 +106,6 @@ class SparseMultiHeadAttention(nn.Module): self.ctx_channels = ctx_channels if ctx_channels is not None else channels self.num_heads = num_heads self._type = type - self.attn_mode = attn_mode - self.window_size = window_size - self.shift_window = shift_window - self.use_rope = use_rope self.qk_rms_norm = qk_rms_norm if self._type == "self": @@ -145,8 +120,8 @@ class SparseMultiHeadAttention(nn.Module): self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype) - if use_rope: - self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq, device=device) + if self._type == "self": + self.rope = SparseRotaryPositionEmbedder(self.head_dim, device=device) @staticmethod def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: @@ -170,43 +145,16 @@ class SparseMultiHeadAttention(nn.Module): x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats - def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None, - transformer_options=None) -> SparseTensor: + def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None, transformer_options=None) -> SparseTensor: if self._type == "self": qkv = self._linear(self.to_qkv, x) qkv = self._fused_pre(qkv, num_fused=3) - if self.attn_mode == "full": - q, k, v = qkv.unbind(dim=-3) - if self.qk_rms_norm: - q = self.q_rms_norm(q) - k = self.k_rms_norm(k) - if self.use_rope: - q, k = self.rope(q, k) - h = sparse_attention(q, k, v, transformer_options=transformer_options) - else: - # Windowed paths take packed qkv; preserve any per-head norm/rope. - if self.qk_rms_norm or self.use_rope: - q, k, v = qkv.unbind(dim=-3) - if self.qk_rms_norm: - q = self.q_rms_norm(q) - k = self.k_rms_norm(k) - if self.use_rope: - q, k = self.rope(q, k) - qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) - if self.attn_mode == "windowed": - h = sparse_windowed_self_attention( - qkv, self.window_size, shift_window=self.shift_window - ) - elif self.attn_mode == "double_windowed": - qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:]) - qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2]) - h0 = sparse_windowed_self_attention( - qkv0, self.window_size, shift_window=(0, 0, 0) - ) - h1 = sparse_windowed_self_attention( - qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3) - ) - h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1)) + q, k, v = qkv.unbind(dim=-3) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + q, k = self.rope(q, k) + h = sparse_attention(q, k, v, transformer_options=transformer_options) else: q = self._linear(self.to_q, x) q = self._reshape_chs(q, (self.num_heads, -1)) @@ -276,37 +224,25 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): ctx_channels: int, num_heads: int, mlp_ratio: float = 4.0, - attn_mode: Literal["full", "swin"] = "full", - window_size: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, - use_checkpoint: bool = False, - use_rope: bool = False, - rope_freq: Tuple[float, float] = (1.0, 10000.0), qk_rms_norm: bool = False, qk_rms_norm_cross: bool = False, qkv_bias: bool = True, share_mod: bool = False, - image_attn_mode: Literal["global", "proj", "gated_proj"] = "global", + image_attn_mode: Literal["global", "proj"] = "global", proj_in_channels: Optional[int] = None, device=None, dtype=None, operations=None ): super().__init__() - self.use_checkpoint = use_checkpoint self.share_mod = share_mod self.image_attn_mode = image_attn_mode - self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) - self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device) - self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) + self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) self.self_attn = SparseMultiHeadAttention( channels, num_heads=num_heads, type="self", - attn_mode=attn_mode, - window_size=window_size, - shift_window=shift_window, qkv_bias=qkv_bias, - use_rope=use_rope, - rope_freq=rope_freq, qk_rms_norm=qk_rms_norm, device=device, dtype=dtype, operations=operations ) @@ -315,7 +251,6 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): ctx_channels=ctx_channels, num_heads=num_heads, type="cross", - attn_mode="full", qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, device=device, dtype=dtype, operations=operations @@ -347,8 +282,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) else: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) - # Fuse the (mul + add) and (mul + residual) pairs into addcmul so the - # mod/shift broadcasts hit one kernel each instead of two. + # Fuse the (mul + add) and (mul + residual) pairs into addcmul b_map = x.batch_boardcast_map h_feats = self.norm1(x.feats) @@ -386,18 +320,12 @@ class SLatFlowModel(nn.Module): num_heads: Optional[int] = None, num_head_channels: Optional[int] = 64, mlp_ratio: float = 4, - pe_mode: Literal["ape", "rope"] = "rope", - rope_freq: Tuple[float, float] = (1.0, 10000.0), - use_checkpoint: bool = False, share_mod: bool = False, - initialization: str = 'vanilla', qk_rms_norm: bool = False, qk_rms_norm_cross: bool = False, - image_attn_mode: Literal["global", "proj", "gated_proj"] = "global", + image_attn_mode: Literal["global", "proj"] = "global", proj_in_channels: Optional[int] = None, - dtype = None, - device = None, - operations = None, + dtype = None, device = None, operations = None, ): super().__init__() self.resolution = resolution @@ -408,10 +336,7 @@ class SLatFlowModel(nn.Module): self.num_blocks = num_blocks self.num_heads = num_heads or model_channels // num_head_channels self.mlp_ratio = mlp_ratio - self.pe_mode = pe_mode - self.use_checkpoint = use_checkpoint self.share_mod = share_mod - self.initialization = initialization self.qk_rms_norm = qk_rms_norm self.qk_rms_norm_cross = qk_rms_norm_cross self.image_attn_mode = image_attn_mode @@ -433,10 +358,6 @@ class SLatFlowModel(nn.Module): cond_channels, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, - attn_mode='full', - use_checkpoint=self.use_checkpoint, - use_rope=(pe_mode == "rope"), - rope_freq=rope_freq, share_mod=self.share_mod, qk_rms_norm=self.qk_rms_norm, qk_rms_norm_cross=self.qk_rms_norm_cross, @@ -491,14 +412,22 @@ class FeedForwardNet(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp(x) +# class MultiHeadRMSNorm(nn.Module): +# def __init__(self, dim: int, heads: int, device=None, dtype=None): +# super().__init__() +# self.scale = dim ** 0.5 +# self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype)) + +# def forward(self, x: torch.Tensor) -> torch.Tensor: +# return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + class MultiHeadRMSNorm(nn.Module): def __init__(self, dim: int, heads: int, device=None, dtype=None): super().__init__() - self.scale = dim ** 0.5 self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype)) def forward(self, x: torch.Tensor) -> torch.Tensor: - return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + return (F.rms_norm(x.float(), (x.shape[-1],)) * self.gamma).to(x.dtype) class MultiHeadAttention(nn.Module): @@ -508,12 +437,7 @@ class MultiHeadAttention(nn.Module): num_heads: int, ctx_channels: Optional[int]=None, type: Literal["self", "cross"] = "self", - attn_mode: Literal["full", "windowed"] = "full", - window_size: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, qkv_bias: bool = True, - use_rope: bool = False, - rope_freq: Tuple[float, float] = (1.0, 10000.0), qk_rms_norm: bool = False, device=None, dtype=None, operations=None ): @@ -524,10 +448,6 @@ class MultiHeadAttention(nn.Module): self.ctx_channels = ctx_channels if ctx_channels is not None else channels self.num_heads = num_heads self._type = type - self.attn_mode = attn_mode - self.window_size = window_size - self.shift_window = shift_window - self.use_rope = use_rope self.qk_rms_norm = qk_rms_norm if self._type == "self": @@ -552,12 +472,11 @@ class MultiHeadAttention(nn.Module): if self.qk_rms_norm: q = self.q_rms_norm(q) k = self.k_rms_norm(k) - if self.use_rope: - assert phases is not None, "Phases must be provided for RoPE" - # phases is [L, head_dim/2, 2, 2]; broadcast to [1, L, 1, ...] - # to align with q/k of shape [B, L, H, head_dim]. - f_cis = phases.unsqueeze(0).unsqueeze(2) - q, k = apply_rope(q, k, f_cis) + assert phases is not None, "Phases must be provided for RoPE" + # phases is [L, head_dim/2, 2, 2]; broadcast to [1, L, 1, ...] + # to align with q/k of shape [B, L, H, head_dim]. + f_cis = phases.unsqueeze(0).unsqueeze(2) + q, k = apply_rope(q, k, f_cis) h = dense_attention(q, k, v, transformer_options=transformer_options) else: Lkv = context.shape[1] @@ -581,37 +500,25 @@ class ModulatedTransformerCrossBlock(nn.Module): ctx_channels: int, num_heads: int, mlp_ratio: float = 4.0, - attn_mode: Literal["full", "windowed"] = "full", - window_size: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, - use_checkpoint: bool = False, - use_rope: bool = False, - rope_freq: Tuple[int, int] = (1.0, 10000.0), qk_rms_norm: bool = False, qk_rms_norm_cross: bool = False, qkv_bias: bool = True, share_mod: bool = False, - image_attn_mode: Literal["global", "proj", "gated_proj"] = "global", + image_attn_mode: Literal["global", "proj"] = "global", proj_in_channels: Optional[int] = None, device=None, dtype=None, operations=None ): super().__init__() - self.use_checkpoint = use_checkpoint self.share_mod = share_mod self.image_attn_mode = image_attn_mode - self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) - self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device) - self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) + self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype) self.self_attn = MultiHeadAttention( channels, num_heads=num_heads, type="self", - attn_mode=attn_mode, - window_size=window_size, - shift_window=shift_window, qkv_bias=qkv_bias, - use_rope=use_rope, - rope_freq=rope_freq, qk_rms_norm=qk_rms_norm, device=device, dtype=dtype, operations=operations ) @@ -620,7 +527,6 @@ class ModulatedTransformerCrossBlock(nn.Module): ctx_channels=ctx_channels, num_heads=num_heads, type="cross", - attn_mode="full", qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, device=device, dtype=dtype, operations=operations @@ -640,25 +546,17 @@ class ModulatedTransformerCrossBlock(nn.Module): device=device, dtype=dtype, operations=operations ) if not share_mod: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device) - ) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)) else: self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) def _forward(self, x: torch.Tensor, mod: torch.Tensor, context, phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor: if self.share_mod: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + mod = (self.modulation + mod).type(mod.dtype) else: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) - shift_msa = shift_msa.unsqueeze(1) - scale_msa = scale_msa.unsqueeze(1) - gate_msa = gate_msa.unsqueeze(1) - shift_mlp = shift_mlp.unsqueeze(1) - scale_mlp = scale_mlp.unsqueeze(1) - gate_mlp = gate_mlp.unsqueeze(1) + mod = self.adaLN_modulation(mod) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(6, dim=-1) h = torch.addcmul(shift_msa, self.norm1(x), 1 + scale_msa) h = self.self_attn(h, phases=phases, transformer_options=transformer_options) @@ -694,14 +592,10 @@ class SparseStructureFlowModel(nn.Module): num_heads: Optional[int] = None, num_head_channels: Optional[int] = 64, mlp_ratio: float = 4, - pe_mode: Literal["ape", "rope"] = "rope", - rope_freq: Tuple[float, float] = (1.0, 10000.0), - use_checkpoint: bool = False, share_mod: bool = False, - initialization: str = 'vanilla', qk_rms_norm: bool = False, qk_rms_norm_cross: bool = False, - image_attn_mode: Literal["global", "proj", "gated_proj"] = "global", + image_attn_mode: Literal["global", "proj"] = "global", proj_in_channels: Optional[int] = None, operations=None, device = None, @@ -718,10 +612,7 @@ class SparseStructureFlowModel(nn.Module): self.num_blocks = num_blocks self.num_heads = num_heads or model_channels // num_head_channels self.mlp_ratio = mlp_ratio - self.pe_mode = pe_mode - self.use_checkpoint = use_checkpoint self.share_mod = share_mod - self.initialization = initialization self.qk_rms_norm = qk_rms_norm self.qk_rms_norm_cross = qk_rms_norm_cross self.image_attn_mode = image_attn_mode @@ -742,9 +633,6 @@ class SparseStructureFlowModel(nn.Module): rope_phases = pos_embedder(coords) self.register_buffer("rope_phases", rope_phases, persistent=False) - if pe_mode != "rope": - self.rope_phases = None - self.input_layer = operations.Linear(in_channels, model_channels, device=device, dtype=dtype) self.blocks = nn.ModuleList([ @@ -753,10 +641,6 @@ class SparseStructureFlowModel(nn.Module): cond_channels, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, - attn_mode='full', - use_checkpoint=self.use_checkpoint, - use_rope=(pe_mode == "rope"), - rope_freq=rope_freq, share_mod=share_mod, qk_rms_norm=self.qk_rms_norm, qk_rms_norm_cross=self.qk_rms_norm_cross, @@ -788,18 +672,10 @@ class SparseStructureFlowModel(nn.Module): return h -def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0): - t_shifted = t_shifted / 1000.0 - t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1)) - t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear) - t_new *= 1000.0 - return t_new - -# Pixal3D ProjGrid math — port of upstream's ProjGrid + project_points_to_image_batch. -# World frame uses world Y as depth, camera looks along -Z local; -# transform_matrix is camera-to-world (inverted internally). Intrinsics: fx = 16 / tan(fov/2) -# with sensor_width = 32mm. +# Pixal3D ProjGrid math +# World frame uses world Y as depth, camera looks along -Z local +# transform_matrix is camera-to-world (inverted internally). Intrinsics: fx = 16 / tan(fov/2) with sensor_width = 32mm. _PROJ_GRID_ROTATION = torch.tensor( [[1.0, 0.0, 0.0], @@ -815,7 +691,7 @@ _PROJ_FRONT_VIEW_TRANSFORM = torch.tensor( ) -def _build_proj_transform_matrix(distance: torch.Tensor, batch_size: int, +def build_proj_transform_matrix(distance: torch.Tensor, batch_size: int, device, dtype=torch.float32) -> torch.Tensor: T = _PROJ_FRONT_VIEW_TRANSFORM.to(device=device, dtype=dtype) T = T.unsqueeze(0).expand(batch_size, -1, -1).clone() @@ -849,8 +725,7 @@ def _project_points_to_image(points_world: torch.Tensor, transform_matrix: torch def _sample_features(feature_map: torch.Tensor, uv_ndc: torch.Tensor) -> torch.Tensor: B, C, _, _ = feature_map.shape grid = uv_ndc.view(B, -1, 1, 2).to(feature_map.dtype) - feat = F.grid_sample(feature_map, grid, mode="bilinear", - padding_mode="border", align_corners=False) + feat = F.grid_sample(feature_map, grid, mode="bilinear", padding_mode="border", align_corners=False) return feat.squeeze(-1) @@ -922,8 +797,6 @@ def _back_project_to_tokens( return out - - def _select_stage_entry(proj_pack: dict, stage: Optional[str]): """Returns (feature_map_lr, feature_map_hr_or_None, image_resolution).""" stages = proj_pack.get("stages") @@ -1025,7 +898,7 @@ def _shape_proj_cond(global_cond: torch.Tensor, image_attn_mode: str, else: proj_feats = proj_feats.repeat((repeats, 1)) - # Mirror upstream's neg_cond by zeroing proj for any uncond batch slot. + # zero proj for any uncond batch slot if cond_or_uncond is not None and eval_batch is not None: uncond_slots = [i for i, v in enumerate(cond_or_uncond) if v == 1] if uncond_slots: @@ -1100,9 +973,6 @@ class Trellis2(nn.Module): # Pre-computed per-stage back-projected features proj_feats = kwargs.get("trellis2_proj_feats") - sampling_preview.set_context(mode=mode, coords=coords, coord_counts=coord_counts, - model_frame=kwargs.get("trellis2_model_frame")) - is_first_shape_pass = False if mode == "shape_generation_512": is_first_shape_pass = True diff --git a/latent_preview.py b/latent_preview.py index d97c538fe..a9d777661 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -75,153 +75,10 @@ class Latent2RGBPreviewer(LatentPreviewer): return preview_to_image(latent_image) -class Trellis3DPreviewer(LatentPreviewer): - """Per-step preview for the Trellis2/Pixal3D cascade. - - Structure stage: x0 is a dense [B, 32, 16, 16, 16] grid — project the per-cell - activation norm orthographically to a 2D occupancy heatmap (no decode, no coords). - Texture stage: x0 is sparse [B, 32, N, 1] — splat the first 3 latent channels as - pseudo-color onto the fixed voxel coords (read from the sampling side-channel). - Shape stage adds no visible motion (coords are fixed, only sub-voxel detail - evolves) and a full decode per step is too costly, so it's skipped. - - Both stages render through one orthographic point splatter (static view). - """ - _SIZE = 128 - _FILL = 0.9 # fraction of frame the texture splat spans (leaves a border) - _STRUCTURE_ZOOM = 0.66 # <1 pulls the SS camera back, leaving margin around the blob - - def _splat(self, points, colors, rad): - # points: [K, 3] voxel-index coords. colors: [K, 3] in [0, 1]. - # Center + isotropic-normalize, project orthographically front-on - # (x->horizontal, y->up, z->depth), then splat a square footprint per point - # with one global far->near sort (painter's). - S = self._SIZE - dev = points.device # keep every tensor here - p = points.float() - p = p - (p.amax(0) + p.amin(0)) * 0.5 - p = p / p.abs().amax().clamp(min=1e-8) - x, y, z = p[:, 0], p[:, 1], p[:, 2] - depth = z # into-screen - m = self._FILL - u = ((x * m * 0.5 + 0.5) * (S - 1)).long().clamp(0, S - 1) - v = (((-y) * m * 0.5 + 0.5) * (S - 1)).long().clamp(0, S - 1) # image up = +y - cols = colors.to(dev) - us, vs, ds, cs = [], [], [], [] - for dv in range(-rad, rad + 1): - for du in range(-rad, rad + 1): - us.append((u + du).clamp(0, S - 1)) - vs.append((v + dv).clamp(0, S - 1)) - ds.append(depth) - cs.append(cols) - order = torch.cat(ds).argsort() - img = torch.zeros(S, S, 3, device=dev) - img[torch.cat(vs)[order], torch.cat(us)[order]] = torch.cat(cs)[order] - return preview_to_image(img, do_scale=False) - - @staticmethod - def _turbo(x): - # Anton Mikhailov polynomial approximation of the turbo colormap. x: any shape - # in [0, 1] -> (..., 3) RGB. - x = x.clamp(0.0, 1.0) - x2 = x * x; x3 = x2 * x; x4 = x2 * x2; x5 = x4 * x - r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5 - g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5 - b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5 - return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0) - - def _structure(self, x0): - # x0: [B, 32, D, H, W]; the model only consumes the first 8 channels. - # Dense orthographic max-projection -> filled occupancy heatmap (turbo-colored, - # intensity-weighted so empty space stays black). - act = x0[0, :min(8, x0.shape[1])].float().norm(dim=0) # [D, H, W] - proj = act.amax(dim=2) # project along one axis - proj = (proj - proj.amin()) / (proj.amax() - proj.amin() + 1e-8) - inner = max(1, int(round(self._SIZE * self._STRUCTURE_ZOOM))) - img = torch.nn.functional.interpolate(proj[None, None], size=(inner, inner), mode="nearest") - pad = self._SIZE - inner - pl, pt = pad // 2, pad // 2 - gray = torch.nn.functional.pad(img, (pl, pad - pl, pt, pad - pt))[0, 0] # [S, S], zero margin - rgb = self._turbo(gray) * gray.unsqueeze(-1) # [S, S, 3], black where empty - return preview_to_image(rgb, do_scale=False) - - @staticmethod - def _latent_color(latent): - # Prefer the calibrated latent->base_color map (fit from real decoded - # albedo by VaeDecodeTextureTrellis); fall back to PCA pseudo-color until - # a texture decode has trained it. - try: - from comfy.ldm.trellis2 import sampling_preview - factors = sampling_preview.get_tex_rgb() - except Exception: - factors = None - if factors is not None: - W, b = factors - rgb = latent @ W.to(latent) + b.to(latent) - return rgb.clamp(0, 1) - return Trellis3DPreviewer._pca_color(latent) - - @staticmethod - def _pca_color(latent): - # latent: [n, C]. Map the 3 directions of maximum variance to RGB. - # Higher contrast and more coherent than picking 3 fixed channels. - X = latent - latent.mean(dim=0, keepdim=True) - cov = (X.transpose(0, 1) @ X) / max(X.shape[0] - 1, 1) # [C, C] - _, evecs = torch.linalg.eigh(cov) # ascending eigenvalues - pcs = evecs[:, -3:] # [C, 3] top-3 components - # Deterministic sign per component (largest-magnitude entry positive) to - # stop the preview's hues from flickering as the latent rotates each step. - sign = torch.sign(pcs[pcs.abs().argmax(dim=0), torch.arange(3, device=pcs.device)]) - pcs = pcs * sign.clamp(min=-1.0) - proj = X @ pcs # [n, 3] - pmin = proj.amin(dim=0, keepdim=True) - pmax = proj.amax(dim=0, keepdim=True) - return ((proj - pmin) / (pmax - pmin + 1e-8)).clamp(0, 1) - - def _texture(self, x0, coords, model_frame=None): - if coords.shape[-1] == 4: - b0 = coords[:, 0] == 0 - spatial = coords[b0][:, 1:4].float() - else: - spatial = coords[:, :3].float() - n0 = spatial.shape[0] - if n0 == 0: - return None - if model_frame == "z_up": - spatial = torch.stack([spatial[:, 0], spatial[:, 2], -spatial[:, 1]], dim=-1) - latent = x0[0, :, :n0, 0].float().transpose(0, 1) # [n0, C] - colors = self._latent_color(latent) # [n0, 3] - res = float(spatial.abs().max().item()) + 1.0 - rad = max(1, int(round(self._SIZE * self._FILL / max(res, 1) / 2))) - return self._splat(spatial, colors, rad) - - def decode_latent_to_preview(self, x0): - try: - from comfy.ldm.trellis2 import sampling_preview - ctx = sampling_preview.get_context() - if x0.ndim == 5: - return self._structure(x0) - mode = ctx.get("mode") - coords = ctx.get("coords") - if mode == "texture_generation" and coords is not None: - return self._texture(x0, coords, model_frame=ctx.get("model_frame")) - except Exception as e: - logging.debug(f"Trellis3DPreviewer: skipping preview ({e})") - return None - - def decode_latent_to_preview_image(self, preview_format, x0): - preview_image = self.decode_latent_to_preview(x0) - if preview_image is None: - return None - return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION) - - def get_previewer(device, latent_format): previewer = None method = args.preview_method if method != LatentPreviewMethod.NoPreviews: - if getattr(latent_format, "trellis3d_preview", False): - return Trellis3DPreviewer() # TODO previewer methods taesd_decoder_path = None if latent_format.taesd_decoder_name is not None: