ComfyUI/comfy_extras/mesh3d/uv_unwrap/mesh.py
2026-07-01 21:39:19 +03:00

163 lines
5.4 KiB
Python

"""Mesh container, edge/face adjacency, manifold cleanup."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List
import numpy as np
import torch
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from torch import Tensor
# ---- Per-face / per-vertex geometry ----
def face_normals(vertices: Tensor, faces: Tensor) -> Tensor:
"""[F,3] unit face normals (degenerate faces -> zero)."""
v0 = vertices[faces[:, 0]]
v1 = vertices[faces[:, 1]]
v2 = vertices[faces[:, 2]]
n = torch.linalg.cross(v1 - v0, v2 - v0)
return n / n.norm(dim=1, keepdim=True).clamp_min(1e-20)
def face_areas(vertices: Tensor, faces: Tensor) -> Tensor:
"""[F] triangle areas."""
v0 = vertices[faces[:, 0]]
v1 = vertices[faces[:, 1]]
v2 = vertices[faces[:, 2]]
return 0.5 * torch.linalg.cross(v1 - v0, v2 - v0).norm(dim=1)
def face_centroids(vertices: Tensor, faces: Tensor) -> Tensor:
"""[F,3] triangle centroids."""
return vertices[faces].mean(dim=1)
def face_edge_lengths(vertices: Tensor, faces: Tensor) -> Tensor:
"""[F,3] edge lengths; column e = |v[faces[:,e]] - v[faces[:,(e+1)%3]]|."""
va = vertices[faces]
vb = vertices[faces.roll(shifts=-1, dims=1)]
return (vb - va).norm(dim=-1).to(torch.float32)
def chart_3d_areas(face_area: Tensor, face_chart: Tensor, n_charts: int) -> Tensor:
"""[n_charts] sum of face areas per chart."""
out = torch.zeros(n_charts, dtype=face_area.dtype, device=face_area.device)
out.scatter_add_(0, face_chart, face_area)
return out
@dataclass
class MeshData:
"""Cleaned mesh with adjacency; face_face[f, i] = face sharing edge (faces[f,i], faces[f,(i+1)%3]) or -1 if boundary."""
vertices: Tensor # [V, 3] float
faces: Tensor # [F, 3] long
face_face: Tensor # [F, 3] long, neighbor face id or -1
face_normal: Tensor # [F, 3] float
face_area: Tensor # [F] float
face_centroid: Tensor # [F, 3] float
component: Tensor # [F] long, connected-component id
n_components: int
def build_mesh(vertices: Tensor, faces: Tensor) -> MeshData:
"""Build adjacency; non-manifold edges (>2 incident faces) get no neighbor and act as boundary."""
if vertices.dtype != torch.float32:
vertices = vertices.to(torch.float32)
if faces.dtype != torch.long:
faces = faces.to(torch.long)
device = faces.device
V = vertices.shape[0]
F = faces.shape[0]
# Per directed face-edge; flat layout p = f*3+i.
a = faces.flatten()
b = faces.roll(shifts=-1, dims=1).flatten()
lo = torch.minimum(a, b)
hi = torch.maximum(a, b)
edge_key = lo * (V + 1) + hi
# Pair manifold (count==2) face-edges; others get no neighbor.
_, inverse, counts = torch.unique(edge_key, return_inverse=True, return_counts=True)
edge_count = counts[inverse]
manifold_mask = edge_count == 2
sort_idx = torch.argsort(edge_key, stable=True)
sorted_manifold = manifold_mask[sort_idx]
pair_positions = sort_idx[sorted_manifold]
pair_a = pair_positions[0::2]
pair_b = pair_positions[1::2]
face_id_flat = torch.arange(F, device=device).repeat_interleave(3)
face_face_flat = torch.full((3 * F,), -1, dtype=torch.long, device=device)
face_face_flat[pair_a] = face_id_flat[pair_b]
face_face_flat[pair_b] = face_id_flat[pair_a]
face_face = face_face_flat.view(F, 3)
face_face_np = face_face.cpu().numpy()
rows_mask = face_face_np >= 0
if rows_mask.any():
rows = np.broadcast_to(np.arange(F)[:, None], (F, 3))[rows_mask]
cols = face_face_np[rows_mask]
adj = csr_matrix(
(np.ones(rows.size, dtype=np.int8), (rows, cols)),
shape=(F, F),
)
else:
adj = csr_matrix((F, F), dtype=np.int8)
n_components, labels = connected_components(adj, directed=False)
face_normal = face_normals(vertices, faces)
face_area = face_areas(vertices, faces)
face_centroid = face_centroids(vertices, faces)
return MeshData(
vertices=vertices,
faces=faces,
face_face=face_face,
face_normal=face_normal,
face_area=face_area,
face_centroid=face_centroid,
component=torch.from_numpy(labels.astype(np.int64)).to(device),
n_components=int(n_components),
)
def chart_boundary_loops(
faces_subset: Tensor, face_face_subset: Tensor
) -> List[List[int]]:
"""Return ordered boundary vertex loops for a chart submesh (face_face_subset[f,i]==-1 marks a boundary edge)."""
F = faces_subset.shape[0]
faces_np = faces_subset.cpu().numpy()
ff = face_face_subset.cpu().numpy()
next_v: Dict[int, int] = {}
for f in range(F):
for i in range(3):
if ff[f, i] == -1:
a = int(faces_np[f, i])
b = int(faces_np[f, (i + 1) % 3])
next_v[a] = b
loops: List[List[int]] = []
visited = set()
for start in list(next_v.keys()):
if start in visited:
continue
loop = [start]
visited.add(start)
cur = next_v.get(start)
while cur is not None and cur != start:
if cur in visited:
break
loop.append(cur)
visited.add(cur)
cur = next_v.get(cur)
if len(loop) >= 3:
loops.append(loop)
return loops