From 2bbf53e8fc6c2b58d5d5b73f3c5617345380b28d Mon Sep 17 00:00:00 2001 From: kijai Date: Wed, 1 Jul 2026 21:21:11 +0300 Subject: [PATCH] Linting --- comfy/ldm/trellis2/vae.py | 4 +- .../mesh3d/postprocess/qem_decimate.py | 26 ++-- comfy_extras/mesh3d/postprocess/remesh.py | 7 +- comfy_extras/mesh3d/uv_unwrap/pack.py | 131 ++++++++++++------ comfy_extras/mesh3d/uv_unwrap/segment.py | 16 ++- 5 files changed, 122 insertions(+), 62 deletions(-) diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 2f7adf0fe..7f765fc6e 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -952,13 +952,13 @@ def flexible_dual_grid_to_mesh( values = torch.arange(N, dtype=torch.int32, device=device) torch_hashmap = TorchHashMap(flat_keys, values) - # Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3] + # Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3] n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,) offsets_per_axis = flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset[0] # (3, 4, 3) connected_voxel = coords[n_idx].unsqueeze(1) + offsets_per_axis[axis_idx] # (M, 4, 3) M = connected_voxel.shape[0] # flatten connected voxel coords and lookup. In-place to avoid extra memory allocation. - W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + H, D = int(grid_size[1].item()), int(grid_size[2].item()) cv = connected_voxel.reshape(-1, 3) conn_flat = cv[:, 0].long() * (H * D) conn_flat.add_(cv[:, 1].long() * D) diff --git a/comfy_extras/mesh3d/postprocess/qem_decimate.py b/comfy_extras/mesh3d/postprocess/qem_decimate.py index d382a0661..eae272ef5 100644 --- a/comfy_extras/mesh3d/postprocess/qem_decimate.py +++ b/comfy_extras/mesh3d/postprocess/qem_decimate.py @@ -40,19 +40,19 @@ class QEMConfig: flip_reject_hard: bool = True # hard-reject (err=+inf) top-K collapses that flip any 1-ring normal - # Per-iteration batch sizing + # Per-iteration batch sizing sampling_cap: int = 10_000_000 # max edges processed per outer iter max_collapses_fraction: float = 0.25 # of remaining faces-to-remove max_collapses_floor: int = 10_000 max_collapses_ceiling: int = 1_000_000 max_collapses_relative_cap: float = 0.10 # cap per-iter collapses as fraction of current faces; 0 disables - # Loop control + # Loop control max_iterations: int = 5_000 compaction_period: int = 5 compaction_threshold: float = 0.85 # compact when alive_frac < this - # Quality knobs + # Quality knobs boundary_quadrics: bool = True boundary_weight: float = 1000.0 recompute_normals_post: bool = True @@ -63,7 +63,7 @@ class QEMConfig: feature_edge_quadric_weight: float = 0.0 feature_edge_min_dihedral_deg: float = 30.0 - # Flip check (FA-QEM §3.3) + # Flip check (FA-QEM §3.3) quality_topk_multiplier: int = 4 # quality-check band size = this * max_collapses_per_iter flip_cos_threshold: float = 0.0 # 0 = count any sign reversal (dihedral > 90°) flip_check_max_degree: int = 16 # cap on vertex degree for the flip-check table @@ -71,13 +71,13 @@ class QEMConfig: # Triangle shape penalty skinny_weight: float = 1e-3 # penalise top-K collapses producing needle/sliver tris; 0 disables - # Topology preservation + # Topology preservation enforce_link_condition: bool = True # reject collapses that violate the link condition - # Quadric area weighting + # Quadric area weighting area_weighted_quadrics: bool = False # True: Garland-Heckbert area-weighted; False: un-weighted - # edge-length cost regularizer + # edge-length cost regularizer lambda_edge_length: float = 1e-2 # add λ*len² to bias toward short edges; 0 disables lambda_edge_length_absolute: bool = True # apply λ absolutely vs relative-to-QEM-median @@ -87,16 +87,16 @@ class QEMConfig: memoryless_qem: bool = True # rebuild quadrics each round vs accumulate repair_nonmanifold: bool = True # final repair_non_manifold_edges pass - # Pre-clean (input mesh) + # Pre-clean (input mesh) preclean: bool = True # weld coincident verts, drop degenerate/duplicate/unused - # Post-clean (output mesh) + # Post-clean (output mesh) postclean: bool = True # remove slivers, tiny components, unused verts left by collapse postclean_min_angle_deg: float = 0.5 postclean_max_aspect_ratio: float = 100.0 postclean_min_component_faces: int = 8 # drop components with fewer faces than this - # Preclean tuning + # Preclean tuning preclean_weld_epsilon_rel: float = 1e-5 # weld tolerance as fraction of bbox diagonal preclean_min_component_faces: int = 0 # 0 = keep all components @@ -1631,8 +1631,10 @@ def qem_decimate_simplify( out_v.append(v) out_f.append(f) out_s.append(s) - if c is not None: out_c.append(c) - if n is not None: out_n.append(n) + if c is not None: + out_c.append(c) + if n is not None: + out_n.append(n) return (out_v, out_f, out_c if out_c else None, out_n if out_n else None, diff --git a/comfy_extras/mesh3d/postprocess/remesh.py b/comfy_extras/mesh3d/postprocess/remesh.py index 502c4e9bc..2e20f6254 100644 --- a/comfy_extras/mesh3d/postprocess/remesh.py +++ b/comfy_extras/mesh3d/postprocess/remesh.py @@ -159,7 +159,7 @@ def _build_tri_spatial_hash(centroids: torch.Tensor, tri_radii: torch.Tensor, local = torch.arange(total, device=device) - cum[rep] sx = spans[rep, 0] sy = spans[rep, 1] - sz = spans[rep, 2] + lx = local % sx ly = (local // sx) % sy lz = local // (sx * sy) @@ -696,7 +696,6 @@ def _filter_components(verts: torch.Tensor, faces: torch.Tensor, """Drop tiny / inverted-volume / bbox-enclosed connected components; returns filtered faces.""" device = faces.device V = verts.shape[0] - F = faces.shape[0] # Connected components via min-label propagation across faces (200-iter max) label = torch.arange(V, dtype=torch.long, device=device) @@ -1090,7 +1089,9 @@ def remesh_narrow_band_dc( safe_tri = closest_tri.clamp(min=0) tri_v_idx = faces[safe_tri].long() # (N, 3) tri_v = vertices[tri_v_idx] # (N, 3, 3) - v0 = tri_v[:, 0]; v1 = tri_v[:, 1]; v2 = tri_v[:, 2] + v0 = tri_v[:, 0] + v1 = tri_v[:, 1] + v2 = tri_v[:, 2] e0 = v1 - v0 e1 = v2 - v0 e2 = closest_pts - v0 diff --git a/comfy_extras/mesh3d/uv_unwrap/pack.py b/comfy_extras/mesh3d/uv_unwrap/pack.py index 9ad77257a..53dbec299 100644 --- a/comfy_extras/mesh3d/uv_unwrap/pack.py +++ b/comfy_extras/mesh3d/uv_unwrap/pack.py @@ -39,11 +39,15 @@ def _best_rotation_jit(uvs_np: np.ndarray, n_angles: int) -> float: half_pi = math.pi * 0.5 for k in range(n_angles): theta = half_pi * k / n_angles - c = math.cos(theta); s = math.sin(theta) - xmin = 1e30; xmax = -1e30 - ymin = 1e30; ymax = -1e30 + c = math.cos(theta) + s = math.sin(theta) + xmin = 1e30 + xmax = -1e30 + ymin = 1e30 + ymax = -1e30 for i in range(V): - ux = uvs_np[i, 0]; uy = uvs_np[i, 1] + ux = uvs_np[i, 0] + uy = uvs_np[i, 1] xr = ux * c - uy * s yr = ux * s + uy * c if xr < xmin: xmin = xr @@ -78,10 +82,15 @@ def _rasterize_chart_jit( F = faces.shape[0] eps = 1e-7 for fi in range(F): - i0 = faces[fi, 0]; i1 = faces[fi, 1]; i2 = faces[fi, 2] - x0 = uvs_tex[i0, 0]; y0 = uvs_tex[i0, 1] - x1 = uvs_tex[i1, 0]; y1 = uvs_tex[i1, 1] - x2 = uvs_tex[i2, 0]; y2 = uvs_tex[i2, 1] + i0 = faces[fi, 0] + i1 = faces[fi, 1] + i2 = faces[fi, 2] + x0 = uvs_tex[i0, 0] + y0 = uvs_tex[i0, 1] + x1 = uvs_tex[i1, 0] + y1 = uvs_tex[i1, 1] + x2 = uvs_tex[i2, 0] + y2 = uvs_tex[i2, 1] xmin_f = x0 if x1 < xmin_f: xmin_f = x1 if x2 < xmin_f: xmin_f = x2 @@ -172,19 +181,25 @@ def _build_candidates_jit( for xs in range(x, x_end): if skyline[xs] > y: y = int(skyline[xs]) - out[k, 0] = x; out[k, 1] = y; out[k, 2] = swap_flag + out[k, 0] = x + out[k, 1] = y + out[k, 2] = swap_flag k += 1 x += step for y_fixed in (0, cur_h): x = 0 while x <= cur_w: - out[k, 0] = x; out[k, 1] = y_fixed; out[k, 2] = swap_flag + out[k, 0] = x + out[k, 1] = y_fixed + out[k, 2] = swap_flag k += 1 x += step for x_fixed in (0, cur_w): y = 0 while y <= cur_h: - out[k, 0] = x_fixed; out[k, 1] = y; out[k, 2] = swap_flag + out[k, 0] = x_fixed + out[k, 1] = y + out[k, 2] = swap_flag k += 1 y += step return out[:k] @@ -194,7 +209,8 @@ def _build_candidates_jit( def _update_skyline_jit(skyline: np.ndarray, chart: np.ndarray, x: int, y: int) -> None: """Lift skyline[x+i] to y + topmost_True_row + 1 per chart column.""" - ch = chart.shape[0]; cw = chart.shape[1] + ch = chart.shape[0] + cw = chart.shape[1] sw = skyline.shape[0] for i in range(cw): col_x = x + i @@ -227,17 +243,22 @@ def _best_placement_jit( best_y = -1 best_score = -1 best_swap = 0 - bh0 = bitmap.shape[0]; bw0 = bitmap.shape[1] - bh1 = bitmap_rot.shape[0]; bw1 = bitmap_rot.shape[1] - ah = atlas.shape[0]; aw = atlas.shape[1] + bh0 = bitmap.shape[0] + bw0 = bitmap.shape[1] + bh1 = bitmap_rot.shape[0] + bw1 = bitmap_rot.shape[1] + ah = atlas.shape[0] + aw = atlas.shape[1] for k in range(n): x = candidates[k, 0] y = candidates[k, 1] swap = candidates[k, 2] if swap == 0: - ch = bh0; cw = bw0 + ch = bh0 + cw = bw0 else: - ch = bh1; cw = bw1 + ch = bh1 + cw = bw1 if x < 0 or y < 0: continue nw = cur_w if cur_w > x + cw else x + cw @@ -265,8 +286,10 @@ def _best_placement_jit( break if not ok: continue - best_x = x; best_y = y - best_score = score; best_swap = swap + best_x = x + best_y = y + best_score = score + best_swap = swap if x + cw <= cur_w and y + ch <= cur_h: break return best_x, best_y, best_score, best_swap @@ -330,8 +353,10 @@ def _dilate_local(x: Tensor, p: int) -> Tensor: dilation OR-scattered equals dilating the assembled chart bitmap.""" for _ in range(p): y = x.clone() - y[:, 1:, :] |= x[:, :-1, :]; y[:, :-1, :] |= x[:, 1:, :] - y[:, :, 1:] |= x[:, :, :-1]; y[:, :, :-1] |= x[:, :, 1:] + y[:, 1:, :] |= x[:, :-1, :] + y[:, :-1, :] |= x[:, 1:, :] + y[:, :, 1:] |= x[:, :, :-1] + y[:, :, :-1] |= x[:, :, 1:] x = y return x @@ -358,7 +383,8 @@ def _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding, device cid = torch.arange(n, device=device).repeat_interleave(fmax)[fm] # per-triangle pixel bbox, inflated by padding (origin >= 0); bucket by next-pow2 max-dim - tmin = tri_f.amin(1); tmax = tri_f.amax(1) + tmin = tri_f.amin(1) + tmax = tri_f.amax(1) x0 = (tmin[:, 0].floor().long() - padding).clamp_min(0) y0 = (tmin[:, 1].floor().long() - padding).clamp_min(0) bbw = (tmax[:, 0].ceil().long() + padding) - x0 + 1 @@ -366,20 +392,31 @@ def _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding, device mxd = torch.maximum(bbw, bbh).clamp_min(1) bsz = (2 ** torch.ceil(torch.log2(mxd.float())).long()).long() - a = tri_f[:, 0]; b = tri_f[:, 1]; c = tri_f[:, 2] - v0 = b - a; v1 = c - a - d00 = (v0 * v0).sum(-1); d01 = (v0 * v1).sum(-1); d11 = (v1 * v1).sum(-1) + a = tri_f[:, 0] + b = tri_f[:, 1] + c = tri_f[:, 2] + v0 = b - a + v1 = c - a + d00 = (v0 * v0).sum(-1) + d01 = (v0 * v1).sum(-1) + d11 = (v1 * v1).sum(-1) den = (d00 * d11 - d01 * d01).clamp(min=1e-20) for g in sorted(set(bsz.tolist())): # one batch per pow2 grid sel = (bsz == g).nonzero(as_tuple=True)[0] m = sel.shape[0] - xs0 = x0[sel].view(m, 1, 1); ys0 = y0[sel].view(m, 1, 1) - cc = cid[sel]; bwp = bwL[cc].view(m, 1, 1); bhp = bhL[cc].view(m, 1, 1) + xs0 = x0[sel].view(m, 1, 1) + ys0 = y0[sel].view(m, 1, 1) + cc = cid[sel] + bwp = bwL[cc].view(m, 1, 1) + bhp = bhL[cc].view(m, 1, 1) gi = torch.arange(g, device=device) - px = xs0 + gi.view(1, 1, g); py = ys0 + gi.view(1, g, 1) # (m,g,g) int - pxf = px.float() + 0.5; pyf = py.float() + 0.5 - v2x = pxf - a[sel, 0].view(m, 1, 1); v2y = pyf - a[sel, 1].view(m, 1, 1) + px = xs0 + gi.view(1, 1, g) + py = ys0 + gi.view(1, g, 1) # (m,g,g) int + pxf = px.float() + 0.5 + pyf = py.float() + 0.5 + v2x = pxf - a[sel, 0].view(m, 1, 1) + v2y = pyf - a[sel, 1].view(m, 1, 1) d20 = v2x * v0[sel, 0].view(m, 1, 1) + v2y * v0[sel, 1].view(m, 1, 1) d21 = v2x * v1[sel, 0].view(m, 1, 1) + v2y * v1[sel, 1].view(m, 1, 1) idn = den[sel].view(m, 1, 1).reciprocal() @@ -437,7 +474,8 @@ def _best_placement_torch(atlas, pix0, dim0, pix1, dim1, cand0, cand1, cur_w, cu cx, cy = cand[:, 0], cand[:, 1] coll = atlas[cy[:, None] + pix[:, 0][None, :], # (M,k) True-pixel gather cx[:, None] + pix[:, 1][None, :]].any(dim=1) - nw = torch.clamp(cx + cw, min=cur_w); nh = torch.clamp(cy + ch, min=cur_h) + nw = torch.clamp(cx + cw, min=cur_w) + nh = torch.clamp(cy + ch, min=cur_h) ext = torch.maximum(nw, nh) score = torch.where(coll, torch.full_like(nw, INF), ext * ext + nw * nh) j = score.argmin() @@ -465,7 +503,8 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces, # ---- Prepare pass 1: best-rotation + scale + bbox for ALL charts at once (batched) ---- vcount = [int(u.shape[0]) for u in chart_uvs] fcount = [int(f.shape[0]) for f in chart_faces] - vmax = max(vcount); fmax = max(fcount) + vmax = max(vcount) + fmax = max(fcount) uvs_pad = torch.zeros(n, vmax, 2, device=device) vmask = torch.zeros(n, vmax, dtype=torch.bool, device=device) faces_pad = torch.zeros(n, fmax, 3, dtype=torch.long, device=device) @@ -488,8 +527,10 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces, cc, ss = cos_a[ti][:, None], sin_a[ti][:, None] # (N,1) rx = torch.addcmul(u0 * cc, u1, ss, value=-1) # (N,Vmax) ry = torch.addcmul(u0 * ss, u1, cc) - rxmin = (rx + mlo).amin(1); rxmax = (rx + mhi).amax(1) # (N,) - rymin = (ry + mlo).amin(1); rymax = (ry + mhi).amax(1) + rxmin = (rx + mlo).amin(1) # (N,) + rxmax = (rx + mhi).amax(1) + rymin = (ry + mlo).amin(1) + rymax = (ry + mhi).amax(1) a3 = torch.tensor([max(a, 1e-12) for a in chart_3d_areas], device=device) au = torch.tensor([max(a, 1e-12) for a in chart_uv_areas], device=device) base = (a3 / au).sqrt() * texels_per_unit @@ -504,7 +545,8 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces, # one sync: pull all per-chart scalars thetas = ang[ti].cpu().tolist() scales = scale.cpu().tolist() - bws = bw_t.cpu().tolist(); bhs = bh_t.cpu().tolist() + bws = bw_t.cpu().tolist() + bhs = bh_t.cpu().tolist() # ---- Prepare pass 2: rasterize ALL charts at once, then trim each bitmap to its bounds ---- buf, cbase = _raster_all_torch(uvs_tex_pad, faces_pad, fmask, bw_t, bh_t, padding_texels, device) @@ -513,7 +555,8 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces, for i in range(n): bm = buf[cb[i]:cb[i + 1]].view(bhs[i], bws[i]) raw.append(bm) - rr = torch.arange(bm.shape[0], device=device); cc = torch.arange(bm.shape[1], device=device) + rr = torch.arange(bm.shape[0], device=device) + cc = torch.arange(bm.shape[1], device=device) rmax = torch.where(bm.any(1), rr, rr.new_full((), -1)).amax() # last occupied row / col (-1 if empty) cmax = torch.where(bm.any(0), cc, cc.new_full((), -1)).amax() bnd.append(torch.stack([rmax, cmax])) @@ -527,10 +570,12 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces, bm = (raw[i][:rm + 1, :cm + 1].contiguous() if rm >= 0 and cm >= 0 else torch.zeros((1, 1), dtype=torch.bool, device=device)) bm_rot = torch.flip(bm.t(), dims=[1]).contiguous() - pix_l.append(bm.nonzero()); pixr_l.append(bm_rot.nonzero()) + pix_l.append(bm.nonzero()) + pixr_l.append(bm_rot.nonzero()) dim_l.append((int(bm.shape[0]), int(bm.shape[1]))) dimr_l.append((int(bm_rot.shape[0]), int(bm_rot.shape[1]))) - col_tops.append(_col_top(bm)); col_tops_rot.append(_col_top(bm_rot)) + col_tops.append(_col_top(bm)) + col_tops_rot.append(_col_top(bm_rot)) bm_h.append(int(bm.shape[0])) wmax = max(d[1] for d in dim_l + dimr_l) ct_pad = torch.full((n, wmax), -1, dtype=torch.long, device=device) @@ -557,8 +602,11 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces, if cur_h + margin > atlas.shape[0] or cur_w + margin > atlas.shape[1]: ns = max(atlas.shape[0], cur_h + margin, cur_w + margin) na = torch.zeros((ns, ns), dtype=torch.bool, device=device) - na[:atlas.shape[0], :atlas.shape[1]] = atlas; atlas = na - nsk = torch.zeros(ns, dtype=torch.long, device=device); nsk[:sky_t.shape[0]] = sky_t; sky_t = nsk + na[:atlas.shape[0], :atlas.shape[1]] = atlas + atlas = na + nsk = torch.zeros(ns, dtype=torch.long, device=device) + nsk[:sky_t.shape[0]] = sky_t + sky_t = nsk dim, dimr = dim_l[ci], dimr_l[ci] step = max(1, min(dim[0], dim[1]) // 8) cand0, cand1 = _build_candidates_gpu(sky_t, cur_w, cur_h, dim[1], dimr[1], step, rand_n, gen, device) @@ -569,7 +617,8 @@ def _pack_bitmap_torch(chart_uvs, chart_3d_areas, chart_uv_areas, chart_faces, pix = pixr_l[ci] if swap else pix_l[ci] bh_, bw_ = (dimr if swap else dim) atlas[by + pix[:, 0], bx + pix[:, 1]] = True # sparse blit - cur_w = max(cur_w, bx + bw_); cur_h = max(cur_h, by + bh_) + cur_w = max(cur_w, bx + bw_) + cur_h = max(cur_h, by + bh_) ct = (ctr_pad if swap else ct_pad)[ci, :bw_] # GPU skyline lift ix = torch.arange(bx, bx + bw_, device=device) sky_t[ix] = torch.where(ct >= 0, torch.maximum(sky_t[ix], by + ct + 1), sky_t[ix]) diff --git a/comfy_extras/mesh3d/uv_unwrap/segment.py b/comfy_extras/mesh3d/uv_unwrap/segment.py index 48dc82ab9..12dce3e2b 100644 --- a/comfy_extras/mesh3d/uv_unwrap/segment.py +++ b/comfy_extras/mesh3d/uv_unwrap/segment.py @@ -33,13 +33,17 @@ def _face_curvature_jit(face_normal: np.ndarray, face_face: np.ndarray) -> np.nd F = face_normal.shape[0] raw = np.zeros(F, dtype=np.float32) for f in range(F): - nx = face_normal[f, 0]; ny = face_normal[f, 1]; nz = face_normal[f, 2] + nx = face_normal[f, 0] + ny = face_normal[f, 1] + nz = face_normal[f, 2] s = np.float32(0.0) for e in range(3): nb = face_face[f, e] if nb < 0: continue - mx = face_normal[nb, 0]; my = face_normal[nb, 1]; mz = face_normal[nb, 2] + mx = face_normal[nb, 0] + my = face_normal[nb, 1] + mz = face_normal[nb, 2] d = nx*mx + ny*my + nz*mz s += np.float32(1.0) - d raw[f] = s @@ -70,7 +74,9 @@ def _farthest_point_seeds_jit( continue seeds[n_seeds] = s n_seeds += 1 - sx = face_centroid[s, 0]; sy = face_centroid[s, 1]; sz = face_centroid[s, 2] + sx = face_centroid[s, 0] + sy = face_centroid[s, 1] + sz = face_centroid[s, 2] for f in range(F): dx = face_centroid[f, 0] - sx dy = face_centroid[f, 1] - sy @@ -145,7 +151,9 @@ def _cost_grow_iter_jit( for f in range(F): if face_chart[f] != -1: continue - nx = face_normal[f, 0]; ny = face_normal[f, 1]; nz = face_normal[f, 2] + nx = face_normal[f, 0] + ny = face_normal[f, 1] + nz = face_normal[f, 2] af = face_area[f] for e0 in range(3): nb0 = face_face[f, e0]