From 2b2a1a3cd07e5af7b9f70468b21d71a4f8b90d49 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 20 May 2026 17:15:33 +0300 Subject: [PATCH] remove triton, custom datatype, split mesh postpro --- comfy/ldm/trellis2/flexgemm.py | 437 +++++++------------------ comfy_extras/nodes_mesh_postprocess.py | 178 +++++----- comfy_extras/nodes_trellis2.py | 42 ++- 3 files changed, 248 insertions(+), 409 deletions(-) diff --git a/comfy/ldm/trellis2/flexgemm.py b/comfy/ldm/trellis2/flexgemm.py index 047e785ff..eb08d2970 100644 --- a/comfy/ldm/trellis2/flexgemm.py +++ b/comfy/ldm/trellis2/flexgemm.py @@ -1,136 +1,43 @@ # will contain every cuda -> pytorch operation -import math +from typing import Optional, Tuple import torch -from typing import Callable -import logging -NO_TRITON = False -try: - allow_tf32 = torch.cuda.is_tf32_supported() -except Exception: - allow_tf32 = False -try: - import triton - import triton.language as tl - heuristics = { - 'valid_kernel': lambda args: args['valid_kernel'](args['B1']), - 'valid_kernel_seg': lambda args: args['valid_kernel_seg'](args['B1']), - } +UINT32_SENTINEL = 0xFFFFFFFF - #@triton_autotune( - # configs=config.autotune_config, - # key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'], - #) - @triton.heuristics(heuristics) - @triton.jit - def sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel( - input, - weight, - bias, - neighbor, - sorted_idx, - output, - # Tensor dimensions - N, LOGN, Ci, Co, V: tl.constexpr, - # Meta-parameters - B1: tl.constexpr, # Block size for N dimension - B2: tl.constexpr, # Block size for Co dimension - BK: tl.constexpr, # Block size for K dimension (V * Ci) - allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls - # Huristic parameters - valid_kernel, - valid_kernel_seg, - ): - - block_id = tl.program_id(axis=0) - block_dim_co = tl.cdiv(Co, B2) - block_id_co = block_id % block_dim_co - block_id_n = block_id // block_dim_co - - # Create pointers for submatrices of A and B. - num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension - valid_kernel_start = tl.load(valid_kernel_seg + block_id_n) - valid_kernel_seglen = tl.load(valid_kernel_seg + block_id_n + 1) - valid_kernel_start - offset_n = block_id_n * B1 + tl.arange(0, B1) - n_mask = offset_n < N - offset_sorted_n = tl.load(sorted_idx + offset_n, mask=n_mask, other=0) # (B1,) - offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,) - offset_k = tl.arange(0, BK) # (BK,) - - # Create a block of the output matrix C. - accumulator = tl.zeros((B1, B2), dtype=tl.float32) - - # Iterate along V*Ci dimension. - for k in range(num_k * valid_kernel_seglen): - v = k // num_k - bk = k % num_k - v = tl.load(valid_kernel + valid_kernel_start + v) - # Calculate pointers to input matrix. - neighbor_offset_n = tl.load(neighbor + offset_sorted_n * V + v) # (B1,) - input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK) - # Calculate pointers to weight matrix. - weight_ptr = weight + v * Ci + bk * BK + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2) - # Load the next block of input and weight. - neigh_mask = neighbor_offset_n != 0xffffffff - k_mask = offset_k < Ci - bk * BK - input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0) - weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0) - # Accumulate along the K dimension. - accumulator = tl.dot(input_block, weight_block, accumulator, - input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2) - c = accumulator.to(input.type.element_ty) - - # add bias - if bias is not None: - bias_block = tl.load(bias + offset_co) - c += bias_block[None, :] - - # Write back the block of the output matrix with masks. - out_offset_n = offset_sorted_n - out_offset_co = block_id_co * B2 + tl.arange(0, B2) - out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :]) - out_mask = n_mask[:, None] & (out_offset_co[None, :] < Co) - tl.store(out_ptr, c, mask=out_mask) - def sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk( - input: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - neighbor: torch.Tensor, - sorted_idx: torch.Tensor, - valid_kernel: Callable[[int], torch.Tensor], - valid_kernel_seg: Callable[[int], torch.Tensor], - ) -> torch.Tensor: - N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1] - LOGN = int(math.log2(N)) - output = torch.empty((N, Co), device=input.device, dtype=input.dtype) - grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),) - sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid]( - input, weight, bias, neighbor, sorted_idx, output, - N, LOGN, Ci, Co, V, - B1=128, - B2=64, - BK=32, - valid_kernel=valid_kernel, - valid_kernel_seg=valid_kernel_seg, - allow_tf32=allow_tf32, - ) - return output -except Exception: - NO_TRITON = True def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device): - # offsets in same order as CUDA kernel + """Kernel spatial offsets in the same order as the CUDA/Triton kernels.""" offsets = [] for vx in range(Kw): for vy in range(Kh): for vz in range(Kd): - offsets.append(( - vx * Dw, - vy * Dh, - vz * Dd - )) - return torch.tensor(offsets, device=device) + offsets.append((vx * Dw, vy * Dh, vz * Dd)) + return torch.tensor(offsets, device=device, dtype=torch.int32) + + +class TorchHashMap: + """Sorted-array hashmap backed by torch.searchsorted.""" + + def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): + device = keys.device + self.sorted_keys, order = torch.sort(keys.to(torch.long)) + self.sorted_vals = values.to(torch.long)[order] + self.default_value = torch.tensor(default_value, dtype=torch.long, device=device) + self._n = self.sorted_keys.numel() + + def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: + flat = flat_keys.to(torch.long) + if self._n == 0: + return torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32) + idx = torch.searchsorted(self.sorted_keys, flat) + idx_safe = torch.clamp(idx, max=self._n - 1) + found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat) + out = torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32) + if found.any(): + out[found] = self.sorted_vals[idx_safe[found]].to(torch.int32) + return out + def build_submanifold_neighbor_map( hashmap, @@ -143,10 +50,10 @@ def build_submanifold_neighbor_map( M = coords.shape[0] V = Kw * Kh * Kd half_V = V // 2 + 1 + INVALID = -1 - INVALID = hashmap.default_value - - neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long) + # int32 neighbour map: 4 bytes/elem vs 8 bytes for int64 + neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.int32) b = coords[:, 0].long() x = coords[:, 1].long() @@ -161,7 +68,8 @@ def build_submanifold_neighbor_map( for v in range(half_V): if v == half_V - 1: - neighbor[:, v] = torch.arange(M, device=device) + # Center voxel always maps to itself + neighbor[:, v] = torch.arange(M, device=device, dtype=torch.int32) continue dx, dy, dz = offsets[v] @@ -170,7 +78,6 @@ def build_submanifold_neighbor_map( ky = oy + dy kz = oz + dz - # Check spatial bounds valid = ( (kx >= 0) & (kx < W) & (ky >= 0) & (ky < H) & @@ -187,192 +94,59 @@ def build_submanifold_neighbor_map( if flat.numel() > 0: found = hashmap.lookup_flat(flat) idx_in_M = torch.where(valid)[0] - neighbor[idx_in_M, v] = found + neighbor[idx_in_M, v] = found.to(torch.int32) - valid_found_mask = (found != INVALID) + # BUG FIX: old code used found != hashmap.default_value which + # compared int32 -1 against int64 4294967295 → always True. + # We now explicitly check for valid indices. + valid_found_mask = found >= 0 if valid_found_mask.any(): src_points = idx_in_M[valid_found_mask] - dst_points = found[valid_found_mask] - neighbor[dst_points, V - 1 - v] = src_points + dst_points = found[valid_found_mask].long() + neighbor[dst_points, V - 1 - v] = src_points.to(torch.int32) return neighbor -class TorchHashMap: - def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): - device = keys.device - # use long for searchsorted - self.sorted_keys, order = torch.sort(keys.to(torch.long)) - self.sorted_vals = values.to(torch.long)[order] - self.default_value = torch.tensor(default_value, dtype=torch.long, device=device) - self._n = self.sorted_keys.numel() - def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: - flat = flat_keys.to(torch.long) - if self._n == 0: - return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) - idx = torch.searchsorted(self.sorted_keys, flat) - idx_safe = torch.clamp(idx, max=self._n - 1) - found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat) - out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) - if found.any(): - out[found] = self.sorted_vals[idx_safe[found]] - return out +def sparse_submanifold_conv3d( + feats: torch.Tensor, + coords: torch.Tensor, + shape: tuple, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + neighbor_cache: Optional[torch.Tensor], + dilation: tuple, + max_chunk_mem_gb: float = 6.0, + accumulate_f32: bool = True, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - -UINT32_SENTINEL = 0xFFFFFFFF - -def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map): - device = neighbor_map.device - N, V = neighbor_map.shape - - sentinel = UINT32_SENTINEL - - neigh_map_T = neighbor_map.t().reshape(-1) - neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32) - - mask = (neighbor_map != sentinel).to(torch.long) - gray_code = torch.zeros(N, dtype=torch.long, device=device) - - for v in range(V): - gray_code |= (mask[:, v] << v) - - binary_code = gray_code.clone() - for v in range(1, V): - binary_code ^= (gray_code >> v) - - sorted_idx = torch.argsort(binary_code) - - prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T, dim=0) - - total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0 - - if total_valid_signal > 0: - pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0] - to = (prefix_sum_neighbor_mask[pos] - 1).long() - - valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device) - valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device) - - valid_signal_i[to] = (pos % N).to(torch.long) - valid_signal_o[to] = neigh_map_T[pos].to(torch.long) - else: - valid_signal_i = torch.empty((0,), dtype=torch.long, device=device) - valid_signal_o = torch.empty((0,), dtype=torch.long, device=device) - - seg = torch.empty((V + 1,), dtype=torch.long, device=device) - seg[0] = 0 - if V > 0: - idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1 - seg[1:] = prefix_sum_neighbor_mask[idxs] - - return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg - -def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor: - - x = x.to(torch.int64) - - m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device) - m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device) - m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device) - h01 = torch.tensor(0x0101010101010101, dtype=torch.int64, device=x.device) - - x = x - ((x >> 1) & m1) - x = (x & m2) + ((x >> 2) & m2) - x = (x + (x >> 4)) & m4 - x = (x * h01) >> 56 - return x.to(torch.int32) - - -def neighbor_map_post_process_for_masked_implicit_gemm_2( - gray_code: torch.Tensor, - sorted_idx: torch.Tensor, - block_size: int -): - device = gray_code.device - N = gray_code.numel() - num_blocks = (N + block_size - 1) // block_size - - pad = num_blocks * block_size - N - if pad > 0: - pad_vals = torch.zeros((pad,), dtype=torch.int32, device=device) - gray_padded = torch.cat([gray_code[sorted_idx], pad_vals], dim=0) - else: - gray_padded = gray_code[sorted_idx] - - gray_blocks = gray_padded.view(num_blocks, block_size) - - reduced_code = gray_blocks - while reduced_code.shape[1] > 1: - half = reduced_code.shape[1] // 2 - remainder = reduced_code.shape[1] % 2 - - left = reduced_code[:, :half * 2:2] - right = reduced_code[:, 1:half * 2:2] - merged = left | right - - if remainder: - reduced_code = torch.cat([merged, reduced_code[:, -1:]], dim=1) - else: - reduced_code = merged - - reduced_code = reduced_code.squeeze(1) - - seglen_counts = _popcount_int32_tensor(reduced_code).to(torch.int32) - - seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device) - seg[0] = 0 - if num_blocks > 0: - seg[1:] = torch.cumsum(seglen_counts, dim=0) - - total = int(seg[-1].item()) - - if total == 0: - return torch.empty((0,), dtype=torch.int32, device=device), seg - - V = int(reduced_code.max().item()).bit_length() if reduced_code.max() > 0 else 0 - - if V == 0: - return torch.empty((0,), dtype=torch.int32, device=device), seg - - bit_pos = torch.arange(0, V, dtype=torch.int32, device=device) - shifted = reduced_code.unsqueeze(1) >> bit_pos.unsqueeze(0) - bits = (shifted & 1).to(torch.bool) - - positions = bit_pos.unsqueeze(0).expand(num_blocks, V) - valid_kernel_idx = positions[bits].to(torch.int32).contiguous() - - return valid_kernel_idx, seg - - -def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation): - if NO_TRITON: # TODO - raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.") if feats.shape[0] == 0: - logging.warning("Found feats to be empty!") Co = weight.shape[0] return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None + if len(shape) == 5: - N, C, W, H, D = shape + _, _, W, H, D = shape else: W, H, D = shape Co, Kw, Kh, Kd, Ci = weight.shape - - b_stride = W * H * D - x_stride = H * D - y_stride = D - z_stride = 1 - - flat_keys = (coords[:, 0].long() * b_stride + - coords[:, 1].long() * x_stride + - coords[:, 2].long() * y_stride + - coords[:, 3].long() * z_stride) - - vals = torch.arange(coords.shape[0], dtype=torch.int32, device=coords.device) - - hashmap = TorchHashMap(flat_keys, vals, 0xFFFFFFFF) + V = Kw * Kh * Kd + device = feats.device + sentinel = -1 if neighbor_cache is None: + b_stride = W * H * D + x_stride = H * D + y_stride = D + z_stride = 1 + + flat_keys = (coords[:, 0].long() * b_stride + + coords[:, 1].long() * x_stride + + coords[:, 2].long() * y_stride + + coords[:, 3].long() * z_stride) + vals = torch.arange(coords.shape[0], dtype=torch.int32, device=device) + hashmap = TorchHashMap(flat_keys, vals, UINT32_SENTINEL) + neighbor = build_submanifold_neighbor_map( hashmap, coords, W, H, D, Kw, Kh, Kd, dilation[0], dilation[1], dilation[2] @@ -380,30 +154,67 @@ def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache else: neighbor = neighbor_cache - block_size = 128 + N_pts = feats.shape[0] - gray_code, sorted_idx, valid_signal_i, valid_signal_o, valid_signal_seg = \ - neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor) + if accumulate_f32: + weight_T = weight.view(Co, V * Ci).to(torch.float32).T.contiguous() + output = torch.zeros(N_pts, Co, device=device, dtype=torch.float32) + else: + weight_T = weight.view(Co, V * Ci).to(feats.dtype).T.contiguous() + output = torch.zeros(N_pts, Co, device=device, dtype=feats.dtype) - valid_kernel, valid_kernel_seg = \ - neighbor_map_post_process_for_masked_implicit_gemm_2(gray_code, sorted_idx, block_size) + # ------------------------------------------------------------------ + # Chunk size from memory budget + # ------------------------------------------------------------------ + bytes_per_elem = 4 if accumulate_f32 else feats.element_size() + mem_per_row = V * Ci * bytes_per_elem + max_chunk_mem = max_chunk_mem_gb * (1024 ** 3) + chunk_size = max(1, int(max_chunk_mem / mem_per_row)) + chunk_size = min(chunk_size, N_pts) - valid_kernel_fn = lambda b_size: valid_kernel - valid_kernel_seg_fn = lambda b_size: valid_kernel_seg + # ------------------------------------------------------------------ + # Chunked forward pass + # Each iteration: + # 1. gather (chunk, V, Ci) – memory bound + # 2. mask zero invalids – in-place, no extra alloc + # 3. reshape (chunk, V*Ci) + # 4. GEMM (chunk, V*Ci) @ (V*Ci, Co) → (chunk, Co) – cuBLAS + # written directly into output slice via out= argument + # ------------------------------------------------------------------ + for start in range(0, N_pts, chunk_size): + end = min(start + chunk_size, N_pts) + actual_chunk = end - start - weight_flat = weight.contiguous().view(Co, -1, Ci) + # (chunk, V) int32 + chunk_neighbor = neighbor[start:end] + chunk_valid = chunk_neighbor != sentinel - out = sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk( - feats, - weight_flat, - bias, - neighbor, - sorted_idx, - valid_kernel_fn, - valid_kernel_seg_fn - ) + # Clamp sentinel -1 → 0 for safe indexing. No clone of the full map. + chunk_idx = chunk_neighbor.clamp(min=0).long() - return out, neighbor + # Gather: (chunk, V, Ci). Memory-bound, single index_select. + gathered = feats[chunk_idx] + + # Zero invalid neighbours in-place. gathered is a fresh tensor from + # advanced indexing, so in-place mutation is safe. + gathered.mul_(chunk_valid.unsqueeze(-1)) + + # Reshape to (chunk, V*Ci) + gathered_flat = gathered.view(actual_chunk, V * Ci) + if accumulate_f32: + gathered_flat = gathered_flat.to(torch.float32) + + # Single GEMM call per chunk, written directly into output. + # This avoids allocating a temporary (chunk, Co) tensor. + torch.matmul(gathered_flat, weight_T, out=output[start:end]) + + if accumulate_f32: + output = output.to(feats.dtype) + + if bias is not None: + output = output + bias.unsqueeze(0).to(output.dtype) + + return output, neighbor class Mesh: def __init__(self, diff --git a/comfy_extras/nodes_mesh_postprocess.py b/comfy_extras/nodes_mesh_postprocess.py index 89ba6c9f2..a1f3c1a68 100644 --- a/comfy_extras/nodes_mesh_postprocess.py +++ b/comfy_extras/nodes_mesh_postprocess.py @@ -802,97 +802,127 @@ def compute_vertex_normals(verts, faces): return torch.nn.functional.normalize(vertex_normals, p=2, dim=-1, eps=1e-6) -class PostProcessMesh(IO.ComfyNode): +def _process_mesh_batch(mesh, per_item_fn): + """Handles list/batched/single mesh dispatching, color extraction, and stacking.""" + mesh = copy.deepcopy(mesh) + + def process_single(v, f, c, bar): + v, f, c = per_item_fn(v, f, c) + bar.update(1) + return v, f, c + + is_list = isinstance(mesh.vertices, list) + is_batched_tensor = not is_list and mesh.vertices.ndim == 3 + + if is_list or is_batched_tensor: + out_v, out_f, out_c = [], [], [] + bsz = len(mesh.vertices) if is_list else mesh.vertices.shape[0] + bar = comfy.utils.ProgressBar(bsz) + + for i in range(bsz): + v_i = mesh.vertices[i] + f_i = mesh.faces[i] + c_i = None + if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None: + c_i = mesh.vertex_colors[i] if (isinstance(mesh.vertex_colors, list) or mesh.vertex_colors.ndim == 3) else mesh.vertex_colors + + v_i, f_i, c_i = process_single(v_i, f_i, c_i, bar) + + out_v.append(v_i) + out_f.append(f_i) + if c_i is not None: + out_c.append(c_i) + + if all(v.shape == out_v[0].shape for v in out_v) and all(f.shape == out_f[0].shape for f in out_f): + mesh.vertices = torch.stack(out_v) + mesh.faces = torch.stack(out_f) + if out_c: + mesh.vertex_colors = torch.stack(out_c) + else: + mesh.vertices = out_v + mesh.faces = out_f + if out_c: + mesh.vertex_colors = out_c + else: + c = mesh.vertex_colors if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None else None + bar = comfy.utils.ProgressBar(1) + v, f, c = process_single(mesh.vertices, mesh.faces, c, bar) + mesh.vertices = v + mesh.faces = f + if c is not None: + mesh.vertex_colors = c + + return IO.NodeOutput(mesh) + + +class DecimateMesh(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( - node_id="PostProcessMesh", - display_name="Post Process Mesh", + node_id="DecimateMesh", + display_name="Decimate Mesh", category="latent/3d", - description=( - "Applies a sequence of mesh post-processing operations including optional hole filling" - " and mesh simplification to a target face count." - ), + description="Simplifies a mesh to a target face count using QEM.", inputs=[ IO.Mesh.Input("mesh"), - IO.Int.Input("target_face_count", default=1_000_000, min=0, max=50_000_000, - tooltip="Target maximum number of faces after mesh simplification. Set to 0 to disable simplification."), - IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0, step=0.0001, - tooltip=( - "Maximum hole perimeter threshold for filling holes in the mesh. " - "Smaller values only fill tiny holes, larger values fill larger gaps. " - "Set to 0 to disable hole filling.")) + IO.Int.Input("target_face_count", default=200_000, min=0, max=50_000_000, + tooltip="Target maximum number of faces. Set to 0 to disable."), ], - outputs=[ - IO.Mesh.Output("mesh"), - ] + outputs=[IO.Mesh.Output("mesh")], ) @classmethod - def execute(cls, mesh, target_face_count, fill_holes_perimeter): - mesh = copy.deepcopy(mesh) - - def process_single(v, f, c, bar): - if fill_holes_perimeter > 0: - v, f = fill_holes_fn(v, f, max_perimeter=fill_holes_perimeter) - bar.update(1) - - n = compute_vertex_normals(v, f) + def execute(cls, mesh, target_face_count): + def _fn(v, f, c): if target_face_count > 0 and f.shape[0] > target_face_count: + n = compute_vertex_normals(v, f) v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count) - bar.update(1) - - v, f, c = make_double_sided(v, f, c) - bar.update(1) return v, f, c + return _process_mesh_batch(mesh, _fn) - is_list = isinstance(mesh.vertices, list) - is_batched_tensor = not is_list and mesh.vertices.ndim == 3 - if is_list or is_batched_tensor: - out_v, out_f, out_c = [], [],[] - bsz = len(mesh.vertices) if is_list else mesh.vertices.shape[0] - bar = comfy.utils.ProgressBar(3 * bsz) +class FillHoles(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="FillHoles", + display_name="Fill Holes", + category="latent/3d", + description="Fills holes in a mesh up to a maximum perimeter threshold.", + inputs=[ + IO.Mesh.Input("mesh"), + IO.Float.Input("max_perimeter", default=0.03, min=0.0, step=0.0001, + tooltip="Maximum hole perimeter to fill. Set to 0 to disable."), + ], + outputs=[IO.Mesh.Output("mesh")], + ) - for i in range(bsz): - v_i = mesh.vertices[i] - f_i = mesh.faces[i] + @classmethod + def execute(cls, mesh, max_perimeter): + def _fn(v, f, c): + if max_perimeter > 0: + v, f = fill_holes_fn(v, f, max_perimeter=max_perimeter) + return v, f, c + return _process_mesh_batch(mesh, _fn) - # Safely grab colors if they exist - c_i = None - if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None: - c_i = mesh.vertex_colors[i] if (isinstance(mesh.vertex_colors, list) or mesh.vertex_colors.ndim == 3) else mesh.vertex_colors - v_i, f_i, c_i = process_single(v_i, f_i, c_i, bar) +class MakeDoubleSided(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MakeDoubleSided", + display_name="Make Double Sided", + category="latent/3d", + description="Duplicates faces with flipped normals so the mesh renders from both sides.", + inputs=[IO.Mesh.Input("mesh")], + outputs=[IO.Mesh.Output("mesh")], + ) - out_v.append(v_i) - out_f.append(f_i) - if c_i is not None: - out_c.append(c_i) - - # If the output meshes happen to have the exact same shape, stack them nicely. - # Otherwise, just leave them as a List! (ComfyUI native standard) - if all(v.shape == out_v[0].shape for v in out_v) and all(f.shape == out_f[0].shape for f in out_f): - mesh.vertices = torch.stack(out_v) - mesh.faces = torch.stack(out_f) - if out_c: - mesh.vertex_colors = torch.stack(out_c) - else: - mesh.vertices = out_v - mesh.faces = out_f - if out_c: - mesh.vertex_colors = out_c - - else: - # Single Unbatched Mesh[V, 3] - c = mesh.vertex_colors if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None else None - v, f, c = process_single(mesh.vertices, mesh.faces, c) - mesh.vertices = v - mesh.faces = f - if c is not None: - mesh.vertex_colors = c - - return IO.NodeOutput(mesh) + @classmethod + def execute(cls, mesh): + def _fn(v, f, c): + return make_double_sided(v, f, c) + return _process_mesh_batch(mesh, _fn) @@ -900,7 +930,9 @@ class PostProcessMeshExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ - PostProcessMesh, + MakeDoubleSided, + FillHoles, + DecimateMesh, PaintMesh ] diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 3702ed511..c922d58b6 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -8,7 +8,6 @@ import numpy as np import torch ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES") -HighResVoxel = io.Custom("HIGH_RES_VOXEL") def prepare_trellis_vae_for_decode(vae, sample_shape): memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype) @@ -297,7 +296,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode): )) ], outputs=[ - HighResVoxel.Output( + IO.Voxel.Output( "high_res_voxel", tooltip=( "High-resolution sparse coordinates produced after cascade upsampling. " @@ -389,11 +388,11 @@ class Trellis2UpsampleCascade(IO.ComfyNode): final_coords_list.append(final_coords_i) output_coord_counts.append(int(final_coords_i.shape[0])) - output = { - "coords": torch.cat(final_coords_list, dim=0), - "coord_counts": torch.tensor(output_coord_counts, dtype=torch.int64), - "resolutions": torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64), - } + coords = torch.cat(final_coords_list, dim=0) + output = Types.VOXEL(coords) + output.coord_counts = torch.tensor(output_coord_counts, dtype=torch.int64) + output.resolutions = torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64) + output.upsampled = True return IO.NodeOutput(output,) @@ -537,9 +536,8 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode): node_id="EmptyTrellis2ShapeLatent", category="latent/3d", inputs=[ - IO.MultiType.Input( + IO.Voxel.Input( "voxel", - types=[IO.Voxel, HighResVoxel], tooltip=( "Shape structure input. Accepts either a voxel structure " "or upsampled voxel coordinates from a previous cascade stage." @@ -555,20 +553,18 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode): def execute(cls, voxel): # to accept the upscaled coords is_512_pass = False + upsampled = hasattr(voxel, "upsampled") + if upsampled: + voxel = voxel.data - if isinstance(voxel, dict): - voxel = voxel["coords"] - - if hasattr(voxel, "data") and voxel.data.ndim == 4: + if not upsampled: decoded = voxel.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() is_512_pass = True - elif isinstance(voxel, torch.Tensor) and voxel.ndim == 2: + else: coords = voxel.int() is_512_pass = False - else: - raise ValueError(f"Invalid input to EmptyTrellis2ShapeLatent: {type(voxel)}") batch_size, counts, max_tokens = infer_batched_coord_layout(coords) in_channels = 32 @@ -589,9 +585,8 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode): node_id="EmptyTrellis2LatentTexture", category="latent/3d", inputs=[ - IO.MultiType.Input( + IO.Voxel.Input( "voxel", - types=[IO.Voxel, HighResVoxel], tooltip=( "Shape structure input. Accepts either a voxel structure " "or upsampled voxel coordinates from a previous cascade stage." @@ -607,13 +602,14 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode): @classmethod def execute(cls, voxel, shape_latent): channels = 32 - if isinstance(voxel, dict): - voxel = voxel["coords"] - if hasattr(voxel, "data") and voxel.data.ndim == 4: + upsampled = hasattr(voxel, "upsampled") + if upsampled: + voxel = voxel.data + + if not upsampled: decoded = voxel.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() - - elif isinstance(voxel, torch.Tensor) and voxel.ndim == 2: + else: coords = voxel.int() batch_size, counts, max_tokens = infer_batched_coord_layout(coords)