mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-17 21:39:45 +08:00
remove triton, custom datatype, split mesh postpro
This commit is contained in:
parent
9bf7bbb496
commit
2b2a1a3cd0
@ -1,136 +1,43 @@
|
||||
# will contain every cuda -> pytorch operation
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
from typing import Callable
|
||||
import logging
|
||||
|
||||
NO_TRITON = False
|
||||
try:
|
||||
allow_tf32 = torch.cuda.is_tf32_supported()
|
||||
except Exception:
|
||||
allow_tf32 = False
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
heuristics = {
|
||||
'valid_kernel': lambda args: args['valid_kernel'](args['B1']),
|
||||
'valid_kernel_seg': lambda args: args['valid_kernel_seg'](args['B1']),
|
||||
}
|
||||
UINT32_SENTINEL = 0xFFFFFFFF
|
||||
|
||||
#@triton_autotune(
|
||||
# configs=config.autotune_config,
|
||||
# key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'],
|
||||
#)
|
||||
@triton.heuristics(heuristics)
|
||||
@triton.jit
|
||||
def sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
neighbor,
|
||||
sorted_idx,
|
||||
output,
|
||||
# Tensor dimensions
|
||||
N, LOGN, Ci, Co, V: tl.constexpr,
|
||||
# Meta-parameters
|
||||
B1: tl.constexpr, # Block size for N dimension
|
||||
B2: tl.constexpr, # Block size for Co dimension
|
||||
BK: tl.constexpr, # Block size for K dimension (V * Ci)
|
||||
allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
|
||||
# Huristic parameters
|
||||
valid_kernel,
|
||||
valid_kernel_seg,
|
||||
):
|
||||
|
||||
block_id = tl.program_id(axis=0)
|
||||
block_dim_co = tl.cdiv(Co, B2)
|
||||
block_id_co = block_id % block_dim_co
|
||||
block_id_n = block_id // block_dim_co
|
||||
|
||||
# Create pointers for submatrices of A and B.
|
||||
num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension
|
||||
valid_kernel_start = tl.load(valid_kernel_seg + block_id_n)
|
||||
valid_kernel_seglen = tl.load(valid_kernel_seg + block_id_n + 1) - valid_kernel_start
|
||||
offset_n = block_id_n * B1 + tl.arange(0, B1)
|
||||
n_mask = offset_n < N
|
||||
offset_sorted_n = tl.load(sorted_idx + offset_n, mask=n_mask, other=0) # (B1,)
|
||||
offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,)
|
||||
offset_k = tl.arange(0, BK) # (BK,)
|
||||
|
||||
# Create a block of the output matrix C.
|
||||
accumulator = tl.zeros((B1, B2), dtype=tl.float32)
|
||||
|
||||
# Iterate along V*Ci dimension.
|
||||
for k in range(num_k * valid_kernel_seglen):
|
||||
v = k // num_k
|
||||
bk = k % num_k
|
||||
v = tl.load(valid_kernel + valid_kernel_start + v)
|
||||
# Calculate pointers to input matrix.
|
||||
neighbor_offset_n = tl.load(neighbor + offset_sorted_n * V + v) # (B1,)
|
||||
input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK)
|
||||
# Calculate pointers to weight matrix.
|
||||
weight_ptr = weight + v * Ci + bk * BK + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2)
|
||||
# Load the next block of input and weight.
|
||||
neigh_mask = neighbor_offset_n != 0xffffffff
|
||||
k_mask = offset_k < Ci - bk * BK
|
||||
input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
|
||||
weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
|
||||
# Accumulate along the K dimension.
|
||||
accumulator = tl.dot(input_block, weight_block, accumulator,
|
||||
input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
|
||||
c = accumulator.to(input.type.element_ty)
|
||||
|
||||
# add bias
|
||||
if bias is not None:
|
||||
bias_block = tl.load(bias + offset_co)
|
||||
c += bias_block[None, :]
|
||||
|
||||
# Write back the block of the output matrix with masks.
|
||||
out_offset_n = offset_sorted_n
|
||||
out_offset_co = block_id_co * B2 + tl.arange(0, B2)
|
||||
out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :])
|
||||
out_mask = n_mask[:, None] & (out_offset_co[None, :] < Co)
|
||||
tl.store(out_ptr, c, mask=out_mask)
|
||||
def sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
neighbor: torch.Tensor,
|
||||
sorted_idx: torch.Tensor,
|
||||
valid_kernel: Callable[[int], torch.Tensor],
|
||||
valid_kernel_seg: Callable[[int], torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
|
||||
LOGN = int(math.log2(N))
|
||||
output = torch.empty((N, Co), device=input.device, dtype=input.dtype)
|
||||
grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
|
||||
sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid](
|
||||
input, weight, bias, neighbor, sorted_idx, output,
|
||||
N, LOGN, Ci, Co, V,
|
||||
B1=128,
|
||||
B2=64,
|
||||
BK=32,
|
||||
valid_kernel=valid_kernel,
|
||||
valid_kernel_seg=valid_kernel_seg,
|
||||
allow_tf32=allow_tf32,
|
||||
)
|
||||
return output
|
||||
except Exception:
|
||||
NO_TRITON = True
|
||||
|
||||
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
|
||||
# offsets in same order as CUDA kernel
|
||||
"""Kernel spatial offsets in the same order as the CUDA/Triton kernels."""
|
||||
offsets = []
|
||||
for vx in range(Kw):
|
||||
for vy in range(Kh):
|
||||
for vz in range(Kd):
|
||||
offsets.append((
|
||||
vx * Dw,
|
||||
vy * Dh,
|
||||
vz * Dd
|
||||
))
|
||||
return torch.tensor(offsets, device=device)
|
||||
offsets.append((vx * Dw, vy * Dh, vz * Dd))
|
||||
return torch.tensor(offsets, device=device, dtype=torch.int32)
|
||||
|
||||
|
||||
class TorchHashMap:
|
||||
"""Sorted-array hashmap backed by torch.searchsorted."""
|
||||
|
||||
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
|
||||
device = keys.device
|
||||
self.sorted_keys, order = torch.sort(keys.to(torch.long))
|
||||
self.sorted_vals = values.to(torch.long)[order]
|
||||
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
|
||||
self._n = self.sorted_keys.numel()
|
||||
|
||||
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
|
||||
flat = flat_keys.to(torch.long)
|
||||
if self._n == 0:
|
||||
return torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32)
|
||||
idx = torch.searchsorted(self.sorted_keys, flat)
|
||||
idx_safe = torch.clamp(idx, max=self._n - 1)
|
||||
found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat)
|
||||
out = torch.full((flat.shape[0],), -1, device=flat.device, dtype=torch.int32)
|
||||
if found.any():
|
||||
out[found] = self.sorted_vals[idx_safe[found]].to(torch.int32)
|
||||
return out
|
||||
|
||||
|
||||
def build_submanifold_neighbor_map(
|
||||
hashmap,
|
||||
@ -143,10 +50,10 @@ def build_submanifold_neighbor_map(
|
||||
M = coords.shape[0]
|
||||
V = Kw * Kh * Kd
|
||||
half_V = V // 2 + 1
|
||||
INVALID = -1
|
||||
|
||||
INVALID = hashmap.default_value
|
||||
|
||||
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long)
|
||||
# int32 neighbour map: 4 bytes/elem vs 8 bytes for int64
|
||||
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.int32)
|
||||
|
||||
b = coords[:, 0].long()
|
||||
x = coords[:, 1].long()
|
||||
@ -161,7 +68,8 @@ def build_submanifold_neighbor_map(
|
||||
|
||||
for v in range(half_V):
|
||||
if v == half_V - 1:
|
||||
neighbor[:, v] = torch.arange(M, device=device)
|
||||
# Center voxel always maps to itself
|
||||
neighbor[:, v] = torch.arange(M, device=device, dtype=torch.int32)
|
||||
continue
|
||||
|
||||
dx, dy, dz = offsets[v]
|
||||
@ -170,7 +78,6 @@ def build_submanifold_neighbor_map(
|
||||
ky = oy + dy
|
||||
kz = oz + dz
|
||||
|
||||
# Check spatial bounds
|
||||
valid = (
|
||||
(kx >= 0) & (kx < W) &
|
||||
(ky >= 0) & (ky < H) &
|
||||
@ -187,192 +94,59 @@ def build_submanifold_neighbor_map(
|
||||
if flat.numel() > 0:
|
||||
found = hashmap.lookup_flat(flat)
|
||||
idx_in_M = torch.where(valid)[0]
|
||||
neighbor[idx_in_M, v] = found
|
||||
neighbor[idx_in_M, v] = found.to(torch.int32)
|
||||
|
||||
valid_found_mask = (found != INVALID)
|
||||
# BUG FIX: old code used found != hashmap.default_value which
|
||||
# compared int32 -1 against int64 4294967295 → always True.
|
||||
# We now explicitly check for valid indices.
|
||||
valid_found_mask = found >= 0
|
||||
if valid_found_mask.any():
|
||||
src_points = idx_in_M[valid_found_mask]
|
||||
dst_points = found[valid_found_mask]
|
||||
neighbor[dst_points, V - 1 - v] = src_points
|
||||
dst_points = found[valid_found_mask].long()
|
||||
neighbor[dst_points, V - 1 - v] = src_points.to(torch.int32)
|
||||
|
||||
return neighbor
|
||||
|
||||
class TorchHashMap:
|
||||
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
|
||||
device = keys.device
|
||||
# use long for searchsorted
|
||||
self.sorted_keys, order = torch.sort(keys.to(torch.long))
|
||||
self.sorted_vals = values.to(torch.long)[order]
|
||||
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
|
||||
self._n = self.sorted_keys.numel()
|
||||
|
||||
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
|
||||
flat = flat_keys.to(torch.long)
|
||||
if self._n == 0:
|
||||
return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
|
||||
idx = torch.searchsorted(self.sorted_keys, flat)
|
||||
idx_safe = torch.clamp(idx, max=self._n - 1)
|
||||
found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat)
|
||||
out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
|
||||
if found.any():
|
||||
out[found] = self.sorted_vals[idx_safe[found]]
|
||||
return out
|
||||
def sparse_submanifold_conv3d(
|
||||
feats: torch.Tensor,
|
||||
coords: torch.Tensor,
|
||||
shape: tuple,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
neighbor_cache: Optional[torch.Tensor],
|
||||
dilation: tuple,
|
||||
max_chunk_mem_gb: float = 6.0,
|
||||
accumulate_f32: bool = True,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
|
||||
UINT32_SENTINEL = 0xFFFFFFFF
|
||||
|
||||
def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
|
||||
device = neighbor_map.device
|
||||
N, V = neighbor_map.shape
|
||||
|
||||
sentinel = UINT32_SENTINEL
|
||||
|
||||
neigh_map_T = neighbor_map.t().reshape(-1)
|
||||
neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32)
|
||||
|
||||
mask = (neighbor_map != sentinel).to(torch.long)
|
||||
gray_code = torch.zeros(N, dtype=torch.long, device=device)
|
||||
|
||||
for v in range(V):
|
||||
gray_code |= (mask[:, v] << v)
|
||||
|
||||
binary_code = gray_code.clone()
|
||||
for v in range(1, V):
|
||||
binary_code ^= (gray_code >> v)
|
||||
|
||||
sorted_idx = torch.argsort(binary_code)
|
||||
|
||||
prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T, dim=0)
|
||||
|
||||
total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0
|
||||
|
||||
if total_valid_signal > 0:
|
||||
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
|
||||
to = (prefix_sum_neighbor_mask[pos] - 1).long()
|
||||
|
||||
valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
||||
valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
||||
|
||||
valid_signal_i[to] = (pos % N).to(torch.long)
|
||||
valid_signal_o[to] = neigh_map_T[pos].to(torch.long)
|
||||
else:
|
||||
valid_signal_i = torch.empty((0,), dtype=torch.long, device=device)
|
||||
valid_signal_o = torch.empty((0,), dtype=torch.long, device=device)
|
||||
|
||||
seg = torch.empty((V + 1,), dtype=torch.long, device=device)
|
||||
seg[0] = 0
|
||||
if V > 0:
|
||||
idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1
|
||||
seg[1:] = prefix_sum_neighbor_mask[idxs]
|
||||
|
||||
return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg
|
||||
|
||||
def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
x = x.to(torch.int64)
|
||||
|
||||
m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device)
|
||||
m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device)
|
||||
m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device)
|
||||
h01 = torch.tensor(0x0101010101010101, dtype=torch.int64, device=x.device)
|
||||
|
||||
x = x - ((x >> 1) & m1)
|
||||
x = (x & m2) + ((x >> 2) & m2)
|
||||
x = (x + (x >> 4)) & m4
|
||||
x = (x * h01) >> 56
|
||||
return x.to(torch.int32)
|
||||
|
||||
|
||||
def neighbor_map_post_process_for_masked_implicit_gemm_2(
|
||||
gray_code: torch.Tensor,
|
||||
sorted_idx: torch.Tensor,
|
||||
block_size: int
|
||||
):
|
||||
device = gray_code.device
|
||||
N = gray_code.numel()
|
||||
num_blocks = (N + block_size - 1) // block_size
|
||||
|
||||
pad = num_blocks * block_size - N
|
||||
if pad > 0:
|
||||
pad_vals = torch.zeros((pad,), dtype=torch.int32, device=device)
|
||||
gray_padded = torch.cat([gray_code[sorted_idx], pad_vals], dim=0)
|
||||
else:
|
||||
gray_padded = gray_code[sorted_idx]
|
||||
|
||||
gray_blocks = gray_padded.view(num_blocks, block_size)
|
||||
|
||||
reduced_code = gray_blocks
|
||||
while reduced_code.shape[1] > 1:
|
||||
half = reduced_code.shape[1] // 2
|
||||
remainder = reduced_code.shape[1] % 2
|
||||
|
||||
left = reduced_code[:, :half * 2:2]
|
||||
right = reduced_code[:, 1:half * 2:2]
|
||||
merged = left | right
|
||||
|
||||
if remainder:
|
||||
reduced_code = torch.cat([merged, reduced_code[:, -1:]], dim=1)
|
||||
else:
|
||||
reduced_code = merged
|
||||
|
||||
reduced_code = reduced_code.squeeze(1)
|
||||
|
||||
seglen_counts = _popcount_int32_tensor(reduced_code).to(torch.int32)
|
||||
|
||||
seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device)
|
||||
seg[0] = 0
|
||||
if num_blocks > 0:
|
||||
seg[1:] = torch.cumsum(seglen_counts, dim=0)
|
||||
|
||||
total = int(seg[-1].item())
|
||||
|
||||
if total == 0:
|
||||
return torch.empty((0,), dtype=torch.int32, device=device), seg
|
||||
|
||||
V = int(reduced_code.max().item()).bit_length() if reduced_code.max() > 0 else 0
|
||||
|
||||
if V == 0:
|
||||
return torch.empty((0,), dtype=torch.int32, device=device), seg
|
||||
|
||||
bit_pos = torch.arange(0, V, dtype=torch.int32, device=device)
|
||||
shifted = reduced_code.unsqueeze(1) >> bit_pos.unsqueeze(0)
|
||||
bits = (shifted & 1).to(torch.bool)
|
||||
|
||||
positions = bit_pos.unsqueeze(0).expand(num_blocks, V)
|
||||
valid_kernel_idx = positions[bits].to(torch.int32).contiguous()
|
||||
|
||||
return valid_kernel_idx, seg
|
||||
|
||||
|
||||
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
|
||||
if NO_TRITON: # TODO
|
||||
raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.")
|
||||
if feats.shape[0] == 0:
|
||||
logging.warning("Found feats to be empty!")
|
||||
Co = weight.shape[0]
|
||||
return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None
|
||||
|
||||
if len(shape) == 5:
|
||||
N, C, W, H, D = shape
|
||||
_, _, W, H, D = shape
|
||||
else:
|
||||
W, H, D = shape
|
||||
|
||||
Co, Kw, Kh, Kd, Ci = weight.shape
|
||||
|
||||
b_stride = W * H * D
|
||||
x_stride = H * D
|
||||
y_stride = D
|
||||
z_stride = 1
|
||||
|
||||
flat_keys = (coords[:, 0].long() * b_stride +
|
||||
coords[:, 1].long() * x_stride +
|
||||
coords[:, 2].long() * y_stride +
|
||||
coords[:, 3].long() * z_stride)
|
||||
|
||||
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=coords.device)
|
||||
|
||||
hashmap = TorchHashMap(flat_keys, vals, 0xFFFFFFFF)
|
||||
V = Kw * Kh * Kd
|
||||
device = feats.device
|
||||
sentinel = -1
|
||||
|
||||
if neighbor_cache is None:
|
||||
b_stride = W * H * D
|
||||
x_stride = H * D
|
||||
y_stride = D
|
||||
z_stride = 1
|
||||
|
||||
flat_keys = (coords[:, 0].long() * b_stride +
|
||||
coords[:, 1].long() * x_stride +
|
||||
coords[:, 2].long() * y_stride +
|
||||
coords[:, 3].long() * z_stride)
|
||||
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=device)
|
||||
hashmap = TorchHashMap(flat_keys, vals, UINT32_SENTINEL)
|
||||
|
||||
neighbor = build_submanifold_neighbor_map(
|
||||
hashmap, coords, W, H, D, Kw, Kh, Kd,
|
||||
dilation[0], dilation[1], dilation[2]
|
||||
@ -380,30 +154,67 @@ def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache
|
||||
else:
|
||||
neighbor = neighbor_cache
|
||||
|
||||
block_size = 128
|
||||
N_pts = feats.shape[0]
|
||||
|
||||
gray_code, sorted_idx, valid_signal_i, valid_signal_o, valid_signal_seg = \
|
||||
neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor)
|
||||
if accumulate_f32:
|
||||
weight_T = weight.view(Co, V * Ci).to(torch.float32).T.contiguous()
|
||||
output = torch.zeros(N_pts, Co, device=device, dtype=torch.float32)
|
||||
else:
|
||||
weight_T = weight.view(Co, V * Ci).to(feats.dtype).T.contiguous()
|
||||
output = torch.zeros(N_pts, Co, device=device, dtype=feats.dtype)
|
||||
|
||||
valid_kernel, valid_kernel_seg = \
|
||||
neighbor_map_post_process_for_masked_implicit_gemm_2(gray_code, sorted_idx, block_size)
|
||||
# ------------------------------------------------------------------
|
||||
# Chunk size from memory budget
|
||||
# ------------------------------------------------------------------
|
||||
bytes_per_elem = 4 if accumulate_f32 else feats.element_size()
|
||||
mem_per_row = V * Ci * bytes_per_elem
|
||||
max_chunk_mem = max_chunk_mem_gb * (1024 ** 3)
|
||||
chunk_size = max(1, int(max_chunk_mem / mem_per_row))
|
||||
chunk_size = min(chunk_size, N_pts)
|
||||
|
||||
valid_kernel_fn = lambda b_size: valid_kernel
|
||||
valid_kernel_seg_fn = lambda b_size: valid_kernel_seg
|
||||
# ------------------------------------------------------------------
|
||||
# Chunked forward pass
|
||||
# Each iteration:
|
||||
# 1. gather (chunk, V, Ci) – memory bound
|
||||
# 2. mask zero invalids – in-place, no extra alloc
|
||||
# 3. reshape (chunk, V*Ci)
|
||||
# 4. GEMM (chunk, V*Ci) @ (V*Ci, Co) → (chunk, Co) – cuBLAS
|
||||
# written directly into output slice via out= argument
|
||||
# ------------------------------------------------------------------
|
||||
for start in range(0, N_pts, chunk_size):
|
||||
end = min(start + chunk_size, N_pts)
|
||||
actual_chunk = end - start
|
||||
|
||||
weight_flat = weight.contiguous().view(Co, -1, Ci)
|
||||
# (chunk, V) int32
|
||||
chunk_neighbor = neighbor[start:end]
|
||||
chunk_valid = chunk_neighbor != sentinel
|
||||
|
||||
out = sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
|
||||
feats,
|
||||
weight_flat,
|
||||
bias,
|
||||
neighbor,
|
||||
sorted_idx,
|
||||
valid_kernel_fn,
|
||||
valid_kernel_seg_fn
|
||||
)
|
||||
# Clamp sentinel -1 → 0 for safe indexing. No clone of the full map.
|
||||
chunk_idx = chunk_neighbor.clamp(min=0).long()
|
||||
|
||||
return out, neighbor
|
||||
# Gather: (chunk, V, Ci). Memory-bound, single index_select.
|
||||
gathered = feats[chunk_idx]
|
||||
|
||||
# Zero invalid neighbours in-place. gathered is a fresh tensor from
|
||||
# advanced indexing, so in-place mutation is safe.
|
||||
gathered.mul_(chunk_valid.unsqueeze(-1))
|
||||
|
||||
# Reshape to (chunk, V*Ci)
|
||||
gathered_flat = gathered.view(actual_chunk, V * Ci)
|
||||
if accumulate_f32:
|
||||
gathered_flat = gathered_flat.to(torch.float32)
|
||||
|
||||
# Single GEMM call per chunk, written directly into output.
|
||||
# This avoids allocating a temporary (chunk, Co) tensor.
|
||||
torch.matmul(gathered_flat, weight_T, out=output[start:end])
|
||||
|
||||
if accumulate_f32:
|
||||
output = output.to(feats.dtype)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias.unsqueeze(0).to(output.dtype)
|
||||
|
||||
return output, neighbor
|
||||
|
||||
class Mesh:
|
||||
def __init__(self,
|
||||
|
||||
@ -802,97 +802,127 @@ def compute_vertex_normals(verts, faces):
|
||||
|
||||
return torch.nn.functional.normalize(vertex_normals, p=2, dim=-1, eps=1e-6)
|
||||
|
||||
class PostProcessMesh(IO.ComfyNode):
|
||||
def _process_mesh_batch(mesh, per_item_fn):
|
||||
"""Handles list/batched/single mesh dispatching, color extraction, and stacking."""
|
||||
mesh = copy.deepcopy(mesh)
|
||||
|
||||
def process_single(v, f, c, bar):
|
||||
v, f, c = per_item_fn(v, f, c)
|
||||
bar.update(1)
|
||||
return v, f, c
|
||||
|
||||
is_list = isinstance(mesh.vertices, list)
|
||||
is_batched_tensor = not is_list and mesh.vertices.ndim == 3
|
||||
|
||||
if is_list or is_batched_tensor:
|
||||
out_v, out_f, out_c = [], [], []
|
||||
bsz = len(mesh.vertices) if is_list else mesh.vertices.shape[0]
|
||||
bar = comfy.utils.ProgressBar(bsz)
|
||||
|
||||
for i in range(bsz):
|
||||
v_i = mesh.vertices[i]
|
||||
f_i = mesh.faces[i]
|
||||
c_i = None
|
||||
if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None:
|
||||
c_i = mesh.vertex_colors[i] if (isinstance(mesh.vertex_colors, list) or mesh.vertex_colors.ndim == 3) else mesh.vertex_colors
|
||||
|
||||
v_i, f_i, c_i = process_single(v_i, f_i, c_i, bar)
|
||||
|
||||
out_v.append(v_i)
|
||||
out_f.append(f_i)
|
||||
if c_i is not None:
|
||||
out_c.append(c_i)
|
||||
|
||||
if all(v.shape == out_v[0].shape for v in out_v) and all(f.shape == out_f[0].shape for f in out_f):
|
||||
mesh.vertices = torch.stack(out_v)
|
||||
mesh.faces = torch.stack(out_f)
|
||||
if out_c:
|
||||
mesh.vertex_colors = torch.stack(out_c)
|
||||
else:
|
||||
mesh.vertices = out_v
|
||||
mesh.faces = out_f
|
||||
if out_c:
|
||||
mesh.vertex_colors = out_c
|
||||
else:
|
||||
c = mesh.vertex_colors if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None else None
|
||||
bar = comfy.utils.ProgressBar(1)
|
||||
v, f, c = process_single(mesh.vertices, mesh.faces, c, bar)
|
||||
mesh.vertices = v
|
||||
mesh.faces = f
|
||||
if c is not None:
|
||||
mesh.vertex_colors = c
|
||||
|
||||
return IO.NodeOutput(mesh)
|
||||
|
||||
|
||||
class DecimateMesh(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="PostProcessMesh",
|
||||
display_name="Post Process Mesh",
|
||||
node_id="DecimateMesh",
|
||||
display_name="Decimate Mesh",
|
||||
category="latent/3d",
|
||||
description=(
|
||||
"Applies a sequence of mesh post-processing operations including optional hole filling"
|
||||
" and mesh simplification to a target face count."
|
||||
),
|
||||
description="Simplifies a mesh to a target face count using QEM.",
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.Int.Input("target_face_count", default=1_000_000, min=0, max=50_000_000,
|
||||
tooltip="Target maximum number of faces after mesh simplification. Set to 0 to disable simplification."),
|
||||
IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0, step=0.0001,
|
||||
tooltip=(
|
||||
"Maximum hole perimeter threshold for filling holes in the mesh. "
|
||||
"Smaller values only fill tiny holes, larger values fill larger gaps. "
|
||||
"Set to 0 to disable hole filling."))
|
||||
IO.Int.Input("target_face_count", default=200_000, min=0, max=50_000_000,
|
||||
tooltip="Target maximum number of faces. Set to 0 to disable."),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output("mesh"),
|
||||
]
|
||||
outputs=[IO.Mesh.Output("mesh")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, target_face_count, fill_holes_perimeter):
|
||||
mesh = copy.deepcopy(mesh)
|
||||
|
||||
def process_single(v, f, c, bar):
|
||||
if fill_holes_perimeter > 0:
|
||||
v, f = fill_holes_fn(v, f, max_perimeter=fill_holes_perimeter)
|
||||
bar.update(1)
|
||||
|
||||
n = compute_vertex_normals(v, f)
|
||||
def execute(cls, mesh, target_face_count):
|
||||
def _fn(v, f, c):
|
||||
if target_face_count > 0 and f.shape[0] > target_face_count:
|
||||
n = compute_vertex_normals(v, f)
|
||||
v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count)
|
||||
bar.update(1)
|
||||
|
||||
v, f, c = make_double_sided(v, f, c)
|
||||
bar.update(1)
|
||||
return v, f, c
|
||||
return _process_mesh_batch(mesh, _fn)
|
||||
|
||||
is_list = isinstance(mesh.vertices, list)
|
||||
is_batched_tensor = not is_list and mesh.vertices.ndim == 3
|
||||
|
||||
if is_list or is_batched_tensor:
|
||||
out_v, out_f, out_c = [], [],[]
|
||||
bsz = len(mesh.vertices) if is_list else mesh.vertices.shape[0]
|
||||
bar = comfy.utils.ProgressBar(3 * bsz)
|
||||
class FillHoles(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="FillHoles",
|
||||
display_name="Fill Holes",
|
||||
category="latent/3d",
|
||||
description="Fills holes in a mesh up to a maximum perimeter threshold.",
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.Float.Input("max_perimeter", default=0.03, min=0.0, step=0.0001,
|
||||
tooltip="Maximum hole perimeter to fill. Set to 0 to disable."),
|
||||
],
|
||||
outputs=[IO.Mesh.Output("mesh")],
|
||||
)
|
||||
|
||||
for i in range(bsz):
|
||||
v_i = mesh.vertices[i]
|
||||
f_i = mesh.faces[i]
|
||||
@classmethod
|
||||
def execute(cls, mesh, max_perimeter):
|
||||
def _fn(v, f, c):
|
||||
if max_perimeter > 0:
|
||||
v, f = fill_holes_fn(v, f, max_perimeter=max_perimeter)
|
||||
return v, f, c
|
||||
return _process_mesh_batch(mesh, _fn)
|
||||
|
||||
# Safely grab colors if they exist
|
||||
c_i = None
|
||||
if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None:
|
||||
c_i = mesh.vertex_colors[i] if (isinstance(mesh.vertex_colors, list) or mesh.vertex_colors.ndim == 3) else mesh.vertex_colors
|
||||
|
||||
v_i, f_i, c_i = process_single(v_i, f_i, c_i, bar)
|
||||
class MakeDoubleSided(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MakeDoubleSided",
|
||||
display_name="Make Double Sided",
|
||||
category="latent/3d",
|
||||
description="Duplicates faces with flipped normals so the mesh renders from both sides.",
|
||||
inputs=[IO.Mesh.Input("mesh")],
|
||||
outputs=[IO.Mesh.Output("mesh")],
|
||||
)
|
||||
|
||||
out_v.append(v_i)
|
||||
out_f.append(f_i)
|
||||
if c_i is not None:
|
||||
out_c.append(c_i)
|
||||
|
||||
# If the output meshes happen to have the exact same shape, stack them nicely.
|
||||
# Otherwise, just leave them as a List! (ComfyUI native standard)
|
||||
if all(v.shape == out_v[0].shape for v in out_v) and all(f.shape == out_f[0].shape for f in out_f):
|
||||
mesh.vertices = torch.stack(out_v)
|
||||
mesh.faces = torch.stack(out_f)
|
||||
if out_c:
|
||||
mesh.vertex_colors = torch.stack(out_c)
|
||||
else:
|
||||
mesh.vertices = out_v
|
||||
mesh.faces = out_f
|
||||
if out_c:
|
||||
mesh.vertex_colors = out_c
|
||||
|
||||
else:
|
||||
# Single Unbatched Mesh[V, 3]
|
||||
c = mesh.vertex_colors if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None else None
|
||||
v, f, c = process_single(mesh.vertices, mesh.faces, c)
|
||||
mesh.vertices = v
|
||||
mesh.faces = f
|
||||
if c is not None:
|
||||
mesh.vertex_colors = c
|
||||
|
||||
return IO.NodeOutput(mesh)
|
||||
@classmethod
|
||||
def execute(cls, mesh):
|
||||
def _fn(v, f, c):
|
||||
return make_double_sided(v, f, c)
|
||||
return _process_mesh_batch(mesh, _fn)
|
||||
|
||||
|
||||
|
||||
@ -900,7 +930,9 @@ class PostProcessMeshExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
PostProcessMesh,
|
||||
MakeDoubleSided,
|
||||
FillHoles,
|
||||
DecimateMesh,
|
||||
PaintMesh
|
||||
]
|
||||
|
||||
|
||||
@ -8,7 +8,6 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
|
||||
HighResVoxel = io.Custom("HIGH_RES_VOXEL")
|
||||
|
||||
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
||||
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
|
||||
@ -297,7 +296,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode):
|
||||
))
|
||||
],
|
||||
outputs=[
|
||||
HighResVoxel.Output(
|
||||
IO.Voxel.Output(
|
||||
"high_res_voxel",
|
||||
tooltip=(
|
||||
"High-resolution sparse coordinates produced after cascade upsampling. "
|
||||
@ -389,11 +388,11 @@ class Trellis2UpsampleCascade(IO.ComfyNode):
|
||||
final_coords_list.append(final_coords_i)
|
||||
output_coord_counts.append(int(final_coords_i.shape[0]))
|
||||
|
||||
output = {
|
||||
"coords": torch.cat(final_coords_list, dim=0),
|
||||
"coord_counts": torch.tensor(output_coord_counts, dtype=torch.int64),
|
||||
"resolutions": torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64),
|
||||
}
|
||||
coords = torch.cat(final_coords_list, dim=0)
|
||||
output = Types.VOXEL(coords)
|
||||
output.coord_counts = torch.tensor(output_coord_counts, dtype=torch.int64)
|
||||
output.resolutions = torch.full((len(final_coords_list),), int(hr_resolution), dtype=torch.int64)
|
||||
output.upsampled = True
|
||||
|
||||
return IO.NodeOutput(output,)
|
||||
|
||||
@ -537,9 +536,8 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
|
||||
node_id="EmptyTrellis2ShapeLatent",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
IO.Voxel.Input(
|
||||
"voxel",
|
||||
types=[IO.Voxel, HighResVoxel],
|
||||
tooltip=(
|
||||
"Shape structure input. Accepts either a voxel structure "
|
||||
"or upsampled voxel coordinates from a previous cascade stage."
|
||||
@ -555,20 +553,18 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
|
||||
def execute(cls, voxel):
|
||||
# to accept the upscaled coords
|
||||
is_512_pass = False
|
||||
upsampled = hasattr(voxel, "upsampled")
|
||||
if upsampled:
|
||||
voxel = voxel.data
|
||||
|
||||
if isinstance(voxel, dict):
|
||||
voxel = voxel["coords"]
|
||||
|
||||
if hasattr(voxel, "data") and voxel.data.ndim == 4:
|
||||
if not upsampled:
|
||||
decoded = voxel.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
is_512_pass = True
|
||||
|
||||
elif isinstance(voxel, torch.Tensor) and voxel.ndim == 2:
|
||||
else:
|
||||
coords = voxel.int()
|
||||
is_512_pass = False
|
||||
else:
|
||||
raise ValueError(f"Invalid input to EmptyTrellis2ShapeLatent: {type(voxel)}")
|
||||
|
||||
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
|
||||
in_channels = 32
|
||||
@ -589,9 +585,8 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
||||
node_id="EmptyTrellis2LatentTexture",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
IO.Voxel.Input(
|
||||
"voxel",
|
||||
types=[IO.Voxel, HighResVoxel],
|
||||
tooltip=(
|
||||
"Shape structure input. Accepts either a voxel structure "
|
||||
"or upsampled voxel coordinates from a previous cascade stage."
|
||||
@ -607,13 +602,14 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, voxel, shape_latent):
|
||||
channels = 32
|
||||
if isinstance(voxel, dict):
|
||||
voxel = voxel["coords"]
|
||||
if hasattr(voxel, "data") and voxel.data.ndim == 4:
|
||||
upsampled = hasattr(voxel, "upsampled")
|
||||
if upsampled:
|
||||
voxel = voxel.data
|
||||
|
||||
if not upsampled:
|
||||
decoded = voxel.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
|
||||
elif isinstance(voxel, torch.Tensor) and voxel.ndim == 2:
|
||||
else:
|
||||
coords = voxel.int()
|
||||
|
||||
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user