From e90bde2f82c0234f5f8f7fdd5ff6390763a4a62c Mon Sep 17 00:00:00 2001 From: "Yousef R. Gamaleldin" <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 22 May 2026 00:11:48 +0300 Subject: [PATCH] Vertex Clustering, Mask Fix, Normal Fix (#14035) * Vertex Clustering, Mask Fix, Normal Fix * detects inverted mask * update the decimate mesh --- comfy_extras/nodes_mesh_postprocess.py | 341 +++++++++++++++++++++++-- comfy_extras/nodes_trellis2.py | 57 ++++- 2 files changed, 371 insertions(+), 27 deletions(-) diff --git a/comfy_extras/nodes_mesh_postprocess.py b/comfy_extras/nodes_mesh_postprocess.py index 74eb9e0c1..9cc77a206 100644 --- a/comfy_extras/nodes_mesh_postprocess.py +++ b/comfy_extras/nodes_mesh_postprocess.py @@ -625,7 +625,6 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, normals_in, target_faces, # Finalize final_v = verts[v_alive] final_c = colors[v_alive] if colors is not None else None - final_n = normals[v_alive] if normals is not None else None remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device) remap[v_alive] = torch.arange(int(v_alive.sum().item()), device=device) @@ -640,26 +639,9 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, normals_in, target_faces, if final_f.numel() > 0: final_f = torch.unique(torch.sort(final_f, dim=1)[0], dim=0) - if final_n is not None and final_f.numel() > 0: - v0, v1, v2 = final_v[final_f[:, 0]], final_v[final_f[:, 1]], final_v[final_f[:, 2]] - - # calculate the actual normal of the simplified faces - face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) - - # Get the average reference normal for each face - n0, n1, n2 = final_n[final_f[:, 0]], final_n[final_f[:, 1]], final_n[final_f[:, 2]] - ref_face_normals = (n0 + n1 + n2) / 3.0 - - # Dot product to check if they point in the same direction - dot_products = (face_normals * ref_face_normals).sum(dim=-1) - - # Flip the indices of ONLY the incorrect faces (swap vertex 1 and 2) - wrong_way_mask = dot_products < 0 - final_f[wrong_way_mask] = final_f[wrong_way_mask][:, [0, 2, 1]] - final_v, final_f = _cleanup_mesh(final_v, final_f, min_angle_deg=0.5, max_aspect=100.0) - return final_v, final_f, final_c, final_n + return final_v, final_f, final_c, None def simplify_fn_fast(vertices, faces, colors=None, normals=None, target=100000, max_edge_length=None): @@ -709,6 +691,79 @@ def simplify_fn_fast(vertices, faces, colors=None, normals=None, target=100000, ) return final_v, final_f, final_c, final_n +def simplify_fn_vertex(vertices, faces, colors=None, target=100000): + if vertices.ndim == 3: + v_list, f_list, c_list = [], [], [] + for i in range(vertices.shape[0]): + c_in = colors[i] if colors is not None else None + v_i, f_i, c_i = simplify_fn_vertex(vertices[i], faces[i], c_in, target) + v_list.append(v_i) + f_list.append(f_i) + if c_i is not None: + c_list.append(c_i) + + c_out = torch.stack(c_list) if len(c_list) > 0 else None + return torch.stack(v_list), torch.stack(f_list), c_out + + if faces.shape[0] <= target: + return vertices, faces, colors + + device = vertices.device + target_v = max(target / 4.0, 1.0) + + min_v = vertices.min(dim=0)[0] + max_v = vertices.max(dim=0)[0] + extent = max_v - min_v + + volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8) + cell_size = (volume / target_v) ** (1/3.0) + + # Use CPU-side ordered reductions here so repeated runs produce identical + # simplified meshes instead of relying on GPU scatter-add accumulation order. + vertices_np = vertices.detach().cpu().numpy() + faces_np = faces.detach().cpu().numpy() + colors_np = colors.detach().cpu().numpy() if colors is not None else None + min_v_np = min_v.detach().cpu().numpy() + cell_size_value = float(cell_size.detach().cpu()) + + quantized = np.rint((vertices_np - min_v_np) / cell_size_value).astype(np.int64) + unique_coords, inverse_indices = np.unique(quantized, axis=0, return_inverse=True) + num_cells = unique_coords.shape[0] + + new_vertices_np = np.zeros((num_cells, 3), dtype=vertices_np.dtype) + np.add.at(new_vertices_np, inverse_indices, vertices_np) + + counts_np = np.bincount(inverse_indices, minlength=num_cells).astype(vertices_np.dtype).reshape(-1, 1) + new_vertices_np = new_vertices_np / np.clip(counts_np, 1, None) + + new_colors = None + if colors_np is not None: + new_colors_np = np.zeros((num_cells, colors_np.shape[1]), dtype=colors_np.dtype) + np.add.at(new_colors_np, inverse_indices, colors_np) + new_colors = new_colors_np / np.clip(counts_np, 1, None) + + new_faces = inverse_indices[faces_np] + valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \ + (new_faces[:, 1] != new_faces[:, 2]) & \ + (new_faces[:, 2] != new_faces[:, 0]) + new_faces = new_faces[valid_mask] + + if new_faces.size == 0: + final_vertices_np = new_vertices_np[:0] + final_faces_np = np.empty((0, 3), dtype=np.int64) + final_colors_np = new_colors[:0] if new_colors is not None else None + else: + unique_face_indices, inv_face = np.unique(new_faces.reshape(-1), return_inverse=True) + final_vertices_np = new_vertices_np[unique_face_indices] + final_faces_np = inv_face.reshape(-1, 3).astype(np.int64) + final_colors_np = new_colors[unique_face_indices] if new_colors is not None else None + + final_vertices = torch.from_numpy(final_vertices_np).to(device=device, dtype=vertices.dtype) + final_faces = torch.from_numpy(final_faces_np).to(device=device, dtype=faces.dtype) + final_colors = torch.from_numpy(final_colors_np).to(device=device, dtype=colors.dtype) if final_colors_np is not None else None + + return final_vertices, final_faces, final_colors + def compute_vertex_normals(verts, faces): """Computes area-weighted vertex normals.""" # QUICK FIX: Ensure indices are int64 for scatter_add_ @@ -781,6 +836,235 @@ def _process_mesh_batch(mesh, per_item_fn): return IO.NodeOutput(mesh) +def fix_face_orientation(vertices, faces, reference_normals=None): + num_faces = faces.shape[0] + if num_faces == 0: + return faces + + device = faces.device + corrected = faces.clone() + + idx = torch.tensor([[0, 1], [1, 2], [2, 0]], dtype=torch.int64, device=device) + edges = corrected[:, idx] # (num_faces, 3, 2) + + edges_canon = torch.sort(edges, dim=2)[0] + edges_flat = edges_canon.view(-1, 2) + + max_vert = vertices.shape[0] + edge_hash = edges_flat[:, 0] * max_vert + edges_flat[:, 1] + + hash_sorted, sort_idx = torch.sort(edge_hash) + + hash_diff = hash_sorted[1:] != hash_sorted[:-1] + hash_diff = torch.cat([torch.tensor([True], device=device), hash_diff]) + unique_starts = torch.nonzero(hash_diff, as_tuple=True)[0] + unique_ends = torch.cat([unique_starts[1:], torch.tensor([len(hash_sorted)], device=device)]) + run_lengths = unique_ends - unique_starts + + manifold_mask = run_lengths == 2 + manifold_starts = unique_starts[manifold_mask] + + component_id_np = np.full(num_faces, -1, dtype=np.int64) + + if manifold_starts.numel() > 0: + # Replaces slow, nested element-wise matching with direct index mapping + f_a = sort_idx[manifold_starts] // 3 + f_b = sort_idx[manifold_starts + 1] // 3 + local_edge_a = sort_idx[manifold_starts] % 3 + local_edge_b = sort_idx[manifold_starts + 1] % 3 + + dir_edge_a = edges[f_a, local_edge_a] + dir_edge_b = edges[f_b, local_edge_b] + + opposite = (dir_edge_a == dir_edge_b.flip(dims=[1])).all(dim=1) + needs_flip_rel = ~opposite + + adj_faces = torch.cat([f_a, f_b]) + adj_neighbors = torch.cat([f_b, f_a]) + adj_flip = torch.cat([needs_flip_rel, needs_flip_rel]) + + adj_order = torch.argsort(adj_faces) + adj_faces_np = adj_faces[adj_order].cpu().numpy() + adj_neighbors_np = adj_neighbors[adj_order].cpu().numpy() + adj_flip_np = adj_flip[adj_order].cpu().numpy() + + # Build CSR-style adjacency on CPU using NumPy + adj_ptr_np = np.zeros(num_faces + 1, dtype=np.int64) + counts_np = np.bincount(adj_faces_np, minlength=num_faces) + adj_ptr_np[1:] = np.cumsum(counts_np) + + visited_np = np.zeros(num_faces, dtype=bool) + flip_state_np = np.zeros(num_faces, dtype=bool) + comp_counter = 0 + + queue_np = np.empty(num_faces, dtype=np.int64) + + for seed in range(num_faces): + if visited_np[seed]: + continue + + visited_np[seed] = True + component_id_np[seed] = comp_counter + q_head = 0 + q_tail = 1 + queue_np[0] = seed + + while q_head < q_tail: + current = queue_np[q_head] + q_head += 1 + + start = adj_ptr_np[current] + end = adj_ptr_np[current + 1] + if start == end: + continue + + nbrs = adj_neighbors_np[start:end] + flips = adj_flip_np[start:end] + src_flip = flip_state_np[current] + + unvisited_mask = ~visited_np[nbrs] + if not np.any(unvisited_mask): + continue + + nbrs_new = nbrs[unvisited_mask] + flips_new = flips[unvisited_mask] + + visited_np[nbrs_new] = True + component_id_np[nbrs_new] = comp_counter + + # NumPy bitwise XOR is fast and direct + flip_state_np[nbrs_new] = flips_new ^ src_flip + + n_new = len(nbrs_new) + queue_np[q_tail:q_tail + n_new] = nbrs_new + q_tail += n_new + + comp_counter += 1 + + flip_state = torch.from_numpy(flip_state_np).to(device=device) + component_id = torch.from_numpy(component_id_np).to(device=device) + + if flip_state.any(): + corrected[flip_state] = corrected[flip_state][:, [0, 2, 1]] + else: + component_id = torch.arange(num_faces, device=device) + + v0 = vertices[corrected[:, 0]] + v1 = vertices[corrected[:, 1]] + v2 = vertices[corrected[:, 2]] + + face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) + face_normals = face_normals / (torch.norm(face_normals, dim=-1, keepdim=True) + 1e-8) + + num_components = int(component_id.max().item()) + 1 if component_id.numel() > 0 else 0 + + if reference_normals is not None: + n0 = reference_normals[corrected[:, 0]] + n1 = reference_normals[corrected[:, 1]] + n2 = reference_normals[corrected[:, 2]] + ref_normals = (n0 + n1 + n2) / 3.0 + ref_normals = ref_normals / (torch.norm(ref_normals, dim=-1, keepdim=True) + 1e-8) + + votes = (face_normals * ref_normals).sum(dim=-1) + + outward_votes_comp = torch.zeros(num_components, dtype=torch.int64, device=device) + inward_votes_comp = torch.zeros(num_components, dtype=torch.int64, device=device) + + outward_votes_comp.scatter_add_(0, component_id, (votes > 0).to(torch.int64)) + inward_votes_comp.scatter_add_(0, component_id, (votes < 0).to(torch.int64)) + + n_faces_comp_int = torch.zeros(num_components, dtype=torch.int64, device=device) + n_faces_comp_int.scatter_add_(0, component_id, torch.ones(num_faces, dtype=torch.int64, device=device)) + + thresholds = torch.maximum(torch.ones_like(n_faces_comp_int), n_faces_comp_int // 10) + should_flip_comp = inward_votes_comp > outward_votes_comp + thresholds + else: + # Vectorized 3-Axis Extreme Majority Vote (Geometrically Infallible) + face_centroids = (v0 + v1 + v2) / 3.0 + + votes_by_axis = [] + for axis in range(3): + coords = face_centroids[:, axis] + + # Double stable sort acts as a vectorized lexsort on (coords, component_id) + sort_idx = torch.argsort(coords, stable=True) + sort_idx = sort_idx[torch.argsort(component_id[sort_idx], stable=True)] + + # Find group boundaries to get the extreme outer face along this axis per component + comp_id_sorted = component_id[sort_idx] + group_ends = torch.nonzero(comp_id_sorted[1:] != comp_id_sorted[:-1], as_tuple=True)[0] + group_ends = torch.cat([group_ends, torch.tensor([len(comp_id_sorted) - 1], device=device)]) + + extreme_face_indices = sort_idx[group_ends] + extreme_normals = face_normals[extreme_face_indices] + + # Normal's component along the respective axis should be positive + votes_by_axis.append(extreme_normals[:, axis] > 0) + + stacked_votes = torch.stack(votes_by_axis, dim=0) + should_flip_comp = stacked_votes.sum(dim=0) < 2 # False if at least 2 axes agree outward + + should_flip_face = should_flip_comp[component_id] + if should_flip_face.any(): + corrected[should_flip_face] = corrected[should_flip_face][:, [0, 2, 1]] + + return corrected + + +def unweld_and_offset_mesh(vertices, faces, colors=None, z_offset=1e-4): + is_batched = vertices.ndim == 3 + device = vertices.device + + if is_batched: + B = vertices.shape[0] + F = faces.shape[1] + + # 1. Advanced index broadcast to pull all faces in parallel without any Python loops + batch_idx = torch.arange(B, device=device).view(-1, 1, 1) + v_faces = vertices[batch_idx, faces] # shape (B, F, 3, 3) + + v0, v1, v2 = v_faces[:, :, 0], v_faces[:, :, 1], v_faces[:, :, 2] + + # 2. Compute face normals + fn = torch.cross(v1 - v0, v2 - v0, dim=-1) + fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8) + + # 3. Translate directly along the face normals in parallel + offset_verts = v_faces + fn.unsqueeze(2) * z_offset + out_v = offset_verts.reshape(B, -1, 3) + + # 4. Generate identical faces for all batches using constant expansion (O(1)) + f_single = torch.arange(F * 3, device=device).reshape(-1, 3) + out_f = f_single.unsqueeze(0).expand(B, -1, -1) + + if colors is not None: + c_faces = colors[batch_idx, faces] + out_c = c_faces.reshape(B, -1, colors.shape[-1]) + return out_v, out_f, out_c + return out_v, out_f + + # --- Unbatched (Single Mesh) --- + v_faces = vertices[faces] # shape (F, 3, 3) + v0, v1, v2 = v_faces[:, 0], v_faces[:, 1], v_faces[:, 2] + + # Compute face normals + fn = torch.cross(v1 - v0, v2 - v0, dim=-1) + fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8) + + # Offset each face's private vertices along its face normal + offset_verts = v_faces + fn.unsqueeze(1) * z_offset + offset_verts = offset_verts.reshape(-1, 3) + + # Generate sequential face indices for the unwelded vertices + f_unwelded = torch.arange(faces.shape[0] * 3, device=vertices.device).reshape(-1, 3) + + if colors is not None: + c_faces = colors[faces] + c_unwelded = c_faces.reshape(-1, colors.shape[-1]) + return offset_verts, f_unwelded, c_unwelded + + return offset_verts, f_unwelded, None + class DecimateMesh(IO.ComfyNode): @classmethod def define_schema(cls): @@ -801,8 +1085,23 @@ class DecimateMesh(IO.ComfyNode): def execute(cls, mesh, target_face_count): def _fn(v, f, c): if target_face_count > 0 and f.shape[0] > target_face_count: - n = compute_vertex_normals(v, f) - v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count) + try: + v0, v1, v2 = v[f[:, 0]], v[f[:, 1]], v[f[:, 2]] + fn = torch.cross(v1 - v0, v2 - v0, dim=-1) + fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8) + + n = torch.zeros_like(v) + n.index_add_(0, f[:, 0], fn) + n.index_add_(0, f[:, 1], fn) + n.index_add_(0, f[:, 2], fn) + n = n / (torch.norm(n, dim=-1, keepdim=True) + 1e-8) + + v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count) + f = fix_face_orientation(v, f) + v, f, c = unweld_and_offset_mesh(v, f, colors=c, z_offset=1e-4) + except Exception as e: + logging.warning("Ran into an error while QEM Simplifying, falling back to vertex clustering:\n" + str(e)) + v, f, c = simplify_fn_vertex(v, f, c, target_face_count) return v, f, c return _process_mesh_batch(mesh, _fn) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index c922d58b6..698b6b128 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -445,13 +445,31 @@ class Trellis2Conditioning(IO.ComfyNode): ] ) + @classmethod @classmethod def execute(cls, clip_vision_model, image, mask) -> IO.NodeOutput: # Normalize to batched form so per-image conditioning loop below is uniform. if image.ndim == 3: image = image.unsqueeze(0) - if mask.ndim == 2: + elif image.ndim == 4: + if image.shape[1] in [1, 3, 4] and image.shape[-1] not in [1, 3, 4]: + image = image.permute(0, 2, 3, 1) + + # normalize mask to standard [B, H, W] (handling 2D, 3D, and 4D variants) + if mask.ndim == 4: + if mask.shape[1] == 1: + mask = mask.squeeze(1) + elif mask.shape[-1] == 1: + mask = mask.squeeze(-1) + else: + mask = mask[:, :, :, 0] # take first channel as fallback + + if mask.ndim == 3: + if mask.shape[-1] == 1: + mask = mask.squeeze(-1).unsqueeze(0) + elif mask.ndim == 2: mask = mask.unsqueeze(0) + batch_size = image.shape[0] if mask.shape[0] == 1 and batch_size > 1: mask = mask.expand(batch_size, -1, -1) @@ -468,6 +486,27 @@ class Trellis2Conditioning(IO.ComfyNode): img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + # Ensure img_np is either 2D (grayscale) or 3D (RGB/RGBA) + if img_np.ndim == 3 and img_np.shape[-1] == 1: + img_np = img_np.squeeze(-1) + + mask_np = mask_np.squeeze() + + # detect inverted mask + border_pixels = np.concatenate([ + mask_np[0, :], mask_np[-1, :], mask_np[:, 0], mask_np[:, -1] + ]) + if np.mean(border_pixels) > 127: + mask_np = 255 - mask_np + + mask_np[mask_np < 35] = 0 + + border_shave = 4 + mask_np[:border_shave, :] = 0 + mask_np[-border_shave:, :] = 0 + mask_np[:, :border_shave] = 0 + mask_np[:, -border_shave:] = 0 + pil_img = Image.fromarray(img_np) pil_mask = Image.fromarray(mask_np) @@ -479,7 +518,7 @@ class Trellis2Conditioning(IO.ComfyNode): pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) - rgba_np[:, :, :3] = np.array(pil_img) + rgba_np[:, :, :3] = np.array(pil_img.convert("RGB")) rgba_np[:, :, 3] = np.array(pil_mask) alpha = rgba_np[:, :, 3] @@ -511,12 +550,18 @@ class Trellis2Conditioning(IO.ComfyNode): alpha_float = cropped_np[:, :, 3:4] composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) - # to match trellis2 code (quantize -> dequantize) - composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) + # Keep the image as 4-channel RGBA to force TRELLIS to bypass its internal background remover + rgb_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) + alpha_uint8 = (alpha_float.squeeze(-1) * 255.0).round().clip(0, 255).astype(np.uint8) - cropped_pil = Image.fromarray(composite_uint8) + rgba_composite = np.zeros((cropped_np.shape[0], cropped_np.shape[1], 4), dtype=np.uint8) + rgba_composite[:, :, :3] = rgb_uint8 + rgba_composite[:, :, 3] = alpha_uint8 - item_conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True) + cropped_pil = Image.fromarray(rgba_composite, mode="RGBA") + + # Convert to RGB to ensure the CLIP/DINO model receives a 3-channel image + item_conditioning = run_conditioning(clip_vision_model, cropped_pil.convert("RGB"), include_1024=True) cond_512_list.append(item_conditioning["cond_512"]) cond_1024_list.append(item_conditioning["cond_1024"])