mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
postprocessing node fixes + model small fixes
This commit is contained in:
parent
91fa563b21
commit
c14317d6e0
@ -787,8 +787,10 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
return h
|
return h
|
||||||
|
|
||||||
def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
|
def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
|
||||||
|
t_shifted /= 1000.0
|
||||||
t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1))
|
t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1))
|
||||||
t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear)
|
t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear)
|
||||||
|
t_new *= 1000.0
|
||||||
return t_new
|
return t_new
|
||||||
|
|
||||||
class Trellis2(nn.Module):
|
class Trellis2(nn.Module):
|
||||||
@ -841,10 +843,13 @@ class Trellis2(nn.Module):
|
|||||||
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
||||||
else: # structure
|
else: # structure
|
||||||
timestep = timestep_reshift(timestep)
|
timestep = timestep_reshift(timestep)
|
||||||
|
orig_bsz = x.shape[0]
|
||||||
if shape_rule:
|
if shape_rule:
|
||||||
x = x[0].unsqueeze(0)
|
x = x[0].unsqueeze(0)
|
||||||
timestep = timestep[0]
|
timestep = timestep[0].unsqueeze(0)
|
||||||
out = self.structure_model(x, timestep, context if not shape_rule else cond)
|
out = self.structure_model(x, timestep, context if not shape_rule else cond)
|
||||||
|
if shape_rule:
|
||||||
|
out = out.repeat(orig_bsz, 1, 1, 1, 1)
|
||||||
|
|
||||||
out.generation_mode = mode
|
out.generation_mode = mode
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -295,7 +295,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
|||||||
def simplify_fn(vertices, faces, target=100000):
|
def simplify_fn(vertices, faces, target=100000):
|
||||||
|
|
||||||
if vertices.shape[0] <= target:
|
if vertices.shape[0] <= target:
|
||||||
return
|
return vertices, faces
|
||||||
|
|
||||||
min_feat = vertices.min(dim=0)[0]
|
min_feat = vertices.min(dim=0)[0]
|
||||||
max_feat = vertices.max(dim=0)[0]
|
max_feat = vertices.max(dim=0)[0]
|
||||||
@ -334,6 +334,19 @@ def simplify_fn(vertices, faces, target=100000):
|
|||||||
return final_vertices, final_faces
|
return final_vertices, final_faces
|
||||||
|
|
||||||
def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2):
|
def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2):
|
||||||
|
is_batched = vertices.ndim == 3
|
||||||
|
if is_batched:
|
||||||
|
batch_size = vertices.shape[0]
|
||||||
|
if batch_size > 1:
|
||||||
|
v_out, f_out = [], []
|
||||||
|
for i in range(batch_size):
|
||||||
|
v, f = fill_holes_fn(vertices[i], faces[i], max_hole_perimeter)
|
||||||
|
v_out.append(v)
|
||||||
|
f_out.append(f)
|
||||||
|
return torch.stack(v_out), torch.stack(f_out)
|
||||||
|
|
||||||
|
vertices = vertices.squeeze(0)
|
||||||
|
faces = faces.squeeze(0)
|
||||||
|
|
||||||
device = vertices.device
|
device = vertices.device
|
||||||
orig_vertices = vertices
|
orig_vertices = vertices
|
||||||
@ -346,24 +359,23 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2):
|
|||||||
], dim=0)
|
], dim=0)
|
||||||
|
|
||||||
edges_sorted, _ = torch.sort(edges, dim=1)
|
edges_sorted, _ = torch.sort(edges, dim=1)
|
||||||
|
|
||||||
unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True)
|
unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True)
|
||||||
|
|
||||||
boundary_mask = counts == 1
|
boundary_mask = counts == 1
|
||||||
boundary_edges_sorted = unique_edges[boundary_mask]
|
boundary_edges_sorted = unique_edges[boundary_mask]
|
||||||
|
|
||||||
if boundary_edges_sorted.shape[0] == 0:
|
if boundary_edges_sorted.shape[0] == 0:
|
||||||
return
|
if is_batched:
|
||||||
|
return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0)
|
||||||
|
return orig_vertices, orig_faces
|
||||||
|
|
||||||
max_idx = vertices.shape[0]
|
max_idx = vertices.shape[0]
|
||||||
|
|
||||||
_, inverse_indices, counts_packed = torch.unique(
|
packed_edges_all = torch.sort(edges, dim=1).values
|
||||||
torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1],
|
packed_edges_all = packed_edges_all[:, 0] * max_idx + packed_edges_all[:, 1]
|
||||||
return_inverse=True, return_counts=True
|
|
||||||
)
|
|
||||||
|
|
||||||
boundary_packed_mask = counts_packed == 1
|
packed_boundary = boundary_edges_sorted[:, 0] * max_idx + boundary_edges_sorted[:, 1]
|
||||||
is_boundary_edge = boundary_packed_mask[inverse_indices]
|
|
||||||
|
|
||||||
|
is_boundary_edge = torch.isin(packed_edges_all, packed_boundary)
|
||||||
active_boundary_edges = edges[is_boundary_edge]
|
active_boundary_edges = edges[is_boundary_edge]
|
||||||
|
|
||||||
adj = {}
|
adj = {}
|
||||||
@ -373,78 +385,61 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2):
|
|||||||
|
|
||||||
loops = []
|
loops = []
|
||||||
visited_edges = set()
|
visited_edges = set()
|
||||||
|
|
||||||
possible_starts = list(adj.keys())
|
|
||||||
|
|
||||||
processed_nodes = set()
|
processed_nodes = set()
|
||||||
|
for start_node in list(adj.keys()):
|
||||||
for start_node in possible_starts:
|
if start_node in processed_nodes: continue
|
||||||
if start_node in processed_nodes:
|
current_loop, curr = [], start_node
|
||||||
continue
|
|
||||||
|
|
||||||
current_loop = []
|
|
||||||
curr = start_node
|
|
||||||
|
|
||||||
while curr in adj:
|
while curr in adj:
|
||||||
next_node = adj[curr]
|
next_node = adj[curr]
|
||||||
if (curr, next_node) in visited_edges:
|
if (curr, next_node) in visited_edges: break
|
||||||
break
|
|
||||||
|
|
||||||
visited_edges.add((curr, next_node))
|
visited_edges.add((curr, next_node))
|
||||||
processed_nodes.add(curr)
|
processed_nodes.add(curr)
|
||||||
current_loop.append(curr)
|
current_loop.append(curr)
|
||||||
|
|
||||||
curr = next_node
|
curr = next_node
|
||||||
|
|
||||||
if curr == start_node:
|
if curr == start_node:
|
||||||
loops.append(current_loop)
|
loops.append(current_loop)
|
||||||
break
|
break
|
||||||
|
if len(current_loop) > len(edges_np): break
|
||||||
if len(current_loop) > len(edges_np):
|
|
||||||
break
|
|
||||||
|
|
||||||
if not loops:
|
if not loops:
|
||||||
return
|
if is_batched: return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0)
|
||||||
|
return orig_vertices, orig_faces
|
||||||
|
|
||||||
new_faces = []
|
new_faces = []
|
||||||
|
|
||||||
v_offset = vertices.shape[0]
|
v_offset = vertices.shape[0]
|
||||||
|
|
||||||
valid_new_verts = []
|
valid_new_verts = []
|
||||||
|
|
||||||
for loop_indices in loops:
|
for loop_indices in loops:
|
||||||
if len(loop_indices) < 3:
|
if len(loop_indices) < 3: continue
|
||||||
continue
|
|
||||||
|
|
||||||
loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device)
|
loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device)
|
||||||
loop_verts = vertices[loop_tensor]
|
loop_verts = vertices[loop_tensor]
|
||||||
|
|
||||||
diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0)
|
diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0)
|
||||||
perimeter = torch.norm(diffs, dim=1).sum()
|
perimeter = torch.norm(diffs, dim=1).sum()
|
||||||
|
|
||||||
if perimeter > max_hole_perimeter:
|
if perimeter > max_hole_perimeter: continue
|
||||||
continue
|
|
||||||
|
|
||||||
center = loop_verts.mean(dim=0)
|
center = loop_verts.mean(dim=0)
|
||||||
valid_new_verts.append(center)
|
valid_new_verts.append(center)
|
||||||
|
|
||||||
c_idx = v_offset
|
c_idx = v_offset
|
||||||
v_offset += 1
|
v_offset += 1
|
||||||
|
|
||||||
num_v = len(loop_indices)
|
num_v = len(loop_indices)
|
||||||
for i in range(num_v):
|
for i in range(num_v):
|
||||||
v_curr = loop_indices[i]
|
v_curr, v_next = loop_indices[i], loop_indices[(i + 1) % num_v]
|
||||||
v_next = loop_indices[(i + 1) % num_v]
|
|
||||||
new_faces.append([v_curr, v_next, c_idx])
|
new_faces.append([v_curr, v_next, c_idx])
|
||||||
|
|
||||||
if len(valid_new_verts) > 0:
|
if len(valid_new_verts) > 0:
|
||||||
added_vertices = torch.stack(valid_new_verts, dim=0)
|
added_vertices = torch.stack(valid_new_verts, dim=0)
|
||||||
added_faces = torch.tensor(new_faces, dtype=torch.long, device=device)
|
added_faces = torch.tensor(new_faces, dtype=torch.long, device=device)
|
||||||
|
vertices = torch.cat([orig_vertices, added_vertices], dim=0)
|
||||||
|
faces = torch.cat([orig_faces, added_faces], dim=0)
|
||||||
|
else:
|
||||||
|
vertices, faces = orig_vertices, orig_faces
|
||||||
|
|
||||||
vertices_f = torch.cat([orig_vertices, added_vertices], dim=0)
|
if is_batched:
|
||||||
faces_f = torch.cat([orig_faces, added_faces], dim=0)
|
return vertices.unsqueeze(0), faces.unsqueeze(0)
|
||||||
|
|
||||||
return vertices_f, faces_f
|
return vertices, faces
|
||||||
|
|
||||||
class PostProcessMesh(IO.ComfyNode):
|
class PostProcessMesh(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -454,8 +449,8 @@ class PostProcessMesh(IO.ComfyNode):
|
|||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Mesh.Input("mesh"),
|
IO.Mesh.Input("mesh"),
|
||||||
IO.Int.Input("simplify", default=100_000, min=0), # max?
|
IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), # max?
|
||||||
IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0)
|
IO.Float.Input("fill_holes_perimeter", default=0.003, min=0.0, step=0.0001)
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Mesh.Output("output_mesh"),
|
IO.Mesh.Output("output_mesh"),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user