mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-06 06:30:52 +08:00
Linting
This commit is contained in:
parent
80e9fc65b1
commit
2bbf53e8fc
@ -952,13 +952,13 @@ def flexible_dual_grid_to_mesh(
|
|||||||
values = torch.arange(N, dtype=torch.int32, device=device)
|
values = torch.arange(N, dtype=torch.int32, device=device)
|
||||||
torch_hashmap = TorchHashMap(flat_keys, values)
|
torch_hashmap = TorchHashMap(flat_keys, values)
|
||||||
|
|
||||||
# Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3]
|
# Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3]
|
||||||
n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,)
|
n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,)
|
||||||
offsets_per_axis = flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset[0] # (3, 4, 3)
|
offsets_per_axis = flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset[0] # (3, 4, 3)
|
||||||
connected_voxel = coords[n_idx].unsqueeze(1) + offsets_per_axis[axis_idx] # (M, 4, 3)
|
connected_voxel = coords[n_idx].unsqueeze(1) + offsets_per_axis[axis_idx] # (M, 4, 3)
|
||||||
M = connected_voxel.shape[0]
|
M = connected_voxel.shape[0]
|
||||||
# flatten connected voxel coords and lookup. In-place to avoid extra memory allocation.
|
# flatten connected voxel coords and lookup. In-place to avoid extra memory allocation.
|
||||||
W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item())
|
H, D = int(grid_size[1].item()), int(grid_size[2].item())
|
||||||
cv = connected_voxel.reshape(-1, 3)
|
cv = connected_voxel.reshape(-1, 3)
|
||||||
conn_flat = cv[:, 0].long() * (H * D)
|
conn_flat = cv[:, 0].long() * (H * D)
|
||||||
conn_flat.add_(cv[:, 1].long() * D)
|
conn_flat.add_(cv[:, 1].long() * D)
|
||||||
|
|||||||
@ -40,19 +40,19 @@ class QEMConfig:
|
|||||||
|
|
||||||
flip_reject_hard: bool = True # hard-reject (err=+inf) top-K collapses that flip any 1-ring normal
|
flip_reject_hard: bool = True # hard-reject (err=+inf) top-K collapses that flip any 1-ring normal
|
||||||
|
|
||||||
# Per-iteration batch sizing
|
# Per-iteration batch sizing
|
||||||
sampling_cap: int = 10_000_000 # max edges processed per outer iter
|
sampling_cap: int = 10_000_000 # max edges processed per outer iter
|
||||||
max_collapses_fraction: float = 0.25 # of remaining faces-to-remove
|
max_collapses_fraction: float = 0.25 # of remaining faces-to-remove
|
||||||
max_collapses_floor: int = 10_000
|
max_collapses_floor: int = 10_000
|
||||||
max_collapses_ceiling: int = 1_000_000
|
max_collapses_ceiling: int = 1_000_000
|
||||||
max_collapses_relative_cap: float = 0.10 # cap per-iter collapses as fraction of current faces; 0 disables
|
max_collapses_relative_cap: float = 0.10 # cap per-iter collapses as fraction of current faces; 0 disables
|
||||||
|
|
||||||
# Loop control
|
# Loop control
|
||||||
max_iterations: int = 5_000
|
max_iterations: int = 5_000
|
||||||
compaction_period: int = 5
|
compaction_period: int = 5
|
||||||
compaction_threshold: float = 0.85 # compact when alive_frac < this
|
compaction_threshold: float = 0.85 # compact when alive_frac < this
|
||||||
|
|
||||||
# Quality knobs
|
# Quality knobs
|
||||||
boundary_quadrics: bool = True
|
boundary_quadrics: bool = True
|
||||||
boundary_weight: float = 1000.0
|
boundary_weight: float = 1000.0
|
||||||
recompute_normals_post: bool = True
|
recompute_normals_post: bool = True
|
||||||
@ -63,7 +63,7 @@ class QEMConfig:
|
|||||||
feature_edge_quadric_weight: float = 0.0
|
feature_edge_quadric_weight: float = 0.0
|
||||||
feature_edge_min_dihedral_deg: float = 30.0
|
feature_edge_min_dihedral_deg: float = 30.0
|
||||||
|
|
||||||
# Flip check (FA-QEM §3.3)
|
# Flip check (FA-QEM §3.3)
|
||||||
quality_topk_multiplier: int = 4 # quality-check band size = this * max_collapses_per_iter
|
quality_topk_multiplier: int = 4 # quality-check band size = this * max_collapses_per_iter
|
||||||
flip_cos_threshold: float = 0.0 # 0 = count any sign reversal (dihedral > 90°)
|
flip_cos_threshold: float = 0.0 # 0 = count any sign reversal (dihedral > 90°)
|
||||||
flip_check_max_degree: int = 16 # cap on vertex degree for the flip-check table
|
flip_check_max_degree: int = 16 # cap on vertex degree for the flip-check table
|
||||||
@ -71,13 +71,13 @@ class QEMConfig:
|
|||||||
# Triangle shape penalty
|
# Triangle shape penalty
|
||||||
skinny_weight: float = 1e-3 # penalise top-K collapses producing needle/sliver tris; 0 disables
|
skinny_weight: float = 1e-3 # penalise top-K collapses producing needle/sliver tris; 0 disables
|
||||||
|
|
||||||
# Topology preservation
|
# Topology preservation
|
||||||
enforce_link_condition: bool = True # reject collapses that violate the link condition
|
enforce_link_condition: bool = True # reject collapses that violate the link condition
|
||||||
|
|
||||||
# Quadric area weighting
|
# Quadric area weighting
|
||||||
area_weighted_quadrics: bool = False # True: Garland-Heckbert area-weighted; False: un-weighted
|
area_weighted_quadrics: bool = False # True: Garland-Heckbert area-weighted; False: un-weighted
|
||||||
|
|
||||||
# edge-length cost regularizer
|
# edge-length cost regularizer
|
||||||
lambda_edge_length: float = 1e-2 # add λ*len² to bias toward short edges; 0 disables
|
lambda_edge_length: float = 1e-2 # add λ*len² to bias toward short edges; 0 disables
|
||||||
lambda_edge_length_absolute: bool = True # apply λ absolutely vs relative-to-QEM-median
|
lambda_edge_length_absolute: bool = True # apply λ absolutely vs relative-to-QEM-median
|
||||||
|
|
||||||
@ -87,16 +87,16 @@ class QEMConfig:
|
|||||||
memoryless_qem: bool = True # rebuild quadrics each round vs accumulate
|
memoryless_qem: bool = True # rebuild quadrics each round vs accumulate
|
||||||
repair_nonmanifold: bool = True # final repair_non_manifold_edges pass
|
repair_nonmanifold: bool = True # final repair_non_manifold_edges pass
|
||||||
|
|
||||||
# Pre-clean (input mesh)
|
# Pre-clean (input mesh)
|
||||||
preclean: bool = True # weld coincident verts, drop degenerate/duplicate/unused
|
preclean: bool = True # weld coincident verts, drop degenerate/duplicate/unused
|
||||||
|
|
||||||
# Post-clean (output mesh)
|
# Post-clean (output mesh)
|
||||||
postclean: bool = True # remove slivers, tiny components, unused verts left by collapse
|
postclean: bool = True # remove slivers, tiny components, unused verts left by collapse
|
||||||
postclean_min_angle_deg: float = 0.5
|
postclean_min_angle_deg: float = 0.5
|
||||||
postclean_max_aspect_ratio: float = 100.0
|
postclean_max_aspect_ratio: float = 100.0
|
||||||
postclean_min_component_faces: int = 8 # drop components with fewer faces than this
|
postclean_min_component_faces: int = 8 # drop components with fewer faces than this
|
||||||
|
|
||||||
# Preclean tuning
|
# Preclean tuning
|
||||||
preclean_weld_epsilon_rel: float = 1e-5 # weld tolerance as fraction of bbox diagonal
|
preclean_weld_epsilon_rel: float = 1e-5 # weld tolerance as fraction of bbox diagonal
|
||||||
preclean_min_component_faces: int = 0 # 0 = keep all components
|
preclean_min_component_faces: int = 0 # 0 = keep all components
|
||||||
|
|
||||||
@ -1631,8 +1631,10 @@ def qem_decimate_simplify(
|
|||||||
out_v.append(v)
|
out_v.append(v)
|
||||||
out_f.append(f)
|
out_f.append(f)
|
||||||
out_s.append(s)
|
out_s.append(s)
|
||||||
if c is not None: out_c.append(c)
|
if c is not None:
|
||||||
if n is not None: out_n.append(n)
|
out_c.append(c)
|
||||||
|
if n is not None:
|
||||||
|
out_n.append(n)
|
||||||
return (out_v, out_f,
|
return (out_v, out_f,
|
||||||
out_c if out_c else None,
|
out_c if out_c else None,
|
||||||
out_n if out_n else None,
|
out_n if out_n else None,
|
||||||
|
|||||||
@ -159,7 +159,7 @@ def _build_tri_spatial_hash(centroids: torch.Tensor, tri_radii: torch.Tensor,
|
|||||||
local = torch.arange(total, device=device) - cum[rep]
|
local = torch.arange(total, device=device) - cum[rep]
|
||||||
sx = spans[rep, 0]
|
sx = spans[rep, 0]
|
||||||
sy = spans[rep, 1]
|
sy = spans[rep, 1]
|
||||||
sz = spans[rep, 2]
|
|
||||||
lx = local % sx
|
lx = local % sx
|
||||||
ly = (local // sx) % sy
|
ly = (local // sx) % sy
|
||||||
lz = local // (sx * sy)
|
lz = local // (sx * sy)
|
||||||
@ -696,7 +696,6 @@ def _filter_components(verts: torch.Tensor, faces: torch.Tensor,
|
|||||||
"""Drop tiny / inverted-volume / bbox-enclosed connected components; returns filtered faces."""
|
"""Drop tiny / inverted-volume / bbox-enclosed connected components; returns filtered faces."""
|
||||||
device = faces.device
|
device = faces.device
|
||||||
V = verts.shape[0]
|
V = verts.shape[0]
|
||||||
F = faces.shape[0]
|
|
||||||
|
|
||||||
# Connected components via min-label propagation across faces (200-iter max)
|
# Connected components via min-label propagation across faces (200-iter max)
|
||||||
label = torch.arange(V, dtype=torch.long, device=device)
|
label = torch.arange(V, dtype=torch.long, device=device)
|
||||||
@ -1090,7 +1089,9 @@ def remesh_narrow_band_dc(
|
|||||||
safe_tri = closest_tri.clamp(min=0)
|
safe_tri = closest_tri.clamp(min=0)
|
||||||
tri_v_idx = faces[safe_tri].long() # (N, 3)
|
tri_v_idx = faces[safe_tri].long() # (N, 3)
|
||||||
tri_v = vertices[tri_v_idx] # (N, 3, 3)
|
tri_v = vertices[tri_v_idx] # (N, 3, 3)
|
||||||
v0 = tri_v[:, 0]; v1 = tri_v[:, 1]; v2 = tri_v[:, 2]
|
v0 = tri_v[:, 0]
|
||||||
|
v1 = tri_v[:, 1]
|
||||||
|
v2 = tri_v[:, 2]
|
||||||
e0 = v1 - v0
|
e0 = v1 - v0
|
||||||
e1 = v2 - v0
|
e1 = v2 - v0
|
||||||
e2 = closest_pts - v0
|
e2 = closest_pts - v0
|
||||||
|
|||||||
@ -39,11 +39,15 @@ def _best_rotation_jit(uvs_np: np.ndarray, n_angles: int) -> float:
|
|||||||
half_pi = math.pi * 0.5
|
half_pi = math.pi * 0.5
|
||||||
for k in range(n_angles):
|
for k in range(n_angles):
|
||||||
theta = half_pi * k / n_angles
|
theta = half_pi * k / n_angles
|
||||||
c = math.cos(theta); s = math.sin(theta)
|
c = math.cos(theta)
|
||||||
xmin = 1e30; xmax = -1e30
|
s = math.sin(theta)
|
||||||
ymin = 1e30; ymax = -1e30
|
xmin = 1e30
|
||||||
|
xmax = -1e30
|
||||||
|
ymin = 1e30
|
||||||
|
ymax = -1e30
|
||||||
for i in range(V):
|
for i in range(V):
|
||||||
ux = uvs_np[i, 0]; uy = uvs_np[i, 1]
|
ux = uvs_np[i, 0]
|
||||||
|
uy = uvs_np[i, 1]
|
||||||
xr = ux * c - uy * s
|
xr = ux * c - uy * s
|
||||||
yr = ux * s + uy * c
|
yr = ux * s + uy * c
|
||||||
if xr < xmin: xmin = xr
|
if xr < xmin: xmin = xr
|
||||||
@ -78,10 +82,15 @@ def _rasterize_chart_jit(
|
|||||||
F = faces.shape[0]
|
F = faces.shape[0]
|
||||||
eps = 1e-7
|
eps = 1e-7
|
||||||
for fi in range(F):
|
for fi in range(F):
|
||||||
i0 = faces[fi, 0]; i1 = faces[fi, 1]; i2 = faces[fi, 2]
|
i0 = faces[fi, 0]
|
||||||
x0 = uvs_tex[i0, 0]; y0 = uvs_tex[i0, 1]
|
i1 = faces[fi, 1]
|
||||||
x1 = uvs_tex[i1, 0]; y1 = uvs_tex[i1, 1]
|
i2 = faces[fi, 2]
|
||||||
x2 = uvs_tex[i2, 0]; y2 = uvs_tex[i2, 1]
|
x0 = uvs_tex[i0, 0]
|
||||||
|
y0 = uvs_tex[i0, 1]
|
||||||
|
x1 = uvs_tex[i1, 0]
|
||||||
|
y1 = uvs_tex[i1, 1]
|
||||||
|
x2 = uvs_tex[i2, 0]
|
||||||
|
y2 = uvs_tex[i2, 1]
|
||||||
xmin_f = x0
|
xmin_f = x0
|
||||||
if x1 < xmin_f: xmin_f = x1
|
if x1 < xmin_f: xmin_f = x1
|
||||||
if x2 < xmin_f: xmin_f = x2
|
if x2 < xmin_f: xmin_f = x2
|
||||||
@ -172,19 +181,25 @@ def _build_candidates_jit(
|
|||||||
for xs in range(x, x_end):
|
for xs in range(x, x_end):
|
||||||
if skyline[xs] > y:
|
if skyline[xs] > y:
|
||||||
y = int(skyline[xs])
|
y = int(skyline[xs])
|
||||||
out[k, 0] = x; out[k, 1] = y; out[k, 2] = swap_flag
|
out[k, 0] = x
|
||||||
|
out[k, 1] = y
|
||||||
|
out[k, 2] = swap_flag
|
||||||
k += 1
|
k += 1
|
||||||
x += step
|
x += step
|
||||||
for y_fixed in (0, cur_h):
|
for y_fixed in (0, cur_h):
|
||||||
x = 0
|
x = 0
|
||||||
while x <= cur_w:
|
while x <= cur_w:
|
||||||
out[k, 0] = x; out[k, 1] = y_fixed; out[k, 2] = swap_flag
|
out[k, 0] = x
|
||||||
|
out[k, 1] = y_fixed
|
||||||
|
out[k, 2] = swap_flag
|
||||||
k += 1
|
k += 1
|
||||||
x += step
|
x += step
|
||||||
for x_fixed in (0, cur_w):
|
for x_fixed in (0, cur_w):
|
||||||
y = 0
|
y = 0
|
||||||
while y <= cur_h:
|
while y <= cur_h:
|
||||||
out[k, 0] = x_fixed; out[k, 1] = y; out[k, 2] = swap_flag
|
out[k, 0] = x_fixed
|
||||||
|
out[k, 1] = y
|
||||||
|
out[k, 2] = swap_flag
|
||||||
k += 1
|
k += 1
|
||||||
y += step
|
y += step
|
||||||
return out[:k]
|
return out[:k]
|
||||||
@ -194,7 +209,8 @@ def _build_candidates_jit(
|
|||||||
def _update_skyline_jit(skyline: np.ndarray, chart: np.ndarray,
|
def _update_skyline_jit(skyline: np.ndarray, chart: np.ndarray,
|
||||||
x: int, y: int) -> None:
|
x: int, y: int) -> None:
|
||||||
"""Lift skyline[x+i] to y + topmost_True_row + 1 per chart column."""
|
"""Lift skyline[x+i] to y + topmost_True_row + 1 per chart column."""
|
||||||
ch = chart.shape[0]; cw = chart.shape[1]
|
ch = chart.shape[0]
|
||||||
|
cw = chart.shape[1]
|
||||||
sw = skyline.shape[0]
|
sw = skyline.shape[0]
|
||||||
for i in range(cw):
|
for i in range(cw):
|
||||||
col_x = x + i
|
col_x = x + i
|
||||||
@ -227,17 +243,22 @@ def _best_placement_jit(
|
|||||||
best_y = -1
|
best_y = -1
|
||||||
best_score = -1
|
best_score = -1
|
||||||
best_swap = 0
|
best_swap = 0
|
||||||
bh0 = bitmap.shape[0]; bw0 = bitmap.shape[1]
|
bh0 = bitmap.shape[0]
|
||||||
bh1 = bitmap_rot.shape[0]; bw1 = bitmap_rot.shape[1]
|
bw0 = bitmap.shape[1]
|
||||||
ah = atlas.shape[0]; aw = atlas.shape[1]
|
bh1 = bitmap_rot.shape[0]
|
||||||
|
bw1 = bitmap_rot.shape[1]
|
||||||
|
ah = atlas.shape[0]
|
||||||
|
aw = atlas.shape[1]
|
||||||
for k in range(n):
|
for k in range(n):
|
||||||
x = candidates[k, 0]
|
x = candidates[k, 0]
|
||||||
y = candidates[k, 1]
|
y = candidates[k, 1]
|
||||||
swap = candidates[k, 2]
|
swap = candidates[k, 2]
|
||||||
if swap == 0:
|
if swap == 0:
|
||||||
ch = bh0; cw = bw0
|
ch = bh0
|
||||||
|
cw = bw0
|
||||||
else:
|
else:
|
||||||
ch = bh1; cw = bw1
|
ch = bh1
|
||||||
|
cw = bw1
|
||||||
if x < 0 or y < 0:
|
if x < 0 or y < 0:
|
||||||
continue
|
continue
|
||||||
nw = cur_w if cur_w > x + cw else x + cw
|
nw = cur_w if cur_w > x + cw else x + cw
|
||||||
@ -265,8 +286,10 @@ def _best_placement_jit(
|
|||||||
break
|
break
|
||||||
if not ok:
|
if not ok:
|
||||||
continue
|
continue
|
||||||
best_x = x; best_y = y
|
best_x = x
|
||||||
best_score = score; best_swap = swap
|
best_y = y
|
||||||
|
best_score = score
|
||||||
|
best_swap = swap
|
||||||
if x + cw <= cur_w and y + ch <= cur_h:
|
if x + cw <= cur_w and y + ch <= cur_h:
|
||||||
break
|
break
|
||||||
return best_x, best_y, best_score, best_swap
|
return best_x, best_y, best_score, best_swap
|
||||||
@ -330,8 +353,10 @@ def _dilate_local(x: Tensor, p: int) -> Tensor:
|
|||||||
dilation OR-scattered equals dilating the assembled chart bitmap."""
|
dilation OR-scattered equals dilating the assembled chart bitmap."""
|
||||||
for _ in range(p):
|
for _ in range(p):
|
||||||
y = x.clone()
|
y = x.clone()
|
||||||
y[:, 1:, :] |= x[:, :-1, :]; y[:, :-1, :] |= x[:, 1:, :]
|
y[:, 1:, :] |= x[:, :-1, :]
|
||||||
y[:, :, 1:] |= x[:, :, :-1]; y[:, :, :-1] |= x[:, :, 1:]
|
y[:, :-1, :] |= x[:, 1:, :]
|
||||||
|
y[:, :, 1:] |= x[:, :, :-1]
|
||||||
|
y[:, :, :-1] |= x[:, :, 1:]
|
||||||
x = y
|
x = y
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -358,7 +383,8 @@ def _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding, device
|
|||||||
cid = torch.arange(n, device=device).repeat_interleave(fmax)[fm]
|
cid = torch.arange(n, device=device).repeat_interleave(fmax)[fm]
|
||||||
|
|
||||||
# per-triangle pixel bbox, inflated by padding (origin >= 0); bucket by next-pow2 max-dim
|
# per-triangle pixel bbox, inflated by padding (origin >= 0); bucket by next-pow2 max-dim
|
||||||
tmin = tri_f.amin(1); tmax = tri_f.amax(1)
|
tmin = tri_f.amin(1)
|
||||||
|
tmax = tri_f.amax(1)
|
||||||
x0 = (tmin[:, 0].floor().long() - padding).clamp_min(0)
|
x0 = (tmin[:, 0].floor().long() - padding).clamp_min(0)
|
||||||
y0 = (tmin[:, 1].floor().long() - padding).clamp_min(0)
|
y0 = (tmin[:, 1].floor().long() - padding).clamp_min(0)
|
||||||
bbw = (tmax[:, 0].ceil().long() + padding) - x0 + 1
|
bbw = (tmax[:, 0].ceil().long() + padding) - x0 + 1
|
||||||
@ -366,20 +392,31 @@ def _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding, device
|
|||||||
mxd = torch.maximum(bbw, bbh).clamp_min(1)
|
mxd = torch.maximum(bbw, bbh).clamp_min(1)
|
||||||
bsz = (2 ** torch.ceil(torch.log2(mxd.float())).long()).long()
|
bsz = (2 ** torch.ceil(torch.log2(mxd.float())).long()).long()
|
||||||
|
|
||||||
a = tri_f[:, 0]; b = tri_f[:, 1]; c = tri_f[:, 2]
|
a = tri_f[:, 0]
|
||||||
v0 = b - a; v1 = c - a
|
b = tri_f[:, 1]
|
||||||
d00 = (v0 * v0).sum(-1); d01 = (v0 * v1).sum(-1); d11 = (v1 * v1).sum(-1)
|
c = tri_f[:, 2]
|
||||||
|
v0 = b - a
|
||||||
|
v1 = c - a
|
||||||
|
d00 = (v0 * v0).sum(-1)
|
||||||
|
d01 = (v0 * v1).sum(-1)
|
||||||
|
d11 = (v1 * v1).sum(-1)
|
||||||
den = (d00 * d11 - d01 * d01).clamp(min=1e-20)
|
den = (d00 * d11 - d01 * d01).clamp(min=1e-20)
|
||||||
|
|
||||||
for g in sorted(set(bsz.tolist())): # one batch per pow2 grid
|
for g in sorted(set(bsz.tolist())): # one batch per pow2 grid
|
||||||
sel = (bsz == g).nonzero(as_tuple=True)[0]
|
sel = (bsz == g).nonzero(as_tuple=True)[0]
|
||||||
m = sel.shape[0]
|
m = sel.shape[0]
|
||||||
xs0 = x0[sel].view(m, 1, 1); ys0 = y0[sel].view(m, 1, 1)
|
xs0 = x0[sel].view(m, 1, 1)
|
||||||
cc = cid[sel]; bwp = bwL[cc].view(m, 1, 1); bhp = bhL[cc].view(m, 1, 1)
|
ys0 = y0[sel].view(m, 1, 1)
|
||||||
|
cc = cid[sel]
|
||||||
|
bwp = bwL[cc].view(m, 1, 1)
|
||||||
|
bhp = bhL[cc].view(m, 1, 1)
|
||||||
gi = torch.arange(g, device=device)
|
gi = torch.arange(g, device=device)
|
||||||
px = xs0 + gi.view(1, 1, g); py = ys0 + gi.view(1, g, 1) # (m,g,g) int
|
px = xs0 + gi.view(1, 1, g)
|
||||||
pxf = px.float() + 0.5; pyf = py.float() + 0.5
|
py = ys0 + gi.view(1, g, 1) # (m,g,g) int
|
||||||
v2x = pxf - a[sel, 0].view(m, 1, 1); v2y = pyf - a[sel, 1].view(m, 1, 1)
|
pxf = px.float() + 0.5
|
||||||
|
pyf = py.float() + 0.5
|
||||||
|
v2x = pxf - a[sel, 0].view(m, 1, 1)
|
||||||
|
v2y = pyf - a[sel, 1].view(m, 1, 1)
|
||||||
d20 = v2x * v0[sel, 0].view(m, 1, 1) + v2y * v0[sel, 1].view(m, 1, 1)
|
d20 = v2x * v0[sel, 0].view(m, 1, 1) + v2y * v0[sel, 1].view(m, 1, 1)
|
||||||
d21 = v2x * v1[sel, 0].view(m, 1, 1) + v2y * v1[sel, 1].view(m, 1, 1)
|
d21 = v2x * v1[sel, 0].view(m, 1, 1) + v2y * v1[sel, 1].view(m, 1, 1)
|
||||||
idn = den[sel].view(m, 1, 1).reciprocal()
|
idn = den[sel].view(m, 1, 1).reciprocal()
|
||||||
@ -437,7 +474,8 @@ def _best_placement_torch(atlas, pix0, dim0, pix1, dim1, cand0, cand1, cur_w, cu
|
|||||||
cx, cy = cand[:, 0], cand[:, 1]
|
cx, cy = cand[:, 0], cand[:, 1]
|
||||||
coll = atlas[cy[:, None] + pix[:, 0][None, :], # (M,k) True-pixel gather
|
coll = atlas[cy[:, None] + pix[:, 0][None, :], # (M,k) True-pixel gather
|
||||||
cx[:, None] + pix[:, 1][None, :]].any(dim=1)
|
cx[:, None] + pix[:, 1][None, :]].any(dim=1)
|
||||||
nw = torch.clamp(cx + cw, min=cur_w); nh = torch.clamp(cy + ch, min=cur_h)
|
nw = torch.clamp(cx + cw, min=cur_w)
|
||||||
|
nh = torch.clamp(cy + ch, min=cur_h)
|
||||||
ext = torch.maximum(nw, nh)
|
ext = torch.maximum(nw, nh)
|
||||||
score = torch.where(coll, torch.full_like(nw, INF), ext * ext + nw * nh)
|
score = torch.where(coll, torch.full_like(nw, INF), ext * ext + nw * nh)
|
||||||
j = score.argmin()
|
j = score.argmin()
|
||||||
@ -465,7 +503,8 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
|
|||||||
# ---- Prepare pass 1: best-rotation + scale + bbox for ALL charts at once (batched) ----
|
# ---- Prepare pass 1: best-rotation + scale + bbox for ALL charts at once (batched) ----
|
||||||
vcount = [int(u.shape[0]) for u in chart_uvs]
|
vcount = [int(u.shape[0]) for u in chart_uvs]
|
||||||
fcount = [int(f.shape[0]) for f in chart_faces]
|
fcount = [int(f.shape[0]) for f in chart_faces]
|
||||||
vmax = max(vcount); fmax = max(fcount)
|
vmax = max(vcount)
|
||||||
|
fmax = max(fcount)
|
||||||
uvs_pad = torch.zeros(n, vmax, 2, device=device)
|
uvs_pad = torch.zeros(n, vmax, 2, device=device)
|
||||||
vmask = torch.zeros(n, vmax, dtype=torch.bool, device=device)
|
vmask = torch.zeros(n, vmax, dtype=torch.bool, device=device)
|
||||||
faces_pad = torch.zeros(n, fmax, 3, dtype=torch.long, device=device)
|
faces_pad = torch.zeros(n, fmax, 3, dtype=torch.long, device=device)
|
||||||
@ -488,8 +527,10 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
|
|||||||
cc, ss = cos_a[ti][:, None], sin_a[ti][:, None] # (N,1)
|
cc, ss = cos_a[ti][:, None], sin_a[ti][:, None] # (N,1)
|
||||||
rx = torch.addcmul(u0 * cc, u1, ss, value=-1) # (N,Vmax)
|
rx = torch.addcmul(u0 * cc, u1, ss, value=-1) # (N,Vmax)
|
||||||
ry = torch.addcmul(u0 * ss, u1, cc)
|
ry = torch.addcmul(u0 * ss, u1, cc)
|
||||||
rxmin = (rx + mlo).amin(1); rxmax = (rx + mhi).amax(1) # (N,)
|
rxmin = (rx + mlo).amin(1) # (N,)
|
||||||
rymin = (ry + mlo).amin(1); rymax = (ry + mhi).amax(1)
|
rxmax = (rx + mhi).amax(1)
|
||||||
|
rymin = (ry + mlo).amin(1)
|
||||||
|
rymax = (ry + mhi).amax(1)
|
||||||
a3 = torch.tensor([max(a, 1e-12) for a in chart_3d_areas], device=device)
|
a3 = torch.tensor([max(a, 1e-12) for a in chart_3d_areas], device=device)
|
||||||
au = torch.tensor([max(a, 1e-12) for a in chart_uv_areas], device=device)
|
au = torch.tensor([max(a, 1e-12) for a in chart_uv_areas], device=device)
|
||||||
base = (a3 / au).sqrt() * texels_per_unit
|
base = (a3 / au).sqrt() * texels_per_unit
|
||||||
@ -504,7 +545,8 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
|
|||||||
# one sync: pull all per-chart scalars
|
# one sync: pull all per-chart scalars
|
||||||
thetas = ang[ti].cpu().tolist()
|
thetas = ang[ti].cpu().tolist()
|
||||||
scales = scale.cpu().tolist()
|
scales = scale.cpu().tolist()
|
||||||
bws = bw_t.cpu().tolist(); bhs = bh_t.cpu().tolist()
|
bws = bw_t.cpu().tolist()
|
||||||
|
bhs = bh_t.cpu().tolist()
|
||||||
|
|
||||||
# ---- Prepare pass 2: rasterize ALL charts at once, then trim each bitmap to its bounds ----
|
# ---- Prepare pass 2: rasterize ALL charts at once, then trim each bitmap to its bounds ----
|
||||||
buf, cbase = _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding_texels, device)
|
buf, cbase = _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding_texels, device)
|
||||||
@ -513,7 +555,8 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
|
|||||||
for i in range(n):
|
for i in range(n):
|
||||||
bm = buf[cb[i]:cb[i + 1]].view(bhs[i], bws[i])
|
bm = buf[cb[i]:cb[i + 1]].view(bhs[i], bws[i])
|
||||||
raw.append(bm)
|
raw.append(bm)
|
||||||
rr = torch.arange(bm.shape[0], device=device); cc = torch.arange(bm.shape[1], device=device)
|
rr = torch.arange(bm.shape[0], device=device)
|
||||||
|
cc = torch.arange(bm.shape[1], device=device)
|
||||||
rmax = torch.where(bm.any(1), rr, rr.new_full((), -1)).amax() # last occupied row / col (-1 if empty)
|
rmax = torch.where(bm.any(1), rr, rr.new_full((), -1)).amax() # last occupied row / col (-1 if empty)
|
||||||
cmax = torch.where(bm.any(0), cc, cc.new_full((), -1)).amax()
|
cmax = torch.where(bm.any(0), cc, cc.new_full((), -1)).amax()
|
||||||
bnd.append(torch.stack([rmax, cmax]))
|
bnd.append(torch.stack([rmax, cmax]))
|
||||||
@ -527,10 +570,12 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
|
|||||||
bm = (raw[i][:rm + 1, :cm + 1].contiguous() if rm >= 0 and cm >= 0
|
bm = (raw[i][:rm + 1, :cm + 1].contiguous() if rm >= 0 and cm >= 0
|
||||||
else torch.zeros((1, 1), dtype=torch.bool, device=device))
|
else torch.zeros((1, 1), dtype=torch.bool, device=device))
|
||||||
bm_rot = torch.flip(bm.t(), dims=[1]).contiguous()
|
bm_rot = torch.flip(bm.t(), dims=[1]).contiguous()
|
||||||
pix_l.append(bm.nonzero()); pixr_l.append(bm_rot.nonzero())
|
pix_l.append(bm.nonzero())
|
||||||
|
pixr_l.append(bm_rot.nonzero())
|
||||||
dim_l.append((int(bm.shape[0]), int(bm.shape[1])))
|
dim_l.append((int(bm.shape[0]), int(bm.shape[1])))
|
||||||
dimr_l.append((int(bm_rot.shape[0]), int(bm_rot.shape[1])))
|
dimr_l.append((int(bm_rot.shape[0]), int(bm_rot.shape[1])))
|
||||||
col_tops.append(_col_top(bm)); col_tops_rot.append(_col_top(bm_rot))
|
col_tops.append(_col_top(bm))
|
||||||
|
col_tops_rot.append(_col_top(bm_rot))
|
||||||
bm_h.append(int(bm.shape[0]))
|
bm_h.append(int(bm.shape[0]))
|
||||||
wmax = max(d[1] for d in dim_l + dimr_l)
|
wmax = max(d[1] for d in dim_l + dimr_l)
|
||||||
ct_pad = torch.full((n, wmax), -1, dtype=torch.long, device=device)
|
ct_pad = torch.full((n, wmax), -1, dtype=torch.long, device=device)
|
||||||
@ -557,8 +602,11 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
|
|||||||
if cur_h + margin > atlas.shape[0] or cur_w + margin > atlas.shape[1]:
|
if cur_h + margin > atlas.shape[0] or cur_w + margin > atlas.shape[1]:
|
||||||
ns = max(atlas.shape[0], cur_h + margin, cur_w + margin)
|
ns = max(atlas.shape[0], cur_h + margin, cur_w + margin)
|
||||||
na = torch.zeros((ns, ns), dtype=torch.bool, device=device)
|
na = torch.zeros((ns, ns), dtype=torch.bool, device=device)
|
||||||
na[:atlas.shape[0], :atlas.shape[1]] = atlas; atlas = na
|
na[:atlas.shape[0], :atlas.shape[1]] = atlas
|
||||||
nsk = torch.zeros(ns, dtype=torch.long, device=device); nsk[:sky_t.shape[0]] = sky_t; sky_t = nsk
|
atlas = na
|
||||||
|
nsk = torch.zeros(ns, dtype=torch.long, device=device)
|
||||||
|
nsk[:sky_t.shape[0]] = sky_t
|
||||||
|
sky_t = nsk
|
||||||
dim, dimr = dim_l[ci], dimr_l[ci]
|
dim, dimr = dim_l[ci], dimr_l[ci]
|
||||||
step = max(1, min(dim[0], dim[1]) // 8)
|
step = max(1, min(dim[0], dim[1]) // 8)
|
||||||
cand0, cand1 = _build_candidates_gpu(sky_t, cur_w, cur_h, dim[1], dimr[1], step, rand_n, gen, device)
|
cand0, cand1 = _build_candidates_gpu(sky_t, cur_w, cur_h, dim[1], dimr[1], step, rand_n, gen, device)
|
||||||
@ -569,7 +617,8 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces,
|
|||||||
pix = pixr_l[ci] if swap else pix_l[ci]
|
pix = pixr_l[ci] if swap else pix_l[ci]
|
||||||
bh_, bw_ = (dimr if swap else dim)
|
bh_, bw_ = (dimr if swap else dim)
|
||||||
atlas[by + pix[:, 0], bx + pix[:, 1]] = True # sparse blit
|
atlas[by + pix[:, 0], bx + pix[:, 1]] = True # sparse blit
|
||||||
cur_w = max(cur_w, bx + bw_); cur_h = max(cur_h, by + bh_)
|
cur_w = max(cur_w, bx + bw_)
|
||||||
|
cur_h = max(cur_h, by + bh_)
|
||||||
ct = (ctr_pad if swap else ct_pad)[ci, :bw_] # GPU skyline lift
|
ct = (ctr_pad if swap else ct_pad)[ci, :bw_] # GPU skyline lift
|
||||||
ix = torch.arange(bx, bx + bw_, device=device)
|
ix = torch.arange(bx, bx + bw_, device=device)
|
||||||
sky_t[ix] = torch.where(ct >= 0, torch.maximum(sky_t[ix], by + ct + 1), sky_t[ix])
|
sky_t[ix] = torch.where(ct >= 0, torch.maximum(sky_t[ix], by + ct + 1), sky_t[ix])
|
||||||
|
|||||||
@ -33,13 +33,17 @@ def _face_curvature_jit(face_normal: np.ndarray, face_face: np.ndarray) -> np.nd
|
|||||||
F = face_normal.shape[0]
|
F = face_normal.shape[0]
|
||||||
raw = np.zeros(F, dtype=np.float32)
|
raw = np.zeros(F, dtype=np.float32)
|
||||||
for f in range(F):
|
for f in range(F):
|
||||||
nx = face_normal[f, 0]; ny = face_normal[f, 1]; nz = face_normal[f, 2]
|
nx = face_normal[f, 0]
|
||||||
|
ny = face_normal[f, 1]
|
||||||
|
nz = face_normal[f, 2]
|
||||||
s = np.float32(0.0)
|
s = np.float32(0.0)
|
||||||
for e in range(3):
|
for e in range(3):
|
||||||
nb = face_face[f, e]
|
nb = face_face[f, e]
|
||||||
if nb < 0:
|
if nb < 0:
|
||||||
continue
|
continue
|
||||||
mx = face_normal[nb, 0]; my = face_normal[nb, 1]; mz = face_normal[nb, 2]
|
mx = face_normal[nb, 0]
|
||||||
|
my = face_normal[nb, 1]
|
||||||
|
mz = face_normal[nb, 2]
|
||||||
d = nx*mx + ny*my + nz*mz
|
d = nx*mx + ny*my + nz*mz
|
||||||
s += np.float32(1.0) - d
|
s += np.float32(1.0) - d
|
||||||
raw[f] = s
|
raw[f] = s
|
||||||
@ -70,7 +74,9 @@ def _farthest_point_seeds_jit(
|
|||||||
continue
|
continue
|
||||||
seeds[n_seeds] = s
|
seeds[n_seeds] = s
|
||||||
n_seeds += 1
|
n_seeds += 1
|
||||||
sx = face_centroid[s, 0]; sy = face_centroid[s, 1]; sz = face_centroid[s, 2]
|
sx = face_centroid[s, 0]
|
||||||
|
sy = face_centroid[s, 1]
|
||||||
|
sz = face_centroid[s, 2]
|
||||||
for f in range(F):
|
for f in range(F):
|
||||||
dx = face_centroid[f, 0] - sx
|
dx = face_centroid[f, 0] - sx
|
||||||
dy = face_centroid[f, 1] - sy
|
dy = face_centroid[f, 1] - sy
|
||||||
@ -145,7 +151,9 @@ def _cost_grow_iter_jit(
|
|||||||
for f in range(F):
|
for f in range(F):
|
||||||
if face_chart[f] != -1:
|
if face_chart[f] != -1:
|
||||||
continue
|
continue
|
||||||
nx = face_normal[f, 0]; ny = face_normal[f, 1]; nz = face_normal[f, 2]
|
nx = face_normal[f, 0]
|
||||||
|
ny = face_normal[f, 1]
|
||||||
|
nz = face_normal[f, 2]
|
||||||
af = face_area[f]
|
af = face_area[f]
|
||||||
for e0 in range(3):
|
for e0 in range(3):
|
||||||
nb0 = face_face[f, e0]
|
nb0 = face_face[f, e0]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user