remove triton version

This commit is contained in:
Yousef Rafat 2026-05-14 15:07:50 +03:00
parent e3a25f1b11
commit d2b97b510c

View File

@ -4,9 +4,7 @@ from comfy.ldm.trellis2.vae import SparseTensor
import comfy.model_management import comfy.model_management
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import triton.language as tl
import logging import logging
import triton
import torch import torch
import scipy import scipy
import copy import copy
@ -761,227 +759,6 @@ class EmptyTrellis2LatentStructure(IO.ComfyNode):
} }
return IO.NodeOutput(output) return IO.NodeOutput(output)
@triton.jit
def qem_edge_errors_kernel(
verts_ptr, Q_ptr, edges_ptr, optimal_ptr, error_ptr, wander_ptr,
n_edges, stabilizer, max_edge_length_sq, mesh_scale_sq,
BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < n_edges
va = tl.load(edges_ptr + offs * 2, mask=mask, other=0).to(tl.int64)
vb = tl.load(edges_ptr + offs * 2 + 1, mask=mask, other=0).to(tl.int64)
vax = tl.load(verts_ptr + va * 3 + 0, mask=mask, other=0.0)
vay = tl.load(verts_ptr + va * 3 + 1, mask=mask, other=0.0)
vaz = tl.load(verts_ptr + va * 3 + 2, mask=mask, other=0.0)
vbx = tl.load(verts_ptr + vb * 3 + 0, mask=mask, other=0.0)
vby = tl.load(verts_ptr + vb * 3 + 1, mask=mask, other=0.0)
vbz = tl.load(verts_ptr + vb * 3 + 2, mask=mask, other=0.0)
ex = vbx - vax
ey = vby - vay
ez = vbz - vaz
el_sq = ex * ex + ey * ey + ez * ez
el = tl.sqrt(el_sq)
Qa_base = Q_ptr + va * 16
Qb_base = Q_ptr + vb * 16
qe0 = tl.load(Qa_base + 0, mask=mask, other=0.0) + tl.load(Qb_base + 0, mask=mask, other=0.0)
qe1 = tl.load(Qa_base + 1, mask=mask, other=0.0) + tl.load(Qb_base + 1, mask=mask, other=0.0)
qe2 = tl.load(Qa_base + 2, mask=mask, other=0.0) + tl.load(Qb_base + 2, mask=mask, other=0.0)
qe3 = tl.load(Qa_base + 3, mask=mask, other=0.0) + tl.load(Qb_base + 3, mask=mask, other=0.0)
qe4 = tl.load(Qa_base + 4, mask=mask, other=0.0) + tl.load(Qb_base + 4, mask=mask, other=0.0)
qe5 = tl.load(Qa_base + 5, mask=mask, other=0.0) + tl.load(Qb_base + 5, mask=mask, other=0.0)
qe6 = tl.load(Qa_base + 6, mask=mask, other=0.0) + tl.load(Qb_base + 6, mask=mask, other=0.0)
qe7 = tl.load(Qa_base + 7, mask=mask, other=0.0) + tl.load(Qb_base + 7, mask=mask, other=0.0)
qe8 = tl.load(Qa_base + 8, mask=mask, other=0.0) + tl.load(Qb_base + 8, mask=mask, other=0.0)
qe9 = tl.load(Qa_base + 9, mask=mask, other=0.0) + tl.load(Qb_base + 9, mask=mask, other=0.0)
qe10 = tl.load(Qa_base + 10, mask=mask, other=0.0) + tl.load(Qb_base + 10, mask=mask, other=0.0)
qe11 = tl.load(Qa_base + 11, mask=mask, other=0.0) + tl.load(Qb_base + 11, mask=mask, other=0.0)
qe12 = tl.load(Qa_base + 12, mask=mask, other=0.0) + tl.load(Qb_base + 12, mask=mask, other=0.0)
qe13 = tl.load(Qa_base + 13, mask=mask, other=0.0) + tl.load(Qb_base + 13, mask=mask, other=0.0)
qe14 = tl.load(Qa_base + 14, mask=mask, other=0.0) + tl.load(Qb_base + 14, mask=mask, other=0.0)
qe15 = tl.load(Qa_base + 15, mask=mask, other=0.0) + tl.load(Qb_base + 15, mask=mask, other=0.0)
a11 = qe0 + stabilizer
a12 = qe1
a13 = qe2
a21 = qe4
a22 = qe5 + stabilizer
a23 = qe6
a31 = qe8
a32 = qe9
a33 = qe10 + stabilizer
b1 = -qe3
b2 = -qe7
b3 = -qe11
det = (a11 * (a22 * a33 - a23 * a32)
- a12 * (a21 * a33 - a23 * a31)
+ a13 * (a21 * a32 - a22 * a31))
det_good = tl.abs(det) > 1e-12
det_x = (b1 * (a22 * a33 - a23 * a32)
- a12 * (b2 * a33 - a23 * b3)
+ a13 * (b2 * a32 - a22 * b3))
det_y = (a11 * (b2 * a33 - a23 * b3)
- b1 * (a21 * a33 - a23 * a31)
+ a13 * (a21 * b3 - b2 * a31))
det_z = (a11 * (a22 * b3 - b2 * a32)
- a12 * (a21 * b3 - b2 * a31)
+ b1 * (a21 * a32 - a22 * a31))
ox = tl.where(det_good, det_x / det, (vax + vbx) * 0.5)
oy = tl.where(det_good, det_y / det, (vay + vby) * 0.5)
oz = tl.where(det_good, det_z / det, (vaz + vbz) * 0.5)
dist_a_sq = (ox - vax) * (ox - vax) + (oy - vay) * (oy - vay) + (oz - vaz) * (oz - vaz)
dist_b_sq = (ox - vbx) * (ox - vbx) + (oy - vby) * (oy - vby) + (oz - vbz) * (oz - vbz)
wander_thresh = 16.0 * el_sq
wander_bad = (dist_a_sq > wander_thresh) | (dist_b_sq > wander_thresh)
ox = tl.where(wander_bad & (el > 0.0), (vax + vbx) * 0.5, ox)
oy = tl.where(wander_bad & (el > 0.0), (vay + vby) * 0.5, oy)
oz = tl.where(wander_bad & (el > 0.0), (vaz + vbz) * 0.5, oz)
v4_0 = ox
v4_1 = oy
v4_2 = oz
v4_3 = 1.0
qv0 = qe0 * v4_0 + qe1 * v4_1 + qe2 * v4_2 + qe3 * v4_3
qv1 = qe4 * v4_0 + qe5 * v4_1 + qe6 * v4_2 + qe7 * v4_3
qv2 = qe8 * v4_0 + qe9 * v4_1 + qe10 * v4_2 + qe11 * v4_3
qv3 = qe12 * v4_0 + qe13 * v4_1 + qe14 * v4_2 + qe15 * v4_3
err = tl.abs(v4_0 * qv0 + v4_1 * qv1 + v4_2 * qv2 + v4_3 * qv3)
tl.store(optimal_ptr + offs * 3 + 0, ox, mask=mask)
tl.store(optimal_ptr + offs * 3 + 1, oy, mask=mask)
tl.store(optimal_ptr + offs * 3 + 2, oz, mask=mask)
tl.store(error_ptr + offs, err, mask=mask)
tl.store(wander_ptr + offs, wander_bad.to(tl.int32), mask=mask)
@triton.jit
def validate_faces_kernel(
verts_ptr, faces_ptr, va_ptr, vb_ptr, opt_ptr, pair_edge_ptr, pair_face_ptr,
n_pairs, area_thresh, keep_mask_ptr,
BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < n_pairs
ei = tl.load(pair_edge_ptr + offs, mask=mask, other=0).to(tl.int64)
fi = tl.load(pair_face_ptr + offs, mask=mask, other=0).to(tl.int64)
f0 = tl.load(faces_ptr + fi * 3 + 0, mask=mask, other=0).to(tl.int64)
f1 = tl.load(faces_ptr + fi * 3 + 1, mask=mask, other=0).to(tl.int64)
f2 = tl.load(faces_ptr + fi * 3 + 2, mask=mask, other=0).to(tl.int64)
vai = tl.load(va_ptr + ei, mask=mask, other=0).to(tl.int64)
vbi = tl.load(vb_ptr + ei, mask=mask, other=0).to(tl.int64)
optx = tl.load(opt_ptr + ei * 3 + 0, mask=mask, other=0.0)
opty = tl.load(opt_ptr + ei * 3 + 1, mask=mask, other=0.0)
optz = tl.load(opt_ptr + ei * 3 + 2, mask=mask, other=0.0)
v0x = tl.load(verts_ptr + f0 * 3 + 0, mask=mask, other=0.0)
v0y = tl.load(verts_ptr + f0 * 3 + 1, mask=mask, other=0.0)
v0z = tl.load(verts_ptr + f0 * 3 + 2, mask=mask, other=0.0)
v1x = tl.load(verts_ptr + f1 * 3 + 0, mask=mask, other=0.0)
v1y = tl.load(verts_ptr + f1 * 3 + 1, mask=mask, other=0.0)
v1z = tl.load(verts_ptr + f1 * 3 + 2, mask=mask, other=0.0)
v2x = tl.load(verts_ptr + f2 * 3 + 0, mask=mask, other=0.0)
v2y = tl.load(verts_ptr + f2 * 3 + 1, mask=mask, other=0.0)
v2z = tl.load(verts_ptr + f2 * 3 + 2, mask=mask, other=0.0)
is_v0_a = (f0 == vai) | (f0 == vbi)
is_v1_a = (f1 == vai) | (f1 == vbi)
is_v2_a = (f2 == vai) | (f2 == vbi)
n0x = tl.where(is_v0_a, optx, v0x)
n0y = tl.where(is_v0_a, opty, v0y)
n0z = tl.where(is_v0_a, optz, v0z)
n1x = tl.where(is_v1_a, optx, v1x)
n1y = tl.where(is_v1_a, opty, v1y)
n1z = tl.where(is_v1_a, optz, v1z)
n2x = tl.where(is_v2_a, optx, v2x)
n2y = tl.where(is_v2_a, opty, v2y)
n2z = tl.where(is_v2_a, optz, v2z)
e1x_old = v1x - v0x
e1y_old = v1y - v0y
e1z_old = v1z - v0z
e2x_old = v2x - v0x
e2y_old = v2y - v0y
e2z_old = v2z - v0z
nx_old = e1y_old * e2z_old - e1z_old * e2y_old
ny_old = e1z_old * e2x_old - e1x_old * e2z_old
nz_old = e1x_old * e2y_old - e1y_old * e2x_old
area_old_sq = nx_old * nx_old + ny_old * ny_old + nz_old * nz_old
area_old = tl.sqrt(area_old_sq)
e1x_new = n1x - n0x
e1y_new = n1y - n0y
e1z_new = n1z - n0z
e2x_new = n2x - n0x
e2y_new = n2y - n0y
e2z_new = n2z - n0z
nx_new = e1y_new * e2z_new - e1z_new * e2y_new
ny_new = e1z_new * e2x_new - e1x_new * e2z_new
nz_new = e1x_new * e2y_new - e1y_new * e2x_new
area_new_sq = nx_new * nx_new + ny_new * ny_new + nz_new * nz_new
area_new = tl.sqrt(area_new_sq)
area_bad = area_new_sq < area_thresh * area_thresh
dot = nx_old * nx_new + ny_old * ny_new + nz_old * nz_new
flip_bad = dot < -0.2 * area_old * area_new
e0x_new = n1x - n0x
e0y_new = n1y - n0y
e0z_new = n1z - n0z
e1x_new2 = n2x - n1x
e1y_new2 = n2y - n1y
e1z_new2 = n2z - n1z
e2x_new2 = n0x - n2x
e2y_new2 = n0y - n2y
e2z_new2 = n0z - n2z
l0_new_sq = e0x_new * e0x_new + e0y_new * e0y_new + e0z_new * e0z_new
l1_new_sq = e1x_new2 * e1x_new2 + e1y_new2 * e1y_new2 + e1z_new2 * e1z_new2
l2_new_sq = e2x_new2 * e2x_new2 + e2y_new2 * e2y_new2 + e2z_new2 * e2z_new2
max_new_sq = tl.maximum(tl.maximum(l0_new_sq, l1_new_sq), l2_new_sq)
e0x_old = v1x - v0x
e0y_old = v1y - v0y
e0z_old = v1z - v0z
e1x_old2 = v2x - v1x
e1y_old2 = v2y - v1y
e1z_old2 = v2z - v1z
e2x_old2 = v0x - v2x
e2y_old2 = v0y - v2y
e2z_old2 = v0z - v2z
l0_old_sq = e0x_old * e0x_old + e0y_old * e0y_old + e0z_old * e0z_old
l1_old_sq = e1x_old2 * e1x_old2 + e1y_old2 * e1y_old2 + e1z_old2 * e1z_old2
l2_old_sq = e2x_old2 * e2x_old2 + e2y_old2 * e2y_old2 + e2z_old2 * e2z_old2
max_old_sq = tl.maximum(tl.maximum(l0_old_sq, l1_old_sq), l2_old_sq)
stretch_bad = max_new_sq > 6.25 * max_old_sq
any_bad = area_bad | flip_bad | stretch_bad
tl.store(keep_mask_ptr + offs, any_bad.to(tl.int32), mask=mask)
def _pytorch_edge_errors(verts, Q, edges, stabilizer, max_edge_length_sq, mesh_scale_sq): def _pytorch_edge_errors(verts, Q, edges, stabilizer, max_edge_length_sq, mesh_scale_sq):
n_edges = edges.shape[0] n_edges = edges.shape[0]
if n_edges == 0: if n_edges == 0:
@ -1037,161 +814,6 @@ def _pytorch_edge_errors(verts, Q, edges, stabilizer, max_edge_length_sq, mesh_s
return opt, err, valid return opt, err, valid
def _triton_edge_errors(verts, Q, edges, stabilizer, max_edge_length_sq, mesh_scale_sq):
n_edges = edges.shape[0]
if n_edges == 0:
return (torch.empty((0, 3), dtype=torch.float64, device=verts.device),
torch.empty((0,), dtype=torch.float64, device=verts.device),
torch.zeros((0,), dtype=torch.bool, device=verts.device))
device = verts.device
optimal = torch.empty((n_edges, 3), dtype=torch.float64, device=device)
error = torch.empty((n_edges,), dtype=torch.float64, device=device)
wander = torch.empty((n_edges,), dtype=torch.int32, device=device)
BLOCK_SIZE = 256
grid = (triton.cdiv(n_edges, BLOCK_SIZE),)
try:
qem_edge_errors_kernel[grid](
verts, Q, edges, optimal, error, wander,
n_edges, stabilizer, max_edge_length_sq, mesh_scale_sq,
BLOCK_SIZE=BLOCK_SIZE
)
except Exception:
return _pytorch_edge_errors(verts, Q, edges, stabilizer, max_edge_length_sq, mesh_scale_sq)
has_nan = torch.isnan(optimal).any() or torch.isnan(error).any()
has_inf = torch.isinf(error).any()
if has_nan or has_inf:
return _pytorch_edge_errors(verts, Q, edges, stabilizer, max_edge_length_sq, mesh_scale_sq)
pa = verts[edges[:, 0]]
pb = verts[edges[:, 1]]
el = torch.norm(pb - pa, dim=-1)
mesh_scale = (mesh_scale_sq) ** 0.5
length_ok = el > mesh_scale * 1e-5
error_ok = error < max_edge_length_sq
nan_ok = ~torch.isnan(optimal).any(dim=-1) & ~torch.isnan(error)
valid = length_ok & error_ok & nan_ok
return optimal, error, valid
def _pytorch_validate_faces(verts, faces, v_a, v_b, opt_pos, pair_edge_idx, pair_face_idx, area_thresh):
n_pairs = len(pair_edge_idx)
if n_pairs == 0:
return torch.ones(v_a.numel(), dtype=torch.bool, device=verts.device)
device = verts.device
old_faces = faces[pair_face_idx]
v0_old = verts[old_faces[:, 0]]
v1_old = verts[old_faces[:, 1]]
v2_old = verts[old_faces[:, 2]]
v0_new = v0_old.clone()
v1_new = v1_old.clone()
v2_new = v2_old.clone()
va_t = v_a[pair_edge_idx]
vb_t = v_b[pair_edge_idx]
opt_t = opt_pos[pair_edge_idx]
mask0 = (old_faces[:, 0] == va_t) | (old_faces[:, 0] == vb_t)
mask1 = (old_faces[:, 1] == va_t) | (old_faces[:, 1] == vb_t)
mask2 = (old_faces[:, 2] == va_t) | (old_faces[:, 2] == vb_t)
v0_new[mask0] = opt_t[mask0]
v1_new[mask1] = opt_t[mask1]
v2_new[mask2] = opt_t[mask2]
e1_old = v1_old - v0_old
e2_old = v2_old - v0_old
n_old = torch.cross(e1_old, e2_old, dim=-1)
e1_new = v1_new - v0_new
e2_new = v2_new - v0_new
n_new = torch.cross(e1_new, e2_new, dim=-1)
area_new = torch.norm(n_new, dim=-1)
area_bad = area_new < area_thresh
n_old_norm = n_old / (torch.norm(n_old, dim=-1, keepdim=True) + 1e-12)
n_new_norm = n_new / (torch.norm(n_new, dim=-1, keepdim=True) + 1e-12)
dots = (n_old_norm * n_new_norm).sum(dim=-1)
flip_bad = dots < -0.2
old_edges = torch.stack([
torch.norm(v1_old - v0_old, dim=-1),
torch.norm(v2_old - v1_old, dim=-1),
torch.norm(v0_old - v2_old, dim=-1),
], dim=1).max(dim=1)[0]
new_edges = torch.stack([
torch.norm(v1_new - v0_new, dim=-1),
torch.norm(v2_new - v1_new, dim=-1),
torch.norm(v0_new - v2_new, dim=-1),
], dim=1).max(dim=1)[0]
stretch_bad = new_edges > 2.5 * old_edges
def face_angles(v0, v1, v2):
e0 = v1 - v0
e1 = v2 - v1
e2 = v0 - v2
l0 = torch.norm(e0, dim=-1)
l1 = torch.norm(e1, dim=-1)
l2 = torch.norm(e2, dim=-1)
cos_a = (l1 * l1 + l2 * l2 - l0 * l0) / (2 * l1 * l2 + 1e-12)
cos_b = (l0 * l0 + l2 * l2 - l1 * l1) / (2 * l0 * l2 + 1e-12)
cos_c = (l0 * l0 + l1 * l1 - l2 * l2) / (2 * l0 * l1 + 1e-12)
cos_all = torch.stack([cos_a, cos_b, cos_c], dim=-1)
return torch.acos(torch.clamp(cos_all, -1, 1)) * 180 / np.pi
new_angles = face_angles(v0_new, v1_new, v2_new)
angle_bad = (new_angles < 1.0).any(dim=-1) | (new_angles > 178.0).any(dim=-1)
any_bad = area_bad | flip_bad | stretch_bad | angle_bad
keep_mask = torch.ones(v_a.numel(), dtype=torch.bool, device=device)
if any_bad.any():
bad_edges = pair_edge_idx[any_bad]
keep_mask.scatter_(0, bad_edges, False)
return keep_mask
def _triton_validate_faces(verts, faces, v_a, v_b, opt_pos, pair_edge_idx, pair_face_idx, area_thresh):
n_pairs = len(pair_edge_idx)
if n_pairs == 0:
return torch.ones(v_a.numel(), dtype=torch.bool, device=verts.device)
device = verts.device
pair_bad = torch.empty(n_pairs, dtype=torch.int32, device=device)
BLOCK_SIZE = 256
grid = (triton.cdiv(n_pairs, BLOCK_SIZE),)
try:
validate_faces_kernel[grid](
verts, faces, v_a, v_b, opt_pos, pair_edge_idx, pair_face_idx,
n_pairs, area_thresh, pair_bad,
BLOCK_SIZE=BLOCK_SIZE
)
except Exception:
return _pytorch_validate_faces(verts, faces, v_a, v_b, opt_pos, pair_edge_idx, pair_face_idx, area_thresh)
keep_mask = torch.ones(v_a.numel(), dtype=torch.bool, device=device)
bad_mask = pair_bad.bool()
if bad_mask.any():
bad_edges = pair_edge_idx[bad_mask]
keep_mask.scatter_(0, bad_edges, False)
return keep_mask
def _build_quadrics(verts, faces): def _build_quadrics(verts, faces):
v0 = verts[faces[:, 0]] v0 = verts[faces[:, 0]]
v1 = verts[faces[:, 1]] v1 = verts[faces[:, 1]]
@ -1350,7 +972,7 @@ def _gpu_greedy_sampled(edges, errors, v_alive, max_select):
return torch.empty(0, dtype=torch.int64, device=device) return torch.empty(0, dtype=torch.int64, device=device)
return torch.tensor(selected, dtype=torch.int64, device=device) return torch.tensor(selected, dtype=torch.int64, device=device)
def _qem_simplify(verts_np, faces_np, colors_np, target_faces, device, max_edge_length=None, use_triton=False, fast_mode=False): def _qem_simplify(verts_np, faces_np, colors_np, target_faces, device, max_edge_length=None):
verts = torch.from_numpy(verts_np).to(device=device, dtype=torch.float64) verts = torch.from_numpy(verts_np).to(device=device, dtype=torch.float64)
faces = torch.from_numpy(faces_np).to(device=device, dtype=torch.int64) faces = torch.from_numpy(faces_np).to(device=device, dtype=torch.int64)
colors = ( colors = (
@ -1362,7 +984,7 @@ def _qem_simplify(verts_np, faces_np, colors_np, target_faces, device, max_edge_
num_verts = verts.shape[0] num_verts = verts.shape[0]
num_faces = faces.shape[0] num_faces = faces.shape[0]
logging.debug(f"[QEM] Input: {num_verts} verts, {num_faces} faces, target={target_faces}, fast={fast_mode}") logging.debug(f"[QEM] Input: {num_verts} verts, {num_faces} faces, target={target_faces}")
v_alive = torch.ones(num_verts, dtype=torch.bool, device=device) v_alive = torch.ones(num_verts, dtype=torch.bool, device=device)
f_alive = torch.ones(num_faces, dtype=torch.bool, device=device) f_alive = torch.ones(num_faces, dtype=torch.bool, device=device)
@ -1379,7 +1001,6 @@ def _qem_simplify(verts_np, faces_np, colors_np, target_faces, device, max_edge_
max_edge_length = 1.0 max_edge_length = 1.0
stabilizer = mesh_scale * mesh_scale * 0.001 stabilizer = mesh_scale * mesh_scale * 0.001
area_thresh = mesh_scale * mesh_scale * 1e-10
max_edge_length_sq = max_edge_length * max_edge_length max_edge_length_sq = max_edge_length * max_edge_length
mesh_scale_sq = mesh_scale * mesh_scale mesh_scale_sq = mesh_scale * mesh_scale
@ -1448,15 +1069,9 @@ def _qem_simplify(verts_np, faces_np, colors_np, target_faces, device, max_edge_
else: else:
n_edges = n_edges_total n_edges = n_edges_total
# Compute edge errors optimal, err, valid = _pytorch_edge_errors(
if use_triton and torch.cuda.is_available(): verts, Q, edges_orig, stabilizer, max_edge_length_sq, mesh_scale_sq
optimal, err, valid = _triton_edge_errors( )
verts, Q, edges_orig, stabilizer, max_edge_length_sq, mesh_scale_sq
)
else:
optimal, err, valid = _pytorch_edge_errors(
verts, Q, edges_orig, stabilizer, max_edge_length_sq, mesh_scale_sq
)
if not valid.any(): if not valid.any():
valid = torch.ones(n_edges, dtype=torch.bool, device=device) valid = torch.ones(n_edges, dtype=torch.bool, device=device)
@ -1477,7 +1092,6 @@ def _qem_simplify(verts_np, faces_np, colors_np, target_faces, device, max_edge_
v_a = edges_orig[sel, 0] v_a = edges_orig[sel, 0]
v_b = edges_orig[sel, 1] v_b = edges_orig[sel, 1]
opt_pos = optimal[sel]
# Build adjacency # Build adjacency
face_indices, vert_ptrs = _build_vertex_face_csr(active_faces, num_verts) face_indices, vert_ptrs = _build_vertex_face_csr(active_faces, num_verts)
@ -1508,56 +1122,6 @@ def _qem_simplify(verts_np, faces_np, colors_np, target_faces, device, max_edge_
keep_mask = torch.ones(v_a.numel(), dtype=torch.bool, device=device) keep_mask = torch.ones(v_a.numel(), dtype=torch.bool, device=device)
# Face validation (skip in fast_mode)
if not fast_mode and len(pair_edge_idx) > 0:
pair_edge_idx_t = torch.tensor(pair_edge_idx, dtype=torch.int64, device=device)
pair_face_idx_t = torch.tensor(pair_face_idx, dtype=torch.int64, device=device)
if use_triton and torch.cuda.is_available():
keep_mask = _triton_validate_faces(
verts, active_faces, v_a, v_b, opt_pos,
pair_edge_idx_t, pair_face_idx_t, area_thresh
)
else:
keep_mask = _pytorch_validate_faces(
verts, active_faces, v_a, v_b, opt_pos,
pair_edge_idx_t, pair_face_idx_t, area_thresh
)
# Link condition (skip in fast_mode for massive speedup)
if not fast_mode:
# Vectorized link condition using GPU operations
link_keep = torch.ones(v_a.numel(), dtype=torch.bool, device=device)
# Build neighbor sets for va and vb using GPU operations
for ei in range(v_a.numel()):
vai = v_a[ei].item()
vbi = v_b[ei].item()
f_va = _get_vertex_faces(vai, face_indices, vert_ptrs)
f_vb = _get_vertex_faces(vbi, face_indices, vert_ptrs)
if f_va.numel() == 0 or f_vb.numel() == 0:
continue
faces_va = active_faces[f_va]
verts_va = faces_va[faces_va != vai]
verts_va = verts_va[verts_va != vbi]
faces_vb = active_faces[f_vb]
verts_vb = faces_vb[faces_vb != vbi]
verts_vb = verts_vb[verts_vb != vai]
if verts_va.numel() == 0 or verts_vb.numel() == 0:
continue
# Use torch.intersect1d for GPU-native intersection
common = torch.intersect1d(verts_va, verts_vb)
if common.numel() > 2:
link_keep[ei] = False
keep_mask &= link_keep
if not keep_mask.any(): if not keep_mask.any():
break break
@ -1623,12 +1187,12 @@ def _qem_simplify(verts_np, faces_np, colors_np, target_faces, device, max_edge_
return final_v, final_f, final_c return final_v, final_f, final_c
def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=None, use_triton=True, fast_mode=True): def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=None):
if vertices.ndim == 3: if vertices.ndim == 3:
v_list, f_list, c_list = [], [], [] v_list, f_list, c_list = [], [], []
for i in range(vertices.shape[0]): for i in range(vertices.shape[0]):
c_in = colors[i] if colors is not None else None c_in = colors[i] if colors is not None else None
v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target, max_edge_length, use_triton, fast_mode) v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target, max_edge_length)
v_list.append(v_i) v_list.append(v_i)
f_list.append(f_i) f_list.append(f_i)
if c_i is not None: if c_i is not None:
@ -1651,7 +1215,7 @@ def simplify_fn(vertices, faces, colors=None, target=100000, max_edge_length=Non
) )
out_v, out_f, out_c = _qem_simplify( out_v, out_f, out_c = _qem_simplify(
verts_np, faces_np, colors_np, target, device, max_edge_length, use_triton, fast_mode verts_np, faces_np, colors_np, target, device, max_edge_length
) )
final_v = out_v.to(device=device, dtype=dtype) final_v = out_v.to(device=device, dtype=dtype)