diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 4c398294a..eb410fe8b 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -787,8 +787,10 @@ class SparseStructureFlowModel(nn.Module): return h def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0): + t_shifted /= 1000.0 t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1)) t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear) + t_new *= 1000.0 return t_new class Trellis2(nn.Module): @@ -841,10 +843,13 @@ class Trellis2(nn.Module): out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure timestep = timestep_reshift(timestep) + orig_bsz = x.shape[0] if shape_rule: x = x[0].unsqueeze(0) - timestep = timestep[0] + timestep = timestep[0].unsqueeze(0) out = self.structure_model(x, timestep, context if not shape_rule else cond) + if shape_rule: + out = out.repeat(orig_bsz, 1, 1, 1, 1) out.generation_mode = mode return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 4d97129eb..ad9881db7 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -295,7 +295,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): def simplify_fn(vertices, faces, target=100000): if vertices.shape[0] <= target: - return + return vertices, faces min_feat = vertices.min(dim=0)[0] max_feat = vertices.max(dim=0)[0] @@ -334,6 +334,19 @@ def simplify_fn(vertices, faces, target=100000): return final_vertices, final_faces def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): + is_batched = vertices.ndim == 3 + if is_batched: + batch_size = vertices.shape[0] + if batch_size > 1: + v_out, f_out = [], [] + for i in range(batch_size): + v, f = fill_holes_fn(vertices[i], faces[i], max_hole_perimeter) + v_out.append(v) + f_out.append(f) + return torch.stack(v_out), torch.stack(f_out) + + vertices = vertices.squeeze(0) + faces = faces.squeeze(0) device = vertices.device orig_vertices = vertices @@ -346,24 +359,23 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): ], dim=0) edges_sorted, _ = torch.sort(edges, dim=1) - unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) - boundary_mask = counts == 1 boundary_edges_sorted = unique_edges[boundary_mask] if boundary_edges_sorted.shape[0] == 0: - return + if is_batched: + return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) + return orig_vertices, orig_faces + max_idx = vertices.shape[0] - _, inverse_indices, counts_packed = torch.unique( - torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1], - return_inverse=True, return_counts=True - ) + packed_edges_all = torch.sort(edges, dim=1).values + packed_edges_all = packed_edges_all[:, 0] * max_idx + packed_edges_all[:, 1] - boundary_packed_mask = counts_packed == 1 - is_boundary_edge = boundary_packed_mask[inverse_indices] + packed_boundary = boundary_edges_sorted[:, 0] * max_idx + boundary_edges_sorted[:, 1] + is_boundary_edge = torch.isin(packed_edges_all, packed_boundary) active_boundary_edges = edges[is_boundary_edge] adj = {} @@ -373,78 +385,61 @@ def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): loops = [] visited_edges = set() - - possible_starts = list(adj.keys()) - processed_nodes = set() - - for start_node in possible_starts: - if start_node in processed_nodes: - continue - - current_loop = [] - curr = start_node - + for start_node in list(adj.keys()): + if start_node in processed_nodes: continue + current_loop, curr = [], start_node while curr in adj: next_node = adj[curr] - if (curr, next_node) in visited_edges: - break - + if (curr, next_node) in visited_edges: break visited_edges.add((curr, next_node)) processed_nodes.add(curr) current_loop.append(curr) - curr = next_node - if curr == start_node: loops.append(current_loop) break - - if len(current_loop) > len(edges_np): - break + if len(current_loop) > len(edges_np): break if not loops: - return + if is_batched: return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) + return orig_vertices, orig_faces new_faces = [] - v_offset = vertices.shape[0] - valid_new_verts = [] for loop_indices in loops: - if len(loop_indices) < 3: - continue - + if len(loop_indices) < 3: continue loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) loop_verts = vertices[loop_tensor] - diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) perimeter = torch.norm(diffs, dim=1).sum() - if perimeter > max_hole_perimeter: - continue + if perimeter > max_hole_perimeter: continue center = loop_verts.mean(dim=0) valid_new_verts.append(center) - c_idx = v_offset v_offset += 1 num_v = len(loop_indices) for i in range(num_v): - v_curr = loop_indices[i] - v_next = loop_indices[(i + 1) % num_v] + v_curr, v_next = loop_indices[i], loop_indices[(i + 1) % num_v] new_faces.append([v_curr, v_next, c_idx]) if len(valid_new_verts) > 0: added_vertices = torch.stack(valid_new_verts, dim=0) added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) + vertices = torch.cat([orig_vertices, added_vertices], dim=0) + faces = torch.cat([orig_faces, added_faces], dim=0) + else: + vertices, faces = orig_vertices, orig_faces - vertices_f = torch.cat([orig_vertices, added_vertices], dim=0) - faces_f = torch.cat([orig_faces, added_faces], dim=0) + if is_batched: + return vertices.unsqueeze(0), faces.unsqueeze(0) - return vertices_f, faces_f + return vertices, faces class PostProcessMesh(IO.ComfyNode): @classmethod @@ -454,8 +449,8 @@ class PostProcessMesh(IO.ComfyNode): category="latent/3d", inputs=[ IO.Mesh.Input("mesh"), - IO.Int.Input("simplify", default=100_000, min=0), # max? - IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0) + IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), # max? + IO.Float.Input("fill_holes_perimeter", default=0.003, min=0.0, step=0.0001) ], outputs=[ IO.Mesh.Output("output_mesh"),