ComfyUI/comfy/ldm/trellis2/cumesh.py
Yousef R. Gamaleldin 6ea2e5b288 init
2026-01-30 23:34:48 +02:00

150 lines
4.8 KiB
Python

# will contain every cuda -> pytorch operation
import torch
from typing import Dict
class TorchHashMap:
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
device = keys.device
# use long for searchsorted
self.sorted_keys, order = torch.sort(keys.long())
self.sorted_vals = values.long()[order]
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
self._n = self.sorted_keys.numel()
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
flat = flat_keys.long()
idx = torch.searchsorted(self.sorted_keys, flat)
found = (idx < self._n) & (self.sorted_keys[idx] == flat)
out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
if found.any():
out[found] = self.sorted_vals[idx[found]]
return out
class Voxel:
def __init__(
self,
origin: list,
voxel_size: float,
coords: torch.Tensor = None,
attrs: torch.Tensor = None,
layout: Dict = {},
device: torch.device = 'cuda'
):
self.origin = torch.tensor(origin, dtype=torch.float32, device=device)
self.voxel_size = voxel_size
self.coords = coords
self.attrs = attrs
self.layout = layout
self.device = device
@property
def position(self):
return (self.coords + 0.5) * self.voxel_size + self.origin[None, :]
def split_attrs(self):
return {
k: self.attrs[:, self.layout[k]]
for k in self.layout
}
class Mesh:
def __init__(self,
vertices,
faces,
vertex_attrs=None
):
self.vertices = vertices.float()
self.faces = faces.int()
self.vertex_attrs = vertex_attrs
@property
def device(self):
return self.vertices.device
def to(self, device, non_blocking=False):
return Mesh(
self.vertices.to(device, non_blocking=non_blocking),
self.faces.to(device, non_blocking=non_blocking),
self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None,
)
def cuda(self, non_blocking=False):
return self.to('cuda', non_blocking=non_blocking)
def cpu(self):
return self.to('cpu')
# TODO could be an option
def fill_holes(self, max_hole_perimeter=3e-2):
import cumesh
vertices = self.vertices.cuda()
faces = self.faces.cuda()
mesh = cumesh.CuMesh()
mesh.init(vertices, faces)
mesh.get_edges()
mesh.get_boundary_info()
if mesh.num_boundaries == 0:
return
mesh.get_vertex_edge_adjacency()
mesh.get_vertex_boundary_adjacency()
mesh.get_manifold_boundary_adjacency()
mesh.read_manifold_boundary_adjacency()
mesh.get_boundary_connected_components()
mesh.get_boundary_loops()
if mesh.num_boundary_loops == 0:
return
mesh.fill_holes(max_hole_perimeter=max_hole_perimeter)
new_vertices, new_faces = mesh.read()
self.vertices = new_vertices.to(self.device)
self.faces = new_faces.to(self.device)
# TODO could be an option
def simplify(self, target=1000000, verbose: bool=False, options: dict={}):
import cumesh
vertices = self.vertices.cuda()
faces = self.faces.cuda()
mesh = cumesh.CuMesh()
mesh.init(vertices, faces)
mesh.simplify(target, verbose=verbose, options=options)
new_vertices, new_faces = mesh.read()
self.vertices = new_vertices.to(self.device)
self.faces = new_faces.to(self.device)
class MeshWithVoxel(Mesh, Voxel):
def __init__(self,
vertices: torch.Tensor,
faces: torch.Tensor,
origin: list,
voxel_size: float,
coords: torch.Tensor,
attrs: torch.Tensor,
voxel_shape: torch.Size,
layout: Dict = {},
):
self.vertices = vertices.float()
self.faces = faces.int()
self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device)
self.voxel_size = voxel_size
self.coords = coords
self.attrs = attrs
self.voxel_shape = voxel_shape
self.layout = layout
def to(self, device, non_blocking=False):
return MeshWithVoxel(
self.vertices.to(device, non_blocking=non_blocking),
self.faces.to(device, non_blocking=non_blocking),
self.origin.tolist(),
self.voxel_size,
self.coords.to(device, non_blocking=non_blocking),
self.attrs.to(device, non_blocking=non_blocking),
self.voxel_shape,
self.layout,
)