fixes to vae and cumesh impl.

This commit is contained in:
Yousef Rafat 2026-02-05 17:19:57 +02:00
parent cd0f7ba64e
commit f2c0320fe8
2 changed files with 143 additions and 48 deletions

View File

@ -5,6 +5,10 @@ import torch
from typing import Dict, Callable
NO_TRITION = False
try:
allow_tf32 = torch.cuda.is_tf32_supported
except Exception:
allow_tf32 = False
try:
import triton
import triton.language as tl
@ -102,10 +106,13 @@ try:
grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid](
input, weight, bias, neighbor, sorted_idx, output,
N, LOGN, Ci, Co, V, #
N, LOGN, Ci, Co, V,
B1=128,
B2=64,
BK=32,
valid_kernel=valid_kernel,
valid_kernel_seg=valid_kernel_seg,
allow_tf32=torch.cuda.is_tf32_supported(),
allow_tf32=allow_tf32,
)
return output
except:
@ -140,16 +147,16 @@ def build_submanifold_neighbor_map(
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long)
b = coords[:, 0]
x = coords[:, 1]
y = coords[:, 2]
z = coords[:, 3]
b = coords[:, 0].long()
x = coords[:, 1].long()
y = coords[:, 2].long()
z = coords[:, 3].long()
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device)
ox = x[:, None] - (Kw // 2) * Dw
oy = y[:, None] - (Kh // 2) * Dh
oz = z[:, None] - (Kd // 2) * Dd
ox = x - (Kw // 2) * Dw
oy = y - (Kh // 2) * Dh
oz = z - (Kd // 2) * Dd
for v in range(half_V):
if v == half_V - 1:
@ -158,10 +165,11 @@ def build_submanifold_neighbor_map(
dx, dy, dz = offsets[v]
kx = ox[:, v] + dx
ky = oy[:, v] + dy
kz = oz[:, v] + dz
kx = ox + dx
ky = oy + dy
kz = oz + dz
# Check spatial bounds
valid = (
(kx >= 0) & (kx < W) &
(ky >= 0) & (ky < H) &
@ -169,22 +177,22 @@ def build_submanifold_neighbor_map(
)
flat = (
b * (W * H * D) +
kx * (H * D) +
ky * D +
kz
b[valid] * (W * H * D) +
kx[valid] * (H * D) +
ky[valid] * D +
kz[valid]
)
flat = flat[valid]
idx = torch.nonzero(valid, as_tuple=False).squeeze(1)
if flat.numel() > 0:
found = hashmap.lookup_flat(flat)
idx_in_M = torch.where(valid)[0]
neighbor[idx_in_M, v] = found
found = hashmap.lookup_flat(flat)
neighbor[idx, v] = found
# symmetric write
valid_found = found != INVALID
neighbor[found[valid_found], V - 1 - v] = idx[valid_found]
valid_found_mask = (found != INVALID)
if valid_found_mask.any():
src_points = idx_in_M[valid_found_mask]
dst_points = found[valid_found_mask]
neighbor[dst_points, V - 1 - v] = src_points
return neighbor
@ -461,31 +469,118 @@ class Mesh:
def cpu(self):
return self.to('cpu')
# TODO could be an option
# could make this into a new node
def fill_holes(self, max_hole_perimeter=3e-2):
import cumesh
vertices = self.vertices.cuda()
faces = self.faces.cuda()
mesh = cumesh.CuMesh()
mesh.init(vertices, faces)
mesh.get_edges()
mesh.get_boundary_info()
if mesh.num_boundaries == 0:
return
mesh.get_vertex_edge_adjacency()
mesh.get_vertex_boundary_adjacency()
mesh.get_manifold_boundary_adjacency()
mesh.read_manifold_boundary_adjacency()
mesh.get_boundary_connected_components()
mesh.get_boundary_loops()
if mesh.num_boundary_loops == 0:
return
mesh.fill_holes(max_hole_perimeter=max_hole_perimeter)
new_vertices, new_faces = mesh.read()
device = self.vertices.device
vertices = self.vertices
faces = self.faces
edges = torch.cat([
faces[:, [0, 1]],
faces[:, [1, 2]],
faces[:, [2, 0]]
], dim=0)
edges_sorted, _ = torch.sort(edges, dim=1)
unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True)
boundary_mask = counts == 1
boundary_edges_sorted = unique_edges[boundary_mask]
if boundary_edges_sorted.shape[0] == 0:
return
max_idx = vertices.shape[0]
_, inverse_indices, counts_packed = torch.unique(
torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1],
return_inverse=True, return_counts=True
)
boundary_packed_mask = counts_packed == 1
is_boundary_edge = boundary_packed_mask[inverse_indices]
active_boundary_edges = edges[is_boundary_edge]
adj = {}
edges_np = active_boundary_edges.cpu().numpy()
for u, v in edges_np:
adj[u] = v
loops = []
visited_edges = set()
possible_starts = list(adj.keys())
processed_nodes = set()
for start_node in possible_starts:
if start_node in processed_nodes:
continue
current_loop = []
curr = start_node
while curr in adj:
next_node = adj[curr]
if (curr, next_node) in visited_edges:
break
visited_edges.add((curr, next_node))
processed_nodes.add(curr)
current_loop.append(curr)
curr = next_node
if curr == start_node:
loops.append(current_loop)
break
if len(current_loop) > len(edges_np):
break
if not loops:
return
new_faces = []
v_offset = vertices.shape[0]
valid_new_verts = []
for loop_indices in loops:
if len(loop_indices) < 3:
continue
loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device)
loop_verts = vertices[loop_tensor]
diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0)
perimeter = torch.norm(diffs, dim=1).sum()
if perimeter > max_hole_perimeter:
continue
center = loop_verts.mean(dim=0)
valid_new_verts.append(center)
c_idx = v_offset
v_offset += 1
num_v = len(loop_indices)
for i in range(num_v):
v_curr = loop_indices[i]
v_next = loop_indices[(i + 1) % num_v]
new_faces.append([v_curr, v_next, c_idx])
if len(valid_new_verts) > 0:
added_vertices = torch.stack(valid_new_verts, dim=0)
added_faces = torch.tensor(new_faces, dtype=torch.long, device=device)
self.vertices = torch.cat([self.vertices, added_vertices], dim=0)
self.faces = torch.cat([self.faces, added_faces], dim=0)
self.vertices = new_vertices.to(self.device)
self.faces = new_faces.to(self.device)
# TODO could be an option
def simplify(self, target=1000000, verbose: bool=False, options: dict={}):

View File

@ -208,7 +208,7 @@ class SparseResBlockC2S3d(nn.Module):
self.to_subdiv = SparseLinear(channels, 8)
self.updown = SparseChannel2Spatial(2)
def _forward(self, x, subdiv = None):
def forward(self, x, subdiv = None):
if self.pred_subdiv:
subdiv = self.to_subdiv(x)
h = x.replace(self.norm1(x.feats))