From 6ea2e5b288d14eac984ad38499fd76aa1f9295c7 Mon Sep 17 00:00:00 2001 From: "Yousef R. Gamaleldin" Date: Fri, 30 Jan 2026 23:34:48 +0200 Subject: [PATCH 01/93] init --- comfy/image_encoders/dino3.py | 240 +++++ comfy/image_encoders/dino3_large.json | 24 + comfy/ldm/trellis2/attention.py | 194 ++++ comfy/ldm/trellis2/cumesh.py | 149 ++++ comfy/ldm/trellis2/model.py | 499 +++++++++++ comfy/ldm/trellis2/vae.py | 1185 +++++++++++++++++++++++++ comfy_extras/trellis2.py | 240 +++++ 7 files changed, 2531 insertions(+) create mode 100644 comfy/image_encoders/dino3.py create mode 100644 comfy/image_encoders/dino3_large.json create mode 100644 comfy/ldm/trellis2/attention.py create mode 100644 comfy/ldm/trellis2/cumesh.py create mode 100644 comfy/ldm/trellis2/model.py create mode 100644 comfy/ldm/trellis2/vae.py create mode 100644 comfy_extras/trellis2.py diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py new file mode 100644 index 000000000..d07c2c5b8 --- /dev/null +++ b/comfy/image_encoders/dino3.py @@ -0,0 +1,240 @@ +import math +import torch +import torch.nn as nn + +from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.ldm.flux.math import apply_rope +from dino2 import Dinov2MLP as DINOv3ViTMLP, LayerScale as DINOv3ViTLayerScale + +class DINOv3ViTAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, device, dtype, operations): + super().__init__() + self.embed_dim = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.is_causal = False + + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False + self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + + self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + position_embeddings = torch.stack([cos, sin], dim = -1) + query_states, key_states = apply_rope(query_states, key_states, position_embeddings) + + attn_output, attn_weights = optimized_attention_for_device( + query_states, key_states, value_states, attention_mask, skip_reshape=True, skip_output_reshape=True + ) + + attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + +class DINOv3ViTGatedMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.GELU() + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +def get_patches_center_coordinates( + num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + + coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device) + coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device) + coords_h = coords_h / num_patches_h + coords_w = coords_w / num_patches_w + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) + coords = coords.flatten(0, 1) + coords = 2.0 * coords - 1.0 + return coords + +class DINOv3ViTRopePositionEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, rope_theta, hidden_size, num_attention_heads, image_size, patch_size, device, dtype): + super().__init__() + self.base = rope_theta + self.head_dim = hidden_size // num_attention_heads + self.num_patches_h = image_size // patch_size + self.num_patches_w = image_size // patch_size + + inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + _, _, height, width = pixel_values.shape + num_patches_h = height // self.patch_size + num_patches_w = width // self.patch_size + + device = pixel_values.device + device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu" + with torch.amp.autocast(device_type = device_type, enabled=False): + patch_coords = get_patches_center_coordinates( + num_patches_h, num_patches_w, dtype=torch.float32, device=device + ) + + angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] + angles = angles.flatten(1, 2) + angles = angles.tile(2) + + cos = torch.cos(angles) + sin = torch.sin(angles) + + dtype = pixel_values.dtype + return cos.to(dtype=dtype), sin.to(dtype=dtype) + + +class DINOv3ViTEmbeddings(nn.Module): + def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations): + super().__init__() + self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_size, device=device, dtype=dtype)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype)) + self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype)) + self.patch_embeddings = operations.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype + ) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None): + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embeddings.weight.dtype + + patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) + + if bool_masked_pos is not None: + mask_token = self.mask_token.to(patch_embeddings.dtype) + patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) + + cls_token = self.cls_token.expand(batch_size, -1, -1) + register_tokens = self.register_tokens.expand(batch_size, -1, -1) + embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) + + return embeddings + +class DINOv3ViTLayer(nn.Module): + + def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, layerscale_value, mlp_bias, intermediate_size, num_attention_heads, + device, dtype, operations): + super().__init__() + + self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps) + self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations) + self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, layerscale_value, device=device, dtype=dtype) + + self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + + if use_gated_mlp: + self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations) + else: + self.mlp = DINOv3ViTMLP(hidden_size, device=device, dtype=dtype, operations=operations) + self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, layerscale_value, device=device, dtype=dtype) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states, _ = self.attention( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + hidden_states = self.layer_scale1(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.layer_scale2(hidden_states) + hidden_states = hidden_states + residual + + return hidden_states + + +class DINOv3ViTModel(nn.Module): + def __init__(self, config, device, dtype, operations): + super().__init__() + num_hidden_layers = config["num_hidden_layers"] + hidden_size = config["hidden_size"] + num_attention_heads = config["num_attention_heads"] + num_register_tokens = config["num_register_tokens"] + intermediate_size = config["intermediate_size"] + layer_norm_eps = config["layer_norm_eps"] + layerscale_value = config["layerscale_value"] + num_channels = config["num_channels"] + patch_size = config["patch_size"] + rope_theta = config["rope_theta"] + + self.embeddings = DINOv3ViTEmbeddings( + hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size, dtype=dtype, device=device, operations=operations + ) + self.rope_embeddings = DINOv3ViTRopePositionEmbedding( + rope_theta, hidden_size, num_attention_heads, image_size=512, patch_size=patch_size, dtype=dtype, device=device + ) + self.layer = nn.ModuleList( + [DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, layerscale_value=layerscale_value, mlp_bias=True, + intermediate_size=intermediate_size,num_attention_heads = num_attention_heads, + dtype=dtype, device=device, operations=operations) + for _ in range(num_hidden_layers)]) + self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: torch.Tensor | None = None, + **kwargs, + ): + + pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + position_embeddings = self.rope_embeddings(pixel_values) + + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module( + hidden_states, + position_embeddings=position_embeddings, + ) + + sequence_output = self.norm(hidden_states) + pooled_output = sequence_output[:, 0, :] + + return sequence_output, None, pooled_output, None diff --git a/comfy/image_encoders/dino3_large.json b/comfy/image_encoders/dino3_large.json new file mode 100644 index 000000000..96263f0d6 --- /dev/null +++ b/comfy/image_encoders/dino3_large.json @@ -0,0 +1,24 @@ +{ + + "hidden_size": 384, + "image_size": 224, + "initializer_range": 0.02, + "intermediate_size": 1536, + "key_bias": false, + "layer_norm_eps": 1e-05, + "layerscale_value": 1.0, + "mlp_bias": true, + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "num_register_tokens": 4, + "patch_size": 16, + "pos_embed_rescale": 2.0, + "proj_bias": true, + "query_bias": true, + "rope_theta": 100.0, + "use_gated_mlp": false, + "value_bias": true, + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225] +} diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py new file mode 100644 index 000000000..9cd7d4995 --- /dev/null +++ b/comfy/ldm/trellis2/attention.py @@ -0,0 +1,194 @@ +import torch +import math +from comfy.ldm.modules.attention import optimized_attention +from typing import Tuple, Union, List +from vae import VarLenTensor + +def sparse_windowed_scaled_dot_product_self_attention( + qkv, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +): + + serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args)) + else: + fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if optimized_attention.__name__ == 'attention_xformers': + if 'xops' not in globals(): + import xformers.ops as xops + q, k, v = qkv_feats.unbind(dim=1) + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C] + elif optimized_attention.__name__ == 'attention_flash': + if 'flash_attn' not in globals(): + import flash_attn + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + return qkv.replace(out) + +def calc_window_partition( + tensor, + window_size: Union[int, Tuple[int, ...]], + shift_window: Union[int, Tuple[int, ...]] = 0, +) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: + + DIM = tensor.coords.shape[1] - 1 + shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size + shifted_coords = tensor.coords.clone().detach() + shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) + + MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)] + NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] + OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] + + shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) + shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) + fwd_indices = torch.argsort(shifted_indices) + bwd_indices = torch.empty_like(fwd_indices) + bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) + seq_lens = torch.bincount(shifted_indices) + mask = seq_lens != 0 + seq_lens = seq_lens[mask] + + if optimized_attention.__name__ == 'attention_xformers': + if 'xops' not in globals(): + import xformers.ops as xops + attn_func_args = { + 'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + } + elif optimized_attention.__name__ == 'attention_flash': + attn_func_args = { + 'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(), + 'max_seqlen': torch.max(seq_lens) + } + + return fwd_indices, bwd_indices, seq_lens, attn_func_args + + +def sparse_scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + device = qkv.device + + s = qkv + q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + device = q.device + + if isinstance(q, VarLenTensor): + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, VarLenTensor): + kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] + kv = kv.feats # [T_KV, 2, H, C] + else: + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + device = q.device + + if isinstance(q, VarLenTensor): + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, VarLenTensor): + kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + if optimized_attention.__name__ == 'attention_xformers': + if 'xops' not in globals(): + import xformers.ops as xops + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif optimized_attention.__name__ == 'attention_flash': + if 'flash_attn' not in globals(): + import flash_attn + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif optimized_attention.__name__ == 'flash_attn_3': # TODO + if 'flash_attn_3' not in globals(): + import flash_attn_interface as flash_attn_3 + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + cu_seqlens_kv = cu_seqlens_q.clone() + max_q_seqlen = max_kv_seqlen = max(q_seqlen) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + max_q_seqlen = max(q_seqlen) + max_kv_seqlen = max(kv_seqlen) + elif num_all_args == 3: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + max_q_seqlen = max(q_seqlen) + max_kv_seqlen = max(kv_seqlen) + out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen) + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py new file mode 100644 index 000000000..be8200341 --- /dev/null +++ b/comfy/ldm/trellis2/cumesh.py @@ -0,0 +1,149 @@ +# will contain every cuda -> pytorch operation + +import torch +from typing import Dict + + +class TorchHashMap: + def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): + device = keys.device + # use long for searchsorted + self.sorted_keys, order = torch.sort(keys.long()) + self.sorted_vals = values.long()[order] + self.default_value = torch.tensor(default_value, dtype=torch.long, device=device) + self._n = self.sorted_keys.numel() + + def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: + flat = flat_keys.long() + idx = torch.searchsorted(self.sorted_keys, flat) + found = (idx < self._n) & (self.sorted_keys[idx] == flat) + out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) + if found.any(): + out[found] = self.sorted_vals[idx[found]] + return out + +class Voxel: + def __init__( + self, + origin: list, + voxel_size: float, + coords: torch.Tensor = None, + attrs: torch.Tensor = None, + layout: Dict = {}, + device: torch.device = 'cuda' + ): + self.origin = torch.tensor(origin, dtype=torch.float32, device=device) + self.voxel_size = voxel_size + self.coords = coords + self.attrs = attrs + self.layout = layout + self.device = device + + @property + def position(self): + return (self.coords + 0.5) * self.voxel_size + self.origin[None, :] + + def split_attrs(self): + return { + k: self.attrs[:, self.layout[k]] + for k in self.layout + } + +class Mesh: + def __init__(self, + vertices, + faces, + vertex_attrs=None + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.vertex_attrs = vertex_attrs + + @property + def device(self): + return self.vertices.device + + def to(self, device, non_blocking=False): + return Mesh( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None, + ) + + def cuda(self, non_blocking=False): + return self.to('cuda', non_blocking=non_blocking) + + def cpu(self): + return self.to('cpu') + + # TODO could be an option + def fill_holes(self, max_hole_perimeter=3e-2): + import cumesh + vertices = self.vertices.cuda() + faces = self.faces.cuda() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.get_edges() + mesh.get_boundary_info() + if mesh.num_boundaries == 0: + return + mesh.get_vertex_edge_adjacency() + mesh.get_vertex_boundary_adjacency() + mesh.get_manifold_boundary_adjacency() + mesh.read_manifold_boundary_adjacency() + mesh.get_boundary_connected_components() + mesh.get_boundary_loops() + if mesh.num_boundary_loops == 0: + return + mesh.fill_holes(max_hole_perimeter=max_hole_perimeter) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + + # TODO could be an option + def simplify(self, target=1000000, verbose: bool=False, options: dict={}): + import cumesh + vertices = self.vertices.cuda() + faces = self.faces.cuda() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.simplify(target, verbose=verbose, options=options) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + +class MeshWithVoxel(Mesh, Voxel): + def __init__(self, + vertices: torch.Tensor, + faces: torch.Tensor, + origin: list, + voxel_size: float, + coords: torch.Tensor, + attrs: torch.Tensor, + voxel_shape: torch.Size, + layout: Dict = {}, + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device) + self.voxel_size = voxel_size + self.coords = coords + self.attrs = attrs + self.voxel_shape = voxel_shape + self.layout = layout + + def to(self, device, non_blocking=False): + return MeshWithVoxel( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.origin.tolist(), + self.voxel_size, + self.coords.to(device, non_blocking=non_blocking), + self.attrs.to(device, non_blocking=non_blocking), + self.voxel_shape, + self.layout, + ) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py new file mode 100644 index 000000000..a0889c4dd --- /dev/null +++ b/comfy/ldm/trellis2/model.py @@ -0,0 +1,499 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor +from typing import Optional, Tuple, Literal, Union, List +from attention import sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention +from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder + +class SparseGELU(nn.GELU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio)), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: VarLenTensor) -> VarLenTensor: + return self.mlp(x) + +def manual_cast(tensor, dtype): + if not torch.is_autocast_enabled(): + return tensor.type(dtype) + return tensor +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class SparseMultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + x_type = x.dtype + x = x.float() + if isinstance(x, VarLenTensor): + x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale) + else: + x = F.normalize(x, dim=-1) * self.gamma * self.scale + return x.to(x_type) + +# TODO: replace with apply_rope1 +class SparseRotaryPositionEmbedder(nn.Module): + def __init__( + self, + head_dim: int, + dim: int = 3, + rope_freq: Tuple[float, float] = (1.0, 10000.0) + ): + super().__init__() + assert head_dim % 2 == 0, "Head dim must be divisible by 2" + self.head_dim = head_dim + self.dim = dim + self.rope_freq = rope_freq + self.freq_dim = head_dim // 2 // dim + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases.unsqueeze(-2) + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q (SparseTensor): [..., N, H, D] tensor of queries + k (SparseTensor): [..., N, H, D] tensor of keys + """ + assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1" + phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}' + phases = q.get_spatial_cache(phases_cache_name) + if phases is None: + coords = q.coords[..., 1:] + phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1) + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], padn, device=phases.device), + torch.zeros(*phases.shape[:-1], padn, device=phases.device) + )], dim=-1) + q.register_spatial_cache(phases_cache_name, phases) + q_embed = q.replace(self._rotary_embedding(q.feats, phases)) + if k is None: + return q_embed + k_embed = k.replace(self._rotary_embedding(k.feats, phases)) + return q_embed, k_embed + +class SparseMultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed", "double_windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + ): + super().__init__() + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq) + + @staticmethod + def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.replace(module(x.feats)) + else: + return module(x) + + @staticmethod + def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.reshape(*shape) + else: + return x.reshape(*x.shape[:2], *shape) + + def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + x_feats = x.feats.unsqueeze(0) + else: + x_feats = x + x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) + return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats + + def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor: + if self._type == "self": + qkv = self._linear(self.to_qkv, x) + qkv = self._fused_pre(qkv, num_fused=3) + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=-3) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + q, k = self.rope(q, k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "full": + h = sparse_scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + h = sparse_windowed_scaled_dot_product_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + elif self.attn_mode == "double_windowed": + qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:]) + qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2]) + h0 = sparse_windowed_scaled_dot_product_self_attention( + qkv0, self.window_size, shift_window=(0, 0, 0) + ) + h1 = sparse_windowed_scaled_dot_product_self_attention( + qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3) + ) + h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1)) + else: + q = self._linear(self.to_q, x) + q = self._reshape_chs(q, (self.num_heads, -1)) + kv = self._linear(self.to_kv, context) + kv = self._fused_pre(kv, num_fused=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=-3) + k = self.k_rms_norm(k) + h = sparse_scaled_dot_product_attention(q, k, v) + else: + h = sparse_scaled_dot_product_attention(q, kv) + h = self._reshape_chs(h, (-1,)) + h = self._linear(self.to_out, h) + return h + +class ModulatedSparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedSparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.self_attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + return self._forward(x, mod, context) + + +class SLatFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "rope", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + dtype = None, + device = None, + operations = None, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = dtype + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + self.input_layer = SparseLinear(in_channels, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = SparseLinear(model_channels, out_channels) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward( + self, + x: SparseTensor, + t: torch.Tensor, + cond: Union[torch.Tensor, List[torch.Tensor]], + concat_cond: Optional[SparseTensor] = None, + **kwargs + ) -> SparseTensor: + if concat_cond is not None: + x = sparse_cat([x, concat_cond], dim=-1) + if isinstance(cond, list): + cond = VarLenTensor.from_tensor_list(cond) + + h = self.input_layer(x) + h = manual_cast(h, self.dtype) + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + cond = manual_cast(cond, self.dtype) + + if self.pe_mode == "ape": + pe = self.pos_embedder(h.coords[:, 1:]) + h = h + manual_cast(pe, self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + + h = manual_cast(h, x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return h + +class Trellis2(nn.Module): + def __init__(self, resolution, + in_channels = 32, + out_channels = 32, + model_channels = 1536, + cond_channels = 1024, + num_blocks = 30, + num_heads = 12, + mlp_ratio = 5.3334, + share_mod = True, + qk_rms_norm = True, + qk_rms_norm_cross = True, + dtype=None, device=None, operations=None): + args = { + "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, + "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, + "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations + } + # TODO: update the names/checkpoints + self.img2shape = SLatFlowModel(resolution, in_channels=in_channels, *args) + self.shape2txt = SLatFlowModel(resolution, in_channels=in_channels*2, *args) + self.shape_generation = True + + def forward(self, x, timestep, context): + pass diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py new file mode 100644 index 000000000..1d564bca2 --- /dev/null +++ b/comfy/ldm/trellis2/vae.py @@ -0,0 +1,1185 @@ +import torch +import torch.nn as nn +from typing import List, Any, Dict, Optional, overload, Union, Tuple +from fractions import Fraction +import torch.nn.functional as F +from dataclasses import dataclass +import numpy as np +from cumesh import TorchHashMap, Mesh, MeshWithVoxel + +# TODO: determine which conv they actually use +@dataclass +class config: + CONV = "none" + +# TODO post processing +def simplify(self, target_num_faces: int, verbose: bool=False, options: dict={}): + + num_face = self.cu_mesh.num_faces() + if num_face <= target_num_faces: + return + + thresh = options.get('thresh', 1e-8) + lambda_edge_length = options.get('lambda_edge_length', 1e-2) + lambda_skinny = options.get('lambda_skinny', 1e-3) + while True: + new_num_vert, new_num_face = self.cu_mesh.simplify_step(lambda_edge_length, lambda_skinny, thresh, False) + + if new_num_face <= target_num_faces: + break + + del_num_face = num_face - new_num_face + if del_num_face / num_face < 1e-2: + thresh *= 10 + num_face = new_num_face + +class VarLenTensor: + + def __init__(self, feats: torch.Tensor, layout: List[slice]=None): + self.feats = feats + self.layout = layout if layout is not None else [slice(0, feats.shape[0])] + self._cache = {} + + @staticmethod + def layout_from_seqlen(seqlen: list) -> List[slice]: + """ + Create a layout from a tensor of sequence lengths. + """ + layout = [] + start = 0 + for l in seqlen: + layout.append(slice(start, start + l)) + start += l + return layout + + @staticmethod + def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor': + """ + Create a VarLenTensor from a list of tensors. + """ + feats = torch.cat(tensor_list, dim=0) + layout = [] + start = 0 + for tensor in tensor_list: + layout.append(slice(start, start + tensor.shape[0])) + start += tensor.shape[0] + return VarLenTensor(feats, layout) + + def __len__(self) -> int: + return len(self.layout) + + @property + def shape(self) -> torch.Size: + return torch.Size([len(self.layout), *self.feats.shape[1:]]) + + def dim(self) -> int: + return len(self.shape) + + @property + def ndim(self) -> int: + return self.dim() + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + if 'seqlen' not in self._cache: + self._cache['seqlen'] = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + return self._cache['seqlen'] + + @property + def cum_seqlen(self) -> torch.LongTensor: + if 'cum_seqlen' not in self._cache: + self._cache['cum_seqlen'] = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + return self._cache['cum_seqlen'] + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + if 'batch_boardcast_map' not in self._cache: + self._cache['batch_boardcast_map'] = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + return self._cache['batch_boardcast_map'] + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + def to(self, *args, **kwargs) -> 'VarLenTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'VarLenTensor': + new_feats = self.feats.cpu() + return self.replace(new_feats) + + def cuda(self) -> 'VarLenTensor': + new_feats = self.feats.cuda() + return self.replace(new_feats) + + def half(self) -> 'VarLenTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'VarLenTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'VarLenTensor': + new_feats = self.feats.detach() + return self.replace(new_feats) + + def reshape(self, *shape) -> 'VarLenTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['VarLenTensor']: + return varlen_unbind(self, dim) + + def replace(self, feats: torch.Tensor) -> 'VarLenTensor': + new_tensor = VarLenTensor( + feats=feats, + layout=self.layout, + ) + new_tensor._cache = self._cache + return new_tensor + + def to_dense(self, max_length=None) -> torch.Tensor: + N = len(self) + L = max_length or self.seqlen.max().item() + spatial = self.feats.shape[1:] + idx = torch.arange(L, device=self.device).unsqueeze(0).expand(N, L) + mask = (idx < self.seqlen.unsqueeze(1)) + mapping = mask.reshape(-1).cumsum(dim=0) - 1 + dense = self.feats[mapping] + dense = dense.reshape(N, L, *spatial) + return dense, mask + + def __neg__(self) -> 'VarLenTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_feats = [] + new_layout = [] + start = 0 + for new_idx, old_idx in enumerate(idx): + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_feats[-1]))) + start += len(new_feats[-1]) + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = VarLenTensor(feats=new_feats, layout=new_layout) + return new_tensor + + def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + if isinstance(dim, int): + dim = (dim,) + + if op =='mean': + red = self.feats.mean(dim=dim, keepdim=keepdim) + elif op =='sum': + red = self.feats.sum(dim=dim, keepdim=keepdim) + elif op == 'prod': + red = self.feats.prod(dim=dim, keepdim=keepdim) + else: + raise ValueError(f"Unsupported reduce operation: {op}") + + if dim is None or 0 in dim: + return red + + red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen) + return red + + def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='mean', dim=dim, keepdim=keepdim) + + def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='sum', dim=dim, keepdim=keepdim) + + def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='prod', dim=dim, keepdim=keepdim) + + def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + mean = self.mean(dim=dim, keepdim=True) + mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True) + std = (mean2 - mean ** 2).sqrt() + return std + + def __repr__(self) -> str: + return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + +def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]: + + if dim == 0: + return [input[i] for i in range(len(input))] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] + + +class SparseTensor(VarLenTensor): + + SparseTensorData = None + + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + if self.SparseTensorData is None: + import importlib + if config.CONV == 'torchsparse': + self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif config.CONV == 'spconv': + self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape = args + (None,) * (3 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + if config.CONV == 'torchsparse': + self.data = self.SparseTensorData(feats, coords, **kwargs) + elif config.CONV == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1) + self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs) + self.data._features = feats + else: + self.data = { + 'feats': feats, + 'coords': coords, + } + elif method_id == 1: + data, shape = args + (None,) * (2 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + self.data = data + + self._shape = shape + self._scale = kwargs.get('scale', (Fraction(1, 1), Fraction(1, 1), Fraction(1, 1))) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + @staticmethod + def from_tensor_list(feats_list: List[torch.Tensor], coords_list: List[torch.Tensor]) -> 'SparseTensor': + """ + Create a SparseTensor from a list of tensors. + """ + feats = torch.cat(feats_list, dim=0) + coords = [] + for i, coord in enumerate(coords_list): + coord = torch.cat([torch.full_like(coord[:, :1], i), coord[:, 1:]], dim=1) + coords.append(coord) + coords = torch.cat(coords, dim=0) + return SparseTensor(feats, coords) + + def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Convert a SparseTensor to list of tensors. + """ + feats_list = [] + coords_list = [] + for s in self.layout: + feats_list.append(self.feats[s]) + coords_list.append(self.coords[s]) + return feats_list, coords_list + + def __len__(self) -> int: + return len(self.layout) + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + def __cal_spatial_shape(self, coords): + return torch.Size((coords[:, 1:].max(0)[0] + 1).tolist()) + + @property + def shape(self) -> torch.Size: + if self._shape is None: + self._shape = self.__cal_shape(self.feats, self.coords) + return self._shape + + @property + def layout(self) -> List[slice]: + layout = self.get_spatial_cache('layout') + if layout is None: + layout = self.__cal_layout(self.coords, self.shape[0]) + self.register_spatial_cache('layout', layout) + return layout + + @property + def spatial_shape(self) -> torch.Size: + spatial_shape = self.get_spatial_cache('shape') + if spatial_shape is None: + spatial_shape = self.__cal_spatial_shape(self.coords) + self.register_spatial_cache('shape', spatial_shape) + return spatial_shape + + @property + def feats(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.F + elif config.CONV == 'spconv': + return self.data.features + else: + return self.data['feats'] + + @feats.setter + def feats(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.F = value + elif config.CONV == 'spconv': + self.data.features = value + else: + self.data['feats'] = value + + @property + def coords(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.C + elif config.CONV == 'spconv': + return self.data.indices + else: + return self.data['coords'] + + @coords.setter + def coords(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.C = value + elif config.CONV == 'spconv': + self.data.indices = value + else: + self.data['coords'] = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + seqlen = self.get_spatial_cache('seqlen') + if seqlen is None: + seqlen = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + self.register_spatial_cache('seqlen', seqlen) + return seqlen + + @property + def cum_seqlen(self) -> torch.LongTensor: + cum_seqlen = self.get_spatial_cache('cum_seqlen') + if cum_seqlen is None: + cum_seqlen = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + self.register_spatial_cache('cum_seqlen', cum_seqlen) + return cum_seqlen + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + batch_boardcast_map = self.get_spatial_cache('batch_boardcast_map') + if batch_boardcast_map is None: + batch_boardcast_map = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + self.register_spatial_cache('batch_boardcast_map', batch_boardcast_map) + return batch_boardcast_map + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + new_coords = self.coords.to(device=device, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + if config.CONV == 'torchsparse': + new_data = self.SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif config.CONV == 'spconv': + new_data = self.SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + else: + new_data = { + 'feats': feats, + 'coords': self.data['coords'] if coords is None else coords, + } + new_tensor = SparseTensor( + new_data, + shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None, + scale=self._scale, + spatial_cache=self._spatial_cache + ) + return new_tensor + + def to_dense(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.dense() + elif config.CONV == 'spconv': + return self.data.dense() + else: + spatial_shape = self.spatial_shape + ret = torch.zeros(*self.shape, *spatial_shape, dtype=self.dtype, device=self.device) + idx = [self.coords[:, 0], slice(None)] + self.coords[:, 1:].unbind(1) + ret[tuple(idx)] = self.feats + return ret + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_coords = [] + new_feats = [] + new_layout = [] + new_shape = torch.Size([len(idx)] + list(self.shape[1:])) + start = 0 + for new_idx, old_idx in enumerate(idx): + new_coords.append(self.coords[self.layout[old_idx]].clone()) + new_coords[-1][:, 0] = new_idx + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_coords[-1]))) + start += len(new_coords[-1]) + new_coords = torch.cat(new_coords, dim=0).contiguous() + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = SparseTensor(feats=new_feats, coords=new_coords, shape=new_shape) + new_tensor.register_spatial_cache('layout', new_layout) + return new_tensor + + def clear_spatial_cache(self) -> None: + """ + Clear all spatial caches. + """ + self._spatial_cache = {} + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + def __repr__(self) -> str: + return f"SparseTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] + +class SparseLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(SparseLinear, self).__init__(in_features, out_features, bias) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + + +class SparseUnetVaeEncoder(nn.Module): + """ + Sparse Swin Transformer Unet VAE model. + """ + def __init__( + self, + in_channels: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + down_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = SparseLinear(in_channels, model_channels[0]) + self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels) + + self.blocks = nn.ModuleList([]) + for i in range(len(num_blocks)): + self.blocks.append(nn.ModuleList([])) + for j in range(num_blocks[i]): + self.blocks[-1].append( + globals()[block_type[i]]( + model_channels[i], + **block_args[i], + ) + ) + if i < len(num_blocks) - 1: + self.blocks[-1].append( + globals()[down_block_type[i]]( + model_channels[i], + model_channels[i+1], + **block_args[i], + ) + ) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward(self, x: SparseTensor, sample_posterior=False, return_raw=False): + h = self.input_layer(x) + h = h.type(self.dtype) + for i, res in enumerate(self.blocks): + for j, block in enumerate(res): + h = block(h) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.to_latent(h) + + # Sample from the posterior distribution + mean, logvar = h.feats.chunk(2, dim=-1) + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + z = h.replace(z) + + if return_raw: + return z, mean, logvar + else: + return z + + + +class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder): + def __init__( + self, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + down_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + ): + super().__init__( + 6, + model_channels, + latent_channels, + num_blocks, + block_type, + down_block_type, + block_args, + use_fp16, + ) + + def forward(self, vertices: SparseTensor, intersected: SparseTensor, sample_posterior=False, return_raw=False): + x = vertices.replace(torch.cat([ + vertices.feats - 0.5, + intersected.feats.float() - 0.5, + ], dim=1)) + return super().forward(x, sample_posterior, return_raw) + +class SparseUnetVaeDecoder(nn.Module): + """ + Sparse Swin Transformer Unet VAE model. + """ + def __init__( + self, + out_channels: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.out_channels = out_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.use_fp16 = use_fp16 + self.pred_subdiv = pred_subdiv + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.low_vram = False + + self.output_layer = SparseLinear(model_channels[-1], out_channels) + self.from_latent = SparseLinear(latent_channels, model_channels[0]) + + self.blocks = nn.ModuleList([]) + for i in range(len(num_blocks)): + self.blocks.append(nn.ModuleList([])) + for j in range(num_blocks[i]): + self.blocks[-1].append( + globals()[block_type[i]]( + model_channels[i], + **block_args[i], + ) + ) + if i < len(num_blocks) - 1: + self.blocks[-1].append( + globals()[up_block_type[i]]( + model_channels[i], + model_channels[i+1], + pred_subdiv=pred_subdiv, + **block_args[i], + ) + ) + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward(self, x: SparseTensor, guide_subs: Optional[List[SparseTensor]] = None, return_subs: bool = False) -> SparseTensor: + + h = self.from_latent(x) + h = h.type(self.dtype) + subs = [] + for i, res in enumerate(self.blocks): + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + if self.pred_subdiv: + h, sub = block(h) + subs.append(sub) + else: + h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None) + else: + h = block(h) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.output_layer(h) + if return_subs: + return h, subs + else: + return h + + def upsample(self, x: SparseTensor, upsample_times: int) -> torch.Tensor: + + h = self.from_latent(x) + h = h.type(self.dtype) + for i, res in enumerate(self.blocks): + if i == upsample_times: + return h.coords + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + h, sub = block(h) + else: + h = block(h) + +class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder): + def __init__( + self, + resolution: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + voxel_margin: float = 0.5, + use_fp16: bool = False, + ): + self.resolution = resolution + self.voxel_margin = voxel_margin + # cache for a TorchHashMap instance + self._torch_hashmap_cache = None + + super().__init__( + 7, + model_channels, + latent_channels, + num_blocks, + block_type, + up_block_type, + block_args, + use_fp16, + ) + + def set_resolution(self, resolution: int) -> None: + self.resolution = resolution + + def _build_or_get_hashmap(self, coords: torch.Tensor, grid_size: torch.Tensor): + device = coords.device + N = coords.shape[0] + # compute flat keys for all coords (prepend batch 0 same as original code) + b = torch.zeros((N,), dtype=torch.long, device=device) + x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long() + W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + flat_keys = b * (W * H * D) + x * (H * D) + y * D + z + values = torch.arange(N, dtype=torch.long, device=device) + DEFAULT_VAL = 0xffffffff # sentinel used in original code + return TorchHashMap(flat_keys, values, DEFAULT_VAL) + + def forward(self, x: SparseTensor, gt_intersected: SparseTensor = None, **kwargs): + decoded = super().forward(x, **kwargs) + out_list = list(decoded) if isinstance(decoded, tuple) else [decoded] + h = out_list[0] + vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin) + intersected = h.replace(h.feats[..., 3:6] > 0) + quad_lerp = h.replace(F.softplus(h.feats[..., 6:7])) + mesh = [Mesh(*flexible_dual_grid_to_mesh( + v.coords[:, 1:], v.feats, i.feats, q.feats, + aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + grid_size=self.resolution, + train=False, + hashmap_builder=self._build_or_get_hashmap, + )) for v, i, q in zip(vertices, intersected, quad_lerp)] + out_list[0] = mesh + return out_list[0] if len(out_list) == 1 else tuple(out_list) + +def flexible_dual_grid_to_mesh( + coords: torch.Tensor, + dual_vertices: torch.Tensor, + intersected_flag: torch.Tensor, + split_weight: Union[torch.Tensor, None], + aabb: Union[list, tuple, np.ndarray, torch.Tensor], + voxel_size: Union[float, list, tuple, np.ndarray, torch.Tensor] = None, + grid_size: Union[int, list, tuple, np.ndarray, torch.Tensor] = None, + train: bool = False, + hashmap_builder=None, # optional callable for building/caching a TorchHashMap +): + + if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset"): + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([ + [[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis + [[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis + [[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis + ], dtype=torch.int, device=coords.device).unsqueeze(0) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1"): + flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=coords.device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2"): + flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=coords.device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train"): + flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=coords.device, requires_grad=False) + + # AABB + if isinstance(aabb, (list, tuple)): + aabb = np.array(aabb) + if isinstance(aabb, np.ndarray): + aabb = torch.tensor(aabb, dtype=torch.float32, device=coords.device) + + # Voxel size + if voxel_size is not None: + if isinstance(voxel_size, float): + voxel_size = [voxel_size, voxel_size, voxel_size] + if isinstance(voxel_size, (list, tuple)): + voxel_size = np.array(voxel_size) + if isinstance(voxel_size, np.ndarray): + voxel_size = torch.tensor(voxel_size, dtype=torch.float32, device=coords.device) + grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int() + else: + if isinstance(grid_size, int): + grid_size = [grid_size, grid_size, grid_size] + if isinstance(grid_size, (list, tuple)): + grid_size = np.array(grid_size) + if isinstance(grid_size, np.ndarray): + grid_size = torch.tensor(grid_size, dtype=torch.int32, device=coords.device) + voxel_size = (aabb[1] - aabb[0]) / grid_size + + # Extract mesh + N = dual_vertices.shape[0] + mesh_vertices = (coords.float() + dual_vertices) / (2 * N) - 0.5 + + if hashmap_builder is None: + # build local TorchHashMap + device = coords.device + b = torch.zeros((N,), dtype=torch.long, device=device) + x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long() + W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + flat_keys = b * (W * H * D) + x * (H * D) + y * D + z + values = torch.arange(N, dtype=torch.long, device=device) + DEFAULT_VAL = 0xffffffff + torch_hashmap = TorchHashMap(flat_keys, values, DEFAULT_VAL) + else: + torch_hashmap = hashmap_builder(coords, grid_size) + + # Find connected voxels + edge_neighbor_voxel = coords.reshape(N, 1, 1, 3) + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset # (N, 3, 4, 3) + connected_voxel = edge_neighbor_voxel[intersected_flag] # (M, 4, 3) + M = connected_voxel.shape[0] + # flatten connected voxel coords and lookup + conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device) + conn_x = connected_voxel.reshape(-1, 3)[:, 0].long() + conn_y = connected_voxel.reshape(-1, 3)[:, 1].long() + conn_z = connected_voxel.reshape(-1, 3)[:, 2].long() + W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z + + conn_indices = torch_hashmap.lookup_flat(conn_flat).reshape(M, 4).int() + connected_voxel_valid = (conn_indices != 0xffffffff).all(dim=1) + quad_indices = conn_indices[connected_voxel_valid].int() # (L, 4) + + mesh_vertices = (coords.float() + dual_vertices) * voxel_size + aabb[0].reshape(1, 3) + if split_weight is None: + # if split 1 + atempt_triangles_0 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1] + normals0 = torch.cross(mesh_vertices[atempt_triangles_0[:, 1]] - mesh_vertices[atempt_triangles_0[:, 0]], mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 0]]) + normals1 = torch.cross(mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 1]], mesh_vertices[atempt_triangles_0[:, 3]] - mesh_vertices[atempt_triangles_0[:, 1]]) + align0 = (normals0 * normals1).sum(dim=1, keepdim=True).abs() + # if split 2 + atempt_triangles_1 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2] + normals0 = torch.cross(mesh_vertices[atempt_triangles_1[:, 1]] - mesh_vertices[atempt_triangles_1[:, 0]], mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 0]]) + normals1 = torch.cross(mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 1]], mesh_vertices[atempt_triangles_1[:, 3]] - mesh_vertices[atempt_triangles_1[:, 1]]) + align1 = (normals0 * normals1).sum(dim=1, keepdim=True).abs() + # select split + mesh_triangles = torch.where(align0 > align1, atempt_triangles_0, atempt_triangles_1).reshape(-1, 3) + else: + split_weight_ws = split_weight[quad_indices] + split_weight_ws_02 = split_weight_ws[:, 0] * split_weight_ws[:, 2] + split_weight_ws_13 = split_weight_ws[:, 1] * split_weight_ws[:, 3] + mesh_triangles = torch.where( + split_weight_ws_02 > split_weight_ws_13, + quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1], + quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2] + ).reshape(-1, 3) + + return mesh_vertices, mesh_triangles + +class Vae(nn.Module): + def __init__(self, config, operations=None): + operations = operations or torch.nn + + self.txt_dec = SparseUnetVaeDecoder( + out_channels=6, + model_channels=[1024, 512, 256, 128, 64], + latent_channels=32, + num_blocks=[4, 16, 8, 4, 0], + block_type=["SparseConvNeXtBlock3d"] * 5, + up_block_type=["SparseResBlockS2C3d"] * 4, + pred_subdiv=False + ) + + self.shape_dec = FlexiDualGridVaeDecoder( + resolution=256, + model_channels=[1024, 512, 256, 128, 64], + latent_channels=32, + num_blocks=[4, 16, 8, 4, 0], + block_type=["SparseConvNeXtBlock3d"] * 5, + up_block_type=["SparseResBlockS2C3d"] * 4, + ) + + def decode_shape_slat(self, slat, resolution: int): + self.shape_dec.set_resolution(resolution) + return self.shape_dec(slat, return_subs=True) + + def decode_tex_slat(self, slat, subs): + return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5 + + @torch.no_grad() + def decode( + self, + shape_slat: SparseTensor, + tex_slat: SparseTensor, + resolution: int, + ): + meshes, subs = self.decode_shape_slat(shape_slat, resolution) + tex_voxels = self.decode_tex_slat(tex_slat, subs) + out_mesh = [] + for m, v in zip(meshes, tex_voxels): + m.fill_holes() # TODO + out_mesh.append( + MeshWithVoxel( + m.vertices, m.faces, + origin = [-0.5, -0.5, -0.5], + voxel_size = 1 / resolution, + coords = v.coords[:, 1:], + attrs = v.feats, + voxel_shape = torch.Size([*v.shape, *v.spatial_shape]), + layout=self.pbr_attr_layout + ) + ) + return out_mesh diff --git a/comfy_extras/trellis2.py b/comfy_extras/trellis2.py new file mode 100644 index 000000000..c3ad56007 --- /dev/null +++ b/comfy_extras/trellis2.py @@ -0,0 +1,240 @@ +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO +import torch +from comfy.ldm.trellis2.model import SparseTensor +import comfy.model_management +from PIL import Image +import PIL +import numpy as np + +shape_slat_normalization = { + "mean": torch.tensor([ + 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218, + -0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944, + 0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667, + -0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665 + ])[None], + "std": torch.tensor([ + 5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004, + 5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578, + 4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194, + 5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691 + ])[None] +} + +tex_slat_normalization = { + "mean": torch.tensor([ + 3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075, + 0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149, + -1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717, + 1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862 + ])[None], + "std": torch.tensor([ + 2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822, + 2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588, + 2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999, + 2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190 + ])[None] +} + +def smart_crop_square( + image: torch.Tensor, + background_color=(128, 128, 128), +): + C, H, W = image.shape + size = max(H, W) + canvas = torch.empty( + (C, size, size), + dtype=image.dtype, + device=image.device + ) + for c in range(C): + canvas[c].fill_(background_color[c]) + top = (size - H) // 2 + left = (size - W) // 2 + canvas[:, top:top + H, left:left + W] = image + + return canvas + +def run_conditioning( + model, + image: torch.Tensor, + include_1024: bool = True, + background_color: str = "black", +): + # TODO: should check if normalization was applied in these steps + model = model.model + device = comfy.model_management.intermediate_device() # replaces .cpu() + torch_device = comfy.model_management.get_torch_device() # replaces .cuda() + bg_colors = { + "black": (0, 0, 0), + "gray": (128, 128, 128), + "white": (255, 255, 255), + } + bg_color = bg_colors.get(background_color, (128, 128, 128)) + + # Convert image to PIL + if image.dim() == 4: + pil_image = (image[0] * 255).clip(0, 255).astype(torch.uint8) + else: + pil_image = (image * 255).clip(0, 255).astype(torch.uint8) + + pil_image = smart_crop_square(pil_image, background_color=bg_color) + + model.image_size = 512 + def set_image_size(image, image_size=512): + image = PIL.from_array(image) + image = [i.resize((image_size, image_size), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).to(torch_device) + return image + + pil_image = set_image_size(image, 512) + cond_512 = model([pil_image]) + + cond_1024 = None + if include_1024: + model.image_size = 1024 + pil_image = set_image_size(pil_image, 1024) + cond_1024 = model([pil_image]) + + neg_cond = torch.zeros_like(cond_512) + + conditioning = { + 'cond_512': cond_512.to(device), + 'neg_cond': neg_cond.to(device), + } + if cond_1024 is not None: + conditioning['cond_1024'] = cond_1024.to(device) + + preprocessed_tensor = pil_image.to(torch.float32) / 255.0 + preprocessed_tensor = torch.from_numpy(preprocessed_tensor).unsqueeze(0) + + return conditioning, preprocessed_tensor + +class VaeDecodeShapeTrellis(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VaeDecodeShapeTrellis", + category="latent/3d", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.Int.Input("resolution", tooltip="Shape Generation Resolution"), + ], + outputs=[ + IO.Mesh.Output("mesh"), + IO.AnyType.Output("shape_subs"), + ] + ) + + @classmethod + def execute(cls, samples, vae, resolution): + std = shape_slat_normalization["std"] + mean = shape_slat_normalization["mean"] + samples = samples * std + mean + + mesh, subs = vae.decode_shape_slat(resolution, samples) + return mesh, subs + +class VaeDecodeTextureTrellis(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VaeDecodeTextureTrellis", + category="latent/3d", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.AnyType.Input("shape_subs"), + ], + outputs=[ + IO.Mesh.Output("mesh"), + ] + ) + + @classmethod + def execute(cls, samples, vae, shape_subs): + if shape_subs is None: + raise ValueError("Shape subs must be provided for texture generation") + + std = tex_slat_normalization["std"] + mean = tex_slat_normalization["mean"] + samples = samples * std + mean + + mesh = vae.decode_tex_slat(samples, shape_subs) + return mesh + +class Trellis2Conditioning(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Trellis2Conditioning", + category="conditioning/video_models", + inputs=[ + IO.ClipVision.Input("clip_vision_model"), + IO.Image.Input("image"), + IO.MultiCombo.Input("background_color", options=["black", "gray", "white"], default="black") + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ] + ) + + @classmethod + def execute(cls, clip_vision_model, image, background_color) -> IO.NodeOutput: + # could make 1024 an option + conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color) + embeds = conditioning["cond_1024"] # should add that + positive = [[conditioning["cond_512"], {embeds}]] + negative = [[conditioning["cond_neg"], {embeds}]] + return IO.NodeOutput(positive, negative) + +class EmptyLatentTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentTrellis2", + category="latent/3d", + inputs=[ + IO.Int.Input("resolution", default=3072, min=1, max=8192), + IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), + IO.Vae.Input("vae"), + IO.Boolean.Input("shape_generation", tooltip="Setting to false will generate texture."), + IO.MultiCombo.Input("generation_type", options=["structure_generation", "shape_generation", "texture_generation"]) + ], + outputs=[ + IO.Latent.Output(), + ] + ) + + @classmethod + def execute(cls, batch_size, coords, vae, generation_type) -> IO.NodeOutput: + # TODO: i will probably update how shape/texture is generated + # could split this too + in_channels = 32 + shape_generation = generation_type == "shape_generation" + device = comfy.model_management.intermediate_device() + if shape_generation: + latent = SparseTensor(feats=torch.randn(batch_size, in_channels).to(device), coords=coords) + else: + # coords = shape_slat in txt gen case + latent = coords.replace(feats=torch.randn(coords.coords.shape[0], in_channels - coords.feats.shape[1]).to(device)) + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + +class Trellis2Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + Trellis2Conditioning, + EmptyLatentTrellis2, + VaeDecodeTextureTrellis, + VaeDecodeShapeTrellis + ] + + +async def comfy_entrypoint() -> Trellis2Extension: + return Trellis2Extension() From 23474ce816845341b2a79f9f31d5ff7df1709e85 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 2 Feb 2026 21:20:46 +0200 Subject: [PATCH 02/93] updated the trellis2 nodes --- .../{trellis2.py => nodes_trellis2.py} | 66 ++++++++++++++----- nodes.py | 3 +- 2 files changed, 51 insertions(+), 18 deletions(-) rename comfy_extras/{trellis2.py => nodes_trellis2.py} (82%) diff --git a/comfy_extras/trellis2.py b/comfy_extras/nodes_trellis2.py similarity index 82% rename from comfy_extras/trellis2.py rename to comfy_extras/nodes_trellis2.py index c3ad56007..304d95493 100644 --- a/comfy_extras/trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -193,7 +193,48 @@ class Trellis2Conditioning(IO.ComfyNode): negative = [[conditioning["cond_neg"], {embeds}]] return IO.NodeOutput(positive, negative) -class EmptyLatentTrellis2(IO.ComfyNode): +class EmptyShapeLatentTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentTrellis2", + category="latent/3d", + inputs=[ + IO.Latent.Input("structure_output"), + ], + outputs=[ + IO.Latent.Output(), + ] + ) + + def execute(cls, structure_output): + # i will see what i have to do here + coords = structure_output or structure_output.coords + in_channels = 32 + latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + +class EmptyTextureLatentTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentTrellis2", + category="latent/3d", + inputs=[ + IO.Latent.Input("structure_output"), + ], + outputs=[ + IO.Latent.Output(), + ] + ) + + def execute(cls, structure_output): + # TODO + in_channels = 32 + latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1])) + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + +class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( @@ -202,35 +243,26 @@ class EmptyLatentTrellis2(IO.ComfyNode): inputs=[ IO.Int.Input("resolution", default=3072, min=1, max=8192), IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), - IO.Vae.Input("vae"), - IO.Boolean.Input("shape_generation", tooltip="Setting to false will generate texture."), - IO.MultiCombo.Input("generation_type", options=["structure_generation", "shape_generation", "texture_generation"]) ], outputs=[ IO.Latent.Output(), ] ) - - @classmethod - def execute(cls, batch_size, coords, vae, generation_type) -> IO.NodeOutput: - # TODO: i will probably update how shape/texture is generated - # could split this too + + def execute(cls, res, batch_size): in_channels = 32 - shape_generation = generation_type == "shape_generation" - device = comfy.model_management.intermediate_device() - if shape_generation: - latent = SparseTensor(feats=torch.randn(batch_size, in_channels).to(device), coords=coords) - else: - # coords = shape_slat in txt gen case - latent = coords.replace(feats=torch.randn(coords.coords.shape[0], in_channels - coords.feats.shape[1]).to(device)) + latent = torch.randn(batch_size, in_channels, res, res, res) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + class Trellis2Extension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ Trellis2Conditioning, - EmptyLatentTrellis2, + EmptyShapeLatentTrellis2, + EmptyStructureLatentTrellis2, + EmptyTextureLatentTrellis2, VaeDecodeTextureTrellis, VaeDecodeShapeTrellis ] diff --git a/nodes.py b/nodes.py index 1cb43d9e2..051e808cc 100644 --- a/nodes.py +++ b/nodes.py @@ -2433,7 +2433,8 @@ async def init_builtin_extra_nodes(): "nodes_image_compare.py", "nodes_zimage.py", "nodes_lora_debug.py", - "nodes_color.py" + "nodes_color.py", + "nodes_trellis2.py" ] import_failed = [] From 614b167994aa847fda2b4a718048a2cdbc84c132 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 2 Feb 2026 21:23:19 +0200 Subject: [PATCH 03/93] . --- comfy_extras/nodes_trellis2.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 304d95493..70bbbb29d 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -206,7 +206,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): IO.Latent.Output(), ] ) - + + @classmethod def execute(cls, structure_output): # i will see what i have to do here coords = structure_output or structure_output.coords @@ -227,7 +228,8 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): IO.Latent.Output(), ] ) - + + @classmethod def execute(cls, structure_output): # TODO in_channels = 32 @@ -248,7 +250,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): IO.Latent.Output(), ] ) - + @classmethod def execute(cls, res, batch_size): in_channels = 32 latent = torch.randn(batch_size, in_channels, res, res, res) From f76e3a11b5dc4c09f2edb087c013ba7bcc5c7a6b Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 2 Feb 2026 21:27:15 +0200 Subject: [PATCH 04/93] .. --- comfy/ldm/trellis2/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index a0889c4dd..cdbfbf6fc 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -1,9 +1,9 @@ import torch import torch.nn.functional as F import torch.nn as nn -from vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor +from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor from typing import Optional, Tuple, Literal, Union, List -from attention import sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention +from comfy.ldm.trellis2.attention import sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder class SparseGELU(nn.GELU): From d6573fd26d63e6fe00515e903d4c2535b02fbeea Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:10:20 +0200 Subject: [PATCH 05/93] model init working --- comfy/latent_formats.py | 2 + comfy/ldm/trellis2/cumesh.py | 387 ++++++++++++++++++++++++++++++++++- comfy/ldm/trellis2/model.py | 20 +- comfy/ldm/trellis2/vae.py | 221 +++++++++++++++++++- comfy/model_base.py | 8 + comfy/supported_models.py | 13 +- 6 files changed, 639 insertions(+), 12 deletions(-) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 4b3a3798c..fc4c4e6d3 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -746,6 +746,8 @@ class Hunyuan3Dv2_1(LatentFormat): latent_channels = 64 latent_dimensions = 1 +class Trellis2(LatentFormat): # TODO + latent_channels = 32 class Hunyuan3Dv2mini(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index be8200341..41ac35db9 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -1,8 +1,192 @@ # will contain every cuda -> pytorch operation +import math import torch -from typing import Dict +from typing import Dict, Callable +NO_TRITION = False +try: + import triton + import triton.language as tl + heuristics = { + 'valid_kernel': lambda args: args['valid_kernel'](args['B1']), + 'valid_kernel_seg': lambda args: args['valid_kernel_seg'](args['B1']), + } + + #@triton_autotune( + # configs=config.autotune_config, + # key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'], + #) + @triton.heuristics(heuristics) + @triton.jit + def sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel( + input, + weight, + bias, + neighbor, + sorted_idx, + output, + # Tensor dimensions + N, LOGN, Ci, Co, V: tl.constexpr, + # Meta-parameters + B1: tl.constexpr, # Block size for N dimension + B2: tl.constexpr, # Block size for Co dimension + BK: tl.constexpr, # Block size for K dimension (V * Ci) + allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls + # Huristic parameters + valid_kernel, + valid_kernel_seg, + ): + + block_id = tl.program_id(axis=0) + block_dim_co = tl.cdiv(Co, B2) + block_id_co = block_id % block_dim_co + block_id_n = block_id // block_dim_co + + # Create pointers for submatrices of A and B. + num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension + valid_kernel_start = tl.load(valid_kernel_seg + block_id_n) + valid_kernel_seglen = tl.load(valid_kernel_seg + block_id_n + 1) - valid_kernel_start + offset_n = block_id_n * B1 + tl.arange(0, B1) + n_mask = offset_n < N + offset_sorted_n = tl.load(sorted_idx + offset_n, mask=n_mask, other=0) # (B1,) + offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,) + offset_k = tl.arange(0, BK) # (BK,) + + # Create a block of the output matrix C. + accumulator = tl.zeros((B1, B2), dtype=tl.float32) + + # Iterate along V*Ci dimension. + for k in range(num_k * valid_kernel_seglen): + v = k // num_k + bk = k % num_k + v = tl.load(valid_kernel + valid_kernel_start + v) + # Calculate pointers to input matrix. + neighbor_offset_n = tl.load(neighbor + offset_sorted_n * V + v) # (B1,) + input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK) + # Calculate pointers to weight matrix. + weight_ptr = weight + v * Ci + bk * BK + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2) + # Load the next block of input and weight. + neigh_mask = neighbor_offset_n != 0xffffffff + k_mask = offset_k < Ci - bk * BK + input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0) + weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0) + # Accumulate along the K dimension. + accumulator = tl.dot(input_block, weight_block, accumulator, + input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2) + c = accumulator.to(input.type.element_ty) + + # add bias + if bias is not None: + bias_block = tl.load(bias + offset_co) + c += bias_block[None, :] + + # Write back the block of the output matrix with masks. + out_offset_n = offset_sorted_n + out_offset_co = block_id_co * B2 + tl.arange(0, B2) + out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :]) + out_mask = n_mask[:, None] & (out_offset_co[None, :] < Co) + tl.store(out_ptr, c, mask=out_mask) + def sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + neighbor: torch.Tensor, + sorted_idx: torch.Tensor, + valid_kernel: Callable[[int], torch.Tensor], + valid_kernel_seg: Callable[[int], torch.Tensor], + ) -> torch.Tensor: + N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1] + LOGN = int(math.log2(N)) + output = torch.empty((N, Co), device=input.device, dtype=input.dtype) + grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),) + sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid]( + input, weight, bias, neighbor, sorted_idx, output, + N, LOGN, Ci, Co, V, # + valid_kernel=valid_kernel, + valid_kernel_seg=valid_kernel_seg, + allow_tf32=torch.cuda.is_tf32_supported(), + ) + return output +except: + NO_TRITION = True + +def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device): + # offsets in same order as CUDA kernel + offsets = [] + for vx in range(Kw): + for vy in range(Kh): + for vz in range(Kd): + offsets.append(( + vx * Dw, + vy * Dh, + vz * Dd + )) + return torch.tensor(offsets, device=device) + +def build_submanifold_neighbor_map( + hashmap, + coords: torch.Tensor, + W, H, D, + Kw, Kh, Kd, + Dw, Dh, Dd, +): + device = coords.device + M = coords.shape[0] + V = Kw * Kh * Kd + half_V = V // 2 + 1 + + INVALID = hashmap.default_value + + neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long) + + b = coords[:, 0] + x = coords[:, 1] + y = coords[:, 2] + z = coords[:, 3] + + offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device) + + ox = x[:, None] - (Kw // 2) * Dw + oy = y[:, None] - (Kh // 2) * Dh + oz = z[:, None] - (Kd // 2) * Dd + + for v in range(half_V): + if v == half_V - 1: + neighbor[:, v] = torch.arange(M, device=device) + continue + + dx, dy, dz = offsets[v] + + kx = ox[:, v] + dx + ky = oy[:, v] + dy + kz = oz[:, v] + dz + + valid = ( + (kx >= 0) & (kx < W) & + (ky >= 0) & (ky < H) & + (kz >= 0) & (kz < D) + ) + + flat = ( + b * (W * H * D) + + kx * (H * D) + + ky * D + + kz + ) + + flat = flat[valid] + idx = torch.nonzero(valid, as_tuple=False).squeeze(1) + + found = hashmap.lookup_flat(flat) + + neighbor[idx, v] = found + + # symmetric write + valid_found = found != INVALID + neighbor[found[valid_found], V - 1 - v] = idx[valid_found] + + return neighbor class TorchHashMap: def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): @@ -22,6 +206,207 @@ class TorchHashMap: out[found] = self.sorted_vals[idx[found]] return out + +UINT32_SENTINEL = 0xFFFFFFFF + +def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map): + device = neighbor_map.device + N, V = neighbor_map.shape + + + neigh = neighbor_map.to(torch.long) + sentinel = torch.tensor(UINT32_SENTINEL, dtype=torch.long, device=device) + + + neigh_map_T = neigh.t().reshape(-1) + + neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32) + + mask = (neigh != sentinel).to(torch.long) + + powers = (1 << torch.arange(V, dtype=torch.long, device=device)) + + gray_long = (mask * powers).sum(dim=1) + + gray_code = gray_long.to(torch.int32) + + binary_long = gray_long.clone() + for v in range(1, V): + binary_long ^= (gray_long >> v) + binary_code = binary_long.to(torch.int32) + + sorted_idx = torch.argsort(binary_code) + + prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T.to(torch.int32), dim=0) # (V*N,) + + total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0 + + if total_valid_signal > 0: + valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device) + valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device) + + pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0] + + to = (prefix_sum_neighbor_mask[pos] - 1).to(torch.long) + + valid_signal_i[to] = (pos % N).to(torch.long) + + valid_signal_o[to] = neigh_map_T[pos].to(torch.long) + else: + valid_signal_i = torch.empty((0,), dtype=torch.long, device=device) + valid_signal_o = torch.empty((0,), dtype=torch.long, device=device) + + seg = torch.empty((V + 1,), dtype=torch.long, device=device) + seg[0] = 0 + if V > 0: + idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1 + seg[1:] = prefix_sum_neighbor_mask[idxs].to(torch.long) + else: + pass + + return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg + +def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor: + + x = x.to(torch.int64) + + m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device) + m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device) + m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device) + h01 = torch.tensor(0x0101010101010101, dtype=torch.int64, device=x.device) + + x = x - ((x >> 1) & m1) + x = (x & m2) + ((x >> 2) & m2) + x = (x + (x >> 4)) & m4 + x = (x * h01) >> 56 + return x.to(torch.int32) + + +def neighbor_map_post_process_for_masked_implicit_gemm_2( + gray_code: torch.Tensor, # [N], int32-like (non-negative) + sorted_idx: torch.Tensor, # [N], long (indexing into gray_code) + block_size: int +): + device = gray_code.device + N = gray_code.numel() + + # num of blocks (same as CUDA) + num_blocks = (N + block_size - 1) // block_size + + # Ensure dtypes + gray_long = gray_code.to(torch.int64) # safer to OR in 64-bit then cast + sorted_idx = sorted_idx.to(torch.long) + + # 1) Group gray_code by blocks and compute OR across each block + # pad the last block with zeros if necessary so we can reshape + pad = num_blocks * block_size - N + if pad > 0: + pad_vals = torch.zeros((pad,), dtype=torch.int64, device=device) + gray_padded = torch.cat([gray_long[sorted_idx], pad_vals], dim=0) + else: + gray_padded = gray_long[sorted_idx] + + # reshape to (num_blocks, block_size) and compute bitwise_or across dim=1 + gray_blocks = gray_padded.view(num_blocks, block_size) # each row = block entries + # reduce with bitwise_or + reduced_code = gray_blocks[:, 0].clone() + for i in range(1, block_size): + reduced_code |= gray_blocks[:, i] + reduced_code = reduced_code.to(torch.int32) # match CUDA int32 + + # 2) compute seglen (popcount per reduced_code) and seg (prefix sum) + seglen_counts = _popcount_int32_tensor(reduced_code.to(torch.int64)).to(torch.int32) # [num_blocks] + # seg: length num_blocks+1, seg[0]=0, seg[i+1]=cumsum(seglen_counts) up to i + seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device) + seg[0] = 0 + if num_blocks > 0: + seg[1:] = torch.cumsum(seglen_counts, dim=0) + + total = int(seg[-1].item()) + + # 3) scatter — produce valid_kernel_idx as concatenated ascending set-bit positions for each reduced_code row + if total == 0: + valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device) + return valid_kernel_idx, seg + + max_val = int(reduced_code.max().item()) + V = max_val.bit_length() if max_val > 0 else 0 + # If you know V externally, pass it instead or set here explicitly. + + if V == 0: + # no bits set anywhere + valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device) + return valid_kernel_idx, seg + + # build mask of shape (num_blocks, V): True where bit is set + bit_pos = torch.arange(0, V, dtype=torch.int64, device=device) # [V] + # shifted = reduced_code[:, None] >> bit_pos[None, :] + shifted = reduced_code.to(torch.int64).unsqueeze(1) >> bit_pos.unsqueeze(0) + bits = (shifted & 1).to(torch.bool) # (num_blocks, V) + + positions = bit_pos.unsqueeze(0).expand(num_blocks, V) + + valid_positions = positions[bits] + valid_kernel_idx = valid_positions.to(torch.int32).contiguous() + + return valid_kernel_idx, seg + + +def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation): + if len(shape) == 5: + N, C, W, H, D = shape + else: + W, H, D = shape + + Co, Kw, Kh, Kd, Ci = weight.shape + + b_stride = W * H * D + x_stride = H * D + y_stride = D + z_stride = 1 + + flat_keys = (coords[:, 0].long() * b_stride + + coords[:, 1].long() * x_stride + + coords[:, 2].long() * y_stride + + coords[:, 3].long() * z_stride) + + vals = torch.arange(coords.shape[0], dtype=torch.int32, device=coords.device) + + hashmap = TorchHashMap(flat_keys, vals, 0xFFFFFFFF) + + if neighbor_cache is None: + neighbor = build_submanifold_neighbor_map( + hashmap, coords, W, H, D, Kw, Kh, Kd, + dilation[0], dilation[1], dilation[2] + ) + else: + neighbor = neighbor_cache + + block_size = 128 + + gray_code, sorted_idx, valid_signal_i, valid_signal_o, valid_signal_seg = \ + neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor) + + valid_kernel, valid_kernel_seg = \ + neighbor_map_post_process_for_masked_implicit_gemm_2(gray_code, sorted_idx, block_size) + + valid_kernel_fn = lambda b_size: valid_kernel + valid_kernel_seg_fn = lambda b_size: valid_kernel_seg + + weight_flat = weight.contiguous().view(Co, -1, Ci) + + out = sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk( + feats, + weight_flat, + bias, + neighbor, + sorted_idx, + valid_kernel_fn, + valid_kernel_seg_fn + ) + + return out, neighbor + class Voxel: def __init__( self, diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index cdbfbf6fc..9d1a8fdb4 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -408,7 +408,7 @@ class SLatFlowModel(nn.Module): self.qk_rms_norm_cross = qk_rms_norm_cross self.dtype = dtype - self.t_embedder = TimestepEmbedder(model_channels) + self.t_embedder = TimestepEmbedder(model_channels, device=device, dtype=dtype, operations=operations) if share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), @@ -485,15 +485,25 @@ class Trellis2(nn.Module): qk_rms_norm = True, qk_rms_norm_cross = True, dtype=None, device=None, operations=None): + + super().__init__() args = { "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations } # TODO: update the names/checkpoints - self.img2shape = SLatFlowModel(resolution, in_channels=in_channels, *args) - self.shape2txt = SLatFlowModel(resolution, in_channels=in_channels*2, *args) + self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) + self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) self.shape_generation = True - def forward(self, x, timestep, context): - pass + def forward(self, x, timestep, context, **kwargs): + # TODO add mode + mode = kwargs.get("mode", "shape_generation") + mode = "texture_generation" if mode == 1 else "shape_generation" + if mode == "shape_generation": + out = self.img2shape(x, timestep, context) + if mode == "texture_generation": + out = self.shape2txt(x, timestep, context) + + return out diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 1d564bca2..5dabf5246 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -1,3 +1,4 @@ +import math import torch import torch.nn as nn from typing import List, Any, Dict, Optional, overload, Union, Tuple @@ -5,12 +6,219 @@ from fractions import Fraction import torch.nn.functional as F from dataclasses import dataclass import numpy as np -from cumesh import TorchHashMap, Mesh, MeshWithVoxel +from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d + + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key) + + def forward(self, x): + return sparse_conv3d_forward(self, x) + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size, ) * 3 + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, ) * 3 + self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation, ) * 3 + + self.weight = nn.Parameter(torch.empty((out_channels, in_channels, *self.kernel_size))) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels)) + else: + self.register_parameter("bias", None) + + if self.bias is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + torch.nn.init.uniform_(self.bias, -bound, bound) + + # Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci) + self.weight = nn.Parameter(self.weight.permute(0, 2, 3, 4, 1).contiguous()) + + +def sparse_conv3d_forward(self, x): + # check if neighbor map is already computed + Co, Kd, Kh, Kw, Ci = self.weight.shape + neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}' + neighbor_cache = x.get_spatial_cache(neighbor_cache_key) + + out, neighbor_cache_ = sparse_submanifold_conv3d( + x.feats, + x.coords, + torch.Size([*x.shape, *x.spatial_shape]), + self.weight, + self.bias, + neighbor_cache, + self.dilation + ) + + if neighbor_cache is None: + x.register_spatial_cache(neighbor_cache_key, neighbor_cache_) + + out = x.replace(out) + return out + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = x.to(torch.float32) + o = super().forward(x) + return o.to(x_dtype) + +class SparseConvNeXtBlock3d(nn.Module): + def __init__( + self, + channels: int, + mlp_ratio: float = 4.0, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.use_checkpoint = use_checkpoint + + self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.conv = SparseConv3d(channels, channels, 3) + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.SiLU(), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def _forward(self, x): + h = self.conv(x) + h = h.replace(self.norm(h.feats)) + h = h.replace(self.mlp(h.feats)) + return h + x + + def forward(self, x): + return self._forward(x) + +class SparseSpatial2Channel(nn.Module): + def __init__(self, factor: int = 2): + super(SparseSpatial2Channel, self).__init__() + self.factor = factor + + def forward(self, x): + DIM = x.coords.shape[-1] - 1 + cache = x.get_spatial_cache(f'spatial2channel_{self.factor}') + if cache is None: + coord = list(x.coords.unbind(dim=-1)) + for i in range(DIM): + coord[i+1] = coord[i+1] // self.factor + subidx = x.coords[:, 1:] % self.factor + subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)]) + + MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + else: + new_coords, idx, subidx = cache + + new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype) + new_feats[idx * self.factor ** DIM + subidx] = x.feats + + out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM])) + out._scale = tuple([s * self.factor for s in x._scale]) + out._spatial_cache = x._spatial_cache + + if cache is None: + x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx)) + out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx)) + out.register_spatial_cache('shape', torch.Size(MAX)) + + return out + +class SparseChannel2Spatial(nn.Module): + def __init__(self, factor: int = 2): + super(SparseChannel2Spatial, self).__init__() + self.factor = factor + + def forward(self, x, subdivision = None): + DIM = x.coords.shape[-1] - 1 + + cache = x.get_spatial_cache(f'channel2spatial_{self.factor}') + if cache is None: + if subdivision is None: + raise ValueError('Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.') + else: + sub = subdivision.feats # [N, self.factor ** DIM] + N_leaf = sub.sum(dim=-1) # [N] + subidx = sub.nonzero()[:, -1] + new_coords = x.coords.clone().detach() + new_coords[:, 1:] *= self.factor + new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0]) + for i in range(DIM): + new_coords[:, i+1] += subidx // self.factor ** i % self.factor + idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0]) + else: + new_coords, idx, subidx = cache + + x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1) + new_feats = x_feats[idx * self.factor ** DIM + subidx] + out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM])) + out._scale = tuple([s / self.factor for s in x._scale]) + if cache is not None: # only keep cache when subdiv following it + out._spatial_cache = x._spatial_cache + return out + +class SparseResBlockC2S3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + self.pred_subdiv = pred_subdiv + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = SparseConv3d(channels, self.out_channels * 8, 3) + self.conv2 = SparseConv3d(self.out_channels, self.out_channels, 3) + self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1)) + if pred_subdiv: + self.to_subdiv = SparseLinear(channels, 8) + self.updown = SparseChannel2Spatial(2) + + def _forward(self, x, subdiv = None): + if self.pred_subdiv: + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None + h = self.updown(h, subdiv_binarized) + x = self.updown(x, subdiv_binarized) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.pred_subdiv: + return h, subdiv + else: + return h -# TODO: determine which conv they actually use @dataclass class config: - CONV = "none" + CONV = "flexgemm" + FLEX_GEMM_HASHMAP_RATIO = 2.0 # TODO post processing def simplify(self, target_num_faces: int, verbose: bool=False, options: dict={}): @@ -1131,6 +1339,7 @@ def flexible_dual_grid_to_mesh( class Vae(nn.Module): def __init__(self, config, operations=None): + super().__init__() operations = operations or torch.nn self.txt_dec = SparseUnetVaeDecoder( @@ -1139,7 +1348,8 @@ class Vae(nn.Module): latent_channels=32, num_blocks=[4, 16, 8, 4, 0], block_type=["SparseConvNeXtBlock3d"] * 5, - up_block_type=["SparseResBlockS2C3d"] * 4, + up_block_type=["SparseResBlockC2S3d"] * 4, + block_args=[{}, {}, {}, {}, {}], pred_subdiv=False ) @@ -1149,7 +1359,8 @@ class Vae(nn.Module): latent_channels=32, num_blocks=[4, 16, 8, 4, 0], block_type=["SparseConvNeXtBlock3d"] * 5, - up_block_type=["SparseResBlockS2C3d"] * 4, + up_block_type=["SparseResBlockC2S3d"] * 4, + block_args=[{}, {}, {}, {}, {}], ) def decode_shape_slat(self, slat, resolution: int): diff --git a/comfy/model_base.py b/comfy/model_base.py index 85acdb66a..a5fc81c4d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model +import comfy.ldm.trellis2.model import comfy.model_management import comfy.patcher_extension @@ -1455,6 +1456,13 @@ class WAN22(WAN21): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class Trellis2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.trellis2.model.Trellis2): + super().__init__(model_config, model_type, device, unet_model) + + def extra_conds(self, **kwargs): + return super().extra_conds(**kwargs) + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index d25271d6e..6c2725b9f 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1242,6 +1242,17 @@ class WAN22_T2V(WAN21_T2V): out = model_base.WAN22(self, image_to_video=True, device=device) return out +class Trellis2(supported_models_base.BASE): + unet_config = { + "image_model": "trellis2" + } + + latent_format = latent_formats.Trellis2 + vae_key_prefix = ["vae."] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.Trellis2(self, device=device) + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1596,6 +1607,6 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, Trellis2] models += [SVD_img2vid] From 66249395056ab311f4459fdec138fc49e080702d Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 3 Feb 2026 22:40:54 +0200 Subject: [PATCH 06/93] structure model --- comfy/ldm/trellis2/attention.py | 67 +++++++ comfy/ldm/trellis2/model.py | 316 +++++++++++++++++++++++++++++++- comfy/supported_models.py | 4 + 3 files changed, 382 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 9cd7d4995..6c912c8d9 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -4,6 +4,73 @@ from comfy.ldm.modules.attention import optimized_attention from typing import Tuple, Union, List from vae import VarLenTensor +FLASH_ATTN_3_AVA = True +try: + import flash_attn_interface as flash_attn_3 +except: + FLASH_ATTN_3_AVA = False + +# TODO repalce with optimized attention +def scaled_dot_product_attention(*args, **kwargs): + num_all_args = len(args) + len(kwargs) + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + + if optimized_attention.__name__ == 'attention_xformers': + if 'xops' not in globals(): + import xformers.ops as xops + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = xops.memory_efficient_attention(q, k, v) + elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA: + if 'flash_attn' not in globals(): + import flash_attn + if num_all_args == 2: + out = flash_attn.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn.flash_attn_func(q, k, v) + elif optimized_attention.__name__ == 'attention_flash': # TODO + if 'flash_attn_3' not in globals(): + import flash_attn_interface as flash_attn_3 + if num_all_args == 2: + k, v = kv.unbind(dim=2) + out = flash_attn_3.flash_attn_func(q, k, v) + elif num_all_args == 3: + out = flash_attn_3.flash_attn_func(q, k, v) + elif optimized_attention.__name__ == 'attention_pytorch': + if 'sdpa' not in globals(): + from torch.nn.functional import scaled_dot_product_attention as sdpa + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif optimized_attention.__name__ == 'attention_basic': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.shape[2] # TODO + out = optimized_attention(q, k, v) + + return out + def sparse_windowed_scaled_dot_product_self_attention( qkv, window_size: int, diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 9d1a8fdb4..add3f21ce 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -3,7 +3,9 @@ import torch.nn.functional as F import torch.nn as nn from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor from typing import Optional, Tuple, Literal, Union, List -from comfy.ldm.trellis2.attention import sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention +from comfy.ldm.trellis2.attention import ( + sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention +) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder class SparseGELU(nn.GELU): @@ -103,6 +105,18 @@ class SparseRotaryPositionEmbedder(nn.Module): k_embed = k.replace(self._rotary_embedding(k.feats, phases)) return q_embed, k_embed +class RotaryPositionEmbedder(SparseRotaryPositionEmbedder): + def forward(self, indices: torch.Tensor) -> torch.Tensor: + assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}" + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], padn, device=phases.device), + torch.zeros(*phases.shape[:-1], padn, device=phases.device) + )], dim=-1) + return phases + class SparseMultiHeadAttention(nn.Module): def __init__( self, @@ -472,6 +486,292 @@ class SLatFlowModel(nn.Module): h = self.out_layer(h) return h +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.GELU(approximate="tanh"), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int]=None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + ): + super().__init__() + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + B, L, C = x.shape + if self._type == "self": + qkv = self.to_qkv(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1) + + if self.attn_mode == "full": + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + assert phases is not None, "Phases must be provided for RoPE" + q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases) + k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(qkv) + else: + Lkv = context.shape[1] + q = self.to_q(x) + kv = self.to_kv(context) + q = q.reshape(B, L, self.num_heads, -1) + kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=2) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(q, kv) + h = h.reshape(B, L, -1) + h = self.to_out(h) + return h + +class ModulatedTransformerCrossBlock(nn.Module): + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.self_attn(h, phases=phases) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False) + else: + return self._forward(x, mod, context, phases) + + +class SparseStructureFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "ape", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + dtype: str = 'float32', + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + **kwargs + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = dtype + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + rope_phases = pos_embedder(coords) + self.register_buffer("rope_phases", rope_phases) + + if pe_mode != "rope": + self.rope_phases = None + + self.input_layer = nn.Linear(in_channels, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = nn.Linear(model_channels, out_channels) + + self.initialize_weights() + self.convert_to(self.dtype) + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ + f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" + + h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() + + h = self.input_layer(h) + if self.pe_mode == "ape": + h = h + self.pos_emb[None] + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + h = manual_cast(h, self.dtype) + cond = manual_cast(cond, self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond, self.rope_phases) + h = manual_cast(h, x.dtype) + h = F.layer_norm(h, h.shape[-1:]) + h = self.out_layer(h) + + h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous() + + return h + class Trellis2(nn.Module): def __init__(self, resolution, in_channels = 32, @@ -492,18 +792,24 @@ class Trellis2(nn.Module): "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations } - # TODO: update the names/checkpoints self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) - self.shape_generation = True + args.pop("out_channels") + args.pop("in_channels") + self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) def forward(self, x, timestep, context, **kwargs): # TODO add mode mode = kwargs.get("mode", "shape_generation") - mode = "texture_generation" if mode == 1 else "shape_generation" + if mode != 0: + mode = "texture_generation" if mode == 2 else "shape_generation" + else: + mode = "structure_generation" if mode == "shape_generation": out = self.img2shape(x, timestep, context) - if mode == "texture_generation": + elif mode == "texture_generation": out = self.shape2txt(x, timestep, context) + else: + out = self.structure_model(x, timestep, context) return out diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6c2725b9f..9e2f17149 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1247,6 +1247,10 @@ class Trellis2(supported_models_base.BASE): "image_model": "trellis2" } + sampling_settings = { + "shift": 3.0, + } + latent_format = latent_formats.Trellis2 vae_key_prefix = ["vae."] From 3002708fe390b1dd8b75a725cb5ae14ce8ceb7a3 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 4 Feb 2026 14:15:00 +0200 Subject: [PATCH 07/93] needed updates --- comfy/ldm/trellis2/model.py | 29 +++++-- comfy/ldm/trellis2/vae.py | 148 +++++++++++++++++++++++++++++++++ comfy/model_base.py | 5 +- comfy_extras/nodes_trellis2.py | 46 +++++++--- 4 files changed, 209 insertions(+), 19 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index add3f21ce..1dbbc4955 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -7,6 +7,7 @@ from comfy.ldm.trellis2.attention import ( sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder +from comfy.nested_tensor import NestedTensor class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: @@ -772,6 +773,11 @@ class SparseStructureFlowModel(nn.Module): return h +def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0): + t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1)) + t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear) + return t_new + class Trellis2(nn.Module): def __init__(self, resolution, in_channels = 32, @@ -798,18 +804,25 @@ class Trellis2(nn.Module): args.pop("in_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) - def forward(self, x, timestep, context, **kwargs): - # TODO add mode - mode = kwargs.get("mode", "shape_generation") - if mode != 0: - mode = "texture_generation" if mode == 2 else "shape_generation" - else: + def forward(self, x: NestedTensor, timestep, context, **kwargs): + x = x.tensors[0] + embeds = kwargs.get("embeds") + if not hasattr(x, "feats"): mode = "structure_generation" + else: + if x.feats.shape[1] == 32: + mode = "shape_generation" + else: + mode = "texture_generation" if mode == "shape_generation": - out = self.img2shape(x, timestep, context) + # TODO + out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)])) elif mode == "texture_generation": out = self.shape2txt(x, timestep, context) - else: + else: # structure + timestep = timestep_reshift(timestep) out = self.structure_model(x, timestep, context) + out = NestedTensor([out]) + out.generation_mode = mode return out diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 5dabf5246..584fa91ae 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -9,6 +9,17 @@ import numpy as np from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d +def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: + """ + 3D pixel shuffle. + """ + B, C, H, W, D = x.shape + C_ = C // scale_factor**3 + x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) + return x + class SparseConv3d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): super(SparseConv3d, self).__init__() @@ -1337,6 +1348,135 @@ def flexible_dual_grid_to_mesh( return mesh_vertices, mesh_triangles +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + return x + +class UpsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode = "conv", + ): + assert mode in ["conv", "nearest"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) + elif mode == "nearest": + assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + x = self.conv(x) + return pixel_shuffle_3d(x, 2) + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + +def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: + return ChannelLayerNorm32(*args, **kwargs) + +class ResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1) + self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class SparseStructureDecoder(nn.Module): + def __init__( + self, + out_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type = "layer", + use_fp16: bool = False, + ): + super().__init__() + self.out_channels = out_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ]) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + UpsampleBlock3d(ch, channels[i+1]) + ) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], out_channels, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.input_layer(x) + + h = h.type(self.dtype) + + h = self.middle_block(h) + for block in self.blocks: + h = block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + return h + class Vae(nn.Module): def __init__(self, config, operations=None): super().__init__() @@ -1363,6 +1503,14 @@ class Vae(nn.Module): block_args=[{}, {}, {}, {}, {}], ) + self.struct_dec = SparseStructureDecoder( + out_channels=1, + latent_channels=8, + num_res_blocks=2, + num_res_blocks_middle=2, + channels=[512, 128, 32], + ) + def decode_shape_slat(self, slat, resolution: int): self.shape_dec.set_resolution(resolution) return self.shape_dec(slat, return_subs=True) diff --git a/comfy/model_base.py b/comfy/model_base.py index a5fc81c4d..6bf4fadc9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1461,7 +1461,10 @@ class Trellis2(BaseModel): super().__init__(model_config, model_type, device, unet_model) def extra_conds(self, **kwargs): - return super().extra_conds(**kwargs) + out = super().extra_conds(**kwargs) + embeds = kwargs.get("embeds") + out["embeds"] = comfy.conds.CONDRegular(embeds) + return out class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 70bbbb29d..4a36e2fee 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -6,6 +6,7 @@ import comfy.model_management from PIL import Image import PIL import numpy as np +from comfy.nested_tensor import NestedTensor shape_slat_normalization = { "mean": torch.tensor([ @@ -131,7 +132,8 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples, vae, resolution): + def execute(cls, samples: NestedTensor, vae, resolution): + samples = samples.tensors[0] std = shape_slat_normalization["std"] mean = shape_slat_normalization["mean"] samples = samples * std + mean @@ -157,9 +159,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, vae, shape_subs): - if shape_subs is None: - raise ValueError("Shape subs must be provided for texture generation") - + samples = samples.tensors[0] std = tex_slat_normalization["std"] mean = tex_slat_normalization["mean"] samples = samples * std + mean @@ -167,6 +167,28 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): mesh = vae.decode_tex_slat(samples, shape_subs) return mesh +class VaeDecodeStructureTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VaeDecodeStructureTrellis2", + category="latent/3d", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + ], + outputs=[ + IO.Mesh.Output("structure_output"), + ] + ) + + @classmethod + def execute(cls, samples, vae): + decoder = vae.struct_dec + decoded = decoder(samples)>0 + coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int() + return coords + class Trellis2Conditioning(IO.ComfyNode): @classmethod def define_schema(cls): @@ -189,8 +211,8 @@ class Trellis2Conditioning(IO.ComfyNode): # could make 1024 an option conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color) embeds = conditioning["cond_1024"] # should add that - positive = [[conditioning["cond_512"], {embeds}]] - negative = [[conditioning["cond_neg"], {embeds}]] + positive = [[conditioning["cond_512"], {"embeds": embeds}]] + negative = [[conditioning["cond_neg"], {"embeds": embeds}]] return IO.NodeOutput(positive, negative) class EmptyShapeLatentTrellis2(IO.ComfyNode): @@ -200,7 +222,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): node_id="EmptyLatentTrellis2", category="latent/3d", inputs=[ - IO.Latent.Input("structure_output"), + IO.Mesh.Input("structure_output"), ], outputs=[ IO.Latent.Output(), @@ -210,9 +232,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): # i will see what i have to do here - coords = structure_output or structure_output.coords + coords = structure_output # or structure_output.coords in_channels = 32 latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) + latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @@ -222,7 +245,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): node_id="EmptyLatentTrellis2", category="latent/3d", inputs=[ - IO.Latent.Input("structure_output"), + IO.Mesh.Input("structure_output"), ], outputs=[ IO.Latent.Output(), @@ -234,6 +257,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): # TODO in_channels = 32 latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1])) + latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @@ -254,6 +278,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): def execute(cls, res, batch_size): in_channels = 32 latent = torch.randn(batch_size, in_channels, res, res, res) + latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) @@ -266,7 +291,8 @@ class Trellis2Extension(ComfyExtension): EmptyStructureLatentTrellis2, EmptyTextureLatentTrellis2, VaeDecodeTextureTrellis, - VaeDecodeShapeTrellis + VaeDecodeShapeTrellis, + VaeDecodeStructureTrellis2 ] From cd0f7ba64e6d0a195fdc53132cc58b053569dacb Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 5 Feb 2026 02:34:08 +0200 Subject: [PATCH 08/93] apply rope and optimized attention --- comfy/ldm/trellis2/attention.py | 29 +++++++------ comfy/ldm/trellis2/model.py | 76 ++++++++++++++++----------------- 2 files changed, 52 insertions(+), 53 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 6c912c8d9..edc85ce83 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -26,21 +26,21 @@ def scaled_dot_product_attention(*args, **kwargs): k = args[1] if len(args) > 1 else kwargs['k'] v = args[2] if len(args) > 2 else kwargs['v'] + # TODO verify + heads = q or qkv + heads = heads.shape[2] + if optimized_attention.__name__ == 'attention_xformers': - if 'xops' not in globals(): - import xformers.ops as xops if num_all_args == 1: q, k, v = qkv.unbind(dim=2) elif num_all_args == 2: k, v = kv.unbind(dim=2) - out = xops.memory_efficient_attention(q, k, v) + #out = xops.memory_efficient_attention(q, k, v) + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA: - if 'flash_attn' not in globals(): - import flash_attn if num_all_args == 2: - out = flash_attn.flash_attn_kvpacked_func(q, kv) - elif num_all_args == 3: - out = flash_attn.flash_attn_func(q, k, v) + k, v = kv.unbind(dim=2) + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) elif optimized_attention.__name__ == 'attention_flash': # TODO if 'flash_attn_3' not in globals(): import flash_attn_interface as flash_attn_3 @@ -59,15 +59,14 @@ def scaled_dot_product_attention(*args, **kwargs): q = q.permute(0, 2, 1, 3) # [N, H, L, C] k = k.permute(0, 2, 1, 3) # [N, H, L, C] v = v.permute(0, 2, 1, 3) # [N, H, L, C] - out = sdpa(q, k, v) # [N, H, L, C] + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) out = out.permute(0, 2, 1, 3) # [N, L, H, C] elif optimized_attention.__name__ == 'attention_basic': if num_all_args == 1: q, k, v = qkv.unbind(dim=2) elif num_all_args == 2: k, v = kv.unbind(dim=2) - q = q.shape[2] # TODO - out = optimized_attention(q, k, v) + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) return out @@ -86,19 +85,21 @@ def sparse_windowed_scaled_dot_product_self_attention( fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + heads = qkv_feats.shape[2] if optimized_attention.__name__ == 'attention_xformers': - if 'xops' not in globals(): - import xformers.ops as xops q, k, v = qkv_feats.unbind(dim=1) q = q.unsqueeze(0) # [1, M, H, C] k = k.unsqueeze(0) # [1, M, H, C] v = v.unsqueeze(0) # [1, M, H, C] - out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C] + #out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C] + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) elif optimized_attention.__name__ == 'attention_flash': if 'flash_attn' not in globals(): import flash_attn out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C] + else: + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) out = out[bwd_indices] # [T, H, C] diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 1dbbc4955..484622d76 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -8,6 +8,7 @@ from comfy.ldm.trellis2.attention import ( ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder from comfy.nested_tensor import NestedTensor +from comfy.ldm.flux.math import apply_rope, apply_rope1 class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: @@ -52,7 +53,6 @@ class SparseMultiHeadRMSNorm(nn.Module): x = F.normalize(x, dim=-1) * self.gamma * self.scale return x.to(x_type) -# TODO: replace with apply_rope1 class SparseRotaryPositionEmbedder(nn.Module): def __init__( self, @@ -61,7 +61,6 @@ class SparseRotaryPositionEmbedder(nn.Module): rope_freq: Tuple[float, float] = (1.0, 10000.0) ): super().__init__() - assert head_dim % 2 == 0, "Head dim must be divisible by 2" self.head_dim = head_dim self.dim = dim self.rope_freq = rope_freq @@ -69,46 +68,48 @@ class SparseRotaryPositionEmbedder(nn.Module): self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) - def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: - self.freqs = self.freqs.to(indices.device) - phases = torch.outer(indices, self.freqs) - phases = torch.polar(torch.ones_like(phases), phases) - return phases + def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor: + phases_list = [] + for i in range(self.dim): + phases_list.append(torch.outer(coords[..., i], self.freqs.to(coords.device))) - def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: - x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) - x_rotated = x_complex * phases.unsqueeze(-2) - x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) - return x_embed + phases = torch.cat(phases_list, dim=-1) + + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.zeros(*phases.shape[:-1], padn, device=phases.device)], dim=-1) + + cos = torch.cos(phases) + sin = torch.sin(phases) + + f_cis_0 = torch.stack([cos, sin], dim=-1) + f_cis_1 = torch.stack([-sin, cos], dim=-1) + freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1) + + return freqs_cis + + def forward(self, q, k=None): + cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}' + freqs_cis = q.get_spatial_cache(cache_name) + + if freqs_cis is None: + coords = q.coords[..., 1:].to(torch.float32) + freqs_cis = self._get_freqs_cis(coords) + q.register_spatial_cache(cache_name, freqs_cis) + + if q.feats.ndim == 3: + f_cis = freqs_cis.unsqueeze(1) + else: + f_cis = freqs_cis - def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - q (SparseTensor): [..., N, H, D] tensor of queries - k (SparseTensor): [..., N, H, D] tensor of keys - """ - assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1" - phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}' - phases = q.get_spatial_cache(phases_cache_name) - if phases is None: - coords = q.coords[..., 1:] - phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1) - if phases.shape[-1] < self.head_dim // 2: - padn = self.head_dim // 2 - phases.shape[-1] - phases = torch.cat([phases, torch.polar( - torch.ones(*phases.shape[:-1], padn, device=phases.device), - torch.zeros(*phases.shape[:-1], padn, device=phases.device) - )], dim=-1) - q.register_spatial_cache(phases_cache_name, phases) - q_embed = q.replace(self._rotary_embedding(q.feats, phases)) if k is None: - return q_embed - k_embed = k.replace(self._rotary_embedding(k.feats, phases)) - return q_embed, k_embed + return q.replace(apply_rope1(q.feats, f_cis)) + + q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis) + return q.replace(q_feats), k.replace(k_feats) class RotaryPositionEmbedder(SparseRotaryPositionEmbedder): def forward(self, indices: torch.Tensor) -> torch.Tensor: - assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}" phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) if phases.shape[-1] < self.head_dim // 2: padn = self.head_dim // 2 - phases.shape[-1] @@ -228,9 +229,6 @@ class SparseMultiHeadAttention(nn.Module): return h class ModulatedSparseTransformerBlock(nn.Module): - """ - Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. - """ def __init__( self, channels: int, From f2c0320fe84e533ad1e0173e0eb25c934027b216 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 5 Feb 2026 17:19:57 +0200 Subject: [PATCH 09/93] fixes to vae and cumesh impl. --- comfy/ldm/trellis2/cumesh.py | 189 ++++++++++++++++++++++++++--------- comfy/ldm/trellis2/vae.py | 2 +- 2 files changed, 143 insertions(+), 48 deletions(-) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index 41ac35db9..fe7e80e15 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -5,6 +5,10 @@ import torch from typing import Dict, Callable NO_TRITION = False +try: + allow_tf32 = torch.cuda.is_tf32_supported +except Exception: + allow_tf32 = False try: import triton import triton.language as tl @@ -102,10 +106,13 @@ try: grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),) sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid]( input, weight, bias, neighbor, sorted_idx, output, - N, LOGN, Ci, Co, V, # + N, LOGN, Ci, Co, V, + B1=128, + B2=64, + BK=32, valid_kernel=valid_kernel, valid_kernel_seg=valid_kernel_seg, - allow_tf32=torch.cuda.is_tf32_supported(), + allow_tf32=allow_tf32, ) return output except: @@ -140,16 +147,16 @@ def build_submanifold_neighbor_map( neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long) - b = coords[:, 0] - x = coords[:, 1] - y = coords[:, 2] - z = coords[:, 3] + b = coords[:, 0].long() + x = coords[:, 1].long() + y = coords[:, 2].long() + z = coords[:, 3].long() offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device) - ox = x[:, None] - (Kw // 2) * Dw - oy = y[:, None] - (Kh // 2) * Dh - oz = z[:, None] - (Kd // 2) * Dd + ox = x - (Kw // 2) * Dw + oy = y - (Kh // 2) * Dh + oz = z - (Kd // 2) * Dd for v in range(half_V): if v == half_V - 1: @@ -158,10 +165,11 @@ def build_submanifold_neighbor_map( dx, dy, dz = offsets[v] - kx = ox[:, v] + dx - ky = oy[:, v] + dy - kz = oz[:, v] + dz + kx = ox + dx + ky = oy + dy + kz = oz + dz + # Check spatial bounds valid = ( (kx >= 0) & (kx < W) & (ky >= 0) & (ky < H) & @@ -169,22 +177,22 @@ def build_submanifold_neighbor_map( ) flat = ( - b * (W * H * D) + - kx * (H * D) + - ky * D + - kz + b[valid] * (W * H * D) + + kx[valid] * (H * D) + + ky[valid] * D + + kz[valid] ) - flat = flat[valid] - idx = torch.nonzero(valid, as_tuple=False).squeeze(1) + if flat.numel() > 0: + found = hashmap.lookup_flat(flat) + idx_in_M = torch.where(valid)[0] + neighbor[idx_in_M, v] = found - found = hashmap.lookup_flat(flat) - - neighbor[idx, v] = found - - # symmetric write - valid_found = found != INVALID - neighbor[found[valid_found], V - 1 - v] = idx[valid_found] + valid_found_mask = (found != INVALID) + if valid_found_mask.any(): + src_points = idx_in_M[valid_found_mask] + dst_points = found[valid_found_mask] + neighbor[dst_points, V - 1 - v] = src_points return neighbor @@ -461,31 +469,118 @@ class Mesh: def cpu(self): return self.to('cpu') - # TODO could be an option + # could make this into a new node def fill_holes(self, max_hole_perimeter=3e-2): - import cumesh - vertices = self.vertices.cuda() - faces = self.faces.cuda() - mesh = cumesh.CuMesh() - mesh.init(vertices, faces) - mesh.get_edges() - mesh.get_boundary_info() - if mesh.num_boundaries == 0: - return - mesh.get_vertex_edge_adjacency() - mesh.get_vertex_boundary_adjacency() - mesh.get_manifold_boundary_adjacency() - mesh.read_manifold_boundary_adjacency() - mesh.get_boundary_connected_components() - mesh.get_boundary_loops() - if mesh.num_boundary_loops == 0: - return - mesh.fill_holes(max_hole_perimeter=max_hole_perimeter) - new_vertices, new_faces = mesh.read() + device = self.vertices.device + vertices = self.vertices + faces = self.faces + + edges = torch.cat([ + faces[:, [0, 1]], + faces[:, [1, 2]], + faces[:, [2, 0]] + ], dim=0) + + edges_sorted, _ = torch.sort(edges, dim=1) + + unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) + + boundary_mask = counts == 1 + boundary_edges_sorted = unique_edges[boundary_mask] + + if boundary_edges_sorted.shape[0] == 0: + return + max_idx = vertices.shape[0] + + _, inverse_indices, counts_packed = torch.unique( + torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1], + return_inverse=True, return_counts=True + ) + + boundary_packed_mask = counts_packed == 1 + is_boundary_edge = boundary_packed_mask[inverse_indices] + + active_boundary_edges = edges[is_boundary_edge] + + adj = {} + edges_np = active_boundary_edges.cpu().numpy() + for u, v in edges_np: + adj[u] = v + + loops = [] + visited_edges = set() + + possible_starts = list(adj.keys()) + + processed_nodes = set() + + for start_node in possible_starts: + if start_node in processed_nodes: + continue + + current_loop = [] + curr = start_node + + while curr in adj: + next_node = adj[curr] + if (curr, next_node) in visited_edges: + break + + visited_edges.add((curr, next_node)) + processed_nodes.add(curr) + current_loop.append(curr) + + curr = next_node + + if curr == start_node: + loops.append(current_loop) + break + + if len(current_loop) > len(edges_np): + break + + if not loops: + return + + new_faces = [] + + v_offset = vertices.shape[0] + + valid_new_verts = [] + + for loop_indices in loops: + if len(loop_indices) < 3: + continue + + loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) + loop_verts = vertices[loop_tensor] + + diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) + perimeter = torch.norm(diffs, dim=1).sum() + + if perimeter > max_hole_perimeter: + continue + + center = loop_verts.mean(dim=0) + valid_new_verts.append(center) + + c_idx = v_offset + v_offset += 1 + + num_v = len(loop_indices) + for i in range(num_v): + v_curr = loop_indices[i] + v_next = loop_indices[(i + 1) % num_v] + new_faces.append([v_curr, v_next, c_idx]) + + if len(valid_new_verts) > 0: + added_vertices = torch.stack(valid_new_verts, dim=0) + added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) + + self.vertices = torch.cat([self.vertices, added_vertices], dim=0) + self.faces = torch.cat([self.faces, added_faces], dim=0) - self.vertices = new_vertices.to(self.device) - self.faces = new_faces.to(self.device) # TODO could be an option def simplify(self, target=1000000, verbose: bool=False, options: dict={}): diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 584fa91ae..2bbfa938c 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -208,7 +208,7 @@ class SparseResBlockC2S3d(nn.Module): self.to_subdiv = SparseLinear(channels, 8) self.updown = SparseChannel2Spatial(2) - def _forward(self, x, subdiv = None): + def forward(self, x, subdiv = None): if self.pred_subdiv: subdiv = self.to_subdiv(x) h = x.replace(self.norm1(x.feats)) From cdd7ced1e8c2353538506e66d0cd306c9b462c3a Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 6 Feb 2026 19:28:49 +0200 Subject: [PATCH 10/93] model fixes --- comfy/ldm/trellis2/model.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 484622d76..2367fc42c 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -26,10 +26,9 @@ class SparseFeedForwardNet(nn.Module): def forward(self, x: VarLenTensor) -> VarLenTensor: return self.mlp(x) -def manual_cast(tensor, dtype): - if not torch.is_autocast_enabled(): - return tensor.type(dtype) - return tensor +def manual_cast(obj, dtype): + return obj.to(dtype=dtype) + class LayerNorm32(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: x_dtype = x.dtype @@ -88,6 +87,12 @@ class SparseRotaryPositionEmbedder(nn.Module): return freqs_cis + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + def forward(self, q, k=None): cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}' freqs_cis = q.get_spatial_cache(cache_name) @@ -111,11 +116,15 @@ class SparseRotaryPositionEmbedder(nn.Module): class RotaryPositionEmbedder(SparseRotaryPositionEmbedder): def forward(self, indices: torch.Tensor) -> torch.Tensor: phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if torch.is_complex(phases): + phases = phases.to(torch.complex64) + else: + phases = phases.to(torch.float32) if phases.shape[-1] < self.head_dim // 2: padn = self.head_dim // 2 - phases.shape[-1] phases = torch.cat([phases, torch.polar( - torch.ones(*phases.shape[:-1], padn, device=phases.device), - torch.zeros(*phases.shape[:-1], padn, device=phases.device) + torch.ones(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32), + torch.zeros(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32) )], dim=-1) return phases @@ -468,7 +477,7 @@ class SLatFlowModel(nn.Module): h = self.input_layer(x) h = manual_cast(h, self.dtype) - t_emb = self.t_embedder(t) + t_emb = self.t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) t_emb = manual_cast(t_emb, self.dtype) @@ -687,9 +696,12 @@ class SparseStructureFlowModel(nn.Module): initialization: str = 'vanilla', qk_rms_norm: bool = False, qk_rms_norm_cross: bool = False, + operations=None, + device = None, **kwargs ): super().__init__() + self.device = device self.resolution = resolution self.in_channels = in_channels self.model_channels = model_channels @@ -706,7 +718,7 @@ class SparseStructureFlowModel(nn.Module): self.qk_rms_norm_cross = qk_rms_norm_cross self.dtype = dtype - self.t_embedder = TimestepEmbedder(model_channels) + self.t_embedder = TimestepEmbedder(model_channels, operations=operations) if share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), @@ -743,9 +755,6 @@ class SparseStructureFlowModel(nn.Module): self.out_layer = nn.Linear(model_channels, out_channels) - self.initialize_weights() - self.convert_to(self.dtype) - def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" @@ -755,7 +764,7 @@ class SparseStructureFlowModel(nn.Module): h = self.input_layer(h) if self.pe_mode == "ape": h = h + self.pos_emb[None] - t_emb = self.t_embedder(t) + t_emb = self.t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) t_emb = manual_cast(t_emb, self.dtype) @@ -799,7 +808,6 @@ class Trellis2(nn.Module): self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) args.pop("out_channels") - args.pop("in_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) def forward(self, x: NestedTensor, timestep, context, **kwargs): From 64a52f5585d46a3e804567640da2a9627f48257e Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 6 Feb 2026 20:35:33 +0200 Subject: [PATCH 11/93] checkpoint --- comfy/ldm/trellis2/model.py | 4 +++- comfy/ldm/trellis2/vae.py | 23 ++++++++++++----------- comfy/model_detection.py | 14 ++++++++++++++ comfy/sd.py | 7 +++++++ 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 2367fc42c..8ca112b13 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -797,6 +797,7 @@ class Trellis2(nn.Module): share_mod = True, qk_rms_norm = True, qk_rms_norm_cross = True, + init_txt_model=False, # for now dtype=None, device=None, operations=None): super().__init__() @@ -806,7 +807,8 @@ class Trellis2(nn.Module): "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations } self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) - self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) + if init_txt_model: + self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) args.pop("out_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 2bbfa938c..d997bbc41 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -1481,17 +1481,18 @@ class Vae(nn.Module): def __init__(self, config, operations=None): super().__init__() operations = operations or torch.nn - - self.txt_dec = SparseUnetVaeDecoder( - out_channels=6, - model_channels=[1024, 512, 256, 128, 64], - latent_channels=32, - num_blocks=[4, 16, 8, 4, 0], - block_type=["SparseConvNeXtBlock3d"] * 5, - up_block_type=["SparseResBlockC2S3d"] * 4, - block_args=[{}, {}, {}, {}, {}], - pred_subdiv=False - ) + init_txt_model = config.get("init_txt_model", False) + if init_txt_model: + self.txt_dec = SparseUnetVaeDecoder( + out_channels=6, + model_channels=[1024, 512, 256, 128, 64], + latent_channels=32, + num_blocks=[4, 16, 8, 4, 0], + block_type=["SparseConvNeXtBlock3d"] * 5, + up_block_type=["SparseResBlockC2S3d"] * 4, + block_args=[{}, {}, {}, {}, {}], + pred_subdiv=False + ) self.shape_dec = FlexiDualGridVaeDecoder( resolution=256, diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8cea16e50..4f5542af5 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -106,6 +106,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] return unet_config + if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: + unet_config = {} + unet_config["image_model"] = "trellis2" + if '{}model.shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: + unet_config["init_txt_model"] = True + else: + unet_config["init_txt_model"] = False + if metadata is not None: + if metadata["is_512"] is True: + unet_config["resolution"] = 32 + else: + unet_config["resolution"] = 64 + return unet_config + if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit unet_config = {} unet_config["audio_model"] = "dit1.0" diff --git a/comfy/sd.py b/comfy/sd.py index fd0ac85e7..be3d1b4f0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae +import comfy.ldm.trellis2.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline @@ -492,6 +493,12 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 + elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2 + if "txt_dec.blocks.1.16.norm1.weight" in sd: + config["init_txt_model"] = True + else: + config["init_txt_model"] = False + self.first_stage_model = comfy.ldm.trellis2.vae.Vae(config) elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} From 955c00ee38356fd35d4c347f617ef15cb10659ec Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 6 Feb 2026 23:54:27 +0200 Subject: [PATCH 12/93] post-process node --- comfy/ldm/trellis2/cumesh.py | 127 ---------------------- comfy/ldm/trellis2/vae.py | 22 ---- comfy_extras/nodes_trellis2.py | 187 ++++++++++++++++++++++++++++++++- 3 files changed, 186 insertions(+), 150 deletions(-) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index fe7e80e15..972fb13c3 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -469,133 +469,6 @@ class Mesh: def cpu(self): return self.to('cpu') - # could make this into a new node - def fill_holes(self, max_hole_perimeter=3e-2): - - device = self.vertices.device - vertices = self.vertices - faces = self.faces - - edges = torch.cat([ - faces[:, [0, 1]], - faces[:, [1, 2]], - faces[:, [2, 0]] - ], dim=0) - - edges_sorted, _ = torch.sort(edges, dim=1) - - unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) - - boundary_mask = counts == 1 - boundary_edges_sorted = unique_edges[boundary_mask] - - if boundary_edges_sorted.shape[0] == 0: - return - max_idx = vertices.shape[0] - - _, inverse_indices, counts_packed = torch.unique( - torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1], - return_inverse=True, return_counts=True - ) - - boundary_packed_mask = counts_packed == 1 - is_boundary_edge = boundary_packed_mask[inverse_indices] - - active_boundary_edges = edges[is_boundary_edge] - - adj = {} - edges_np = active_boundary_edges.cpu().numpy() - for u, v in edges_np: - adj[u] = v - - loops = [] - visited_edges = set() - - possible_starts = list(adj.keys()) - - processed_nodes = set() - - for start_node in possible_starts: - if start_node in processed_nodes: - continue - - current_loop = [] - curr = start_node - - while curr in adj: - next_node = adj[curr] - if (curr, next_node) in visited_edges: - break - - visited_edges.add((curr, next_node)) - processed_nodes.add(curr) - current_loop.append(curr) - - curr = next_node - - if curr == start_node: - loops.append(current_loop) - break - - if len(current_loop) > len(edges_np): - break - - if not loops: - return - - new_faces = [] - - v_offset = vertices.shape[0] - - valid_new_verts = [] - - for loop_indices in loops: - if len(loop_indices) < 3: - continue - - loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) - loop_verts = vertices[loop_tensor] - - diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) - perimeter = torch.norm(diffs, dim=1).sum() - - if perimeter > max_hole_perimeter: - continue - - center = loop_verts.mean(dim=0) - valid_new_verts.append(center) - - c_idx = v_offset - v_offset += 1 - - num_v = len(loop_indices) - for i in range(num_v): - v_curr = loop_indices[i] - v_next = loop_indices[(i + 1) % num_v] - new_faces.append([v_curr, v_next, c_idx]) - - if len(valid_new_verts) > 0: - added_vertices = torch.stack(valid_new_verts, dim=0) - added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) - - self.vertices = torch.cat([self.vertices, added_vertices], dim=0) - self.faces = torch.cat([self.faces, added_faces], dim=0) - - - # TODO could be an option - def simplify(self, target=1000000, verbose: bool=False, options: dict={}): - import cumesh - vertices = self.vertices.cuda() - faces = self.faces.cuda() - - mesh = cumesh.CuMesh() - mesh.init(vertices, faces) - mesh.simplify(target, verbose=verbose, options=options) - new_vertices, new_faces = mesh.read() - - self.vertices = new_vertices.to(self.device) - self.faces = new_faces.to(self.device) - class MeshWithVoxel(Mesh, Voxel): def __init__(self, vertices: torch.Tensor, diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index d997bbc41..1d26986cc 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -231,27 +231,6 @@ class config: CONV = "flexgemm" FLEX_GEMM_HASHMAP_RATIO = 2.0 -# TODO post processing -def simplify(self, target_num_faces: int, verbose: bool=False, options: dict={}): - - num_face = self.cu_mesh.num_faces() - if num_face <= target_num_faces: - return - - thresh = options.get('thresh', 1e-8) - lambda_edge_length = options.get('lambda_edge_length', 1e-2) - lambda_skinny = options.get('lambda_skinny', 1e-3) - while True: - new_num_vert, new_num_face = self.cu_mesh.simplify_step(lambda_edge_length, lambda_skinny, thresh, False) - - if new_num_face <= target_num_faces: - break - - del_num_face = num_face - new_num_face - if del_num_face / num_face < 1e-2: - thresh *= 10 - num_face = new_num_face - class VarLenTensor: def __init__(self, feats: torch.Tensor, layout: List[slice]=None): @@ -1530,7 +1509,6 @@ class Vae(nn.Module): tex_voxels = self.decode_tex_slat(tex_slat, subs) out_mesh = [] for m, v in zip(meshes, tex_voxels): - m.fill_holes() # TODO out_mesh.append( MeshWithVoxel( m.vertices, m.faces, diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 4a36e2fee..8497b83e2 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -281,6 +281,190 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) +def simplify_fn(vertices, faces, target=100000): + + if vertices.shape[0] <= target: + return + + min_feat = vertices.min(dim=0)[0] + max_feat = vertices.max(dim=0)[0] + extent = (max_feat - min_feat).max() + + grid_resolution = int(torch.sqrt(torch.tensor(target)).item() * 1.5) + voxel_size = extent / grid_resolution + + quantized_coords = ((vertices - min_feat) / voxel_size).long() + + unique_coords, inverse_indices = torch.unique(quantized_coords, dim=0, return_inverse=True) + + num_new_verts = unique_coords.shape[0] + new_vertices = torch.zeros((num_new_verts, 3), dtype=vertices.dtype, device=vertices.device) + + counts = torch.zeros((num_new_verts, 1), dtype=vertices.dtype, device=vertices.device) + + new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices) + counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1])) + + new_vertices = new_vertices / counts.clamp(min=1) + + new_faces = inverse_indices[faces] + + v0 = new_faces[:, 0] + v1 = new_faces[:, 1] + v2 = new_faces[:, 2] + + valid_mask = (v0 != v1) & (v1 != v2) & (v2 != v0) + new_faces = new_faces[valid_mask] + + unique_face_indices, inv_face = torch.unique(new_faces.reshape(-1), return_inverse=True) + final_vertices = new_vertices[unique_face_indices] + final_faces = inv_face.reshape(-1, 3) + + return final_vertices, final_faces + +def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): + + device = vertices.device + orig_vertices = vertices + orig_faces = faces + + edges = torch.cat([ + faces[:, [0, 1]], + faces[:, [1, 2]], + faces[:, [2, 0]] + ], dim=0) + + edges_sorted, _ = torch.sort(edges, dim=1) + + unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) + + boundary_mask = counts == 1 + boundary_edges_sorted = unique_edges[boundary_mask] + + if boundary_edges_sorted.shape[0] == 0: + return + max_idx = vertices.shape[0] + + _, inverse_indices, counts_packed = torch.unique( + torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1], + return_inverse=True, return_counts=True + ) + + boundary_packed_mask = counts_packed == 1 + is_boundary_edge = boundary_packed_mask[inverse_indices] + + active_boundary_edges = edges[is_boundary_edge] + + adj = {} + edges_np = active_boundary_edges.cpu().numpy() + for u, v in edges_np: + adj[u] = v + + loops = [] + visited_edges = set() + + possible_starts = list(adj.keys()) + + processed_nodes = set() + + for start_node in possible_starts: + if start_node in processed_nodes: + continue + + current_loop = [] + curr = start_node + + while curr in adj: + next_node = adj[curr] + if (curr, next_node) in visited_edges: + break + + visited_edges.add((curr, next_node)) + processed_nodes.add(curr) + current_loop.append(curr) + + curr = next_node + + if curr == start_node: + loops.append(current_loop) + break + + if len(current_loop) > len(edges_np): + break + + if not loops: + return + + new_faces = [] + + v_offset = vertices.shape[0] + + valid_new_verts = [] + + for loop_indices in loops: + if len(loop_indices) < 3: + continue + + loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) + loop_verts = vertices[loop_tensor] + + diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) + perimeter = torch.norm(diffs, dim=1).sum() + + if perimeter > max_hole_perimeter: + continue + + center = loop_verts.mean(dim=0) + valid_new_verts.append(center) + + c_idx = v_offset + v_offset += 1 + + num_v = len(loop_indices) + for i in range(num_v): + v_curr = loop_indices[i] + v_next = loop_indices[(i + 1) % num_v] + new_faces.append([v_curr, v_next, c_idx]) + + if len(valid_new_verts) > 0: + added_vertices = torch.stack(valid_new_verts, dim=0) + added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) + + vertices_f = torch.cat([orig_vertices, added_vertices], dim=0) + faces_f = torch.cat([orig_faces, added_faces], dim=0) + + return vertices_f, faces_f + +class PostProcessMesh(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PostProcessMesh", + category="latent/3d", + inputs=[ + IO.Mesh.Input("mesh"), + IO.Int.Input("simplify", default=100_000, min=0), # max? + IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0) + ], + outputs=[ + IO.Mesh.Output("output_mesh"), + ] + ) + @classmethod + def execute(cls, mesh, simplify, fill_holes_perimeter): + verts, faces = mesh.vertices, mesh.faces + + if fill_holes_perimeter != 0.0: + verts, faces = fill_holes_fn(verts, faces, max_hole_perimeter=fill_holes_perimeter) + + if simplify != 0: + verts, faces = simplify_fn(verts, faces, simplify) + + + mesh.vertices = verts + mesh.faces = faces + + return mesh class Trellis2Extension(ComfyExtension): @override @@ -292,7 +476,8 @@ class Trellis2Extension(ComfyExtension): EmptyTextureLatentTrellis2, VaeDecodeTextureTrellis, VaeDecodeShapeTrellis, - VaeDecodeStructureTrellis2 + VaeDecodeStructureTrellis2, + PostProcessMesh ] From 704e1b54621a78b30d529b5367a2951383bd890c Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 9 Feb 2026 00:41:01 +0200 Subject: [PATCH 13/93] small bug fixes --- comfy/ldm/trellis2/attention.py | 2 +- comfy/ldm/trellis2/model.py | 5 ++++- comfy/ldm/trellis2/vae.py | 5 ++--- comfy/model_detection.py | 12 +++++++----- comfy/sd.py | 7 +++---- comfy_extras/nodes_trellis2.py | 8 ++++---- 6 files changed, 21 insertions(+), 18 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index edc85ce83..3038f4023 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -2,7 +2,7 @@ import torch import math from comfy.ldm.modules.attention import optimized_attention from typing import Tuple, Union, List -from vae import VarLenTensor +from comfy.ldm.trellis2.vae import VarLenTensor FLASH_ATTN_3_AVA = True try: diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 8ca112b13..9aab045c7 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -798,9 +798,12 @@ class Trellis2(nn.Module): qk_rms_norm = True, qk_rms_norm_cross = True, init_txt_model=False, # for now - dtype=None, device=None, operations=None): + dtype=None, device=None, operations=None, **kwargs): super().__init__() + self.dtype = dtype + # for some reason it passes num_heads = -1 + num_heads = 12 args = { "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 1d26986cc..6e13afd8d 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -6,7 +6,7 @@ from fractions import Fraction import torch.nn.functional as F from dataclasses import dataclass import numpy as np -from cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d +from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: @@ -1457,10 +1457,9 @@ class SparseStructureDecoder(nn.Module): return h class Vae(nn.Module): - def __init__(self, config, operations=None): + def __init__(self, init_txt_model, operations=None): super().__init__() operations = operations or torch.nn - init_txt_model = config.get("init_txt_model", False) if init_txt_model: self.txt_dec = SparseUnetVaeDecoder( out_channels=6, diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 4f5542af5..004adbf71 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -109,15 +109,17 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: unet_config = {} unet_config["image_model"] = "trellis2" + + unet_config["init_txt_model"] = False if '{}model.shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: unet_config["init_txt_model"] = True - else: - unet_config["init_txt_model"] = False + + unet_config["resolution"] = 64 if metadata is not None: - if metadata["is_512"] is True: + if "is_512" in metadata and metadata["metadata"]: unet_config["resolution"] = 32 - else: - unet_config["resolution"] = 64 + + unet_config["num_heads"] = 12 return unet_config if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit diff --git a/comfy/sd.py b/comfy/sd.py index be3d1b4f0..25fd3ba7b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -494,11 +494,10 @@ class VAE: self.downscale_ratio = 32 self.latent_channels = 16 elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2 + init_txt_model = False if "txt_dec.blocks.1.16.norm1.weight" in sd: - config["init_txt_model"] = True - else: - config["init_txt_model"] = False - self.first_stage_model = comfy.ldm.trellis2.vae.Vae(config) + init_txt_model = True + self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model) elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 8497b83e2..17ba94ec8 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -198,7 +198,7 @@ class Trellis2Conditioning(IO.ComfyNode): inputs=[ IO.ClipVision.Input("clip_vision_model"), IO.Image.Input("image"), - IO.MultiCombo.Input("background_color", options=["black", "gray", "white"], default="black") + IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black") ], outputs=[ IO.Conditioning.Output(display_name="positive"), @@ -219,7 +219,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="EmptyLatentTrellis2", + node_id="EmptyShapeLatentTrellis2", category="latent/3d", inputs=[ IO.Mesh.Input("structure_output"), @@ -242,7 +242,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="EmptyLatentTrellis2", + node_id="EmptyTextureLatentTrellis2", category="latent/3d", inputs=[ IO.Mesh.Input("structure_output"), @@ -264,7 +264,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="EmptyLatentTrellis2", + node_id="EmptyStructureLatentTrellis2", category="latent/3d", inputs=[ IO.Int.Input("resolution", default=3072, min=1, max=8192), From 2eef826def23599d395bb1854d6b95cb657c8385 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 9 Feb 2026 22:47:50 +0200 Subject: [PATCH 14/93] multiple fixes --- comfy/clip_vision.py | 4 +++ comfy/image_encoders/dino3.py | 32 +++++++++++++++++------ comfy/image_encoders/dino3_large.json | 11 ++++---- comfy/supported_models.py | 6 +++++ comfy_extras/nodes_trellis2.py | 37 +++++++++++++++------------ 5 files changed, 60 insertions(+), 30 deletions(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 1691fca81..71f2200b7 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -9,6 +9,7 @@ import comfy.model_management import comfy.utils import comfy.clip_model import comfy.image_encoders.dino2 +import comfy.image_encoders.dino3 class Output: def __getitem__(self, key): @@ -23,6 +24,7 @@ IMAGE_ENCODERS = { "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection, "dinov2": comfy.image_encoders.dino2.Dinov2Model, + "dinov3": comfy.image_encoders.dino3.DINOv3ViTModel } class ClipVisionModel(): @@ -134,6 +136,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") elif 'encoder.layer.23.layer_scale2.lambda1' in sd: json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json") + elif 'layer.9.attention.o_proj.bias' in sd: # dinov3 + json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino3_large.json") else: return None diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index d07c2c5b8..b27b95b5f 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -4,7 +4,19 @@ import torch.nn as nn from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.flux.math import apply_rope -from dino2 import Dinov2MLP as DINOv3ViTMLP, LayerScale as DINOv3ViTLayerScale +from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale + +class DINOv3ViTMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.GELU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) class DINOv3ViTAttention(nn.Module): def __init__(self, hidden_size, num_attention_heads, device, dtype, operations): @@ -90,6 +102,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module): self.head_dim = hidden_size // num_attention_heads self.num_patches_h = image_size // patch_size self.num_patches_w = image_size // patch_size + self.patch_size = patch_size inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device) self.register_buffer("inv_freq", inv_freq, persistent=False) @@ -106,6 +119,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module): num_patches_h, num_patches_w, dtype=torch.float32, device=device ) + self.inv_freq = self.inv_freq.to(device) angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] angles = angles.flatten(1, 2) angles = angles.tile(2) @@ -140,27 +154,30 @@ class DINOv3ViTEmbeddings(nn.Module): cls_token = self.cls_token.expand(batch_size, -1, -1) register_tokens = self.register_tokens.expand(batch_size, -1, -1) + device = patch_embeddings + cls_token = cls_token.to(device) + register_tokens = register_tokens.to(device) embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) return embeddings class DINOv3ViTLayer(nn.Module): - def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, layerscale_value, mlp_bias, intermediate_size, num_attention_heads, + def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size, num_attention_heads, device, dtype, operations): super().__init__() self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps) self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations) - self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, layerscale_value, device=device, dtype=dtype) + self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) if use_gated_mlp: self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations) else: - self.mlp = DINOv3ViTMLP(hidden_size, device=device, dtype=dtype, operations=operations) - self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, layerscale_value, device=device, dtype=dtype) + self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations) + self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) def forward( self, @@ -188,7 +205,7 @@ class DINOv3ViTLayer(nn.Module): class DINOv3ViTModel(nn.Module): - def __init__(self, config, device, dtype, operations): + def __init__(self, config, dtype, device, operations): super().__init__() num_hidden_layers = config["num_hidden_layers"] hidden_size = config["hidden_size"] @@ -196,7 +213,6 @@ class DINOv3ViTModel(nn.Module): num_register_tokens = config["num_register_tokens"] intermediate_size = config["intermediate_size"] layer_norm_eps = config["layer_norm_eps"] - layerscale_value = config["layerscale_value"] num_channels = config["num_channels"] patch_size = config["patch_size"] rope_theta = config["rope_theta"] @@ -208,7 +224,7 @@ class DINOv3ViTModel(nn.Module): rope_theta, hidden_size, num_attention_heads, image_size=512, patch_size=patch_size, dtype=dtype, device=device ) self.layer = nn.ModuleList( - [DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, layerscale_value=layerscale_value, mlp_bias=True, + [DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, mlp_bias=True, intermediate_size=intermediate_size,num_attention_heads = num_attention_heads, dtype=dtype, device=device, operations=operations) for _ in range(num_hidden_layers)]) diff --git a/comfy/image_encoders/dino3_large.json b/comfy/image_encoders/dino3_large.json index 96263f0d6..53f761a25 100644 --- a/comfy/image_encoders/dino3_large.json +++ b/comfy/image_encoders/dino3_large.json @@ -1,16 +1,15 @@ { - - "hidden_size": 384, + "model_type": "dinov3", + "hidden_size": 1024, "image_size": 224, "initializer_range": 0.02, - "intermediate_size": 1536, + "intermediate_size": 4096, "key_bias": false, "layer_norm_eps": 1e-05, - "layerscale_value": 1.0, "mlp_bias": true, - "num_attention_heads": 6, + "num_attention_heads": 16, "num_channels": 3, - "num_hidden_layers": 12, + "num_hidden_layers": 24, "num_register_tokens": 4, "patch_size": 16, "pos_embed_rescale": 2.0, diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 9e2f17149..3373f78a2 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1251,12 +1251,18 @@ class Trellis2(supported_models_base.BASE): "shift": 3.0, } + memory_usage_factor = 3.5 + latent_format = latent_formats.Trellis2 vae_key_prefix = ["vae."] + clip_vision_prefix = "conditioner.main_image_encoder.model." def get_model(self, state_dict, prefix="", device=None): return model_base.Trellis2(self, device=device) + def clip_target(self, state_dict={}): + return None + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 17ba94ec8..f53d36736 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -3,10 +3,8 @@ from comfy_api.latest import ComfyExtension, IO import torch from comfy.ldm.trellis2.model import SparseTensor import comfy.model_management -from PIL import Image -import PIL -import numpy as np from comfy.nested_tensor import NestedTensor +from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode shape_slat_normalization = { "mean": torch.tensor([ @@ -76,23 +74,30 @@ def run_conditioning( # Convert image to PIL if image.dim() == 4: - pil_image = (image[0] * 255).clip(0, 255).astype(torch.uint8) + pil_image = (image[0] * 255).clip(0, 255).to(torch.uint8) else: - pil_image = (image * 255).clip(0, 255).astype(torch.uint8) + pil_image = (image * 255).clip(0, 255).to(torch.uint8) + pil_image = pil_image.movedim(-1, 0) pil_image = smart_crop_square(pil_image, background_color=bg_color) model.image_size = 512 def set_image_size(image, image_size=512): - image = PIL.from_array(image) - image = [i.resize((image_size, image_size), Image.LANCZOS) for i in image] - image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] - image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] - image = torch.stack(image).to(torch_device) - return image + if image.ndim == 3: + image = image.unsqueeze(0) - pil_image = set_image_size(image, 512) - cond_512 = model([pil_image]) + to_pil = ToPILImage() + to_tensor = ToTensor() + resizer = Resize((image_size, image_size), interpolation=InterpolationMode.LANCZOS) + + pil_img = to_pil(image.squeeze(0)) + resized_pil = resizer(pil_img) + image = to_tensor(resized_pil).unsqueeze(0) + + return image.to(torch_device).float() + + pil_image = set_image_size(pil_image, 512) + cond_512 = model(pil_image) cond_1024 = None if include_1024: @@ -267,7 +272,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): node_id="EmptyStructureLatentTrellis2", category="latent/3d", inputs=[ - IO.Int.Input("resolution", default=3072, min=1, max=8192), + IO.Int.Input("resolution", default=256, min=1, max=8192), IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), ], outputs=[ @@ -275,9 +280,9 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): ] ) @classmethod - def execute(cls, res, batch_size): + def execute(cls, resolution, batch_size): in_channels = 32 - latent = torch.randn(batch_size, in_channels, res, res, res) + latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) From f4059c189e92f7a1a6c65014c4901f9ad67b6b25 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 11 Feb 2026 01:27:54 +0200 Subject: [PATCH 15/93] dinov3 fixes + other --- comfy/image_encoders/dino3.py | 39 +++++++++++++++++++++++++++------- comfy/ldm/trellis2/model.py | 17 +++++++++++---- comfy_extras/nodes_trellis2.py | 14 ++++++------ 3 files changed, 51 insertions(+), 19 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index b27b95b5f..ef04556da 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -24,7 +24,6 @@ class DINOv3ViTAttention(nn.Module): self.embed_dim = hidden_size self.num_heads = num_attention_heads self.head_dim = self.embed_dim // self.num_heads - self.is_causal = False self.scaling = self.head_dim**-0.5 self.is_causal = False @@ -53,18 +52,41 @@ class DINOv3ViTAttention(nn.Module): key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) - cos, sin = position_embeddings - position_embeddings = torch.stack([cos, sin], dim = -1) - query_states, key_states = apply_rope(query_states, key_states, position_embeddings) + if position_embeddings is not None: + cos, sin = position_embeddings - attn_output, attn_weights = optimized_attention_for_device( - query_states, key_states, value_states, attention_mask, skip_reshape=True, skip_output_reshape=True + num_tokens = query_states.shape[-2] + num_patches = cos.shape[-2] + num_prefix_tokens = num_tokens - num_patches + + q_prefix, q_patches = query_states.split((num_prefix_tokens, num_patches), dim=-2) + k_prefix, k_patches = key_states.split((num_prefix_tokens, num_patches), dim=-2) + + cos = cos[..., :self.head_dim // 2] + sin = sin[..., :self.head_dim // 2] + + f_cis_0 = torch.stack([cos, sin], dim=-1) + f_cis_1 = torch.stack([-sin, cos], dim=-1) + freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1) + + while freqs_cis.ndim < q_patches.ndim + 1: + freqs_cis = freqs_cis.unsqueeze(0) + + q_patches, k_patches = apply_rope(q_patches, k_patches, freqs_cis) + + query_states = torch.cat((q_prefix, q_patches), dim=-2) + key_states = torch.cat((k_prefix, k_patches), dim=-2) + + attn = optimized_attention_for_device(query_states.device, mask=False) + + attn_output = attn( + query_states, key_states, value_states, self.num_heads, attention_mask, skip_reshape=True, skip_output_reshape=True ) attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + return attn_output class DINOv3ViTGatedMLP(nn.Module): def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations): @@ -187,7 +209,7 @@ class DINOv3ViTLayer(nn.Module): ) -> torch.Tensor: residual = hidden_states hidden_states = self.norm1(hidden_states) - hidden_states, _ = self.attention( + hidden_states = self.attention( hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings, @@ -250,6 +272,7 @@ class DINOv3ViTModel(nn.Module): position_embeddings=position_embeddings, ) + self.norm = self.norm.to(hidden_states.device) sequence_output = self.norm(hidden_states) pooled_output = sequence_output[:, 0, :] diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 9aab045c7..8bc8e8f7a 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -113,6 +113,13 @@ class SparseRotaryPositionEmbedder(nn.Module): q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis) return q.replace(q_feats), k.replace(k_feats) + @staticmethod + def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases.unsqueeze(-2) + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + class RotaryPositionEmbedder(SparseRotaryPositionEmbedder): def forward(self, indices: torch.Tensor) -> torch.Tensor: phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) @@ -559,6 +566,7 @@ class MultiHeadAttention(nn.Module): def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: B, L, C = x.shape if self._type == "self": + x = x.to(next(self.to_qkv.parameters()).dtype) qkv = self.to_qkv(x) qkv = qkv.reshape(B, L, 3, self.num_heads, -1) @@ -688,7 +696,7 @@ class SparseStructureFlowModel(nn.Module): num_heads: Optional[int] = None, num_head_channels: Optional[int] = 64, mlp_ratio: float = 4, - pe_mode: Literal["ape", "rope"] = "ape", + pe_mode: Literal["ape", "rope"] = "rope", rope_freq: Tuple[float, float] = (1.0, 10000.0), dtype: str = 'float32', use_checkpoint: bool = False, @@ -756,14 +764,14 @@ class SparseStructureFlowModel(nn.Module): self.out_layer = nn.Linear(model_channels, out_channels) def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3) assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() + h = h.to(next(self.input_layer.parameters()).dtype) h = self.input_layer(h) - if self.pe_mode == "ape": - h = h + self.pos_emb[None] t_emb = self.t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) @@ -816,7 +824,8 @@ class Trellis2(nn.Module): self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) def forward(self, x: NestedTensor, timestep, context, **kwargs): - x = x.tensors[0] + if isinstance(x, NestedTensor): + x = x.tensors[0] embeds = kwargs.get("embeds") if not hasattr(x, "feats"): mode = "structure_generation" diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index f53d36736..4eff2dbc3 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -97,13 +97,13 @@ def run_conditioning( return image.to(torch_device).float() pil_image = set_image_size(pil_image, 512) - cond_512 = model(pil_image) + cond_512 = model(pil_image)[0] cond_1024 = None if include_1024: model.image_size = 1024 pil_image = set_image_size(pil_image, 1024) - cond_1024 = model([pil_image]) + cond_1024 = model(pil_image)[0] neg_cond = torch.zeros_like(cond_512) @@ -115,7 +115,7 @@ def run_conditioning( conditioning['cond_1024'] = cond_1024.to(device) preprocessed_tensor = pil_image.to(torch.float32) / 255.0 - preprocessed_tensor = torch.from_numpy(preprocessed_tensor).unsqueeze(0) + preprocessed_tensor = preprocessed_tensor.unsqueeze(0) return conditioning, preprocessed_tensor @@ -217,7 +217,7 @@ class Trellis2Conditioning(IO.ComfyNode): conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color) embeds = conditioning["cond_1024"] # should add that positive = [[conditioning["cond_512"], {"embeds": embeds}]] - negative = [[conditioning["cond_neg"], {"embeds": embeds}]] + negative = [[conditioning["neg_cond"], {"embeds": embeds}]] return IO.NodeOutput(positive, negative) class EmptyShapeLatentTrellis2(IO.ComfyNode): @@ -272,7 +272,6 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): node_id="EmptyStructureLatentTrellis2", category="latent/3d", inputs=[ - IO.Int.Input("resolution", default=256, min=1, max=8192), IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), ], outputs=[ @@ -280,8 +279,9 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): ] ) @classmethod - def execute(cls, resolution, batch_size): - in_channels = 32 + def execute(cls, batch_size): + in_channels = 8 + resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) From b7764479c263c9d41bd077c7453b8d4c15551a34 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 11 Feb 2026 20:33:59 +0200 Subject: [PATCH 16/93] debugging --- comfy/ldm/trellis2/attention.py | 7 ++++-- comfy/ldm/trellis2/model.py | 8 +++---- comfy/ldm/trellis2/vae.py | 41 +++++++++++++++++++++++++++------ comfy/sd.py | 4 ++++ comfy_extras/nodes_trellis2.py | 30 +++++++++++++++--------- 5 files changed, 65 insertions(+), 25 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 3038f4023..e6aa50842 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -14,6 +14,7 @@ except: def scaled_dot_product_attention(*args, **kwargs): num_all_args = len(args) + len(kwargs) + q = None if num_all_args == 1: qkv = args[0] if len(args) > 0 else kwargs['qkv'] @@ -26,8 +27,10 @@ def scaled_dot_product_attention(*args, **kwargs): k = args[1] if len(args) > 1 else kwargs['k'] v = args[2] if len(args) > 2 else kwargs['v'] - # TODO verify - heads = q or qkv + if q is not None: + heads = q + else: + heads = qkv heads = heads.shape[2] if optimized_attention.__name__ == 'attention_xformers': diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 8bc8e8f7a..17286a553 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -7,7 +7,6 @@ from comfy.ldm.trellis2.attention import ( sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder -from comfy.nested_tensor import NestedTensor from comfy.ldm.flux.math import apply_rope, apply_rope1 class SparseGELU(nn.GELU): @@ -586,6 +585,7 @@ class MultiHeadAttention(nn.Module): else: Lkv = context.shape[1] q = self.to_q(x) + context = context.to(next(self.to_kv.parameters()).dtype) kv = self.to_kv(context) q = q.reshape(B, L, self.num_heads, -1) kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) @@ -782,6 +782,7 @@ class SparseStructureFlowModel(nn.Module): h = block(h, t_emb, cond, self.rope_phases) h = manual_cast(h, x.dtype) h = F.layer_norm(h, h.shape[-1:]) + h = h.to(next(self.out_layer.parameters()).dtype) h = self.out_layer(h) h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous() @@ -823,9 +824,7 @@ class Trellis2(nn.Module): args.pop("out_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) - def forward(self, x: NestedTensor, timestep, context, **kwargs): - if isinstance(x, NestedTensor): - x = x.tensors[0] + def forward(self, x, timestep, context, **kwargs): embeds = kwargs.get("embeds") if not hasattr(x, "feats"): mode = "structure_generation" @@ -843,6 +842,5 @@ class Trellis2(nn.Module): timestep = timestep_reshift(timestep) out = self.structure_model(x, timestep, context) - out = NestedTensor([out]) out.generation_mode = mode return out diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 6e13afd8d..57bf78346 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -10,9 +10,6 @@ from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_ def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: - """ - 3D pixel shuffle. - """ B, C, H, W, D = x.shape C_ = C // scale_factor**3 x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) @@ -967,6 +964,25 @@ class SparseLinear(nn.Linear): return input.replace(super().forward(input.feats)) +MIX_PRECISION_MODULES = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + SparseConv3d, + SparseLinear, +) + + +def convert_module_to_f16(l): + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.half() + + class SparseUnetVaeEncoder(nn.Module): """ @@ -1381,8 +1397,12 @@ class ResBlock3d(nn.Module): self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: + self.norm1 = self.norm1.to(torch.float32) + self.norm2 = self.norm2.to(torch.float32) h = self.norm1(x) h = F.silu(h) + dtype = next(self.conv1.parameters()).dtype + h = h.to(dtype) h = self.conv1(h) h = self.norm2(h) h = F.silu(h) @@ -1400,7 +1420,7 @@ class SparseStructureDecoder(nn.Module): channels: List[int], num_res_blocks_middle: int = 2, norm_type = "layer", - use_fp16: bool = False, + use_fp16: bool = True, ): super().__init__() self.out_channels = out_channels @@ -1439,20 +1459,27 @@ class SparseStructureDecoder(nn.Module): if use_fp16: self.convert_to_fp16() - @property def device(self) -> torch.device: return next(self.parameters()).device + def convert_to_fp16(self) -> None: + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = next(self.input_layer.parameters()).dtype + x = x.to(dtype) h = self.input_layer(x) h = h.type(self.dtype) - h = self.middle_block(h) for block in self.blocks: h = block(h) - h = h.type(x.dtype) + h = h.to(torch.float32) + self.out_layer = self.out_layer.to(torch.float32) h = self.out_layer(h) return h diff --git a/comfy/sd.py b/comfy/sd.py index 25fd3ba7b..276e87d2a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -497,6 +497,10 @@ class VAE: init_txt_model = False if "txt_dec.blocks.1.16.norm1.weight" in sd: init_txt_model = True + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + # TODO + self.memory_used_decode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model) elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 4eff2dbc3..c735469be 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -3,7 +3,7 @@ from comfy_api.latest import ComfyExtension, IO import torch from comfy.ldm.trellis2.model import SparseTensor import comfy.model_management -from comfy.nested_tensor import NestedTensor +import comfy.model_patcher from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode shape_slat_normalization = { @@ -137,14 +137,15 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples: NestedTensor, vae, resolution): - samples = samples.tensors[0] + def execute(cls, samples, vae, resolution): + vae = vae.first_stage_model + samples = samples["samples"] std = shape_slat_normalization["std"] mean = shape_slat_normalization["mean"] samples = samples * std + mean mesh, subs = vae.decode_shape_slat(resolution, samples) - return mesh, subs + return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @classmethod @@ -164,13 +165,14 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, vae, shape_subs): - samples = samples.tensors[0] + vae = vae.first_stage_model + samples = samples["samples"] std = tex_slat_normalization["std"] mean = tex_slat_normalization["mean"] samples = samples * std + mean mesh = vae.decode_tex_slat(samples, shape_subs) - return mesh + return IO.NodeOutput(mesh) class VaeDecodeStructureTrellis2(IO.ComfyNode): @classmethod @@ -189,10 +191,19 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): @classmethod def execute(cls, samples, vae): + vae = vae.first_stage_model decoder = vae.struct_dec + load_device = comfy.model_management.get_torch_device() + decoder = comfy.model_patcher.ModelPatcher( + decoder, load_device=load_device, offload_device=comfy.model_management.vae_offload_device() + ) + comfy.model_management.load_model_gpu(decoder) + decoder = decoder.model + samples = samples["samples"] + samples = samples.to(load_device) decoded = decoder(samples)>0 coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int() - return coords + return IO.NodeOutput(coords) class Trellis2Conditioning(IO.ComfyNode): @classmethod @@ -240,7 +251,6 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): coords = structure_output # or structure_output.coords in_channels = 32 latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) - latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @@ -262,7 +272,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): # TODO in_channels = 32 latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1])) - latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @@ -283,7 +292,6 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): in_channels = 8 resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) - latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) def simplify_fn(vertices, faces, target=100000): @@ -469,7 +477,7 @@ class PostProcessMesh(IO.ComfyNode): mesh.vertices = verts mesh.faces = faces - return mesh + return IO.NodeOutput(mesh) class Trellis2Extension(ComfyExtension): @override From 8e90bdc1ccad930527cbf3cd5170590ff3eb7902 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 12 Feb 2026 00:30:51 +0200 Subject: [PATCH 17/93] small error fixes --- comfy_extras/nodes_trellis2.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index c735469be..9fd257785 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,5 +1,5 @@ from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO +from comfy_api.latest import ComfyExtension, IO, Types import torch from comfy.ldm.trellis2.model import SparseTensor import comfy.model_management @@ -185,7 +185,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): IO.Vae.Input("vae"), ], outputs=[ - IO.Mesh.Output("structure_output"), + IO.Voxel.Output("structure_output"), ] ) @@ -194,16 +194,15 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): vae = vae.first_stage_model decoder = vae.struct_dec load_device = comfy.model_management.get_torch_device() - decoder = comfy.model_patcher.ModelPatcher( - decoder, load_device=load_device, offload_device=comfy.model_management.vae_offload_device() - ) - comfy.model_management.load_model_gpu(decoder) - decoder = decoder.model + offload_device = comfy.model_management.vae_offload_device() + decoder = decoder.to(load_device) samples = samples["samples"] samples = samples.to(load_device) decoded = decoder(samples)>0 - coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int() - return IO.NodeOutput(coords) + decoder.to(offload_device) + comfy.model_management.get_offload_stream + out = Types.VOXEL(decoded.squeeze(1).float()) + return IO.NodeOutput(out) class Trellis2Conditioning(IO.ComfyNode): @classmethod @@ -238,7 +237,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): node_id="EmptyShapeLatentTrellis2", category="latent/3d", inputs=[ - IO.Mesh.Input("structure_output"), + IO.Voxel.Input("structure_output"), ], outputs=[ IO.Latent.Output(), @@ -247,8 +246,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): - # i will see what i have to do here - coords = structure_output # or structure_output.coords + decoded = structure_output.data + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1) in_channels = 32 latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) @@ -260,7 +259,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): node_id="EmptyTextureLatentTrellis2", category="latent/3d", inputs=[ - IO.Mesh.Input("structure_output"), + IO.Voxel.Input("structure_output"), ], outputs=[ IO.Latent.Output(), @@ -271,7 +270,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): def execute(cls, structure_output): # TODO in_channels = 32 - latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1])) + latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1])) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): From 0e239dc39b878b1bc3357fadeece4c76ff335892 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 12 Feb 2026 23:35:57 +0200 Subject: [PATCH 18/93] fixed attn (couldn't use apply_rope for dino3) --- comfy/image_encoders/dino3.py | 45 +++++++++++++++++------------------ 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index ef04556da..9cb231e28 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn from comfy.ldm.modules.attention import optimized_attention_for_device -from comfy.ldm.flux.math import apply_rope from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale class DINOv3ViTMLP(nn.Module): @@ -18,6 +17,26 @@ class DINOv3ViTMLP(nn.Module): def forward(self, x): return self.down_proj(self.act_fn(self.up_proj(x))) +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) +def apply_rotary_pos_emb(q, k, cos, sin, **kwargs): + num_tokens = q.shape[-2] + num_patches = sin.shape[-2] + num_prefix_tokens = num_tokens - num_patches + + q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2) + k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2) + + q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin) + k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin) + + q = torch.cat((q_prefix_tokens, q_patches), dim=-2) + k = torch.cat((k_prefix_tokens, k_patches), dim=-2) + + return q, k + class DINOv3ViTAttention(nn.Module): def __init__(self, hidden_size, num_attention_heads, device, dtype, operations): super().__init__() @@ -54,28 +73,7 @@ class DINOv3ViTAttention(nn.Module): if position_embeddings is not None: cos, sin = position_embeddings - - num_tokens = query_states.shape[-2] - num_patches = cos.shape[-2] - num_prefix_tokens = num_tokens - num_patches - - q_prefix, q_patches = query_states.split((num_prefix_tokens, num_patches), dim=-2) - k_prefix, k_patches = key_states.split((num_prefix_tokens, num_patches), dim=-2) - - cos = cos[..., :self.head_dim // 2] - sin = sin[..., :self.head_dim // 2] - - f_cis_0 = torch.stack([cos, sin], dim=-1) - f_cis_1 = torch.stack([-sin, cos], dim=-1) - freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1) - - while freqs_cis.ndim < q_patches.ndim + 1: - freqs_cis = freqs_cis.unsqueeze(0) - - q_patches, k_patches = apply_rope(q_patches, k_patches, freqs_cis) - - query_states = torch.cat((q_prefix, q_patches), dim=-2) - key_states = torch.cat((k_prefix, k_patches), dim=-2) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) attn = optimized_attention_for_device(query_states.device, mask=False) @@ -83,6 +81,7 @@ class DINOv3ViTAttention(nn.Module): query_states, key_states, value_states, self.num_heads, attention_mask, skip_reshape=True, skip_output_reshape=True ) + attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() attn_output = self.o_proj(attn_output) From 0e51bee64ff909f5dff90a2782e4973818196624 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 13 Feb 2026 00:10:25 +0200 Subject: [PATCH 19/93] more reliable detection --- comfy/ldm/trellis2/model.py | 8 +------- comfy_extras/nodes_trellis2.py | 3 +++ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 17286a553..760372f5c 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -826,13 +826,7 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): embeds = kwargs.get("embeds") - if not hasattr(x, "feats"): - mode = "structure_generation" - else: - if x.feats.shape[1] == 32: - mode = "shape_generation" - else: - mode = "texture_generation" + mode = x.generation_mode if mode == "shape_generation": # TODO out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)])) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 9fd257785..a5c387c1d 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -250,6 +250,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1) in_channels = 32 latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) + latent.generation_mode = "shape_generation" return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @@ -271,6 +272,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): # TODO in_channels = 32 latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1])) + latent.generation_mode = "texture_generation" return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @@ -291,6 +293,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): in_channels = 8 resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) + latent.generation_mode = "structure_generation" return IO.NodeOutput({"samples": latent, "type": "trellis2"}) def simplify_fn(vertices, faces, target=100000): From 92aa058a587900c24ff163e89288295f9d334acf Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 13 Feb 2026 21:05:59 +0200 Subject: [PATCH 20/93] . --- comfy/image_encoders/dino3.py | 2 ++ comfy/ldm/trellis2/model.py | 18 +++++++++++------- comfy_extras/nodes_trellis2.py | 9 +++------ 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 9cb231e28..ce6b2edd9 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -228,6 +228,8 @@ class DINOv3ViTLayer(nn.Module): class DINOv3ViTModel(nn.Module): def __init__(self, config, dtype, device, operations): super().__init__() + if dtype == torch.float16: + dtype = torch.bfloat16 num_hidden_layers = config["num_hidden_layers"] hidden_size = config["hidden_size"] num_attention_heads = config["num_attention_heads"] diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 760372f5c..76fe8ad19 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -678,10 +678,7 @@ class ModulatedTransformerCrossBlock(nn.Module): return x def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: - if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False) - else: - return self._forward(x, mod, context, phases) + return self._forward(x, mod, context, phases) class SparseStructureFlowModel(nn.Module): @@ -823,18 +820,25 @@ class Trellis2(nn.Module): self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) args.pop("out_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) + self.guidance_interval = [0.6, 1.0] + self.guidance_interval_txt = [0.6, 0.9] def forward(self, x, timestep, context, **kwargs): embeds = kwargs.get("embeds") - mode = x.generation_mode + mode = kwargs.get("generation_mode") + sigmas = kwargs.get("sigmas")[0].item() + cond = context.chunk(2) + shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] + txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] + if mode == "shape_generation": # TODO out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)])) elif mode == "texture_generation": - out = self.shape2txt(x, timestep, context) + out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure timestep = timestep_reshift(timestep) - out = self.structure_model(x, timestep, context) + out = self.structure_model(x, timestep, context if not shape_rule else cond) out.generation_mode = mode return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index a5c387c1d..560751091 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -250,8 +250,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1) in_channels = 32 latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) - latent.generation_mode = "shape_generation" - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation"}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -272,8 +271,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): # TODO in_channels = 32 latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1])) - latent.generation_mode = "texture_generation" - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod @@ -293,8 +291,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): in_channels = 8 resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) - latent.generation_mode = "structure_generation" - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "structure_generation"}) def simplify_fn(vertices, faces, target=100000): From 91fa563b21a745041c50bb7f7e5038330e01ae38 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 16 Feb 2026 01:53:53 +0200 Subject: [PATCH 21/93] rewriting conditioning logic + model code addition --- comfy/ldm/trellis2/model.py | 10 ++- comfy_extras/nodes_trellis2.py | 125 ++++++++++++++++----------------- 2 files changed, 70 insertions(+), 65 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 76fe8ad19..4c398294a 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -826,8 +826,11 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): embeds = kwargs.get("embeds") mode = kwargs.get("generation_mode") - sigmas = kwargs.get("sigmas")[0].item() - cond = context.chunk(2) + transformer_options = kwargs.get("transformer_options") + sigmas = transformer_options.get("sigmas")[0].item() + if sigmas < 1.00001: + timestep *= 1000.0 + cond = context.chunk(2)[1] shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] @@ -838,6 +841,9 @@ class Trellis2(nn.Module): out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure timestep = timestep_reshift(timestep) + if shape_rule: + x = x[0].unsqueeze(0) + timestep = timestep[0] out = self.structure_model(x, timestep, context if not shape_rule else cond) out.generation_mode = mode diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 560751091..4d97129eb 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -4,7 +4,6 @@ import torch from comfy.ldm.trellis2.model import SparseTensor import comfy.model_management import comfy.model_patcher -from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode shape_slat_normalization = { "mean": torch.tensor([ @@ -36,86 +35,85 @@ tex_slat_normalization = { ])[None] } -def smart_crop_square( - image: torch.Tensor, - background_color=(128, 128, 128), -): +dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) +dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + +def smart_crop_square(image, mask, margin_ratio=0.1, bg_color=(128, 128, 128)): + nz = torch.nonzero(mask[0] > 0.5) + if nz.shape[0] == 0: + C, H, W = image.shape + side = max(H, W) + canvas = torch.full((C, side, side), 0.5, device=image.device) # Gray + canvas[:, (side-H)//2:(side-H)//2+H, (side-W)//2:(side-W)//2+W] = image + return canvas + + y_min, x_min = nz.min(dim=0)[0] + y_max, x_max = nz.max(dim=0)[0] + + obj_w, obj_h = x_max - x_min, y_max - y_min + center_x, center_y = (x_min + x_max) / 2, (y_min + y_max) / 2 + + side = int(max(obj_w, obj_h) * (1 + margin_ratio * 2)) + half_side = side / 2 + + x1, y1 = int(center_x - half_side), int(center_y - half_side) + x2, y2 = x1 + side, y1 + side + C, H, W = image.shape - size = max(H, W) - canvas = torch.empty( - (C, size, size), - dtype=image.dtype, - device=image.device - ) + canvas = torch.ones((C, side, side), device=image.device) for c in range(C): - canvas[c].fill_(background_color[c]) - top = (size - H) // 2 - left = (size - W) // 2 - canvas[:, top:top + H, left:left + W] = image + canvas[c] *= (bg_color[c] / 255.0) + + src_x1, src_y1 = max(0, x1), max(0, y1) + src_x2, src_y2 = min(W, x2), min(H, y2) + + dst_x1, dst_y1 = max(0, -x1), max(0, -y1) + dst_x2 = dst_x1 + (src_x2 - src_x1) + dst_y2 = dst_y1 + (src_y2 - src_y1) + + canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = image[:, src_y1:src_y2, src_x1:src_x2] return canvas -def run_conditioning( - model, - image: torch.Tensor, - include_1024: bool = True, - background_color: str = "black", -): - # TODO: should check if normalization was applied in these steps - model = model.model - device = comfy.model_management.intermediate_device() # replaces .cpu() - torch_device = comfy.model_management.get_torch_device() # replaces .cuda() - bg_colors = { - "black": (0, 0, 0), - "gray": (128, 128, 128), - "white": (255, 255, 255), - } - bg_color = bg_colors.get(background_color, (128, 128, 128)) +def run_conditioning(model, image, mask, include_1024 = True, background_color = "black"): + model_internal = model.model + device = comfy.model_management.intermediate_device() + torch_device = comfy.model_management.get_torch_device() - # Convert image to PIL - if image.dim() == 4: - pil_image = (image[0] * 255).clip(0, 255).to(torch.uint8) - else: - pil_image = (image * 255).clip(0, 255).to(torch.uint8) + bg_colors = {"black": (0, 0, 0), "gray": (128, 128, 128), "white": (255, 255, 255)} + bg_rgb = bg_colors.get(background_color, (128, 128, 128)) - pil_image = pil_image.movedim(-1, 0) - pil_image = smart_crop_square(pil_image, background_color=bg_color) + img_t = image[0].movedim(-1, 0).to(torch_device).float() + mask_t = mask[0].to(torch_device).float() + if mask_t.ndim == 2: + mask_t = mask_t.unsqueeze(0) - model.image_size = 512 - def set_image_size(image, image_size=512): - if image.ndim == 3: - image = image.unsqueeze(0) + cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb) - to_pil = ToPILImage() - to_tensor = ToTensor() - resizer = Resize((image_size, image_size), interpolation=InterpolationMode.LANCZOS) + def prepare_tensor(img, size): + resized = torch.nn.functional.interpolate( + img.unsqueeze(0), size=(size, size), mode='bicubic', align_corners=False + ) + return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) - pil_img = to_pil(image.squeeze(0)) - resized_pil = resizer(pil_img) - image = to_tensor(resized_pil).unsqueeze(0) - - return image.to(torch_device).float() - - pil_image = set_image_size(pil_image, 512) - cond_512 = model(pil_image)[0] + model_internal.image_size = 512 + input_512 = prepare_tensor(cropped_img, 512) + cond_512 = model_internal(input_512)[0] cond_1024 = None if include_1024: - model.image_size = 1024 - pil_image = set_image_size(pil_image, 1024) - cond_1024 = model(pil_image)[0] - - neg_cond = torch.zeros_like(cond_512) + model_internal.image_size = 1024 + input_1024 = prepare_tensor(cropped_img, 1024) + cond_1024 = model_internal(input_1024)[0] conditioning = { 'cond_512': cond_512.to(device), - 'neg_cond': neg_cond.to(device), + 'neg_cond': torch.zeros_like(cond_512).to(device), } if cond_1024 is not None: conditioning['cond_1024'] = cond_1024.to(device) - preprocessed_tensor = pil_image.to(torch.float32) / 255.0 - preprocessed_tensor = preprocessed_tensor.unsqueeze(0) + preprocessed_tensor = cropped_img.movedim(0, -1).unsqueeze(0).cpu() return conditioning, preprocessed_tensor @@ -213,6 +211,7 @@ class Trellis2Conditioning(IO.ComfyNode): inputs=[ IO.ClipVision.Input("clip_vision_model"), IO.Image.Input("image"), + IO.Mask.Input("mask"), IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black") ], outputs=[ @@ -222,9 +221,9 @@ class Trellis2Conditioning(IO.ComfyNode): ) @classmethod - def execute(cls, clip_vision_model, image, background_color) -> IO.NodeOutput: + def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput: # could make 1024 an option - conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color) + conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) embeds = conditioning["cond_1024"] # should add that positive = [[conditioning["cond_512"], {"embeds": embeds}]] negative = [[conditioning["neg_cond"], {"embeds": embeds}]] From c14317d6e0d574253fcda8184470683d4ea9ded6 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 17 Feb 2026 00:10:48 +0200 Subject: [PATCH 22/93] postprocessing node fixes + model small fixes --- comfy/ldm/trellis2/model.py | 7 ++- comfy_extras/nodes_trellis2.py | 87 ++++++++++++++++------------------ 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 4c398294a..eb410fe8b 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -787,8 +787,10 @@ class SparseStructureFlowModel(nn.Module): return h def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0): + t_shifted /= 1000.0 t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1)) t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear) + t_new *= 1000.0 return t_new class Trellis2(nn.Module): @@ -841,10 +843,13 @@ class Trellis2(nn.Module): out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure timestep = timestep_reshift(timestep) + orig_bsz = x.shape[0] if shape_rule: x = x[0].unsqueeze(0) - timestep = timestep[0] + timestep = timestep[0].unsqueeze(0) out = self.structure_model(x, timestep, context if not shape_rule else cond) + if shape_rule: + out = out.repeat(orig_bsz, 1, 1, 1, 1) out.generation_mode = mode return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 4d97129eb..ad9881db7 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -295,7 +295,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): def simplify_fn(vertices, faces, target=100000): if vertices.shape[0] <= target: - return + return vertices, faces min_feat = vertices.min(dim=0)[0] max_feat = vertices.max(dim=0)[0] @@ -334,6 +334,19 @@ def simplify_fn(vertices, faces, target=100000): return final_vertices, final_faces def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): + is_batched = vertices.ndim == 3 + if is_batched: + batch_size = vertices.shape[0] + if batch_size > 1: + v_out, f_out = [], [] + for i in range(batch_size): + v, f = fill_holes_fn(vertices[i], faces[i], max_hole_perimeter) + v_out.append(v) + f_out.append(f) + return torch.stack(v_out), torch.stack(f_out) + + vertices = vertices.squeeze(0) + faces = faces.squeeze(0) device = vertices.device orig_vertices = vertices @@ -346,24 +359,23 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): ], dim=0) edges_sorted, _ = torch.sort(edges, dim=1) - unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) - boundary_mask = counts == 1 boundary_edges_sorted = unique_edges[boundary_mask] if boundary_edges_sorted.shape[0] == 0: - return + if is_batched: + return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) + return orig_vertices, orig_faces + max_idx = vertices.shape[0] - _, inverse_indices, counts_packed = torch.unique( - torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1], - return_inverse=True, return_counts=True - ) + packed_edges_all = torch.sort(edges, dim=1).values + packed_edges_all = packed_edges_all[:, 0] * max_idx + packed_edges_all[:, 1] - boundary_packed_mask = counts_packed == 1 - is_boundary_edge = boundary_packed_mask[inverse_indices] + packed_boundary = boundary_edges_sorted[:, 0] * max_idx + boundary_edges_sorted[:, 1] + is_boundary_edge = torch.isin(packed_edges_all, packed_boundary) active_boundary_edges = edges[is_boundary_edge] adj = {} @@ -373,78 +385,61 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): loops = [] visited_edges = set() - - possible_starts = list(adj.keys()) - processed_nodes = set() - - for start_node in possible_starts: - if start_node in processed_nodes: - continue - - current_loop = [] - curr = start_node - + for start_node in list(adj.keys()): + if start_node in processed_nodes: continue + current_loop, curr = [], start_node while curr in adj: next_node = adj[curr] - if (curr, next_node) in visited_edges: - break - + if (curr, next_node) in visited_edges: break visited_edges.add((curr, next_node)) processed_nodes.add(curr) current_loop.append(curr) - curr = next_node - if curr == start_node: loops.append(current_loop) break - - if len(current_loop) > len(edges_np): - break + if len(current_loop) > len(edges_np): break if not loops: - return + if is_batched: return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) + return orig_vertices, orig_faces new_faces = [] - v_offset = vertices.shape[0] - valid_new_verts = [] for loop_indices in loops: - if len(loop_indices) < 3: - continue - + if len(loop_indices) < 3: continue loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) loop_verts = vertices[loop_tensor] - diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) perimeter = torch.norm(diffs, dim=1).sum() - if perimeter > max_hole_perimeter: - continue + if perimeter > max_hole_perimeter: continue center = loop_verts.mean(dim=0) valid_new_verts.append(center) - c_idx = v_offset v_offset += 1 num_v = len(loop_indices) for i in range(num_v): - v_curr = loop_indices[i] - v_next = loop_indices[(i + 1) % num_v] + v_curr, v_next = loop_indices[i], loop_indices[(i + 1) % num_v] new_faces.append([v_curr, v_next, c_idx]) if len(valid_new_verts) > 0: added_vertices = torch.stack(valid_new_verts, dim=0) added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) + vertices = torch.cat([orig_vertices, added_vertices], dim=0) + faces = torch.cat([orig_faces, added_faces], dim=0) + else: + vertices, faces = orig_vertices, orig_faces - vertices_f = torch.cat([orig_vertices, added_vertices], dim=0) - faces_f = torch.cat([orig_faces, added_faces], dim=0) + if is_batched: + return vertices.unsqueeze(0), faces.unsqueeze(0) - return vertices_f, faces_f + return vertices, faces class PostProcessMesh(IO.ComfyNode): @classmethod @@ -454,8 +449,8 @@ class PostProcessMesh(IO.ComfyNode): category="latent/3d", inputs=[ IO.Mesh.Input("mesh"), - IO.Int.Input("simplify", default=100_000, min=0), # max? - IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0) + IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), # max? + IO.Float.Input("fill_holes_perimeter", default=0.003, min=0.0, step=0.0001) ], outputs=[ IO.Mesh.Output("output_mesh"), From ff04ef555854f524bce57c8f932500b67282844d Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 18 Feb 2026 21:52:40 +0200 Subject: [PATCH 23/93] fix run_conditioning --- comfy_extras/nodes_trellis2.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index ad9881db7..bf457135c 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -71,7 +71,14 @@ def smart_crop_square(image, mask, margin_ratio=0.1, bg_color=(128, 128, 128)): dst_x2 = dst_x1 + (src_x2 - src_x1) dst_y2 = dst_y1 + (src_y2 - src_y1) - canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = image[:, src_y1:src_y2, src_x1:src_x2] + img_crop = image[:, src_y1:src_y2, src_x1:src_x2] + mask_crop = mask[0, src_y1:src_y2, src_x1:src_x2] + + bg_val = torch.tensor(bg_color, device=image.device, dtype=image.dtype).view(-1, 1, 1) / 255.0 + + masked_crop = img_crop * mask_crop + bg_val * (1.0 - mask_crop) + + canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = masked_crop return canvas From 0a49718194d3f3bf42330cb88f8a1bbb7ade55fe Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 18 Feb 2026 21:54:05 +0200 Subject: [PATCH 24/93] .. --- comfy_extras/nodes_trellis2.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index bf457135c..817769d08 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -394,11 +394,13 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): visited_edges = set() processed_nodes = set() for start_node in list(adj.keys()): - if start_node in processed_nodes: continue + if start_node in processed_nodes: + continue current_loop, curr = [], start_node while curr in adj: next_node = adj[curr] - if (curr, next_node) in visited_edges: break + if (curr, next_node) in visited_edges: + break visited_edges.add((curr, next_node)) processed_nodes.add(curr) current_loop.append(curr) @@ -406,10 +408,12 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): if curr == start_node: loops.append(current_loop) break - if len(current_loop) > len(edges_np): break + if len(current_loop) > len(edges_np): + break if not loops: - if is_batched: return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) + if is_batched: + return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) return orig_vertices, orig_faces new_faces = [] @@ -417,13 +421,15 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): valid_new_verts = [] for loop_indices in loops: - if len(loop_indices) < 3: continue + if len(loop_indices) < 3: + continue loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) loop_verts = vertices[loop_tensor] diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) perimeter = torch.norm(diffs, dim=1).sum() - if perimeter > max_hole_perimeter: continue + if perimeter > max_hole_perimeter: + continue center = loop_verts.mean(dim=0) valid_new_verts.append(center) From b5feac202c45f5106ac91cffb361dcb0c411fd5e Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 18 Feb 2026 22:01:09 +0200 Subject: [PATCH 25/93] . --- comfy/ldm/trellis2/attention.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index e6aa50842..e8e401fd7 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -6,7 +6,7 @@ from comfy.ldm.trellis2.vae import VarLenTensor FLASH_ATTN_3_AVA = True try: - import flash_attn_interface as flash_attn_3 + import flash_attn_interface as flash_attn_3 # noqa: F401 except: FLASH_ATTN_3_AVA = False @@ -53,8 +53,6 @@ def scaled_dot_product_attention(*args, **kwargs): elif num_all_args == 3: out = flash_attn_3.flash_attn_func(q, k, v) elif optimized_attention.__name__ == 'attention_pytorch': - if 'sdpa' not in globals(): - from torch.nn.functional import scaled_dot_product_attention as sdpa if num_all_args == 1: q, k, v = qkv.unbind(dim=2) elif num_all_args == 2: From ee2b66a2f2b7d699ccc95d0a5b3bcb7fd5390814 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 19 Feb 2026 00:48:26 +0200 Subject: [PATCH 26/93] small fixes --- comfy/ldm/trellis2/model.py | 4 ++++ comfy_extras/nodes_trellis2.py | 15 +++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index eb410fe8b..b4fc15abc 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -828,6 +828,7 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): embeds = kwargs.get("embeds") mode = kwargs.get("generation_mode") + coords = kwargs.get("coords") transformer_options = kwargs.get("transformer_options") sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: @@ -836,6 +837,9 @@ class Trellis2(nn.Module): shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] + if mode in ["shape_generation", "texture_generation"]: + x = SparseTensor(feats=x, coords=coords) + if mode == "shape_generation": # TODO out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)])) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 817769d08..bd250e5f5 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,9 +1,7 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types import torch -from comfy.ldm.trellis2.model import SparseTensor import comfy.model_management -import comfy.model_patcher shape_slat_normalization = { "mean": torch.tensor([ @@ -205,7 +203,6 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): samples = samples.to(load_device) decoded = decoder(samples)>0 decoder.to(offload_device) - comfy.model_management.get_offload_stream out = Types.VOXEL(decoded.squeeze(1).float()) return IO.NodeOutput(out) @@ -253,10 +250,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): decoded = structure_output.data - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 - latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) - return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation"}) + latent = torch.randn(coords.shape[0], in_channels) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation", "coords": coords}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -275,9 +272,11 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): # TODO + decoded = structure_output.data + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 - latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1])) - return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation"}) + latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation", "coords": coords}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod From 9243ae347b59cad350e139f4c3bce5b01020f9a4 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 19 Feb 2026 02:01:26 +0200 Subject: [PATCH 27/93] added conditioning --- comfy_extras/nodes_trellis2.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index bd250e5f5..fc9a15cfa 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -2,6 +2,8 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types import torch import comfy.model_management +from PIL import Image +import numpy as np shape_slat_normalization = { "mean": torch.tensor([ @@ -226,6 +228,31 @@ class Trellis2Conditioning(IO.ComfyNode): @classmethod def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput: + + if image.ndim == 4: + image = image[0] + + # TODO + image = Image.fromarray(image.numpy()) + max_size = max(image.size) + scale = min(1, 1024 / max_size) + if scale < 1: + image = image.resize((int(image.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + + output_np = np.array(image) + alpha = output_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1) + bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 + output = image.crop(bbox) # type: ignore + output = np.array(output).astype(np.float32) / 255 + output = output[:, :, :3] * output[:, :, 3:4] + + image = torch.tensor(output) + # could make 1024 an option conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) embeds = conditioning["cond_1024"] # should add that From 6191cd86bfcaa1c18df68d2f0c932212dbc0a64d Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 19 Feb 2026 22:05:33 +0200 Subject: [PATCH 28/93] trellis2conditioning and a hidden bug --- comfy/ldm/trellis2/model.py | 13 ++++++++++--- comfy_extras/nodes_trellis2.py | 21 +++++---------------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index b4fc15abc..8579b0580 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -826,7 +826,12 @@ class Trellis2(nn.Module): self.guidance_interval_txt = [0.6, 0.9] def forward(self, x, timestep, context, **kwargs): + # FIXME: should find a way to distinguish between 512/1024 models + # currently assumes 1024 embeds = kwargs.get("embeds") + _, cond = context.chunk(2) + cond = embeds.chunk(2)[0] + context = torch.cat([torch.zeros_like(cond), cond]) mode = kwargs.get("generation_mode") coords = kwargs.get("coords") transformer_options = kwargs.get("transformer_options") @@ -837,12 +842,13 @@ class Trellis2(nn.Module): shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] - if mode in ["shape_generation", "texture_generation"]: + not_struct_mode = mode in ["shape_generation", "texture_generation"] + if not_struct_mode: x = SparseTensor(feats=x, coords=coords) if mode == "shape_generation": # TODO - out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)])) + out = self.img2shape(x, timestep, context) elif mode == "texture_generation": out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure @@ -855,5 +861,6 @@ class Trellis2(nn.Module): if shape_rule: out = out.repeat(orig_bsz, 1, 1, 1, 1) - out.generation_mode = mode + if not_struct_mode: + out = out.feats return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index fc9a15cfa..1683949a3 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -233,25 +233,14 @@ class Trellis2Conditioning(IO.ComfyNode): image = image[0] # TODO - image = Image.fromarray(image.numpy()) + image = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + image = Image.fromarray(image) max_size = max(image.size) scale = min(1, 1024 / max_size) if scale < 1: image = image.resize((int(image.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) - output_np = np.array(image) - alpha = output_np[:, :, 3] - bbox = np.argwhere(alpha > 0.8 * 255) - bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) - center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 - size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) - size = int(size * 1) - bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 - output = image.crop(bbox) # type: ignore - output = np.array(output).astype(np.float32) / 255 - output = output[:, :, :3] * output[:, :, 3:4] - - image = torch.tensor(output) + image = torch.tensor(np.array(image)).unsqueeze(0) # could make 1024 an option conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) @@ -276,7 +265,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): - decoded = structure_output.data + decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(coords.shape[0], in_channels) @@ -299,7 +288,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): # TODO - decoded = structure_output.data + decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) From c5a750205d1a35d5ee8937f8997f8b2c10e37b10 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Feb 2026 17:39:44 +0200 Subject: [PATCH 29/93] . --- comfy/ldm/trellis2/model.py | 14 +++++++++++--- comfy_extras/nodes_trellis2.py | 13 +++++++++---- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 8579b0580..5ff2a1ce0 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -8,6 +8,7 @@ from comfy.ldm.trellis2.attention import ( ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder from comfy.ldm.flux.math import apply_rope, apply_rope1 +import builtins class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: @@ -481,6 +482,8 @@ class SLatFlowModel(nn.Module): if isinstance(cond, list): cond = VarLenTensor.from_tensor_list(cond) + dtype = next(self.input_layer.parameters()).dtype + x = x.to(dtype) h = self.input_layer(x) h = manual_cast(h, self.dtype) t_emb = self.t_embedder(t, out_dtype = t.dtype) @@ -832,8 +835,14 @@ class Trellis2(nn.Module): _, cond = context.chunk(2) cond = embeds.chunk(2)[0] context = torch.cat([torch.zeros_like(cond), cond]) - mode = kwargs.get("generation_mode") - coords = kwargs.get("coords") + mode = getattr(builtins, "TRELLIS_MODE", "structure_generation") + coords = getattr(builtins, "TRELLIS_COORDS", None) + if coords is not None: + x = x.squeeze(0) + not_struct_mode = True + else: + mode = "structure_generation" + not_struct_mode = False transformer_options = kwargs.get("transformer_options") sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: @@ -842,7 +851,6 @@ class Trellis2(nn.Module): shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] - not_struct_mode = mode in ["shape_generation", "texture_generation"] if not_struct_mode: x = SparseTensor(feats=x, coords=coords) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 1683949a3..c8d84fd23 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -4,6 +4,7 @@ import torch import comfy.model_management from PIL import Image import numpy as np +import builtins shape_slat_normalization = { "mean": torch.tensor([ @@ -268,8 +269,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 - latent = torch.randn(coords.shape[0], in_channels) - return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation", "coords": coords}) + latent = torch.randn(1, coords.shape[0], in_channels) + builtins.TRELLIS_MODE = "shape_generation" + builtins.TRELLIS_COORDS = coords + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -292,7 +295,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) - return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation", "coords": coords}) + builtins.TRELLIS_MODE = "texture_generation" + builtins.TRELLIS_COORDS = coords + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod @@ -312,7 +317,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): in_channels = 8 resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) - return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "structure_generation"}) + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) def simplify_fn(vertices, faces, target=100000): From f3d4125e4904c427d26aa86c84d26b8dba48fe22 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Feb 2026 20:16:49 +0200 Subject: [PATCH 30/93] code rabbit suggestions --- comfy/image_encoders/dino3.py | 6 +++--- comfy/image_encoders/dino3_large.json | 4 ++-- comfy/ldm/trellis2/cumesh.py | 2 +- comfy/ldm/trellis2/model.py | 8 +++----- comfy/model_detection.py | 4 ++-- 5 files changed, 11 insertions(+), 13 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index ce6b2edd9..e009a7291 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -188,7 +188,7 @@ class DINOv3ViTLayer(nn.Module): device, dtype, operations): super().__init__() - self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps) + self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations) self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) @@ -273,8 +273,8 @@ class DINOv3ViTModel(nn.Module): position_embeddings=position_embeddings, ) - self.norm = self.norm.to(hidden_states.device) - sequence_output = self.norm(hidden_states) + norm = self.norm.to(hidden_states.device) + sequence_output = norm(hidden_states) pooled_output = sequence_output[:, 0, :] return sequence_output, None, pooled_output, None diff --git a/comfy/image_encoders/dino3_large.json b/comfy/image_encoders/dino3_large.json index 53f761a25..b37b61dc8 100644 --- a/comfy/image_encoders/dino3_large.json +++ b/comfy/image_encoders/dino3_large.json @@ -18,6 +18,6 @@ "rope_theta": 100.0, "use_gated_mlp": false, "value_bias": true, - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225] + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225] } diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index 972fb13c3..cb067a32f 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -6,7 +6,7 @@ from typing import Dict, Callable NO_TRITION = False try: - allow_tf32 = torch.cuda.is_tf32_supported + allow_tf32 = torch.cuda.is_tf32_supported() except Exception: allow_tf32 = False try: diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 5ff2a1ce0..07cf86d30 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -306,10 +306,7 @@ class ModulatedSparseTransformerBlock(nn.Module): return x def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: - if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) - else: - return self._forward(x, mod) + return self._forward(x, mod) class ModulatedSparseTransformerCrossBlock(nn.Module): @@ -486,6 +483,7 @@ class SLatFlowModel(nn.Module): x = x.to(dtype) h = self.input_layer(x) h = manual_cast(h, self.dtype) + t = t.to(dtype) t_emb = self.t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) @@ -790,7 +788,7 @@ class SparseStructureFlowModel(nn.Module): return h def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0): - t_shifted /= 1000.0 + t_shifted = t_shifted / 1000.0 t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1)) t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear) t_new *= 1000.0 diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 375cb87b1..6cadc8af6 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -117,12 +117,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config["image_model"] = "trellis2" unet_config["init_txt_model"] = False - if '{}model.shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: + if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: unet_config["init_txt_model"] = True unet_config["resolution"] = 64 if metadata is not None: - if "is_512" in metadata and metadata["metadata"]: + if "is_512" in metadata: unet_config["resolution"] = 32 unet_config["num_heads"] = 12 From b3da8ed4c594744c5a9a8e14577b9a6898ada8f1 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Feb 2026 21:13:13 +0200 Subject: [PATCH 31/93] coderabbit 2 --- comfy/ldm/trellis2/cumesh.py | 7 +++++-- comfy/ldm/trellis2/model.py | 9 ++++----- comfy_extras/nodes_trellis2.py | 37 +++++++++++++++++++++++++--------- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index cb067a32f..1be8408c6 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -207,11 +207,14 @@ class TorchHashMap: def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: flat = flat_keys.long() + if self._n == 0: + return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) idx = torch.searchsorted(self.sorted_keys, flat) - found = (idx < self._n) & (self.sorted_keys[idx] == flat) + idx_safe = torch.clamp(idx, max=self._n - 1) + found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat) out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) if found.any(): - out[found] = self.sorted_vals[idx[found]] + out[found] = self.sorted_vals[idx_safe[found]] return out diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 07cf86d30..ef1c25d33 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -8,7 +8,6 @@ from comfy.ldm.trellis2.attention import ( ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder from comfy.ldm.flux.math import apply_rope, apply_rope1 -import builtins class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: @@ -829,19 +828,19 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): # FIXME: should find a way to distinguish between 512/1024 models # currently assumes 1024 + transformer_options = kwargs.get("transformer_options") embeds = kwargs.get("embeds") - _, cond = context.chunk(2) + #_, cond = context.chunk(2) # TODO cond = embeds.chunk(2)[0] context = torch.cat([torch.zeros_like(cond), cond]) - mode = getattr(builtins, "TRELLIS_MODE", "structure_generation") - coords = getattr(builtins, "TRELLIS_COORDS", None) + coords = transformer_options.get("coords", None) + mode = transformer_options.get("generation_mode", "structure_generation") if coords is not None: x = x.squeeze(0) not_struct_mode = True else: mode = "structure_generation" not_struct_mode = False - transformer_options = kwargs.get("transformer_options") sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: timestep *= 1000.0 diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index c8d84fd23..14f5484d6 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -4,7 +4,6 @@ import torch import comfy.model_management from PIL import Image import numpy as np -import builtins shape_slat_normalization = { "mean": torch.tensor([ @@ -258,21 +257,31 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): category="latent/3d", inputs=[ IO.Voxel.Input("structure_output"), + IO.Model.Input("model") ], outputs=[ IO.Latent.Output(), + IO.Model.Output() ] ) @classmethod - def execute(cls, structure_output): + def execute(cls, structure_output, model): decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(1, coords.shape[0], in_channels) - builtins.TRELLIS_MODE = "shape_generation" - builtins.TRELLIS_COORDS = coords - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + model = model.clone() + if "transformer_options" not in model.model_options: + model.model_options = {} + else: + model.model_options = model.model_options.copy() + + model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() + + model.model_options["transformer_options"]["coords"] = coords + model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -285,19 +294,29 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ], outputs=[ IO.Latent.Output(), + IO.Model.Output() ] ) @classmethod - def execute(cls, structure_output): + def execute(cls, structure_output, model): # TODO decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) - builtins.TRELLIS_MODE = "texture_generation" - builtins.TRELLIS_COORDS = coords - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + model = model.clone() + if "transformer_options" not in model.model_options: + model.model_options = {} + else: + model.model_options = model.model_options.copy() + + model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() + + model.model_options["transformer_options"]["coords"] = coords + model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) + class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod From 1fde60b2bc67d39ab4177c1aabc828350245d5f9 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Feb 2026 22:04:37 +0200 Subject: [PATCH 32/93] debugging --- comfy/image_encoders/dino3.py | 2 +- comfy/ldm/trellis2/model.py | 94 +++++++++------------------------- comfy_extras/nodes_trellis2.py | 4 +- 3 files changed, 27 insertions(+), 73 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index e009a7291..3ec7f8a04 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -175,7 +175,7 @@ class DINOv3ViTEmbeddings(nn.Module): cls_token = self.cls_token.expand(batch_size, -1, -1) register_tokens = self.register_tokens.expand(batch_size, -1, -1) - device = patch_embeddings + device = patch_embeddings.device cls_token = cls_token.to(device) register_tokens = register_tokens.to(device) embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index ef1c25d33..7f16c4d41 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -201,6 +201,8 @@ class SparseMultiHeadAttention(nn.Module): def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor: if self._type == "self": + dtype = next(self.to_qkv.parameters()).dtype + x = x.to(dtype) qkv = self._linear(self.to_qkv, x) qkv = self._fused_pre(qkv, num_fused=3) if self.qk_rms_norm or self.use_rope: @@ -243,71 +245,6 @@ class SparseMultiHeadAttention(nn.Module): h = self._linear(self.to_out, h) return h -class ModulatedSparseTransformerBlock(nn.Module): - def __init__( - self, - channels: int, - num_heads: int, - mlp_ratio: float = 4.0, - attn_mode: Literal["full", "swin"] = "full", - window_size: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, - use_checkpoint: bool = False, - use_rope: bool = False, - rope_freq: Tuple[float, float] = (1.0, 10000.0), - qk_rms_norm: bool = False, - qkv_bias: bool = True, - share_mod: bool = False, - ): - super().__init__() - self.use_checkpoint = use_checkpoint - self.share_mod = share_mod - self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.attn = SparseMultiHeadAttention( - channels, - num_heads=num_heads, - attn_mode=attn_mode, - window_size=window_size, - shift_window=shift_window, - qkv_bias=qkv_bias, - use_rope=use_rope, - rope_freq=rope_freq, - qk_rms_norm=qk_rms_norm, - ) - self.mlp = SparseFeedForwardNet( - channels, - mlp_ratio=mlp_ratio, - ) - if not share_mod: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) - ) - else: - self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) - - def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: - if self.share_mod: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) - else: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) - h = x.replace(self.norm1(x.feats)) - h = h * (1 + scale_msa) + shift_msa - h = self.attn(h) - h = h * gate_msa - x = x + h - h = x.replace(self.norm2(x.feats)) - h = h * (1 + scale_mlp) + shift_mlp - h = self.mlp(h) - h = h * gate_mlp - x = x + h - return x - - def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: - return self._forward(x, mod) - - class ModulatedSparseTransformerCrossBlock(nn.Module): """ Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. @@ -483,15 +420,13 @@ class SLatFlowModel(nn.Module): h = self.input_layer(x) h = manual_cast(h, self.dtype) t = t.to(dtype) - t_emb = self.t_embedder(t, out_dtype = t.dtype) + t_embedder = self.t_embedder.to(dtype) + t_emb = t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) t_emb = manual_cast(t_emb, self.dtype) cond = manual_cast(cond, self.dtype) - if self.pe_mode == "ape": - pe = self.pos_embedder(h.coords[:, 1:]) - h = h + manual_cast(pe, self.dtype) for block in self.blocks: h = block(h, t_emb, cond) @@ -849,7 +784,24 @@ class Trellis2(nn.Module): txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] if not_struct_mode: - x = SparseTensor(feats=x, coords=coords) + B, N, C = x.shape + + if mode == "shape_generation": + feats_flat = x.reshape(-1, C) + + # 3. inflate coords [N, 4] -> [B*N, 4] + coords_list = [] + for i in range(B): + c = coords.clone() + c[:, 0] = i + coords_list.append(c) + + batched_coords = torch.cat(coords_list, dim=0) + else: # TODO: texture + # may remove the else if texture doesn't require special handling + batched_coords = coords + feats_flat = x + x = SparseTensor(feats=feats_flat, coords=batched_coords) if mode == "shape_generation": # TODO @@ -868,4 +820,6 @@ class Trellis2(nn.Module): if not_struct_mode: out = out.feats + if mode == "shape_generation": + out = out.view(B, N, -1) return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 14f5484d6..623430b9e 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -238,9 +238,9 @@ class Trellis2Conditioning(IO.ComfyNode): max_size = max(image.size) scale = min(1, 1024 / max_size) if scale < 1: - image = image.resize((int(image.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS) - image = torch.tensor(np.array(image)).unsqueeze(0) + image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255 # could make 1024 an option conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) From 253ee4c02c8396555cff8bc07c7983fa2ccd0074 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 22 Feb 2026 01:25:10 +0200 Subject: [PATCH 33/93] fixes --- comfy/image_encoders/dino3.py | 5 +---- comfy/ldm/trellis2/cumesh.py | 8 +++++--- comfy/ldm/trellis2/model.py | 11 +++++++---- comfy/ldm/trellis2/vae.py | 2 ++ comfy_extras/nodes_trellis2.py | 17 +++++++++++------ 5 files changed, 26 insertions(+), 17 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 3ec7f8a04..40ece19ed 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -44,9 +44,6 @@ class DINOv3ViTAttention(nn.Module): self.num_heads = num_attention_heads self.head_dim = self.embed_dim // self.num_heads - self.scaling = self.head_dim**-0.5 - self.is_causal = False - self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) @@ -251,7 +248,7 @@ class DINOv3ViTModel(nn.Module): intermediate_size=intermediate_size,num_attention_heads = num_attention_heads, dtype=dtype, device=device, operations=operations) for _ in range(num_hidden_layers)]) - self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) + self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) def get_input_embeddings(self): return self.embeddings.patch_embeddings diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index 1be8408c6..8f677ce24 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -4,7 +4,7 @@ import math import torch from typing import Dict, Callable -NO_TRITION = False +NO_TRITON = False try: allow_tf32 = torch.cuda.is_tf32_supported() except Exception: @@ -115,8 +115,8 @@ try: allow_tf32=allow_tf32, ) return output -except: - NO_TRITION = True +except Exception: + NO_TRITON = True def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device): # offsets in same order as CUDA kernel @@ -364,6 +364,8 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2( def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation): + if NO_TRITON: # TODO + raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.") if len(shape) == 5: N, C, W, H, D = shape else: diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 7f16c4d41..7cc1c1678 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -697,8 +697,6 @@ class SparseStructureFlowModel(nn.Module): def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3) - assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ - f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() @@ -746,7 +744,8 @@ class Trellis2(nn.Module): super().__init__() self.dtype = dtype # for some reason it passes num_heads = -1 - num_heads = 12 + if num_heads == -1: + num_heads = 12 args = { "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, @@ -763,8 +762,10 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): # FIXME: should find a way to distinguish between 512/1024 models # currently assumes 1024 - transformer_options = kwargs.get("transformer_options") + transformer_options = kwargs.get("transformer_options", {}) embeds = kwargs.get("embeds") + if embeds is None: + raise ValueError("Trellis2.forward requires 'embeds' in kwargs") #_, cond = context.chunk(2) # TODO cond = embeds.chunk(2)[0] context = torch.cat([torch.zeros_like(cond), cond]) @@ -807,6 +808,8 @@ class Trellis2(nn.Module): # TODO out = self.img2shape(x, timestep, context) elif mode == "texture_generation": + if self.shape2txt is None: + raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure timestep = timestep_reshift(timestep) diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 57bf78346..c6ea5deb2 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -1522,6 +1522,8 @@ class Vae(nn.Module): return self.shape_dec(slat, return_subs=True) def decode_tex_slat(self, slat, subs): + if self.txt_dec is None: + raise ValueError("Checkpoint doesn't include texture model") return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5 @torch.no_grad() diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 623430b9e..c688c343d 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,9 +1,11 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types -import torch +import torch.nn.functional as TF import comfy.model_management from PIL import Image import numpy as np +import torch +import copy shape_slat_normalization = { "mean": torch.tensor([ @@ -145,11 +147,11 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): def execute(cls, samples, vae, resolution): vae = vae.first_stage_model samples = samples["samples"] - std = shape_slat_normalization["std"] - mean = shape_slat_normalization["mean"] + std = shape_slat_normalization["std"].to(samples) + mean = shape_slat_normalization["mean"].to(samples) samples = samples * std + mean - mesh, subs = vae.decode_shape_slat(resolution, samples) + mesh, subs = vae.decode_shape_slat(samples, resolution) return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @@ -172,8 +174,8 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): def execute(cls, samples, vae, shape_subs): vae = vae.first_stage_model samples = samples["samples"] - std = tex_slat_normalization["std"] - mean = tex_slat_normalization["mean"] + std = tex_slat_normalization["std"].to(samples) + mean = tex_slat_normalization["mean"].to(samples) samples = samples * std + mean mesh = vae.decode_tex_slat(samples, shape_subs) @@ -239,6 +241,8 @@ class Trellis2Conditioning(IO.ComfyNode): scale = min(1, 1024 / max_size) if scale < 1: image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS) + new_h, new_w = int(mask.shape[-2] * scale), int(mask.shape[-1] * scale) + mask = TF.interpolate(mask.unsqueeze(0).float(), size=(new_h, new_w), mode='nearest').squeeze(0) image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255 @@ -510,6 +514,7 @@ class PostProcessMesh(IO.ComfyNode): ) @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): + mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces if fill_holes_perimeter != 0.0: From c9f5c788a733c68029be33141197706b07f9f669 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 22 Feb 2026 23:47:49 +0200 Subject: [PATCH 34/93] bunch of fixes --- comfy/image_encoders/dino3.py | 3 +- comfy/ldm/trellis2/attention.py | 82 ++++++++++++++++++++++----------- comfy/ldm/trellis2/cumesh.py | 11 ++++- comfy/ldm/trellis2/model.py | 2 +- comfy/ldm/trellis2/vae.py | 21 +++++---- comfy_extras/nodes_trellis2.py | 18 ++++---- 6 files changed, 86 insertions(+), 51 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 40ece19ed..1bf404498 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -2,6 +2,7 @@ import math import torch import torch.nn as nn +import comfy.model_management from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale @@ -225,7 +226,7 @@ class DINOv3ViTLayer(nn.Module): class DINOv3ViTModel(nn.Module): def __init__(self, config, dtype, device, operations): super().__init__() - if dtype == torch.float16: + if dtype == torch.float16 and comfy.model_management.should_use_bf16(device, prioritize_performance=False): dtype = torch.bfloat16 num_hidden_layers = config["num_hidden_layers"] hidden_size = config["hidden_size"] diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index e8e401fd7..19de93b96 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -3,12 +3,56 @@ import math from comfy.ldm.modules.attention import optimized_attention from typing import Tuple, Union, List from comfy.ldm.trellis2.vae import VarLenTensor +import comfy.ops + + +# replica of the seedvr2 code +def var_attn_arg(kwargs): + cu_seqlens_q = kwargs.get("cu_seqlens_q", None) + max_seqlen_q = kwargs.get("max_seqlen_q", None) + cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) or kwargs.get("cu_seqlens_kv", cu_seqlens_q) + max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) or kwargs.get("max_kv_seqlen", max_seqlen_q) + assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True" + return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k + +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + var_length = True + if var_length: + cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs) + if not skip_reshape: + # assumes 2D q, k,v [total_tokens, embed_dim] + total_tokens, embed_dim = q.shape + head_dim = embed_dim // heads + q = q.view(total_tokens, heads, head_dim) + k = k.view(k.shape[0], heads, head_dim) + v = v.view(v.shape[0], heads, head_dim) + + b = q.size(0) + dim_head = q.shape[-1] + q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long()) + k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long()) + v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long()) + + mask = None + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if mask is not None: + if mask.ndim == 2: + mask = mask.unsqueeze(0) + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + if var_length: + return out.contiguous().transpose(1, 2).values() + if not skip_output_reshape: + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) + return out -FLASH_ATTN_3_AVA = True -try: - import flash_attn_interface as flash_attn_3 # noqa: F401 -except: - FLASH_ATTN_3_AVA = False # TODO repalce with optimized attention def scaled_dot_product_attention(*args, **kwargs): @@ -40,18 +84,10 @@ def scaled_dot_product_attention(*args, **kwargs): k, v = kv.unbind(dim=2) #out = xops.memory_efficient_attention(q, k, v) out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) - elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA: + elif optimized_attention.__name__ == 'attention_flash': if num_all_args == 2: k, v = kv.unbind(dim=2) out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) - elif optimized_attention.__name__ == 'attention_flash': # TODO - if 'flash_attn_3' not in globals(): - import flash_attn_interface as flash_attn_3 - if num_all_args == 2: - k, v = kv.unbind(dim=2) - out = flash_attn_3.flash_attn_func(q, k, v) - elif num_all_args == 3: - out = flash_attn_3.flash_attn_func(q, k, v) elif optimized_attention.__name__ == 'attention_pytorch': if num_all_args == 1: q, k, v = qkv.unbind(dim=2) @@ -238,24 +274,16 @@ def sparse_scaled_dot_product_attention(*args, **kwargs): out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) elif num_all_args == 3: out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) - elif optimized_attention.__name__ == 'flash_attn_3': # TODO - if 'flash_attn_3' not in globals(): - import flash_attn_interface as flash_attn_3 + + elif optimized_attention.__name__ == "attention_pytorch": cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) if num_all_args == 1: q, k, v = qkv.unbind(dim=1) - cu_seqlens_kv = cu_seqlens_q.clone() - max_q_seqlen = max_kv_seqlen = max(q_seqlen) elif num_all_args == 2: k, v = kv.unbind(dim=1) - cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) - max_q_seqlen = max(q_seqlen) - max_kv_seqlen = max(kv_seqlen) - elif num_all_args == 3: - cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) - max_q_seqlen = max(q_seqlen) - max_kv_seqlen = max(kv_seqlen) - out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen) + out = attention_pytorch(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) if s is not None: return s.replace(out) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index 8f677ce24..ea069a465 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -3,6 +3,7 @@ import math import torch from typing import Dict, Callable +import logging NO_TRITON = False try: @@ -366,6 +367,10 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2( def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation): if NO_TRITON: # TODO raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.") + if feats.shape[0] == 0: + logging.warning("Found feats to be empty!") + Co = weight.shape[0] + return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None if len(shape) == 5: N, C, W, H, D = shape else: @@ -427,9 +432,11 @@ class Voxel: voxel_size: float, coords: torch.Tensor = None, attrs: torch.Tensor = None, - layout: Dict = {}, - device: torch.device = 'cuda' + layout = None, + device: torch.device = None ): + if layout is None: + layout = {} self.origin = torch.tensor(origin, dtype=torch.float32, device=device) self.voxel_size = voxel_size self.coords = coords diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 7cc1c1678..52242b4e0 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -630,7 +630,6 @@ class SparseStructureFlowModel(nn.Module): mlp_ratio: float = 4, pe_mode: Literal["ape", "rope"] = "rope", rope_freq: Tuple[float, float] = (1.0, 10000.0), - dtype: str = 'float32', use_checkpoint: bool = False, share_mod: bool = False, initialization: str = 'vanilla', @@ -638,6 +637,7 @@ class SparseStructureFlowModel(nn.Module): qk_rms_norm_cross: bool = False, operations=None, device = None, + dtype = torch.float32, **kwargs ): super().__init__() diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index c6ea5deb2..36e2f3df5 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -1004,7 +1004,6 @@ class SparseUnetVaeEncoder(nn.Module): self.model_channels = model_channels self.num_blocks = num_blocks self.dtype = torch.float16 if use_fp16 else torch.float32 - self.dtype = torch.float16 if use_fp16 else torch.float32 self.input_layer = SparseLinear(in_channels, model_channels[0]) self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels) @@ -1247,24 +1246,26 @@ def flexible_dual_grid_to_mesh( hashmap_builder=None, # optional callable for building/caching a TorchHashMap ): - if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset"): + device = coords.device + if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset") \ + or flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset.device != device: flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([ [[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis [[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis [[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis - ], dtype=torch.int, device=coords.device).unsqueeze(0) - if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1"): - flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=coords.device, requires_grad=False) - if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2"): - flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=coords.device, requires_grad=False) - if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train"): - flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=coords.device, requires_grad=False) + ], dtype=torch.int, device=device).unsqueeze(0) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1") or flexible_dual_grid_to_mesh.quad_split_1.device != device: + flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2") or flexible_dual_grid_to_mesh.quad_split_2.device != device: + flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train") or flexible_dual_grid_to_mesh.quad_split_train.device != device: + flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=device, requires_grad=False) # AABB if isinstance(aabb, (list, tuple)): aabb = np.array(aabb) if isinstance(aabb, np.ndarray): - aabb = torch.tensor(aabb, dtype=torch.float32, device=coords.device) + aabb = torch.tensor(aabb, dtype=torch.float32, device=device) # Voxel size if voxel_size is not None: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index c688c343d..2b44a19eb 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -276,12 +276,11 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): in_channels = 32 latent = torch.randn(1, coords.shape[0], in_channels) model = model.clone() - if "transformer_options" not in model.model_options: - model.model_options = {} + model.model_options = model.model_options.copy() + if "transformer_options" in model.model_options: + model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() else: - model.model_options = model.model_options.copy() - - model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() + model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords model.model_options["transformer_options"]["generation_mode"] = "shape_generation" @@ -310,12 +309,11 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): in_channels = 32 latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) model = model.clone() - if "transformer_options" not in model.model_options: - model.model_options = {} + model.model_options = model.model_options.copy() + if "transformer_options" in model.model_options: + model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() else: - model.model_options = model.model_options.copy() - - model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() + model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords model.model_options["transformer_options"]["generation_mode"] = "shape_generation" From 2a27c3b41738b375d2820ddde3dc512cb8f8d2b4 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 23 Feb 2026 01:40:56 +0200 Subject: [PATCH 35/93] progressing --- comfy/image_encoders/dino3.py | 5 ++++- comfy/ldm/trellis2/attention.py | 17 ++++++++++++++--- comfy/ldm/trellis2/model.py | 10 +++++----- comfy_extras/nodes_trellis2.py | 11 ++++++++++- 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 1bf404498..ff17d78d6 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -226,8 +226,11 @@ class DINOv3ViTLayer(nn.Module): class DINOv3ViTModel(nn.Module): def __init__(self, config, dtype, device, operations): super().__init__() - if dtype == torch.float16 and comfy.model_management.should_use_bf16(device, prioritize_performance=False): + use_bf16 = comfy.model_management.should_use_bf16(device, prioritize_performance=True) + if dtype == torch.float16 and use_bf16: dtype = torch.bfloat16 + elif dtype == torch.float16 and not use_bf16: + dtype = torch.float32 num_hidden_layers = config["num_hidden_layers"] hidden_size = config["hidden_size"] num_attention_heads = config["num_attention_heads"] diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 19de93b96..0b9c12294 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -10,8 +10,8 @@ import comfy.ops def var_attn_arg(kwargs): cu_seqlens_q = kwargs.get("cu_seqlens_q", None) max_seqlen_q = kwargs.get("max_seqlen_q", None) - cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q) or kwargs.get("cu_seqlens_kv", cu_seqlens_q) - max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q) or kwargs.get("max_kv_seqlen", max_seqlen_q) + cu_seqlens_k = kwargs.get("cu_seqlens_kv", cu_seqlens_q) + max_seqlen_k = kwargs.get("max_kv_seqlen", max_seqlen_q) assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True" return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k @@ -183,6 +183,7 @@ def calc_window_partition( def sparse_scaled_dot_product_attention(*args, **kwargs): + q=None arg_names_dict = { 1: ['qkv'], 2: ['q', 'kv'], @@ -250,6 +251,12 @@ def sparse_scaled_dot_product_attention(*args, **kwargs): k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + # TODO: change + if q is not None: + heads = q + else: + heads = qkv + heads = heads.shape[2] if optimized_attention.__name__ == 'attention_xformers': if 'xops' not in globals(): import xformers.ops as xops @@ -279,11 +286,15 @@ def sparse_scaled_dot_product_attention(*args, **kwargs): cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) if num_all_args in [2, 3]: cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + else: + cu_seqlens_kv = cu_seqlens_q if num_all_args == 1: q, k, v = qkv.unbind(dim=1) elif num_all_args == 2: k, v = kv.unbind(dim=1) - out = attention_pytorch(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + out = attention_pytorch(q, k, v, heads=heads,cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max(q_seqlen), max_kv_seqlen=max(kv_seqlen), + skip_reshape=True, skip_output_reshape=True) if s is not None: return s.replace(out) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 52242b4e0..a565ec37e 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -232,6 +232,8 @@ class SparseMultiHeadAttention(nn.Module): else: q = self._linear(self.to_q, x) q = self._reshape_chs(q, (self.num_heads, -1)) + dtype = next(self.to_kv.parameters()).dtype + context = context.to(dtype) kv = self._linear(self.to_kv, context) kv = self._fused_pre(kv, num_fused=2) if self.qk_rms_norm: @@ -760,15 +762,13 @@ class Trellis2(nn.Module): self.guidance_interval_txt = [0.6, 0.9] def forward(self, x, timestep, context, **kwargs): - # FIXME: should find a way to distinguish between 512/1024 models - # currently assumes 1024 transformer_options = kwargs.get("transformer_options", {}) embeds = kwargs.get("embeds") if embeds is None: raise ValueError("Trellis2.forward requires 'embeds' in kwargs") - #_, cond = context.chunk(2) # TODO - cond = embeds.chunk(2)[0] - context = torch.cat([torch.zeros_like(cond), cond]) + is_1024 = self.img2shape.resolution == 1024 + if is_1024: + context = embeds coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") if coords is not None: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 2b44a19eb..1b43f7f62 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -2,6 +2,7 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types import torch.nn.functional as TF import comfy.model_management +from comfy.utils import ProgressBar from PIL import Image import numpy as np import torch @@ -250,7 +251,7 @@ class Trellis2Conditioning(IO.ComfyNode): conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) embeds = conditioning["cond_1024"] # should add that positive = [[conditioning["cond_512"], {"embeds": embeds}]] - negative = [[conditioning["neg_cond"], {"embeds": embeds}]] + negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]] return IO.NodeOutput(positive, negative) class EmptyShapeLatentTrellis2(IO.ComfyNode): @@ -512,15 +513,23 @@ class PostProcessMesh(IO.ComfyNode): ) @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): + bar = ProgressBar(2) mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces if fill_holes_perimeter != 0.0: verts, faces = fill_holes_fn(verts, faces, max_hole_perimeter=fill_holes_perimeter) + bar.update(1) + else: + bar.update(1) if simplify != 0: verts, faces = simplify_fn(verts, faces, simplify) + bar.update(1) + else: + bar.update(1) + # potentially adding laplacian smoothing mesh.vertices = verts mesh.faces = faces From a2c8a7aab5e33df51942454f59fa16bb19c8a9c7 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 24 Feb 2026 14:53:54 +0200 Subject: [PATCH 36/93] . --- comfy/ldm/trellis2/model.py | 4 ++-- comfy_extras/nodes_trellis2.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index a565ec37e..fb5276f94 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -772,7 +772,7 @@ class Trellis2(nn.Module): coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") if coords is not None: - x = x.squeeze(0) + x = x.squeeze(-1).transpose(1, 2) not_struct_mode = True else: mode = "structure_generation" @@ -824,5 +824,5 @@ class Trellis2(nn.Module): if not_struct_mode: out = out.feats if mode == "shape_generation": - out = out.view(B, N, -1) + out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 1b43f7f62..f40ff5161 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -275,7 +275,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 - latent = torch.randn(1, coords.shape[0], in_channels) + # image like format + latent = torch.randn(1, in_channels, coords.shape[0], 1) model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: From f31c2e1d1d7359f05995804f636444313b794068 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 25 Feb 2026 04:26:33 +0200 Subject: [PATCH 37/93] vae shape decode fixes --- comfy/ldm/trellis2/model.py | 2 +- comfy/ldm/trellis2/vae.py | 30 ++++++++++++++++++++---------- comfy/sd.py | 4 ++-- comfy_extras/nodes_trellis2.py | 15 ++++++++++++--- 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index fb5276f94..45740faea 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -802,7 +802,7 @@ class Trellis2(nn.Module): # may remove the else if texture doesn't require special handling batched_coords = coords feats_flat = x - x = SparseTensor(feats=feats_flat, coords=batched_coords) + x = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) if mode == "shape_generation": # TODO diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 36e2f3df5..0b1975092 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -1,11 +1,12 @@ import math import torch -import torch.nn as nn -from typing import List, Any, Dict, Optional, overload, Union, Tuple -from fractions import Fraction -import torch.nn.functional as F -from dataclasses import dataclass import numpy as np +import torch.nn as nn +import comfy.model_management +import torch.nn.functional as F +from fractions import Fraction +from dataclasses import dataclass +from typing import List, Any, Dict, Optional, overload, Union, Tuple from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d @@ -55,11 +56,12 @@ def sparse_conv3d_forward(self, x): Co, Kd, Kh, Kw, Ci = self.weight.shape neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}' neighbor_cache = x.get_spatial_cache(neighbor_cache_key) + x = x.to(self.weight.dtype).to(self.weight.device) out, neighbor_cache_ = sparse_submanifold_conv3d( x.feats, x.coords, - torch.Size([*x.shape, *x.spatial_shape]), + x.spatial_shape, self.weight, self.bias, neighbor_cache, @@ -100,7 +102,8 @@ class SparseConvNeXtBlock3d(nn.Module): def _forward(self, x): h = self.conv(x) - h = h.replace(self.norm(h.feats)) + norm = self.norm.to(torch.float32) + h = h.replace(norm(h.feats)) h = h.replace(self.mlp(h.feats)) return h + x @@ -208,13 +211,15 @@ class SparseResBlockC2S3d(nn.Module): def forward(self, x, subdiv = None): if self.pred_subdiv: subdiv = self.to_subdiv(x) - h = x.replace(self.norm1(x.feats)) + norm1 = self.norm1.to(torch.float32) + norm2 = self.norm2.to(torch.float32) + h = x.replace(norm1(x.feats)) h = h.replace(F.silu(h.feats)) h = self.conv1(h) subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None h = self.updown(h, subdiv_binarized) x = self.updown(x, subdiv_binarized) - h = h.replace(self.norm2(h.feats)) + h = h.replace(norm2(h.feats)) h = h.replace(F.silu(h.feats)) h = self.conv2(h) h = h + self.skip_connection(x) @@ -1139,6 +1144,9 @@ class SparseUnetVaeDecoder(nn.Module): def forward(self, x: SparseTensor, guide_subs: Optional[List[SparseTensor]] = None, return_subs: bool = False) -> SparseTensor: + dtype = next(self.from_latent.parameters()).dtype + device = next(self.from_latent.parameters()).device + x.feats = x.feats.to(dtype).to(device) h = self.from_latent(x) h = h.type(self.dtype) subs = [] @@ -1152,7 +1160,7 @@ class SparseUnetVaeDecoder(nn.Module): h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None) else: h = block(h) - h = h.type(x.dtype) + h = h.type(x.feats.dtype) h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) h = self.output_layer(h) if return_subs: @@ -1520,6 +1528,8 @@ class Vae(nn.Module): def decode_shape_slat(self, slat, resolution: int): self.shape_dec.set_resolution(resolution) + device = comfy.model_management.get_torch_device() + self.shape_dec = self.shape_dec.to(device) return self.shape_dec(slat, return_subs=True) def decode_tex_slat(self, slat, subs): diff --git a/comfy/sd.py b/comfy/sd.py index fecd16c88..f9898b0de 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -513,8 +513,8 @@ class VAE: init_txt_model = True self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] # TODO - self.memory_used_decode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model) elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index f40ff5161..96510e916 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,8 +1,9 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types +from comfy.ldm.trellis2.vae import SparseTensor +from comfy.utils import ProgressBar import torch.nn.functional as TF import comfy.model_management -from comfy.utils import ProgressBar from PIL import Image import numpy as np import torch @@ -135,6 +136,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): category="latent/3d", inputs=[ IO.Latent.Input("samples"), + IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), IO.Int.Input("resolution", tooltip="Shape Generation Resolution"), ], @@ -145,11 +147,14 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples, vae, resolution): + def execute(cls, samples, structure_output, vae, resolution): vae = vae.first_stage_model + decoded = structure_output.data.unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] std = shape_slat_normalization["std"].to(samples) mean = shape_slat_normalization["mean"].to(samples) + samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean mesh, subs = vae.decode_shape_slat(samples, resolution) @@ -163,6 +168,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): category="latent/3d", inputs=[ IO.Latent.Input("samples"), + IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), IO.AnyType.Input("shape_subs"), ], @@ -172,11 +178,14 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples, vae, shape_subs): + def execute(cls, samples, structure_output, vae, shape_subs): vae = vae.first_stage_model + decoded = structure_output.data.unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] std = tex_slat_normalization["std"].to(samples) mean = tex_slat_normalization["mean"].to(samples) + samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean mesh = vae.decode_tex_slat(samples, shape_subs) From 39270fdca941fee1d5a4df0c930e28c90505a0ad Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 25 Feb 2026 23:54:03 +0200 Subject: [PATCH 38/93] removed unnecessary code + optimizations + progres --- comfy/ldm/trellis2/cumesh.py | 177 +++++++++------------------------ comfy/ldm/trellis2/vae.py | 144 +++------------------------ comfy_extras/nodes_trellis2.py | 18 +++- 3 files changed, 75 insertions(+), 264 deletions(-) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index ea069a465..047e785ff 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -2,7 +2,7 @@ import math import torch -from typing import Dict, Callable +from typing import Callable import logging NO_TRITON = False @@ -201,13 +201,13 @@ class TorchHashMap: def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): device = keys.device # use long for searchsorted - self.sorted_keys, order = torch.sort(keys.long()) - self.sorted_vals = values.long()[order] + self.sorted_keys, order = torch.sort(keys.to(torch.long)) + self.sorted_vals = values.to(torch.long)[order] self.default_value = torch.tensor(default_value, dtype=torch.long, device=device) self._n = self.sorted_keys.numel() def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: - flat = flat_keys.long() + flat = flat_keys.to(torch.long) if self._n == 0: return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) idx = torch.searchsorted(self.sorted_keys, flat) @@ -225,44 +225,35 @@ def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map): device = neighbor_map.device N, V = neighbor_map.shape + sentinel = UINT32_SENTINEL - neigh = neighbor_map.to(torch.long) - sentinel = torch.tensor(UINT32_SENTINEL, dtype=torch.long, device=device) - - - neigh_map_T = neigh.t().reshape(-1) - + neigh_map_T = neighbor_map.t().reshape(-1) neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32) - mask = (neigh != sentinel).to(torch.long) + mask = (neighbor_map != sentinel).to(torch.long) + gray_code = torch.zeros(N, dtype=torch.long, device=device) - powers = (1 << torch.arange(V, dtype=torch.long, device=device)) + for v in range(V): + gray_code |= (mask[:, v] << v) - gray_long = (mask * powers).sum(dim=1) - - gray_code = gray_long.to(torch.int32) - - binary_long = gray_long.clone() + binary_code = gray_code.clone() for v in range(1, V): - binary_long ^= (gray_long >> v) - binary_code = binary_long.to(torch.int32) + binary_code ^= (gray_code >> v) sorted_idx = torch.argsort(binary_code) - prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T.to(torch.int32), dim=0) # (V*N,) + prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T, dim=0) total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0 if total_valid_signal > 0: + pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0] + to = (prefix_sum_neighbor_mask[pos] - 1).long() + valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device) valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device) - pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0] - - to = (prefix_sum_neighbor_mask[pos] - 1).to(torch.long) - valid_signal_i[to] = (pos % N).to(torch.long) - valid_signal_o[to] = neigh_map_T[pos].to(torch.long) else: valid_signal_i = torch.empty((0,), dtype=torch.long, device=device) @@ -272,9 +263,7 @@ def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map): seg[0] = 0 if V > 0: idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1 - seg[1:] = prefix_sum_neighbor_mask[idxs].to(torch.long) - else: - pass + seg[1:] = prefix_sum_neighbor_mask[idxs] return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg @@ -295,40 +284,41 @@ def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor: def neighbor_map_post_process_for_masked_implicit_gemm_2( - gray_code: torch.Tensor, # [N], int32-like (non-negative) - sorted_idx: torch.Tensor, # [N], long (indexing into gray_code) + gray_code: torch.Tensor, + sorted_idx: torch.Tensor, block_size: int ): device = gray_code.device N = gray_code.numel() - - # num of blocks (same as CUDA) num_blocks = (N + block_size - 1) // block_size - # Ensure dtypes - gray_long = gray_code.to(torch.int64) # safer to OR in 64-bit then cast - sorted_idx = sorted_idx.to(torch.long) - - # 1) Group gray_code by blocks and compute OR across each block - # pad the last block with zeros if necessary so we can reshape pad = num_blocks * block_size - N if pad > 0: - pad_vals = torch.zeros((pad,), dtype=torch.int64, device=device) - gray_padded = torch.cat([gray_long[sorted_idx], pad_vals], dim=0) + pad_vals = torch.zeros((pad,), dtype=torch.int32, device=device) + gray_padded = torch.cat([gray_code[sorted_idx], pad_vals], dim=0) else: - gray_padded = gray_long[sorted_idx] + gray_padded = gray_code[sorted_idx] - # reshape to (num_blocks, block_size) and compute bitwise_or across dim=1 - gray_blocks = gray_padded.view(num_blocks, block_size) # each row = block entries - # reduce with bitwise_or - reduced_code = gray_blocks[:, 0].clone() - for i in range(1, block_size): - reduced_code |= gray_blocks[:, i] - reduced_code = reduced_code.to(torch.int32) # match CUDA int32 + gray_blocks = gray_padded.view(num_blocks, block_size) + + reduced_code = gray_blocks + while reduced_code.shape[1] > 1: + half = reduced_code.shape[1] // 2 + remainder = reduced_code.shape[1] % 2 + + left = reduced_code[:, :half * 2:2] + right = reduced_code[:, 1:half * 2:2] + merged = left | right + + if remainder: + reduced_code = torch.cat([merged, reduced_code[:, -1:]], dim=1) + else: + reduced_code = merged + + reduced_code = reduced_code.squeeze(1) + + seglen_counts = _popcount_int32_tensor(reduced_code).to(torch.int32) - # 2) compute seglen (popcount per reduced_code) and seg (prefix sum) - seglen_counts = _popcount_int32_tensor(reduced_code.to(torch.int64)).to(torch.int32) # [num_blocks] - # seg: length num_blocks+1, seg[0]=0, seg[i+1]=cumsum(seglen_counts) up to i seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device) seg[0] = 0 if num_blocks > 0: @@ -336,30 +326,20 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2( total = int(seg[-1].item()) - # 3) scatter — produce valid_kernel_idx as concatenated ascending set-bit positions for each reduced_code row if total == 0: - valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device) - return valid_kernel_idx, seg + return torch.empty((0,), dtype=torch.int32, device=device), seg - max_val = int(reduced_code.max().item()) - V = max_val.bit_length() if max_val > 0 else 0 - # If you know V externally, pass it instead or set here explicitly. + V = int(reduced_code.max().item()).bit_length() if reduced_code.max() > 0 else 0 if V == 0: - # no bits set anywhere - valid_kernel_idx = torch.empty((0,), dtype=torch.int32, device=device) - return valid_kernel_idx, seg + return torch.empty((0,), dtype=torch.int32, device=device), seg - # build mask of shape (num_blocks, V): True where bit is set - bit_pos = torch.arange(0, V, dtype=torch.int64, device=device) # [V] - # shifted = reduced_code[:, None] >> bit_pos[None, :] - shifted = reduced_code.to(torch.int64).unsqueeze(1) >> bit_pos.unsqueeze(0) - bits = (shifted & 1).to(torch.bool) # (num_blocks, V) + bit_pos = torch.arange(0, V, dtype=torch.int32, device=device) + shifted = reduced_code.unsqueeze(1) >> bit_pos.unsqueeze(0) + bits = (shifted & 1).to(torch.bool) positions = bit_pos.unsqueeze(0).expand(num_blocks, V) - - valid_positions = positions[bits] - valid_kernel_idx = valid_positions.to(torch.int32).contiguous() + valid_kernel_idx = positions[bits].to(torch.int32).contiguous() return valid_kernel_idx, seg @@ -425,35 +405,6 @@ def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache return out, neighbor -class Voxel: - def __init__( - self, - origin: list, - voxel_size: float, - coords: torch.Tensor = None, - attrs: torch.Tensor = None, - layout = None, - device: torch.device = None - ): - if layout is None: - layout = {} - self.origin = torch.tensor(origin, dtype=torch.float32, device=device) - self.voxel_size = voxel_size - self.coords = coords - self.attrs = attrs - self.layout = layout - self.device = device - - @property - def position(self): - return (self.coords + 0.5) * self.voxel_size + self.origin[None, :] - - def split_attrs(self): - return { - k: self.attrs[:, self.layout[k]] - for k in self.layout - } - class Mesh: def __init__(self, vertices, @@ -480,35 +431,3 @@ class Mesh: def cpu(self): return self.to('cpu') - -class MeshWithVoxel(Mesh, Voxel): - def __init__(self, - vertices: torch.Tensor, - faces: torch.Tensor, - origin: list, - voxel_size: float, - coords: torch.Tensor, - attrs: torch.Tensor, - voxel_shape: torch.Size, - layout: Dict = {}, - ): - self.vertices = vertices.float() - self.faces = faces.int() - self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device) - self.voxel_size = voxel_size - self.coords = coords - self.attrs = attrs - self.voxel_shape = voxel_shape - self.layout = layout - - def to(self, device, non_blocking=False): - return MeshWithVoxel( - self.vertices.to(device, non_blocking=non_blocking), - self.faces.to(device, non_blocking=non_blocking), - self.origin.tolist(), - self.voxel_size, - self.coords.to(device, non_blocking=non_blocking), - self.attrs.to(device, non_blocking=non_blocking), - self.voxel_shape, - self.layout, - ) diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 0b1975092..2a18c496a 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from fractions import Fraction from dataclasses import dataclass from typing import List, Any, Dict, Optional, overload, Union, Tuple -from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_submanifold_conv3d +from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, sparse_submanifold_conv3d def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: @@ -210,6 +210,8 @@ class SparseResBlockC2S3d(nn.Module): def forward(self, x, subdiv = None): if self.pred_subdiv: + dtype = next(self.to_subdiv.parameters()).dtype + x = x.to(dtype) subdiv = self.to_subdiv(x) norm1 = self.norm1.to(torch.float32) norm2 = self.norm2.to(torch.float32) @@ -987,114 +989,7 @@ def convert_module_to_f16(l): for p in l.parameters(): p.data = p.data.half() - - -class SparseUnetVaeEncoder(nn.Module): - """ - Sparse Swin Transformer Unet VAE model. - """ - def __init__( - self, - in_channels: int, - model_channels: List[int], - latent_channels: int, - num_blocks: List[int], - block_type: List[str], - down_block_type: List[str], - block_args: List[Dict[str, Any]], - use_fp16: bool = False, - ): - super().__init__() - self.in_channels = in_channels - self.model_channels = model_channels - self.num_blocks = num_blocks - self.dtype = torch.float16 if use_fp16 else torch.float32 - - self.input_layer = SparseLinear(in_channels, model_channels[0]) - self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels) - - self.blocks = nn.ModuleList([]) - for i in range(len(num_blocks)): - self.blocks.append(nn.ModuleList([])) - for j in range(num_blocks[i]): - self.blocks[-1].append( - globals()[block_type[i]]( - model_channels[i], - **block_args[i], - ) - ) - if i < len(num_blocks) - 1: - self.blocks[-1].append( - globals()[down_block_type[i]]( - model_channels[i], - model_channels[i+1], - **block_args[i], - ) - ) - - @property - def device(self) -> torch.device: - return next(self.parameters()).device - - def forward(self, x: SparseTensor, sample_posterior=False, return_raw=False): - h = self.input_layer(x) - h = h.type(self.dtype) - for i, res in enumerate(self.blocks): - for j, block in enumerate(res): - h = block(h) - h = h.type(x.dtype) - h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) - h = self.to_latent(h) - - # Sample from the posterior distribution - mean, logvar = h.feats.chunk(2, dim=-1) - if sample_posterior: - std = torch.exp(0.5 * logvar) - z = mean + std * torch.randn_like(std) - else: - z = mean - z = h.replace(z) - - if return_raw: - return z, mean, logvar - else: - return z - - - -class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder): - def __init__( - self, - model_channels: List[int], - latent_channels: int, - num_blocks: List[int], - block_type: List[str], - down_block_type: List[str], - block_args: List[Dict[str, Any]], - use_fp16: bool = False, - ): - super().__init__( - 6, - model_channels, - latent_channels, - num_blocks, - block_type, - down_block_type, - block_args, - use_fp16, - ) - - def forward(self, vertices: SparseTensor, intersected: SparseTensor, sample_posterior=False, return_raw=False): - x = vertices.replace(torch.cat([ - vertices.feats - 0.5, - intersected.feats.float() - 0.5, - ], dim=1)) - return super().forward(x, sample_posterior, return_raw) - class SparseUnetVaeDecoder(nn.Module): - """ - Sparse Swin Transformer Unet VAE model. - """ def __init__( self, out_channels: int, @@ -1218,10 +1113,10 @@ class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder): N = coords.shape[0] # compute flat keys for all coords (prepend batch 0 same as original code) b = torch.zeros((N,), dtype=torch.long, device=device) - x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long() + x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32) W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) flat_keys = b * (W * H * D) + x * (H * D) + y * D + z - values = torch.arange(N, dtype=torch.long, device=device) + values = torch.arange(N, dtype=torch.int32, device=device) DEFAULT_VAL = 0xffffffff # sentinel used in original code return TorchHashMap(flat_keys, values, DEFAULT_VAL) @@ -1295,13 +1190,12 @@ def flexible_dual_grid_to_mesh( # Extract mesh N = dual_vertices.shape[0] - mesh_vertices = (coords.float() + dual_vertices) / (2 * N) - 0.5 if hashmap_builder is None: # build local TorchHashMap device = coords.device b = torch.zeros((N,), dtype=torch.long, device=device) - x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long() + x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32) W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) flat_keys = b * (W * H * D) + x * (H * D) + y * D + z values = torch.arange(N, dtype=torch.long, device=device) @@ -1316,9 +1210,9 @@ def flexible_dual_grid_to_mesh( M = connected_voxel.shape[0] # flatten connected voxel coords and lookup conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device) - conn_x = connected_voxel.reshape(-1, 3)[:, 0].long() - conn_y = connected_voxel.reshape(-1, 3)[:, 1].long() - conn_z = connected_voxel.reshape(-1, 3)[:, 2].long() + conn_x = connected_voxel.reshape(-1, 3)[:, 0].to(torch.int32) + conn_y = connected_voxel.reshape(-1, 3)[:, 1].to(torch.int32) + conn_z = connected_voxel.reshape(-1, 3)[:, 2].to(torch.int32) W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z @@ -1526,17 +1420,18 @@ class Vae(nn.Module): channels=[512, 128, 32], ) + @torch.no_grad() def decode_shape_slat(self, slat, resolution: int): self.shape_dec.set_resolution(resolution) - device = comfy.model_management.get_torch_device() - self.shape_dec = self.shape_dec.to(device) return self.shape_dec(slat, return_subs=True) + @torch.no_grad() def decode_tex_slat(self, slat, subs): if self.txt_dec is None: raise ValueError("Checkpoint doesn't include texture model") return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5 + # shouldn't be called (placeholder) @torch.no_grad() def decode( self, @@ -1546,17 +1441,4 @@ class Vae(nn.Module): ): meshes, subs = self.decode_shape_slat(shape_slat, resolution) tex_voxels = self.decode_tex_slat(tex_slat, subs) - out_mesh = [] - for m, v in zip(meshes, tex_voxels): - out_mesh.append( - MeshWithVoxel( - m.vertices, m.faces, - origin = [-0.5, -0.5, -0.5], - voxel_size = 1 / resolution, - coords = v.coords[:, 1:], - attrs = v.feats, - voxel_shape = torch.Size([*v.shape, *v.spatial_shape]), - layout=self.pbr_attr_layout - ) - ) - return out_mesh + return tex_voxels diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 96510e916..e781d35e3 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,7 +1,7 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy.ldm.trellis2.vae import SparseTensor -from comfy.utils import ProgressBar +from comfy.utils import ProgressBar, lanczos import torch.nn.functional as TF import comfy.model_management from PIL import Image @@ -102,9 +102,7 @@ def run_conditioning(model, image, mask, include_1024 = True, background_color = cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb) def prepare_tensor(img, size): - resized = torch.nn.functional.interpolate( - img.unsqueeze(0), size=(size, size), mode='bicubic', align_corners=False - ) + resized = lanczos(img.unsqueeze(0), size, size) return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) model_internal.image_size = 512 @@ -148,10 +146,16 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, structure_output, vae, resolution): + + patcher = vae.patcher + device = comfy.model_management.get_torch_device() + comfy.model_management.load_model_gpu(patcher) + vae = vae.first_stage_model decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] + samples = samples.squeeze(-1).transpose(1, 2).to(device) std = shape_slat_normalization["std"].to(samples) mean = shape_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords) @@ -179,10 +183,16 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, structure_output, vae, shape_subs): + + patcher = vae.patcher + device = comfy.model_management.get_torch_device() + comfy.model_management.load_model_gpu(patcher) + vae = vae.first_stage_model decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] + samples = samples.squeeze(-1).transpose(1, 2).to(device) std = tex_slat_normalization["std"].to(samples) mean = tex_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords) From 7d444a4fcca545cf37d3bd42bc3e2d53f0994a3b Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 27 Feb 2026 22:22:07 +0200 Subject: [PATCH 39/93] resolution logit --- comfy/ldm/trellis2/model.py | 2 +- comfy_extras/nodes_trellis2.py | 23 +++++++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 45740faea..bd8309f2b 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -812,7 +812,7 @@ class Trellis2(nn.Module): raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure - timestep = timestep_reshift(timestep) + #timestep = timestep_reshift(timestep) orig_bsz = x.shape[0] if shape_rule: x = x[0].unsqueeze(0) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index e781d35e3..739233523 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -136,7 +136,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): IO.Latent.Input("samples"), IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), - IO.Int.Input("resolution", tooltip="Shape Generation Resolution"), + IO.Combo.Input("resolution", options=["512", "1024"], default="512") ], outputs=[ IO.Mesh.Output("mesh"), @@ -147,6 +147,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, structure_output, vae, resolution): + resolution = int(resolution) patcher = vae.patcher device = comfy.model_management.get_torch_device() comfy.model_management.load_model_gpu(patcher) @@ -154,14 +155,18 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): vae = vae.first_stage_model decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + samples = samples["samples"] - samples = samples.squeeze(-1).transpose(1, 2).to(device) + samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) std = shape_slat_normalization["std"].to(samples) mean = shape_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean mesh, subs = vae.decode_shape_slat(samples, resolution) + faces = torch.stack([m.faces for m in mesh]) + verts = torch.stack([m.vertices for m in mesh]) + mesh = Types.MESH(vertices=verts, faces=faces) return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @@ -192,13 +197,16 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] - samples = samples.squeeze(-1).transpose(1, 2).to(device) + samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) std = tex_slat_normalization["std"].to(samples) mean = tex_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean mesh = vae.decode_tex_slat(samples, shape_subs) + faces = torch.stack([m.faces for m in mesh]) + verts = torch.stack([m.vertices for m in mesh]) + mesh = Types.MESH(vertices=verts, faces=faces) return IO.NodeOutput(mesh) class VaeDecodeStructureTrellis2(IO.ComfyNode): @@ -210,6 +218,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): inputs=[ IO.Latent.Input("samples"), IO.Vae.Input("vae"), + IO.Combo.Input("resolution", options=["32", "64"], default="32") ], outputs=[ IO.Voxel.Output("structure_output"), @@ -217,7 +226,8 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, samples, vae): + def execute(cls, samples, vae, resolution): + resolution = int(resolution) vae = vae.first_stage_model decoder = vae.struct_dec load_device = comfy.model_management.get_torch_device() @@ -227,6 +237,11 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): samples = samples.to(load_device) decoded = decoder(samples)>0 decoder.to(offload_device) + current_res = decoded.shape[2] + + if current_res != resolution: + ratio = current_res // resolution + decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 out = Types.VOXEL(decoded.squeeze(1).float()) return IO.NodeOutput(out) From 44adb27782ea1df23ea43ca44fde808ac8b893d2 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 10 Mar 2026 22:27:10 +0200 Subject: [PATCH 40/93] working version --- comfy/ldm/trellis2/attention.py | 2 +- comfy/ldm/trellis2/model.py | 4 +- comfy_extras/nodes_trellis2.py | 430 +++++++++++++++++--------------- 3 files changed, 232 insertions(+), 204 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 0b9c12294..681666430 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -46,7 +46,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if var_length: - return out.contiguous().transpose(1, 2).values() + return out.transpose(1, 2).values() if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index bd8309f2b..4bbfbff5f 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -767,8 +767,6 @@ class Trellis2(nn.Module): if embeds is None: raise ValueError("Trellis2.forward requires 'embeds' in kwargs") is_1024 = self.img2shape.resolution == 1024 - if is_1024: - context = embeds coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") if coords is not None: @@ -777,6 +775,8 @@ class Trellis2(nn.Module): else: mode = "structure_generation" not_struct_mode = False + if is_1024 and mode == "shape_generation": + context = embeds sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: timestep *= 1000.0 diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 739233523..23b2f72bb 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,9 +1,8 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy.ldm.trellis2.vae import SparseTensor -from comfy.utils import ProgressBar, lanczos -import torch.nn.functional as TF import comfy.model_management +import logging from PIL import Image import numpy as np import torch @@ -39,93 +38,6 @@ tex_slat_normalization = { ])[None] } -dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) -dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) - -def smart_crop_square(image, mask, margin_ratio=0.1, bg_color=(128, 128, 128)): - nz = torch.nonzero(mask[0] > 0.5) - if nz.shape[0] == 0: - C, H, W = image.shape - side = max(H, W) - canvas = torch.full((C, side, side), 0.5, device=image.device) # Gray - canvas[:, (side-H)//2:(side-H)//2+H, (side-W)//2:(side-W)//2+W] = image - return canvas - - y_min, x_min = nz.min(dim=0)[0] - y_max, x_max = nz.max(dim=0)[0] - - obj_w, obj_h = x_max - x_min, y_max - y_min - center_x, center_y = (x_min + x_max) / 2, (y_min + y_max) / 2 - - side = int(max(obj_w, obj_h) * (1 + margin_ratio * 2)) - half_side = side / 2 - - x1, y1 = int(center_x - half_side), int(center_y - half_side) - x2, y2 = x1 + side, y1 + side - - C, H, W = image.shape - canvas = torch.ones((C, side, side), device=image.device) - for c in range(C): - canvas[c] *= (bg_color[c] / 255.0) - - src_x1, src_y1 = max(0, x1), max(0, y1) - src_x2, src_y2 = min(W, x2), min(H, y2) - - dst_x1, dst_y1 = max(0, -x1), max(0, -y1) - dst_x2 = dst_x1 + (src_x2 - src_x1) - dst_y2 = dst_y1 + (src_y2 - src_y1) - - img_crop = image[:, src_y1:src_y2, src_x1:src_x2] - mask_crop = mask[0, src_y1:src_y2, src_x1:src_x2] - - bg_val = torch.tensor(bg_color, device=image.device, dtype=image.dtype).view(-1, 1, 1) / 255.0 - - masked_crop = img_crop * mask_crop + bg_val * (1.0 - mask_crop) - - canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = masked_crop - - return canvas - -def run_conditioning(model, image, mask, include_1024 = True, background_color = "black"): - model_internal = model.model - device = comfy.model_management.intermediate_device() - torch_device = comfy.model_management.get_torch_device() - - bg_colors = {"black": (0, 0, 0), "gray": (128, 128, 128), "white": (255, 255, 255)} - bg_rgb = bg_colors.get(background_color, (128, 128, 128)) - - img_t = image[0].movedim(-1, 0).to(torch_device).float() - mask_t = mask[0].to(torch_device).float() - if mask_t.ndim == 2: - mask_t = mask_t.unsqueeze(0) - - cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb) - - def prepare_tensor(img, size): - resized = lanczos(img.unsqueeze(0), size, size) - return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) - - model_internal.image_size = 512 - input_512 = prepare_tensor(cropped_img, 512) - cond_512 = model_internal(input_512)[0] - - cond_1024 = None - if include_1024: - model_internal.image_size = 1024 - input_1024 = prepare_tensor(cropped_img, 1024) - cond_1024 = model_internal(input_1024)[0] - - conditioning = { - 'cond_512': cond_512.to(device), - 'neg_cond': torch.zeros_like(cond_512).to(device), - } - if cond_1024 is not None: - conditioning['cond_1024'] = cond_1024.to(device) - - preprocessed_tensor = cropped_img.movedim(0, -1).unsqueeze(0).cpu() - - return conditioning, preprocessed_tensor - class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def define_schema(cls): @@ -245,6 +157,39 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): out = Types.VOXEL(decoded.squeeze(1).float()) return IO.NodeOutput(out) +dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) +dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + +def run_conditioning(model, cropped_img_tensor, include_1024=True): + model_internal = model.model + device = comfy.model_management.intermediate_device() + torch_device = comfy.model_management.get_torch_device() + + img_t = cropped_img_tensor.to(torch_device) + + def prepare_tensor(img, size): + resized = torch.nn.functional.interpolate(img, size=(size, size), mode='bicubic', align_corners=False).clamp(0.0, 1.0) + return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) + + model_internal.image_size = 512 + input_512 = prepare_tensor(img_t, 512) + cond_512 = model_internal(input_512)[0] + + cond_1024 = None + if include_1024: + model_internal.image_size = 1024 + input_1024 = prepare_tensor(img_t, 1024) + cond_1024 = model_internal(input_1024)[0] + + conditioning = { + 'cond_512': cond_512.to(device), + 'neg_cond': torch.zeros_like(cond_512).to(device), + } + if cond_1024 is not None: + conditioning['cond_1024'] = cond_1024.to(device) + + return conditioning + class Trellis2Conditioning(IO.ComfyNode): @classmethod def define_schema(cls): @@ -268,22 +213,60 @@ class Trellis2Conditioning(IO.ComfyNode): if image.ndim == 4: image = image[0] + if mask.ndim == 3: + mask = mask[0] - # TODO - image = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - image = Image.fromarray(image) - max_size = max(image.size) - scale = min(1, 1024 / max_size) - if scale < 1: - image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS) - new_h, new_w = int(mask.shape[-2] * scale), int(mask.shape[-1] * scale) - mask = TF.interpolate(mask.unsqueeze(0).float(), size=(new_h, new_w), mode='nearest').squeeze(0) + img_np = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + mask_np = (mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255 + pil_img = Image.fromarray(img_np) + pil_mask = Image.fromarray(mask_np) - # could make 1024 an option - conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) - embeds = conditioning["cond_1024"] # should add that + max_size = max(pil_img.size) + scale = min(1.0, 1024 / max_size) + if scale < 1.0: + new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale) + pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) + pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) + + rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) + rgba_np[:, :, :3] = np.array(pil_img) + rgba_np[:, :, 3] = np.array(pil_mask) + + alpha = rgba_np[:, :, 3] + bbox_coords = np.argwhere(alpha > 0.8 * 255) + + if len(bbox_coords) > 0: + y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1]) + y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1]) + + center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0 + size = max(y_max - y_min, x_max - x_min) + + crop_x1 = int(center_x - size // 2) + crop_y1 = int(center_y - size // 2) + crop_x2 = int(center_x + size // 2) + crop_y2 = int(center_y + size // 2) + + rgba_pil = Image.fromarray(rgba_np, 'RGBA') + cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) + cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 + else: + logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") + cropped_np = rgba_np.astype(np.float32) / 255.0 + + bg_colors = {"black": [0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} + bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32) + + fg = cropped_np[:, :, :3] + alpha_float = cropped_np[:, :, 3:4] + composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) + + cropped_img_tensor = torch.from_numpy(composite_np).movedim(-1, 0).unsqueeze(0).float() + + conditioning = run_conditioning(clip_vision_model, cropped_img_tensor, include_1024=True) + + embeds = conditioning["cond_1024"] positive = [[conditioning["cond_512"], {"embeds": embeds}]] negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]] return IO.NodeOutput(positive, negative) @@ -417,118 +400,168 @@ def simplify_fn(vertices, faces, target=100000): return final_vertices, final_faces -def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): +def fill_holes_fn(vertices, faces, max_perimeter=0.03): is_batched = vertices.ndim == 3 if is_batched: - batch_size = vertices.shape[0] - if batch_size > 1: - v_out, f_out = [], [] - for i in range(batch_size): - v, f = fill_holes_fn(vertices[i], faces[i], max_hole_perimeter) - v_out.append(v) - f_out.append(f) - return torch.stack(v_out), torch.stack(f_out) - - vertices = vertices.squeeze(0) - faces = faces.squeeze(0) + v_list, f_list = [],[] + for i in range(vertices.shape[0]): + v_i, f_i = fill_holes_fn(vertices[i], faces[i], max_perimeter) + v_list.append(v_i) + f_list.append(f_i) + return torch.stack(v_list), torch.stack(f_list) device = vertices.device - orig_vertices = vertices - orig_faces = faces + v = vertices + f = faces - edges = torch.cat([ - faces[:, [0, 1]], - faces[:, [1, 2]], - faces[:, [2, 0]] - ], dim=0) + if f.shape[0] == 0: + return v, f + edges = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0) edges_sorted, _ = torch.sort(edges, dim=1) - unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) + + max_v = v.shape[0] + packed_undirected = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long() + + unique_packed, counts = torch.unique(packed_undirected, return_counts=True) boundary_mask = counts == 1 - boundary_edges_sorted = unique_edges[boundary_mask] + boundary_packed = unique_packed[boundary_mask] - if boundary_edges_sorted.shape[0] == 0: - if is_batched: - return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) - return orig_vertices, orig_faces + if boundary_packed.numel() == 0: + return v, f - max_idx = vertices.shape[0] - - packed_edges_all = torch.sort(edges, dim=1).values - packed_edges_all = packed_edges_all[:, 0] * max_idx + packed_edges_all[:, 1] - - packed_boundary = boundary_edges_sorted[:, 0] * max_idx + boundary_edges_sorted[:, 1] - - is_boundary_edge = torch.isin(packed_edges_all, packed_boundary) - active_boundary_edges = edges[is_boundary_edge] + packed_directed_sorted = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long() + is_boundary = torch.isin(packed_directed_sorted, boundary_packed) + boundary_edges_directed = edges[is_boundary] adj = {} - edges_np = active_boundary_edges.cpu().numpy() - for u, v in edges_np: - adj[u] = v + in_deg = {} + out_deg = {} + + edges_list = boundary_edges_directed.tolist() + for u, v_idx in edges_list: + if u not in adj: adj[u] = [] + adj[u].append(v_idx) + out_deg[u] = out_deg.get(u, 0) + 1 + in_deg[v_idx] = in_deg.get(v_idx, 0) + 1 + + manifold_nodes = set() + for node in set(list(in_deg.keys()) + list(out_deg.keys())): + if in_deg.get(node, 0) == 1 and out_deg.get(node, 0) == 1: + manifold_nodes.add(node) + + loops =[] + visited_nodes = set() - loops = [] - visited_edges = set() - processed_nodes = set() for start_node in list(adj.keys()): - if start_node in processed_nodes: + if start_node not in manifold_nodes or start_node in visited_nodes: continue - current_loop, curr = [], start_node - while curr in adj: - next_node = adj[curr] - if (curr, next_node) in visited_edges: - break - visited_edges.add((curr, next_node)) - processed_nodes.add(curr) + + curr = start_node + current_loop =[] + + while True: current_loop.append(curr) + visited_nodes.add(curr) + + next_node = adj[curr][0] + + if next_node == start_node: + if len(current_loop) >= 3: + loops.append(current_loop) + break + + if next_node not in manifold_nodes or next_node in visited_nodes: + break + curr = next_node - if curr == start_node: - loops.append(current_loop) - break - if len(current_loop) > len(edges_np): + + if len(current_loop) > len(edges_list): break - if not loops: - if is_batched: - return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) - return orig_vertices, orig_faces + new_faces =[] + new_verts = [] + curr_v_idx = v.shape[0] - new_faces = [] - v_offset = vertices.shape[0] - valid_new_verts = [] + for loop in loops: + loop_indices = torch.tensor(loop, device=device, dtype=torch.long) + loop_points = v[loop_indices] - for loop_indices in loops: - if len(loop_indices) < 3: - continue - loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) - loop_verts = vertices[loop_tensor] - diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) - perimeter = torch.norm(diffs, dim=1).sum() + # Calculate perimeter + p1 = loop_points + p2 = torch.roll(loop_points, shifts=-1, dims=0) + perimeter = torch.norm(p1 - p2, dim=1).sum().item() - if perimeter > max_hole_perimeter: - continue + if perimeter <= max_perimeter: + centroid = loop_points.mean(dim=0) + new_verts.append(centroid) + center_idx = curr_v_idx + curr_v_idx += 1 - center = loop_verts.mean(dim=0) - valid_new_verts.append(center) - c_idx = v_offset - v_offset += 1 + for i in range(len(loop)): + u_idx = loop[i] + v_next_idx = loop[(i + 1) % len(loop)] + new_faces.append([u_idx, v_next_idx, center_idx]) - num_v = len(loop_indices) - for i in range(num_v): - v_curr, v_next = loop_indices[i], loop_indices[(i + 1) % num_v] - new_faces.append([v_curr, v_next, c_idx]) + if new_faces: + v = torch.cat([v, torch.stack(new_verts)], dim=0) + f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0) - if len(valid_new_verts) > 0: - added_vertices = torch.stack(valid_new_verts, dim=0) - added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) - vertices = torch.cat([orig_vertices, added_vertices], dim=0) - faces = torch.cat([orig_faces, added_faces], dim=0) - else: - vertices, faces = orig_vertices, orig_faces + return v, f +def merge_duplicate_vertices(vertices, faces, tolerance=1e-5): + is_batched = vertices.ndim == 3 if is_batched: - return vertices.unsqueeze(0), faces.unsqueeze(0) + v_list, f_list = [],[] + for i in range(vertices.shape[0]): + v_i, f_i = merge_duplicate_vertices(vertices[i], faces[i], tolerance) + v_list.append(v_i) + f_list.append(f_i) + return torch.stack(v_list), torch.stack(f_list) + v_min = vertices.min(dim=0, keepdim=True)[0] + v_quant = ((vertices - v_min) / tolerance).round().long() + + unique_quant, inverse_indices = torch.unique(v_quant, dim=0, return_inverse=True) + + new_vertices = torch.zeros((unique_quant.shape[0], 3), dtype=vertices.dtype, device=vertices.device) + new_vertices.index_copy_(0, inverse_indices, vertices) + + new_faces = inverse_indices[faces.long()] + + valid = (new_faces[:, 0] != new_faces[:, 1]) & \ + (new_faces[:, 1] != new_faces[:, 2]) & \ + (new_faces[:, 2] != new_faces[:, 0]) + + return new_vertices, new_faces[valid] + +def fix_normals(vertices, faces): + is_batched = vertices.ndim == 3 + if is_batched: + v_list, f_list = [], [] + for i in range(vertices.shape[0]): + v_i, f_i = fix_normals(vertices[i], faces[i]) + v_list.append(v_i) + f_list.append(f_i) + return torch.stack(v_list), torch.stack(f_list) + + if faces.shape[0] == 0: + return vertices, faces + + center = vertices.mean(0) + v0 = vertices[faces[:, 0].long()] + v1 = vertices[faces[:, 1].long()] + v2 = vertices[faces[:, 2].long()] + + normals = torch.cross(v1 - v0, v2 - v0, dim=1) + + face_centers = (v0 + v1 + v2) / 3.0 + dir_from_center = face_centers - center + + dot = (normals * dir_from_center).sum(1) + flip_mask = dot < 0 + + faces[flip_mask] = faces[flip_mask][:, [0, 2, 1]] return vertices, faces class PostProcessMesh(IO.ComfyNode): @@ -539,36 +572,31 @@ class PostProcessMesh(IO.ComfyNode): category="latent/3d", inputs=[ IO.Mesh.Input("mesh"), - IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), # max? - IO.Float.Input("fill_holes_perimeter", default=0.003, min=0.0, step=0.0001) + IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), + IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0, step=0.0001) ], outputs=[ IO.Mesh.Output("output_mesh"), ] ) + @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): - bar = ProgressBar(2) mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces - if fill_holes_perimeter != 0.0: - verts, faces = fill_holes_fn(verts, faces, max_hole_perimeter=fill_holes_perimeter) - bar.update(1) - else: - bar.update(1) + verts, faces = merge_duplicate_vertices(verts, faces, tolerance=1e-5) - if simplify != 0: - verts, faces = simplify_fn(verts, faces, simplify) - bar.update(1) - else: - bar.update(1) + if fill_holes_perimeter > 0: + verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter) - # potentially adding laplacian smoothing + if simplify > 0 and faces.shape[0] > simplify: + verts, faces = simplify_fn(verts, faces, target=simplify) + + verts, faces = fix_normals(verts, faces) mesh.vertices = verts mesh.faces = faces - return IO.NodeOutput(mesh) class Trellis2Extension(ComfyExtension): From 011f624dd548540e92cb8ae08671e36ae3532b6a Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 11 Mar 2026 20:11:58 +0200 Subject: [PATCH 41/93] post-process rewrite + light texture model work --- comfy/ldm/trellis2/model.py | 4 + comfy_extras/nodes_trellis2.py | 214 ++++++++++++--------------------- 2 files changed, 84 insertions(+), 134 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 4bbfbff5f..651d516b6 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -810,6 +810,10 @@ class Trellis2(nn.Module): elif mode == "texture_generation": if self.shape2txt is None: raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") + slat = transformer_options.get("shape_slat") + if slat is None: + raise ValueError("shape_slat can't be None") + x = sparse_cat([x, slat]) out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure #timestep = timestep_reshift(timestep) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 23b2f72bb..0b94a2d0a 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -38,6 +38,13 @@ tex_slat_normalization = { ])[None] } +def shape_norm(shape_latent, coords): + std = shape_slat_normalization["std"].to(shape_latent) + mean = shape_slat_normalization["mean"].to(shape_latent) + samples = SparseTensor(feats = shape_latent, coords=coords) + samples = samples * std + mean + return samples + class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def define_schema(cls): @@ -70,10 +77,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): samples = samples["samples"] samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) - std = shape_slat_normalization["std"].to(samples) - mean = shape_slat_normalization["mean"].to(samples) - samples = SparseTensor(feats = samples, coords=coords) - samples = samples * std + mean + samples = shape_norm(samples, coords) mesh, subs = vae.decode_shape_slat(samples, resolution) faces = torch.stack([m.faces for m in mesh]) @@ -313,6 +317,8 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): category="latent/3d", inputs=[ IO.Voxel.Input("structure_output"), + IO.Latent.Input("shape_latent"), + IO.Model.Input("model") ], outputs=[ IO.Latent.Output(), @@ -321,11 +327,15 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_output, model): + def execute(cls, structure_output, shape_latent, model): # TODO decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 + + shape_latent = shape_latent["samples"] + shape_latent = shape_norm(shape_latent, coords) + latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) model = model.clone() model.model_options = model.model_options.copy() @@ -336,6 +346,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"]["coords"] = coords model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + model.model_options["transformer_options"]["shape_slat"] = shape_latent return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) @@ -360,25 +371,34 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): return IO.NodeOutput({"samples": latent, "type": "trellis2"}) def simplify_fn(vertices, faces, target=100000): + is_batched = vertices.ndim == 3 + if is_batched: + v_list, f_list = [], [] + for i in range(vertices.shape[0]): + v_i, f_i = simplify_fn(vertices[i], faces[i], target) + v_list.append(v_i) + f_list.append(f_i) + return torch.stack(v_list), torch.stack(f_list) - if vertices.shape[0] <= target: + if faces.shape[0] <= target: return vertices, faces - min_feat = vertices.min(dim=0)[0] - max_feat = vertices.max(dim=0)[0] - extent = (max_feat - min_feat).max() + device = vertices.device + target_v = target / 2.0 - grid_resolution = int(torch.sqrt(torch.tensor(target)).item() * 1.5) - voxel_size = extent / grid_resolution + min_v = vertices.min(dim=0)[0] + max_v = vertices.max(dim=0)[0] + extent = max_v - min_v - quantized_coords = ((vertices - min_feat) / voxel_size).long() + volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8) + cell_size = (volume / target_v) ** (1/3.0) - unique_coords, inverse_indices = torch.unique(quantized_coords, dim=0, return_inverse=True) + quantized = ((vertices - min_v) / cell_size).round().long() + unique_coords, inverse_indices = torch.unique(quantized, dim=0, return_inverse=True) + num_cells = unique_coords.shape[0] - num_new_verts = unique_coords.shape[0] - new_vertices = torch.zeros((num_new_verts, 3), dtype=vertices.dtype, device=vertices.device) - - counts = torch.zeros((num_new_verts, 1), dtype=vertices.dtype, device=vertices.device) + new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device) + counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device) new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices) counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1])) @@ -387,11 +407,9 @@ def simplify_fn(vertices, faces, target=100000): new_faces = inverse_indices[faces] - v0 = new_faces[:, 0] - v1 = new_faces[:, 1] - v2 = new_faces[:, 2] - - valid_mask = (v0 != v1) & (v1 != v2) & (v2 != v0) + valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \ + (new_faces[:, 1] != new_faces[:, 2]) & \ + (new_faces[:, 2] != new_faces[:, 0]) new_faces = new_faces[valid_mask] unique_face_indices, inv_face = torch.unique(new_faces.reshape(-1), return_inverse=True) @@ -414,7 +432,7 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): v = vertices f = faces - if f.shape[0] == 0: + if f.numel() == 0: return v, f edges = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0) @@ -424,145 +442,75 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): packed_undirected = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long() unique_packed, counts = torch.unique(packed_undirected, return_counts=True) - boundary_mask = counts == 1 - boundary_packed = unique_packed[boundary_mask] + boundary_packed = unique_packed[counts == 1] if boundary_packed.numel() == 0: return v, f - packed_directed_sorted = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long() + packed_directed_sorted = edges[:, 0].min(edges[:, 1]).long() * max_v + edges[:, 0].max(edges[:, 1]).long() is_boundary = torch.isin(packed_directed_sorted, boundary_packed) - boundary_edges_directed = edges[is_boundary] + b_edges = edges[is_boundary] - adj = {} - in_deg = {} - out_deg = {} - - edges_list = boundary_edges_directed.tolist() - for u, v_idx in edges_list: - if u not in adj: adj[u] = [] - adj[u].append(v_idx) - out_deg[u] = out_deg.get(u, 0) + 1 - in_deg[v_idx] = in_deg.get(v_idx, 0) + 1 - - manifold_nodes = set() - for node in set(list(in_deg.keys()) + list(out_deg.keys())): - if in_deg.get(node, 0) == 1 and out_deg.get(node, 0) == 1: - manifold_nodes.add(node) + adj = {u.item(): v_idx.item() for u, v_idx in b_edges} loops =[] - visited_nodes = set() + visited = set() - for start_node in list(adj.keys()): - if start_node not in manifold_nodes or start_node in visited_nodes: + for start_node in adj.keys(): + if start_node in visited: continue curr = start_node - current_loop =[] + loop = [] - while True: - current_loop.append(curr) - visited_nodes.add(curr) + while curr not in visited: + visited.add(curr) + loop.append(curr) + curr = adj.get(curr, -1) - next_node = adj[curr][0] - - if next_node == start_node: - if len(current_loop) >= 3: - loops.append(current_loop) + if curr == -1: + loop = [] + break + if curr == start_node: + loops.append(loop) break - if next_node not in manifold_nodes or next_node in visited_nodes: - break - - curr = next_node - - if len(current_loop) > len(edges_list): - break - - new_faces =[] - new_verts = [] - curr_v_idx = v.shape[0] + new_verts =[] + new_faces = [] + v_idx = v.shape[0] for loop in loops: - loop_indices = torch.tensor(loop, device=device, dtype=torch.long) - loop_points = v[loop_indices] + loop_t = torch.tensor(loop, device=device, dtype=torch.long) + loop_v = v[loop_t] - # Calculate perimeter - p1 = loop_points - p2 = torch.roll(loop_points, shifts=-1, dims=0) - perimeter = torch.norm(p1 - p2, dim=1).sum().item() + diffs = loop_v - torch.roll(loop_v, shifts=-1, dims=0) + perimeter = torch.norm(diffs, dim=1).sum().item() if perimeter <= max_perimeter: - centroid = loop_points.mean(dim=0) - new_verts.append(centroid) - center_idx = curr_v_idx - curr_v_idx += 1 + new_verts.append(loop_v.mean(dim=0)) for i in range(len(loop)): - u_idx = loop[i] - v_next_idx = loop[(i + 1) % len(loop)] - new_faces.append([u_idx, v_next_idx, center_idx]) + new_faces.append([loop[i], loop[(i + 1) % len(loop)], v_idx]) + v_idx += 1 - if new_faces: + if new_verts: v = torch.cat([v, torch.stack(new_verts)], dim=0) f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0) return v, f -def merge_duplicate_vertices(vertices, faces, tolerance=1e-5): +def make_double_sided(vertices, faces): is_batched = vertices.ndim == 3 if is_batched: - v_list, f_list = [],[] - for i in range(vertices.shape[0]): - v_i, f_i = merge_duplicate_vertices(vertices[i], faces[i], tolerance) - v_list.append(v_i) - f_list.append(f_i) - return torch.stack(v_list), torch.stack(f_list) + f_list =[] + for i in range(faces.shape[0]): + f_inv = faces[i][:,[0, 2, 1]] + f_list.append(torch.cat([faces[i], f_inv], dim=0)) + return vertices, torch.stack(f_list) - v_min = vertices.min(dim=0, keepdim=True)[0] - v_quant = ((vertices - v_min) / tolerance).round().long() - - unique_quant, inverse_indices = torch.unique(v_quant, dim=0, return_inverse=True) - - new_vertices = torch.zeros((unique_quant.shape[0], 3), dtype=vertices.dtype, device=vertices.device) - new_vertices.index_copy_(0, inverse_indices, vertices) - - new_faces = inverse_indices[faces.long()] - - valid = (new_faces[:, 0] != new_faces[:, 1]) & \ - (new_faces[:, 1] != new_faces[:, 2]) & \ - (new_faces[:, 2] != new_faces[:, 0]) - - return new_vertices, new_faces[valid] - -def fix_normals(vertices, faces): - is_batched = vertices.ndim == 3 - if is_batched: - v_list, f_list = [], [] - for i in range(vertices.shape[0]): - v_i, f_i = fix_normals(vertices[i], faces[i]) - v_list.append(v_i) - f_list.append(f_i) - return torch.stack(v_list), torch.stack(f_list) - - if faces.shape[0] == 0: - return vertices, faces - - center = vertices.mean(0) - v0 = vertices[faces[:, 0].long()] - v1 = vertices[faces[:, 1].long()] - v2 = vertices[faces[:, 2].long()] - - normals = torch.cross(v1 - v0, v2 - v0, dim=1) - - face_centers = (v0 + v1 + v2) / 3.0 - dir_from_center = face_centers - center - - dot = (normals * dir_from_center).sum(1) - flip_mask = dot < 0 - - faces[flip_mask] = faces[flip_mask][:, [0, 2, 1]] - return vertices, faces + faces_inv = faces[:, [0, 2, 1]] + faces_double = torch.cat([faces, faces_inv], dim=0) + return vertices, faces_double class PostProcessMesh(IO.ComfyNode): @classmethod @@ -572,7 +520,7 @@ class PostProcessMesh(IO.ComfyNode): category="latent/3d", inputs=[ IO.Mesh.Input("mesh"), - IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), + IO.Int.Input("simplify", default=1_000_000, min=0, max=50_000_000), IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0, step=0.0001) ], outputs=[ @@ -585,15 +533,13 @@ class PostProcessMesh(IO.ComfyNode): mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces - verts, faces = merge_duplicate_vertices(verts, faces, tolerance=1e-5) - if fill_holes_perimeter > 0: verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter) if simplify > 0 and faces.shape[0] > simplify: verts, faces = simplify_fn(verts, faces, target=simplify) - verts, faces = fix_normals(verts, faces) + verts, faces = make_double_sided(verts, faces) mesh.vertices = verts mesh.faces = faces From 2d904b28da9631a756784b9bd54c4b46b8290522 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 11 Mar 2026 22:50:17 +0200 Subject: [PATCH 42/93] upscale node + simple node simplification --- comfy_extras/nodes_trellis2.py | 78 ++++++++++++++++++++++++++++------ 1 file changed, 65 insertions(+), 13 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 0b94a2d0a..86f08f8bd 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -53,7 +53,6 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): category="latent/3d", inputs=[ IO.Latent.Input("samples"), - IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), IO.Combo.Input("resolution", options=["512", "1024"], default="512") ], @@ -64,7 +63,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples, structure_output, vae, resolution): + def execute(cls, samples, vae, resolution): resolution = int(resolution) patcher = vae.patcher @@ -72,8 +71,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): comfy.model_management.load_model_gpu(patcher) vae = vae.first_stage_model - decoded = structure_output.data.unsqueeze(1) - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + coords = samples["coords"] samples = samples["samples"] samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) @@ -93,7 +91,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): category="latent/3d", inputs=[ IO.Latent.Input("samples"), - IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), IO.AnyType.Input("shape_subs"), ], @@ -103,15 +100,15 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples, structure_output, vae, shape_subs): + def execute(cls, samples, vae, shape_subs): patcher = vae.patcher device = comfy.model_management.get_torch_device() comfy.model_management.load_model_gpu(patcher) vae = vae.first_stage_model - decoded = structure_output.data.unsqueeze(1) - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + coords = samples["coords"] + samples = samples["samples"] samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) std = tex_slat_normalization["std"].to(samples) @@ -161,6 +158,56 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): out = Types.VOXEL(decoded.squeeze(1).float()) return IO.NodeOutput(out) +class Trellis2UpsampleCascade(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Trellis2UpsampleCascade", + category="latent/3d", + inputs=[ + IO.Latent.Input("shape_latent_512"), + IO.Vae.Input("vae"), + IO.Combo.Input("target_resolution", options=["1024", "1536"], default="1024"), + IO.Int.Input("max_tokens", default=49152, min=1024, max=100000) + ], + outputs=[ + IO.AnyType.Output("hr_coords"), + ] + ) + + @classmethod + def execute(cls, shape_latent_512, vae, target_resolution, max_tokens): + device = comfy.model_management.get_torch_device() + + feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) + coords_512 = shape_latent_512["coords"].to(device) + + slat = shape_norm(feats, coords_512) + + decoder = vae.first_stage_model.shape_dec + decoder.to(device) + + slat.feats = slat.feats.to(next(decoder.parameters()).dtype) + hr_coords = decoder.upsample(slat, upsample_times=4) + decoder.cpu() + + lr_resolution = 512 + hr_resolution = int(target_resolution) + + while True: + quant_coords = torch.cat([ + hr_coords[:, :1], + ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + final_coords = quant_coords.unique(dim=0) + num_tokens = final_coords.shape[0] + + if num_tokens < max_tokens or hr_resolution <= 1024: + break + hr_resolution -= 128 + + return IO.NodeOutput(final_coords.cpu()) + dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) @@ -282,7 +329,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): node_id="EmptyShapeLatentTrellis2", category="latent/3d", inputs=[ - IO.Voxel.Input("structure_output"), + IO.AnyType.Input("structure_or_coords"), IO.Model.Input("model") ], outputs=[ @@ -292,9 +339,13 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_output, model): - decoded = structure_output.data.unsqueeze(1) - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + def execute(cls, structure_or_coords, model): + # to accept the upscaled coords + if hasattr(structure_or_coords, "data"): + decoded = structure_or_coords.data.unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + else: + coords = structure_or_coords in_channels = 32 # image like format latent = torch.randn(1, in_channels, coords.shape[0], 1) @@ -307,7 +358,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"]["coords"] = coords model.model_options["transformer_options"]["generation_mode"] = "shape_generation" - return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) + return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -556,6 +607,7 @@ class Trellis2Extension(ComfyExtension): VaeDecodeTextureTrellis, VaeDecodeShapeTrellis, VaeDecodeStructureTrellis2, + Trellis2UpsampleCascade, PostProcessMesh ] From 5d2548822c4519b8fb22b454c15114f7ba0ea26d Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Mar 2026 02:36:01 +0200 Subject: [PATCH 43/93] . --- comfy/ldm/trellis2/model.py | 1 + comfy_extras/nodes_trellis2.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 651d516b6..7a3e387c3 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -756,6 +756,7 @@ class Trellis2(nn.Module): self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) if init_txt_model: self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) + self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args) args.pop("out_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) self.guidance_interval = [0.6, 1.0] diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 86f08f8bd..409d2d23c 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -54,7 +54,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): inputs=[ IO.Latent.Input("samples"), IO.Vae.Input("vae"), - IO.Combo.Input("resolution", options=["512", "1024"], default="512") + IO.Combo.Input("resolution", options=["512", "1024"], default="1024") ], outputs=[ IO.Mesh.Output("mesh"), @@ -116,7 +116,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean - mesh = vae.decode_tex_slat(samples, shape_subs) + mesh = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5 faces = torch.stack([m.faces for m in mesh]) verts = torch.stack([m.vertices for m in mesh]) mesh = Types.MESH(vertices=verts, faces=faces) @@ -541,7 +541,7 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): new_verts.append(loop_v.mean(dim=0)) for i in range(len(loop)): - new_faces.append([loop[i], loop[(i + 1) % len(loop)], v_idx]) + new_faces.append([loop[(i + 1) % len(loop)], loop[i], v_idx]) v_idx += 1 if new_verts: From def8947e75b406fa87df78f9c6f4b9d0929eb6d8 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Mar 2026 18:37:11 +0200 Subject: [PATCH 44/93] shape working --- comfy/ldm/trellis2/model.py | 50 +++++++++++++++++++++++++--------- comfy_extras/nodes_trellis2.py | 22 +++++++++++---- 2 files changed, 53 insertions(+), 19 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 7a3e387c3..8a0c6d8b6 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -770,14 +770,20 @@ class Trellis2(nn.Module): is_1024 = self.img2shape.resolution == 1024 coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") + is_512_run = False + if mode == "shape_generation_512": + is_512_run = True + mode = "shape_generation" if coords is not None: x = x.squeeze(-1).transpose(1, 2) not_struct_mode = True else: mode = "structure_generation" not_struct_mode = False - if is_1024 and mode == "shape_generation": + + if is_1024 and mode == "shape_generation" and not is_512_run: context = embeds + sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: timestep *= 1000.0 @@ -786,12 +792,24 @@ class Trellis2(nn.Module): txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] if not_struct_mode: - B, N, C = x.shape + orig_bsz = x.shape[0] + rule = txt_rule if mode == "texture_generation" else shape_rule - if mode == "shape_generation": - feats_flat = x.reshape(-1, C) + if rule and orig_bsz > 1: + x_eval = x[1].unsqueeze(0) + t_eval = timestep[1].unsqueeze(0) if timestep.shape[0] > 1 else timestep + c_eval = cond + else: + x_eval = x + t_eval = timestep + c_eval = context - # 3. inflate coords [N, 4] -> [B*N, 4] + B, N, C = x_eval.shape + + if mode in ["shape_generation", "texture_generation"]: + feats_flat = x_eval.reshape(-1, C) + + # inflate coords [N, 4] -> [B*N, 4] coords_list = [] for i in range(B): c = coords.clone() @@ -799,23 +817,27 @@ class Trellis2(nn.Module): coords_list.append(c) batched_coords = torch.cat(coords_list, dim=0) - else: # TODO: texture - # may remove the else if texture doesn't require special handling + else: batched_coords = coords - feats_flat = x - x = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) + feats_flat = x_eval + + x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) if mode == "shape_generation": # TODO - out = self.img2shape(x, timestep, context) + if is_512_run: + out = self.img2shape_512(x_st, t_eval, c_eval) + else: + out = self.img2shape(x_st, t_eval, c_eval) elif mode == "texture_generation": if self.shape2txt is None: raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") slat = transformer_options.get("shape_slat") if slat is None: raise ValueError("shape_slat can't be None") - x = sparse_cat([x, slat]) - out = self.shape2txt(x, timestep, context if not txt_rule else cond) + slat.feats = slat.feats.repeat(B, 1) + x_st = sparse_cat([x_st, slat]) + out = self.shape2txt(x_st, t_eval, c_eval) else: # structure #timestep = timestep_reshift(timestep) orig_bsz = x.shape[0] @@ -828,6 +850,8 @@ class Trellis2(nn.Module): if not_struct_mode: out = out.feats - if mode == "shape_generation": + if not_struct_mode: out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) + if rule and orig_bsz > 1: + out = out.repeat(orig_bsz, 1, 1, 1) return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 409d2d23c..cba6b3241 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -178,6 +178,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode): @classmethod def execute(cls, shape_latent_512, vae, target_resolution, max_tokens): device = comfy.model_management.get_torch_device() + comfy.model_management.load_model_gpu(vae.patcher) feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) coords_512 = shape_latent_512["coords"].to(device) @@ -185,11 +186,9 @@ class Trellis2UpsampleCascade(IO.ComfyNode): slat = shape_norm(feats, coords_512) decoder = vae.first_stage_model.shape_dec - decoder.to(device) slat.feats = slat.feats.to(next(decoder.parameters()).dtype) hr_coords = decoder.upsample(slat, upsample_times=4) - decoder.cpu() lr_resolution = 512 hr_resolution = int(target_resolution) @@ -206,7 +205,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode): break hr_resolution -= 128 - return IO.NodeOutput(final_coords.cpu()) + return IO.NodeOutput(final_coords,) dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) @@ -341,11 +340,19 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_or_coords, model): # to accept the upscaled coords - if hasattr(structure_or_coords, "data"): + is_512_pass = False + + if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: decoded = structure_or_coords.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + is_512_pass = True + + elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: + coords = structure_or_coords.int() + is_512_pass = False + else: - coords = structure_or_coords + raise ValueError(f"Invalid input to EmptyShapeLatent: {type(structure_or_coords)}") in_channels = 32 # image like format latent = torch.randn(1, in_channels, coords.shape[0], 1) @@ -357,7 +364,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords - model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + if is_512_pass: + model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512" + else: + model.model_options["transformer_options"]["generation_mode"] = "shape_generation" return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): From 56e52e5d03f52c407cf529c8b211a1636a3ed221 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 24 Mar 2026 02:44:03 +0200 Subject: [PATCH 45/93] work on txt gen --- comfy/ldm/trellis2/model.py | 15 ++++++++------- comfy_extras/nodes_trellis2.py | 33 ++++++++++++++++++--------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 8a0c6d8b6..34aeba3e1 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -754,6 +754,7 @@ class Trellis2(nn.Module): "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations } self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) + self.shape2txt = None if init_txt_model: self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args) @@ -835,11 +836,12 @@ class Trellis2(nn.Module): slat = transformer_options.get("shape_slat") if slat is None: raise ValueError("shape_slat can't be None") - slat.feats = slat.feats.repeat(B, 1) - x_st = sparse_cat([x_st, slat]) + + base_slat_feats = slat.feats[:N] + slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device) + x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1)) out = self.shape2txt(x_st, t_eval, c_eval) else: # structure - #timestep = timestep_reshift(timestep) orig_bsz = x.shape[0] if shape_rule: x = x[0].unsqueeze(0) @@ -850,8 +852,7 @@ class Trellis2(nn.Module): if not_struct_mode: out = out.feats - if not_struct_mode: - out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) - if rule and orig_bsz > 1: - out = out.repeat(orig_bsz, 1, 1, 1) + out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) + if rule and orig_bsz > 1: + out = out.repeat(orig_bsz, 1, 1, 1) return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index cba6b3241..dcf8dcb98 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -95,7 +95,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): IO.AnyType.Input("shape_subs"), ], outputs=[ - IO.Mesh.Output("mesh"), + IO.Voxel.Output("voxel"), ] ) @@ -116,11 +116,9 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean - mesh = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5 - faces = torch.stack([m.faces for m in mesh]) - verts = torch.stack([m.vertices for m in mesh]) - mesh = Types.MESH(vertices=verts, faces=faces) - return IO.NodeOutput(mesh) + voxel = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5 + voxel = Types.VOXEL(voxel) + return IO.NodeOutput(voxel) class VaeDecodeStructureTrellis2(IO.ComfyNode): @classmethod @@ -377,7 +375,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): node_id="EmptyTextureLatentTrellis2", category="latent/3d", inputs=[ - IO.Voxel.Input("structure_output"), + IO.Voxel.Input("structure_or_coords"), IO.Latent.Input("shape_latent"), IO.Model.Input("model") ], @@ -388,16 +386,21 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_output, shape_latent, model): - # TODO - decoded = structure_output.data.unsqueeze(1) - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() - in_channels = 32 + def execute(cls, structure_or_coords, shape_latent, model): + channels = 32 + if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: + decoded = structure_or_coords.data.unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + + elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: + coords = structure_or_coords.int() shape_latent = shape_latent["samples"] + if shape_latent.ndim == 4: + shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels) shape_latent = shape_norm(shape_latent, coords) - latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) + latent = torch.randn(1, channels, coords.shape[0], 1) model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: @@ -406,9 +409,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords - model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + model.model_options["transformer_options"]["generation_mode"] = "texture_generation" model.model_options["transformer_options"]["shape_slat"] = shape_latent - return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) + return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model) class EmptyStructureLatentTrellis2(IO.ComfyNode): From fe25190cae5e3a0bbaf20aab8e4f2d5f25dd0538 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 25 Mar 2026 02:40:15 +0200 Subject: [PATCH 46/93] add color support for save mesh --- comfy_extras/nodes_hunyuan3d.py | 17 +++++++++++-- comfy_extras/nodes_trellis2.py | 43 +++++++++++++++++++++++++++++---- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index df0c3e4b1..692834c2b 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -484,7 +484,7 @@ class VoxelToMesh(IO.ComfyNode): decode = execute # TODO: remove -def save_glb(vertices, faces, filepath, metadata=None): +def save_glb(vertices, faces, filepath, metadata=None, colors=None): """ Save PyTorch tensor vertices and faces as a GLB file without external dependencies. @@ -515,6 +515,13 @@ def save_glb(vertices, faces, filepath, metadata=None): indices_byte_length = len(indices_buffer) indices_byte_offset = len(vertices_buffer_padded) + if colors is not None: + colors_np = colors.cpu().numpy().astype(np.float32) + colors_buffer = colors_np.tobytes() + colors_byte_length = len(colors_buffer) + colors_byte_offset = len(buffer_data) + buffer_data += pad_to_4_bytes(colors_buffer) + gltf = { "asset": {"version": "2.0", "generator": "ComfyUI"}, "buffers": [ @@ -580,6 +587,11 @@ def save_glb(vertices, faces, filepath, metadata=None): "scene": 0 } + if colors is not None: + gltf["bufferViews"].append({"buffer": 0, "byteOffset": colors_byte_offset, "byteLength": colors_byte_length, "target": 34962}) + gltf["accessors"].append({"bufferView": 2, "byteOffset": 0, "componentType": 5126, "count": len(colors_np), "type": "VEC3"}) + gltf["meshes"][0]["primitives"][0]["attributes"]["COLOR_0"] = 2 + if metadata is not None: gltf["asset"]["extras"] = metadata @@ -669,7 +681,8 @@ class SaveGLB(IO.ComfyNode): # Handle Mesh input - save vertices and faces as GLB for i in range(mesh.vertices.shape[0]): f = f"{filename}_{counter:05}_.glb" - save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata) + v_colors = mesh.colors[i] if hasattr(mesh, "colors") and mesh.colors is not None else None + save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata, v_colors) results.append({ "filename": f, "subfolder": subfolder, diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index dcf8dcb98..4126fb536 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -45,6 +45,34 @@ def shape_norm(shape_latent, coords): samples = samples * std + mean return samples +def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution, chunk_size=4096): + """ + Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field. + Keeps chunking internal to prevent OOM crashes on large matrices. + """ + device = voxel_coords.device + + # Map Voxel Grid to Real 3D Space + origin = torch.tensor([-0.5, -0.5, -0.5], device=device) + voxel_size = 1.0 / resolution + voxel_pos = voxel_coords.float() * voxel_size + origin + + verts = mesh.vertices.to(device).squeeze(0) + v_colors = torch.zeros((verts.shape[0], 3), device=device) + + for i in range(0, verts.shape[0], chunk_size): + v_chunk = verts[i : i + chunk_size] + dists = torch.cdist(v_chunk, voxel_pos) + nearest_idx = torch.argmin(dists, dim=1) + v_colors[i : i + chunk_size] = voxel_colors[nearest_idx] + + final_colors = (v_colors * 0.5 + 0.5).clamp(0, 1).unsqueeze(0) + + out_mesh = copy.deepcopy(mesh) + out_mesh.colors = final_colors + + return out_mesh + class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def define_schema(cls): @@ -90,18 +118,20 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): node_id="VaeDecodeTextureTrellis", category="latent/3d", inputs=[ + IO.Mesh.Input("shape_mesh"), IO.Latent.Input("samples"), IO.Vae.Input("vae"), IO.AnyType.Input("shape_subs"), ], outputs=[ - IO.Voxel.Output("voxel"), + IO.Mesh.Output("mesh"), ] ) @classmethod - def execute(cls, samples, vae, shape_subs): + def execute(cls, shape_mesh, samples, vae, shape_subs): + resolution = 1024 patcher = vae.patcher device = comfy.model_management.get_torch_device() comfy.model_management.load_model_gpu(patcher) @@ -116,9 +146,12 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean - voxel = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5 - voxel = Types.VOXEL(voxel) - return IO.NodeOutput(voxel) + voxel = vae.decode_tex_slat(samples, shape_subs) + color_feats = voxel.feats[:, :3] + voxel_coords = voxel.coords[:, 1:] + + out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) + return IO.NodeOutput(out_mesh) class VaeDecodeStructureTrellis2(IO.ComfyNode): @classmethod From d2c37c222a9cc869e8aff1783a2add7b333c0c69 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 25 Mar 2026 17:03:52 +0200 Subject: [PATCH 47/93] pytorch -> scipy --- comfy_extras/nodes_trellis2.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 4126fb536..469b460eb 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -6,6 +6,7 @@ import logging from PIL import Image import numpy as np import torch +import scipy import copy shape_slat_normalization = { @@ -45,26 +46,30 @@ def shape_norm(shape_latent, coords): samples = samples * std + mean return samples -def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution, chunk_size=4096): +def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): """ Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field. - Keeps chunking internal to prevent OOM crashes on large matrices. """ - device = voxel_coords.device + device = comfy.model_management.vae_offload_device() - # Map Voxel Grid to Real 3D Space origin = torch.tensor([-0.5, -0.5, -0.5], device=device) voxel_size = 1.0 / resolution - voxel_pos = voxel_coords.float() * voxel_size + origin + # map voxels + voxel_pos = voxel_coords.to(device).float() * voxel_size + origin verts = mesh.vertices.to(device).squeeze(0) - v_colors = torch.zeros((verts.shape[0], 3), device=device) + voxel_colors = voxel_colors.to(device) - for i in range(0, verts.shape[0], chunk_size): - v_chunk = verts[i : i + chunk_size] - dists = torch.cdist(v_chunk, voxel_pos) - nearest_idx = torch.argmin(dists, dim=1) - v_colors[i : i + chunk_size] = voxel_colors[nearest_idx] + voxel_pos_np = voxel_pos.numpy() + verts_np = verts.numpy() + + tree = scipy.spatial.cKDTree(voxel_pos_np) + + # nearest neighbour k=1 + _, nearest_idx_np = tree.query(verts_np, k=1, workers=-1) + + nearest_idx = torch.from_numpy(nearest_idx_np).long() + v_colors = voxel_colors[nearest_idx] final_colors = (v_colors * 0.5 + 0.5).clamp(0, 1).unsqueeze(0) @@ -343,7 +348,11 @@ class Trellis2Conditioning(IO.ComfyNode): alpha_float = cropped_np[:, :, 3:4] composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) - cropped_img_tensor = torch.from_numpy(composite_np).movedim(-1, 0).unsqueeze(0).float() + # to match trellis2 code (quantize -> dequantize) + composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) + + cropped_img_tensor = torch.from_numpy(composite_uint8).float() / 255.0 + cropped_img_tensor = cropped_img_tensor.movedim(-1, 0).unsqueeze(0) conditioning = run_conditioning(clip_vision_model, cropped_img_tensor, include_1024=True) From 22dcd81fb3b8360f4d602dda97f62187921e97f3 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 25 Mar 2026 23:51:54 +0200 Subject: [PATCH 48/93] fixed color addition --- comfy/ldm/trellis2/model.py | 3 +-- comfy_extras/nodes_hunyuan3d.py | 3 +++ comfy_extras/nodes_trellis2.py | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 34aeba3e1..613f7ef50 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -782,7 +782,7 @@ class Trellis2(nn.Module): mode = "structure_generation" not_struct_mode = False - if is_1024 and mode == "shape_generation" and not is_512_run: + if is_1024 and not_struct_mode and not is_512_run: context = embeds sigmas = transformer_options.get("sigmas")[0].item() @@ -825,7 +825,6 @@ class Trellis2(nn.Module): x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) if mode == "shape_generation": - # TODO if is_512_run: out = self.img2shape_512(x_st, t_eval, c_eval) else: diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 692834c2b..ac91fe0a7 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -591,6 +591,9 @@ def save_glb(vertices, faces, filepath, metadata=None, colors=None): gltf["bufferViews"].append({"buffer": 0, "byteOffset": colors_byte_offset, "byteLength": colors_byte_length, "target": 34962}) gltf["accessors"].append({"bufferView": 2, "byteOffset": 0, "componentType": 5126, "count": len(colors_np), "type": "VEC3"}) gltf["meshes"][0]["primitives"][0]["attributes"]["COLOR_0"] = 2 + # Define a base material so Three.js actually activates vertex coloring + gltf["materials"] =[{"pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0]}}] + gltf["meshes"][0]["primitives"][0]["material"] = 0 if metadata is not None: gltf["asset"]["extras"] = metadata diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 469b460eb..ff95e8332 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -53,6 +53,7 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): device = comfy.model_management.vae_offload_device() origin = torch.tensor([-0.5, -0.5, -0.5], device=device) + # TODO: generic independent node? if so: figure how pass the resolution parameter voxel_size = 1.0 / resolution # map voxels From 55595697356038132c643bd6d5a5fa1a95cc9814 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 27 Mar 2026 20:52:23 +0200 Subject: [PATCH 49/93] .. --- comfy/ldm/trellis2/model.py | 4 ++-- comfy_extras/nodes_trellis2.py | 6 +++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 613f7ef50..40646f369 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -843,8 +843,8 @@ class Trellis2(nn.Module): else: # structure orig_bsz = x.shape[0] if shape_rule: - x = x[0].unsqueeze(0) - timestep = timestep[0].unsqueeze(0) + x = x[1].unsqueeze(0) + timestep = timestep[1].unsqueeze(0) out = self.structure_model(x, timestep, context if not shape_rule else cond) if shape_rule: out = out.repeat(orig_bsz, 1, 1, 1, 1) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index ff95e8332..f58dbf592 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -441,7 +441,11 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): shape_latent = shape_latent["samples"] if shape_latent.ndim == 4: shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels) - shape_latent = shape_norm(shape_latent, coords) + + std = shape_slat_normalization["std"].to(shape_latent) + mean = shape_slat_normalization["mean"].to(shape_latent) + shape_latent = SparseTensor(feats = shape_latent, coords=coords) + shape_latent = (shape_latent - mean) / std latent = torch.randn(1, channels, coords.shape[0], 1) model = model.clone() From 72640888ffc6bcdda612432c27c8528cf0e76da8 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 30 Mar 2026 23:20:46 +0200 Subject: [PATCH 50/93] wrong normalization for the texture node --- comfy_extras/nodes_trellis2.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index f58dbf592..77e6a3add 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -442,11 +442,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): if shape_latent.ndim == 4: shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels) - std = shape_slat_normalization["std"].to(shape_latent) - mean = shape_slat_normalization["mean"].to(shape_latent) - shape_latent = SparseTensor(feats = shape_latent, coords=coords) - shape_latent = (shape_latent - mean) / std - latent = torch.randn(1, channels, coords.shape[0], 1) model = model.clone() model.model_options = model.model_options.copy() From 57b306464ecb2b87b93491198a632c720d0bb85c Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 3 Apr 2026 01:22:38 +0200 Subject: [PATCH 51/93] texture generation works --- comfy/ldm/trellis2/attention.py | 60 +++++++++++---------------------- comfy/ldm/trellis2/model.py | 7 ++-- comfy_extras/nodes_trellis2.py | 8 ++++- 3 files changed, 32 insertions(+), 43 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 681666430..d95b071b5 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -53,57 +53,37 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha ) return out - -# TODO repalce with optimized attention def scaled_dot_product_attention(*args, **kwargs): num_all_args = len(args) + len(kwargs) q = None if num_all_args == 1: - qkv = args[0] if len(args) > 0 else kwargs['qkv'] - + qkv = args[0] if len(args) > 0 else kwargs.get('qkv') elif num_all_args == 2: - q = args[0] if len(args) > 0 else kwargs['q'] - kv = args[1] if len(args) > 1 else kwargs['kv'] - + q = args[0] if len(args) > 0 else kwargs.get('q') + kv = args[1] if len(args) > 1 else kwargs.get('kv') elif num_all_args == 3: - q = args[0] if len(args) > 0 else kwargs['q'] - k = args[1] if len(args) > 1 else kwargs['k'] - v = args[2] if len(args) > 2 else kwargs['v'] + q = args[0] if len(args) > 0 else kwargs.get('q') + k = args[1] if len(args) > 1 else kwargs.get('k') + v = args[2] if len(args) > 2 else kwargs.get('v') if q is not None: - heads = q + heads = q.shape[2] else: - heads = qkv - heads = heads.shape[2] + heads = qkv.shape[3] - if optimized_attention.__name__ == 'attention_xformers': - if num_all_args == 1: - q, k, v = qkv.unbind(dim=2) - elif num_all_args == 2: - k, v = kv.unbind(dim=2) - #out = xops.memory_efficient_attention(q, k, v) - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) - elif optimized_attention.__name__ == 'attention_flash': - if num_all_args == 2: - k, v = kv.unbind(dim=2) - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) - elif optimized_attention.__name__ == 'attention_pytorch': - if num_all_args == 1: - q, k, v = qkv.unbind(dim=2) - elif num_all_args == 2: - k, v = kv.unbind(dim=2) - q = q.permute(0, 2, 1, 3) # [N, H, L, C] - k = k.permute(0, 2, 1, 3) # [N, H, L, C] - v = v.permute(0, 2, 1, 3) # [N, H, L, C] - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) - out = out.permute(0, 2, 1, 3) # [N, L, H, C] - elif optimized_attention.__name__ == 'attention_basic': - if num_all_args == 1: - q, k, v = qkv.unbind(dim=2) - elif num_all_args == 2: - k, v = kv.unbind(dim=2) - out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs) + + out = out.permute(0, 2, 1, 3) return out diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 40646f369..7c6ffdd69 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -788,7 +788,10 @@ class Trellis2(nn.Module): sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: timestep *= 1000.0 - cond = context.chunk(2)[1] + if context.size(0) > 1: + cond = context.chunk(2)[1] + else: + cond = context shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] @@ -836,7 +839,7 @@ class Trellis2(nn.Module): if slat is None: raise ValueError("shape_slat can't be None") - base_slat_feats = slat.feats[:N] + base_slat_feats = slat[:N] slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device) x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1)) out = self.shape2txt(x_st, t_eval, c_eval) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 77e6a3add..088cdd3f1 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -72,7 +72,13 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): nearest_idx = torch.from_numpy(nearest_idx_np).long() v_colors = voxel_colors[nearest_idx] - final_colors = (v_colors * 0.5 + 0.5).clamp(0, 1).unsqueeze(0) + # to [0, 1] + srgb_colors = (v_colors * 0.5 + 0.5).clamp(0, 1) + + # to Linear RGB (required for GLTF) + linear_colors = torch.pow(srgb_colors, 2.2) + + final_colors = linear_colors.unsqueeze(0) out_mesh = copy.deepcopy(mesh) out_mesh.colors = final_colors From 2cb06431e8e481b41c2c4da2aa92ba30ea07d66c Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 7 Apr 2026 23:02:33 +0200 Subject: [PATCH 52/93] fix for conditioning --- comfy/ldm/trellis2/model.py | 2 +- comfy_extras/nodes_trellis2.py | 29 ++++++++++++++--------------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 7c6ffdd69..ea7ada9f8 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -671,7 +671,7 @@ class SparseStructureFlowModel(nn.Module): coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') coords = torch.stack(coords, dim=-1).reshape(-1, 3) rope_phases = pos_embedder(coords) - self.register_buffer("rope_phases", rope_phases) + self.register_buffer("rope_phases", rope_phases, persistent=False) if pe_mode != "rope": self.rope_phases = None diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 088cdd3f1..d3f5e4940 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -2,7 +2,6 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy.ldm.trellis2.vae import SparseTensor import comfy.model_management -import logging from PIL import Image import numpy as np import torch @@ -250,28 +249,28 @@ class Trellis2UpsampleCascade(IO.ComfyNode): return IO.NodeOutput(final_coords,) -dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) -dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) +dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) +dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) def run_conditioning(model, cropped_img_tensor, include_1024=True): model_internal = model.model device = comfy.model_management.intermediate_device() torch_device = comfy.model_management.get_torch_device() - img_t = cropped_img_tensor.to(torch_device) - - def prepare_tensor(img, size): - resized = torch.nn.functional.interpolate(img, size=(size, size), mode='bicubic', align_corners=False).clamp(0.0, 1.0) - return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) + def prepare_tensor(pil_img, size): + resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS) + img_np = np.array(resized_pil).astype(np.float32) / 255.0 + img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device) + return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device) model_internal.image_size = 512 - input_512 = prepare_tensor(img_t, 512) + input_512 = prepare_tensor(cropped_img_tensor, 512) cond_512 = model_internal(input_512)[0] cond_1024 = None if include_1024: model_internal.image_size = 1024 - input_1024 = prepare_tensor(img_t, 1024) + input_1024 = prepare_tensor(cropped_img_tensor, 1024) cond_1024 = model_internal(input_1024)[0] conditioning = { @@ -341,14 +340,15 @@ class Trellis2Conditioning(IO.ComfyNode): crop_x2 = int(center_x + size // 2) crop_y2 = int(center_y + size // 2) - rgba_pil = Image.fromarray(rgba_np, 'RGBA') + rgba_pil = Image.fromarray(rgba_np) cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 else: + import logging logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") cropped_np = rgba_np.astype(np.float32) / 255.0 - bg_colors = {"black": [0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} + bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32) fg = cropped_np[:, :, :3] @@ -358,10 +358,9 @@ class Trellis2Conditioning(IO.ComfyNode): # to match trellis2 code (quantize -> dequantize) composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) - cropped_img_tensor = torch.from_numpy(composite_uint8).float() / 255.0 - cropped_img_tensor = cropped_img_tensor.movedim(-1, 0).unsqueeze(0) + cropped_pil = Image.fromarray(composite_uint8) - conditioning = run_conditioning(clip_vision_model, cropped_img_tensor, include_1024=True) + conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True) embeds = conditioning["cond_1024"] positive = [[conditioning["cond_512"], {"embeds": embeds}]] From 0ebeac98a78885d4c13c09691e027fb141d9e581 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 8 Apr 2026 19:08:26 +0200 Subject: [PATCH 53/93] removed unnecessary vae float32 upcast --- comfy/ldm/trellis2/vae.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 2a18c496a..c42ad8d2f 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -2,7 +2,6 @@ import math import torch import numpy as np import torch.nn as nn -import comfy.model_management import torch.nn.functional as F from fractions import Fraction from dataclasses import dataclass @@ -78,7 +77,10 @@ class LayerNorm32(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: x_dtype = x.dtype x = x.to(torch.float32) - o = super().forward(x) + w = self.weight.to(torch.float32) + b = self.bias.to(torch.float32) if self.bias is not None else None + + o = F.layer_norm(x, self.normalized_shape, w, b, self.eps) return o.to(x_dtype) class SparseConvNeXtBlock3d(nn.Module): @@ -102,8 +104,7 @@ class SparseConvNeXtBlock3d(nn.Module): def _forward(self, x): h = self.conv(x) - norm = self.norm.to(torch.float32) - h = h.replace(norm(h.feats)) + h = h.replace(self.norm(h.feats)) h = h.replace(self.mlp(h.feats)) return h + x @@ -213,15 +214,13 @@ class SparseResBlockC2S3d(nn.Module): dtype = next(self.to_subdiv.parameters()).dtype x = x.to(dtype) subdiv = self.to_subdiv(x) - norm1 = self.norm1.to(torch.float32) - norm2 = self.norm2.to(torch.float32) - h = x.replace(norm1(x.feats)) + h = x.replace(self.norm1(x.feats)) h = h.replace(F.silu(h.feats)) h = self.conv1(h) subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None h = self.updown(h, subdiv_binarized) x = self.updown(x, subdiv_binarized) - h = h.replace(norm2(h.feats)) + h = h.replace(self.norm2(h.feats)) h = h.replace(F.silu(h.feats)) h = self.conv2(h) h = h + self.skip_connection(x) @@ -1300,8 +1299,6 @@ class ResBlock3d(nn.Module): self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: - self.norm1 = self.norm1.to(torch.float32) - self.norm2 = self.norm2.to(torch.float32) h = self.norm1(x) h = F.silu(h) dtype = next(self.conv1.parameters()).dtype @@ -1381,8 +1378,7 @@ class SparseStructureDecoder(nn.Module): for block in self.blocks: h = block(h) - h = h.to(torch.float32) - self.out_layer = self.out_layer.to(torch.float32) + h = h.type(x.dtype) h = self.out_layer(h) return h From ea255543e648000ed987576494dbc227320b61d9 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 10 Apr 2026 14:24:07 +0200 Subject: [PATCH 54/93] structure generation works --- comfy/image_encoders/dino3.py | 8 ++++++-- comfy/ldm/trellis2/vae.py | 2 +- comfy_extras/nodes_trellis2.py | 4 ++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index ff17d78d6..145bd5490 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -1,6 +1,7 @@ import math import torch import torch.nn as nn +import torch.nn.functional as F import comfy.model_management from comfy.ldm.modules.attention import optimized_attention_for_device @@ -274,8 +275,11 @@ class DINOv3ViTModel(nn.Module): position_embeddings=position_embeddings, ) - norm = self.norm.to(hidden_states.device) - sequence_output = norm(hidden_states) + if kwargs.get("skip_norm_elementwise", False): + sequence_output= F.layer_norm(hidden_states, hidden_states.shape[-1:]) + else: + norm = self.norm.to(hidden_states.device) + sequence_output = norm(hidden_states) pooled_output = sequence_output[:, 0, :] return sequence_output, None, pooled_output, None diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index c42ad8d2f..cd37ccd30 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -77,7 +77,7 @@ class LayerNorm32(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: x_dtype = x.dtype x = x.to(torch.float32) - w = self.weight.to(torch.float32) + w = self.weight.to(torch.float32) if self.weight is not None else None b = self.bias.to(torch.float32) if self.bias is not None else None o = F.layer_norm(x, self.normalized_shape, w, b, self.eps) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index d3f5e4940..56ce4f5ea 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -265,13 +265,13 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): model_internal.image_size = 512 input_512 = prepare_tensor(cropped_img_tensor, 512) - cond_512 = model_internal(input_512)[0] + cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] cond_1024 = None if include_1024: model_internal.image_size = 1024 input_1024 = prepare_tensor(cropped_img_tensor, 1024) - cond_1024 = model_internal(input_1024)[0] + cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] conditioning = { 'cond_512': cond_512.to(device), From 4e14d42da1c26f48af7836c3c2eff8aa8cc8d4f5 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 10 Apr 2026 16:12:23 +0200 Subject: [PATCH 55/93] comfy ops + color support in postprocess --- comfy/ldm/trellis2/model.py | 105 +++++++++++++++++++-------------- comfy/ldm/trellis2/vae.py | 14 +++-- comfy_extras/nodes_trellis2.py | 57 +++++++++--------- 3 files changed, 98 insertions(+), 78 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index ea7ada9f8..a613fb325 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -14,12 +14,12 @@ class SparseGELU(nn.GELU): return input.replace(super().forward(input.feats)) class SparseFeedForwardNet(nn.Module): - def __init__(self, channels: int, mlp_ratio: float = 4.0): + def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None): super().__init__() self.mlp = nn.Sequential( - SparseLinear(channels, int(channels * mlp_ratio)), + SparseLinear(channels, int(channels * mlp_ratio), device=device, dtype=dtype, operations=operations), SparseGELU(approximate="tanh"), - SparseLinear(int(channels * mlp_ratio), channels), + SparseLinear(int(channels * mlp_ratio), channels, device=device, dtype=dtype, operations=operations), ) def forward(self, x: VarLenTensor) -> VarLenTensor: @@ -37,10 +37,10 @@ class LayerNorm32(nn.LayerNorm): class SparseMultiHeadRMSNorm(nn.Module): - def __init__(self, dim: int, heads: int): + def __init__(self, dim: int, heads: int, device, dtype): super().__init__() self.scale = dim ** 0.5 - self.gamma = nn.Parameter(torch.ones(heads, dim)) + self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype)) def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: x_type = x.dtype @@ -56,14 +56,15 @@ class SparseRotaryPositionEmbedder(nn.Module): self, head_dim: int, dim: int = 3, - rope_freq: Tuple[float, float] = (1.0, 10000.0) + rope_freq: Tuple[float, float] = (1.0, 10000.0), + device=None ): super().__init__() self.head_dim = head_dim self.dim = dim self.rope_freq = rope_freq self.freq_dim = head_dim // 2 // dim - self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32, device=device) / self.freq_dim self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor: @@ -148,6 +149,7 @@ class SparseMultiHeadAttention(nn.Module): use_rope: bool = False, rope_freq: Tuple[int, int] = (1.0, 10000.0), qk_rms_norm: bool = False, + device=None, dtype=None, operations=None ): super().__init__() @@ -163,19 +165,19 @@ class SparseMultiHeadAttention(nn.Module): self.qk_rms_norm = qk_rms_norm if self._type == "self": - self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, device=device, dtype=dtype) else: - self.to_q = nn.Linear(channels, channels, bias=qkv_bias) - self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype) + self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype) if self.qk_rms_norm: - self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) - self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) + self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) - self.to_out = nn.Linear(channels, channels) + self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype) if use_rope: - self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq) + self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq, device=device) @staticmethod def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: @@ -267,14 +269,14 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): qk_rms_norm_cross: bool = False, qkv_bias: bool = True, share_mod: bool = False, - + device=None, dtype=None, operations=None ): super().__init__() self.use_checkpoint = use_checkpoint self.share_mod = share_mod - self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) - self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) self.self_attn = SparseMultiHeadAttention( channels, num_heads=num_heads, @@ -286,6 +288,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): use_rope=use_rope, rope_freq=rope_freq, qk_rms_norm=qk_rms_norm, + device=device, dtype=dtype, operations=operations ) self.cross_attn = SparseMultiHeadAttention( channels, @@ -295,18 +298,20 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): attn_mode="full", qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations ) self.mlp = SparseFeedForwardNet( channels, mlp_ratio=mlp_ratio, + device=device, dtype=dtype, operations=operations ) if not share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) + operations.Linear(channels, 6 * channels, bias=True, device=device, dtype=dtype) ) else: - self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: if self.share_mod: @@ -376,10 +381,10 @@ class SLatFlowModel(nn.Module): if share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(model_channels, 6 * model_channels, bias=True) + operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype) ) - self.input_layer = SparseLinear(in_channels, model_channels) + self.input_layer = SparseLinear(in_channels, model_channels, device=device, dtype=dtype, operations=operations) self.blocks = nn.ModuleList([ ModulatedSparseTransformerCrossBlock( @@ -394,11 +399,12 @@ class SLatFlowModel(nn.Module): share_mod=self.share_mod, qk_rms_norm=self.qk_rms_norm, qk_rms_norm_cross=self.qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations ) for _ in range(num_blocks) ]) - self.out_layer = SparseLinear(model_channels, out_channels) + self.out_layer = SparseLinear(model_channels, out_channels, device=device, dtype=dtype, operations=operations) @property def device(self) -> torch.device: @@ -438,22 +444,22 @@ class SLatFlowModel(nn.Module): return h class FeedForwardNet(nn.Module): - def __init__(self, channels: int, mlp_ratio: float = 4.0): + def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None): super().__init__() self.mlp = nn.Sequential( - nn.Linear(channels, int(channels * mlp_ratio)), + operations.Linear(channels, int(channels * mlp_ratio), device=device, dtype=dtype), nn.GELU(approximate="tanh"), - nn.Linear(int(channels * mlp_ratio), channels), + operations.Linear(int(channels * mlp_ratio), channels, device=device, dtype=dtype), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp(x) class MultiHeadRMSNorm(nn.Module): - def __init__(self, dim: int, heads: int): + def __init__(self, dim: int, heads: int, device=None, dtype=None): super().__init__() self.scale = dim ** 0.5 - self.gamma = nn.Parameter(torch.ones(heads, dim)) + self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype)) def forward(self, x: torch.Tensor) -> torch.Tensor: return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) @@ -473,6 +479,7 @@ class MultiHeadAttention(nn.Module): use_rope: bool = False, rope_freq: Tuple[float, float] = (1.0, 10000.0), qk_rms_norm: bool = False, + device=None, dtype=None, operations=None ): super().__init__() @@ -488,16 +495,16 @@ class MultiHeadAttention(nn.Module): self.qk_rms_norm = qk_rms_norm if self._type == "self": - self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device) else: - self.to_q = nn.Linear(channels, channels, bias=qkv_bias) - self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype) + self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype) if self.qk_rms_norm: - self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) - self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) - self.to_out = nn.Linear(channels, channels) + self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype) def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: B, L, C = x.shape @@ -554,13 +561,14 @@ class ModulatedTransformerCrossBlock(nn.Module): qk_rms_norm_cross: bool = False, qkv_bias: bool = True, share_mod: bool = False, + device=None, dtype=None, operations=None ): super().__init__() self.use_checkpoint = use_checkpoint self.share_mod = share_mod - self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) - self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) self.self_attn = MultiHeadAttention( channels, num_heads=num_heads, @@ -572,6 +580,7 @@ class ModulatedTransformerCrossBlock(nn.Module): use_rope=use_rope, rope_freq=rope_freq, qk_rms_norm=qk_rms_norm, + device=device, dtype=dtype, operations=operations ) self.cross_attn = MultiHeadAttention( channels, @@ -581,18 +590,20 @@ class ModulatedTransformerCrossBlock(nn.Module): attn_mode="full", qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations ) self.mlp = FeedForwardNet( channels, mlp_ratio=mlp_ratio, + device=device, dtype=dtype, operations=operations ) if not share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) + operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device) ) else: - self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: if self.share_mod: @@ -659,16 +670,17 @@ class SparseStructureFlowModel(nn.Module): self.qk_rms_norm = qk_rms_norm self.qk_rms_norm_cross = qk_rms_norm_cross self.dtype = dtype + self.device = device - self.t_embedder = TimestepEmbedder(model_channels, operations=operations) + self.t_embedder = TimestepEmbedder(model_channels, dtype=dtype, device=device, operations=operations) if share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(model_channels, 6 * model_channels, bias=True) + operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype) ) - pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3) - coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') + pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3, device=device) + coords = torch.meshgrid(*[torch.arange(res, device=self.device, dtype=dtype) for res in [resolution] * 3], indexing='ij') coords = torch.stack(coords, dim=-1).reshape(-1, 3) rope_phases = pos_embedder(coords) self.register_buffer("rope_phases", rope_phases, persistent=False) @@ -676,7 +688,7 @@ class SparseStructureFlowModel(nn.Module): if pe_mode != "rope": self.rope_phases = None - self.input_layer = nn.Linear(in_channels, model_channels) + self.input_layer = operations.Linear(in_channels, model_channels, device=device, dtype=dtype) self.blocks = nn.ModuleList([ ModulatedTransformerCrossBlock( @@ -691,11 +703,12 @@ class SparseStructureFlowModel(nn.Module): share_mod=share_mod, qk_rms_norm=self.qk_rms_norm, qk_rms_norm_cross=self.qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations ) for _ in range(num_blocks) ]) - self.out_layer = nn.Linear(model_channels, out_channels) + self.out_layer = operations.Linear(model_channels, out_channels, device=device, dtype=dtype) def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3) @@ -745,6 +758,7 @@ class Trellis2(nn.Module): super().__init__() self.dtype = dtype + operations = operations or nn # for some reason it passes num_heads = -1 if num_heads == -1: num_heads = 12 @@ -772,6 +786,7 @@ class Trellis2(nn.Module): coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") is_512_run = False + timestep = timestep.to(self.dtype) if mode == "shape_generation_512": is_512_run = True mode = "shape_generation" diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index cd37ccd30..30f902868 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -962,13 +962,17 @@ def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: feats = input.feats.unbind(dim) return [input.replace(f) for f in feats] -class SparseLinear(nn.Linear): - def __init__(self, in_features, out_features, bias=True): - super(SparseLinear, self).__init__(in_features, out_features, bias) +# allow operations.Linear inheritance +class SparseLinear: + def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None, operations=nn, *args, **kwargs): + class _SparseLinear(operations.Linear): + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) - def forward(self, input: VarLenTensor) -> VarLenTensor: - return input.replace(super().forward(input.feats)) + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + return _SparseLinear(in_features, out_features, bias=bias, device=device, dtype=dtype, *args, **kwargs) MIX_PRECISION_MODULES = ( nn.Conv1d, diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 56ce4f5ea..1bf7c55b8 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -481,21 +481,25 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) -def simplify_fn(vertices, faces, target=100000): - is_batched = vertices.ndim == 3 - if is_batched: - v_list, f_list = [], [] +def simplify_fn(vertices, faces, colors=None, target=100000): + if vertices.ndim == 3: + v_list, f_list, c_list = [], [], [] for i in range(vertices.shape[0]): - v_i, f_i = simplify_fn(vertices[i], faces[i], target) + c_in = colors[i] if colors is not None else None + v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target) v_list.append(v_i) f_list.append(f_i) - return torch.stack(v_list), torch.stack(f_list) + if c_i is not None: + c_list.append(c_i) + + c_out = torch.stack(c_list) if len(c_list) > 0 else None + return torch.stack(v_list), torch.stack(f_list), c_out if faces.shape[0] <= target: - return vertices, faces + return vertices, faces, colors device = vertices.device - target_v = target / 2.0 + target_v = max(target / 4.0, 1.0) min_v = vertices.min(dim=0)[0] max_v = vertices.max(dim=0)[0] @@ -510,14 +514,17 @@ def simplify_fn(vertices, faces, target=100000): new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device) counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device) - new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices) counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1])) - new_vertices = new_vertices / counts.clamp(min=1) - new_faces = inverse_indices[faces] + new_colors = None + if colors is not None: + new_colors = torch.zeros((num_cells, colors.shape[1]), dtype=colors.dtype, device=device) + new_colors.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, colors.shape[1]), colors) + new_colors = new_colors / counts.clamp(min=1) + new_faces = inverse_indices[faces] valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \ (new_faces[:, 1] != new_faces[:, 2]) & \ (new_faces[:, 2] != new_faces[:, 0]) @@ -527,7 +534,10 @@ def simplify_fn(vertices, faces, target=100000): final_vertices = new_vertices[unique_face_indices] final_faces = inv_face.reshape(-1, 3) - return final_vertices, final_faces + # assign colors + final_colors = new_colors[unique_face_indices] if new_colors is not None else None + + return final_vertices, final_faces, final_colors def fill_holes_fn(vertices, faces, max_perimeter=0.03): is_batched = vertices.ndim == 3 @@ -610,19 +620,6 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): return v, f -def make_double_sided(vertices, faces): - is_batched = vertices.ndim == 3 - if is_batched: - f_list =[] - for i in range(faces.shape[0]): - f_inv = faces[i][:,[0, 2, 1]] - f_list.append(torch.cat([faces[i], f_inv], dim=0)) - return vertices, torch.stack(f_list) - - faces_inv = faces[:, [0, 2, 1]] - faces_double = torch.cat([faces, faces_inv], dim=0) - return vertices, faces_double - class PostProcessMesh(IO.ComfyNode): @classmethod def define_schema(cls): @@ -641,19 +638,23 @@ class PostProcessMesh(IO.ComfyNode): @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): + # TODO: batched mode may break mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces + colors = None + if hasattr(mesh, "colors"): + colors = mesh.colors if fill_holes_perimeter > 0: verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter) if simplify > 0 and faces.shape[0] > simplify: - verts, faces = simplify_fn(verts, faces, target=simplify) - - verts, faces = make_double_sided(verts, faces) + verts, faces, colors = simplify_fn(verts, faces, target=simplify, colors=colors) mesh.vertices = verts mesh.faces = faces + if colors is not None: + mesh.colors = None return IO.NodeOutput(mesh) class Trellis2Extension(ComfyExtension): From 937faafe21c70036df67659bef321619adb99eaa Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 10 Apr 2026 16:33:19 +0200 Subject: [PATCH 56/93] corrected simplification logic --- comfy_extras/nodes_trellis2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 1bf7c55b8..4aaa2c5f4 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -639,18 +639,19 @@ class PostProcessMesh(IO.ComfyNode): @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): # TODO: batched mode may break - mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces colors = None if hasattr(mesh, "colors"): colors = mesh.colors + actual_face_count = faces.shape[1] if faces.ndim == 3 else faces.shape[0] if fill_holes_perimeter > 0: verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter) - if simplify > 0 and faces.shape[0] > simplify: + if simplify > 0 and actual_face_count > simplify: verts, faces, colors = simplify_fn(verts, faces, target=simplify, colors=colors) + mesh = type(mesh)(vertices=verts, faces=faces) mesh.vertices = verts mesh.faces = faces if colors is not None: From 243691c258803ebbf22c9f7fc38d7baa3e4239f2 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 10 Apr 2026 19:18:14 +0200 Subject: [PATCH 57/93] texture fixes --- comfy/supported_models.py | 2 ++ comfy_extras/nodes_trellis2.py | 17 ++++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 20a60194b..f445edf66 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1270,6 +1270,8 @@ class Trellis2(supported_models_base.BASE): latent_format = latent_formats.Trellis2 vae_key_prefix = ["vae."] clip_vision_prefix = "conditioner.main_image_encoder.model." + # this is only needed for the texture model + supported_inference_dtypes = [torch.bfloat16, torch.float32] def get_model(self, state_dict, prefix="", device=None): return model_base.Trellis2(self, device=device) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 4aaa2c5f4..42fe7f707 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -620,6 +620,19 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): return v, f + +def make_double_sided(vertices, faces): + is_batched = vertices.ndim == 3 + if is_batched: + f_list = [] + for i in range(faces.shape[0]): + f_inv = faces[i][:, [0, 2, 1]] + f_list.append(torch.cat([faces[i], f_inv], dim=0)) + return vertices, torch.stack(f_list) + + faces_inv = faces[:, [0, 2, 1]] + return vertices, torch.cat([faces, faces_inv], dim=0) + class PostProcessMesh(IO.ComfyNode): @classmethod def define_schema(cls): @@ -651,11 +664,13 @@ class PostProcessMesh(IO.ComfyNode): if simplify > 0 and actual_face_count > simplify: verts, faces, colors = simplify_fn(verts, faces, target=simplify, colors=colors) + verts, faces = make_double_sided(verts, faces) + mesh = type(mesh)(vertices=verts, faces=faces) mesh.vertices = verts mesh.faces = faces if colors is not None: - mesh.colors = None + mesh.colors = colors return IO.NodeOutput(mesh) class Trellis2Extension(ComfyExtension): From a1364a7b0076343eb9ae10e9c85c602470b2147b Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 12 Apr 2026 20:14:47 +0200 Subject: [PATCH 58/93] small final change --- comfy_extras/nodes_trellis2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 42fe7f707..3479d5410 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -72,7 +72,7 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): v_colors = voxel_colors[nearest_idx] # to [0, 1] - srgb_colors = (v_colors * 0.5 + 0.5).clamp(0, 1) + srgb_colors = v_colors.clamp(0, 1)#(v_colors * 0.5 + 0.5).clamp(0, 1) # to Linear RGB (required for GLTF) linear_colors = torch.pow(srgb_colors, 2.2) From f4ae7b8391b6c977f17312fb0eb3446251ac0d35 Mon Sep 17 00:00:00 2001 From: "Yousef R. Gamaleldin" <81116377+yousef-rafat@users.noreply.github.com> Date: Sun, 12 Apr 2026 20:21:19 +0200 Subject: [PATCH 59/93] Fix return statement --- comfy/model_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 157bdd929..5f258178f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1589,6 +1589,7 @@ class WAN21_SCAIL(WAN21): pose_latents = kwargs.get("pose_video_latent", None) if pose_latents is not None: out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]] + return out class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): From d62bbe5fe069f15b27507e1e374f57c376328580 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 18:23:17 -0500 Subject: [PATCH 60/93] fix: issue 86 1024 conditioning gate --- comfy/ldm/trellis2/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index a613fb325..96c1eeef2 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -782,7 +782,7 @@ class Trellis2(nn.Module): embeds = kwargs.get("embeds") if embeds is None: raise ValueError("Trellis2.forward requires 'embeds' in kwargs") - is_1024 = self.img2shape.resolution == 1024 + is_1024 = self.img2shape.resolution == 64 coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") is_512_run = False From 5575e06ff355829ddb7ad957ec62a9bd5793f717 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 19:20:14 -0500 Subject: [PATCH 61/93] clarify: issue 86 latent-to-pixel resolution mapping --- comfy/ldm/trellis2/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 96c1eeef2..3081d5919 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -782,7 +782,9 @@ class Trellis2(nn.Module): embeds = kwargs.get("embeds") if embeds is None: raise ValueError("Trellis2.forward requires 'embeds' in kwargs") - is_1024 = self.img2shape.resolution == 64 + # img2shape.resolution is the latent-grid size, not the input pixel size: + # 32 -> 512px path, 64 -> 1024px path. + uses_1024_conditioning = self.img2shape.resolution == 64 coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") is_512_run = False @@ -797,7 +799,7 @@ class Trellis2(nn.Module): mode = "structure_generation" not_struct_mode = False - if is_1024 and not_struct_mode and not is_512_run: + if uses_1024_conditioning and not_struct_mode and not is_512_run: context = embeds sigmas = transformer_options.get("sigmas")[0].item() From 04099ef605c373ec82081f3137a23bd1926b67ae Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 19:53:30 -0500 Subject: [PATCH 62/93] Restore Trellis2 clip vision image_size state --- comfy_extras/nodes_trellis2.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 3479d5410..b1ad5d1e1 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -256,6 +256,7 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): model_internal = model.model device = comfy.model_management.intermediate_device() torch_device = comfy.model_management.get_torch_device() + original_image_size = getattr(model_internal, "image_size", None) def prepare_tensor(pil_img, size): resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS) @@ -268,10 +269,13 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] cond_1024 = None - if include_1024: - model_internal.image_size = 1024 - input_1024 = prepare_tensor(cropped_img_tensor, 1024) - cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] + try: + if include_1024: + model_internal.image_size = 1024 + input_1024 = prepare_tensor(cropped_img_tensor, 1024) + cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] + finally: + model_internal.image_size = original_image_size conditioning = { 'cond_512': cond_512.to(device), From d7416e56906b9bc8280223fd22532364428fc716 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 20:49:42 -0500 Subject: [PATCH 63/93] Guard full Trellis2 conditioning image_size restore --- comfy_extras/nodes_trellis2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index b1ad5d1e1..c8ac9bc33 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -264,12 +264,12 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device) return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device) - model_internal.image_size = 512 - input_512 = prepare_tensor(cropped_img_tensor, 512) - cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] - cond_1024 = None try: + model_internal.image_size = 512 + input_512 = prepare_tensor(cropped_img_tensor, 512) + cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] + if include_1024: model_internal.image_size = 1024 input_1024 = prepare_tensor(cropped_img_tensor, 1024) From 2ad1ca5531b96ad61cc4c80d81118d219b635afc Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 20:51:22 -0500 Subject: [PATCH 64/93] Handle missing Trellis2 image_size restore state --- comfy_extras/nodes_trellis2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index c8ac9bc33..2b712d113 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -256,7 +256,8 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): model_internal = model.model device = comfy.model_management.intermediate_device() torch_device = comfy.model_management.get_torch_device() - original_image_size = getattr(model_internal, "image_size", None) + image_size_missing = object() + original_image_size = getattr(model_internal, "image_size", image_size_missing) def prepare_tensor(pil_img, size): resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS) @@ -275,7 +276,10 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): input_1024 = prepare_tensor(cropped_img_tensor, 1024) cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] finally: - model_internal.image_size = original_image_size + if original_image_size is image_size_missing: + delattr(model_internal, "image_size") + else: + model_internal.image_size = original_image_size conditioning = { 'cond_512': cond_512.to(device), From 7c6b237fe89d074510da0cd4f382a07829649547 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 21:05:07 -0500 Subject: [PATCH 65/93] Match Copilot image_size restore pattern --- comfy_extras/nodes_trellis2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 2b712d113..61d3532a1 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -256,8 +256,8 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): model_internal = model.model device = comfy.model_management.intermediate_device() torch_device = comfy.model_management.get_torch_device() - image_size_missing = object() - original_image_size = getattr(model_internal, "image_size", image_size_missing) + had_image_size = hasattr(model_internal, "image_size") + original_image_size = getattr(model_internal, "image_size", None) def prepare_tensor(pil_img, size): resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS) @@ -276,7 +276,7 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): input_1024 = prepare_tensor(cropped_img_tensor, 1024) cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] finally: - if original_image_size is image_size_missing: + if not had_image_size: delattr(model_internal, "image_size") else: model_internal.image_size = original_image_size From cf3cfec5964afd91dd8404f2fb8ac7312ad458fa Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 21:11:58 -0500 Subject: [PATCH 66/93] Add Trellis2 image_size restore tests --- .../comfy_extras_test/nodes_trellis2_test.py | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 tests-unit/comfy_extras_test/nodes_trellis2_test.py diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py new file mode 100644 index 000000000..920eca471 --- /dev/null +++ b/tests-unit/comfy_extras_test/nodes_trellis2_test.py @@ -0,0 +1,127 @@ +import importlib +import sys +import types +import unittest +from unittest.mock import patch + +import torch +from PIL import Image + + +class _DummyPort: + @staticmethod + def Input(*args, **kwargs): + return None + + @staticmethod + def Output(*args, **kwargs): + return None + + +class _DummyIO: + ComfyNode = object + + @staticmethod + def Schema(*args, **kwargs): + return None + + @staticmethod + def NodeOutput(*args, **kwargs): + return args + + def __getattr__(self, name): + return _DummyPort + + +class _DummyTypes: + def __getattr__(self, name): + return lambda *args, **kwargs: None + + +dummy_comfy_api_latest = types.SimpleNamespace( + ComfyExtension=object, + IO=_DummyIO(), + Types=_DummyTypes(), +) + +dummy_sparse_tensor = type("SparseTensor", (), {}) +dummy_trellis_vae = types.SimpleNamespace(SparseTensor=dummy_sparse_tensor) + +with patch.dict(sys.modules, { + "comfy_api.latest": dummy_comfy_api_latest, + "comfy.ldm.trellis2.vae": dummy_trellis_vae, +}): + nodes_trellis2 = importlib.import_module("comfy_extras.nodes_trellis2") + + +class DummyInnerModel: + def __init__(self, image_size=..., fail_on_call=None): + self.call_count = 0 + self.fail_on_call = fail_on_call + if image_size is not ...: + self.image_size = image_size + + def __call__(self, input_tensor, skip_norm_elementwise=True): + self.call_count += 1 + if self.fail_on_call == self.call_count: + raise RuntimeError("expected conditioning failure") + return (torch.ones((1, 4), dtype=torch.float32),) + + +class DummyModel: + def __init__(self, inner_model): + self.model = inner_model + + +class TestRunConditioningRestore(unittest.TestCase): + def setUp(self): + self.intermediate_patch = patch.object( + nodes_trellis2.comfy.model_management, "intermediate_device", lambda: "cpu" + ) + self.torch_device_patch = patch.object( + nodes_trellis2.comfy.model_management, "get_torch_device", lambda: "cpu" + ) + self.intermediate_patch.start() + self.torch_device_patch.start() + + def tearDown(self): + self.intermediate_patch.stop() + self.torch_device_patch.stop() + + @staticmethod + def make_test_image(): + return Image.new("RGB", (8, 8), color="white") + + def test_restores_existing_image_size_after_success(self): + inner_model = DummyInnerModel(image_size=777) + + nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) + + self.assertEqual(inner_model.image_size, 777) + + def test_deletes_missing_image_size_after_success(self): + inner_model = DummyInnerModel() + + nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) + + self.assertFalse(hasattr(inner_model, "image_size")) + + def test_restores_existing_image_size_after_512_failure(self): + inner_model = DummyInnerModel(image_size=777, fail_on_call=1) + + with self.assertRaisesRegex(RuntimeError, "expected conditioning failure"): + nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) + + self.assertEqual(inner_model.image_size, 777) + + def test_deletes_missing_image_size_after_1024_failure(self): + inner_model = DummyInnerModel(fail_on_call=2) + + with self.assertRaisesRegex(RuntimeError, "expected conditioning failure"): + nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) + + self.assertFalse(hasattr(inner_model, "image_size")) + + +if __name__ == "__main__": + unittest.main() From b443f423b430429ad9e6c98c9318a7658eb2ad8d Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 21:26:48 -0500 Subject: [PATCH 67/93] Trellis2: slice cond half of x symmetrically under shape_rule pruning --- comfy/ldm/trellis2/model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 3081d5919..fc8df3c2b 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -863,11 +863,12 @@ class Trellis2(nn.Module): else: # structure orig_bsz = x.shape[0] if shape_rule: - x = x[1].unsqueeze(0) - timestep = timestep[1].unsqueeze(0) + half = orig_bsz // 2 + x = x[half:] + timestep = timestep[half:] if timestep.shape[0] > 1 else timestep out = self.structure_model(x, timestep, context if not shape_rule else cond) if shape_rule: - out = out.repeat(orig_bsz, 1, 1, 1, 1) + out = out.repeat(2, 1, 1, 1, 1) if not_struct_mode: out = out.feats From 70511a9a91d89efeca0cc3d11e7e4041c5e2438c Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 21:38:45 -0500 Subject: [PATCH 68/93] Trellis2: guard structure shape_rule pruning to CFG batches --- comfy/ldm/trellis2/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index fc8df3c2b..1c5d6c3ec 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -862,12 +862,12 @@ class Trellis2(nn.Module): out = self.shape2txt(x_st, t_eval, c_eval) else: # structure orig_bsz = x.shape[0] - if shape_rule: + if shape_rule and orig_bsz > 1: half = orig_bsz // 2 x = x[half:] timestep = timestep[half:] if timestep.shape[0] > 1 else timestep - out = self.structure_model(x, timestep, context if not shape_rule else cond) - if shape_rule: + out = self.structure_model(x, timestep, cond if shape_rule and orig_bsz > 1 else context) + if shape_rule and orig_bsz > 1: out = out.repeat(2, 1, 1, 1, 1) if not_struct_mode: From 44043ee6e5fadaa569c8a905915901b677dec65f Mon Sep 17 00:00:00 2001 From: John Pollock Date: Fri, 17 Apr 2026 22:42:42 -0500 Subject: [PATCH 69/93] Trellis2/Hunyuan3d: n>1 batched cascade support Mesh-producing nodes (VoxelToMeshBasic, VoxelToMesh, VaeDecodeShapeTrellis) previously stacked per-batch vertex/face tensors with torch.stack, which failed under batch>1 because per-item meshes have variable shapes. Store per-item tensors as lists when shapes differ; keep stacked-tensor fast path when shapes match. Update SaveGLB, PostProcessMesh, and VaeDecodeTextureTrellis consumers to iterate per-item when input is a list. Trellis2Conditioning.execute previously collapsed batched image/mask input to index 0, yielding identical conditioning for every batch item. Loop over the batch and produce per-image cond_512/cond_1024/neg_cond tensors stacked along the batch dim, matching the latent batch size. batch_size=1 behavior is unchanged. batch_size>1 runs now emit one GLB per batch item per SaveGLB node and carry per-image conditioning through the structure/shape/texture cascade. --- comfy_extras/nodes_hunyuan3d.py | 11 ++- comfy_extras/nodes_trellis2.py | 156 +++++++++++++++++++++----------- 2 files changed, 112 insertions(+), 55 deletions(-) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index ac91fe0a7..8f58e85d9 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -443,7 +443,9 @@ class VoxelToMeshBasic(IO.ComfyNode): vertices.append(v) faces.append(f) - return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + return IO.NodeOutput(Types.MESH(vertices, faces)) decode = execute # TODO: remove @@ -479,7 +481,9 @@ class VoxelToMesh(IO.ComfyNode): vertices.append(v) faces.append(f) - return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + return IO.NodeOutput(Types.MESH(vertices, faces)) decode = execute # TODO: remove @@ -682,7 +686,8 @@ class SaveGLB(IO.ComfyNode): }) else: # Handle Mesh input - save vertices and faces as GLB - for i in range(mesh.vertices.shape[0]): + bsz = len(mesh.vertices) if isinstance(mesh.vertices, list) else mesh.vertices.shape[0] + for i in range(bsz): f = f"{filename}_{counter:05}_.glb" v_colors = mesh.colors[i] if hasattr(mesh, "colors") and mesh.colors is not None else None save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata, v_colors) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 61d3532a1..8ef3e8f5a 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -117,9 +117,12 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): samples = shape_norm(samples, coords) mesh, subs = vae.decode_shape_slat(samples, resolution) - faces = torch.stack([m.faces for m in mesh]) - verts = torch.stack([m.vertices for m in mesh]) - mesh = Types.MESH(vertices=verts, faces=faces) + face_list = [m.faces for m in mesh] + vert_list = [m.vertices for m in mesh] + if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list): + mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list)) + else: + mesh = Types.MESH(vertices=vert_list, faces=face_list) return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @@ -160,8 +163,23 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): voxel = vae.decode_tex_slat(samples, shape_subs) color_feats = voxel.feats[:, :3] voxel_coords = voxel.coords[:, 1:] + voxel_batch_idx = voxel.coords[:, 0] - out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) + if isinstance(shape_mesh.vertices, list): + out_verts, out_faces, out_colors = [], [], [] + for i in range(len(shape_mesh.vertices)): + sel = voxel_batch_idx == i + item_coords = voxel_coords[sel] + item_colors = color_feats[sel] + item_mesh = Types.MESH(vertices=shape_mesh.vertices[i].unsqueeze(0), faces=shape_mesh.faces[i].unsqueeze(0)) + painted = paint_mesh_with_voxels(item_mesh, item_coords, item_colors, resolution=resolution) + out_verts.append(painted.vertices.squeeze(0)) + out_faces.append(painted.faces.squeeze(0)) + out_colors.append(painted.colors.squeeze(0)) + out_mesh = Types.MESH(vertices=out_verts, faces=out_faces) + out_mesh.colors = out_colors + else: + out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) return IO.NodeOutput(out_mesh) class VaeDecodeStructureTrellis2(IO.ComfyNode): @@ -310,69 +328,83 @@ class Trellis2Conditioning(IO.ComfyNode): @classmethod def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput: + # Normalize to batched form so per-image conditioning loop below is uniform. + if image.ndim == 3: + image = image.unsqueeze(0) + if mask.ndim == 2: + mask = mask.unsqueeze(0) + batch_size = image.shape[0] - if image.ndim == 4: - image = image[0] - if mask.ndim == 3: - mask = mask[0] + cond_512_list = [] + cond_1024_list = [] - img_np = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - mask_np = (mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + for b in range(batch_size): + item_image = image[b] + item_mask = mask[b] - pil_img = Image.fromarray(img_np) - pil_mask = Image.fromarray(mask_np) + img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - max_size = max(pil_img.size) - scale = min(1.0, 1024 / max_size) - if scale < 1.0: - new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale) - pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) - pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) + pil_img = Image.fromarray(img_np) + pil_mask = Image.fromarray(mask_np) - rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) - rgba_np[:, :, :3] = np.array(pil_img) - rgba_np[:, :, 3] = np.array(pil_mask) + max_size = max(pil_img.size) + scale = min(1.0, 1024 / max_size) + if scale < 1.0: + new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale) + pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) + pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) - alpha = rgba_np[:, :, 3] - bbox_coords = np.argwhere(alpha > 0.8 * 255) + rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) + rgba_np[:, :, :3] = np.array(pil_img) + rgba_np[:, :, 3] = np.array(pil_mask) - if len(bbox_coords) > 0: - y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1]) - y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1]) + alpha = rgba_np[:, :, 3] + bbox_coords = np.argwhere(alpha > 0.8 * 255) - center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0 - size = max(y_max - y_min, x_max - x_min) + if len(bbox_coords) > 0: + y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1]) + y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1]) - crop_x1 = int(center_x - size // 2) - crop_y1 = int(center_y - size // 2) - crop_x2 = int(center_x + size // 2) - crop_y2 = int(center_y + size // 2) + center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0 + size = max(y_max - y_min, x_max - x_min) - rgba_pil = Image.fromarray(rgba_np) - cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) - cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 - else: - import logging - logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") - cropped_np = rgba_np.astype(np.float32) / 255.0 + crop_x1 = int(center_x - size // 2) + crop_y1 = int(center_y - size // 2) + crop_x2 = int(center_x + size // 2) + crop_y2 = int(center_y + size // 2) - bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} - bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32) + rgba_pil = Image.fromarray(rgba_np) + cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) + cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 + else: + import logging + logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") + cropped_np = rgba_np.astype(np.float32) / 255.0 - fg = cropped_np[:, :, :3] - alpha_float = cropped_np[:, :, 3:4] - composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) + bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} + bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32) - # to match trellis2 code (quantize -> dequantize) - composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) + fg = cropped_np[:, :, :3] + alpha_float = cropped_np[:, :, 3:4] + composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) - cropped_pil = Image.fromarray(composite_uint8) + # to match trellis2 code (quantize -> dequantize) + composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) - conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True) + cropped_pil = Image.fromarray(composite_uint8) - embeds = conditioning["cond_1024"] - positive = [[conditioning["cond_512"], {"embeds": embeds}]] - negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]] + item_conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True) + cond_512_list.append(item_conditioning["cond_512"]) + cond_1024_list.append(item_conditioning["cond_1024"]) + + cond_512_batched = torch.cat(cond_512_list, dim=0) + cond_1024_batched = torch.cat(cond_1024_list, dim=0) + neg_cond_batched = torch.zeros_like(cond_512_batched) + neg_embeds_batched = torch.zeros_like(cond_1024_batched) + + positive = [[cond_512_batched, {"embeds": cond_1024_batched}]] + negative = [[neg_cond_batched, {"embeds": neg_embeds_batched}]] return IO.NodeOutput(positive, negative) class EmptyShapeLatentTrellis2(IO.ComfyNode): @@ -659,7 +691,27 @@ class PostProcessMesh(IO.ComfyNode): @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): - # TODO: batched mode may break + if isinstance(mesh.vertices, list): + out_verts, out_faces, out_colors = [], [], [] + colors_in = mesh.colors if hasattr(mesh, "colors") and mesh.colors is not None else None + for i in range(len(mesh.vertices)): + v_i = mesh.vertices[i] + f_i = mesh.faces[i] + c_i = colors_in[i] if colors_in is not None else None + actual_face_count = f_i.shape[0] + if fill_holes_perimeter > 0: + v_i, f_i = fill_holes_fn(v_i, f_i, max_perimeter=fill_holes_perimeter) + if simplify > 0 and actual_face_count > simplify: + v_i, f_i, c_i = simplify_fn(v_i, f_i, target=simplify, colors=c_i) + v_i, f_i = make_double_sided(v_i, f_i) + out_verts.append(v_i) + out_faces.append(f_i) + if c_i is not None: + out_colors.append(c_i) + out_mesh = type(mesh)(vertices=out_verts, faces=out_faces) + if len(out_colors) == len(out_verts): + out_mesh.colors = out_colors + return IO.NodeOutput(out_mesh) verts, faces = mesh.vertices, mesh.faces colors = None if hasattr(mesh, "colors"): From 6d99b636c12315f340c668b05d97431b1b547b5c Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 22:55:38 -0500 Subject: [PATCH 70/93] Trellis2/Hunyuan3d: preserve mesh tensor contract in batch mode --- comfy_extras/nodes_hunyuan3d.py | 61 ++++++++++++++++++++++--- comfy_extras/nodes_trellis2.py | 80 ++++++++++++++++++++++++++------- 2 files changed, 121 insertions(+), 20 deletions(-) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 8f58e85d9..0b7e17bb5 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -445,7 +445,7 @@ class VoxelToMeshBasic(IO.ComfyNode): if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) - return IO.NodeOutput(Types.MESH(vertices, faces)) + return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces)) decode = execute # TODO: remove @@ -483,7 +483,7 @@ class VoxelToMesh(IO.ComfyNode): if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) - return IO.NodeOutput(Types.MESH(vertices, faces)) + return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces)) decode = execute # TODO: remove @@ -632,6 +632,57 @@ def save_glb(vertices, faces, filepath, metadata=None, colors=None): return filepath +def pack_variable_mesh_batch(vertices, faces, colors=None): + batch_size = len(vertices) + max_vertices = max(v.shape[0] for v in vertices) + max_faces = max(f.shape[0] for f in faces) + + packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) + packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) + vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) + face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) + + for i, (v, f) in enumerate(zip(vertices, faces)): + packed_vertices[i, :v.shape[0]] = v + packed_faces[i, :f.shape[0]] = f + + mesh = Types.MESH(packed_vertices, packed_faces) + mesh.vertex_counts = vertex_counts + mesh.face_counts = face_counts + + if colors is not None: + max_colors = max(c.shape[0] for c in colors) + packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) + color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) + for i, c in enumerate(colors): + packed_colors[i, :c.shape[0]] = c + mesh.colors = packed_colors + mesh.color_counts = color_counts + + return mesh + + +def get_mesh_batch_item(mesh, index): + if hasattr(mesh, "vertex_counts"): + vertex_count = int(mesh.vertex_counts[index].item()) + face_count = int(mesh.face_counts[index].item()) + vertices = mesh.vertices[index, :vertex_count] + faces = mesh.faces[index, :face_count] + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + if hasattr(mesh, "color_counts"): + color_count = int(mesh.color_counts[index].item()) + colors = mesh.colors[index, :color_count] + else: + colors = mesh.colors[index, :vertex_count] + return vertices, faces, colors + + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + colors = mesh.colors[index] + return mesh.vertices[index], mesh.faces[index], colors + + class SaveGLB(IO.ComfyNode): @classmethod def define_schema(cls): @@ -686,11 +737,11 @@ class SaveGLB(IO.ComfyNode): }) else: # Handle Mesh input - save vertices and faces as GLB - bsz = len(mesh.vertices) if isinstance(mesh.vertices, list) else mesh.vertices.shape[0] + bsz = mesh.vertices.shape[0] for i in range(bsz): f = f"{filename}_{counter:05}_.glb" - v_colors = mesh.colors[i] if hasattr(mesh, "colors") and mesh.colors is not None else None - save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata, v_colors) + vertices, faces, v_colors = get_mesh_batch_item(mesh, i) + save_glb(vertices, faces, os.path.join(full_output_folder, f), metadata, v_colors) results.append({ "filename": f, "subfolder": subfolder, diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 8ef3e8f5a..57732151b 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -8,6 +8,57 @@ import torch import scipy import copy + +def pack_variable_mesh_batch(vertices, faces, colors=None): + batch_size = len(vertices) + max_vertices = max(v.shape[0] for v in vertices) + max_faces = max(f.shape[0] for f in faces) + + packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) + packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) + vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) + face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) + + for i, (v, f) in enumerate(zip(vertices, faces)): + packed_vertices[i, :v.shape[0]] = v + packed_faces[i, :f.shape[0]] = f + + mesh = Types.MESH(packed_vertices, packed_faces) + mesh.vertex_counts = vertex_counts + mesh.face_counts = face_counts + + if colors is not None: + max_colors = max(c.shape[0] for c in colors) + packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) + color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) + for i, c in enumerate(colors): + packed_colors[i, :c.shape[0]] = c + mesh.colors = packed_colors + mesh.color_counts = color_counts + + return mesh + + +def get_mesh_batch_item(mesh, index): + if hasattr(mesh, "vertex_counts"): + vertex_count = int(mesh.vertex_counts[index].item()) + face_count = int(mesh.face_counts[index].item()) + vertices = mesh.vertices[index, :vertex_count] + faces = mesh.faces[index, :face_count] + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + if hasattr(mesh, "color_counts"): + color_count = int(mesh.color_counts[index].item()) + colors = mesh.colors[index, :color_count] + else: + colors = mesh.colors[index, :vertex_count] + return vertices, faces, colors + + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + colors = mesh.colors[index] + return mesh.vertices[index], mesh.faces[index], colors + shape_slat_normalization = { "mean": torch.tensor([ 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218, @@ -122,7 +173,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list): mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list)) else: - mesh = Types.MESH(vertices=vert_list, faces=face_list) + mesh = pack_variable_mesh_batch(vert_list, face_list) return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @@ -165,19 +216,19 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): voxel_coords = voxel.coords[:, 1:] voxel_batch_idx = voxel.coords[:, 0] - if isinstance(shape_mesh.vertices, list): + if hasattr(shape_mesh, "vertex_counts"): out_verts, out_faces, out_colors = [], [], [] - for i in range(len(shape_mesh.vertices)): + for i in range(shape_mesh.vertices.shape[0]): sel = voxel_batch_idx == i item_coords = voxel_coords[sel] item_colors = color_feats[sel] - item_mesh = Types.MESH(vertices=shape_mesh.vertices[i].unsqueeze(0), faces=shape_mesh.faces[i].unsqueeze(0)) + item_vertices, item_faces, _ = get_mesh_batch_item(shape_mesh, i) + item_mesh = Types.MESH(vertices=item_vertices.unsqueeze(0), faces=item_faces.unsqueeze(0)) painted = paint_mesh_with_voxels(item_mesh, item_coords, item_colors, resolution=resolution) out_verts.append(painted.vertices.squeeze(0)) out_faces.append(painted.faces.squeeze(0)) out_colors.append(painted.colors.squeeze(0)) - out_mesh = Types.MESH(vertices=out_verts, faces=out_faces) - out_mesh.colors = out_colors + out_mesh = pack_variable_mesh_batch(out_verts, out_faces, out_colors) else: out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) return IO.NodeOutput(out_mesh) @@ -334,6 +385,10 @@ class Trellis2Conditioning(IO.ComfyNode): if mask.ndim == 2: mask = mask.unsqueeze(0) batch_size = image.shape[0] + if mask.shape[0] == 1 and batch_size > 1: + mask = mask.repeat(batch_size, 1, 1) + elif mask.shape[0] != batch_size: + raise ValueError(f"Trellis2Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}") cond_512_list = [] cond_1024_list = [] @@ -691,13 +746,10 @@ class PostProcessMesh(IO.ComfyNode): @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): - if isinstance(mesh.vertices, list): + if hasattr(mesh, "vertex_counts"): out_verts, out_faces, out_colors = [], [], [] - colors_in = mesh.colors if hasattr(mesh, "colors") and mesh.colors is not None else None - for i in range(len(mesh.vertices)): - v_i = mesh.vertices[i] - f_i = mesh.faces[i] - c_i = colors_in[i] if colors_in is not None else None + for i in range(mesh.vertices.shape[0]): + v_i, f_i, c_i = get_mesh_batch_item(mesh, i) actual_face_count = f_i.shape[0] if fill_holes_perimeter > 0: v_i, f_i = fill_holes_fn(v_i, f_i, max_perimeter=fill_holes_perimeter) @@ -708,9 +760,7 @@ class PostProcessMesh(IO.ComfyNode): out_faces.append(f_i) if c_i is not None: out_colors.append(c_i) - out_mesh = type(mesh)(vertices=out_verts, faces=out_faces) - if len(out_colors) == len(out_verts): - out_mesh.colors = out_colors + out_mesh = pack_variable_mesh_batch(out_verts, out_faces, out_colors if len(out_colors) == len(out_verts) else None) return IO.NodeOutput(out_mesh) verts, faces = mesh.vertices, mesh.faces colors = None From c297a9f839a26c22c44a879aeeb3aed302055448 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 23:23:24 -0500 Subject: [PATCH 71/93] Trellis2: handle empty and batched texture paint paths --- comfy_extras/nodes_trellis2.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 57732151b..7a72b2824 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -135,6 +135,13 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): return out_mesh + +def paint_mesh_default_colors(mesh): + out_mesh = copy.deepcopy(mesh) + vertex_count = mesh.vertices.shape[1] + out_mesh.colors = mesh.vertices.new_zeros((1, vertex_count, 3)) + return out_mesh + class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def define_schema(cls): @@ -216,21 +223,28 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): voxel_coords = voxel.coords[:, 1:] voxel_batch_idx = voxel.coords[:, 0] - if hasattr(shape_mesh, "vertex_counts"): + mesh_batch_size = shape_mesh.vertices.shape[0] + if mesh_batch_size > 1: out_verts, out_faces, out_colors = [], [], [] - for i in range(shape_mesh.vertices.shape[0]): + for i in range(mesh_batch_size): sel = voxel_batch_idx == i item_coords = voxel_coords[sel] item_colors = color_feats[sel] item_vertices, item_faces, _ = get_mesh_batch_item(shape_mesh, i) item_mesh = Types.MESH(vertices=item_vertices.unsqueeze(0), faces=item_faces.unsqueeze(0)) - painted = paint_mesh_with_voxels(item_mesh, item_coords, item_colors, resolution=resolution) + if item_coords.shape[0] == 0: + painted = paint_mesh_default_colors(item_mesh) + else: + painted = paint_mesh_with_voxels(item_mesh, item_coords, item_colors, resolution=resolution) out_verts.append(painted.vertices.squeeze(0)) out_faces.append(painted.faces.squeeze(0)) out_colors.append(painted.colors.squeeze(0)) out_mesh = pack_variable_mesh_batch(out_verts, out_faces, out_colors) else: - out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) + if voxel_coords.shape[0] == 0: + out_mesh = paint_mesh_default_colors(shape_mesh) + else: + out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) return IO.NodeOutput(out_mesh) class VaeDecodeStructureTrellis2(IO.ComfyNode): From 40219ab0fce492f8ff91f99f909e3b5060483e32 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 23:33:09 -0500 Subject: [PATCH 72/93] Trellis2: share batched mesh helpers --- comfy_extras/mesh_batch_utils.py | 53 +++++++++++++++++++++++++++++ comfy_extras/nodes_hunyuan3d.py | 53 +---------------------------- comfy_extras/nodes_trellis2.py | 58 +++----------------------------- 3 files changed, 58 insertions(+), 106 deletions(-) create mode 100644 comfy_extras/mesh_batch_utils.py diff --git a/comfy_extras/mesh_batch_utils.py b/comfy_extras/mesh_batch_utils.py new file mode 100644 index 000000000..841328776 --- /dev/null +++ b/comfy_extras/mesh_batch_utils.py @@ -0,0 +1,53 @@ +import torch +from comfy_api.latest import Types + + +def pack_variable_mesh_batch(vertices, faces, colors=None): + batch_size = len(vertices) + max_vertices = max(v.shape[0] for v in vertices) + max_faces = max(f.shape[0] for f in faces) + + packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) + packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) + vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) + face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) + + for i, (v, f) in enumerate(zip(vertices, faces)): + packed_vertices[i, :v.shape[0]] = v + packed_faces[i, :f.shape[0]] = f + + mesh = Types.MESH(packed_vertices, packed_faces) + mesh.vertex_counts = vertex_counts + mesh.face_counts = face_counts + + if colors is not None: + max_colors = max(c.shape[0] for c in colors) + packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) + color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) + for i, c in enumerate(colors): + packed_colors[i, :c.shape[0]] = c + mesh.colors = packed_colors + mesh.color_counts = color_counts + + return mesh + + +def get_mesh_batch_item(mesh, index): + if hasattr(mesh, "vertex_counts"): + vertex_count = int(mesh.vertex_counts[index].item()) + face_count = int(mesh.face_counts[index].item()) + vertices = mesh.vertices[index, :vertex_count] + faces = mesh.faces[index, :face_count] + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + if hasattr(mesh, "color_counts"): + color_count = int(mesh.color_counts[index].item()) + colors = mesh.colors[index, :color_count] + else: + colors = mesh.colors[index, :vertex_count] + return vertices, faces, colors + + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + colors = mesh.colors[index] + return mesh.vertices[index], mesh.faces[index], colors diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 0b7e17bb5..78ab3b841 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -10,6 +10,7 @@ from comfy.cli_args import args from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa +from comfy_extras.mesh_batch_utils import pack_variable_mesh_batch, get_mesh_batch_item class EmptyLatentHunyuan3Dv2(IO.ComfyNode): @@ -631,58 +632,6 @@ def save_glb(vertices, faces, filepath, metadata=None, colors=None): return filepath - -def pack_variable_mesh_batch(vertices, faces, colors=None): - batch_size = len(vertices) - max_vertices = max(v.shape[0] for v in vertices) - max_faces = max(f.shape[0] for f in faces) - - packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) - packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) - vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) - face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) - - for i, (v, f) in enumerate(zip(vertices, faces)): - packed_vertices[i, :v.shape[0]] = v - packed_faces[i, :f.shape[0]] = f - - mesh = Types.MESH(packed_vertices, packed_faces) - mesh.vertex_counts = vertex_counts - mesh.face_counts = face_counts - - if colors is not None: - max_colors = max(c.shape[0] for c in colors) - packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) - color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) - for i, c in enumerate(colors): - packed_colors[i, :c.shape[0]] = c - mesh.colors = packed_colors - mesh.color_counts = color_counts - - return mesh - - -def get_mesh_batch_item(mesh, index): - if hasattr(mesh, "vertex_counts"): - vertex_count = int(mesh.vertex_counts[index].item()) - face_count = int(mesh.face_counts[index].item()) - vertices = mesh.vertices[index, :vertex_count] - faces = mesh.faces[index, :face_count] - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - if hasattr(mesh, "color_counts"): - color_count = int(mesh.color_counts[index].item()) - colors = mesh.colors[index, :color_count] - else: - colors = mesh.colors[index, :vertex_count] - return vertices, faces, colors - - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - colors = mesh.colors[index] - return mesh.vertices[index], mesh.faces[index], colors - - class SaveGLB(IO.ComfyNode): @classmethod def define_schema(cls): diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 7a72b2824..cdac6f103 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,6 +1,7 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy.ldm.trellis2.vae import SparseTensor +from comfy_extras.mesh_batch_utils import pack_variable_mesh_batch, get_mesh_batch_item import comfy.model_management from PIL import Image import numpy as np @@ -8,57 +9,6 @@ import torch import scipy import copy - -def pack_variable_mesh_batch(vertices, faces, colors=None): - batch_size = len(vertices) - max_vertices = max(v.shape[0] for v in vertices) - max_faces = max(f.shape[0] for f in faces) - - packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) - packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) - vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) - face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) - - for i, (v, f) in enumerate(zip(vertices, faces)): - packed_vertices[i, :v.shape[0]] = v - packed_faces[i, :f.shape[0]] = f - - mesh = Types.MESH(packed_vertices, packed_faces) - mesh.vertex_counts = vertex_counts - mesh.face_counts = face_counts - - if colors is not None: - max_colors = max(c.shape[0] for c in colors) - packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) - color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) - for i, c in enumerate(colors): - packed_colors[i, :c.shape[0]] = c - mesh.colors = packed_colors - mesh.color_counts = color_counts - - return mesh - - -def get_mesh_batch_item(mesh, index): - if hasattr(mesh, "vertex_counts"): - vertex_count = int(mesh.vertex_counts[index].item()) - face_count = int(mesh.face_counts[index].item()) - vertices = mesh.vertices[index, :vertex_count] - faces = mesh.faces[index, :face_count] - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - if hasattr(mesh, "color_counts"): - color_count = int(mesh.color_counts[index].item()) - colors = mesh.colors[index, :color_count] - else: - colors = mesh.colors[index, :vertex_count] - return vertices, faces, colors - - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - colors = mesh.colors[index] - return mesh.vertices[index], mesh.faces[index], colors - shape_slat_normalization = { "mean": torch.tensor([ 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218, @@ -130,14 +80,14 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): final_colors = linear_colors.unsqueeze(0) - out_mesh = copy.deepcopy(mesh) + out_mesh = copy.copy(mesh) out_mesh.colors = final_colors return out_mesh def paint_mesh_default_colors(mesh): - out_mesh = copy.deepcopy(mesh) + out_mesh = copy.copy(mesh) vertex_count = mesh.vertices.shape[1] out_mesh.colors = mesh.vertices.new_zeros((1, vertex_count, 3)) return out_mesh @@ -400,7 +350,7 @@ class Trellis2Conditioning(IO.ComfyNode): mask = mask.unsqueeze(0) batch_size = image.shape[0] if mask.shape[0] == 1 and batch_size > 1: - mask = mask.repeat(batch_size, 1, 1) + mask = mask.expand(batch_size, -1, -1) elif mask.shape[0] != batch_size: raise ValueError(f"Trellis2Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}") From 9cfa8f2c0171ca386c2702d482252a3d6cf64ce8 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Sun, 19 Apr 2026 23:47:57 -0500 Subject: [PATCH 73/93] Trellis2: inline batched mesh helpers --- comfy_extras/mesh_batch_utils.py | 53 -------------------------------- comfy_extras/nodes_hunyuan3d.py | 52 ++++++++++++++++++++++++++++++- comfy_extras/nodes_trellis2.py | 52 ++++++++++++++++++++++++++++++- 3 files changed, 102 insertions(+), 55 deletions(-) delete mode 100644 comfy_extras/mesh_batch_utils.py diff --git a/comfy_extras/mesh_batch_utils.py b/comfy_extras/mesh_batch_utils.py deleted file mode 100644 index 841328776..000000000 --- a/comfy_extras/mesh_batch_utils.py +++ /dev/null @@ -1,53 +0,0 @@ -import torch -from comfy_api.latest import Types - - -def pack_variable_mesh_batch(vertices, faces, colors=None): - batch_size = len(vertices) - max_vertices = max(v.shape[0] for v in vertices) - max_faces = max(f.shape[0] for f in faces) - - packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) - packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) - vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) - face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) - - for i, (v, f) in enumerate(zip(vertices, faces)): - packed_vertices[i, :v.shape[0]] = v - packed_faces[i, :f.shape[0]] = f - - mesh = Types.MESH(packed_vertices, packed_faces) - mesh.vertex_counts = vertex_counts - mesh.face_counts = face_counts - - if colors is not None: - max_colors = max(c.shape[0] for c in colors) - packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) - color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) - for i, c in enumerate(colors): - packed_colors[i, :c.shape[0]] = c - mesh.colors = packed_colors - mesh.color_counts = color_counts - - return mesh - - -def get_mesh_batch_item(mesh, index): - if hasattr(mesh, "vertex_counts"): - vertex_count = int(mesh.vertex_counts[index].item()) - face_count = int(mesh.face_counts[index].item()) - vertices = mesh.vertices[index, :vertex_count] - faces = mesh.faces[index, :face_count] - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - if hasattr(mesh, "color_counts"): - color_count = int(mesh.color_counts[index].item()) - colors = mesh.colors[index, :color_count] - else: - colors = mesh.colors[index, :vertex_count] - return vertices, faces, colors - - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - colors = mesh.colors[index] - return mesh.vertices[index], mesh.faces[index], colors diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 78ab3b841..7ae69db98 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -10,7 +10,6 @@ from comfy.cli_args import args from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa -from comfy_extras.mesh_batch_utils import pack_variable_mesh_batch, get_mesh_batch_item class EmptyLatentHunyuan3Dv2(IO.ComfyNode): @@ -632,6 +631,57 @@ def save_glb(vertices, faces, filepath, metadata=None, colors=None): return filepath + +def pack_variable_mesh_batch(vertices, faces, colors=None): + batch_size = len(vertices) + max_vertices = max(v.shape[0] for v in vertices) + max_faces = max(f.shape[0] for f in faces) + + packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) + packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) + vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) + face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) + + for i, (v, f) in enumerate(zip(vertices, faces)): + packed_vertices[i, :v.shape[0]] = v + packed_faces[i, :f.shape[0]] = f + + mesh = Types.MESH(packed_vertices, packed_faces) + mesh.vertex_counts = vertex_counts + mesh.face_counts = face_counts + + if colors is not None: + max_colors = max(c.shape[0] for c in colors) + packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) + color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) + for i, c in enumerate(colors): + packed_colors[i, :c.shape[0]] = c + mesh.colors = packed_colors + mesh.color_counts = color_counts + + return mesh + + +def get_mesh_batch_item(mesh, index): + if hasattr(mesh, "vertex_counts"): + vertex_count = int(mesh.vertex_counts[index].item()) + face_count = int(mesh.face_counts[index].item()) + vertices = mesh.vertices[index, :vertex_count] + faces = mesh.faces[index, :face_count] + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + if hasattr(mesh, "color_counts"): + color_count = int(mesh.color_counts[index].item()) + colors = mesh.colors[index, :color_count] + else: + colors = mesh.colors[index, :vertex_count] + return vertices, faces, colors + + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + colors = mesh.colors[index] + return mesh.vertices[index], mesh.faces[index], colors + class SaveGLB(IO.ComfyNode): @classmethod def define_schema(cls): diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index cdac6f103..8121e261b 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,7 +1,6 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy.ldm.trellis2.vae import SparseTensor -from comfy_extras.mesh_batch_utils import pack_variable_mesh_batch, get_mesh_batch_item import comfy.model_management from PIL import Image import numpy as np @@ -9,6 +8,57 @@ import torch import scipy import copy + +def pack_variable_mesh_batch(vertices, faces, colors=None): + batch_size = len(vertices) + max_vertices = max(v.shape[0] for v in vertices) + max_faces = max(f.shape[0] for f in faces) + + packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) + packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) + vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) + face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) + + for i, (v, f) in enumerate(zip(vertices, faces)): + packed_vertices[i, :v.shape[0]] = v + packed_faces[i, :f.shape[0]] = f + + mesh = Types.MESH(packed_vertices, packed_faces) + mesh.vertex_counts = vertex_counts + mesh.face_counts = face_counts + + if colors is not None: + max_colors = max(c.shape[0] for c in colors) + packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) + color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) + for i, c in enumerate(colors): + packed_colors[i, :c.shape[0]] = c + mesh.colors = packed_colors + mesh.color_counts = color_counts + + return mesh + + +def get_mesh_batch_item(mesh, index): + if hasattr(mesh, "vertex_counts"): + vertex_count = int(mesh.vertex_counts[index].item()) + face_count = int(mesh.face_counts[index].item()) + vertices = mesh.vertices[index, :vertex_count] + faces = mesh.faces[index, :face_count] + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + if hasattr(mesh, "color_counts"): + color_count = int(mesh.color_counts[index].item()) + colors = mesh.colors[index, :color_count] + else: + colors = mesh.colors[index, :vertex_count] + return vertices, faces, colors + + colors = None + if hasattr(mesh, "colors") and mesh.colors is not None: + colors = mesh.colors[index] + return mesh.vertices[index], mesh.faces[index], colors + shape_slat_normalization = { "mean": torch.tensor([ 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218, From c81ddf23498d27f82293272494fd66f31dacb7fc Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 11:06:04 -0500 Subject: [PATCH 74/93] Fix Trellis2 batched shape and texture semantics --- comfy/ldm/trellis2/model.py | 347 ++++++++++++++++++++++++++++++--- comfy/sample.py | 17 ++ comfy_extras/nodes_trellis2.py | 332 +++++++++++++++++++++++++++---- 3 files changed, 635 insertions(+), 61 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 1c5d6c3ec..76dbacc93 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -786,6 +786,7 @@ class Trellis2(nn.Module): # 32 -> 512px path, 64 -> 1024px path. uses_1024_conditioning = self.img2shape.resolution == 64 coords = transformer_options.get("coords", None) + coord_counts = transformer_options.get("coord_counts") mode = transformer_options.get("generation_mode", "structure_generation") is_512_run = False timestep = timestep.to(self.dtype) @@ -811,40 +812,205 @@ class Trellis2(nn.Module): cond = context shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] + dense_out = None if not_struct_mode: orig_bsz = x.shape[0] rule = txt_rule if mode == "texture_generation" else shape_rule - if rule and orig_bsz > 1: - x_eval = x[1].unsqueeze(0) - t_eval = timestep[1].unsqueeze(0) if timestep.shape[0] > 1 else timestep + logical_batch = coord_counts.shape[0] if coord_counts is not None else 1 + if rule and orig_bsz > logical_batch: + half = orig_bsz // 2 + x_eval = x[half:] + t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep c_eval = cond else: x_eval = x t_eval = timestep c_eval = context + x_eval_norms = [float(v) for v in x_eval.square().sum(dim=(1, 2)).detach().cpu().tolist()] + c_eval_norms = [float(v) for v in c_eval.square().sum(dim=(1, 2)).detach().cpu().tolist()] + print( + "TRELLIS2_NOT_STRUCT_INPUT_TRACE", + { + "mode": mode, + "orig_bsz": int(orig_bsz), + "logical_batch": int(logical_batch), + "rule": bool(rule), + "coord_counts": coord_counts.tolist() if coord_counts is not None else None, + "x_eval_norms": x_eval_norms, + "c_eval_norms": c_eval_norms, + }, + ) + B, N, C = x_eval.shape if mode in ["shape_generation", "texture_generation"]: - feats_flat = x_eval.reshape(-1, C) + if coord_counts is not None: + logical_batch = coord_counts.shape[0] + if B % logical_batch != 0: + raise ValueError( + f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}" + ) + repeat_factor = B // logical_batch + sparse_outs = [] + active_coord_counts = [] + if mode == "shape_generation" and repeat_factor > 1: + grouped_outs = [] + grouped_counts = [] + for i in range(logical_batch): + count = int(coord_counts[i].item()) + coords_i = coords[coords[:, 0] == i].clone() + if coords_i.shape[0] != count: + raise ValueError( + f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}" + ) - # inflate coords [N, 4] -> [B*N, 4] - coords_list = [] - for i in range(B): - c = coords.clone() - c[:, 0] = i - coords_list.append(c) + feat_batches = [] + coord_batches = [] + index_batch = [] + for rep in range(repeat_factor): + out_index = rep * logical_batch + i + feat_batches.append(x_eval[out_index, :count]) + coords_rep = coords_i.clone() + coords_rep[:, 0] = rep + coord_batches.append(coords_rep) + index_batch.append(out_index) - batched_coords = torch.cat(coords_list, dim=0) + print( + "TRELLIS2_GROUPED_INPUT_TRACE", + { + "mode": mode, + "sample_index": int(i), + "coord_count": int(count), + "feat_norms": [float(v.square().sum().detach().cpu().item()) for v in feat_batches], + }, + ) + + x_st_i = SparseTensor( + feats=torch.cat(feat_batches, dim=0), + coords=torch.cat(coord_batches, dim=0).to(torch.int32), + ) + index_tensor = torch.tensor(index_batch, device=x_eval.device, dtype=torch.long) + if t_eval.shape[0] > 1: + t_i = t_eval.index_select(0, index_tensor) + else: + t_i = t_eval + if c_eval.shape[0] > 1: + c_i = c_eval.index_select(0, index_tensor) + else: + c_i = c_eval + + if is_512_run: + sparse_out = self.img2shape_512(x_st_i, t_i, c_i) + else: + sparse_out = self.img2shape(x_st_i, t_i, c_i) + + feats_group, coords_group = sparse_out.to_tensor_list() + if len(feats_group) != repeat_factor: + raise ValueError( + f"Trellis2 expected {repeat_factor} sparse output groups for batch {i}, got {len(feats_group)}" + ) + for rep, (feats_rep, coords_rep) in enumerate(zip(feats_group, coords_group)): + if feats_rep.shape[0] != count: + raise ValueError( + f"Trellis2 sparse output rows for batch {i} rep {rep} expected {count}, got {feats_rep.shape[0]}" + ) + if coords_rep.shape[0] != count: + raise ValueError( + f"Trellis2 sparse output coords for batch {i} rep {rep} expected {count}, got {coords_rep.shape[0]}" + ) + grouped_outs.append(feats_group) + grouped_counts.append(count) + + for rep in range(repeat_factor): + for i in range(logical_batch): + sparse_outs.append(grouped_outs[i][rep]) + active_coord_counts.append(grouped_counts[i]) + else: + for rep in range(repeat_factor): + for i in range(logical_batch): + out_index = rep * logical_batch + i + count = int(coord_counts[i].item()) + coords_i = coords[coords[:, 0] == i].clone() + if coords_i.shape[0] != count: + raise ValueError( + f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}" + ) + coords_i[:, 0] = 0 + feats_i = x_eval[out_index, :count] + x_st_i = SparseTensor(feats=feats_i, coords=coords_i.to(torch.int32)) + t_i = t_eval[out_index].unsqueeze(0) if t_eval.shape[0] > 1 else t_eval + c_i = c_eval[out_index].unsqueeze(0) if c_eval.shape[0] > 1 else c_eval + + if mode == "shape_generation": + if is_512_run: + sparse_out = self.img2shape_512(x_st_i, t_i, c_i) + else: + sparse_out = self.img2shape(x_st_i, t_i, c_i) + else: + slat = transformer_options.get("shape_slat") + if slat is None: + raise ValueError("shape_slat can't be None") + if slat.ndim == 3: + if slat.shape[0] != logical_batch: + raise ValueError( + f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}" + ) + if slat.shape[1] < count: + raise ValueError( + f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}" + ) + slat_feats = slat[i, :count].to(x_st_i.device) + else: + slat_feats = slat[:count].to(x_st_i.device) + x_st_i = x_st_i.replace(feats=torch.cat([x_st_i.feats, slat_feats], dim=-1)) + sparse_out = self.shape2txt(x_st_i, t_i, c_i) + + sparse_outs.append(sparse_out.feats) + active_coord_counts.append(count) + + out_channels = sparse_outs[0].shape[-1] + sparse_out_norms = [float(feats.square().sum().detach().cpu().item()) for feats in sparse_outs] + print( + "TRELLIS2_SPARSE_OUT_TRACE", + { + "mode": mode, + "coords_rows": int(coords.shape[0]), + "active_coord_counts": active_coord_counts, + "sparse_out_norms": sparse_out_norms, + }, + ) + padded = sparse_outs[0].new_zeros((B, N, out_channels)) + for out_index, (count, feats_i) in enumerate(zip(active_coord_counts, sparse_outs)): + padded[out_index, :count] = feats_i + dense_out = padded.transpose(1, 2).unsqueeze(-1) + elif coords.shape[0] == N: + feats_flat = x_eval.reshape(-1, C) + coords_list = [] + for i in range(B): + c = coords.clone() + c[:, 0] = i + coords_list.append(c) + batched_coords = torch.cat(coords_list, dim=0) + elif coords.shape[0] == B * N: + feats_flat = x_eval.reshape(-1, C) + batched_coords = coords + else: + raise ValueError( + f"Trellis2 expected coords rows {N} or {B * N}, got {coords.shape[0]}" + ) else: batched_coords = coords feats_flat = x_eval - x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) + if dense_out is None: + x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) - if mode == "shape_generation": + if dense_out is not None: + out = dense_out + elif mode == "shape_generation": if is_512_run: out = self.img2shape_512(x_st, t_eval, c_eval) else: @@ -856,23 +1022,152 @@ class Trellis2(nn.Module): if slat is None: raise ValueError("shape_slat can't be None") - base_slat_feats = slat[:N] - slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device) + if slat.ndim == 3: + if coord_counts is not None: + logical_batch = coord_counts.shape[0] + if slat.shape[0] != logical_batch: + raise ValueError( + f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}" + ) + if B % logical_batch != 0: + raise ValueError( + f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}" + ) + repeat_factor = B // logical_batch + slat_list = [] + for _ in range(repeat_factor): + for i in range(logical_batch): + count = int(coord_counts[i].item()) + if slat.shape[1] < count: + raise ValueError( + f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}" + ) + slat_list.append(slat[i, :count]) + slat_feats_batched = torch.cat(slat_list, dim=0).to(x_st.device) + else: + if slat.shape[0] != B: + raise ValueError(f"shape_slat batch {slat.shape[0]} doesn't match latent batch {B}") + if slat.shape[1] != N: + raise ValueError(f"shape_slat tokens {slat.shape[1]} doesn't match latent tokens {N}") + slat_feats_batched = slat.reshape(B * N, -1).to(x_st.device) + else: + base_slat_feats = slat[:N] + slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device) x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1)) out = self.shape2txt(x_st, t_eval, c_eval) else: # structure orig_bsz = x.shape[0] - if shape_rule and orig_bsz > 1: - half = orig_bsz // 2 - x = x[half:] - timestep = timestep[half:] if timestep.shape[0] > 1 else timestep - out = self.structure_model(x, timestep, cond if shape_rule and orig_bsz > 1 else context) - if shape_rule and orig_bsz > 1: - out = out.repeat(2, 1, 1, 1, 1) + cond_or_uncond = transformer_options.get("cond_or_uncond") or [] + batch_groups = len(cond_or_uncond) if len(cond_or_uncond) > 0 and orig_bsz % len(cond_or_uncond) == 0 else 1 + logical_batch = orig_bsz // batch_groups + print( + "TRELLIS2_STRUCTURE_INPUT_TRACE", + { + "orig_bsz": int(orig_bsz), + "batch_groups": int(batch_groups), + "logical_batch": int(logical_batch), + "cond_or_uncond": cond_or_uncond, + "x_norms": [float(v) for v in x.square().sum(dim=(1, 2, 3, 4)).detach().cpu().tolist()], + "x_sums": [float(v) for v in x.sum(dim=(1, 2, 3, 4)).detach().cpu().tolist()], + "c_norms": [float(v) for v in context.square().sum(dim=(1, 2)).detach().cpu().tolist()], + "c_sums": [float(v) for v in context.sum(dim=(1, 2)).detach().cpu().tolist()], + }, + ) + + if logical_batch > 1: + x_groups = x.reshape(batch_groups, logical_batch, *x.shape[1:]) + if timestep.shape[0] > 1: + t_groups = timestep.reshape(batch_groups, logical_batch, *timestep.shape[1:]) + else: + t_groups = timestep + c_groups = context.reshape(batch_groups, logical_batch, *context.shape[1:]) + + if shape_rule and batch_groups > 1: + selected_group_indices = [batch_groups - 1] + else: + selected_group_indices = list(range(batch_groups)) + + out_groups = [] + selected_x_norms = [] + selected_x_sums = [] + selected_c_norms = [] + selected_c_sums = [] + for sample_index in range(logical_batch): + if shape_rule and batch_groups > 1: + half = orig_bsz // 2 + x_i = x[half + sample_index].unsqueeze(0) + if timestep.shape[0] > 1: + t_i = timestep[half + sample_index].unsqueeze(0) + else: + t_i = timestep + if cond.shape[0] > 1: + c_i = cond[sample_index].unsqueeze(0) + else: + c_i = cond + else: + x_i = x_groups[selected_group_indices, sample_index] + if timestep.shape[0] > 1: + t_i = t_groups[selected_group_indices, sample_index] + else: + t_i = timestep + c_i = c_groups[selected_group_indices, sample_index] + selected_x_norms.extend(float(v) for v in x_i.square().sum(dim=(1, 2, 3, 4)).detach().cpu().tolist()) + selected_x_sums.extend(float(v) for v in x_i.sum(dim=(1, 2, 3, 4)).detach().cpu().tolist()) + selected_c_norms.extend(float(v) for v in c_i.square().sum(dim=(1, 2)).detach().cpu().tolist()) + selected_c_sums.extend(float(v) for v in c_i.sum(dim=(1, 2)).detach().cpu().tolist()) + out_groups.append(self.structure_model(x_i, t_i, c_i)) + + print( + "TRELLIS2_STRUCTURE_SELECTED_TRACE", + { + "selected_group_indices": selected_group_indices, + "selected_x_norms": selected_x_norms, + "selected_x_sums": selected_x_sums, + "selected_c_norms": selected_c_norms, + "selected_c_sums": selected_c_sums, + }, + ) + + out = out_groups[0].new_zeros((orig_bsz, *out_groups[0].shape[1:])) + for sample_index, out_sample in enumerate(out_groups): + if shape_rule and batch_groups > 1: + repeated = out_sample[0] + for group_index in range(batch_groups): + out[group_index * logical_batch + sample_index] = repeated + else: + for local_group_index, group_index in enumerate(selected_group_indices): + out[group_index * logical_batch + sample_index] = out_sample[local_group_index] + else: + if shape_rule and orig_bsz > 1: + half = orig_bsz // 2 + x = x[half:] + timestep = timestep[half:] if timestep.shape[0] > 1 else timestep + out = self.structure_model(x, timestep, cond if shape_rule and orig_bsz > 1 else context) + if shape_rule and orig_bsz > 1: + out = out.repeat(2, 1, 1, 1, 1) + + print( + "TRELLIS2_STRUCTURE_OUTPUT_TRACE", + { + "out_norms": [float(v) for v in out.square().sum(dim=(1, 2, 3, 4)).detach().cpu().tolist()], + "out_sums": [float(v) for v in out.sum(dim=(1, 2, 3, 4)).detach().cpu().tolist()], + }, + ) if not_struct_mode: - out = out.feats - out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) - if rule and orig_bsz > 1: - out = out.repeat(orig_bsz, 1, 1, 1) + if dense_out is None: + out = out.feats + out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) + if rule and orig_bsz > B: + out = out.repeat(orig_bsz // B, 1, 1, 1) + print( + "TRELLIS2_DENSE_OUT_TRACE", + { + "mode": mode, + "coords_rows": int(coords.shape[0]) if coords is not None else None, + "output_shape": list(out.shape), + "output_norms": [float(v) for v in out.squeeze(-1).square().sum(dim=(1, 2)).detach().cpu().tolist()], + "coord_counts": coord_counts.tolist() if coord_counts is not None else None, + }, + ) return out diff --git a/comfy/sample.py b/comfy/sample.py index 653829582..3967fba1b 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -7,6 +7,23 @@ import logging import comfy.nested_tensor def prepare_noise_inner(latent_image, generator, noise_inds=None): + coord_counts = getattr(latent_image, "trellis_coord_counts", None) + if coord_counts is not None: + noise = torch.zeros(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, device="cpu") + base_state = generator.get_state() + for i, count in enumerate(coord_counts.tolist()): + local_generator = torch.Generator(device="cpu") + local_generator.set_state(base_state.clone()) + sample_noise = torch.randn( + [1, latent_image.size(1), int(count), latent_image.size(3)], + dtype=torch.float32, + layout=latent_image.layout, + generator=local_generator, + device="cpu", + ) + noise[i:i + 1, :, :int(count), :] = sample_noise + return noise.to(dtype=latent_image.dtype) + if noise_inds is None: return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 8121e261b..26cb135e7 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -96,6 +96,70 @@ def shape_norm(shape_latent, coords): samples = samples * std + mean return samples + +def infer_batched_coord_layout(coords): + if coords.ndim != 2 or coords.shape[1] != 4: + raise ValueError(f"Expected Trellis2 coords with shape [N, 4], got {tuple(coords.shape)}") + + if coords.shape[0] == 0: + raise ValueError("Trellis2 coords can't be empty") + + batch_ids = coords[:, 0].to(torch.int64) + batch_size = int(batch_ids.max().item()) + 1 + counts = torch.bincount(batch_ids, minlength=batch_size) + + if (counts == 0).any(): + raise ValueError(f"Non-contiguous Trellis2 batch ids in coords: {batch_ids.unique(sorted=True).tolist()}") + + max_tokens = int(counts.max().item()) + return batch_size, counts, max_tokens + + +def flatten_batched_sparse_latent(samples, coords, coord_counts): + samples = samples.squeeze(-1).transpose(1, 2) + if coord_counts is None: + return samples.reshape(-1, samples.shape[-1]), coords + + feat_list = [] + coord_list = [] + for i in range(coord_counts.shape[0]): + count = int(coord_counts[i].item()) + coords_i = coords[coords[:, 0] == i] + if coords_i.shape[0] != count: + raise ValueError(f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}") + feat_list.append(samples[i, :count]) + coord_list.append(coords_i) + + return torch.cat(feat_list, dim=0), torch.cat(coord_list, dim=0) + + +def split_batched_sparse_latent(samples, coords, coord_counts): + samples = samples.squeeze(-1).transpose(1, 2) + if coord_counts is None: + return [(samples.reshape(-1, samples.shape[-1]), coords)] + + items = [] + for i in range(coord_counts.shape[0]): + count = int(coord_counts[i].item()) + coords_i = coords[coords[:, 0] == i] + if coords_i.shape[0] != count: + raise ValueError(f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}") + items.append((samples[i, :count], coords_i)) + return items + + +def log_sparse_batch_trace(tag, items): + feat_norms = [float(feats.square().sum().detach().cpu().item()) for feats, _ in items] + coord_rows = [int(coords_i.shape[0]) for _, coords_i in items] + print( + tag, + { + "batch_size": len(items), + "coord_rows": coord_rows, + "feat_norms": feat_norms, + }, + ) + def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): """ Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field. @@ -169,12 +233,32 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): vae = vae.first_stage_model coords = samples["coords"] + coord_counts = samples.get("coord_counts") samples = samples["samples"] - samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) - samples = shape_norm(samples, coords) + if coord_counts is None: + samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts) + samples = shape_norm(samples.to(device), coords.to(device)) + mesh, subs = vae.decode_shape_slat(samples, resolution) + else: + split_items = split_batched_sparse_latent(samples, coords, coord_counts) + mesh = [] + subs_per_sample = [] + for feats_i, coords_i in split_items: + coords_i = coords_i.to(device).clone() + coords_i[:, 0] = 0 + sample_i = shape_norm(feats_i.to(device), coords_i) + mesh_i, subs_i = vae.decode_shape_slat(sample_i, resolution) + mesh.append(mesh_i[0]) + subs_per_sample.append(subs_i) + + subs = [] + for stage_index in range(len(subs_per_sample[0])): + stage_tensors = [sample_subs[stage_index] for sample_subs in subs_per_sample] + feats_list = [stage_tensor.feats for stage_tensor in stage_tensors] + coords_list = [stage_tensor.coords for stage_tensor in stage_tensors] + subs.append(SparseTensor.from_tensor_list(feats_list, coords_list)) - mesh, subs = vae.decode_shape_slat(samples, resolution) face_list = [m.faces for m in mesh] vert_list = [m.vertices for m in mesh] if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list): @@ -210,12 +294,14 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): vae = vae.first_stage_model coords = samples["coords"] + coord_counts = samples.get("coord_counts") samples = samples["samples"] - samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) + samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts) + samples = samples.to(device) std = tex_slat_normalization["std"].to(samples) mean = tex_slat_normalization["mean"].to(samples) - samples = SparseTensor(feats = samples, coords=coords) + samples = SparseTensor(feats = samples, coords=coords.to(device)) samples = samples * std + mean voxel = vae.decode_tex_slat(samples, shape_subs) @@ -273,7 +359,13 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): decoder = decoder.to(load_device) samples = samples["samples"] samples = samples.to(load_device) - decoded = decoder(samples)>0 + if samples.shape[0] > 1: + decoded_items = [] + for i in range(samples.shape[0]): + decoded_items.append(decoder(samples[i:i + 1]) > 0) + decoded = torch.cat(decoded_items, dim=0) + else: + decoded = decoder(samples) > 0 decoder.to(offload_device) current_res = decoded.shape[2] @@ -305,32 +397,102 @@ class Trellis2UpsampleCascade(IO.ComfyNode): device = comfy.model_management.get_torch_device() comfy.model_management.load_model_gpu(vae.patcher) - feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) - coords_512 = shape_latent_512["coords"].to(device) - - slat = shape_norm(feats, coords_512) - + coord_counts = shape_latent_512.get("coord_counts") decoder = vae.first_stage_model.shape_dec - - slat.feats = slat.feats.to(next(decoder.parameters()).dtype) - hr_coords = decoder.upsample(slat, upsample_times=4) - lr_resolution = 512 - hr_resolution = int(target_resolution) + target_resolution = int(target_resolution) - while True: - quant_coords = torch.cat([ - hr_coords[:, :1], - ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), - ], dim=1) - final_coords = quant_coords.unique(dim=0) - num_tokens = final_coords.shape[0] + if coord_counts is None: + feats, coords_512 = flatten_batched_sparse_latent( + shape_latent_512["samples"], + shape_latent_512["coords"], + coord_counts, + ) + feats = feats.to(device) + coords_512 = coords_512.to(device) + print( + "TRELLIS2_UPSAMPLE_INPUT_TRACE", + { + "batch_size": 1, + "coord_rows": [int(coords_512.shape[0])], + "feat_norms": [float(feats.square().sum().detach().cpu().item())], + }, + ) + slat = shape_norm(feats, coords_512) + slat.feats = slat.feats.to(next(decoder.parameters()).dtype) + hr_coords = decoder.upsample(slat, upsample_times=4) - if num_tokens < max_tokens or hr_resolution <= 1024: - break - hr_resolution -= 128 + hr_resolution = target_resolution + while True: + quant_coords = torch.cat([ + hr_coords[:, :1], + ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + final_coords = quant_coords.unique(dim=0) + num_tokens = final_coords.shape[0] - return IO.NodeOutput(final_coords,) + if num_tokens < max_tokens or hr_resolution <= 1024: + break + hr_resolution -= 128 + + print( + "TRELLIS2_UPSAMPLE_OUTPUT_TRACE", + { + "batch_size": 1, + "coord_rows": [int(final_coords.shape[0])], + "hr_resolution": int(hr_resolution), + }, + ) + return IO.NodeOutput(final_coords,) + + final_coords_list = [] + items = split_batched_sparse_latent( + shape_latent_512["samples"], + shape_latent_512["coords"], + coord_counts, + ) + log_sparse_batch_trace("TRELLIS2_UPSAMPLE_INPUT_TRACE", items) + decoder_dtype = next(decoder.parameters()).dtype + + output_coord_rows = [] + output_resolutions = [] + for batch_index, (feats_i, coords_i) in enumerate(items): + feats_i = feats_i.to(device) + coords_i = coords_i.to(device).clone() + coords_i[:, 0] = 0 + slat_i = shape_norm(feats_i, coords_i) + slat_i.feats = slat_i.feats.to(decoder_dtype) + hr_coords_i = decoder.upsample(slat_i, upsample_times=4) + + hr_resolution = target_resolution + while True: + quant_coords_i = torch.cat([ + hr_coords_i[:, :1], + ((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + final_coords_i = quant_coords_i.unique(dim=0) + num_tokens = final_coords_i.shape[0] + + if num_tokens < max_tokens or hr_resolution <= 1024: + break + hr_resolution -= 128 + + final_coords_i = final_coords_i.clone() + final_coords_i[:, 0] = batch_index + final_coords_list.append(final_coords_i) + output_coord_rows.append(int(final_coords_i.shape[0])) + output_resolutions.append(int(hr_resolution)) + + print( + "TRELLIS2_UPSAMPLE_OUTPUT_TRACE", + { + "batch_size": len(final_coords_list), + "coord_rows": output_coord_rows, + "hr_resolution": output_resolutions, + }, + ) + + return IO.NodeOutput(torch.cat(final_coords_list, dim=0),) dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) @@ -406,6 +568,7 @@ class Trellis2Conditioning(IO.ComfyNode): cond_512_list = [] cond_1024_list = [] + composite_trace = [] for b in range(batch_size): item_image = image[b] @@ -460,6 +623,14 @@ class Trellis2Conditioning(IO.ComfyNode): # to match trellis2 code (quantize -> dequantize) composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) + composite_trace.append( + { + "sample_index": int(b), + "shape": list(composite_uint8.shape), + "sum": int(composite_uint8.sum(dtype=np.int64)), + "prefix": composite_uint8[:2, :2, :].reshape(-1).tolist(), + } + ) cropped_pil = Image.fromarray(composite_uint8) @@ -471,6 +642,19 @@ class Trellis2Conditioning(IO.ComfyNode): cond_1024_batched = torch.cat(cond_1024_list, dim=0) neg_cond_batched = torch.zeros_like(cond_512_batched) neg_embeds_batched = torch.zeros_like(cond_1024_batched) + print( + "TRELLIS2_CONDITIONING_TRACE", + { + "batch_size": int(batch_size), + "cond_512_norms": [float(v) for v in cond_512_batched.square().sum(dim=(1, 2)).detach().cpu().tolist()], + "cond_512_sums": [float(v) for v in cond_512_batched.sum(dim=(1, 2)).detach().cpu().tolist()], + "cond_512_prefix": cond_512_batched[:, 0, :8].detach().cpu().tolist(), + "cond_1024_norms": [float(v) for v in cond_1024_batched.square().sum(dim=(1, 2)).detach().cpu().tolist()], + "cond_1024_sums": [float(v) for v in cond_1024_batched.sum(dim=(1, 2)).detach().cpu().tolist()], + "cond_1024_prefix": cond_1024_batched[:, 0, :8].detach().cpu().tolist(), + "composite_trace": composite_trace, + }, + ) positive = [[cond_512_batched, {"embeds": cond_1024_batched}]] negative = [[neg_cond_batched, {"embeds": neg_embeds_batched}]] @@ -509,8 +693,32 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): else: raise ValueError(f"Invalid input to EmptyShapeLatent: {type(structure_or_coords)}") in_channels = 32 - # image like format - latent = torch.randn(1, in_channels, coords.shape[0], 1) + batch_size, coord_counts, max_tokens = infer_batched_coord_layout(coords) + if batch_size == 1: + coord_counts = None + latent = torch.randn(1, in_channels, coords.shape[0], 1) + else: + latent = torch.zeros(batch_size, in_channels, max_tokens, 1) + base_state = torch.random.get_rng_state() + for i in range(batch_size): + count = int(coord_counts[i].item()) + generator = torch.Generator(device="cpu") + generator.set_state(base_state.clone()) + latent_i = torch.randn(1, in_channels, count, 1, generator=generator) + latent[i, :, :count] = latent_i[0] + if coords.shape[0] > 1000: + norms = [float(v) for v in latent.squeeze(-1).square().sum(dim=(1, 2)).detach().cpu().tolist()] + print( + "TRELLIS2_EMPTY_SHAPE_TRACE", + { + "coords_rows": int(coords.shape[0]), + "batch_size": int(batch_size), + "coord_counts": coord_counts.tolist() if coord_counts is not None else None, + "latent_norms": norms, + }, + ) + if coord_counts is not None: + latent.trellis_coord_counts = coord_counts.clone() model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: @@ -519,11 +727,17 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords + if coord_counts is not None: + model.model_options["transformer_options"]["coord_counts"] = coord_counts if is_512_pass: model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512" else: model.model_options["transformer_options"]["generation_mode"] = "shape_generation" - return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model) + output = {"samples": latent, "coords": coords, "type": "trellis2"} + if coord_counts is not None: + output["coord_counts"] = coord_counts + output["batch_index"] = [0] * batch_size + return IO.NodeOutput(output, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -553,10 +767,45 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): coords = structure_or_coords.int() shape_latent = shape_latent["samples"] + batch_size, coord_counts, max_tokens = infer_batched_coord_layout(coords) if shape_latent.ndim == 4: - shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels) + if shape_latent.shape[0] != batch_size: + raise ValueError( + f"shape_latent batch {shape_latent.shape[0]} doesn't match coords batch {batch_size}" + ) + shape_latent = shape_latent.squeeze(-1).transpose(1, 2) + if shape_latent.shape[1] < max_tokens: + raise ValueError( + f"shape_latent tokens {shape_latent.shape[1]} can't cover coords max tokens {max_tokens}" + ) - latent = torch.randn(1, channels, coords.shape[0], 1) + if batch_size == 1: + coord_counts = None + latent = torch.randn(1, channels, coords.shape[0], 1) + else: + latent = torch.zeros(batch_size, channels, max_tokens, 1) + base_state = torch.random.get_rng_state() + for i in range(batch_size): + count = int(coord_counts[i].item()) + generator = torch.Generator(device="cpu") + generator.set_state(base_state.clone()) + latent_i = torch.randn(1, channels, count, 1, generator=generator) + latent[i, :, :count] = latent_i[0] + if coords.shape[0] > 1000: + norms = [float(v) for v in latent.squeeze(-1).square().sum(dim=(1, 2)).detach().cpu().tolist()] + shape_norms = [float(v) for v in shape_latent.square().sum(dim=(1, 2)).detach().cpu().tolist()] if shape_latent.ndim == 3 else None + print( + "TRELLIS2_EMPTY_TEXTURE_TRACE", + { + "coords_rows": int(coords.shape[0]), + "batch_size": int(batch_size), + "coord_counts": coord_counts.tolist() if coord_counts is not None else None, + "latent_norms": norms, + "shape_latent_norms": shape_norms, + }, + ) + if coord_counts is not None: + latent.trellis_coord_counts = coord_counts.clone() model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: @@ -565,9 +814,15 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords + if coord_counts is not None: + model.model_options["transformer_options"]["coord_counts"] = coord_counts model.model_options["transformer_options"]["generation_mode"] = "texture_generation" model.model_options["transformer_options"]["shape_slat"] = shape_latent - return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model) + output = {"samples": latent, "coords": coords, "type": "trellis2"} + if coord_counts is not None: + output["coord_counts"] = coord_counts + output["batch_index"] = [0] * batch_size + return IO.NodeOutput(output, model) class EmptyStructureLatentTrellis2(IO.ComfyNode): @@ -587,8 +842,15 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): def execute(cls, batch_size): in_channels = 8 resolution = 16 - latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + generator = torch.Generator(device="cpu") + generator.manual_seed(11426) + latent = torch.randn(1, in_channels, resolution, resolution, resolution, generator=generator).repeat(batch_size, 1, 1, 1, 1) + norms = [float(v) for v in latent.square().sum(dim=(1, 2, 3, 4)).detach().cpu().tolist()] + print("TRELLIS2_EMPTY_STRUCTURE_TRACE", {"batch_size": int(batch_size), "latent_norms": norms}) + output = {"samples": latent, "type": "trellis2"} + if batch_size > 1: + output["batch_index"] = [0] * batch_size + return IO.NodeOutput(output) def simplify_fn(vertices, faces, colors=None, target=100000): if vertices.ndim == 3: From 49c1adeed6c91b790bf1ab87dffbec6a6e1eae6f Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 12:15:49 -0500 Subject: [PATCH 75/93] Fix Trellis PR review regressions --- comfy/ldm/trellis2/model.py | 86 -------------------- comfy/sample.py | 29 +++++-- comfy_extras/nodes_trellis2.py | 139 +++++++++++---------------------- 3 files changed, 68 insertions(+), 186 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 76dbacc93..f61c50629 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -829,21 +829,6 @@ class Trellis2(nn.Module): t_eval = timestep c_eval = context - x_eval_norms = [float(v) for v in x_eval.square().sum(dim=(1, 2)).detach().cpu().tolist()] - c_eval_norms = [float(v) for v in c_eval.square().sum(dim=(1, 2)).detach().cpu().tolist()] - print( - "TRELLIS2_NOT_STRUCT_INPUT_TRACE", - { - "mode": mode, - "orig_bsz": int(orig_bsz), - "logical_batch": int(logical_batch), - "rule": bool(rule), - "coord_counts": coord_counts.tolist() if coord_counts is not None else None, - "x_eval_norms": x_eval_norms, - "c_eval_norms": c_eval_norms, - }, - ) - B, N, C = x_eval.shape if mode in ["shape_generation", "texture_generation"]: @@ -878,16 +863,6 @@ class Trellis2(nn.Module): coord_batches.append(coords_rep) index_batch.append(out_index) - print( - "TRELLIS2_GROUPED_INPUT_TRACE", - { - "mode": mode, - "sample_index": int(i), - "coord_count": int(count), - "feat_norms": [float(v.square().sum().detach().cpu().item()) for v in feat_batches], - }, - ) - x_st_i = SparseTensor( feats=torch.cat(feat_batches, dim=0), coords=torch.cat(coord_batches, dim=0).to(torch.int32), @@ -972,16 +947,6 @@ class Trellis2(nn.Module): active_coord_counts.append(count) out_channels = sparse_outs[0].shape[-1] - sparse_out_norms = [float(feats.square().sum().detach().cpu().item()) for feats in sparse_outs] - print( - "TRELLIS2_SPARSE_OUT_TRACE", - { - "mode": mode, - "coords_rows": int(coords.shape[0]), - "active_coord_counts": active_coord_counts, - "sparse_out_norms": sparse_out_norms, - }, - ) padded = sparse_outs[0].new_zeros((B, N, out_channels)) for out_index, (count, feats_i) in enumerate(zip(active_coord_counts, sparse_outs)): padded[out_index, :count] = feats_i @@ -1060,20 +1025,6 @@ class Trellis2(nn.Module): cond_or_uncond = transformer_options.get("cond_or_uncond") or [] batch_groups = len(cond_or_uncond) if len(cond_or_uncond) > 0 and orig_bsz % len(cond_or_uncond) == 0 else 1 logical_batch = orig_bsz // batch_groups - print( - "TRELLIS2_STRUCTURE_INPUT_TRACE", - { - "orig_bsz": int(orig_bsz), - "batch_groups": int(batch_groups), - "logical_batch": int(logical_batch), - "cond_or_uncond": cond_or_uncond, - "x_norms": [float(v) for v in x.square().sum(dim=(1, 2, 3, 4)).detach().cpu().tolist()], - "x_sums": [float(v) for v in x.sum(dim=(1, 2, 3, 4)).detach().cpu().tolist()], - "c_norms": [float(v) for v in context.square().sum(dim=(1, 2)).detach().cpu().tolist()], - "c_sums": [float(v) for v in context.sum(dim=(1, 2)).detach().cpu().tolist()], - }, - ) - if logical_batch > 1: x_groups = x.reshape(batch_groups, logical_batch, *x.shape[1:]) if timestep.shape[0] > 1: @@ -1088,10 +1039,6 @@ class Trellis2(nn.Module): selected_group_indices = list(range(batch_groups)) out_groups = [] - selected_x_norms = [] - selected_x_sums = [] - selected_c_norms = [] - selected_c_sums = [] for sample_index in range(logical_batch): if shape_rule and batch_groups > 1: half = orig_bsz // 2 @@ -1111,23 +1058,8 @@ class Trellis2(nn.Module): else: t_i = timestep c_i = c_groups[selected_group_indices, sample_index] - selected_x_norms.extend(float(v) for v in x_i.square().sum(dim=(1, 2, 3, 4)).detach().cpu().tolist()) - selected_x_sums.extend(float(v) for v in x_i.sum(dim=(1, 2, 3, 4)).detach().cpu().tolist()) - selected_c_norms.extend(float(v) for v in c_i.square().sum(dim=(1, 2)).detach().cpu().tolist()) - selected_c_sums.extend(float(v) for v in c_i.sum(dim=(1, 2)).detach().cpu().tolist()) out_groups.append(self.structure_model(x_i, t_i, c_i)) - print( - "TRELLIS2_STRUCTURE_SELECTED_TRACE", - { - "selected_group_indices": selected_group_indices, - "selected_x_norms": selected_x_norms, - "selected_x_sums": selected_x_sums, - "selected_c_norms": selected_c_norms, - "selected_c_sums": selected_c_sums, - }, - ) - out = out_groups[0].new_zeros((orig_bsz, *out_groups[0].shape[1:])) for sample_index, out_sample in enumerate(out_groups): if shape_rule and batch_groups > 1: @@ -1146,28 +1078,10 @@ class Trellis2(nn.Module): if shape_rule and orig_bsz > 1: out = out.repeat(2, 1, 1, 1, 1) - print( - "TRELLIS2_STRUCTURE_OUTPUT_TRACE", - { - "out_norms": [float(v) for v in out.square().sum(dim=(1, 2, 3, 4)).detach().cpu().tolist()], - "out_sums": [float(v) for v in out.sum(dim=(1, 2, 3, 4)).detach().cpu().tolist()], - }, - ) - if not_struct_mode: if dense_out is None: out = out.feats out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) if rule and orig_bsz > B: out = out.repeat(orig_bsz // B, 1, 1, 1) - print( - "TRELLIS2_DENSE_OUT_TRACE", - { - "mode": mode, - "coords_rows": int(coords.shape[0]) if coords is not None else None, - "output_shape": list(out.shape), - "output_norms": [float(v) for v in out.squeeze(-1).square().sum(dim=(1, 2)).detach().cpu().tolist()], - "coord_counts": coord_counts.tolist() if coord_counts is not None else None, - }, - ) return out diff --git a/comfy/sample.py b/comfy/sample.py index 3967fba1b..7251aa799 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -10,18 +10,37 @@ def prepare_noise_inner(latent_image, generator, noise_inds=None): coord_counts = getattr(latent_image, "trellis_coord_counts", None) if coord_counts is not None: noise = torch.zeros(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, device="cpu") - base_state = generator.get_state() - for i, count in enumerate(coord_counts.tolist()): + if noise_inds is None: + noise_inds = np.arange(latent_image.size(0), dtype=np.int64) + else: + noise_inds = np.asarray(noise_inds, dtype=np.int64) + + unique_inds = np.unique(noise_inds) + first_indices = {int(unique_index): int(np.flatnonzero(noise_inds == unique_index)[0]) for unique_index in unique_inds.tolist()} + index_states = {} + for unique_index in sorted(first_indices): + index_states[unique_index] = generator.get_state().clone() + count = int(coord_counts[first_indices[unique_index]].item()) + torch.randn( + [1, latent_image.size(1), count, latent_image.size(3)], + dtype=torch.float32, + layout=latent_image.layout, + generator=generator, + device="cpu", + ) + + for batch_index, noise_index in enumerate(noise_inds.tolist()): + count = int(coord_counts[batch_index].item()) local_generator = torch.Generator(device="cpu") - local_generator.set_state(base_state.clone()) + local_generator.set_state(index_states[int(noise_index)].clone()) sample_noise = torch.randn( - [1, latent_image.size(1), int(count), latent_image.size(3)], + [1, latent_image.size(1), count, latent_image.size(3)], dtype=torch.float32, layout=latent_image.layout, generator=local_generator, device="cpu", ) - noise[i:i + 1, :, :int(count), :] = sample_noise + noise[batch_index:batch_index + 1, :, :count, :] = sample_noise return noise.to(dtype=latent_image.dtype) if noise_inds is None: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 26cb135e7..621cc9586 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -148,18 +148,6 @@ def split_batched_sparse_latent(samples, coords, coord_counts): return items -def log_sparse_batch_trace(tag, items): - feat_norms = [float(feats.square().sum().detach().cpu().item()) for feats, _ in items] - coord_rows = [int(coords_i.shape[0]) for _, coords_i in items] - print( - tag, - { - "batch_size": len(items), - "coord_rows": coord_rows, - "feat_norms": feat_norms, - }, - ) - def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): """ Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field. @@ -410,14 +398,6 @@ class Trellis2UpsampleCascade(IO.ComfyNode): ) feats = feats.to(device) coords_512 = coords_512.to(device) - print( - "TRELLIS2_UPSAMPLE_INPUT_TRACE", - { - "batch_size": 1, - "coord_rows": [int(coords_512.shape[0])], - "feat_norms": [float(feats.square().sum().detach().cpu().item())], - }, - ) slat = shape_norm(feats, coords_512) slat.feats = slat.feats.to(next(decoder.parameters()).dtype) hr_coords = decoder.upsample(slat, upsample_times=4) @@ -435,27 +415,18 @@ class Trellis2UpsampleCascade(IO.ComfyNode): break hr_resolution -= 128 - print( - "TRELLIS2_UPSAMPLE_OUTPUT_TRACE", - { - "batch_size": 1, - "coord_rows": [int(final_coords.shape[0])], - "hr_resolution": int(hr_resolution), - }, - ) return IO.NodeOutput(final_coords,) - final_coords_list = [] items = split_batched_sparse_latent( shape_latent_512["samples"], shape_latent_512["coords"], coord_counts, ) - log_sparse_batch_trace("TRELLIS2_UPSAMPLE_INPUT_TRACE", items) decoder_dtype = next(decoder.parameters()).dtype - output_coord_rows = [] + final_coords_list = [] output_resolutions = [] + output_coord_counts = [] for batch_index, (feats_i, coords_i) in enumerate(items): feats_i = feats_i.to(device) coords_i = coords_i.to(device).clone() @@ -480,19 +451,14 @@ class Trellis2UpsampleCascade(IO.ComfyNode): final_coords_i = final_coords_i.clone() final_coords_i[:, 0] = batch_index final_coords_list.append(final_coords_i) - output_coord_rows.append(int(final_coords_i.shape[0])) output_resolutions.append(int(hr_resolution)) + output_coord_counts.append(int(final_coords_i.shape[0])) - print( - "TRELLIS2_UPSAMPLE_OUTPUT_TRACE", - { - "batch_size": len(final_coords_list), - "coord_rows": output_coord_rows, - "hr_resolution": output_resolutions, - }, - ) - - return IO.NodeOutput(torch.cat(final_coords_list, dim=0),) + return IO.NodeOutput({ + "coords": torch.cat(final_coords_list, dim=0), + "coord_counts": torch.tensor(output_coord_counts, dtype=torch.int64), + "resolutions": torch.tensor(output_resolutions, dtype=torch.int64), + },) dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) @@ -568,7 +534,6 @@ class Trellis2Conditioning(IO.ComfyNode): cond_512_list = [] cond_1024_list = [] - composite_trace = [] for b in range(batch_size): item_image = image[b] @@ -623,14 +588,6 @@ class Trellis2Conditioning(IO.ComfyNode): # to match trellis2 code (quantize -> dequantize) composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) - composite_trace.append( - { - "sample_index": int(b), - "shape": list(composite_uint8.shape), - "sum": int(composite_uint8.sum(dtype=np.int64)), - "prefix": composite_uint8[:2, :2, :].reshape(-1).tolist(), - } - ) cropped_pil = Image.fromarray(composite_uint8) @@ -642,19 +599,6 @@ class Trellis2Conditioning(IO.ComfyNode): cond_1024_batched = torch.cat(cond_1024_list, dim=0) neg_cond_batched = torch.zeros_like(cond_512_batched) neg_embeds_batched = torch.zeros_like(cond_1024_batched) - print( - "TRELLIS2_CONDITIONING_TRACE", - { - "batch_size": int(batch_size), - "cond_512_norms": [float(v) for v in cond_512_batched.square().sum(dim=(1, 2)).detach().cpu().tolist()], - "cond_512_sums": [float(v) for v in cond_512_batched.sum(dim=(1, 2)).detach().cpu().tolist()], - "cond_512_prefix": cond_512_batched[:, 0, :8].detach().cpu().tolist(), - "cond_1024_norms": [float(v) for v in cond_1024_batched.square().sum(dim=(1, 2)).detach().cpu().tolist()], - "cond_1024_sums": [float(v) for v in cond_1024_batched.sum(dim=(1, 2)).detach().cpu().tolist()], - "cond_1024_prefix": cond_1024_batched[:, 0, :8].detach().cpu().tolist(), - "composite_trace": composite_trace, - }, - ) positive = [[cond_512_batched, {"embeds": cond_1024_batched}]] negative = [[neg_cond_batched, {"embeds": neg_embeds_batched}]] @@ -680,12 +624,20 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): def execute(cls, structure_or_coords, model): # to accept the upscaled coords is_512_pass = False + coord_counts = None + coord_resolutions = None if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: decoded = structure_or_coords.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() is_512_pass = True + elif isinstance(structure_or_coords, dict): + coords = structure_or_coords["coords"].int() + coord_counts = structure_or_coords.get("coord_counts") + coord_resolutions = structure_or_coords.get("resolutions") + is_512_pass = False + elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: coords = structure_or_coords.int() is_512_pass = False @@ -693,7 +645,15 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): else: raise ValueError(f"Invalid input to EmptyShapeLatent: {type(structure_or_coords)}") in_channels = 32 - batch_size, coord_counts, max_tokens = infer_batched_coord_layout(coords) + batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords) + if coord_counts is not None: + coord_counts = coord_counts.to(dtype=torch.int64, device=coords.device) + if coord_counts.shape != inferred_coord_counts.shape or not torch.equal(coord_counts, inferred_coord_counts): + raise ValueError( + f"Trellis2 coord_counts metadata {coord_counts.tolist()} does not match coords layout {inferred_coord_counts.tolist()}" + ) + else: + coord_counts = inferred_coord_counts if batch_size == 1: coord_counts = None latent = torch.randn(1, in_channels, coords.shape[0], 1) @@ -706,17 +666,6 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): generator.set_state(base_state.clone()) latent_i = torch.randn(1, in_channels, count, 1, generator=generator) latent[i, :, :count] = latent_i[0] - if coords.shape[0] > 1000: - norms = [float(v) for v in latent.squeeze(-1).square().sum(dim=(1, 2)).detach().cpu().tolist()] - print( - "TRELLIS2_EMPTY_SHAPE_TRACE", - { - "coords_rows": int(coords.shape[0]), - "batch_size": int(batch_size), - "coord_counts": coord_counts.tolist() if coord_counts is not None else None, - "latent_norms": norms, - }, - ) if coord_counts is not None: latent.trellis_coord_counts = coord_counts.clone() model = model.clone() @@ -729,6 +678,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"]["coords"] = coords if coord_counts is not None: model.model_options["transformer_options"]["coord_counts"] = coord_counts + if coord_resolutions is not None: + model.model_options["transformer_options"]["coord_resolutions"] = coord_resolutions if is_512_pass: model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512" else: @@ -736,6 +687,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): output = {"samples": latent, "coords": coords, "type": "trellis2"} if coord_counts is not None: output["coord_counts"] = coord_counts + if coord_resolutions is not None: + output["coord_resolutions"] = coord_resolutions output["batch_index"] = [0] * batch_size return IO.NodeOutput(output, model) @@ -759,15 +712,28 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_or_coords, shape_latent, model): channels = 32 + coord_counts = None if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: decoded = structure_or_coords.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + elif isinstance(structure_or_coords, dict): + coords = structure_or_coords["coords"].int() + coord_counts = structure_or_coords.get("coord_counts") + elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: coords = structure_or_coords.int() shape_latent = shape_latent["samples"] - batch_size, coord_counts, max_tokens = infer_batched_coord_layout(coords) + batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords) + if coord_counts is not None: + coord_counts = coord_counts.to(dtype=torch.int64, device=coords.device) + if coord_counts.shape != inferred_coord_counts.shape or not torch.equal(coord_counts, inferred_coord_counts): + raise ValueError( + f"Trellis2 coord_counts metadata {coord_counts.tolist()} does not match coords layout {inferred_coord_counts.tolist()}" + ) + else: + coord_counts = inferred_coord_counts if shape_latent.ndim == 4: if shape_latent.shape[0] != batch_size: raise ValueError( @@ -791,19 +757,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): generator.set_state(base_state.clone()) latent_i = torch.randn(1, channels, count, 1, generator=generator) latent[i, :, :count] = latent_i[0] - if coords.shape[0] > 1000: - norms = [float(v) for v in latent.squeeze(-1).square().sum(dim=(1, 2)).detach().cpu().tolist()] - shape_norms = [float(v) for v in shape_latent.square().sum(dim=(1, 2)).detach().cpu().tolist()] if shape_latent.ndim == 3 else None - print( - "TRELLIS2_EMPTY_TEXTURE_TRACE", - { - "coords_rows": int(coords.shape[0]), - "batch_size": int(batch_size), - "coord_counts": coord_counts.tolist() if coord_counts is not None else None, - "latent_norms": norms, - "shape_latent_norms": shape_norms, - }, - ) if coord_counts is not None: latent.trellis_coord_counts = coord_counts.clone() model = model.clone() @@ -842,11 +795,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): def execute(cls, batch_size): in_channels = 8 resolution = 16 - generator = torch.Generator(device="cpu") - generator.manual_seed(11426) - latent = torch.randn(1, in_channels, resolution, resolution, resolution, generator=generator).repeat(batch_size, 1, 1, 1, 1) - norms = [float(v) for v in latent.square().sum(dim=(1, 2, 3, 4)).detach().cpu().tolist()] - print("TRELLIS2_EMPTY_STRUCTURE_TRACE", {"batch_size": int(batch_size), "latent_norms": norms}) + latent = torch.randn(1, in_channels, resolution, resolution, resolution).repeat(batch_size, 1, 1, 1, 1) output = {"samples": latent, "type": "trellis2"} if batch_size > 1: output["batch_index"] = [0] * batch_size From 7d98cc1305612becdf0baa734997f84eb296a49d Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 14:29:07 -0500 Subject: [PATCH 76/93] Fix Trellis seeded sparse batch semantics --- comfy/ldm/trellis2/model.py | 183 +++++++----------- comfy/sample.py | 34 ++-- comfy_extras/nodes_trellis2.py | 161 ++++++++++----- .../comfy_extras_test/nodes_trellis2_test.py | 83 ++++++++ tests-unit/comfy_test/sample_test.py | 47 +++++ 5 files changed, 333 insertions(+), 175 deletions(-) create mode 100644 tests-unit/comfy_test/sample_test.py diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index f61c50629..15939e5c6 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -813,6 +813,14 @@ class Trellis2(nn.Module): shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] dense_out = None + cond_or_uncond = transformer_options.get("cond_or_uncond") or [] + + def cond_group_indices(batch_groups): + if len(cond_or_uncond) == batch_groups: + cond_groups = [i for i, marker in enumerate(cond_or_uncond) if marker == 0] + if len(cond_groups) > 0: + return cond_groups + return [batch_groups - 1] if not_struct_mode: orig_bsz = x.shape[0] @@ -820,10 +828,17 @@ class Trellis2(nn.Module): logical_batch = coord_counts.shape[0] if coord_counts is not None else 1 if rule and orig_bsz > logical_batch: - half = orig_bsz // 2 - x_eval = x[half:] - t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep - c_eval = cond + batch_groups = orig_bsz // logical_batch + selected_groups = cond_group_indices(batch_groups) + x_groups = x.reshape(batch_groups, logical_batch, *x.shape[1:]) + x_eval = x_groups[selected_groups].reshape(-1, *x.shape[1:]) + if timestep.shape[0] > 1: + t_groups = timestep.reshape(batch_groups, logical_batch, *timestep.shape[1:]) + t_eval = t_groups[selected_groups].reshape(-1, *timestep.shape[1:]) + else: + t_eval = timestep + c_groups = context.reshape(batch_groups, logical_batch, *context.shape[1:]) + c_eval = c_groups[selected_groups].reshape(-1, *context.shape[1:]) else: x_eval = x t_eval = timestep @@ -838,113 +853,62 @@ class Trellis2(nn.Module): raise ValueError( f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}" ) + batch_ids = coords[:, 0].to(torch.int64) + order = torch.argsort(batch_ids, stable=True) + sorted_coords = coords.index_select(0, order) + sorted_batch_ids = batch_ids.index_select(0, order) + offsets = coord_counts.cumsum(0) - coord_counts + coords_by_batch = [] + for i in range(logical_batch): + count = int(coord_counts[i].item()) + start = int(offsets[i].item()) + coords_i = sorted_coords[start:start + count] + ids_i = sorted_batch_ids[start:start + count] + if coords_i.shape[0] != count or not torch.all(ids_i == i): + raise ValueError( + f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}" + ) + coords_by_batch.append(coords_i) repeat_factor = B // logical_batch sparse_outs = [] active_coord_counts = [] - if mode == "shape_generation" and repeat_factor > 1: - grouped_outs = [] - grouped_counts = [] + for rep in range(repeat_factor): for i in range(logical_batch): + out_index = rep * logical_batch + i count = int(coord_counts[i].item()) - coords_i = coords[coords[:, 0] == i].clone() - if coords_i.shape[0] != count: - raise ValueError( - f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}" - ) + coords_i = coords_by_batch[i].clone() + coords_i[:, 0] = 0 + feats_i = x_eval[out_index, :count].clone() + x_st_i = SparseTensor(feats=feats_i, coords=coords_i.to(torch.int32)) + t_i = t_eval[out_index].unsqueeze(0).clone() if t_eval.shape[0] > 1 else t_eval + c_i = c_eval[out_index].unsqueeze(0).clone() if c_eval.shape[0] > 1 else c_eval - feat_batches = [] - coord_batches = [] - index_batch = [] - for rep in range(repeat_factor): - out_index = rep * logical_batch + i - feat_batches.append(x_eval[out_index, :count]) - coords_rep = coords_i.clone() - coords_rep[:, 0] = rep - coord_batches.append(coords_rep) - index_batch.append(out_index) - - x_st_i = SparseTensor( - feats=torch.cat(feat_batches, dim=0), - coords=torch.cat(coord_batches, dim=0).to(torch.int32), - ) - index_tensor = torch.tensor(index_batch, device=x_eval.device, dtype=torch.long) - if t_eval.shape[0] > 1: - t_i = t_eval.index_select(0, index_tensor) - else: - t_i = t_eval - if c_eval.shape[0] > 1: - c_i = c_eval.index_select(0, index_tensor) - else: - c_i = c_eval - - if is_512_run: - sparse_out = self.img2shape_512(x_st_i, t_i, c_i) - else: - sparse_out = self.img2shape(x_st_i, t_i, c_i) - - feats_group, coords_group = sparse_out.to_tensor_list() - if len(feats_group) != repeat_factor: - raise ValueError( - f"Trellis2 expected {repeat_factor} sparse output groups for batch {i}, got {len(feats_group)}" - ) - for rep, (feats_rep, coords_rep) in enumerate(zip(feats_group, coords_group)): - if feats_rep.shape[0] != count: - raise ValueError( - f"Trellis2 sparse output rows for batch {i} rep {rep} expected {count}, got {feats_rep.shape[0]}" - ) - if coords_rep.shape[0] != count: - raise ValueError( - f"Trellis2 sparse output coords for batch {i} rep {rep} expected {count}, got {coords_rep.shape[0]}" - ) - grouped_outs.append(feats_group) - grouped_counts.append(count) - - for rep in range(repeat_factor): - for i in range(logical_batch): - sparse_outs.append(grouped_outs[i][rep]) - active_coord_counts.append(grouped_counts[i]) - else: - for rep in range(repeat_factor): - for i in range(logical_batch): - out_index = rep * logical_batch + i - count = int(coord_counts[i].item()) - coords_i = coords[coords[:, 0] == i].clone() - if coords_i.shape[0] != count: - raise ValueError( - f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}" - ) - coords_i[:, 0] = 0 - feats_i = x_eval[out_index, :count] - x_st_i = SparseTensor(feats=feats_i, coords=coords_i.to(torch.int32)) - t_i = t_eval[out_index].unsqueeze(0) if t_eval.shape[0] > 1 else t_eval - c_i = c_eval[out_index].unsqueeze(0) if c_eval.shape[0] > 1 else c_eval - - if mode == "shape_generation": - if is_512_run: - sparse_out = self.img2shape_512(x_st_i, t_i, c_i) - else: - sparse_out = self.img2shape(x_st_i, t_i, c_i) + if mode == "shape_generation": + if is_512_run: + sparse_out = self.img2shape_512(x_st_i, t_i, c_i) else: - slat = transformer_options.get("shape_slat") - if slat is None: - raise ValueError("shape_slat can't be None") - if slat.ndim == 3: - if slat.shape[0] != logical_batch: - raise ValueError( - f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}" - ) - if slat.shape[1] < count: - raise ValueError( - f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}" - ) - slat_feats = slat[i, :count].to(x_st_i.device) - else: - slat_feats = slat[:count].to(x_st_i.device) - x_st_i = x_st_i.replace(feats=torch.cat([x_st_i.feats, slat_feats], dim=-1)) - sparse_out = self.shape2txt(x_st_i, t_i, c_i) + sparse_out = self.img2shape(x_st_i, t_i, c_i) + else: + slat = transformer_options.get("shape_slat") + if slat is None: + raise ValueError("shape_slat can't be None") + if slat.ndim == 3: + if slat.shape[0] != logical_batch: + raise ValueError( + f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}" + ) + if slat.shape[1] < count: + raise ValueError( + f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}" + ) + slat_feats = slat[i, :count].to(x_st_i.device) + else: + slat_feats = slat[:count].to(x_st_i.device) + x_st_i = x_st_i.replace(feats=torch.cat([x_st_i.feats, slat_feats], dim=-1)) + sparse_out = self.shape2txt(x_st_i, t_i, c_i) - sparse_outs.append(sparse_out.feats) - active_coord_counts.append(count) + sparse_outs.append(sparse_out.feats) + active_coord_counts.append(count) out_channels = sparse_outs[0].shape[-1] padded = sparse_outs[0].new_zeros((B, N, out_channels)) @@ -1022,7 +986,6 @@ class Trellis2(nn.Module): out = self.shape2txt(x_st, t_eval, c_eval) else: # structure orig_bsz = x.shape[0] - cond_or_uncond = transformer_options.get("cond_or_uncond") or [] batch_groups = len(cond_or_uncond) if len(cond_or_uncond) > 0 and orig_bsz % len(cond_or_uncond) == 0 else 1 logical_batch = orig_bsz // batch_groups if logical_batch > 1: @@ -1034,23 +997,19 @@ class Trellis2(nn.Module): c_groups = context.reshape(batch_groups, logical_batch, *context.shape[1:]) if shape_rule and batch_groups > 1: - selected_group_indices = [batch_groups - 1] + selected_group_indices = cond_group_indices(batch_groups) else: selected_group_indices = list(range(batch_groups)) out_groups = [] for sample_index in range(logical_batch): if shape_rule and batch_groups > 1: - half = orig_bsz // 2 - x_i = x[half + sample_index].unsqueeze(0) + x_i = x_groups[selected_group_indices, sample_index] if timestep.shape[0] > 1: - t_i = timestep[half + sample_index].unsqueeze(0) + t_i = t_groups[selected_group_indices, sample_index] else: t_i = timestep - if cond.shape[0] > 1: - c_i = cond[sample_index].unsqueeze(0) - else: - c_i = cond + c_i = c_groups[selected_group_indices, sample_index] else: x_i = x_groups[selected_group_indices, sample_index] if timestep.shape[0] > 1: diff --git a/comfy/sample.py b/comfy/sample.py index 7251aa799..6fba221ed 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -15,32 +15,26 @@ def prepare_noise_inner(latent_image, generator, noise_inds=None): else: noise_inds = np.asarray(noise_inds, dtype=np.int64) + base_seed = int(generator.initial_seed()) unique_inds = np.unique(noise_inds) - first_indices = {int(unique_index): int(np.flatnonzero(noise_inds == unique_index)[0]) for unique_index in unique_inds.tolist()} - index_states = {} - for unique_index in sorted(first_indices): - index_states[unique_index] = generator.get_state().clone() - count = int(coord_counts[first_indices[unique_index]].item()) - torch.randn( - [1, latent_image.size(1), count, latent_image.size(3)], - dtype=torch.float32, - layout=latent_image.layout, - generator=generator, - device="cpu", - ) - - for batch_index, noise_index in enumerate(noise_inds.tolist()): - count = int(coord_counts[batch_index].item()) + sample_noises = {} + for noise_index in unique_inds.tolist(): + rows = np.flatnonzero(noise_inds == noise_index) + max_count = max(int(coord_counts[row].item()) for row in rows.tolist()) local_generator = torch.Generator(device="cpu") - local_generator.set_state(index_states[int(noise_index)].clone()) - sample_noise = torch.randn( - [1, latent_image.size(1), count, latent_image.size(3)], + local_generator.manual_seed(base_seed + int(noise_index)) + sample_noises[int(noise_index)] = torch.randn( + [1, latent_image.size(1), max_count, latent_image.size(3)], dtype=torch.float32, layout=latent_image.layout, generator=local_generator, device="cpu", ) - noise[batch_index:batch_index + 1, :, :count, :] = sample_noise + + for batch_index, noise_index in enumerate(noise_inds.tolist()): + count = int(coord_counts[batch_index].item()) + sample_noise = sample_noises[int(noise_index)] + noise[batch_index:batch_index + 1, :, :count, :] = sample_noise[:, :, :count, :] return noise.to(dtype=latent_image.dtype) if noise_inds is None: @@ -76,6 +70,8 @@ def prepare_noise(latent_image, seed, noise_inds=None): def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None): if latent_image.is_nested: return latent_image + if getattr(latent_image, "trellis_skip_empty_fix", False): + return latent_image latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels if torch.count_nonzero(latent_image) == 0: if latent_format.latent_channels != latent_image.shape[1]: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 621cc9586..6556ed176 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -115,18 +115,54 @@ def infer_batched_coord_layout(coords): return batch_size, counts, max_tokens +def split_batched_coords(coords, coord_counts): + batch_ids = coords[:, 0].to(torch.int64) + order = torch.argsort(batch_ids, stable=True) + sorted_coords = coords.index_select(0, order) + sorted_batch_ids = batch_ids.index_select(0, order) + + offsets = coord_counts.cumsum(0) - coord_counts + items = [] + for i in range(coord_counts.shape[0]): + count = int(coord_counts[i].item()) + start = int(offsets[i].item()) + coords_i = sorted_coords[start:start + count] + ids_i = sorted_batch_ids[start:start + count] + if coords_i.shape[0] != count or not torch.all(ids_i == i): + raise ValueError(f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}") + items.append(coords_i) + return items + + +def normalize_batch_index(batch_index): + if batch_index is None: + return None + if isinstance(batch_index, int): + return [int(batch_index)] + return list(batch_index) + + +def resolve_sample_indices(batch_index, batch_size): + sample_indices = normalize_batch_index(batch_index) + if sample_indices is None: + return list(range(batch_size)) + if len(sample_indices) != batch_size: + raise ValueError( + f"Trellis2 batch_index length {len(sample_indices)} does not match batch size {batch_size}" + ) + return sample_indices + + def flatten_batched_sparse_latent(samples, coords, coord_counts): samples = samples.squeeze(-1).transpose(1, 2) if coord_counts is None: return samples.reshape(-1, samples.shape[-1]), coords + coords_items = split_batched_coords(coords, coord_counts) feat_list = [] coord_list = [] - for i in range(coord_counts.shape[0]): + for i, coords_i in enumerate(coords_items): count = int(coord_counts[i].item()) - coords_i = coords[coords[:, 0] == i] - if coords_i.shape[0] != count: - raise ValueError(f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}") feat_list.append(samples[i, :count]) coord_list.append(coords_i) @@ -138,12 +174,10 @@ def split_batched_sparse_latent(samples, coords, coord_counts): if coord_counts is None: return [(samples.reshape(-1, samples.shape[-1]), coords)] + coords_items = split_batched_coords(coords, coord_counts) items = [] - for i in range(coord_counts.shape[0]): + for i, coords_i in enumerate(coords_items): count = int(coord_counts[i].item()) - coords_i = coords[coords[:, 0] == i] - if coords_i.shape[0] != count: - raise ValueError(f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}") items.append((samples[i, :count], coords_i)) return items @@ -345,6 +379,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): load_device = comfy.model_management.get_torch_device() offload_device = comfy.model_management.vae_offload_device() decoder = decoder.to(load_device) + batch_index = normalize_batch_index(samples.get("batch_index")) samples = samples["samples"] samples = samples.to(load_device) if samples.shape[0] > 1: @@ -361,6 +396,8 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): ratio = current_res // resolution decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 out = Types.VOXEL(decoded.squeeze(1).float()) + if batch_index is not None: + out.batch_index = normalize_batch_index(batch_index) return IO.NodeOutput(out) class Trellis2UpsampleCascade(IO.ComfyNode): @@ -386,6 +423,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode): comfy.model_management.load_model_gpu(vae.patcher) coord_counts = shape_latent_512.get("coord_counts") + batch_index = normalize_batch_index(shape_latent_512.get("batch_index")) decoder = vae.first_stage_model.shape_dec lr_resolution = 512 target_resolution = int(target_resolution) @@ -424,40 +462,48 @@ class Trellis2UpsampleCascade(IO.ComfyNode): ) decoder_dtype = next(decoder.parameters()).dtype - final_coords_list = [] - output_resolutions = [] - output_coord_counts = [] - for batch_index, (feats_i, coords_i) in enumerate(items): + sample_hr_coords = [] + for feats_i, coords_i in items: feats_i = feats_i.to(device) coords_i = coords_i.to(device).clone() coords_i[:, 0] = 0 slat_i = shape_norm(feats_i, coords_i) slat_i.feats = slat_i.feats.to(decoder_dtype) - hr_coords_i = decoder.upsample(slat_i, upsample_times=4) + sample_hr_coords.append(decoder.upsample(slat_i, upsample_times=4)) - hr_resolution = target_resolution - while True: + hr_resolution = target_resolution + while True: + exceeds_limit = False + for hr_coords_i in sample_hr_coords: quant_coords_i = torch.cat([ hr_coords_i[:, :1], ((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), ], dim=1) - final_coords_i = quant_coords_i.unique(dim=0) - num_tokens = final_coords_i.shape[0] - - if num_tokens < max_tokens or hr_resolution <= 1024: + if quant_coords_i.unique(dim=0).shape[0] >= max_tokens: + exceeds_limit = True break - hr_resolution -= 128 + if not exceeds_limit or hr_resolution <= 1024: + break + hr_resolution -= 128 + final_coords_list = [] + output_coord_counts = [] + for sample_offset, hr_coords_i in enumerate(sample_hr_coords): + quant_coords_i = torch.cat([ + hr_coords_i[:, :1], + ((hr_coords_i[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + final_coords_i = quant_coords_i.unique(dim=0) final_coords_i = final_coords_i.clone() - final_coords_i[:, 0] = batch_index + final_coords_i[:, 0] = sample_offset final_coords_list.append(final_coords_i) - output_resolutions.append(int(hr_resolution)) output_coord_counts.append(int(final_coords_i.shape[0])) return IO.NodeOutput({ "coords": torch.cat(final_coords_list, dim=0), "coord_counts": torch.tensor(output_coord_counts, dtype=torch.int64), - "resolutions": torch.tensor(output_resolutions, dtype=torch.int64), + "resolutions": torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64), + "batch_index": normalize_batch_index(batch_index), },) dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) @@ -612,7 +658,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): category="latent/3d", inputs=[ IO.AnyType.Input("structure_or_coords"), - IO.Model.Input("model") + IO.Model.Input("model"), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), ], outputs=[ IO.Latent.Output(), @@ -621,21 +668,24 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_or_coords, model): + def execute(cls, structure_or_coords, model, seed): # to accept the upscaled coords is_512_pass = False coord_counts = None coord_resolutions = None + batch_index = None if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: decoded = structure_or_coords.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() is_512_pass = True + batch_index = normalize_batch_index(getattr(structure_or_coords, "batch_index", None)) elif isinstance(structure_or_coords, dict): coords = structure_or_coords["coords"].int() coord_counts = structure_or_coords.get("coord_counts") coord_resolutions = structure_or_coords.get("resolutions") + batch_index = normalize_batch_index(structure_or_coords.get("batch_index")) is_512_pass = False elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: @@ -655,15 +705,17 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): else: coord_counts = inferred_coord_counts if batch_size == 1: - coord_counts = None - latent = torch.randn(1, in_channels, coords.shape[0], 1) + sample_indices = normalize_batch_index(batch_index) or [0] + generator = torch.Generator(device="cpu") + generator.manual_seed(int(seed) + int(sample_indices[0])) + latent = torch.randn(1, in_channels, coords.shape[0], 1, generator=generator) else: + sample_indices = resolve_sample_indices(batch_index, batch_size) latent = torch.zeros(batch_size, in_channels, max_tokens, 1) - base_state = torch.random.get_rng_state() - for i in range(batch_size): + for i, sample_index in enumerate(sample_indices): count = int(coord_counts[i].item()) generator = torch.Generator(device="cpu") - generator.set_state(base_state.clone()) + generator.manual_seed(int(seed) + int(sample_index)) latent_i = torch.randn(1, in_channels, count, 1, generator=generator) latent[i, :, :count] = latent_i[0] if coord_counts is not None: @@ -685,11 +737,12 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): else: model.model_options["transformer_options"]["generation_mode"] = "shape_generation" output = {"samples": latent, "coords": coords, "type": "trellis2"} + if batch_index is not None: + output["batch_index"] = normalize_batch_index(batch_index) if coord_counts is not None: output["coord_counts"] = coord_counts if coord_resolutions is not None: output["coord_resolutions"] = coord_resolutions - output["batch_index"] = [0] * batch_size return IO.NodeOutput(output, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): @@ -701,7 +754,8 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): inputs=[ IO.Voxel.Input("structure_or_coords"), IO.Latent.Input("shape_latent"), - IO.Model.Input("model") + IO.Model.Input("model"), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), ], outputs=[ IO.Latent.Output(), @@ -710,20 +764,24 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_or_coords, shape_latent, model): + def execute(cls, structure_or_coords, shape_latent, model, seed): channels = 32 coord_counts = None + batch_index = None if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: decoded = structure_or_coords.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + batch_index = normalize_batch_index(getattr(structure_or_coords, "batch_index", None)) elif isinstance(structure_or_coords, dict): coords = structure_or_coords["coords"].int() coord_counts = structure_or_coords.get("coord_counts") + batch_index = normalize_batch_index(structure_or_coords.get("batch_index")) elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: coords = structure_or_coords.int() + shape_batch_index = normalize_batch_index(shape_latent.get("batch_index")) shape_latent = shape_latent["samples"] batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords) if coord_counts is not None: @@ -746,19 +804,23 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) if batch_size == 1: - coord_counts = None - latent = torch.randn(1, channels, coords.shape[0], 1) + sample_indices = normalize_batch_index(batch_index) or [0] + generator = torch.Generator(device="cpu") + generator.manual_seed(int(seed) + int(sample_indices[0])) + latent = torch.randn(1, channels, coords.shape[0], 1, generator=generator) else: + sample_indices = resolve_sample_indices(batch_index, batch_size) latent = torch.zeros(batch_size, channels, max_tokens, 1) - base_state = torch.random.get_rng_state() - for i in range(batch_size): + for i, sample_index in enumerate(sample_indices): count = int(coord_counts[i].item()) generator = torch.Generator(device="cpu") - generator.set_state(base_state.clone()) + generator.manual_seed(int(seed) + int(sample_index)) latent_i = torch.randn(1, channels, count, 1, generator=generator) latent[i, :, :count] = latent_i[0] if coord_counts is not None: latent.trellis_coord_counts = coord_counts.clone() + if batch_index is None: + batch_index = shape_batch_index model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: @@ -772,9 +834,10 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"]["generation_mode"] = "texture_generation" model.model_options["transformer_options"]["shape_slat"] = shape_latent output = {"samples": latent, "coords": coords, "type": "trellis2"} + if batch_index is not None: + output["batch_index"] = normalize_batch_index(batch_index) if coord_counts is not None: output["coord_counts"] = coord_counts - output["batch_index"] = [0] * batch_size return IO.NodeOutput(output, model) @@ -786,19 +849,29 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): category="latent/3d", inputs=[ IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), + IO.Int.Input("batch_index_start", default=0, min=0, max=4096, tooltip="Starting sample index for per-sample sampler noise."), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), ], outputs=[ IO.Latent.Output(), ] ) @classmethod - def execute(cls, batch_size): + def execute(cls, batch_size, batch_index_start, seed): in_channels = 8 resolution = 16 - latent = torch.randn(1, in_channels, resolution, resolution, resolution).repeat(batch_size, 1, 1, 1, 1) - output = {"samples": latent, "type": "trellis2"} - if batch_size > 1: - output["batch_index"] = [0] * batch_size + sample_indices = [int(batch_index_start) + i for i in range(batch_size)] + latent = torch.zeros(batch_size, in_channels, resolution, resolution, resolution) + for i, sample_index in enumerate(sample_indices): + generator = torch.Generator(device="cpu") + generator.manual_seed(int(seed) + sample_index) + latent[i] = torch.randn(1, in_channels, resolution, resolution, resolution, generator=generator)[0] + output = { + "samples": latent, + "type": "trellis2", + } + if batch_size > 1 or batch_index_start != 0: + output["batch_index"] = sample_indices return IO.NodeOutput(output) def simplify_fn(vertices, faces, colors=None, target=100000): diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py index 920eca471..95f64d031 100644 --- a/tests-unit/comfy_extras_test/nodes_trellis2_test.py +++ b/tests-unit/comfy_extras_test/nodes_trellis2_test.py @@ -123,5 +123,88 @@ class TestRunConditioningRestore(unittest.TestCase): self.assertFalse(hasattr(inner_model, "image_size")) +class DummyCloneModel: + def __init__(self): + self.model_options = {} + + def clone(self): + cloned = DummyCloneModel() + cloned.model_options = self.model_options.copy() + return cloned + + +class TestTrellisBatchSemantics(unittest.TestCase): + def test_empty_structure_latent_is_deterministic_and_propagates_sample_indices(self): + batch_output = nodes_trellis2.EmptyStructureLatentTrellis2.execute(2, 0, 17)[0] + single_output = nodes_trellis2.EmptyStructureLatentTrellis2.execute(1, 5, 17)[0] + + expected_batch = torch.zeros(2, 8, 16, 16, 16) + expected_batch[0] = torch.randn(1, 8, 16, 16, 16, generator=torch.Generator(device="cpu").manual_seed(17))[0] + expected_batch[1] = torch.randn(1, 8, 16, 16, 16, generator=torch.Generator(device="cpu").manual_seed(18))[0] + expected_single = torch.randn(1, 8, 16, 16, 16, generator=torch.Generator(device="cpu").manual_seed(22)) + + self.assertTrue(torch.equal(batch_output["samples"], expected_batch)) + self.assertEqual(batch_output["batch_index"], [0, 1]) + self.assertTrue(torch.equal(single_output["samples"], expected_single)) + self.assertEqual(single_output["batch_index"], [5]) + + def test_empty_shape_latent_is_deterministic_and_propagates_batch_index(self): + coords = torch.tensor( + [ + [1, 5, 5, 5], + [0, 1, 1, 1], + [1, 6, 6, 6], + [0, 2, 2, 2], + [1, 7, 7, 7], + ], + dtype=torch.int32, + ) + structure = { + "coords": coords, + "coord_counts": torch.tensor([2, 3], dtype=torch.int64), + "batch_index": [4, 9], + } + + output, _ = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 23) + + expected = torch.zeros(2, 32, 3, 1) + expected[0, :, :2, :] = torch.randn(1, 32, 2, 1, generator=torch.Generator(device="cpu").manual_seed(27))[0] + expected[1, :, :3, :] = torch.randn(1, 32, 3, 1, generator=torch.Generator(device="cpu").manual_seed(32))[0] + + self.assertTrue(torch.equal(output["samples"], expected)) + self.assertTrue(torch.equal(output["coord_counts"], torch.tensor([2, 3], dtype=torch.int64))) + self.assertEqual(output["batch_index"], [4, 9]) + + def test_empty_shape_latent_keeps_singleton_coord_counts(self): + structure = { + "coords": torch.tensor( + [ + [0, 1, 1, 1], + [0, 2, 2, 2], + ], + dtype=torch.int32, + ), + } + + output, _ = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11) + + self.assertTrue(torch.equal(output["coord_counts"], torch.tensor([2], dtype=torch.int64))) + + def test_flatten_batched_sparse_latent_validates_coord_counts(self): + samples = torch.zeros(2, 32, 3, 1) + coords = torch.tensor( + [ + [0, 1, 1, 1], + [1, 2, 2, 2], + [1, 3, 3, 3], + ], + dtype=torch.int32, + ) + coord_counts = torch.tensor([2, 1], dtype=torch.int64) + + with self.assertRaises(ValueError): + nodes_trellis2.flatten_batched_sparse_latent(samples, coords, coord_counts) + + if __name__ == "__main__": unittest.main() diff --git a/tests-unit/comfy_test/sample_test.py b/tests-unit/comfy_test/sample_test.py new file mode 100644 index 000000000..ad154aca8 --- /dev/null +++ b/tests-unit/comfy_test/sample_test.py @@ -0,0 +1,47 @@ +import unittest + +import torch + +import comfy.sample + + +class TestPrepareNoiseInnerTrellis(unittest.TestCase): + def test_coord_counts_noise_matches_per_index_prefix_draws(self): + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([3, 5], dtype=torch.int64) + + generator = torch.Generator(device="cpu") + generator.manual_seed(123) + noise = comfy.sample.prepare_noise_inner(latent, generator) + + expected = torch.zeros_like(noise, dtype=torch.float32) + row0 = torch.Generator(device="cpu") + row0.manual_seed(123) + expected[0, :, :3, :] = torch.randn(1, 4, 3, 1, generator=row0)[0] + row1 = torch.Generator(device="cpu") + row1.manual_seed(124) + expected[1] = torch.randn(1, 4, 5, 1, generator=row1)[0] + + self.assertTrue(torch.equal(noise.float(), expected)) + self.assertTrue(torch.equal(noise[0, :, 3:, :], torch.zeros_like(noise[0, :, 3:, :]))) + + def test_coord_counts_noise_inds_share_prefixes_for_duplicates(self): + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([3, 5], dtype=torch.int64) + + generator = torch.Generator(device="cpu") + generator.manual_seed(456) + noise = comfy.sample.prepare_noise_inner(latent, generator, noise_inds=[7, 7]) + + replay = torch.Generator(device="cpu") + replay.manual_seed(463) + expected1 = torch.randn(1, 4, 5, 1, generator=replay) + expected0 = expected1[:, :, :3, :] + + self.assertTrue(torch.equal(noise[0:1, :, :3, :], expected0)) + self.assertTrue(torch.equal(noise[1:2, :, :5, :], expected1)) + self.assertTrue(torch.equal(noise[0, :, 3:, :], torch.zeros_like(noise[0, :, 3:, :]))) + + +if __name__ == "__main__": + unittest.main() From 06661522d9fd884939edd2ab297e4c1f90bc2893 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 14:35:08 -0500 Subject: [PATCH 77/93] fix: issue 88 texture path resolution and gpu-only host conversion --- comfy_extras/nodes_trellis2.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 8121e261b..e00444f8c 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -111,8 +111,8 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): verts = mesh.vertices.to(device).squeeze(0) voxel_colors = voxel_colors.to(device) - voxel_pos_np = voxel_pos.numpy() - verts_np = verts.numpy() + voxel_pos_np = voxel_pos.cpu().numpy() + verts_np = verts.cpu().numpy() tree = scipy.spatial.cKDTree(voxel_pos_np) @@ -194,6 +194,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): IO.Latent.Input("samples"), IO.Vae.Input("vae"), IO.AnyType.Input("shape_subs"), + IO.Combo.Input("resolution", options=["512", "1024"], default="1024") ], outputs=[ IO.Mesh.Output("mesh"), @@ -201,9 +202,9 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, shape_mesh, samples, vae, shape_subs): + def execute(cls, shape_mesh, samples, vae, shape_subs, resolution): - resolution = 1024 + resolution = int(resolution) patcher = vae.patcher device = comfy.model_management.get_torch_device() comfy.model_management.load_model_gpu(patcher) From 7b95f7c4b05414d60e893c40a3f9d2b5b9e58732 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 14:39:15 -0500 Subject: [PATCH 78/93] fix: issue 88 unify texture paint color path on host --- comfy_extras/nodes_trellis2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index e00444f8c..6651ea72a 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -109,7 +109,7 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): # map voxels voxel_pos = voxel_coords.to(device).float() * voxel_size + origin verts = mesh.vertices.to(device).squeeze(0) - voxel_colors = voxel_colors.to(device) + voxel_colors = voxel_colors.cpu() voxel_pos_np = voxel_pos.cpu().numpy() verts_np = verts.cpu().numpy() From 55997759d837d97a9f8c861693c60d41b8ccea9c Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 14:44:45 -0500 Subject: [PATCH 79/93] fix: issue 88 make texture voxel query deterministic --- comfy_extras/nodes_trellis2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 6651ea72a..fe09ec7b9 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -117,7 +117,7 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): tree = scipy.spatial.cKDTree(voxel_pos_np) # nearest neighbour k=1 - _, nearest_idx_np = tree.query(verts_np, k=1, workers=-1) + _, nearest_idx_np = tree.query(verts_np, k=1, workers=1) nearest_idx = torch.from_numpy(nearest_idx_np).long() v_colors = voxel_colors[nearest_idx] From a752dd473642020b66b4171600154b4435d20638 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 14:46:23 -0500 Subject: [PATCH 80/93] Harden Trellis sparse metadata validation --- comfy/ldm/trellis2/model.py | 4 + comfy_extras/nodes_trellis2.py | 34 ++++++-- .../comfy_extras_test/nodes_trellis2_test.py | 77 +++++++++++++++++++ 3 files changed, 108 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 15939e5c6..7cf3e728e 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -853,6 +853,10 @@ class Trellis2(nn.Module): raise ValueError( f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}" ) + if int(coord_counts.sum().item()) != coords.shape[0]: + raise ValueError( + f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}" + ) batch_ids = coords[:, 0].to(torch.int64) order = torch.argsort(batch_ids, stable=True) sorted_coords = coords.index_select(0, order) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 6556ed176..ce184a946 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -105,6 +105,8 @@ def infer_batched_coord_layout(coords): raise ValueError("Trellis2 coords can't be empty") batch_ids = coords[:, 0].to(torch.int64) + if (batch_ids < 0).any(): + raise ValueError(f"Trellis2 batch ids must be non-negative, got {batch_ids.unique(sorted=True).tolist()}") batch_size = int(batch_ids.max().item()) + 1 counts = torch.bincount(batch_ids, minlength=batch_size) @@ -116,6 +118,15 @@ def infer_batched_coord_layout(coords): def split_batched_coords(coords, coord_counts): + if coord_counts.ndim != 1: + raise ValueError(f"Trellis2 coord_counts must be 1D, got shape {tuple(coord_counts.shape)}") + if (coord_counts < 0).any(): + raise ValueError(f"Trellis2 coord_counts must be non-negative, got {coord_counts.tolist()}") + if int(coord_counts.sum().item()) != coords.shape[0]: + raise ValueError( + f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}" + ) + batch_ids = coords[:, 0].to(torch.int64) order = torch.argsort(batch_ids, stable=True) sorted_coords = coords.index_select(0, order) @@ -153,6 +164,17 @@ def resolve_sample_indices(batch_index, batch_size): return sample_indices +def resolve_singleton_sample_index(batch_index): + sample_indices = normalize_batch_index(batch_index) + if sample_indices is None: + return 0 + if len(sample_indices) != 1: + raise ValueError( + f"Trellis2 batch_index must be an int or single-element iterable for singleton coords, got {sample_indices}" + ) + return int(sample_indices[0]) + + def flatten_batched_sparse_latent(samples, coords, coord_counts): samples = samples.squeeze(-1).transpose(1, 2) if coord_counts is None: @@ -705,9 +727,9 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): else: coord_counts = inferred_coord_counts if batch_size == 1: - sample_indices = normalize_batch_index(batch_index) or [0] + sample_index = resolve_singleton_sample_index(batch_index) generator = torch.Generator(device="cpu") - generator.manual_seed(int(seed) + int(sample_indices[0])) + generator.manual_seed(int(seed) + sample_index) latent = torch.randn(1, in_channels, coords.shape[0], 1, generator=generator) else: sample_indices = resolve_sample_indices(batch_index, batch_size) @@ -730,8 +752,6 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"]["coords"] = coords if coord_counts is not None: model.model_options["transformer_options"]["coord_counts"] = coord_counts - if coord_resolutions is not None: - model.model_options["transformer_options"]["coord_resolutions"] = coord_resolutions if is_512_pass: model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512" else: @@ -742,7 +762,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): if coord_counts is not None: output["coord_counts"] = coord_counts if coord_resolutions is not None: - output["coord_resolutions"] = coord_resolutions + output["resolutions"] = coord_resolutions return IO.NodeOutput(output, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): @@ -804,9 +824,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) if batch_size == 1: - sample_indices = normalize_batch_index(batch_index) or [0] + sample_index = resolve_singleton_sample_index(batch_index) generator = torch.Generator(device="cpu") - generator.manual_seed(int(seed) + int(sample_indices[0])) + generator.manual_seed(int(seed) + sample_index) latent = torch.randn(1, channels, coords.shape[0], 1, generator=generator) else: sample_indices = resolve_sample_indices(batch_index, batch_size) diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py index 95f64d031..196a88343 100644 --- a/tests-unit/comfy_extras_test/nodes_trellis2_test.py +++ b/tests-unit/comfy_extras_test/nodes_trellis2_test.py @@ -190,6 +190,40 @@ class TestTrellisBatchSemantics(unittest.TestCase): self.assertTrue(torch.equal(output["coord_counts"], torch.tensor([2], dtype=torch.int64))) + def test_empty_shape_latent_rejects_multi_index_singleton(self): + structure = { + "coords": torch.tensor( + [ + [0, 1, 1, 1], + [0, 2, 2, 2], + ], + dtype=torch.int32, + ), + "batch_index": [5, 6], + } + + with self.assertRaises(ValueError): + nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11) + + def test_empty_texture_latent_rejects_multi_index_singleton(self): + coords = torch.tensor( + [ + [0, 1, 1, 1], + [0, 2, 2, 2], + ], + dtype=torch.int32, + ) + structure = {"coords": coords, "batch_index": [7, 8]} + shape_latent = {"samples": torch.zeros(1, 32, 2, 1)} + + with self.assertRaises(ValueError): + nodes_trellis2.EmptyTextureLatentTrellis2.execute( + structure, + shape_latent, + DummyCloneModel(), + 13, + ) + def test_flatten_batched_sparse_latent_validates_coord_counts(self): samples = torch.zeros(2, 32, 3, 1) coords = torch.tensor( @@ -205,6 +239,49 @@ class TestTrellisBatchSemantics(unittest.TestCase): with self.assertRaises(ValueError): nodes_trellis2.flatten_batched_sparse_latent(samples, coords, coord_counts) + def test_infer_batched_coord_layout_rejects_negative_batch_ids(self): + coords = torch.tensor( + [ + [-1, 1, 1, 1], + [0, 2, 2, 2], + ], + dtype=torch.int32, + ) + + with self.assertRaises(ValueError): + nodes_trellis2.infer_batched_coord_layout(coords) + + def test_split_batched_coords_validates_total_count(self): + coords = torch.tensor( + [ + [0, 1, 1, 1], + [1, 2, 2, 2], + [1, 3, 3, 3], + ], + dtype=torch.int32, + ) + coord_counts = torch.tensor([1, 1], dtype=torch.int64) + + with self.assertRaises(ValueError): + nodes_trellis2.split_batched_coords(coords, coord_counts) + + def test_empty_shape_latent_preserves_resolutions_key(self): + structure = { + "coords": torch.tensor( + [ + [0, 1, 1, 1], + [0, 2, 2, 2], + ], + dtype=torch.int32, + ), + "resolutions": torch.tensor([1024], dtype=torch.int64), + } + + output, model = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11) + + self.assertTrue(torch.equal(output["resolutions"], torch.tensor([1024], dtype=torch.int64))) + self.assertNotIn("coord_resolutions", model.model_options["transformer_options"]) + if __name__ == "__main__": unittest.main() From 0b99c8c44acf964b9989b71439826d2582363238 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 15:50:40 -0500 Subject: [PATCH 81/93] Fail loud on Trellis invalid batch metadata --- comfy/sample.py | 4 ++++ comfy_extras/nodes_trellis2.py | 5 +++++ tests-unit/comfy_extras_test/nodes_trellis2_test.py | 9 +++++++++ tests-unit/comfy_test/sample_test.py | 10 ++++++++++ 4 files changed, 28 insertions(+) diff --git a/comfy/sample.py b/comfy/sample.py index 6fba221ed..8626269a1 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -14,6 +14,10 @@ def prepare_noise_inner(latent_image, generator, noise_inds=None): noise_inds = np.arange(latent_image.size(0), dtype=np.int64) else: noise_inds = np.asarray(noise_inds, dtype=np.int64) + if noise_inds.shape[0] != latent_image.size(0): + raise ValueError( + f"Trellis2 noise_inds length {noise_inds.shape[0]} does not match latent batch {latent_image.size(0)}" + ) base_seed = int(generator.initial_seed()) unique_inds = np.unique(noise_inds) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index ce184a946..328cec6e7 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -800,6 +800,11 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: coords = structure_or_coords.int() + else: + raise ValueError( + "structure_or_coords must be a voxel input with data.ndim == 4, " + f'a dict containing "coords", or a 2D torch.Tensor; got {type(structure_or_coords).__name__}' + ) shape_batch_index = normalize_batch_index(shape_latent.get("batch_index")) shape_latent = shape_latent["samples"] diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py index 196a88343..43647e793 100644 --- a/tests-unit/comfy_extras_test/nodes_trellis2_test.py +++ b/tests-unit/comfy_extras_test/nodes_trellis2_test.py @@ -224,6 +224,15 @@ class TestTrellisBatchSemantics(unittest.TestCase): 13, ) + def test_empty_texture_latent_rejects_invalid_structure_input(self): + with self.assertRaises(ValueError): + nodes_trellis2.EmptyTextureLatentTrellis2.execute( + "bad-input", + {"samples": torch.zeros(1, 32, 2, 1)}, + DummyCloneModel(), + 13, + ) + def test_flatten_batched_sparse_latent_validates_coord_counts(self): samples = torch.zeros(2, 32, 3, 1) coords = torch.tensor( diff --git a/tests-unit/comfy_test/sample_test.py b/tests-unit/comfy_test/sample_test.py index ad154aca8..e76e65266 100644 --- a/tests-unit/comfy_test/sample_test.py +++ b/tests-unit/comfy_test/sample_test.py @@ -42,6 +42,16 @@ class TestPrepareNoiseInnerTrellis(unittest.TestCase): self.assertTrue(torch.equal(noise[1:2, :, :5, :], expected1)) self.assertTrue(torch.equal(noise[0, :, 3:, :], torch.zeros_like(noise[0, :, 3:, :]))) + def test_coord_counts_noise_inds_length_must_match_batch(self): + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([3, 5], dtype=torch.int64) + + generator = torch.Generator(device="cpu") + generator.manual_seed(456) + + with self.assertRaises(ValueError): + comfy.sample.prepare_noise_inner(latent, generator, noise_inds=[7]) + if __name__ == "__main__": unittest.main() From 90ebb50f00bf89ed8c947a0e4ed4ed0803981ea1 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 16:05:10 -0500 Subject: [PATCH 82/93] Harden Trellis sparse latent seeding --- comfy/ldm/trellis2/model.py | 4 +++ comfy/sample.py | 2 -- comfy_extras/nodes_trellis2.py | 4 +-- .../comfy_extras_test/nodes_trellis2_test.py | 29 +++++++++++++++++++ 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 7cf3e728e..e8ed39aed 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -880,6 +880,10 @@ class Trellis2(nn.Module): for i in range(logical_batch): out_index = rep * logical_batch + i count = int(coord_counts[i].item()) + if count > N: + raise ValueError( + f"Trellis2 coord count {count} exceeds latent token dimension {N} for batch {i}" + ) coords_i = coords_by_batch[i].clone() coords_i[:, 0] = 0 feats_i = x_eval[out_index, :count].clone() diff --git a/comfy/sample.py b/comfy/sample.py index 8626269a1..a4ce5f56f 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -74,8 +74,6 @@ def prepare_noise(latent_image, seed, noise_inds=None): def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None): if latent_image.is_nested: return latent_image - if getattr(latent_image, "trellis_skip_empty_fix", False): - return latent_image latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels if torch.count_nonzero(latent_image) == 0: if latent_format.latent_channels != latent_image.shape[1]: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 328cec6e7..d345641b1 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -807,6 +807,8 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) shape_batch_index = normalize_batch_index(shape_latent.get("batch_index")) + if batch_index is None: + batch_index = shape_batch_index shape_latent = shape_latent["samples"] batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords) if coord_counts is not None: @@ -844,8 +846,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): latent[i, :, :count] = latent_i[0] if coord_counts is not None: latent.trellis_coord_counts = coord_counts.clone() - if batch_index is None: - batch_index = shape_batch_index model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py index 43647e793..49e872bc7 100644 --- a/tests-unit/comfy_extras_test/nodes_trellis2_test.py +++ b/tests-unit/comfy_extras_test/nodes_trellis2_test.py @@ -233,6 +233,35 @@ class TestTrellisBatchSemantics(unittest.TestCase): 13, ) + def test_empty_texture_latent_uses_shape_batch_index_for_seed_fallback(self): + coords = torch.tensor( + [ + [0, 1, 1, 1], + [1, 2, 2, 2], + [1, 3, 3, 3], + ], + dtype=torch.int32, + ) + structure = {"coords": coords} + shape_latent = { + "samples": torch.zeros(2, 32, 2, 1), + "batch_index": [4, 9], + } + + output, _ = nodes_trellis2.EmptyTextureLatentTrellis2.execute( + structure, + shape_latent, + DummyCloneModel(), + 13, + ) + + expected = torch.zeros(2, 32, 2, 1) + expected[0, :, :1, :] = torch.randn(1, 32, 1, 1, generator=torch.Generator(device="cpu").manual_seed(17))[0] + expected[1, :, :2, :] = torch.randn(1, 32, 2, 1, generator=torch.Generator(device="cpu").manual_seed(22))[0] + + self.assertTrue(torch.equal(output["samples"], expected)) + self.assertEqual(output["batch_index"], [4, 9]) + def test_flatten_batched_sparse_latent_validates_coord_counts(self): samples = torch.zeros(2, 32, 3, 1) coords = torch.tensor( From 33caec301a6f1a6ab4e802555e80a0e0c5e5c83c Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 16:36:48 -0500 Subject: [PATCH 83/93] Validate Trellis coord_counts noise metadata --- comfy/sample.py | 10 ++++++++++ tests-unit/comfy_test/sample_test.py | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/comfy/sample.py b/comfy/sample.py index a4ce5f56f..878c4e984 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -9,6 +9,16 @@ import comfy.nested_tensor def prepare_noise_inner(latent_image, generator, noise_inds=None): coord_counts = getattr(latent_image, "trellis_coord_counts", None) if coord_counts is not None: + if coord_counts.ndim != 1: + raise ValueError(f"Trellis2 coord_counts must be 1D, got shape {tuple(coord_counts.shape)}") + if coord_counts.shape[0] != latent_image.size(0): + raise ValueError( + f"Trellis2 coord_counts length {coord_counts.shape[0]} does not match latent batch {latent_image.size(0)}" + ) + if (coord_counts < 0).any() or (coord_counts > latent_image.size(2)).any(): + raise ValueError( + f"Trellis2 coord_counts must be within [0, {latent_image.size(2)}], got {coord_counts.tolist()}" + ) noise = torch.zeros(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, device="cpu") if noise_inds is None: noise_inds = np.arange(latent_image.size(0), dtype=np.int64) diff --git a/tests-unit/comfy_test/sample_test.py b/tests-unit/comfy_test/sample_test.py index e76e65266..227659994 100644 --- a/tests-unit/comfy_test/sample_test.py +++ b/tests-unit/comfy_test/sample_test.py @@ -52,6 +52,25 @@ class TestPrepareNoiseInnerTrellis(unittest.TestCase): with self.assertRaises(ValueError): comfy.sample.prepare_noise_inner(latent, generator, noise_inds=[7]) + def test_coord_counts_metadata_must_match_batch_and_bounds(self): + generator = torch.Generator(device="cpu") + generator.manual_seed(456) + + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([[3, 5]], dtype=torch.int64) + with self.assertRaises(ValueError): + comfy.sample.prepare_noise_inner(latent, generator) + + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([3], dtype=torch.int64) + with self.assertRaises(ValueError): + comfy.sample.prepare_noise_inner(latent, generator) + + latent = torch.zeros(2, 4, 5, 1) + latent.trellis_coord_counts = torch.tensor([3, 6], dtype=torch.int64) + with self.assertRaises(ValueError): + comfy.sample.prepare_noise_inner(latent, generator) + if __name__ == "__main__": unittest.main() From 939ac7ebb40f486820cc39a47ff0d3c28b3c44b2 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 17:20:57 -0500 Subject: [PATCH 84/93] Omit null batch_index from Trellis upsample output --- comfy_extras/nodes_trellis2.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index d345641b1..56ec3e736 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -521,12 +521,16 @@ class Trellis2UpsampleCascade(IO.ComfyNode): final_coords_list.append(final_coords_i) output_coord_counts.append(int(final_coords_i.shape[0])) - return IO.NodeOutput({ + normalized_batch_index = normalize_batch_index(batch_index) + output = { "coords": torch.cat(final_coords_list, dim=0), "coord_counts": torch.tensor(output_coord_counts, dtype=torch.int64), "resolutions": torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64), - "batch_index": normalize_batch_index(batch_index), - },) + } + if normalized_batch_index is not None: + output["batch_index"] = normalized_batch_index + + return IO.NodeOutput(output,) dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) From 597adfce3ffa96bf5b11c187404e5b492ea28cfc Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 17:22:31 -0500 Subject: [PATCH 85/93] fix: stabilize Trellis2 mesh simplification --- comfy_extras/nodes_trellis2.py | 62 +++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 8121e261b..8501ef128 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -109,15 +109,15 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): # map voxels voxel_pos = voxel_coords.to(device).float() * voxel_size + origin verts = mesh.vertices.to(device).squeeze(0) - voxel_colors = voxel_colors.to(device) + voxel_colors = voxel_colors.cpu() - voxel_pos_np = voxel_pos.numpy() - verts_np = verts.numpy() + voxel_pos_np = voxel_pos.cpu().numpy() + verts_np = verts.cpu().numpy() tree = scipy.spatial.cKDTree(voxel_pos_np) # nearest neighbour k=1 - _, nearest_idx_np = tree.query(verts_np, k=1, workers=-1) + _, nearest_idx_np = tree.query(verts_np, k=1, workers=1) nearest_idx = torch.from_numpy(nearest_idx_np).long() v_colors = voxel_colors[nearest_idx] @@ -194,6 +194,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): IO.Latent.Input("samples"), IO.Vae.Input("vae"), IO.AnyType.Input("shape_subs"), + IO.Combo.Input("resolution", options=["512", "1024"], default="1024") ], outputs=[ IO.Mesh.Output("mesh"), @@ -201,9 +202,9 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, shape_mesh, samples, vae, shape_subs): + def execute(cls, shape_mesh, samples, vae, shape_subs, resolution): - resolution = 1024 + resolution = int(resolution) patcher = vae.patcher device = comfy.model_management.get_torch_device() comfy.model_management.load_model_gpu(patcher) @@ -617,34 +618,49 @@ def simplify_fn(vertices, faces, colors=None, target=100000): volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8) cell_size = (volume / target_v) ** (1/3.0) - quantized = ((vertices - min_v) / cell_size).round().long() - unique_coords, inverse_indices = torch.unique(quantized, dim=0, return_inverse=True) + # Use CPU-side ordered reductions here so repeated runs produce identical + # simplified meshes instead of relying on GPU scatter-add accumulation order. + vertices_np = vertices.detach().cpu().numpy() + faces_np = faces.detach().cpu().numpy() + colors_np = colors.detach().cpu().numpy() if colors is not None else None + min_v_np = min_v.detach().cpu().numpy() + cell_size_value = float(cell_size.detach().cpu()) + + quantized = np.rint((vertices_np - min_v_np) / cell_size_value).astype(np.int64) + unique_coords, inverse_indices = np.unique(quantized, axis=0, return_inverse=True) num_cells = unique_coords.shape[0] - new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device) - counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device) - new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices) - counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1])) - new_vertices = new_vertices / counts.clamp(min=1) + new_vertices_np = np.zeros((num_cells, 3), dtype=vertices_np.dtype) + np.add.at(new_vertices_np, inverse_indices, vertices_np) + + counts_np = np.bincount(inverse_indices, minlength=num_cells).astype(vertices_np.dtype).reshape(-1, 1) + new_vertices_np = new_vertices_np / np.clip(counts_np, 1, None) new_colors = None - if colors is not None: - new_colors = torch.zeros((num_cells, colors.shape[1]), dtype=colors.dtype, device=device) - new_colors.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, colors.shape[1]), colors) - new_colors = new_colors / counts.clamp(min=1) + if colors_np is not None: + new_colors_np = np.zeros((num_cells, colors_np.shape[1]), dtype=colors_np.dtype) + np.add.at(new_colors_np, inverse_indices, colors_np) + new_colors = new_colors_np / np.clip(counts_np, 1, None) - new_faces = inverse_indices[faces] + new_faces = inverse_indices[faces_np] valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \ (new_faces[:, 1] != new_faces[:, 2]) & \ (new_faces[:, 2] != new_faces[:, 0]) new_faces = new_faces[valid_mask] - unique_face_indices, inv_face = torch.unique(new_faces.reshape(-1), return_inverse=True) - final_vertices = new_vertices[unique_face_indices] - final_faces = inv_face.reshape(-1, 3) + if new_faces.size == 0: + final_vertices_np = new_vertices_np[:0] + final_faces_np = np.empty((0, 3), dtype=np.int64) + final_colors_np = new_colors[:0] if new_colors is not None else None + else: + unique_face_indices, inv_face = np.unique(new_faces.reshape(-1), return_inverse=True) + final_vertices_np = new_vertices_np[unique_face_indices] + final_faces_np = inv_face.reshape(-1, 3).astype(np.int64) + final_colors_np = new_colors[unique_face_indices] if new_colors is not None else None - # assign colors - final_colors = new_colors[unique_face_indices] if new_colors is not None else None + final_vertices = torch.from_numpy(final_vertices_np).to(device=device, dtype=vertices.dtype) + final_faces = torch.from_numpy(final_faces_np).to(device=device, dtype=faces.dtype) + final_colors = torch.from_numpy(final_colors_np).to(device=device, dtype=colors.dtype) if final_colors_np is not None else None return final_vertices, final_faces, final_colors From f15bf73d5cde33f03798b2f99932bb37234f65a5 Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 20:39:08 -0500 Subject: [PATCH 86/93] Fix Trellis VAE decode memory management --- comfy_extras/nodes_trellis2.py | 86 ++++++++++++++++++++++------------ 1 file changed, 57 insertions(+), 29 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 8121e261b..89fb2443e 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,6 +1,6 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types -from comfy.ldm.trellis2.vae import SparseTensor +from comfy.ldm.trellis2.vae import SparseTensor, sparse_cat import comfy.model_management from PIL import Image import numpy as np @@ -8,6 +8,25 @@ import torch import scipy import copy +def prepare_trellis_vae_for_decode(vae, sample_shape): + memory_required = max(1, int(vae.memory_used_decode(sample_shape, vae.vae_dtype))) + device = comfy.model_management.get_torch_device() + comfy.model_management.free_memory(memory_required, device, for_dynamic=False) + comfy.model_management.load_models_gpu( + [vae.patcher], + memory_required=memory_required, + force_full_load=getattr(vae, "disable_offload", False), + ) + free_memory = vae.patcher.get_free_memory(device) + batch_number = max(1, int(free_memory / memory_required)) + return min(sample_shape[0], batch_number) + + +def combine_sparse_sub_batches(sub_batches): + if len(sub_batches) == 1: + return sub_batches[0] + return [sparse_cat([batch[level] for batch in sub_batches], dim=0) for level in range(len(sub_batches[0]))] + def pack_variable_mesh_batch(vertices, faces, colors=None): batch_size = len(vertices) @@ -163,18 +182,24 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): def execute(cls, samples, vae, resolution): resolution = int(resolution) - patcher = vae.patcher + sample_tensor = samples["samples"] device = comfy.model_management.get_torch_device() - comfy.model_management.load_model_gpu(patcher) - - vae = vae.first_stage_model coords = samples["coords"] + batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape) + trellis_vae = vae.first_stage_model - samples = samples["samples"] - samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) - samples = shape_norm(samples, coords) + shape_samples = sample_tensor.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) + shape_latent = shape_norm(shape_samples, coords.to(device)) - mesh, subs = vae.decode_shape_slat(samples, resolution) + mesh = [] + sub_batches = [] + for start in range(0, shape_latent.shape[0], batch_number): + end = start + batch_number + mesh_chunk, subs_chunk = trellis_vae.decode_shape_slat(shape_latent[start:end], resolution) + mesh.extend(mesh_chunk) + sub_batches.append(subs_chunk) + + subs = combine_sparse_sub_batches(sub_batches) face_list = [m.faces for m in mesh] vert_list = [m.vertices for m in mesh] if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list): @@ -204,21 +229,24 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): def execute(cls, shape_mesh, samples, vae, shape_subs): resolution = 1024 - patcher = vae.patcher + sample_tensor = samples["samples"] device = comfy.model_management.get_torch_device() - comfy.model_management.load_model_gpu(patcher) - - vae = vae.first_stage_model coords = samples["coords"] + batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape) + trellis_vae = vae.first_stage_model - samples = samples["samples"] - samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) - std = tex_slat_normalization["std"].to(samples) - mean = tex_slat_normalization["mean"].to(samples) - samples = SparseTensor(feats = samples, coords=coords) - samples = samples * std + mean + tex_samples = sample_tensor.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) + std = tex_slat_normalization["std"].to(tex_samples) + mean = tex_slat_normalization["mean"].to(tex_samples) + tex_latent = SparseTensor(feats=tex_samples, coords=coords.to(device)) + tex_latent = tex_latent * std + mean - voxel = vae.decode_tex_slat(samples, shape_subs) + voxel_batches = [] + for start in range(0, tex_latent.shape[0], batch_number): + end = start + batch_number + guide_subs = [sub[start:end] for sub in shape_subs] + voxel_batches.append(trellis_vae.decode_tex_slat(tex_latent[start:end], guide_subs)) + voxel = voxel_batches[0] if len(voxel_batches) == 1 else sparse_cat(voxel_batches, dim=0) color_feats = voxel.feats[:, :3] voxel_coords = voxel.coords[:, 1:] voxel_batch_idx = voxel.coords[:, 0] @@ -266,15 +294,15 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): @classmethod def execute(cls, samples, vae, resolution): resolution = int(resolution) - vae = vae.first_stage_model - decoder = vae.struct_dec + sample_tensor = samples["samples"] + batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape) + decoder = vae.first_stage_model.struct_dec load_device = comfy.model_management.get_torch_device() - offload_device = comfy.model_management.vae_offload_device() - decoder = decoder.to(load_device) - samples = samples["samples"] - samples = samples.to(load_device) - decoded = decoder(samples)>0 - decoder.to(offload_device) + decoded_batches = [] + for start in range(0, sample_tensor.shape[0], batch_number): + sample_chunk = sample_tensor[start:start + batch_number].to(load_device) + decoded_batches.append(decoder(sample_chunk) > 0) + decoded = torch.cat(decoded_batches, dim=0) current_res = decoded.shape[2] if current_res != resolution: @@ -303,7 +331,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode): @classmethod def execute(cls, shape_latent_512, vae, target_resolution, max_tokens): device = comfy.model_management.get_torch_device() - comfy.model_management.load_model_gpu(vae.patcher) + prepare_trellis_vae_for_decode(vae, shape_latent_512["samples"].shape) feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) coords_512 = shape_latent_512["coords"].to(device) From 8816699e7c2b4d1c5c8d3595541928e92026677a Mon Sep 17 00:00:00 2001 From: John Pollock Date: Mon, 20 Apr 2026 22:10:15 -0500 Subject: [PATCH 87/93] Address Trellis VAE decode review feedback --- comfy_extras/nodes_trellis2.py | 8 +-- .../comfy_extras_test/nodes_trellis2_test.py | 51 +++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 397453562..bc2d6bcab 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -9,9 +9,11 @@ import scipy import copy def prepare_trellis_vae_for_decode(vae, sample_shape): - memory_required = max(1, int(vae.memory_used_decode(sample_shape, vae.vae_dtype))) + memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype) + if len(sample_shape) == 5: + memory_required *= max(1, int(sample_shape[4])) + memory_required = max(1, int(memory_required)) device = comfy.model_management.get_torch_device() - comfy.model_management.free_memory(memory_required, device, for_dynamic=False) comfy.model_management.load_models_gpu( [vae.patcher], memory_required=memory_required, @@ -19,7 +21,7 @@ def prepare_trellis_vae_for_decode(vae, sample_shape): ) free_memory = vae.patcher.get_free_memory(device) batch_number = max(1, int(free_memory / memory_required)) - return min(sample_shape[0], batch_number) + return batch_number def pack_variable_mesh_batch(vertices, faces, colors=None): diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py index 49e872bc7..96fb4395a 100644 --- a/tests-unit/comfy_extras_test/nodes_trellis2_test.py +++ b/tests-unit/comfy_extras_test/nodes_trellis2_test.py @@ -73,6 +73,57 @@ class DummyModel: self.model = inner_model +class DummyPatcher: + def __init__(self, free_memory): + self.free_memory = free_memory + + def get_free_memory(self, device): + return self.free_memory + + +class DummyVAE: + vae_dtype = torch.float16 + + def __init__(self, free_memory, memory_factor=2): + self.patcher = DummyPatcher(free_memory) + self.memory_factor = memory_factor + + def memory_used_decode(self, shape, dtype): + return shape[2] * shape[3] * self.memory_factor + + +class TestPrepareTrellisVaeForDecode(unittest.TestCase): + def test_uses_load_models_gpu_without_pre_freeing_memory(self): + vae = DummyVAE(free_memory=1000) + + with patch.object(nodes_trellis2.comfy.model_management, "get_torch_device", return_value="cuda"): + with patch.object(nodes_trellis2.comfy.model_management, "free_memory") as free_memory: + with patch.object(nodes_trellis2.comfy.model_management, "load_models_gpu") as load_models_gpu: + batch_number = nodes_trellis2.prepare_trellis_vae_for_decode(vae, (3, 32, 10, 1)) + + free_memory.assert_not_called() + load_models_gpu.assert_called_once_with( + [vae.patcher], + memory_required=20, + force_full_load=False, + ) + self.assertEqual(batch_number, 50) + + def test_scales_memory_estimate_for_5d_structure_latents(self): + vae = DummyVAE(free_memory=40960, memory_factor=1) + + with patch.object(nodes_trellis2.comfy.model_management, "get_torch_device", return_value="cuda"): + with patch.object(nodes_trellis2.comfy.model_management, "load_models_gpu") as load_models_gpu: + batch_number = nodes_trellis2.prepare_trellis_vae_for_decode(vae, (2, 8, 16, 16, 16)) + + load_models_gpu.assert_called_once_with( + [vae.patcher], + memory_required=4096, + force_full_load=False, + ) + self.assertEqual(batch_number, 10) + + class TestRunConditioningRestore(unittest.TestCase): def setUp(self): self.intermediate_patch = patch.object( From b29edb0ec4536c6920bbb6ae36e99dfb373f0395 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 24 Apr 2026 01:30:31 +0300 Subject: [PATCH 88/93] removed test files --- comfy_extras/nodes_trellis2.py | 2 +- .../comfy_extras_test/nodes_trellis2_test.py | 376 ------------------ tests-unit/comfy_test/sample_test.py | 76 ---- 3 files changed, 1 insertion(+), 453 deletions(-) delete mode 100644 tests-unit/comfy_extras_test/nodes_trellis2_test.py delete mode 100644 tests-unit/comfy_test/sample_test.py diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index bc2d6bcab..d6edaedb6 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -617,7 +617,7 @@ class Trellis2Conditioning(IO.ComfyNode): for b in range(batch_size): item_image = image[b] - item_mask = mask[b] + item_mask = mask[b] if mask.size(0) > 1 else mask[0] img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) diff --git a/tests-unit/comfy_extras_test/nodes_trellis2_test.py b/tests-unit/comfy_extras_test/nodes_trellis2_test.py deleted file mode 100644 index 96fb4395a..000000000 --- a/tests-unit/comfy_extras_test/nodes_trellis2_test.py +++ /dev/null @@ -1,376 +0,0 @@ -import importlib -import sys -import types -import unittest -from unittest.mock import patch - -import torch -from PIL import Image - - -class _DummyPort: - @staticmethod - def Input(*args, **kwargs): - return None - - @staticmethod - def Output(*args, **kwargs): - return None - - -class _DummyIO: - ComfyNode = object - - @staticmethod - def Schema(*args, **kwargs): - return None - - @staticmethod - def NodeOutput(*args, **kwargs): - return args - - def __getattr__(self, name): - return _DummyPort - - -class _DummyTypes: - def __getattr__(self, name): - return lambda *args, **kwargs: None - - -dummy_comfy_api_latest = types.SimpleNamespace( - ComfyExtension=object, - IO=_DummyIO(), - Types=_DummyTypes(), -) - -dummy_sparse_tensor = type("SparseTensor", (), {}) -dummy_trellis_vae = types.SimpleNamespace(SparseTensor=dummy_sparse_tensor) - -with patch.dict(sys.modules, { - "comfy_api.latest": dummy_comfy_api_latest, - "comfy.ldm.trellis2.vae": dummy_trellis_vae, -}): - nodes_trellis2 = importlib.import_module("comfy_extras.nodes_trellis2") - - -class DummyInnerModel: - def __init__(self, image_size=..., fail_on_call=None): - self.call_count = 0 - self.fail_on_call = fail_on_call - if image_size is not ...: - self.image_size = image_size - - def __call__(self, input_tensor, skip_norm_elementwise=True): - self.call_count += 1 - if self.fail_on_call == self.call_count: - raise RuntimeError("expected conditioning failure") - return (torch.ones((1, 4), dtype=torch.float32),) - - -class DummyModel: - def __init__(self, inner_model): - self.model = inner_model - - -class DummyPatcher: - def __init__(self, free_memory): - self.free_memory = free_memory - - def get_free_memory(self, device): - return self.free_memory - - -class DummyVAE: - vae_dtype = torch.float16 - - def __init__(self, free_memory, memory_factor=2): - self.patcher = DummyPatcher(free_memory) - self.memory_factor = memory_factor - - def memory_used_decode(self, shape, dtype): - return shape[2] * shape[3] * self.memory_factor - - -class TestPrepareTrellisVaeForDecode(unittest.TestCase): - def test_uses_load_models_gpu_without_pre_freeing_memory(self): - vae = DummyVAE(free_memory=1000) - - with patch.object(nodes_trellis2.comfy.model_management, "get_torch_device", return_value="cuda"): - with patch.object(nodes_trellis2.comfy.model_management, "free_memory") as free_memory: - with patch.object(nodes_trellis2.comfy.model_management, "load_models_gpu") as load_models_gpu: - batch_number = nodes_trellis2.prepare_trellis_vae_for_decode(vae, (3, 32, 10, 1)) - - free_memory.assert_not_called() - load_models_gpu.assert_called_once_with( - [vae.patcher], - memory_required=20, - force_full_load=False, - ) - self.assertEqual(batch_number, 50) - - def test_scales_memory_estimate_for_5d_structure_latents(self): - vae = DummyVAE(free_memory=40960, memory_factor=1) - - with patch.object(nodes_trellis2.comfy.model_management, "get_torch_device", return_value="cuda"): - with patch.object(nodes_trellis2.comfy.model_management, "load_models_gpu") as load_models_gpu: - batch_number = nodes_trellis2.prepare_trellis_vae_for_decode(vae, (2, 8, 16, 16, 16)) - - load_models_gpu.assert_called_once_with( - [vae.patcher], - memory_required=4096, - force_full_load=False, - ) - self.assertEqual(batch_number, 10) - - -class TestRunConditioningRestore(unittest.TestCase): - def setUp(self): - self.intermediate_patch = patch.object( - nodes_trellis2.comfy.model_management, "intermediate_device", lambda: "cpu" - ) - self.torch_device_patch = patch.object( - nodes_trellis2.comfy.model_management, "get_torch_device", lambda: "cpu" - ) - self.intermediate_patch.start() - self.torch_device_patch.start() - - def tearDown(self): - self.intermediate_patch.stop() - self.torch_device_patch.stop() - - @staticmethod - def make_test_image(): - return Image.new("RGB", (8, 8), color="white") - - def test_restores_existing_image_size_after_success(self): - inner_model = DummyInnerModel(image_size=777) - - nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) - - self.assertEqual(inner_model.image_size, 777) - - def test_deletes_missing_image_size_after_success(self): - inner_model = DummyInnerModel() - - nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) - - self.assertFalse(hasattr(inner_model, "image_size")) - - def test_restores_existing_image_size_after_512_failure(self): - inner_model = DummyInnerModel(image_size=777, fail_on_call=1) - - with self.assertRaisesRegex(RuntimeError, "expected conditioning failure"): - nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) - - self.assertEqual(inner_model.image_size, 777) - - def test_deletes_missing_image_size_after_1024_failure(self): - inner_model = DummyInnerModel(fail_on_call=2) - - with self.assertRaisesRegex(RuntimeError, "expected conditioning failure"): - nodes_trellis2.run_conditioning(DummyModel(inner_model), self.make_test_image(), include_1024=True) - - self.assertFalse(hasattr(inner_model, "image_size")) - - -class DummyCloneModel: - def __init__(self): - self.model_options = {} - - def clone(self): - cloned = DummyCloneModel() - cloned.model_options = self.model_options.copy() - return cloned - - -class TestTrellisBatchSemantics(unittest.TestCase): - def test_empty_structure_latent_is_deterministic_and_propagates_sample_indices(self): - batch_output = nodes_trellis2.EmptyStructureLatentTrellis2.execute(2, 0, 17)[0] - single_output = nodes_trellis2.EmptyStructureLatentTrellis2.execute(1, 5, 17)[0] - - expected_batch = torch.zeros(2, 8, 16, 16, 16) - expected_batch[0] = torch.randn(1, 8, 16, 16, 16, generator=torch.Generator(device="cpu").manual_seed(17))[0] - expected_batch[1] = torch.randn(1, 8, 16, 16, 16, generator=torch.Generator(device="cpu").manual_seed(18))[0] - expected_single = torch.randn(1, 8, 16, 16, 16, generator=torch.Generator(device="cpu").manual_seed(22)) - - self.assertTrue(torch.equal(batch_output["samples"], expected_batch)) - self.assertEqual(batch_output["batch_index"], [0, 1]) - self.assertTrue(torch.equal(single_output["samples"], expected_single)) - self.assertEqual(single_output["batch_index"], [5]) - - def test_empty_shape_latent_is_deterministic_and_propagates_batch_index(self): - coords = torch.tensor( - [ - [1, 5, 5, 5], - [0, 1, 1, 1], - [1, 6, 6, 6], - [0, 2, 2, 2], - [1, 7, 7, 7], - ], - dtype=torch.int32, - ) - structure = { - "coords": coords, - "coord_counts": torch.tensor([2, 3], dtype=torch.int64), - "batch_index": [4, 9], - } - - output, _ = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 23) - - expected = torch.zeros(2, 32, 3, 1) - expected[0, :, :2, :] = torch.randn(1, 32, 2, 1, generator=torch.Generator(device="cpu").manual_seed(27))[0] - expected[1, :, :3, :] = torch.randn(1, 32, 3, 1, generator=torch.Generator(device="cpu").manual_seed(32))[0] - - self.assertTrue(torch.equal(output["samples"], expected)) - self.assertTrue(torch.equal(output["coord_counts"], torch.tensor([2, 3], dtype=torch.int64))) - self.assertEqual(output["batch_index"], [4, 9]) - - def test_empty_shape_latent_keeps_singleton_coord_counts(self): - structure = { - "coords": torch.tensor( - [ - [0, 1, 1, 1], - [0, 2, 2, 2], - ], - dtype=torch.int32, - ), - } - - output, _ = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11) - - self.assertTrue(torch.equal(output["coord_counts"], torch.tensor([2], dtype=torch.int64))) - - def test_empty_shape_latent_rejects_multi_index_singleton(self): - structure = { - "coords": torch.tensor( - [ - [0, 1, 1, 1], - [0, 2, 2, 2], - ], - dtype=torch.int32, - ), - "batch_index": [5, 6], - } - - with self.assertRaises(ValueError): - nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11) - - def test_empty_texture_latent_rejects_multi_index_singleton(self): - coords = torch.tensor( - [ - [0, 1, 1, 1], - [0, 2, 2, 2], - ], - dtype=torch.int32, - ) - structure = {"coords": coords, "batch_index": [7, 8]} - shape_latent = {"samples": torch.zeros(1, 32, 2, 1)} - - with self.assertRaises(ValueError): - nodes_trellis2.EmptyTextureLatentTrellis2.execute( - structure, - shape_latent, - DummyCloneModel(), - 13, - ) - - def test_empty_texture_latent_rejects_invalid_structure_input(self): - with self.assertRaises(ValueError): - nodes_trellis2.EmptyTextureLatentTrellis2.execute( - "bad-input", - {"samples": torch.zeros(1, 32, 2, 1)}, - DummyCloneModel(), - 13, - ) - - def test_empty_texture_latent_uses_shape_batch_index_for_seed_fallback(self): - coords = torch.tensor( - [ - [0, 1, 1, 1], - [1, 2, 2, 2], - [1, 3, 3, 3], - ], - dtype=torch.int32, - ) - structure = {"coords": coords} - shape_latent = { - "samples": torch.zeros(2, 32, 2, 1), - "batch_index": [4, 9], - } - - output, _ = nodes_trellis2.EmptyTextureLatentTrellis2.execute( - structure, - shape_latent, - DummyCloneModel(), - 13, - ) - - expected = torch.zeros(2, 32, 2, 1) - expected[0, :, :1, :] = torch.randn(1, 32, 1, 1, generator=torch.Generator(device="cpu").manual_seed(17))[0] - expected[1, :, :2, :] = torch.randn(1, 32, 2, 1, generator=torch.Generator(device="cpu").manual_seed(22))[0] - - self.assertTrue(torch.equal(output["samples"], expected)) - self.assertEqual(output["batch_index"], [4, 9]) - - def test_flatten_batched_sparse_latent_validates_coord_counts(self): - samples = torch.zeros(2, 32, 3, 1) - coords = torch.tensor( - [ - [0, 1, 1, 1], - [1, 2, 2, 2], - [1, 3, 3, 3], - ], - dtype=torch.int32, - ) - coord_counts = torch.tensor([2, 1], dtype=torch.int64) - - with self.assertRaises(ValueError): - nodes_trellis2.flatten_batched_sparse_latent(samples, coords, coord_counts) - - def test_infer_batched_coord_layout_rejects_negative_batch_ids(self): - coords = torch.tensor( - [ - [-1, 1, 1, 1], - [0, 2, 2, 2], - ], - dtype=torch.int32, - ) - - with self.assertRaises(ValueError): - nodes_trellis2.infer_batched_coord_layout(coords) - - def test_split_batched_coords_validates_total_count(self): - coords = torch.tensor( - [ - [0, 1, 1, 1], - [1, 2, 2, 2], - [1, 3, 3, 3], - ], - dtype=torch.int32, - ) - coord_counts = torch.tensor([1, 1], dtype=torch.int64) - - with self.assertRaises(ValueError): - nodes_trellis2.split_batched_coords(coords, coord_counts) - - def test_empty_shape_latent_preserves_resolutions_key(self): - structure = { - "coords": torch.tensor( - [ - [0, 1, 1, 1], - [0, 2, 2, 2], - ], - dtype=torch.int32, - ), - "resolutions": torch.tensor([1024], dtype=torch.int64), - } - - output, model = nodes_trellis2.EmptyShapeLatentTrellis2.execute(structure, DummyCloneModel(), 11) - - self.assertTrue(torch.equal(output["resolutions"], torch.tensor([1024], dtype=torch.int64))) - self.assertNotIn("coord_resolutions", model.model_options["transformer_options"]) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests-unit/comfy_test/sample_test.py b/tests-unit/comfy_test/sample_test.py deleted file mode 100644 index 227659994..000000000 --- a/tests-unit/comfy_test/sample_test.py +++ /dev/null @@ -1,76 +0,0 @@ -import unittest - -import torch - -import comfy.sample - - -class TestPrepareNoiseInnerTrellis(unittest.TestCase): - def test_coord_counts_noise_matches_per_index_prefix_draws(self): - latent = torch.zeros(2, 4, 5, 1) - latent.trellis_coord_counts = torch.tensor([3, 5], dtype=torch.int64) - - generator = torch.Generator(device="cpu") - generator.manual_seed(123) - noise = comfy.sample.prepare_noise_inner(latent, generator) - - expected = torch.zeros_like(noise, dtype=torch.float32) - row0 = torch.Generator(device="cpu") - row0.manual_seed(123) - expected[0, :, :3, :] = torch.randn(1, 4, 3, 1, generator=row0)[0] - row1 = torch.Generator(device="cpu") - row1.manual_seed(124) - expected[1] = torch.randn(1, 4, 5, 1, generator=row1)[0] - - self.assertTrue(torch.equal(noise.float(), expected)) - self.assertTrue(torch.equal(noise[0, :, 3:, :], torch.zeros_like(noise[0, :, 3:, :]))) - - def test_coord_counts_noise_inds_share_prefixes_for_duplicates(self): - latent = torch.zeros(2, 4, 5, 1) - latent.trellis_coord_counts = torch.tensor([3, 5], dtype=torch.int64) - - generator = torch.Generator(device="cpu") - generator.manual_seed(456) - noise = comfy.sample.prepare_noise_inner(latent, generator, noise_inds=[7, 7]) - - replay = torch.Generator(device="cpu") - replay.manual_seed(463) - expected1 = torch.randn(1, 4, 5, 1, generator=replay) - expected0 = expected1[:, :, :3, :] - - self.assertTrue(torch.equal(noise[0:1, :, :3, :], expected0)) - self.assertTrue(torch.equal(noise[1:2, :, :5, :], expected1)) - self.assertTrue(torch.equal(noise[0, :, 3:, :], torch.zeros_like(noise[0, :, 3:, :]))) - - def test_coord_counts_noise_inds_length_must_match_batch(self): - latent = torch.zeros(2, 4, 5, 1) - latent.trellis_coord_counts = torch.tensor([3, 5], dtype=torch.int64) - - generator = torch.Generator(device="cpu") - generator.manual_seed(456) - - with self.assertRaises(ValueError): - comfy.sample.prepare_noise_inner(latent, generator, noise_inds=[7]) - - def test_coord_counts_metadata_must_match_batch_and_bounds(self): - generator = torch.Generator(device="cpu") - generator.manual_seed(456) - - latent = torch.zeros(2, 4, 5, 1) - latent.trellis_coord_counts = torch.tensor([[3, 5]], dtype=torch.int64) - with self.assertRaises(ValueError): - comfy.sample.prepare_noise_inner(latent, generator) - - latent = torch.zeros(2, 4, 5, 1) - latent.trellis_coord_counts = torch.tensor([3], dtype=torch.int64) - with self.assertRaises(ValueError): - comfy.sample.prepare_noise_inner(latent, generator) - - latent = torch.zeros(2, 4, 5, 1) - latent.trellis_coord_counts = torch.tensor([3, 6], dtype=torch.int64) - with self.assertRaises(ValueError): - comfy.sample.prepare_noise_inner(latent, generator) - - -if __name__ == "__main__": - unittest.main() From e180d4ad799f533b82bb2e8e83f977317d458ff9 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 7 May 2026 18:47:03 +0300 Subject: [PATCH 89/93] simplify and optimize model.forward --- comfy/ldm/trellis2/model.py | 272 ++++++++++-------------------------- 1 file changed, 70 insertions(+), 202 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index e8ed39aed..a54e4ca94 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -779,66 +779,54 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): transformer_options = kwargs.get("transformer_options", {}) + timestep = timestep.to(x.dtype) embeds = kwargs.get("embeds") if embeds is None: raise ValueError("Trellis2.forward requires 'embeds' in kwargs") - # img2shape.resolution is the latent-grid size, not the input pixel size: - # 32 -> 512px path, 64 -> 1024px path. - uses_1024_conditioning = self.img2shape.resolution == 64 + + is_1024 = self.img2shape.resolution == 1024 coords = transformer_options.get("coords", None) - coord_counts = transformer_options.get("coord_counts") + coord_counts = transformer_options.get("coord_counts", None) mode = transformer_options.get("generation_mode", "structure_generation") + is_512_run = False - timestep = timestep.to(self.dtype) if mode == "shape_generation_512": is_512_run = True mode = "shape_generation" + if coords is not None: - x = x.squeeze(-1).transpose(1, 2) + if x.ndim == 4: + x = x.squeeze(-1).transpose(1, 2) not_struct_mode = True else: mode = "structure_generation" not_struct_mode = False - if uses_1024_conditioning and not_struct_mode and not is_512_run: + if is_1024 and not_struct_mode and not is_512_run: context = embeds sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: timestep *= 1000.0 + if context.size(0) > 1: cond = context.chunk(2)[1] else: cond = context + shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] - dense_out = None - cond_or_uncond = transformer_options.get("cond_or_uncond") or [] - - def cond_group_indices(batch_groups): - if len(cond_or_uncond) == batch_groups: - cond_groups = [i for i, marker in enumerate(cond_or_uncond) if marker == 0] - if len(cond_groups) > 0: - return cond_groups - return [batch_groups - 1] if not_struct_mode: orig_bsz = x.shape[0] rule = txt_rule if mode == "texture_generation" else shape_rule - logical_batch = coord_counts.shape[0] if coord_counts is not None else 1 - if rule and orig_bsz > logical_batch: - batch_groups = orig_bsz // logical_batch - selected_groups = cond_group_indices(batch_groups) - x_groups = x.reshape(batch_groups, logical_batch, *x.shape[1:]) - x_eval = x_groups[selected_groups].reshape(-1, *x.shape[1:]) - if timestep.shape[0] > 1: - t_groups = timestep.reshape(batch_groups, logical_batch, *timestep.shape[1:]) - t_eval = t_groups[selected_groups].reshape(-1, *timestep.shape[1:]) - else: - t_eval = timestep - c_groups = context.reshape(batch_groups, logical_batch, *context.shape[1:]) - c_eval = c_groups[selected_groups].reshape(-1, *context.shape[1:]) + # 1. CFG Bypass Slicing + if rule and orig_bsz > 1: + half = orig_bsz // 2 + x_eval = x[half:] + t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep + c_eval = cond else: x_eval = x t_eval = timestep @@ -846,112 +834,45 @@ class Trellis2(nn.Module): B, N, C = x_eval.shape + # 2. Vectorized SparseTensor Construction (NO FOR LOOPS!) if mode in ["shape_generation", "texture_generation"]: if coord_counts is not None: logical_batch = coord_counts.shape[0] - if B % logical_batch != 0: - raise ValueError( - f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}" - ) - if int(coord_counts.sum().item()) != coords.shape[0]: - raise ValueError( - f"Trellis2 coord_counts total {int(coord_counts.sum().item())} does not match coords rows {coords.shape[0]}" - ) - batch_ids = coords[:, 0].to(torch.int64) - order = torch.argsort(batch_ids, stable=True) - sorted_coords = coords.index_select(0, order) - sorted_batch_ids = batch_ids.index_select(0, order) - offsets = coord_counts.cumsum(0) - coord_counts - coords_by_batch = [] - for i in range(logical_batch): - count = int(coord_counts[i].item()) - start = int(offsets[i].item()) - coords_i = sorted_coords[start:start + count] - ids_i = sorted_batch_ids[start:start + count] - if coords_i.shape[0] != count or not torch.all(ids_i == i): - raise ValueError( - f"Trellis2 coords rows for batch {i} expected {count}, got {coords_i.shape[0]}" - ) - coords_by_batch.append(coords_i) - repeat_factor = B // logical_batch - sparse_outs = [] - active_coord_counts = [] - for rep in range(repeat_factor): - for i in range(logical_batch): - out_index = rep * logical_batch + i - count = int(coord_counts[i].item()) - if count > N: - raise ValueError( - f"Trellis2 coord count {count} exceeds latent token dimension {N} for batch {i}" - ) - coords_i = coords_by_batch[i].clone() - coords_i[:, 0] = 0 - feats_i = x_eval[out_index, :count].clone() - x_st_i = SparseTensor(feats=feats_i, coords=coords_i.to(torch.int32)) - t_i = t_eval[out_index].unsqueeze(0).clone() if t_eval.shape[0] > 1 else t_eval - c_i = c_eval[out_index].unsqueeze(0).clone() if c_eval.shape[0] > 1 else c_eval + # Duplicate coords if CFG is active + if B > logical_batch: + c_pos = coords.clone() + c_pos[:, 0] += logical_batch + batched_coords = torch.cat([coords, c_pos], dim=0) + counts_eval = torch.cat([coord_counts, coord_counts], dim=0) + else: + batched_coords = coords + counts_eval = coord_counts - if mode == "shape_generation": - if is_512_run: - sparse_out = self.img2shape_512(x_st_i, t_i, c_i) - else: - sparse_out = self.img2shape(x_st_i, t_i, c_i) - else: - slat = transformer_options.get("shape_slat") - if slat is None: - raise ValueError("shape_slat can't be None") - if slat.ndim == 3: - if slat.shape[0] != logical_batch: - raise ValueError( - f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}" - ) - if slat.shape[1] < count: - raise ValueError( - f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}" - ) - slat_feats = slat[i, :count].to(x_st_i.device) - else: - slat_feats = slat[:count].to(x_st_i.device) - x_st_i = x_st_i.replace(feats=torch.cat([x_st_i.feats, slat_feats], dim=-1)) - sparse_out = self.shape2txt(x_st_i, t_i, c_i) - - sparse_outs.append(sparse_out.feats) - active_coord_counts.append(count) - - out_channels = sparse_outs[0].shape[-1] - padded = sparse_outs[0].new_zeros((B, N, out_channels)) - for out_index, (count, feats_i) in enumerate(zip(active_coord_counts, sparse_outs)): - padded[out_index, :count] = feats_i - dense_out = padded.transpose(1, 2).unsqueeze(-1) - elif coords.shape[0] == N: + # Create boolean mask [B, N] to drop the padded zeros instantly + mask = torch.arange(N, device=x.device).unsqueeze(0) < counts_eval.unsqueeze(1) + feats_flat = x_eval[mask] + else: feats_flat = x_eval.reshape(-1, C) - coords_list = [] + coords_list =[] for i in range(B): c = coords.clone() c[:, 0] = i coords_list.append(c) batched_coords = torch.cat(coords_list, dim=0) - elif coords.shape[0] == B * N: - feats_flat = x_eval.reshape(-1, C) - batched_coords = coords - else: - raise ValueError( - f"Trellis2 expected coords rows {N} or {B * N}, got {coords.shape[0]}" - ) + mask = None else: batched_coords = coords feats_flat = x_eval + mask = None - if dense_out is None: - x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) + x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) - if dense_out is not None: - out = dense_out - elif mode == "shape_generation": + if mode == "shape_generation": if is_512_run: out = self.img2shape_512(x_st, t_eval, c_eval) else: out = self.img2shape(x_st, t_eval, c_eval) + elif mode == "texture_generation": if self.shape2txt is None: raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") @@ -959,96 +880,43 @@ class Trellis2(nn.Module): if slat is None: raise ValueError("shape_slat can't be None") - if slat.ndim == 3: - if coord_counts is not None: - logical_batch = coord_counts.shape[0] - if slat.shape[0] != logical_batch: - raise ValueError( - f"shape_slat batch {slat.shape[0]} doesn't match coord_counts batch {logical_batch}" - ) - if B % logical_batch != 0: - raise ValueError( - f"Trellis2 coord_counts batch {logical_batch} doesn't divide latent batch {B}" - ) - repeat_factor = B // logical_batch - slat_list = [] - for _ in range(repeat_factor): - for i in range(logical_batch): - count = int(coord_counts[i].item()) - if slat.shape[1] < count: - raise ValueError( - f"shape_slat tokens {slat.shape[1]} can't cover coord count {count} for batch {i}" - ) - slat_list.append(slat[i, :count]) - slat_feats_batched = torch.cat(slat_list, dim=0).to(x_st.device) - else: - if slat.shape[0] != B: - raise ValueError(f"shape_slat batch {slat.shape[0]} doesn't match latent batch {B}") - if slat.shape[1] != N: - raise ValueError(f"shape_slat tokens {slat.shape[1]} doesn't match latent tokens {N}") - slat_feats_batched = slat.reshape(B * N, -1).to(x_st.device) - else: - base_slat_feats = slat[:N] - slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device) - x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1)) + slat_feats = slat.feats + # Duplicate shape context if CFG is active + if coord_counts is not None and B > coord_counts.shape[0]: + slat_feats = torch.cat([slat_feats, slat_feats], dim=0) + elif coord_counts is None: + slat_feats = slat.feats[:N].repeat(B, 1) + + x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats], dim=-1)) out = self.shape2txt(x_st, t_eval, c_eval) + else: # structure orig_bsz = x.shape[0] - batch_groups = len(cond_or_uncond) if len(cond_or_uncond) > 0 and orig_bsz % len(cond_or_uncond) == 0 else 1 - logical_batch = orig_bsz // batch_groups - if logical_batch > 1: - x_groups = x.reshape(batch_groups, logical_batch, *x.shape[1:]) - if timestep.shape[0] > 1: - t_groups = timestep.reshape(batch_groups, logical_batch, *timestep.shape[1:]) - else: - t_groups = timestep - c_groups = context.reshape(batch_groups, logical_batch, *context.shape[1:]) - - if shape_rule and batch_groups > 1: - selected_group_indices = cond_group_indices(batch_groups) - else: - selected_group_indices = list(range(batch_groups)) - - out_groups = [] - for sample_index in range(logical_batch): - if shape_rule and batch_groups > 1: - x_i = x_groups[selected_group_indices, sample_index] - if timestep.shape[0] > 1: - t_i = t_groups[selected_group_indices, sample_index] - else: - t_i = timestep - c_i = c_groups[selected_group_indices, sample_index] - else: - x_i = x_groups[selected_group_indices, sample_index] - if timestep.shape[0] > 1: - t_i = t_groups[selected_group_indices, sample_index] - else: - t_i = timestep - c_i = c_groups[selected_group_indices, sample_index] - out_groups.append(self.structure_model(x_i, t_i, c_i)) - - out = out_groups[0].new_zeros((orig_bsz, *out_groups[0].shape[1:])) - for sample_index, out_sample in enumerate(out_groups): - if shape_rule and batch_groups > 1: - repeated = out_sample[0] - for group_index in range(batch_groups): - out[group_index * logical_batch + sample_index] = repeated - else: - for local_group_index, group_index in enumerate(selected_group_indices): - out[group_index * logical_batch + sample_index] = out_sample[local_group_index] + if shape_rule and orig_bsz > 1: + half = orig_bsz // 2 + x_eval = x[half:] + t_eval = timestep[half:] if timestep.shape[0] > 1 else timestep + out = self.structure_model(x_eval, t_eval, cond) + out = out.repeat(2, 1, 1, 1, 1) else: - if shape_rule and orig_bsz > 1: - half = orig_bsz // 2 - x = x[half:] - timestep = timestep[half:] if timestep.shape[0] > 1 else timestep - out = self.structure_model(x, timestep, cond if shape_rule and orig_bsz > 1 else context) - if shape_rule and orig_bsz > 1: - out = out.repeat(2, 1, 1, 1, 1) + out = self.structure_model(x, timestep, context) + # ================================================== + # RE-PAD AND FORMAT OUTPUT + # ================================================== if not_struct_mode: - if dense_out is None: - out = out.feats - out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) - if rule and orig_bsz > B: - out = out.repeat(orig_bsz // B, 1, 1, 1) + if mask is not None: + # Instantly scatter the valid tokens back into a padded rectangular tensor + padded_out = torch.zeros((B, N, out.feats.shape[-1]), device=x.device, dtype=out.feats.dtype) + padded_out[mask] = out.feats + out_tensor = padded_out.transpose(1, 2).unsqueeze(-1) + else: + out_tensor = out.feats.view(B, N, -1).transpose(1, 2).unsqueeze(-1) + + if rule and orig_bsz > 1: + out_tensor = out_tensor.repeat(2, 1, 1, 1) + return out_tensor + #else: + # out = torch.nn.functional.pad(out, (0, 0, 0, 0, 0, 0, 24, 0)) + return out From 94adce93ab8c6affb811c67c0589c8fffe204995 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 8 May 2026 15:13:07 +0300 Subject: [PATCH 90/93] update the simplify function --- comfy_extras/nodes_trellis2.py | 322 ++++++++++++++++++++++++++++----- 1 file changed, 277 insertions(+), 45 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index d6edaedb6..704f6f32f 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -911,12 +911,12 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): output["batch_index"] = sample_indices return IO.NodeOutput(output) -def simplify_fn(vertices, faces, colors=None, target=100000): +def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=None): if vertices.ndim == 3: v_list, f_list, c_list = [], [], [] for i in range(vertices.shape[0]): c_in = colors[i] if colors is not None else None - v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target) + v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target, max_edge_length) v_list.append(v_i) f_list.append(f_i) if c_i is not None: @@ -929,60 +929,292 @@ def simplify_fn(vertices, faces, colors=None, target=100000): return vertices, faces, colors device = vertices.device - target_v = max(target / 4.0, 1.0) + dtype = vertices.dtype - min_v = vertices.min(dim=0)[0] - max_v = vertices.max(dim=0)[0] - extent = max_v - min_v + verts_np = vertices.detach().cpu().numpy().astype(np.float64) + faces_np = faces.detach().cpu().numpy().astype(np.int64) + colors_np = ( + colors.detach().cpu().numpy().astype(np.float64) + if colors is not None + else None + ) - volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8) - cell_size = (volume / target_v) ** (1/3.0) + out_v, out_f, out_c = _qem_simplify_robust( + verts_np, faces_np, colors_np, target, device, max_edge_length + ) - # Use CPU-side ordered reductions here so repeated runs produce identical - # simplified meshes instead of relying on GPU scatter-add accumulation order. - vertices_np = vertices.detach().cpu().numpy() - faces_np = faces.detach().cpu().numpy() - colors_np = colors.detach().cpu().numpy() if colors is not None else None - min_v_np = min_v.detach().cpu().numpy() - cell_size_value = float(cell_size.detach().cpu()) + final_v = out_v.to(device=device, dtype=dtype) + final_f = out_f.to(device=device, dtype=faces.dtype) + final_c = ( + out_c.to(device=device, dtype=colors.dtype) + if out_c is not None + else None + ) + return final_v, final_f, final_c - quantized = np.rint((vertices_np - min_v_np) / cell_size_value).astype(np.int64) - unique_coords, inverse_indices = np.unique(quantized, axis=0, return_inverse=True) - num_cells = unique_coords.shape[0] +def _qem_simplify_robust(verts_np, faces_np, colors_np, target_faces, device, max_edge_length=None): + verts = torch.from_numpy(verts_np).to(device=device, dtype=torch.float64) + faces = torch.from_numpy(faces_np).to(device=device, dtype=torch.int64) + colors = ( + torch.from_numpy(colors_np).to(device=device, dtype=torch.float64) + if colors_np is not None + else None + ) - new_vertices_np = np.zeros((num_cells, 3), dtype=vertices_np.dtype) - np.add.at(new_vertices_np, inverse_indices, vertices_np) + num_verts = verts.shape[0] + num_faces = faces.shape[0] - counts_np = np.bincount(inverse_indices, minlength=num_cells).astype(vertices_np.dtype).reshape(-1, 1) - new_vertices_np = new_vertices_np / np.clip(counts_np, 1, None) + v_alive = torch.ones(num_verts, dtype=torch.bool, device=device) + f_alive = torch.ones(num_faces, dtype=torch.bool, device=device) - new_colors = None - if colors_np is not None: - new_colors_np = np.zeros((num_cells, colors_np.shape[1]), dtype=colors_np.dtype) - np.add.at(new_colors_np, inverse_indices, colors_np) - new_colors = new_colors_np / np.clip(counts_np, 1, None) + Q = _build_quadrics_fast(verts, faces) - new_faces = inverse_indices[faces_np] - valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \ - (new_faces[:, 1] != new_faces[:, 2]) & \ - (new_faces[:, 2] != new_faces[:, 0]) - new_faces = new_faces[valid_mask] + # Mesh scale for relative thresholds + bbox = verts.max(dim=0)[0] - verts.min(dim=0)[0] + mesh_scale = torch.norm(bbox).item() - if new_faces.size == 0: - final_vertices_np = new_vertices_np[:0] - final_faces_np = np.empty((0, 3), dtype=np.int64) - final_colors_np = new_colors[:0] if new_colors is not None else None - else: - unique_face_indices, inv_face = np.unique(new_faces.reshape(-1), return_inverse=True) - final_vertices_np = new_vertices_np[unique_face_indices] - final_faces_np = inv_face.reshape(-1, 3).astype(np.int64) - final_colors_np = new_colors[unique_face_indices] if new_colors is not None else None + # Default max_edge_length: 2x bounding box diagonal (MeshLib-style) + if max_edge_length is None or max_edge_length <= 0: + max_edge_length = mesh_scale * 2.0 - final_vertices = torch.from_numpy(final_vertices_np).to(device=device, dtype=vertices.dtype) - final_faces = torch.from_numpy(final_faces_np).to(device=device, dtype=faces.dtype) - final_colors = torch.from_numpy(final_colors_np).to(device=device, dtype=colors.dtype) if final_colors_np is not None else None + # Stabilizer: regularization to prevent extreme vertex movement + stabilizer = mesh_scale * mesh_scale * 0.001 # MeshLib default ~0.001 * scale^2 - return final_vertices, final_faces, final_colors + iteration = 0 + while True: + n_faces = int(f_alive.sum().item()) + if n_faces <= target_faces: + break + + alive_v = torch.nonzero(v_alive, as_tuple=True)[0] + alive_f = torch.nonzero(f_alive, as_tuple=True)[0] + + if alive_v.numel() <= 4 or alive_f.numel() == 0: + break + + # ---- compact active mesh ------------------------------------------- + vmap = torch.full((num_verts,), -1, dtype=torch.int64, device=device) + vmap[alive_v] = torch.arange(alive_v.numel(), device=device) + + active_faces = faces[alive_f] + remapped = vmap[active_faces] + + # ---- extract edges -------------------------------------------------- + e0 = remapped[:, [0, 1]] + e1 = remapped[:, [1, 2]] + e2 = remapped[:, [2, 0]] + edges = torch.cat([e0, e1, e2], dim=0) + edges = torch.sort(edges, dim=1)[0] + edges = edges[(edges >= 0).all(dim=1)] + edges = edges[edges[:, 0] != edges[:, 1]] + + if edges.shape[0] == 0: + break + + edges_orig = alive_v[edges] + + # ---- MeshLib-style: only process edges longer than maxEdgeLen ------ + pa = verts[edges_orig[:, 0]] + pb = verts[edges_orig[:, 1]] + el = torch.norm(pb - pa, dim=-1) + + long_enough = el > max_edge_length * 0.1 # Allow some tolerance + if not long_enough.any(): + # If no long edges, lower threshold + long_enough = el > max_edge_length * 0.01 + + edges_orig = edges_orig[long_enough] + if edges_orig.shape[0] == 0: + break + + # subsample so we never chew on >300 k edges + if edges_orig.shape[0] > 300_000: + step = edges_orig.shape[0] // 300_000 + 1 + edges_orig = edges_orig[::step] + + n_edges = edges_orig.shape[0] + if n_edges == 0: + break + + # chunking the qem + Q0 = Q[edges_orig[:, 0]] + Q1 = Q[edges_orig[:, 1]] + Qe = Q0 + Q1 + + A = Qe[:, :3, :3] + b = -Qe[:, :3, 3] + + optimal = torch.zeros((n_edges, 3), dtype=torch.float64, device=device) + SOLVE_CHUNK = 50_000 + + for i in range(0, n_edges, SOLVE_CHUNK): + sl = slice(i, min(i + SOLVE_CHUNK, n_edges)) + A_c = A[sl] + b_c = b[sl].unsqueeze(-1) + + # Add stabilizer to prevent extreme solutions + A_reg = A_c + torch.eye(3, device=device, dtype=torch.float64).unsqueeze(0) * stabilizer + + dets = torch.det(A_reg) + good = dets.abs() > 1e-12 + + if good.any(): + try: + sol = torch.linalg.solve(A_reg[good], b_c[good]) + good_idx = torch.nonzero(good, as_tuple=True)[0] + i + optimal[good_idx] = sol.squeeze(-1) + except RuntimeError: + good = torch.zeros_like(good) + + if (~good).any(): + bad_idx = torch.nonzero(~good, as_tuple=True)[0] + i + va = edges_orig[bad_idx, 0] + vb = edges_orig[bad_idx, 1] + optimal[bad_idx] = (verts[va] + verts[vb]) * 0.5 + + # ---- error = v^T Q v (homogeneous) -------------------------------- + v4 = torch.cat([ + optimal, + torch.ones((n_edges, 1), device=device, dtype=torch.float64) + ], dim=1) + err = torch.abs(torch.einsum("ei,eij,ej->e", v4, Qe, v4)) + + # geometeric guards + pa = verts[edges_orig[:, 0]] + pb = verts[edges_orig[:, 1]] + el = torch.norm(pb - pa, dim=-1) + + # reject near zero edges + length_ok = el > mesh_scale * 1e-5 + + # moderate wander: stabilizer keeps optimal close, so we can be looser + dist_a = torch.norm(optimal - pa, dim=-1) + dist_b = torch.norm(optimal - pb, dim=-1) + wander_ok = (dist_a <= 4.0 * el) & (dist_b <= 4.0 * el) + + nan_ok = ~torch.isnan(optimal).any(dim=-1) + + # MAX ERROR CAP: hard limit on quadric error (MeshLib-style) + # Prevents collapses that would remove too much detail + max_error = max_edge_length * max_edge_length + error_ok = err < max_error + + valid = length_ok & wander_ok & nan_ok & error_ok + if not valid.any(): + break + + valid_idx = torch.nonzero(valid, as_tuple=True)[0] + edges_orig = edges_orig[valid_idx] + optimal = optimal[valid_idx] + err = err[valid_idx] + + # ---- vectorized greedy independent set ------------------------------ + sorted_idx = torch.argsort(err) + used = torch.zeros(num_verts, dtype=torch.bool, device=device) + used[~v_alive] = True + + max_collapses = max(2_000, (n_faces - target_faces) // 5) + selected_edges = [] + n_selected = 0 + GREEDY_CHUNK = 100_000 + + for start in range(0, sorted_idx.numel(), GREEDY_CHUNK): + chunk = sorted_idx[start:start + GREEDY_CHUNK] + va = edges_orig[chunk, 0] + vb = edges_orig[chunk, 1] + + valid_mask = ~used[va] & ~used[vb] + if not valid_mask.any(): + continue + + sel = chunk[valid_mask] + selected_edges.append(sel) + + used[edges_orig[sel, 0]] = True + used[edges_orig[sel, 1]] = True + n_selected += sel.numel() + + if n_selected >= max_collapses: + break + + if n_selected == 0: + break + + sel = torch.cat(selected_edges) + + # ---- apply collapses ------------------------------------------------ + v_a = edges_orig[sel, 0] + v_b = edges_orig[sel, 1] + + verts[v_a] = optimal[sel] + v_alive[v_b] = False + Q[v_a] += Q[v_b] + + if colors is not None: + colors[v_a] = (colors[v_a] + colors[v_b]) * 0.5 + + merge_map = torch.arange(num_verts, device=device) + merge_map[v_b] = v_a + faces = merge_map[faces] + + bad = ( + (faces[:, 0] == faces[:, 1]) + | (faces[:, 1] == faces[:, 2]) + | (faces[:, 2] == faces[:, 0]) + ) + f_alive &= ~bad + + iteration += 1 + if iteration % 5 == 0 and int(f_alive.sum().item()) < num_faces * 0.5: + faces = faces[f_alive] + f_alive = torch.ones(faces.shape[0], dtype=torch.bool, device=device) + num_faces = faces.shape[0] + + final_v = verts[v_alive] + final_c = colors[v_alive] if colors is not None else None + + remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device) + remap[v_alive] = torch.arange(int(v_alive.sum().item()), device=device) + final_f = remap[faces[f_alive]] + + if final_f.numel() > 0: + final_f = torch.unique(torch.sort(final_f, dim=1)[0], dim=0) + + return final_v, final_f, final_c + + +def _build_quadrics_fast(verts, faces): + """GPU quadric build. Fast; non-deterministic on CUDA.""" + v0 = verts[faces[:, 0]] + v1 = verts[faces[:, 1]] + v2 = verts[faces[:, 2]] + + e1 = v1 - v0 + e2 = v2 - v0 + n = torch.cross(e1, e2, dim=-1) + area = torch.norm(n, dim=-1) + + mask = area > 1e-12 + n_norm = torch.zeros_like(n) + n_norm[mask] = n[mask] / area[mask].unsqueeze(-1) + + d = -(n_norm * v0).sum(dim=-1, keepdim=True) + p = torch.cat([n_norm, d], dim=-1) + + K = torch.einsum("fi,fj->fij", p, p) + K = K * area[:, None, None] + + V = verts.shape[0] + Q = torch.zeros((V, 4, 4), dtype=torch.float64, device=verts.device) + + K_flat = K.reshape(-1, 16) + Q_flat = Q.reshape(V, 16) + + for corner in range(3): + idx = faces[:, corner].unsqueeze(1).expand(-1, 16) + Q_flat.scatter_add_(0, idx, K_flat) + + return Q_flat.reshape(V, 4, 4) def fill_holes_fn(vertices, faces, max_perimeter=0.03): is_batched = vertices.ndim == 3 From 9d0f678f6f51ae707a34dc5c3fddf8dd1c7d74af Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 8 May 2026 19:03:06 +0300 Subject: [PATCH 91/93] removing seeds from node display --- comfy/ldm/trellis2/model.py | 22 ++-- comfy_extras/nodes_trellis2.py | 211 +++++---------------------------- 2 files changed, 41 insertions(+), 192 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index a54e4ca94..14810d56d 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -802,6 +802,11 @@ class Trellis2(nn.Module): mode = "structure_generation" not_struct_mode = False + if not not_struct_mode: + bsz = x.size(0) + x = x[:, :8] + x = x.view(bsz, 8, 16, 16, 16) + if is_1024 and not_struct_mode and not is_512_run: context = embeds @@ -821,7 +826,7 @@ class Trellis2(nn.Module): orig_bsz = x.shape[0] rule = txt_rule if mode == "texture_generation" else shape_rule - # 1. CFG Bypass Slicing + # CFG Bypass Slicing if rule and orig_bsz > 1: half = orig_bsz // 2 x_eval = x[half:] @@ -834,7 +839,7 @@ class Trellis2(nn.Module): B, N, C = x_eval.shape - # 2. Vectorized SparseTensor Construction (NO FOR LOOPS!) + # Vectorized SparseTensor Construction if mode in ["shape_generation", "texture_generation"]: if coord_counts is not None: logical_batch = coord_counts.shape[0] @@ -880,14 +885,14 @@ class Trellis2(nn.Module): if slat is None: raise ValueError("shape_slat can't be None") - slat_feats = slat.feats + slat_feats = slat # Duplicate shape context if CFG is active if coord_counts is not None and B > coord_counts.shape[0]: slat_feats = torch.cat([slat_feats, slat_feats], dim=0) elif coord_counts is None: - slat_feats = slat.feats[:N].repeat(B, 1) + slat_feats = slat_feats[:N].repeat(B, 1) - x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats], dim=-1)) + x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats.to(x_st.feats.device)], dim=-1)) out = self.shape2txt(x_st, t_eval, c_eval) else: # structure @@ -901,9 +906,6 @@ class Trellis2(nn.Module): else: out = self.structure_model(x, timestep, context) - # ================================================== - # RE-PAD AND FORMAT OUTPUT - # ================================================== if not_struct_mode: if mask is not None: # Instantly scatter the valid tokens back into a padded rectangular tensor @@ -916,7 +918,7 @@ class Trellis2(nn.Module): if rule and orig_bsz > 1: out_tensor = out_tensor.repeat(2, 1, 1, 1) return out_tensor - #else: - # out = torch.nn.functional.pad(out, (0, 0, 0, 0, 0, 0, 24, 0)) + else: + out = torch.nn.functional.pad(out, (0, 0, 0, 0, 0, 0, 0, 24)) return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 704f6f32f..e65fd9787 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -159,37 +159,6 @@ def split_batched_coords(coords, coord_counts): items.append(coords_i) return items - -def normalize_batch_index(batch_index): - if batch_index is None: - return None - if isinstance(batch_index, int): - return [int(batch_index)] - return list(batch_index) - - -def resolve_sample_indices(batch_index, batch_size): - sample_indices = normalize_batch_index(batch_index) - if sample_indices is None: - return list(range(batch_size)) - if len(sample_indices) != batch_size: - raise ValueError( - f"Trellis2 batch_index length {len(sample_indices)} does not match batch size {batch_size}" - ) - return sample_indices - - -def resolve_singleton_sample_index(batch_index): - sample_indices = normalize_batch_index(batch_index) - if sample_indices is None: - return 0 - if len(sample_indices) != 1: - raise ValueError( - f"Trellis2 batch_index must be an int or single-element iterable for singleton coords, got {sample_indices}" - ) - return int(sample_indices[0]) - - def flatten_batched_sparse_latent(samples, coords, coord_counts): samples = samples.squeeze(-1).transpose(1, 2) if coord_counts is None: @@ -218,7 +187,6 @@ def split_batched_sparse_latent(samples, coords, coord_counts): items.append((samples[i, :count], coords_i)) return items - def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): """ Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field. @@ -232,15 +200,15 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): # map voxels voxel_pos = voxel_coords.to(device).float() * voxel_size + origin verts = mesh.vertices.to(device).squeeze(0) - voxel_colors = voxel_colors.cpu() + voxel_colors = voxel_colors.to(device) - voxel_pos_np = voxel_pos.cpu().numpy() - verts_np = verts.cpu().numpy() + voxel_pos_np = voxel_pos.numpy() + verts_np = verts.numpy() tree = scipy.spatial.cKDTree(voxel_pos_np) # nearest neighbour k=1 - _, nearest_idx_np = tree.query(verts_np, k=1, workers=1) + _, nearest_idx_np = tree.query(verts_np, k=1, workers=-1) nearest_idx = torch.from_numpy(nearest_idx_np).long() v_colors = voxel_colors[nearest_idx] @@ -253,7 +221,7 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): final_colors = linear_colors.unsqueeze(0) - out_mesh = copy.copy(mesh) + out_mesh = copy.deepcopy(mesh) out_mesh.colors = final_colors return out_mesh @@ -411,10 +379,10 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): def execute(cls, samples, vae, resolution): resolution = int(resolution) sample_tensor = samples["samples"] + sample_tensor = sample_tensor[:, :8] batch_number = prepare_trellis_vae_for_decode(vae, sample_tensor.shape) decoder = vae.first_stage_model.struct_dec load_device = comfy.model_management.get_torch_device() - batch_index = normalize_batch_index(samples.get("batch_index")) decoded_batches = [] for start in range(0, sample_tensor.shape[0], batch_number): sample_chunk = sample_tensor[start:start + batch_number].to(load_device) @@ -426,8 +394,6 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): ratio = current_res // resolution decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 out = Types.VOXEL(decoded.squeeze(1).float()) - if batch_index is not None: - out.batch_index = normalize_batch_index(batch_index) return IO.NodeOutput(out) class Trellis2UpsampleCascade(IO.ComfyNode): @@ -453,7 +419,6 @@ class Trellis2UpsampleCascade(IO.ComfyNode): prepare_trellis_vae_for_decode(vae, shape_latent_512["samples"].shape) coord_counts = shape_latent_512.get("coord_counts") - batch_index = normalize_batch_index(shape_latent_512.get("batch_index")) decoder = vae.first_stage_model.shape_dec lr_resolution = 512 target_resolution = int(target_resolution) @@ -529,14 +494,11 @@ class Trellis2UpsampleCascade(IO.ComfyNode): final_coords_list.append(final_coords_i) output_coord_counts.append(int(final_coords_i.shape[0])) - normalized_batch_index = normalize_batch_index(batch_index) output = { "coords": torch.cat(final_coords_list, dim=0), "coord_counts": torch.tensor(output_coord_counts, dtype=torch.int64), "resolutions": torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64), } - if normalized_batch_index is not None: - output["batch_index"] = normalized_batch_index return IO.NodeOutput(output,) @@ -547,8 +509,6 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): model_internal = model.model device = comfy.model_management.intermediate_device() torch_device = comfy.model_management.get_torch_device() - had_image_size = hasattr(model_internal, "image_size") - original_image_size = getattr(model_internal, "image_size", None) def prepare_tensor(pil_img, size): resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS) @@ -556,21 +516,15 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device) return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device) - cond_1024 = None - try: - model_internal.image_size = 512 - input_512 = prepare_tensor(cropped_img_tensor, 512) - cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] + model_internal.image_size = 512 + input_512 = prepare_tensor(cropped_img_tensor, 512) + cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] - if include_1024: - model_internal.image_size = 1024 - input_1024 = prepare_tensor(cropped_img_tensor, 1024) - cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] - finally: - if not had_image_size: - delattr(model_internal, "image_size") - else: - model_internal.image_size = original_image_size + cond_1024 = None + if include_1024: + model_internal.image_size = 1024 + input_1024 = prepare_tensor(cropped_img_tensor, 1024) + cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] conditioning = { 'cond_512': cond_512.to(device), @@ -580,7 +534,6 @@ def run_conditioning(model, cropped_img_tensor, include_1024=True): conditioning['cond_1024'] = cond_1024.to(device) return conditioning - class Trellis2Conditioning(IO.ComfyNode): @classmethod def define_schema(cls): @@ -693,7 +646,6 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): inputs=[ IO.AnyType.Input("structure_or_coords"), IO.Model.Input("model"), - IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), ], outputs=[ IO.Latent.Output(), @@ -702,58 +654,25 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_or_coords, model, seed): + def execute(cls, structure_or_coords, model): # to accept the upscaled coords is_512_pass = False - coord_counts = None - coord_resolutions = None - batch_index = None if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: decoded = structure_or_coords.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() is_512_pass = True - batch_index = normalize_batch_index(getattr(structure_or_coords, "batch_index", None)) - - elif isinstance(structure_or_coords, dict): - coords = structure_or_coords["coords"].int() - coord_counts = structure_or_coords.get("coord_counts") - coord_resolutions = structure_or_coords.get("resolutions") - batch_index = normalize_batch_index(structure_or_coords.get("batch_index")) - is_512_pass = False elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: coords = structure_or_coords.int() is_512_pass = False - else: raise ValueError(f"Invalid input to EmptyShapeLatent: {type(structure_or_coords)}") + + batch_size, counts, max_tokens = infer_batched_coord_layout(coords) in_channels = 32 - batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords) - if coord_counts is not None: - coord_counts = coord_counts.to(dtype=torch.int64, device=coords.device) - if coord_counts.shape != inferred_coord_counts.shape or not torch.equal(coord_counts, inferred_coord_counts): - raise ValueError( - f"Trellis2 coord_counts metadata {coord_counts.tolist()} does not match coords layout {inferred_coord_counts.tolist()}" - ) - else: - coord_counts = inferred_coord_counts - if batch_size == 1: - sample_index = resolve_singleton_sample_index(batch_index) - generator = torch.Generator(device="cpu") - generator.manual_seed(int(seed) + sample_index) - latent = torch.randn(1, in_channels, coords.shape[0], 1, generator=generator) - else: - sample_indices = resolve_sample_indices(batch_index, batch_size) - latent = torch.zeros(batch_size, in_channels, max_tokens, 1) - for i, sample_index in enumerate(sample_indices): - count = int(coord_counts[i].item()) - generator = torch.Generator(device="cpu") - generator.manual_seed(int(seed) + int(sample_index)) - latent_i = torch.randn(1, in_channels, count, 1, generator=generator) - latent[i, :, :count] = latent_i[0] - if coord_counts is not None: - latent.trellis_coord_counts = coord_counts.clone() + # image like format + latent = torch.zeros(batch_size, in_channels, max_tokens, 1) model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: @@ -762,20 +681,11 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords - if coord_counts is not None: - model.model_options["transformer_options"]["coord_counts"] = coord_counts if is_512_pass: model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512" else: model.model_options["transformer_options"]["generation_mode"] = "shape_generation" - output = {"samples": latent, "coords": coords, "type": "trellis2"} - if batch_index is not None: - output["batch_index"] = normalize_batch_index(batch_index) - if coord_counts is not None: - output["coord_counts"] = coord_counts - if coord_resolutions is not None: - output["resolutions"] = coord_resolutions - return IO.NodeOutput(output, model) + return IO.NodeOutput({"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"}, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -787,7 +697,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): IO.Voxel.Input("structure_or_coords"), IO.Latent.Input("shape_latent"), IO.Model.Input("model"), - IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), ], outputs=[ IO.Latent.Output(), @@ -796,68 +705,22 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_or_coords, shape_latent, model, seed): + def execute(cls, structure_or_coords, shape_latent, model): channels = 32 - coord_counts = None - batch_index = None if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: decoded = structure_or_coords.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() - batch_index = normalize_batch_index(getattr(structure_or_coords, "batch_index", None)) - - elif isinstance(structure_or_coords, dict): - coords = structure_or_coords["coords"].int() - coord_counts = structure_or_coords.get("coord_counts") - batch_index = normalize_batch_index(structure_or_coords.get("batch_index")) elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: coords = structure_or_coords.int() - else: - raise ValueError( - "structure_or_coords must be a voxel input with data.ndim == 4, " - f'a dict containing "coords", or a 2D torch.Tensor; got {type(structure_or_coords).__name__}' - ) - shape_batch_index = normalize_batch_index(shape_latent.get("batch_index")) - if batch_index is None: - batch_index = shape_batch_index + batch_size, counts, max_tokens = infer_batched_coord_layout(coords) + shape_latent = shape_latent["samples"] - batch_size, inferred_coord_counts, max_tokens = infer_batched_coord_layout(coords) - if coord_counts is not None: - coord_counts = coord_counts.to(dtype=torch.int64, device=coords.device) - if coord_counts.shape != inferred_coord_counts.shape or not torch.equal(coord_counts, inferred_coord_counts): - raise ValueError( - f"Trellis2 coord_counts metadata {coord_counts.tolist()} does not match coords layout {inferred_coord_counts.tolist()}" - ) - else: - coord_counts = inferred_coord_counts if shape_latent.ndim == 4: - if shape_latent.shape[0] != batch_size: - raise ValueError( - f"shape_latent batch {shape_latent.shape[0]} doesn't match coords batch {batch_size}" - ) - shape_latent = shape_latent.squeeze(-1).transpose(1, 2) - if shape_latent.shape[1] < max_tokens: - raise ValueError( - f"shape_latent tokens {shape_latent.shape[1]} can't cover coords max tokens {max_tokens}" - ) + shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels) - if batch_size == 1: - sample_index = resolve_singleton_sample_index(batch_index) - generator = torch.Generator(device="cpu") - generator.manual_seed(int(seed) + sample_index) - latent = torch.randn(1, channels, coords.shape[0], 1, generator=generator) - else: - sample_indices = resolve_sample_indices(batch_index, batch_size) - latent = torch.zeros(batch_size, channels, max_tokens, 1) - for i, sample_index in enumerate(sample_indices): - count = int(coord_counts[i].item()) - generator = torch.Generator(device="cpu") - generator.manual_seed(int(seed) + int(sample_index)) - latent_i = torch.randn(1, channels, count, 1, generator=generator) - latent[i, :, :count] = latent_i[0] - if coord_counts is not None: - latent.trellis_coord_counts = coord_counts.clone() + latent = torch.zeros(batch_size, channels, max_tokens, 1) model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: @@ -866,16 +729,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords - if coord_counts is not None: - model.model_options["transformer_options"]["coord_counts"] = coord_counts model.model_options["transformer_options"]["generation_mode"] = "texture_generation" model.model_options["transformer_options"]["shape_slat"] = shape_latent - output = {"samples": latent, "coords": coords, "type": "trellis2"} - if batch_index is not None: - output["batch_index"] = normalize_batch_index(batch_index) - if coord_counts is not None: - output["coord_counts"] = coord_counts - return IO.NodeOutput(output, model) + return IO.NodeOutput({"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"}, model) class EmptyStructureLatentTrellis2(IO.ComfyNode): @@ -886,29 +742,20 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): category="latent/3d", inputs=[ IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), - IO.Int.Input("batch_index_start", default=0, min=0, max=4096, tooltip="Starting sample index for per-sample sampler noise."), - IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), ], outputs=[ IO.Latent.Output(), ] ) @classmethod - def execute(cls, batch_size, batch_index_start, seed): + def execute(cls, batch_size): in_channels = 8 resolution = 16 - sample_indices = [int(batch_index_start) + i for i in range(batch_size)] latent = torch.zeros(batch_size, in_channels, resolution, resolution, resolution) - for i, sample_index in enumerate(sample_indices): - generator = torch.Generator(device="cpu") - generator.manual_seed(int(seed) + sample_index) - latent[i] = torch.randn(1, in_channels, resolution, resolution, resolution, generator=generator)[0] output = { "samples": latent, "type": "trellis2", } - if batch_size > 1 or batch_index_start != 0: - output["batch_index"] = sample_indices return IO.NodeOutput(output) def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=None): @@ -939,7 +786,7 @@ def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=Non else None ) - out_v, out_f, out_c = _qem_simplify_robust( + out_v, out_f, out_c = _qem_simplify( verts_np, faces_np, colors_np, target, device, max_edge_length ) @@ -952,7 +799,7 @@ def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=Non ) return final_v, final_f, final_c -def _qem_simplify_robust(verts_np, faces_np, colors_np, target_faces, device, max_edge_length=None): +def _qem_simplify(verts_np, faces_np, colors_np, target_faces, device, max_edge_length=None): verts = torch.from_numpy(verts_np).to(device=device, dtype=torch.float64) faces = torch.from_numpy(faces_np).to(device=device, dtype=torch.int64) colors = ( From 487a67129b57b52e24db33eefe0e7f888199fd72 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 8 May 2026 19:17:28 +0300 Subject: [PATCH 92/93] revert --- comfy_extras/nodes_hunyuan3d.py | 63 +++------------------------------ 1 file changed, 4 insertions(+), 59 deletions(-) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 09c213cf4..29bdab1dc 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -444,9 +444,7 @@ class VoxelToMeshBasic(IO.ComfyNode): vertices.append(v) faces.append(f) - if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): - return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) - return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces)) + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) decode = execute # TODO: remove @@ -483,9 +481,7 @@ class VoxelToMesh(IO.ComfyNode): vertices.append(v) faces.append(f) - if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces): - return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) - return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces)) + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) decode = execute # TODO: remove @@ -633,57 +629,6 @@ def save_glb(vertices, faces, filepath, metadata=None, colors=None): return filepath - -def pack_variable_mesh_batch(vertices, faces, colors=None): - batch_size = len(vertices) - max_vertices = max(v.shape[0] for v in vertices) - max_faces = max(f.shape[0] for f in faces) - - packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) - packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) - vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) - face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) - - for i, (v, f) in enumerate(zip(vertices, faces)): - packed_vertices[i, :v.shape[0]] = v - packed_faces[i, :f.shape[0]] = f - - mesh = Types.MESH(packed_vertices, packed_faces) - mesh.vertex_counts = vertex_counts - mesh.face_counts = face_counts - - if colors is not None: - max_colors = max(c.shape[0] for c in colors) - packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) - color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) - for i, c in enumerate(colors): - packed_colors[i, :c.shape[0]] = c - mesh.colors = packed_colors - mesh.color_counts = color_counts - - return mesh - - -def get_mesh_batch_item(mesh, index): - if hasattr(mesh, "vertex_counts"): - vertex_count = int(mesh.vertex_counts[index].item()) - face_count = int(mesh.face_counts[index].item()) - vertices = mesh.vertices[index, :vertex_count] - faces = mesh.faces[index, :face_count] - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - if hasattr(mesh, "color_counts"): - color_count = int(mesh.color_counts[index].item()) - colors = mesh.colors[index, :color_count] - else: - colors = mesh.colors[index, :vertex_count] - return vertices, faces, colors - - colors = None - if hasattr(mesh, "colors") and mesh.colors is not None: - colors = mesh.colors[index] - return mesh.vertices[index], mesh.faces[index], colors - class SaveGLB(IO.ComfyNode): @classmethod def define_schema(cls): @@ -741,8 +686,8 @@ class SaveGLB(IO.ComfyNode): bsz = mesh.vertices.shape[0] for i in range(bsz): f = f"{filename}_{counter:05}_.glb" - vertices, faces, v_colors = get_mesh_batch_item(mesh, i) - save_glb(vertices, faces, os.path.join(full_output_folder, f), metadata, v_colors) + v_colors = mesh.colors[i] if hasattr(mesh, "colors") and mesh.colors is not None else None + save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata, v_colors) results.append({ "filename": f, "subfolder": subfolder, From 96d0cfe0d7c061f38c5e070e49878fa69c7ede66 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 8 May 2026 20:02:09 +0300 Subject: [PATCH 93/93] . --- comfy/sample.py | 44 -------------------------------------------- 1 file changed, 44 deletions(-) diff --git a/comfy/sample.py b/comfy/sample.py index 878c4e984..653829582 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -7,50 +7,6 @@ import logging import comfy.nested_tensor def prepare_noise_inner(latent_image, generator, noise_inds=None): - coord_counts = getattr(latent_image, "trellis_coord_counts", None) - if coord_counts is not None: - if coord_counts.ndim != 1: - raise ValueError(f"Trellis2 coord_counts must be 1D, got shape {tuple(coord_counts.shape)}") - if coord_counts.shape[0] != latent_image.size(0): - raise ValueError( - f"Trellis2 coord_counts length {coord_counts.shape[0]} does not match latent batch {latent_image.size(0)}" - ) - if (coord_counts < 0).any() or (coord_counts > latent_image.size(2)).any(): - raise ValueError( - f"Trellis2 coord_counts must be within [0, {latent_image.size(2)}], got {coord_counts.tolist()}" - ) - noise = torch.zeros(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, device="cpu") - if noise_inds is None: - noise_inds = np.arange(latent_image.size(0), dtype=np.int64) - else: - noise_inds = np.asarray(noise_inds, dtype=np.int64) - if noise_inds.shape[0] != latent_image.size(0): - raise ValueError( - f"Trellis2 noise_inds length {noise_inds.shape[0]} does not match latent batch {latent_image.size(0)}" - ) - - base_seed = int(generator.initial_seed()) - unique_inds = np.unique(noise_inds) - sample_noises = {} - for noise_index in unique_inds.tolist(): - rows = np.flatnonzero(noise_inds == noise_index) - max_count = max(int(coord_counts[row].item()) for row in rows.tolist()) - local_generator = torch.Generator(device="cpu") - local_generator.manual_seed(base_seed + int(noise_index)) - sample_noises[int(noise_index)] = torch.randn( - [1, latent_image.size(1), max_count, latent_image.size(3)], - dtype=torch.float32, - layout=latent_image.layout, - generator=local_generator, - device="cpu", - ) - - for batch_index, noise_index in enumerate(noise_inds.tolist()): - count = int(coord_counts[batch_index].item()) - sample_noise = sample_noises[int(noise_index)] - noise[batch_index:batch_index + 1, :, :count, :] = sample_noise[:, :, :count, :] - return noise.to(dtype=latent_image.dtype) - if noise_inds is None: return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)