Vertex Clustering, Mask Fix, Normal Fix (#14035)
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled

* Vertex Clustering, Mask Fix, Normal Fix

* detects inverted mask

* update the decimate mesh
This commit is contained in:
Yousef R. Gamaleldin 2026-05-22 00:11:48 +03:00 committed by GitHub
parent 7547897b6f
commit e90bde2f82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 371 additions and 27 deletions

View File

@ -625,7 +625,6 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, normals_in, target_faces,
# Finalize # Finalize
final_v = verts[v_alive] final_v = verts[v_alive]
final_c = colors[v_alive] if colors is not None else None final_c = colors[v_alive] if colors is not None else None
final_n = normals[v_alive] if normals is not None else None
remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device) remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device)
remap[v_alive] = torch.arange(int(v_alive.sum().item()), device=device) remap[v_alive] = torch.arange(int(v_alive.sum().item()), device=device)
@ -640,26 +639,9 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, normals_in, target_faces,
if final_f.numel() > 0: if final_f.numel() > 0:
final_f = torch.unique(torch.sort(final_f, dim=1)[0], dim=0) final_f = torch.unique(torch.sort(final_f, dim=1)[0], dim=0)
if final_n is not None and final_f.numel() > 0:
v0, v1, v2 = final_v[final_f[:, 0]], final_v[final_f[:, 1]], final_v[final_f[:, 2]]
# calculate the actual normal of the simplified faces
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
# Get the average reference normal for each face
n0, n1, n2 = final_n[final_f[:, 0]], final_n[final_f[:, 1]], final_n[final_f[:, 2]]
ref_face_normals = (n0 + n1 + n2) / 3.0
# Dot product to check if they point in the same direction
dot_products = (face_normals * ref_face_normals).sum(dim=-1)
# Flip the indices of ONLY the incorrect faces (swap vertex 1 and 2)
wrong_way_mask = dot_products < 0
final_f[wrong_way_mask] = final_f[wrong_way_mask][:, [0, 2, 1]]
final_v, final_f = _cleanup_mesh(final_v, final_f, min_angle_deg=0.5, max_aspect=100.0) final_v, final_f = _cleanup_mesh(final_v, final_f, min_angle_deg=0.5, max_aspect=100.0)
return final_v, final_f, final_c, final_n return final_v, final_f, final_c, None
def simplify_fn_fast(vertices, faces, colors=None, normals=None, target=100000, max_edge_length=None): def simplify_fn_fast(vertices, faces, colors=None, normals=None, target=100000, max_edge_length=None):
@ -709,6 +691,79 @@ def simplify_fn_fast(vertices, faces, colors=None, normals=None, target=100000,
) )
return final_v, final_f, final_c, final_n return final_v, final_f, final_c, final_n
def simplify_fn_vertex(vertices, faces, colors=None, target=100000):
if vertices.ndim == 3:
v_list, f_list, c_list = [], [], []
for i in range(vertices.shape[0]):
c_in = colors[i] if colors is not None else None
v_i, f_i, c_i = simplify_fn_vertex(vertices[i], faces[i], c_in, target)
v_list.append(v_i)
f_list.append(f_i)
if c_i is not None:
c_list.append(c_i)
c_out = torch.stack(c_list) if len(c_list) > 0 else None
return torch.stack(v_list), torch.stack(f_list), c_out
if faces.shape[0] <= target:
return vertices, faces, colors
device = vertices.device
target_v = max(target / 4.0, 1.0)
min_v = vertices.min(dim=0)[0]
max_v = vertices.max(dim=0)[0]
extent = max_v - min_v
volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8)
cell_size = (volume / target_v) ** (1/3.0)
# Use CPU-side ordered reductions here so repeated runs produce identical
# simplified meshes instead of relying on GPU scatter-add accumulation order.
vertices_np = vertices.detach().cpu().numpy()
faces_np = faces.detach().cpu().numpy()
colors_np = colors.detach().cpu().numpy() if colors is not None else None
min_v_np = min_v.detach().cpu().numpy()
cell_size_value = float(cell_size.detach().cpu())
quantized = np.rint((vertices_np - min_v_np) / cell_size_value).astype(np.int64)
unique_coords, inverse_indices = np.unique(quantized, axis=0, return_inverse=True)
num_cells = unique_coords.shape[0]
new_vertices_np = np.zeros((num_cells, 3), dtype=vertices_np.dtype)
np.add.at(new_vertices_np, inverse_indices, vertices_np)
counts_np = np.bincount(inverse_indices, minlength=num_cells).astype(vertices_np.dtype).reshape(-1, 1)
new_vertices_np = new_vertices_np / np.clip(counts_np, 1, None)
new_colors = None
if colors_np is not None:
new_colors_np = np.zeros((num_cells, colors_np.shape[1]), dtype=colors_np.dtype)
np.add.at(new_colors_np, inverse_indices, colors_np)
new_colors = new_colors_np / np.clip(counts_np, 1, None)
new_faces = inverse_indices[faces_np]
valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \
(new_faces[:, 1] != new_faces[:, 2]) & \
(new_faces[:, 2] != new_faces[:, 0])
new_faces = new_faces[valid_mask]
if new_faces.size == 0:
final_vertices_np = new_vertices_np[:0]
final_faces_np = np.empty((0, 3), dtype=np.int64)
final_colors_np = new_colors[:0] if new_colors is not None else None
else:
unique_face_indices, inv_face = np.unique(new_faces.reshape(-1), return_inverse=True)
final_vertices_np = new_vertices_np[unique_face_indices]
final_faces_np = inv_face.reshape(-1, 3).astype(np.int64)
final_colors_np = new_colors[unique_face_indices] if new_colors is not None else None
final_vertices = torch.from_numpy(final_vertices_np).to(device=device, dtype=vertices.dtype)
final_faces = torch.from_numpy(final_faces_np).to(device=device, dtype=faces.dtype)
final_colors = torch.from_numpy(final_colors_np).to(device=device, dtype=colors.dtype) if final_colors_np is not None else None
return final_vertices, final_faces, final_colors
def compute_vertex_normals(verts, faces): def compute_vertex_normals(verts, faces):
"""Computes area-weighted vertex normals.""" """Computes area-weighted vertex normals."""
# QUICK FIX: Ensure indices are int64 for scatter_add_ # QUICK FIX: Ensure indices are int64 for scatter_add_
@ -781,6 +836,235 @@ def _process_mesh_batch(mesh, per_item_fn):
return IO.NodeOutput(mesh) return IO.NodeOutput(mesh)
def fix_face_orientation(vertices, faces, reference_normals=None):
num_faces = faces.shape[0]
if num_faces == 0:
return faces
device = faces.device
corrected = faces.clone()
idx = torch.tensor([[0, 1], [1, 2], [2, 0]], dtype=torch.int64, device=device)
edges = corrected[:, idx] # (num_faces, 3, 2)
edges_canon = torch.sort(edges, dim=2)[0]
edges_flat = edges_canon.view(-1, 2)
max_vert = vertices.shape[0]
edge_hash = edges_flat[:, 0] * max_vert + edges_flat[:, 1]
hash_sorted, sort_idx = torch.sort(edge_hash)
hash_diff = hash_sorted[1:] != hash_sorted[:-1]
hash_diff = torch.cat([torch.tensor([True], device=device), hash_diff])
unique_starts = torch.nonzero(hash_diff, as_tuple=True)[0]
unique_ends = torch.cat([unique_starts[1:], torch.tensor([len(hash_sorted)], device=device)])
run_lengths = unique_ends - unique_starts
manifold_mask = run_lengths == 2
manifold_starts = unique_starts[manifold_mask]
component_id_np = np.full(num_faces, -1, dtype=np.int64)
if manifold_starts.numel() > 0:
# Replaces slow, nested element-wise matching with direct index mapping
f_a = sort_idx[manifold_starts] // 3
f_b = sort_idx[manifold_starts + 1] // 3
local_edge_a = sort_idx[manifold_starts] % 3
local_edge_b = sort_idx[manifold_starts + 1] % 3
dir_edge_a = edges[f_a, local_edge_a]
dir_edge_b = edges[f_b, local_edge_b]
opposite = (dir_edge_a == dir_edge_b.flip(dims=[1])).all(dim=1)
needs_flip_rel = ~opposite
adj_faces = torch.cat([f_a, f_b])
adj_neighbors = torch.cat([f_b, f_a])
adj_flip = torch.cat([needs_flip_rel, needs_flip_rel])
adj_order = torch.argsort(adj_faces)
adj_faces_np = adj_faces[adj_order].cpu().numpy()
adj_neighbors_np = adj_neighbors[adj_order].cpu().numpy()
adj_flip_np = adj_flip[adj_order].cpu().numpy()
# Build CSR-style adjacency on CPU using NumPy
adj_ptr_np = np.zeros(num_faces + 1, dtype=np.int64)
counts_np = np.bincount(adj_faces_np, minlength=num_faces)
adj_ptr_np[1:] = np.cumsum(counts_np)
visited_np = np.zeros(num_faces, dtype=bool)
flip_state_np = np.zeros(num_faces, dtype=bool)
comp_counter = 0
queue_np = np.empty(num_faces, dtype=np.int64)
for seed in range(num_faces):
if visited_np[seed]:
continue
visited_np[seed] = True
component_id_np[seed] = comp_counter
q_head = 0
q_tail = 1
queue_np[0] = seed
while q_head < q_tail:
current = queue_np[q_head]
q_head += 1
start = adj_ptr_np[current]
end = adj_ptr_np[current + 1]
if start == end:
continue
nbrs = adj_neighbors_np[start:end]
flips = adj_flip_np[start:end]
src_flip = flip_state_np[current]
unvisited_mask = ~visited_np[nbrs]
if not np.any(unvisited_mask):
continue
nbrs_new = nbrs[unvisited_mask]
flips_new = flips[unvisited_mask]
visited_np[nbrs_new] = True
component_id_np[nbrs_new] = comp_counter
# NumPy bitwise XOR is fast and direct
flip_state_np[nbrs_new] = flips_new ^ src_flip
n_new = len(nbrs_new)
queue_np[q_tail:q_tail + n_new] = nbrs_new
q_tail += n_new
comp_counter += 1
flip_state = torch.from_numpy(flip_state_np).to(device=device)
component_id = torch.from_numpy(component_id_np).to(device=device)
if flip_state.any():
corrected[flip_state] = corrected[flip_state][:, [0, 2, 1]]
else:
component_id = torch.arange(num_faces, device=device)
v0 = vertices[corrected[:, 0]]
v1 = vertices[corrected[:, 1]]
v2 = vertices[corrected[:, 2]]
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
face_normals = face_normals / (torch.norm(face_normals, dim=-1, keepdim=True) + 1e-8)
num_components = int(component_id.max().item()) + 1 if component_id.numel() > 0 else 0
if reference_normals is not None:
n0 = reference_normals[corrected[:, 0]]
n1 = reference_normals[corrected[:, 1]]
n2 = reference_normals[corrected[:, 2]]
ref_normals = (n0 + n1 + n2) / 3.0
ref_normals = ref_normals / (torch.norm(ref_normals, dim=-1, keepdim=True) + 1e-8)
votes = (face_normals * ref_normals).sum(dim=-1)
outward_votes_comp = torch.zeros(num_components, dtype=torch.int64, device=device)
inward_votes_comp = torch.zeros(num_components, dtype=torch.int64, device=device)
outward_votes_comp.scatter_add_(0, component_id, (votes > 0).to(torch.int64))
inward_votes_comp.scatter_add_(0, component_id, (votes < 0).to(torch.int64))
n_faces_comp_int = torch.zeros(num_components, dtype=torch.int64, device=device)
n_faces_comp_int.scatter_add_(0, component_id, torch.ones(num_faces, dtype=torch.int64, device=device))
thresholds = torch.maximum(torch.ones_like(n_faces_comp_int), n_faces_comp_int // 10)
should_flip_comp = inward_votes_comp > outward_votes_comp + thresholds
else:
# Vectorized 3-Axis Extreme Majority Vote (Geometrically Infallible)
face_centroids = (v0 + v1 + v2) / 3.0
votes_by_axis = []
for axis in range(3):
coords = face_centroids[:, axis]
# Double stable sort acts as a vectorized lexsort on (coords, component_id)
sort_idx = torch.argsort(coords, stable=True)
sort_idx = sort_idx[torch.argsort(component_id[sort_idx], stable=True)]
# Find group boundaries to get the extreme outer face along this axis per component
comp_id_sorted = component_id[sort_idx]
group_ends = torch.nonzero(comp_id_sorted[1:] != comp_id_sorted[:-1], as_tuple=True)[0]
group_ends = torch.cat([group_ends, torch.tensor([len(comp_id_sorted) - 1], device=device)])
extreme_face_indices = sort_idx[group_ends]
extreme_normals = face_normals[extreme_face_indices]
# Normal's component along the respective axis should be positive
votes_by_axis.append(extreme_normals[:, axis] > 0)
stacked_votes = torch.stack(votes_by_axis, dim=0)
should_flip_comp = stacked_votes.sum(dim=0) < 2 # False if at least 2 axes agree outward
should_flip_face = should_flip_comp[component_id]
if should_flip_face.any():
corrected[should_flip_face] = corrected[should_flip_face][:, [0, 2, 1]]
return corrected
def unweld_and_offset_mesh(vertices, faces, colors=None, z_offset=1e-4):
is_batched = vertices.ndim == 3
device = vertices.device
if is_batched:
B = vertices.shape[0]
F = faces.shape[1]
# 1. Advanced index broadcast to pull all faces in parallel without any Python loops
batch_idx = torch.arange(B, device=device).view(-1, 1, 1)
v_faces = vertices[batch_idx, faces] # shape (B, F, 3, 3)
v0, v1, v2 = v_faces[:, :, 0], v_faces[:, :, 1], v_faces[:, :, 2]
# 2. Compute face normals
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8)
# 3. Translate directly along the face normals in parallel
offset_verts = v_faces + fn.unsqueeze(2) * z_offset
out_v = offset_verts.reshape(B, -1, 3)
# 4. Generate identical faces for all batches using constant expansion (O(1))
f_single = torch.arange(F * 3, device=device).reshape(-1, 3)
out_f = f_single.unsqueeze(0).expand(B, -1, -1)
if colors is not None:
c_faces = colors[batch_idx, faces]
out_c = c_faces.reshape(B, -1, colors.shape[-1])
return out_v, out_f, out_c
return out_v, out_f
# --- Unbatched (Single Mesh) ---
v_faces = vertices[faces] # shape (F, 3, 3)
v0, v1, v2 = v_faces[:, 0], v_faces[:, 1], v_faces[:, 2]
# Compute face normals
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8)
# Offset each face's private vertices along its face normal
offset_verts = v_faces + fn.unsqueeze(1) * z_offset
offset_verts = offset_verts.reshape(-1, 3)
# Generate sequential face indices for the unwelded vertices
f_unwelded = torch.arange(faces.shape[0] * 3, device=vertices.device).reshape(-1, 3)
if colors is not None:
c_faces = colors[faces]
c_unwelded = c_faces.reshape(-1, colors.shape[-1])
return offset_verts, f_unwelded, c_unwelded
return offset_verts, f_unwelded, None
class DecimateMesh(IO.ComfyNode): class DecimateMesh(IO.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -801,8 +1085,23 @@ class DecimateMesh(IO.ComfyNode):
def execute(cls, mesh, target_face_count): def execute(cls, mesh, target_face_count):
def _fn(v, f, c): def _fn(v, f, c):
if target_face_count > 0 and f.shape[0] > target_face_count: if target_face_count > 0 and f.shape[0] > target_face_count:
n = compute_vertex_normals(v, f) try:
v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count) v0, v1, v2 = v[f[:, 0]], v[f[:, 1]], v[f[:, 2]]
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8)
n = torch.zeros_like(v)
n.index_add_(0, f[:, 0], fn)
n.index_add_(0, f[:, 1], fn)
n.index_add_(0, f[:, 2], fn)
n = n / (torch.norm(n, dim=-1, keepdim=True) + 1e-8)
v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count)
f = fix_face_orientation(v, f)
v, f, c = unweld_and_offset_mesh(v, f, colors=c, z_offset=1e-4)
except Exception as e:
logging.warning("Ran into an error while QEM Simplifying, falling back to vertex clustering:\n" + str(e))
v, f, c = simplify_fn_vertex(v, f, c, target_face_count)
return v, f, c return v, f, c
return _process_mesh_batch(mesh, _fn) return _process_mesh_batch(mesh, _fn)

View File

@ -445,13 +445,31 @@ class Trellis2Conditioning(IO.ComfyNode):
] ]
) )
@classmethod
@classmethod @classmethod
def execute(cls, clip_vision_model, image, mask) -> IO.NodeOutput: def execute(cls, clip_vision_model, image, mask) -> IO.NodeOutput:
# Normalize to batched form so per-image conditioning loop below is uniform. # Normalize to batched form so per-image conditioning loop below is uniform.
if image.ndim == 3: if image.ndim == 3:
image = image.unsqueeze(0) image = image.unsqueeze(0)
if mask.ndim == 2: elif image.ndim == 4:
if image.shape[1] in [1, 3, 4] and image.shape[-1] not in [1, 3, 4]:
image = image.permute(0, 2, 3, 1)
# normalize mask to standard [B, H, W] (handling 2D, 3D, and 4D variants)
if mask.ndim == 4:
if mask.shape[1] == 1:
mask = mask.squeeze(1)
elif mask.shape[-1] == 1:
mask = mask.squeeze(-1)
else:
mask = mask[:, :, :, 0] # take first channel as fallback
if mask.ndim == 3:
if mask.shape[-1] == 1:
mask = mask.squeeze(-1).unsqueeze(0)
elif mask.ndim == 2:
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
batch_size = image.shape[0] batch_size = image.shape[0]
if mask.shape[0] == 1 and batch_size > 1: if mask.shape[0] == 1 and batch_size > 1:
mask = mask.expand(batch_size, -1, -1) mask = mask.expand(batch_size, -1, -1)
@ -468,6 +486,27 @@ class Trellis2Conditioning(IO.ComfyNode):
img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
# Ensure img_np is either 2D (grayscale) or 3D (RGB/RGBA)
if img_np.ndim == 3 and img_np.shape[-1] == 1:
img_np = img_np.squeeze(-1)
mask_np = mask_np.squeeze()
# detect inverted mask
border_pixels = np.concatenate([
mask_np[0, :], mask_np[-1, :], mask_np[:, 0], mask_np[:, -1]
])
if np.mean(border_pixels) > 127:
mask_np = 255 - mask_np
mask_np[mask_np < 35] = 0
border_shave = 4
mask_np[:border_shave, :] = 0
mask_np[-border_shave:, :] = 0
mask_np[:, :border_shave] = 0
mask_np[:, -border_shave:] = 0
pil_img = Image.fromarray(img_np) pil_img = Image.fromarray(img_np)
pil_mask = Image.fromarray(mask_np) pil_mask = Image.fromarray(mask_np)
@ -479,7 +518,7 @@ class Trellis2Conditioning(IO.ComfyNode):
pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST)
rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8)
rgba_np[:, :, :3] = np.array(pil_img) rgba_np[:, :, :3] = np.array(pil_img.convert("RGB"))
rgba_np[:, :, 3] = np.array(pil_mask) rgba_np[:, :, 3] = np.array(pil_mask)
alpha = rgba_np[:, :, 3] alpha = rgba_np[:, :, 3]
@ -511,12 +550,18 @@ class Trellis2Conditioning(IO.ComfyNode):
alpha_float = cropped_np[:, :, 3:4] alpha_float = cropped_np[:, :, 3:4]
composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float)
# to match trellis2 code (quantize -> dequantize) # Keep the image as 4-channel RGBA to force TRELLIS to bypass its internal background remover
composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) rgb_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
alpha_uint8 = (alpha_float.squeeze(-1) * 255.0).round().clip(0, 255).astype(np.uint8)
cropped_pil = Image.fromarray(composite_uint8) rgba_composite = np.zeros((cropped_np.shape[0], cropped_np.shape[1], 4), dtype=np.uint8)
rgba_composite[:, :, :3] = rgb_uint8
rgba_composite[:, :, 3] = alpha_uint8
item_conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True) cropped_pil = Image.fromarray(rgba_composite, mode="RGBA")
# Convert to RGB to ensure the CLIP/DINO model receives a 3-channel image
item_conditioning = run_conditioning(clip_vision_model, cropped_pil.convert("RGB"), include_1024=True)
cond_512_list.append(item_conditioning["cond_512"]) cond_512_list.append(item_conditioning["cond_512"])
cond_1024_list.append(item_conditioning["cond_1024"]) cond_1024_list.append(item_conditioning["cond_1024"])