This commit is contained in:
kijai 2026-07-01 21:21:11 +03:00
parent 80e9fc65b1
commit 2bbf53e8fc
5 changed files with 122 additions and 62 deletions

View File

@ -952,13 +952,13 @@ def flexible_dual_grid_to_mesh(
values = torch.arange(N, dtype=torch.int32, device=device)
torch_hashmap = TorchHashMap(flat_keys, values)
# Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3]
# Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3]
n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,)
offsets_per_axis = flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset[0] # (3, 4, 3)
connected_voxel = coords[n_idx].unsqueeze(1) + offsets_per_axis[axis_idx] # (M, 4, 3)
M = connected_voxel.shape[0]
# flatten connected voxel coords and lookup. In-place to avoid extra memory allocation.
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
H, D = int(grid_size[1].item()), int(grid_size[2].item())
cv = connected_voxel.reshape(-1, 3)
conn_flat = cv[:, 0].long() * (H * D)
conn_flat.add_(cv[:, 1].long() * D)

View File

@ -40,19 +40,19 @@ class QEMConfig:
flip_reject_hard: bool = True # hard-reject (err=+inf) top-K collapses that flip any 1-ring normal
# Per-iteration batch sizing
# Per-iteration batch sizing
sampling_cap: int = 10_000_000 # max edges processed per outer iter
max_collapses_fraction: float = 0.25 # of remaining faces-to-remove
max_collapses_floor: int = 10_000
max_collapses_ceiling: int = 1_000_000
max_collapses_relative_cap: float = 0.10 # cap per-iter collapses as fraction of current faces; 0 disables
# Loop control
# Loop control
max_iterations: int = 5_000
compaction_period: int = 5
compaction_threshold: float = 0.85 # compact when alive_frac < this
# Quality knobs
# Quality knobs
boundary_quadrics: bool = True
boundary_weight: float = 1000.0
recompute_normals_post: bool = True
@ -63,7 +63,7 @@ class QEMConfig:
feature_edge_quadric_weight: float = 0.0
feature_edge_min_dihedral_deg: float = 30.0
# Flip check (FA-QEM §3.3)
# Flip check (FA-QEM §3.3)
quality_topk_multiplier: int = 4 # quality-check band size = this * max_collapses_per_iter
flip_cos_threshold: float = 0.0 # 0 = count any sign reversal (dihedral > 90°)
flip_check_max_degree: int = 16 # cap on vertex degree for the flip-check table
@ -71,13 +71,13 @@ class QEMConfig:
# Triangle shape penalty
skinny_weight: float = 1e-3 # penalise top-K collapses producing needle/sliver tris; 0 disables
# Topology preservation
# Topology preservation
enforce_link_condition: bool = True # reject collapses that violate the link condition
# Quadric area weighting
# Quadric area weighting
area_weighted_quadrics: bool = False # True: Garland-Heckbert area-weighted; False: un-weighted
# edge-length cost regularizer
# edge-length cost regularizer
lambda_edge_length: float = 1e-2 # add λ*len² to bias toward short edges; 0 disables
lambda_edge_length_absolute: bool = True # apply λ absolutely vs relative-to-QEM-median
@ -87,16 +87,16 @@ class QEMConfig:
memoryless_qem: bool = True # rebuild quadrics each round vs accumulate
repair_nonmanifold: bool = True # final repair_non_manifold_edges pass
# Pre-clean (input mesh)
# Pre-clean (input mesh)
preclean: bool = True # weld coincident verts, drop degenerate/duplicate/unused
# Post-clean (output mesh)
# Post-clean (output mesh)
postclean: bool = True # remove slivers, tiny components, unused verts left by collapse
postclean_min_angle_deg: float = 0.5
postclean_max_aspect_ratio: float = 100.0
postclean_min_component_faces: int = 8 # drop components with fewer faces than this
# Preclean tuning
# Preclean tuning
preclean_weld_epsilon_rel: float = 1e-5 # weld tolerance as fraction of bbox diagonal
preclean_min_component_faces: int = 0 # 0 = keep all components
@ -1631,8 +1631,10 @@ def qem_decimate_simplify(
out_v.append(v)
out_f.append(f)
out_s.append(s)
if c is not None: out_c.append(c)
if n is not None: out_n.append(n)
if c is not None:
out_c.append(c)
if n is not None:
out_n.append(n)
return (out_v, out_f,
out_c if out_c else None,
out_n if out_n else None,

View File

@ -159,7 +159,7 @@ def _build_tri_spatial_hash(centroids: torch.Tensor, tri_radii: torch.Tensor,
local = torch.arange(total, device=device) - cum[rep]
sx = spans[rep, 0]
sy = spans[rep, 1]
sz = spans[rep, 2]
lx = local % sx
ly = (local // sx) % sy
lz = local // (sx * sy)
@ -696,7 +696,6 @@ def _filter_components(verts: torch.Tensor, faces: torch.Tensor,
"""Drop tiny / inverted-volume / bbox-enclosed connected components; returns filtered faces."""
device = faces.device
V = verts.shape[0]
F = faces.shape[0]
# Connected components via min-label propagation across faces (200-iter max)
label = torch.arange(V, dtype=torch.long, device=device)
@ -1090,7 +1089,9 @@ def remesh_narrow_band_dc(
safe_tri = closest_tri.clamp(min=0)
tri_v_idx = faces[safe_tri].long() # (N, 3)
tri_v = vertices[tri_v_idx] # (N, 3, 3)
v0 = tri_v[:, 0]; v1 = tri_v[:, 1]; v2 = tri_v[:, 2]
v0 = tri_v[:, 0]
v1 = tri_v[:, 1]
v2 = tri_v[:, 2]
e0 = v1 - v0
e1 = v2 - v0
e2 = closest_pts - v0

View File

@ -39,11 +39,15 @@ def _best_rotation_jit(uvs_np: np.ndarray, n_angles: int) -> float:
half_pi = math.pi * 0.5
for k in range(n_angles):
theta = half_pi * k / n_angles
c = math.cos(theta); s = math.sin(theta)
xmin = 1e30; xmax = -1e30
ymin = 1e30; ymax = -1e30
c = math.cos(theta)
s = math.sin(theta)
xmin = 1e30
xmax = -1e30
ymin = 1e30
ymax = -1e30
for i in range(V):
ux = uvs_np[i, 0]; uy = uvs_np[i, 1]
ux = uvs_np[i, 0]
uy = uvs_np[i, 1]
xr = ux * c - uy * s
yr = ux * s + uy * c
if xr < xmin: xmin = xr
@ -78,10 +82,15 @@ def _rasterize_chart_jit(
F = faces.shape[0]
eps = 1e-7
for fi in range(F):
i0 = faces[fi, 0]; i1 = faces[fi, 1]; i2 = faces[fi, 2]
x0 = uvs_tex[i0, 0]; y0 = uvs_tex[i0, 1]
x1 = uvs_tex[i1, 0]; y1 = uvs_tex[i1, 1]
x2 = uvs_tex[i2, 0]; y2 = uvs_tex[i2, 1]
i0 = faces[fi, 0]
i1 = faces[fi, 1]
i2 = faces[fi, 2]
x0 = uvs_tex[i0, 0]
y0 = uvs_tex[i0, 1]
x1 = uvs_tex[i1, 0]
y1 = uvs_tex[i1, 1]
x2 = uvs_tex[i2, 0]
y2 = uvs_tex[i2, 1]
xmin_f = x0
if x1 < xmin_f: xmin_f = x1
if x2 < xmin_f: xmin_f = x2
@ -172,19 +181,25 @@ def _build_candidates_jit(
for xs in range(x, x_end):
if skyline[xs] > y:
y = int(skyline[xs])
out[k, 0] = x; out[k, 1] = y; out[k, 2] = swap_flag
out[k, 0] = x
out[k, 1] = y
out[k, 2] = swap_flag
k += 1
x += step
for y_fixed in (0, cur_h):
x = 0
while x <= cur_w:
out[k, 0] = x; out[k, 1] = y_fixed; out[k, 2] = swap_flag
out[k, 0] = x
out[k, 1] = y_fixed
out[k, 2] = swap_flag
k += 1
x += step
for x_fixed in (0, cur_w):
y = 0
while y <= cur_h:
out[k, 0] = x_fixed; out[k, 1] = y; out[k, 2] = swap_flag
out[k, 0] = x_fixed
out[k, 1] = y
out[k, 2] = swap_flag
k += 1
y += step
return out[:k]
@ -194,7 +209,8 @@ def _build_candidates_jit(
def _update_skyline_jit(skyline: np.ndarray, chart: np.ndarray,
x: int, y: int) -> None:
"""Lift skyline[x+i] to y + topmost_True_row + 1 per chart column."""
ch = chart.shape[0]; cw = chart.shape[1]
ch = chart.shape[0]
cw = chart.shape[1]
sw = skyline.shape[0]
for i in range(cw):
col_x = x + i
@ -227,17 +243,22 @@ def _best_placement_jit(
best_y = -1
best_score = -1
best_swap = 0
bh0 = bitmap.shape[0]; bw0 = bitmap.shape[1]
bh1 = bitmap_rot.shape[0]; bw1 = bitmap_rot.shape[1]
ah = atlas.shape[0]; aw = atlas.shape[1]
bh0 = bitmap.shape[0]
bw0 = bitmap.shape[1]
bh1 = bitmap_rot.shape[0]
bw1 = bitmap_rot.shape[1]
ah = atlas.shape[0]
aw = atlas.shape[1]
for k in range(n):
x = candidates[k, 0]
y = candidates[k, 1]
swap = candidates[k, 2]
if swap == 0:
ch = bh0; cw = bw0
ch = bh0
cw = bw0
else:
ch = bh1; cw = bw1
ch = bh1
cw = bw1
if x < 0 or y < 0:
continue
nw = cur_w if cur_w > x + cw else x + cw
@ -265,8 +286,10 @@ def _best_placement_jit(
break
if not ok:
continue
best_x = x; best_y = y
best_score = score; best_swap = swap
best_x = x
best_y = y
best_score = score
best_swap = swap
if x + cw <= cur_w and y + ch <= cur_h:
break
return best_x, best_y, best_score, best_swap
@ -330,8 +353,10 @@ def _dilate_local(x: Tensor, p: int) -> Tensor:
dilation OR-scattered equals dilating the assembled chart bitmap."""
for _ in range(p):
y = x.clone()
y[:, 1:, :] |= x[:, :-1, :]; y[:, :-1, :] |= x[:, 1:, :]
y[:, :, 1:] |= x[:, :, :-1]; y[:, :, :-1] |= x[:, :, 1:]
y[:, 1:, :] |= x[:, :-1, :]
y[:, :-1, :] |= x[:, 1:, :]
y[:, :, 1:] |= x[:, :, :-1]
y[:, :, :-1] |= x[:, :, 1:]
x = y
return x
@ -358,7 +383,8 @@ def _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding, device
cid = torch.arange(n, device=device).repeat_interleave(fmax)[fm]
# per-triangle pixel bbox, inflated by padding (origin >= 0); bucket by next-pow2 max-dim
tmin = tri_f.amin(1); tmax = tri_f.amax(1)
tmin = tri_f.amin(1)
tmax = tri_f.amax(1)
x0 = (tmin[:, 0].floor().long() - padding).clamp_min(0)
y0 = (tmin[:, 1].floor().long() - padding).clamp_min(0)
bbw = (tmax[:, 0].ceil().long() + padding) - x0 + 1
@ -366,20 +392,31 @@ def _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding, device
mxd = torch.maximum(bbw, bbh).clamp_min(1)
bsz = (2 ** torch.ceil(torch.log2(mxd.float())).long()).long()
a = tri_f[:, 0]; b = tri_f[:, 1]; c = tri_f[:, 2]
v0 = b - a; v1 = c - a
d00 = (v0 * v0).sum(-1); d01 = (v0 * v1).sum(-1); d11 = (v1 * v1).sum(-1)
a = tri_f[:, 0]
b = tri_f[:, 1]
c = tri_f[:, 2]
v0 = b - a
v1 = c - a
d00 = (v0 * v0).sum(-1)
d01 = (v0 * v1).sum(-1)
d11 = (v1 * v1).sum(-1)
den = (d00 * d11 - d01 * d01).clamp(min=1e-20)
for g in sorted(set(bsz.tolist())): # one batch per pow2 grid
sel = (bsz == g).nonzero(as_tuple=True)[0]
m = sel.shape[0]
xs0 = x0[sel].view(m, 1, 1); ys0 = y0[sel].view(m, 1, 1)
cc = cid[sel]; bwp = bwL[cc].view(m, 1, 1); bhp = bhL[cc].view(m, 1, 1)
xs0 = x0[sel].view(m, 1, 1)
ys0 = y0[sel].view(m, 1, 1)
cc = cid[sel]
bwp = bwL[cc].view(m, 1, 1)
bhp = bhL[cc].view(m, 1, 1)
gi = torch.arange(g, device=device)
px = xs0 + gi.view(1, 1, g); py = ys0 + gi.view(1, g, 1) # (m,g,g) int
pxf = px.float() + 0.5; pyf = py.float() + 0.5
v2x = pxf - a[sel, 0].view(m, 1, 1); v2y = pyf - a[sel, 1].view(m, 1, 1)
px = xs0 + gi.view(1, 1, g)
py = ys0 + gi.view(1, g, 1) # (m,g,g) int
pxf = px.float() + 0.5
pyf = py.float() + 0.5
v2x = pxf - a[sel, 0].view(m, 1, 1)
v2y = pyf - a[sel, 1].view(m, 1, 1)
d20 = v2x * v0[sel, 0].view(m, 1, 1) + v2y * v0[sel, 1].view(m, 1, 1)
d21 = v2x * v1[sel, 0].view(m, 1, 1) + v2y * v1[sel, 1].view(m, 1, 1)
idn = den[sel].view(m, 1, 1).reciprocal()
@ -437,7 +474,8 @@ def _best_placement_torch(atlas, pix0, dim0, pix1, dim1, cand0, cand1, cur_w, cu
cx, cy = cand[:, 0], cand[:, 1]
coll = atlas[cy[:, None] + pix[:, 0][None, :], # (M,k) True-pixel gather
cx[:, None] + pix[:, 1][None, :]].any(dim=1)
nw = torch.clamp(cx + cw, min=cur_w); nh = torch.clamp(cy + ch, min=cur_h)
nw = torch.clamp(cx + cw, min=cur_w)
nh = torch.clamp(cy + ch, min=cur_h)
ext = torch.maximum(nw, nh)
score = torch.where(coll, torch.full_like(nw, INF), ext * ext + nw * nh)
j = score.argmin()
@ -465,7 +503,8 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
# ---- Prepare pass 1: best-rotation + scale + bbox for ALL charts at once (batched) ----
vcount = [int(u.shape[0]) for u in chart_uvs]
fcount = [int(f.shape[0]) for f in chart_faces]
vmax = max(vcount); fmax = max(fcount)
vmax = max(vcount)
fmax = max(fcount)
uvs_pad = torch.zeros(n, vmax, 2, device=device)
vmask = torch.zeros(n, vmax, dtype=torch.bool, device=device)
faces_pad = torch.zeros(n, fmax, 3, dtype=torch.long, device=device)
@ -488,8 +527,10 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
cc, ss = cos_a[ti][:, None], sin_a[ti][:, None] # (N,1)
rx = torch.addcmul(u0 * cc, u1, ss, value=-1) # (N,Vmax)
ry = torch.addcmul(u0 * ss, u1, cc)
rxmin = (rx + mlo).amin(1); rxmax = (rx + mhi).amax(1) # (N,)
rymin = (ry + mlo).amin(1); rymax = (ry + mhi).amax(1)
rxmin = (rx + mlo).amin(1) # (N,)
rxmax = (rx + mhi).amax(1)
rymin = (ry + mlo).amin(1)
rymax = (ry + mhi).amax(1)
a3 = torch.tensor([max(a, 1e-12) for a in chart_3d_areas], device=device)
au = torch.tensor([max(a, 1e-12) for a in chart_uv_areas], device=device)
base = (a3 / au).sqrt() * texels_per_unit
@ -504,7 +545,8 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
# one sync: pull all per-chart scalars
thetas = ang[ti].cpu().tolist()
scales = scale.cpu().tolist()
bws = bw_t.cpu().tolist(); bhs = bh_t.cpu().tolist()
bws = bw_t.cpu().tolist()
bhs = bh_t.cpu().tolist()
# ---- Prepare pass 2: rasterize ALL charts at once, then trim each bitmap to its bounds ----
buf, cbase = _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding_texels, device)
@ -513,7 +555,8 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
for i in range(n):
bm = buf[cb[i]:cb[i + 1]].view(bhs[i], bws[i])
raw.append(bm)
rr = torch.arange(bm.shape[0], device=device); cc = torch.arange(bm.shape[1], device=device)
rr = torch.arange(bm.shape[0], device=device)
cc = torch.arange(bm.shape[1], device=device)
rmax = torch.where(bm.any(1), rr, rr.new_full((), -1)).amax() # last occupied row / col (-1 if empty)
cmax = torch.where(bm.any(0), cc, cc.new_full((), -1)).amax()
bnd.append(torch.stack([rmax, cmax]))
@ -527,10 +570,12 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
bm = (raw[i][:rm + 1, :cm + 1].contiguous() if rm >= 0 and cm >= 0
else torch.zeros((1, 1), dtype=torch.bool, device=device))
bm_rot = torch.flip(bm.t(), dims=[1]).contiguous()
pix_l.append(bm.nonzero()); pixr_l.append(bm_rot.nonzero())
pix_l.append(bm.nonzero())
pixr_l.append(bm_rot.nonzero())
dim_l.append((int(bm.shape[0]), int(bm.shape[1])))
dimr_l.append((int(bm_rot.shape[0]), int(bm_rot.shape[1])))
col_tops.append(_col_top(bm)); col_tops_rot.append(_col_top(bm_rot))
col_tops.append(_col_top(bm))
col_tops_rot.append(_col_top(bm_rot))
bm_h.append(int(bm.shape[0]))
wmax = max(d[1] for d in dim_l + dimr_l)
ct_pad = torch.full((n, wmax), -1, dtype=torch.long, device=device)
@ -557,8 +602,11 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
if cur_h + margin > atlas.shape[0] or cur_w + margin > atlas.shape[1]:
ns = max(atlas.shape[0], cur_h + margin, cur_w + margin)
na = torch.zeros((ns, ns), dtype=torch.bool, device=device)
na[:atlas.shape[0], :atlas.shape[1]] = atlas; atlas = na
nsk = torch.zeros(ns, dtype=torch.long, device=device); nsk[:sky_t.shape[0]] = sky_t; sky_t = nsk
na[:atlas.shape[0], :atlas.shape[1]] = atlas
atlas = na
nsk = torch.zeros(ns, dtype=torch.long, device=device)
nsk[:sky_t.shape[0]] = sky_t
sky_t = nsk
dim, dimr = dim_l[ci], dimr_l[ci]
step = max(1, min(dim[0], dim[1]) // 8)
cand0, cand1 = _build_candidates_gpu(sky_t, cur_w, cur_h, dim[1], dimr[1], step, rand_n, gen, device)
@ -569,7 +617,8 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
pix = pixr_l[ci] if swap else pix_l[ci]
bh_, bw_ = (dimr if swap else dim)
atlas[by + pix[:, 0], bx + pix[:, 1]] = True # sparse blit
cur_w = max(cur_w, bx + bw_); cur_h = max(cur_h, by + bh_)
cur_w = max(cur_w, bx + bw_)
cur_h = max(cur_h, by + bh_)
ct = (ctr_pad if swap else ct_pad)[ci, :bw_] # GPU skyline lift
ix = torch.arange(bx, bx + bw_, device=device)
sky_t[ix] = torch.where(ct >= 0, torch.maximum(sky_t[ix], by + ct + 1), sky_t[ix])

View File

@ -33,13 +33,17 @@ def _face_curvature_jit(face_normal: np.ndarray, face_face: np.ndarray) -> np.nd
F = face_normal.shape[0]
raw = np.zeros(F, dtype=np.float32)
for f in range(F):
nx = face_normal[f, 0]; ny = face_normal[f, 1]; nz = face_normal[f, 2]
nx = face_normal[f, 0]
ny = face_normal[f, 1]
nz = face_normal[f, 2]
s = np.float32(0.0)
for e in range(3):
nb = face_face[f, e]
if nb < 0:
continue
mx = face_normal[nb, 0]; my = face_normal[nb, 1]; mz = face_normal[nb, 2]
mx = face_normal[nb, 0]
my = face_normal[nb, 1]
mz = face_normal[nb, 2]
d = nx*mx + ny*my + nz*mz
s += np.float32(1.0) - d
raw[f] = s
@ -70,7 +74,9 @@ def _farthest_point_seeds_jit(
continue
seeds[n_seeds] = s
n_seeds += 1
sx = face_centroid[s, 0]; sy = face_centroid[s, 1]; sz = face_centroid[s, 2]
sx = face_centroid[s, 0]
sy = face_centroid[s, 1]
sz = face_centroid[s, 2]
for f in range(F):
dx = face_centroid[f, 0] - sx
dy = face_centroid[f, 1] - sy
@ -145,7 +151,9 @@ def _cost_grow_iter_jit(
for f in range(F):
if face_chart[f] != -1:
continue
nx = face_normal[f, 0]; ny = face_normal[f, 1]; nz = face_normal[f, 2]
nx = face_normal[f, 0]
ny = face_normal[f, 1]
nz = face_normal[f, 2]
af = face_area[f]
for e0 in range(3):
nb0 = face_face[f, e0]