mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
163 lines
5.4 KiB
Python
163 lines
5.4 KiB
Python
"""Mesh container, edge/face adjacency, manifold cleanup."""
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Dict, List
|
|
|
|
import numpy as np
|
|
import torch
|
|
from scipy.sparse import csr_matrix
|
|
from scipy.sparse.csgraph import connected_components
|
|
from torch import Tensor
|
|
|
|
|
|
# ---- Per-face / per-vertex geometry ----
|
|
|
|
def face_normals(vertices: Tensor, faces: Tensor) -> Tensor:
|
|
"""[F,3] unit face normals (degenerate faces -> zero)."""
|
|
v0 = vertices[faces[:, 0]]
|
|
v1 = vertices[faces[:, 1]]
|
|
v2 = vertices[faces[:, 2]]
|
|
n = torch.linalg.cross(v1 - v0, v2 - v0)
|
|
return n / n.norm(dim=1, keepdim=True).clamp_min(1e-20)
|
|
|
|
|
|
def face_areas(vertices: Tensor, faces: Tensor) -> Tensor:
|
|
"""[F] triangle areas."""
|
|
v0 = vertices[faces[:, 0]]
|
|
v1 = vertices[faces[:, 1]]
|
|
v2 = vertices[faces[:, 2]]
|
|
return 0.5 * torch.linalg.cross(v1 - v0, v2 - v0).norm(dim=1)
|
|
|
|
|
|
def face_centroids(vertices: Tensor, faces: Tensor) -> Tensor:
|
|
"""[F,3] triangle centroids."""
|
|
return vertices[faces].mean(dim=1)
|
|
|
|
|
|
def face_edge_lengths(vertices: Tensor, faces: Tensor) -> Tensor:
|
|
"""[F,3] edge lengths; column e = |v[faces[:,e]] - v[faces[:,(e+1)%3]]|."""
|
|
va = vertices[faces]
|
|
vb = vertices[faces.roll(shifts=-1, dims=1)]
|
|
return (vb - va).norm(dim=-1).to(torch.float32)
|
|
|
|
|
|
def chart_3d_areas(face_area: Tensor, face_chart: Tensor, n_charts: int) -> Tensor:
|
|
"""[n_charts] sum of face areas per chart."""
|
|
out = torch.zeros(n_charts, dtype=face_area.dtype, device=face_area.device)
|
|
out.scatter_add_(0, face_chart, face_area)
|
|
return out
|
|
|
|
|
|
@dataclass
|
|
class MeshData:
|
|
"""Cleaned mesh with adjacency; face_face[f, i] = face sharing edge (faces[f,i], faces[f,(i+1)%3]) or -1 if boundary."""
|
|
|
|
vertices: Tensor # [V, 3] float
|
|
faces: Tensor # [F, 3] long
|
|
face_face: Tensor # [F, 3] long, neighbor face id or -1
|
|
face_normal: Tensor # [F, 3] float
|
|
face_area: Tensor # [F] float
|
|
face_centroid: Tensor # [F, 3] float
|
|
component: Tensor # [F] long, connected-component id
|
|
n_components: int
|
|
|
|
|
|
def build_mesh(vertices: Tensor, faces: Tensor) -> MeshData:
|
|
"""Build adjacency; non-manifold edges (>2 incident faces) get no neighbor and act as boundary."""
|
|
if vertices.dtype != torch.float32:
|
|
vertices = vertices.to(torch.float32)
|
|
if faces.dtype != torch.long:
|
|
faces = faces.to(torch.long)
|
|
|
|
device = faces.device
|
|
V = vertices.shape[0]
|
|
F = faces.shape[0]
|
|
|
|
# Per directed face-edge; flat layout p = f*3+i.
|
|
a = faces.flatten()
|
|
b = faces.roll(shifts=-1, dims=1).flatten()
|
|
lo = torch.minimum(a, b)
|
|
hi = torch.maximum(a, b)
|
|
edge_key = lo * (V + 1) + hi
|
|
|
|
# Pair manifold (count==2) face-edges; others get no neighbor.
|
|
_, inverse, counts = torch.unique(edge_key, return_inverse=True, return_counts=True)
|
|
edge_count = counts[inverse]
|
|
manifold_mask = edge_count == 2
|
|
|
|
sort_idx = torch.argsort(edge_key, stable=True)
|
|
sorted_manifold = manifold_mask[sort_idx]
|
|
pair_positions = sort_idx[sorted_manifold]
|
|
pair_a = pair_positions[0::2]
|
|
pair_b = pair_positions[1::2]
|
|
|
|
face_id_flat = torch.arange(F, device=device).repeat_interleave(3)
|
|
face_face_flat = torch.full((3 * F,), -1, dtype=torch.long, device=device)
|
|
face_face_flat[pair_a] = face_id_flat[pair_b]
|
|
face_face_flat[pair_b] = face_id_flat[pair_a]
|
|
face_face = face_face_flat.view(F, 3)
|
|
|
|
face_face_np = face_face.cpu().numpy()
|
|
rows_mask = face_face_np >= 0
|
|
if rows_mask.any():
|
|
rows = np.broadcast_to(np.arange(F)[:, None], (F, 3))[rows_mask]
|
|
cols = face_face_np[rows_mask]
|
|
adj = csr_matrix(
|
|
(np.ones(rows.size, dtype=np.int8), (rows, cols)),
|
|
shape=(F, F),
|
|
)
|
|
else:
|
|
adj = csr_matrix((F, F), dtype=np.int8)
|
|
n_components, labels = connected_components(adj, directed=False)
|
|
|
|
face_normal = face_normals(vertices, faces)
|
|
face_area = face_areas(vertices, faces)
|
|
face_centroid = face_centroids(vertices, faces)
|
|
|
|
return MeshData(
|
|
vertices=vertices,
|
|
faces=faces,
|
|
face_face=face_face,
|
|
face_normal=face_normal,
|
|
face_area=face_area,
|
|
face_centroid=face_centroid,
|
|
component=torch.from_numpy(labels.astype(np.int64)).to(device),
|
|
n_components=int(n_components),
|
|
)
|
|
|
|
|
|
def chart_boundary_loops(
|
|
faces_subset: Tensor, face_face_subset: Tensor
|
|
) -> List[List[int]]:
|
|
"""Return ordered boundary vertex loops for a chart submesh (face_face_subset[f,i]==-1 marks a boundary edge)."""
|
|
F = faces_subset.shape[0]
|
|
faces_np = faces_subset.cpu().numpy()
|
|
ff = face_face_subset.cpu().numpy()
|
|
|
|
next_v: Dict[int, int] = {}
|
|
for f in range(F):
|
|
for i in range(3):
|
|
if ff[f, i] == -1:
|
|
a = int(faces_np[f, i])
|
|
b = int(faces_np[f, (i + 1) % 3])
|
|
next_v[a] = b
|
|
|
|
loops: List[List[int]] = []
|
|
visited = set()
|
|
for start in list(next_v.keys()):
|
|
if start in visited:
|
|
continue
|
|
loop = [start]
|
|
visited.add(start)
|
|
cur = next_v.get(start)
|
|
while cur is not None and cur != start:
|
|
if cur in visited:
|
|
break
|
|
loop.append(cur)
|
|
visited.add(cur)
|
|
cur = next_v.get(cur)
|
|
if len(loop) >= 3:
|
|
loops.append(loop)
|
|
return loops
|