postprocessing node fixes + model small fixes

This commit is contained in:
Yousef Rafat 2026-02-17 00:10:48 +02:00
parent 91fa563b21
commit c14317d6e0
2 changed files with 47 additions and 47 deletions

View File

@ -787,8 +787,10 @@ class SparseStructureFlowModel(nn.Module):
return h
def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
t_shifted /= 1000.0
t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1))
t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear)
t_new *= 1000.0
return t_new
class Trellis2(nn.Module):
@ -841,10 +843,13 @@ class Trellis2(nn.Module):
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
else: # structure
timestep = timestep_reshift(timestep)
orig_bsz = x.shape[0]
if shape_rule:
x = x[0].unsqueeze(0)
timestep = timestep[0]
timestep = timestep[0].unsqueeze(0)
out = self.structure_model(x, timestep, context if not shape_rule else cond)
if shape_rule:
out = out.repeat(orig_bsz, 1, 1, 1, 1)
out.generation_mode = mode
return out

View File

@ -295,7 +295,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
def simplify_fn(vertices, faces, target=100000):
if vertices.shape[0] <= target:
return
return vertices, faces
min_feat = vertices.min(dim=0)[0]
max_feat = vertices.max(dim=0)[0]
@ -334,6 +334,19 @@ def simplify_fn(vertices, faces, target=100000):
return final_vertices, final_faces
def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2):
is_batched = vertices.ndim == 3
if is_batched:
batch_size = vertices.shape[0]
if batch_size > 1:
v_out, f_out = [], []
for i in range(batch_size):
v, f = fill_holes_fn(vertices[i], faces[i], max_hole_perimeter)
v_out.append(v)
f_out.append(f)
return torch.stack(v_out), torch.stack(f_out)
vertices = vertices.squeeze(0)
faces = faces.squeeze(0)
device = vertices.device
orig_vertices = vertices
@ -346,24 +359,23 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2):
], dim=0)
edges_sorted, _ = torch.sort(edges, dim=1)
unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True)
boundary_mask = counts == 1
boundary_edges_sorted = unique_edges[boundary_mask]
if boundary_edges_sorted.shape[0] == 0:
return
if is_batched:
return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0)
return orig_vertices, orig_faces
max_idx = vertices.shape[0]
_, inverse_indices, counts_packed = torch.unique(
torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1],
return_inverse=True, return_counts=True
)
packed_edges_all = torch.sort(edges, dim=1).values
packed_edges_all = packed_edges_all[:, 0] * max_idx + packed_edges_all[:, 1]
boundary_packed_mask = counts_packed == 1
is_boundary_edge = boundary_packed_mask[inverse_indices]
packed_boundary = boundary_edges_sorted[:, 0] * max_idx + boundary_edges_sorted[:, 1]
is_boundary_edge = torch.isin(packed_edges_all, packed_boundary)
active_boundary_edges = edges[is_boundary_edge]
adj = {}
@ -373,78 +385,61 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2):
loops = []
visited_edges = set()
possible_starts = list(adj.keys())
processed_nodes = set()
for start_node in possible_starts:
if start_node in processed_nodes:
continue
current_loop = []
curr = start_node
for start_node in list(adj.keys()):
if start_node in processed_nodes: continue
current_loop, curr = [], start_node
while curr in adj:
next_node = adj[curr]
if (curr, next_node) in visited_edges:
break
if (curr, next_node) in visited_edges: break
visited_edges.add((curr, next_node))
processed_nodes.add(curr)
current_loop.append(curr)
curr = next_node
if curr == start_node:
loops.append(current_loop)
break
if len(current_loop) > len(edges_np):
break
if len(current_loop) > len(edges_np): break
if not loops:
return
if is_batched: return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0)
return orig_vertices, orig_faces
new_faces = []
v_offset = vertices.shape[0]
valid_new_verts = []
for loop_indices in loops:
if len(loop_indices) < 3:
continue
if len(loop_indices) < 3: continue
loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device)
loop_verts = vertices[loop_tensor]
diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0)
perimeter = torch.norm(diffs, dim=1).sum()
if perimeter > max_hole_perimeter:
continue
if perimeter > max_hole_perimeter: continue
center = loop_verts.mean(dim=0)
valid_new_verts.append(center)
c_idx = v_offset
v_offset += 1
num_v = len(loop_indices)
for i in range(num_v):
v_curr = loop_indices[i]
v_next = loop_indices[(i + 1) % num_v]
v_curr, v_next = loop_indices[i], loop_indices[(i + 1) % num_v]
new_faces.append([v_curr, v_next, c_idx])
if len(valid_new_verts) > 0:
added_vertices = torch.stack(valid_new_verts, dim=0)
added_faces = torch.tensor(new_faces, dtype=torch.long, device=device)
vertices = torch.cat([orig_vertices, added_vertices], dim=0)
faces = torch.cat([orig_faces, added_faces], dim=0)
else:
vertices, faces = orig_vertices, orig_faces
vertices_f = torch.cat([orig_vertices, added_vertices], dim=0)
faces_f = torch.cat([orig_faces, added_faces], dim=0)
if is_batched:
return vertices.unsqueeze(0), faces.unsqueeze(0)
return vertices_f, faces_f
return vertices, faces
class PostProcessMesh(IO.ComfyNode):
@classmethod
@ -454,8 +449,8 @@ class PostProcessMesh(IO.ComfyNode):
category="latent/3d",
inputs=[
IO.Mesh.Input("mesh"),
IO.Int.Input("simplify", default=100_000, min=0), # max?
IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0)
IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), # max?
IO.Float.Input("fill_holes_perimeter", default=0.003, min=0.0, step=0.0001)
],
outputs=[
IO.Mesh.Output("output_mesh"),