From f2c0320fe84e533ad1e0173e0eb25c934027b216 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 5 Feb 2026 17:19:57 +0200 Subject: [PATCH] fixes to vae and cumesh impl. --- comfy/ldm/trellis2/cumesh.py | 189 ++++++++++++++++++++++++++--------- comfy/ldm/trellis2/vae.py | 2 +- 2 files changed, 143 insertions(+), 48 deletions(-) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index 41ac35db9..fe7e80e15 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -5,6 +5,10 @@ import torch from typing import Dict, Callable NO_TRITION = False +try: + allow_tf32 = torch.cuda.is_tf32_supported +except Exception: + allow_tf32 = False try: import triton import triton.language as tl @@ -102,10 +106,13 @@ try: grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),) sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid]( input, weight, bias, neighbor, sorted_idx, output, - N, LOGN, Ci, Co, V, # + N, LOGN, Ci, Co, V, + B1=128, + B2=64, + BK=32, valid_kernel=valid_kernel, valid_kernel_seg=valid_kernel_seg, - allow_tf32=torch.cuda.is_tf32_supported(), + allow_tf32=allow_tf32, ) return output except: @@ -140,16 +147,16 @@ def build_submanifold_neighbor_map( neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long) - b = coords[:, 0] - x = coords[:, 1] - y = coords[:, 2] - z = coords[:, 3] + b = coords[:, 0].long() + x = coords[:, 1].long() + y = coords[:, 2].long() + z = coords[:, 3].long() offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device) - ox = x[:, None] - (Kw // 2) * Dw - oy = y[:, None] - (Kh // 2) * Dh - oz = z[:, None] - (Kd // 2) * Dd + ox = x - (Kw // 2) * Dw + oy = y - (Kh // 2) * Dh + oz = z - (Kd // 2) * Dd for v in range(half_V): if v == half_V - 1: @@ -158,10 +165,11 @@ def build_submanifold_neighbor_map( dx, dy, dz = offsets[v] - kx = ox[:, v] + dx - ky = oy[:, v] + dy - kz = oz[:, v] + dz + kx = ox + dx + ky = oy + dy + kz = oz + dz + # Check spatial bounds valid = ( (kx >= 0) & (kx < W) & (ky >= 0) & (ky < H) & @@ -169,22 +177,22 @@ def build_submanifold_neighbor_map( ) flat = ( - b * (W * H * D) + - kx * (H * D) + - ky * D + - kz + b[valid] * (W * H * D) + + kx[valid] * (H * D) + + ky[valid] * D + + kz[valid] ) - flat = flat[valid] - idx = torch.nonzero(valid, as_tuple=False).squeeze(1) + if flat.numel() > 0: + found = hashmap.lookup_flat(flat) + idx_in_M = torch.where(valid)[0] + neighbor[idx_in_M, v] = found - found = hashmap.lookup_flat(flat) - - neighbor[idx, v] = found - - # symmetric write - valid_found = found != INVALID - neighbor[found[valid_found], V - 1 - v] = idx[valid_found] + valid_found_mask = (found != INVALID) + if valid_found_mask.any(): + src_points = idx_in_M[valid_found_mask] + dst_points = found[valid_found_mask] + neighbor[dst_points, V - 1 - v] = src_points return neighbor @@ -461,31 +469,118 @@ class Mesh: def cpu(self): return self.to('cpu') - # TODO could be an option + # could make this into a new node def fill_holes(self, max_hole_perimeter=3e-2): - import cumesh - vertices = self.vertices.cuda() - faces = self.faces.cuda() - mesh = cumesh.CuMesh() - mesh.init(vertices, faces) - mesh.get_edges() - mesh.get_boundary_info() - if mesh.num_boundaries == 0: - return - mesh.get_vertex_edge_adjacency() - mesh.get_vertex_boundary_adjacency() - mesh.get_manifold_boundary_adjacency() - mesh.read_manifold_boundary_adjacency() - mesh.get_boundary_connected_components() - mesh.get_boundary_loops() - if mesh.num_boundary_loops == 0: - return - mesh.fill_holes(max_hole_perimeter=max_hole_perimeter) - new_vertices, new_faces = mesh.read() + device = self.vertices.device + vertices = self.vertices + faces = self.faces + + edges = torch.cat([ + faces[:, [0, 1]], + faces[:, [1, 2]], + faces[:, [2, 0]] + ], dim=0) + + edges_sorted, _ = torch.sort(edges, dim=1) + + unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) + + boundary_mask = counts == 1 + boundary_edges_sorted = unique_edges[boundary_mask] + + if boundary_edges_sorted.shape[0] == 0: + return + max_idx = vertices.shape[0] + + _, inverse_indices, counts_packed = torch.unique( + torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1], + return_inverse=True, return_counts=True + ) + + boundary_packed_mask = counts_packed == 1 + is_boundary_edge = boundary_packed_mask[inverse_indices] + + active_boundary_edges = edges[is_boundary_edge] + + adj = {} + edges_np = active_boundary_edges.cpu().numpy() + for u, v in edges_np: + adj[u] = v + + loops = [] + visited_edges = set() + + possible_starts = list(adj.keys()) + + processed_nodes = set() + + for start_node in possible_starts: + if start_node in processed_nodes: + continue + + current_loop = [] + curr = start_node + + while curr in adj: + next_node = adj[curr] + if (curr, next_node) in visited_edges: + break + + visited_edges.add((curr, next_node)) + processed_nodes.add(curr) + current_loop.append(curr) + + curr = next_node + + if curr == start_node: + loops.append(current_loop) + break + + if len(current_loop) > len(edges_np): + break + + if not loops: + return + + new_faces = [] + + v_offset = vertices.shape[0] + + valid_new_verts = [] + + for loop_indices in loops: + if len(loop_indices) < 3: + continue + + loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) + loop_verts = vertices[loop_tensor] + + diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) + perimeter = torch.norm(diffs, dim=1).sum() + + if perimeter > max_hole_perimeter: + continue + + center = loop_verts.mean(dim=0) + valid_new_verts.append(center) + + c_idx = v_offset + v_offset += 1 + + num_v = len(loop_indices) + for i in range(num_v): + v_curr = loop_indices[i] + v_next = loop_indices[(i + 1) % num_v] + new_faces.append([v_curr, v_next, c_idx]) + + if len(valid_new_verts) > 0: + added_vertices = torch.stack(valid_new_verts, dim=0) + added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) + + self.vertices = torch.cat([self.vertices, added_vertices], dim=0) + self.faces = torch.cat([self.faces, added_faces], dim=0) - self.vertices = new_vertices.to(self.device) - self.faces = new_faces.to(self.device) # TODO could be an option def simplify(self, target=1000000, verbose: bool=False, options: dict={}): diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 584fa91ae..2bbfa938c 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -208,7 +208,7 @@ class SparseResBlockC2S3d(nn.Module): self.to_subdiv = SparseLinear(channels, 8) self.updown = SparseChannel2Spatial(2) - def _forward(self, x, subdiv = None): + def forward(self, x, subdiv = None): if self.pred_subdiv: subdiv = self.to_subdiv(x) h = x.replace(self.norm1(x.feats))