postprocessing node fixes + model small fixes

This commit is contained in:
Yousef Rafat 2026-02-17 00:10:48 +02:00
parent 91fa563b21
commit c14317d6e0
2 changed files with 47 additions and 47 deletions

View File

@ -787,8 +787,10 @@ class SparseStructureFlowModel(nn.Module):
return h return h
def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0): def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
t_shifted /= 1000.0
t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1)) t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1))
t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear) t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear)
t_new *= 1000.0
return t_new return t_new
class Trellis2(nn.Module): class Trellis2(nn.Module):
@ -841,10 +843,13 @@ class Trellis2(nn.Module):
out = self.shape2txt(x, timestep, context if not txt_rule else cond) out = self.shape2txt(x, timestep, context if not txt_rule else cond)
else: # structure else: # structure
timestep = timestep_reshift(timestep) timestep = timestep_reshift(timestep)
orig_bsz = x.shape[0]
if shape_rule: if shape_rule:
x = x[0].unsqueeze(0) x = x[0].unsqueeze(0)
timestep = timestep[0] timestep = timestep[0].unsqueeze(0)
out = self.structure_model(x, timestep, context if not shape_rule else cond) out = self.structure_model(x, timestep, context if not shape_rule else cond)
if shape_rule:
out = out.repeat(orig_bsz, 1, 1, 1, 1)
out.generation_mode = mode out.generation_mode = mode
return out return out

View File

@ -295,7 +295,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
def simplify_fn(vertices, faces, target=100000): def simplify_fn(vertices, faces, target=100000):
if vertices.shape[0] <= target: if vertices.shape[0] <= target:
return return vertices, faces
min_feat = vertices.min(dim=0)[0] min_feat = vertices.min(dim=0)[0]
max_feat = vertices.max(dim=0)[0] max_feat = vertices.max(dim=0)[0]
@ -334,6 +334,19 @@ def simplify_fn(vertices, faces, target=100000):
return final_vertices, final_faces return final_vertices, final_faces
def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2):
is_batched = vertices.ndim == 3
if is_batched:
batch_size = vertices.shape[0]
if batch_size > 1:
v_out, f_out = [], []
for i in range(batch_size):
v, f = fill_holes_fn(vertices[i], faces[i], max_hole_perimeter)
v_out.append(v)
f_out.append(f)
return torch.stack(v_out), torch.stack(f_out)
vertices = vertices.squeeze(0)
faces = faces.squeeze(0)
device = vertices.device device = vertices.device
orig_vertices = vertices orig_vertices = vertices
@ -346,24 +359,23 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2):
], dim=0) ], dim=0)
edges_sorted, _ = torch.sort(edges, dim=1) edges_sorted, _ = torch.sort(edges, dim=1)
unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True)
boundary_mask = counts == 1 boundary_mask = counts == 1
boundary_edges_sorted = unique_edges[boundary_mask] boundary_edges_sorted = unique_edges[boundary_mask]
if boundary_edges_sorted.shape[0] == 0: if boundary_edges_sorted.shape[0] == 0:
return if is_batched:
return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0)
return orig_vertices, orig_faces
max_idx = vertices.shape[0] max_idx = vertices.shape[0]
_, inverse_indices, counts_packed = torch.unique( packed_edges_all = torch.sort(edges, dim=1).values
torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1], packed_edges_all = packed_edges_all[:, 0] * max_idx + packed_edges_all[:, 1]
return_inverse=True, return_counts=True
)
boundary_packed_mask = counts_packed == 1 packed_boundary = boundary_edges_sorted[:, 0] * max_idx + boundary_edges_sorted[:, 1]
is_boundary_edge = boundary_packed_mask[inverse_indices]
is_boundary_edge = torch.isin(packed_edges_all, packed_boundary)
active_boundary_edges = edges[is_boundary_edge] active_boundary_edges = edges[is_boundary_edge]
adj = {} adj = {}
@ -373,78 +385,61 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2):
loops = [] loops = []
visited_edges = set() visited_edges = set()
possible_starts = list(adj.keys())
processed_nodes = set() processed_nodes = set()
for start_node in list(adj.keys()):
for start_node in possible_starts: if start_node in processed_nodes: continue
if start_node in processed_nodes: current_loop, curr = [], start_node
continue
current_loop = []
curr = start_node
while curr in adj: while curr in adj:
next_node = adj[curr] next_node = adj[curr]
if (curr, next_node) in visited_edges: if (curr, next_node) in visited_edges: break
break
visited_edges.add((curr, next_node)) visited_edges.add((curr, next_node))
processed_nodes.add(curr) processed_nodes.add(curr)
current_loop.append(curr) current_loop.append(curr)
curr = next_node curr = next_node
if curr == start_node: if curr == start_node:
loops.append(current_loop) loops.append(current_loop)
break break
if len(current_loop) > len(edges_np): break
if len(current_loop) > len(edges_np):
break
if not loops: if not loops:
return if is_batched: return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0)
return orig_vertices, orig_faces
new_faces = [] new_faces = []
v_offset = vertices.shape[0] v_offset = vertices.shape[0]
valid_new_verts = [] valid_new_verts = []
for loop_indices in loops: for loop_indices in loops:
if len(loop_indices) < 3: if len(loop_indices) < 3: continue
continue
loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device)
loop_verts = vertices[loop_tensor] loop_verts = vertices[loop_tensor]
diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0)
perimeter = torch.norm(diffs, dim=1).sum() perimeter = torch.norm(diffs, dim=1).sum()
if perimeter > max_hole_perimeter: if perimeter > max_hole_perimeter: continue
continue
center = loop_verts.mean(dim=0) center = loop_verts.mean(dim=0)
valid_new_verts.append(center) valid_new_verts.append(center)
c_idx = v_offset c_idx = v_offset
v_offset += 1 v_offset += 1
num_v = len(loop_indices) num_v = len(loop_indices)
for i in range(num_v): for i in range(num_v):
v_curr = loop_indices[i] v_curr, v_next = loop_indices[i], loop_indices[(i + 1) % num_v]
v_next = loop_indices[(i + 1) % num_v]
new_faces.append([v_curr, v_next, c_idx]) new_faces.append([v_curr, v_next, c_idx])
if len(valid_new_verts) > 0: if len(valid_new_verts) > 0:
added_vertices = torch.stack(valid_new_verts, dim=0) added_vertices = torch.stack(valid_new_verts, dim=0)
added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) added_faces = torch.tensor(new_faces, dtype=torch.long, device=device)
vertices = torch.cat([orig_vertices, added_vertices], dim=0)
faces = torch.cat([orig_faces, added_faces], dim=0)
else:
vertices, faces = orig_vertices, orig_faces
vertices_f = torch.cat([orig_vertices, added_vertices], dim=0) if is_batched:
faces_f = torch.cat([orig_faces, added_faces], dim=0) return vertices.unsqueeze(0), faces.unsqueeze(0)
return vertices_f, faces_f return vertices, faces
class PostProcessMesh(IO.ComfyNode): class PostProcessMesh(IO.ComfyNode):
@classmethod @classmethod
@ -454,8 +449,8 @@ class PostProcessMesh(IO.ComfyNode):
category="latent/3d", category="latent/3d",
inputs=[ inputs=[
IO.Mesh.Input("mesh"), IO.Mesh.Input("mesh"),
IO.Int.Input("simplify", default=100_000, min=0), # max? IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), # max?
IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0) IO.Float.Input("fill_holes_perimeter", default=0.003, min=0.0, step=0.0001)
], ],
outputs=[ outputs=[
IO.Mesh.Output("output_mesh"), IO.Mesh.Output("output_mesh"),