Compare commits

...

2 Commits

Author SHA1 Message Date
Yousef Rafat
391eacf9f0 vertex colors 2026-05-18 23:45:19 +03:00
Yousef Rafat
faeb47b3c8 bug fix in fill_holes 2026-05-18 23:15:00 +03:00
2 changed files with 79 additions and 20 deletions

View File

@ -809,6 +809,10 @@ class Trellis2(nn.Module):
mode = "structure_generation"
not_struct_mode = False
if x.size(-1) == 16 and x.size(-2) == 16:
mode = "structure_generation"
not_struct_mode = False
if not not_struct_mode:
bsz = x.size(0)
x = x[:, :8]

View File

@ -168,11 +168,16 @@ def paint_mesh_default_colors(mesh):
def fill_holes_fn(vertices, faces, max_perimeter=0.03):
is_batched = vertices.ndim == 3
if is_batched:
v_list, f_list = [],[]
v_list, f_list = [], []
for i in range(vertices.shape[0]):
v_i, f_i = fill_holes_fn(vertices[i], faces[i], max_perimeter)
v_list.append(v_i)
f_list.append(f_i)
max_v = max(v.shape[0] for v in v_list)
for i in range(len(v_list)):
if v_list[i].shape[0] < max_v:
pad = torch.zeros(max_v - v_list[i].shape[0], 3, device=v_list[i].device, dtype=v_list[i].dtype)
v_list[i] = torch.cat([v_list[i], pad], dim=0)
return torch.stack(v_list), torch.stack(f_list)
device = vertices.device
@ -194,13 +199,19 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
if boundary_packed.numel() == 0:
return v, f
packed_directed_sorted = edges[:, 0].min(edges[:, 1]).long() * max_v + edges[:, 0].max(edges[:, 1]).long()
is_boundary = torch.isin(packed_directed_sorted, boundary_packed)
b_edges = edges[is_boundary]
# Build undirected boundary edge adjacency
boundary_mask = torch.isin(packed_undirected, boundary_packed)
b_edges = edges_sorted[boundary_mask]
adj = {u.item(): v_idx.item() for u, v_idx in b_edges}
adj = {}
for i in range(b_edges.shape[0]):
a = b_edges[i, 0].item()
b = b_edges[i, 1].item()
adj.setdefault(a, []).append(b)
adj.setdefault(b, []).append(a)
loops =[]
# Trace all boundary loops
loops = []
visited = set()
for start_node in adj.keys():
@ -208,40 +219,84 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
continue
curr = start_node
prev = -1
loop = []
while curr not in visited:
visited.add(curr)
loop.append(curr)
curr = adj.get(curr, -1)
if curr == -1:
neighbors = adj[curr]
candidates = [n for n in neighbors if n != prev]
if not candidates:
loop = []
break
next_node = candidates[0]
prev, curr = curr, next_node
if curr == start_node:
loops.append(loop)
break
new_verts =[]
new_faces = []
v_idx = v.shape[0]
if not loops:
return v, f
# Compute mesh normal for orientation
face_normals = torch.linalg.cross(
v[f[:, 1]] - v[f[:, 0]],
v[f[:, 2]] - v[f[:, 0]],
dim=-1
)
mesh_normal = face_normals.mean(dim=0)
mesh_normal = mesh_normal / (torch.norm(mesh_normal) + 1e-8)
# Classify loops: keep only holes (normal aligns with mesh_normal), discard outer boundary
hole_loops = []
for loop in loops:
loop_t = torch.tensor(loop, device=device, dtype=torch.long)
loop_v = v[loop_t]
next_v = torch.roll(loop_v, -1, dims=0)
cross = torch.linalg.cross(loop_v, next_v, dim=-1)
loop_normal = cross.sum(dim=0)
loop_normal = loop_normal / (torch.norm(loop_normal) + 1e-8)
# Hole: loop normal points same way as mesh normal (both "up" for a hole)
# Outer boundary: loop normal points opposite
if torch.dot(loop_normal, mesh_normal) > 0:
hole_loops.append(loop)
diffs = loop_v - torch.roll(loop_v, shifts=-1, dims=0)
new_verts = []
new_faces = []
v_idx = v.shape[0]
for loop in hole_loops:
loop_t = torch.tensor(loop, device=device, dtype=torch.long)
loop_v = v[loop_t]
# Perimeter check
next_v = torch.roll(loop_v, -1, dims=0)
diffs = loop_v - next_v
perimeter = torch.norm(diffs, dim=1).sum().item()
if perimeter <= max_perimeter:
new_verts.append(loop_v.mean(dim=0))
if perimeter > max_perimeter:
continue
# Ensure CCW winding consistent with mesh
cross = torch.linalg.cross(loop_v, next_v, dim=-1)
loop_normal = cross.sum(dim=0)
loop_normal = loop_normal / (torch.norm(loop_normal) + 1e-8)
if torch.dot(loop_normal, mesh_normal) < 0:
loop = loop[::-1]
if len(loop) == 3:
new_faces.append([loop[0], loop[1], loop[2]])
else:
centroid = loop_v.mean(dim=0)
new_verts.append(centroid)
for i in range(len(loop)):
new_faces.append([loop[(i + 1) % len(loop)], loop[i], v_idx])
new_faces.append([loop[i], loop[(i + 1) % len(loop)], v_idx])
v_idx += 1
if new_verts:
v = torch.cat([v, torch.stack(new_verts)], dim=0)
if new_faces:
f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0)
return v, f
@ -250,7 +305,7 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
def make_double_sided(vertices, faces, colors=None):
"""
Creates a double-sided mesh by duplicating vertices, faces, and colors.
Duplicating vertices prevents opposite faces from sharing the same index,
Duplicating vertices prevents opposite faces from sharing the same index,
which stops the rendering engine from cancelling out the normals to [0,0,0].
"""
is_batched = vertices.ndim == 3
@ -797,8 +852,8 @@ class PostProcessMesh(IO.ComfyNode):
# Safely grab colors if they exist
c_i = None
if hasattr(mesh, 'colors') and mesh.colors is not None:
c_i = mesh.colors[i] if (isinstance(mesh.colors, list) or mesh.colors.ndim == 3) else mesh.colors
if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None:
c_i = mesh.vertex_colors[i] if (isinstance(mesh.vertex_colors, list) or mesh.vertex_colors.ndim == 3) else mesh.vertex_colors
v_i, f_i, c_i = process_single(v_i, f_i, c_i, bar)
@ -822,7 +877,7 @@ class PostProcessMesh(IO.ComfyNode):
else:
# Single Unbatched Mesh[V, 3]
c = mesh.colors if hasattr(mesh, 'colors') and mesh.colors is not None else None
c = mesh.vertex_colors if hasattr(mesh, 'vertex_colors') and mesh.vertex_colors is not None else None
v, f, c = process_single(mesh.vertices, mesh.faces, c)
mesh.vertices = v
mesh.faces = f