remove triton, custom datatype, split mesh postpro

This commit is contained in:
Yousef Rafat 2026-05-20 17:15:33 +03:00
parent 9bf7bbb496
commit 2b2a1a3cd0
3 changed files with 248 additions and 409 deletions

View File

@ -1,136 +1,43 @@
# will contain every cuda -> pytorch operation
import math
from typing import Optional, Tuple
import torch
from typing import Callable
import logging
NO_TRITON = False
try:
allow_tf32 = torch.cuda.is_tf32_supported()
except Exception:
allow_tf32 = False
try:
import triton
import triton.language as tl
heuristics = {
'valid_kernel': lambda args: args['valid_kernel'](args['B1']),
'valid_kernel_seg': lambda args: args['valid_kernel_seg'](args['B1']),
}
UINT32_SENTINEL = 0xFFFFFFFF
#@triton_autotune(
# configs=config.autotune_config,
# key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'],
#)
@triton.heuristics(heuristics)
@triton.jit
def sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel(
input,
weight,
bias,
neighbor,
sorted_idx,
output,
# Tensor dimensions
N, LOGN, Ci, Co, V: tl.constexpr,
# Meta-parameters
B1: tl.constexpr, # Block size for N dimension
B2: tl.constexpr, # Block size for Co dimension
BK: tl.constexpr, # Block size for K dimension (V * Ci)
allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
# Huristic parameters
valid_kernel,
valid_kernel_seg,
):
block_id = tl.program_id(axis=0)
block_dim_co = tl.cdiv(Co, B2)
block_id_co = block_id % block_dim_co
block_id_n = block_id // block_dim_co
# Create pointers for submatrices of A and B.
num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension
valid_kernel_start = tl.load(valid_kernel_seg + block_id_n)
valid_kernel_seglen = tl.load(valid_kernel_seg + block_id_n + 1) - valid_kernel_start
offset_n = block_id_n * B1 + tl.arange(0, B1)
n_mask = offset_n < N
offset_sorted_n = tl.load(sorted_idx + offset_n, mask=n_mask, other=0) # (B1,)
offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,)
offset_k = tl.arange(0, BK) # (BK,)
# Create a block of the output matrix C.
accumulator = tl.zeros((B1, B2), dtype=tl.float32)
# Iterate along V*Ci dimension.
for k in range(num_k * valid_kernel_seglen):
v = k // num_k
bk = k % num_k
v = tl.load(valid_kernel + valid_kernel_start + v)
# Calculate pointers to input matrix.
neighbor_offset_n = tl.load(neighbor + offset_sorted_n * V + v) # (B1,)
input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK)
# Calculate pointers to weight matrix.
weight_ptr = weight + v * Ci + bk * BK + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2)
# Load the next block of input and weight.
neigh_mask = neighbor_offset_n != 0xffffffff
k_mask = offset_k < Ci - bk * BK
input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
# Accumulate along the K dimension.
accumulator = tl.dot(input_block, weight_block, accumulator,
input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
c = accumulator.to(input.type.element_ty)
# add bias
if bias is not None:
bias_block = tl.load(bias + offset_co)
c += bias_block[None, :]
# Write back the block of the output matrix with masks.
out_offset_n = offset_sorted_n
out_offset_co = block_id_co * B2 + tl.arange(0, B2)
out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :])
out_mask = n_mask[:, None] & (out_offset_co[None, :] < Co)
tl.store(out_ptr, c, mask=out_mask)
def sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
neighbor: torch.Tensor,
sorted_idx: torch.Tensor,
valid_kernel: Callable[[int], torch.Tensor],
valid_kernel_seg: Callable[[int], torch.Tensor],
) -> torch.Tensor:
N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
LOGN = int(math.log2(N))
output = torch.empty((N, Co), device=input.device, dtype=input.dtype)
grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid](
input, weight, bias, neighbor, sorted_idx, output,
N, LOGN, Ci, Co, V,
B1=128,
B2=64,
BK=32,
valid_kernel=valid_kernel,
valid_kernel_seg=valid_kernel_seg,
allow_tf32=allow_tf32,
)
return output
except Exception:
NO_TRITON = True
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
# offsets in same order as CUDA kernel
"""Kernel spatial offsets in the same order as the CUDA/Triton kernels."""
offsets = []
for vx in range(Kw):
for vy in range(Kh):
for vz in range(Kd):
offsets.append((
vx * Dw,
vy * Dh,
vz * Dd
))
return torch.tensor(offsets, device=device)
offsets.append((vx * Dw, vy * Dh, vz * Dd))
return torch.tensor(offsets, device=device, dtype=torch.int32)
class TorchHashMap:
"""Sorted-array hashmap backed by torch.searchsorted."""
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
device = keys.device
self.sorted_keys, order = torch.sort(keys.to(torch.long))
self.sorted_vals = values.to(torch.long)[order]
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
self._n = self.sorted_keys.numel()
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
flat = flat_keys.to(torch.long)
if self._n == 0:
return torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32)
idx = torch.searchsorted(self.sorted_keys, flat)
idx_safe = torch.clamp(idx, max=self._n - 1)
found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat)
out = torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32)
if found.any():
out[found] = self.sorted_vals[idx_safe[found]].to(torch.int32)
return out
def build_submanifold_neighbor_map(
hashmap,
@ -143,10 +50,10 @@ def build_submanifold_neighbor_map(
M = coords.shape[0]
V = Kw * Kh * Kd
half_V = V // 2 + 1
INVALID = -1
INVALID = hashmap.default_value
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long)
# int32 neighbour map: 4 bytes/elem vs 8 bytes for int64
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.int32)
b = coords[:, 0].long()
x = coords[:, 1].long()
@ -161,7 +68,8 @@ def build_submanifold_neighbor_map(
for v in range(half_V):
if v == half_V - 1:
neighbor[:, v] = torch.arange(M, device=device)
# Center voxel always maps to itself
neighbor[:, v] = torch.arange(M, device=device, dtype=torch.int32)
continue
dx, dy, dz = offsets[v]
@ -170,7 +78,6 @@ def build_submanifold_neighbor_map(
ky = oy + dy
kz = oz + dz
# Check spatial bounds
valid = (
(kx >= 0) & (kx < W) &
(ky >= 0) & (ky < H) &
@ -187,192 +94,59 @@ def build_submanifold_neighbor_map(
if flat.numel() > 0:
found = hashmap.lookup_flat(flat)
idx_in_M = torch.where(valid)[0]
neighbor[idx_in_M, v] = found
neighbor[idx_in_M, v] = found.to(torch.int32)
valid_found_mask = (found != INVALID)
# BUG FIX: old code used found != hashmap.default_value which
# compared int32 -1 against int64 4294967295 → always True.
# We now explicitly check for valid indices.
valid_found_mask = found >= 0
if valid_found_mask.any():
src_points = idx_in_M[valid_found_mask]
dst_points = found[valid_found_mask]
neighbor[dst_points, V - 1 - v] = src_points
dst_points = found[valid_found_mask].long()
neighbor[dst_points, V - 1 - v] = src_points.to(torch.int32)
return neighbor
class TorchHashMap:
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
device = keys.device
# use long for searchsorted
self.sorted_keys, order = torch.sort(keys.to(torch.long))
self.sorted_vals = values.to(torch.long)[order]
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
self._n = self.sorted_keys.numel()
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
flat = flat_keys.to(torch.long)
if self._n == 0:
return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
idx = torch.searchsorted(self.sorted_keys, flat)
idx_safe = torch.clamp(idx, max=self._n - 1)
found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat)
out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
if found.any():
out[found] = self.sorted_vals[idx_safe[found]]
return out
def sparse_submanifold_conv3d(
feats: torch.Tensor,
coords: torch.Tensor,
shape: tuple,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
neighbor_cache: Optional[torch.Tensor],
dilation: tuple,
max_chunk_mem_gb: float = 6.0,
accumulate_f32: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
UINT32_SENTINEL = 0xFFFFFFFF
def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
device = neighbor_map.device
N, V = neighbor_map.shape
sentinel = UINT32_SENTINEL
neigh_map_T = neighbor_map.t().reshape(-1)
neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32)
mask = (neighbor_map != sentinel).to(torch.long)
gray_code = torch.zeros(N, dtype=torch.long, device=device)
for v in range(V):
gray_code |= (mask[:, v] << v)
binary_code = gray_code.clone()
for v in range(1, V):
binary_code ^= (gray_code >> v)
sorted_idx = torch.argsort(binary_code)
prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T, dim=0)
total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0
if total_valid_signal > 0:
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
to = (prefix_sum_neighbor_mask[pos] - 1).long()
valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
valid_signal_i[to] = (pos % N).to(torch.long)
valid_signal_o[to] = neigh_map_T[pos].to(torch.long)
else:
valid_signal_i = torch.empty((0,), dtype=torch.long, device=device)
valid_signal_o = torch.empty((0,), dtype=torch.long, device=device)
seg = torch.empty((V + 1,), dtype=torch.long, device=device)
seg[0] = 0
if V > 0:
idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1
seg[1:] = prefix_sum_neighbor_mask[idxs]
return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg
def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor:
x = x.to(torch.int64)
m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device)
m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device)
m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device)
h01 = torch.tensor(0x0101010101010101, dtype=torch.int64, device=x.device)
x = x - ((x >> 1) & m1)
x = (x & m2) + ((x >> 2) & m2)
x = (x + (x >> 4)) & m4
x = (x * h01) >> 56
return x.to(torch.int32)
def neighbor_map_post_process_for_masked_implicit_gemm_2(
gray_code: torch.Tensor,
sorted_idx: torch.Tensor,
block_size: int
):
device = gray_code.device
N = gray_code.numel()
num_blocks = (N + block_size - 1) // block_size
pad = num_blocks * block_size - N
if pad > 0:
pad_vals = torch.zeros((pad,), dtype=torch.int32, device=device)
gray_padded = torch.cat([gray_code[sorted_idx], pad_vals], dim=0)
else:
gray_padded = gray_code[sorted_idx]
gray_blocks = gray_padded.view(num_blocks, block_size)
reduced_code = gray_blocks
while reduced_code.shape[1] > 1:
half = reduced_code.shape[1] // 2
remainder = reduced_code.shape[1] % 2
left = reduced_code[:, :half * 2:2]
right = reduced_code[:, 1:half * 2:2]
merged = left | right
if remainder:
reduced_code = torch.cat([merged, reduced_code[:, -1:]], dim=1)
else:
reduced_code = merged
reduced_code = reduced_code.squeeze(1)
seglen_counts = _popcount_int32_tensor(reduced_code).to(torch.int32)
seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device)
seg[0] = 0
if num_blocks > 0:
seg[1:] = torch.cumsum(seglen_counts, dim=0)
total = int(seg[-1].item())
if total == 0:
return torch.empty((0,), dtype=torch.int32, device=device), seg
V = int(reduced_code.max().item()).bit_length() if reduced_code.max() > 0 else 0
if V == 0:
return torch.empty((0,), dtype=torch.int32, device=device), seg
bit_pos = torch.arange(0, V, dtype=torch.int32, device=device)
shifted = reduced_code.unsqueeze(1) >> bit_pos.unsqueeze(0)
bits = (shifted & 1).to(torch.bool)
positions = bit_pos.unsqueeze(0).expand(num_blocks, V)
valid_kernel_idx = positions[bits].to(torch.int32).contiguous()
return valid_kernel_idx, seg
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
if NO_TRITON: # TODO
raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.")
if feats.shape[0] == 0:
logging.warning("Found feats to be empty!")
Co = weight.shape[0]
return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None
if len(shape) == 5:
N, C, W, H, D = shape
_, _, W, H, D = shape
else:
W, H, D = shape
Co, Kw, Kh, Kd, Ci = weight.shape
b_stride = W * H * D
x_stride = H * D
y_stride = D
z_stride = 1
flat_keys = (coords[:, 0].long() * b_stride +
coords[:, 1].long() * x_stride +
coords[:, 2].long() * y_stride +
coords[:, 3].long() * z_stride)
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=coords.device)
hashmap = TorchHashMap(flat_keys, vals, 0xFFFFFFFF)
V = Kw * Kh * Kd
device = feats.device
sentinel = -1
if neighbor_cache is None:
b_stride = W * H * D
x_stride = H * D
y_stride = D
z_stride = 1
flat_keys = (coords[:, 0].long() * b_stride +
coords[:, 1].long() * x_stride +
coords[:, 2].long() * y_stride +
coords[:, 3].long() * z_stride)
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=device)
hashmap = TorchHashMap(flat_keys, vals, UINT32_SENTINEL)
neighbor = build_submanifold_neighbor_map(
hashmap, coords, W, H, D, Kw, Kh, Kd,
dilation[0], dilation[1], dilation[2]
@ -380,30 +154,67 @@ def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache
else:
neighbor = neighbor_cache
block_size = 128
N_pts = feats.shape[0]
gray_code, sorted_idx, valid_signal_i, valid_signal_o, valid_signal_seg = \
neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor)
if accumulate_f32:
weight_T = weight.view(Co, V * Ci).to(torch.float32).T.contiguous()
output = torch.zeros(N_pts, Co, device=device, dtype=torch.float32)
else:
weight_T = weight.view(Co, V * Ci).to(feats.dtype).T.contiguous()
output = torch.zeros(N_pts, Co, device=device, dtype=feats.dtype)
valid_kernel, valid_kernel_seg = \
neighbor_map_post_process_for_masked_implicit_gemm_2(gray_code, sorted_idx, block_size)
# ------------------------------------------------------------------
# Chunk size from memory budget
# ------------------------------------------------------------------
bytes_per_elem = 4 if accumulate_f32 else feats.element_size()
mem_per_row = V * Ci * bytes_per_elem
max_chunk_mem = max_chunk_mem_gb * (1024 ** 3)
chunk_size = max(1, int(max_chunk_mem / mem_per_row))
chunk_size = min(chunk_size, N_pts)
valid_kernel_fn = lambda b_size: valid_kernel
valid_kernel_seg_fn = lambda b_size: valid_kernel_seg
# ------------------------------------------------------------------
# Chunked forward pass
# Each iteration:
# 1. gather (chunk, V, Ci) memory bound
# 2. mask zero invalids in-place, no extra alloc
# 3. reshape (chunk, V*Ci)
# 4. GEMM (chunk, V*Ci) @ (V*Ci, Co) → (chunk, Co) cuBLAS
# written directly into output slice via out= argument
# ------------------------------------------------------------------
for start in range(0, N_pts, chunk_size):
end = min(start + chunk_size, N_pts)
actual_chunk = end - start
weight_flat = weight.contiguous().view(Co, -1, Ci)
# (chunk, V) int32
chunk_neighbor = neighbor[start:end]
chunk_valid = chunk_neighbor != sentinel
out = sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
feats,
weight_flat,
bias,
neighbor,
sorted_idx,
valid_kernel_fn,
valid_kernel_seg_fn
)
# Clamp sentinel -1 → 0 for safe indexing. No clone of the full map.
chunk_idx = chunk_neighbor.clamp(min=0).long()
return out, neighbor
# Gather: (chunk, V, Ci). Memory-bound, single index_select.
gathered = feats[chunk_idx]
# Zero invalid neighbours in-place. gathered is a fresh tensor from
# advanced indexing, so in-place mutation is safe.
gathered.mul_(chunk_valid.unsqueeze(-1))
# Reshape to (chunk, V*Ci)
gathered_flat = gathered.view(actual_chunk, V * Ci)
if accumulate_f32:
gathered_flat = gathered_flat.to(torch.float32)
# Single GEMM call per chunk, written directly into output.
# This avoids allocating a temporary (chunk, Co) tensor.
torch.matmul(gathered_flat, weight_T, out=output[start:end])
if accumulate_f32:
output = output.to(feats.dtype)
if bias is not None:
output = output + bias.unsqueeze(0).to(output.dtype)
return output, neighbor
class Mesh:
def __init__(self,

View File

@ -802,97 +802,127 @@ def compute_vertex_normals(verts, faces):
return torch.nn.functional.normalize(vertex_normals, p=2, dim=-1, eps=1e-6)
class PostProcessMesh(IO.ComfyNode):
def _process_mesh_batch(mesh, per_item_fn):
"""Handles list/batched/single mesh dispatching, color extraction, and stacking."""
mesh = copy.deepcopy(mesh)
def process_single(v, f, c, bar):
v, f, c = per_item_fn(v, f, c)
bar.update(1)
return v, f, c
is_list = isinstance(mesh.vertices, list)
is_batched_tensor = not is_list and mesh.vertices.ndim == 3
if is_list or is_batched_tensor:
out_v, out_f, out_c = [], [], []
bsz = len(mesh.vertices) if is_list else mesh.vertices.shape[0]
bar = comfy.utils.ProgressBar(bsz)
for i in range(bsz):
v_i = mesh.vertices[i]
f_i = mesh.faces[i]
c_i = None
if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None:
c_i = mesh.vertex_colors[i] if (isinstance(mesh.vertex_colors, list) or mesh.vertex_colors.ndim == 3) else mesh.vertex_colors
v_i, f_i, c_i = process_single(v_i, f_i, c_i, bar)
out_v.append(v_i)
out_f.append(f_i)
if c_i is not None:
out_c.append(c_i)
if all(v.shape == out_v[0].shape for v in out_v) and all(f.shape == out_f[0].shape for f in out_f):
mesh.vertices = torch.stack(out_v)
mesh.faces = torch.stack(out_f)
if out_c:
mesh.vertex_colors = torch.stack(out_c)
else:
mesh.vertices = out_v
mesh.faces = out_f
if out_c:
mesh.vertex_colors = out_c
else:
c = mesh.vertex_colors if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None else None
bar = comfy.utils.ProgressBar(1)
v, f, c = process_single(mesh.vertices, mesh.faces, c, bar)
mesh.vertices = v
mesh.faces = f
if c is not None:
mesh.vertex_colors = c
return IO.NodeOutput(mesh)
class DecimateMesh(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="PostProcessMesh",
display_name="Post Process Mesh",
node_id="DecimateMesh",
display_name="Decimate Mesh",
category="latent/3d",
description=(
"Applies a sequence of mesh post-processing operations including optional hole filling"
" and mesh simplification to a target face count."
),
description="Simplifies a mesh to a target face count using QEM.",
inputs=[
IO.Mesh.Input("mesh"),
IO.Int.Input("target_face_count", default=1_000_000, min=0, max=50_000_000,
tooltip="Target maximum number of faces after mesh simplification. Set to 0 to disable simplification."),
IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0, step=0.0001,
tooltip=(
"Maximum hole perimeter threshold for filling holes in the mesh. "
"Smaller values only fill tiny holes, larger values fill larger gaps. "
"Set to 0 to disable hole filling."))
IO.Int.Input("target_face_count", default=200_000, min=0, max=50_000_000,
tooltip="Target maximum number of faces. Set to 0 to disable."),
],
outputs=[
IO.Mesh.Output("mesh"),
]
outputs=[IO.Mesh.Output("mesh")],
)
@classmethod
def execute(cls, mesh, target_face_count, fill_holes_perimeter):
mesh = copy.deepcopy(mesh)
def process_single(v, f, c, bar):
if fill_holes_perimeter > 0:
v, f = fill_holes_fn(v, f, max_perimeter=fill_holes_perimeter)
bar.update(1)
n = compute_vertex_normals(v, f)
def execute(cls, mesh, target_face_count):
def _fn(v, f, c):
if target_face_count > 0 and f.shape[0] > target_face_count:
n = compute_vertex_normals(v, f)
v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count)
bar.update(1)
v, f, c = make_double_sided(v, f, c)
bar.update(1)
return v, f, c
return _process_mesh_batch(mesh, _fn)
is_list = isinstance(mesh.vertices, list)
is_batched_tensor = not is_list and mesh.vertices.ndim == 3
if is_list or is_batched_tensor:
out_v, out_f, out_c = [], [],[]
bsz = len(mesh.vertices) if is_list else mesh.vertices.shape[0]
bar = comfy.utils.ProgressBar(3 * bsz)
class FillHoles(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="FillHoles",
display_name="Fill Holes",
category="latent/3d",
description="Fills holes in a mesh up to a maximum perimeter threshold.",
inputs=[
IO.Mesh.Input("mesh"),
IO.Float.Input("max_perimeter", default=0.03, min=0.0, step=0.0001,
tooltip="Maximum hole perimeter to fill. Set to 0 to disable."),
],
outputs=[IO.Mesh.Output("mesh")],
)
for i in range(bsz):
v_i = mesh.vertices[i]
f_i = mesh.faces[i]
@classmethod
def execute(cls, mesh, max_perimeter):
def _fn(v, f, c):
if max_perimeter > 0:
v, f = fill_holes_fn(v, f, max_perimeter=max_perimeter)
return v, f, c
return _process_mesh_batch(mesh, _fn)
# Safely grab colors if they exist
c_i = None
if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None:
c_i = mesh.vertex_colors[i] if (isinstance(mesh.vertex_colors, list) or mesh.vertex_colors.ndim == 3) else mesh.vertex_colors
v_i, f_i, c_i = process_single(v_i, f_i, c_i, bar)
class MakeDoubleSided(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="MakeDoubleSided",
display_name="Make Double Sided",
category="latent/3d",
description="Duplicates faces with flipped normals so the mesh renders from both sides.",
inputs=[IO.Mesh.Input("mesh")],
outputs=[IO.Mesh.Output("mesh")],
)
out_v.append(v_i)
out_f.append(f_i)
if c_i is not None:
out_c.append(c_i)
# If the output meshes happen to have the exact same shape, stack them nicely.
# Otherwise, just leave them as a List! (ComfyUI native standard)
if all(v.shape == out_v[0].shape for v in out_v) and all(f.shape == out_f[0].shape for f in out_f):
mesh.vertices = torch.stack(out_v)
mesh.faces = torch.stack(out_f)
if out_c:
mesh.vertex_colors = torch.stack(out_c)
else:
mesh.vertices = out_v
mesh.faces = out_f
if out_c:
mesh.vertex_colors = out_c
else:
# Single Unbatched Mesh[V, 3]
c = mesh.vertex_colors if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None else None
v, f, c = process_single(mesh.vertices, mesh.faces, c)
mesh.vertices = v
mesh.faces = f
if c is not None:
mesh.vertex_colors = c
return IO.NodeOutput(mesh)
@classmethod
def execute(cls, mesh):
def _fn(v, f, c):
return make_double_sided(v, f, c)
return _process_mesh_batch(mesh, _fn)
@ -900,7 +930,9 @@ class PostProcessMeshExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
PostProcessMesh,
MakeDoubleSided,
FillHoles,
DecimateMesh,
PaintMesh
]

View File

@ -8,7 +8,6 @@ import numpy as np
import torch
ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
HighResVoxel = io.Custom("HIGH_RES_VOXEL")
def prepare_trellis_vae_for_decode(vae, sample_shape):
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
@ -297,7 +296,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode):
))
],
outputs=[
HighResVoxel.Output(
IO.Voxel.Output(
"high_res_voxel",
tooltip=(
"High-resolution sparse coordinates produced after cascade upsampling. "
@ -389,11 +388,11 @@ class Trellis2UpsampleCascade(IO.ComfyNode):
final_coords_list.append(final_coords_i)
output_coord_counts.append(int(final_coords_i.shape[0]))
output = {
"coords": torch.cat(final_coords_list, dim=0),
"coord_counts": torch.tensor(output_coord_counts, dtype=torch.int64),
"resolutions": torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64),
}
coords = torch.cat(final_coords_list, dim=0)
output = Types.VOXEL(coords)
output.coord_counts = torch.tensor(output_coord_counts, dtype=torch.int64)
output.resolutions = torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64)
output.upsampled = True
return IO.NodeOutput(output,)
@ -537,9 +536,8 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
node_id="EmptyTrellis2ShapeLatent",
category="latent/3d",
inputs=[
IO.MultiType.Input(
IO.Voxel.Input(
"voxel",
types=[IO.Voxel, HighResVoxel],
tooltip=(
"Shape structure input. Accepts either a voxel structure "
"or upsampled voxel coordinates from a previous cascade stage."
@ -555,20 +553,18 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
def execute(cls, voxel):
# to accept the upscaled coords
is_512_pass = False
upsampled = hasattr(voxel, "upsampled")
if upsampled:
voxel = voxel.data
if isinstance(voxel, dict):
voxel = voxel["coords"]
if hasattr(voxel, "data") and voxel.data.ndim == 4:
if not upsampled:
decoded = voxel.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
is_512_pass = True
elif isinstance(voxel, torch.Tensor) and voxel.ndim == 2:
else:
coords = voxel.int()
is_512_pass = False
else:
raise ValueError(f"Invalid input to EmptyTrellis2ShapeLatent: {type(voxel)}")
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
in_channels = 32
@ -589,9 +585,8 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
node_id="EmptyTrellis2LatentTexture",
category="latent/3d",
inputs=[
IO.MultiType.Input(
IO.Voxel.Input(
"voxel",
types=[IO.Voxel, HighResVoxel],
tooltip=(
"Shape structure input. Accepts either a voxel structure "
"or upsampled voxel coordinates from a previous cascade stage."
@ -607,13 +602,14 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
@classmethod
def execute(cls, voxel, shape_latent):
channels = 32
if isinstance(voxel, dict):
voxel = voxel["coords"]
if hasattr(voxel, "data") and voxel.data.ndim == 4:
upsampled = hasattr(voxel, "upsampled")
if upsampled:
voxel = voxel.data
if not upsampled:
decoded = voxel.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
elif isinstance(voxel, torch.Tensor) and voxel.ndim == 2:
else:
coords = voxel.int()
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)