diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 4bbfbff5f..651d516b6 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -810,6 +810,10 @@ class Trellis2(nn.Module): elif mode == "texture_generation": if self.shape2txt is None: raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") + slat = transformer_options.get("shape_slat") + if slat is None: + raise ValueError("shape_slat can't be None") + x = sparse_cat([x, slat]) out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure #timestep = timestep_reshift(timestep) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 23b2f72bb..0b94a2d0a 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -38,6 +38,13 @@ tex_slat_normalization = { ])[None] } +def shape_norm(shape_latent, coords): + std = shape_slat_normalization["std"].to(shape_latent) + mean = shape_slat_normalization["mean"].to(shape_latent) + samples = SparseTensor(feats = shape_latent, coords=coords) + samples = samples * std + mean + return samples + class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def define_schema(cls): @@ -70,10 +77,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): samples = samples["samples"] samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) - std = shape_slat_normalization["std"].to(samples) - mean = shape_slat_normalization["mean"].to(samples) - samples = SparseTensor(feats = samples, coords=coords) - samples = samples * std + mean + samples = shape_norm(samples, coords) mesh, subs = vae.decode_shape_slat(samples, resolution) faces = torch.stack([m.faces for m in mesh]) @@ -313,6 +317,8 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): category="latent/3d", inputs=[ IO.Voxel.Input("structure_output"), + IO.Latent.Input("shape_latent"), + IO.Model.Input("model") ], outputs=[ IO.Latent.Output(), @@ -321,11 +327,15 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_output, model): + def execute(cls, structure_output, shape_latent, model): # TODO decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 + + shape_latent = shape_latent["samples"] + shape_latent = shape_norm(shape_latent, coords) + latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) model = model.clone() model.model_options = model.model_options.copy() @@ -336,6 +346,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"]["coords"] = coords model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + model.model_options["transformer_options"]["shape_slat"] = shape_latent return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) @@ -360,25 +371,34 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): return IO.NodeOutput({"samples": latent, "type": "trellis2"}) def simplify_fn(vertices, faces, target=100000): + is_batched = vertices.ndim == 3 + if is_batched: + v_list, f_list = [], [] + for i in range(vertices.shape[0]): + v_i, f_i = simplify_fn(vertices[i], faces[i], target) + v_list.append(v_i) + f_list.append(f_i) + return torch.stack(v_list), torch.stack(f_list) - if vertices.shape[0] <= target: + if faces.shape[0] <= target: return vertices, faces - min_feat = vertices.min(dim=0)[0] - max_feat = vertices.max(dim=0)[0] - extent = (max_feat - min_feat).max() + device = vertices.device + target_v = target / 2.0 - grid_resolution = int(torch.sqrt(torch.tensor(target)).item() * 1.5) - voxel_size = extent / grid_resolution + min_v = vertices.min(dim=0)[0] + max_v = vertices.max(dim=0)[0] + extent = max_v - min_v - quantized_coords = ((vertices - min_feat) / voxel_size).long() + volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8) + cell_size = (volume / target_v) ** (1/3.0) - unique_coords, inverse_indices = torch.unique(quantized_coords, dim=0, return_inverse=True) + quantized = ((vertices - min_v) / cell_size).round().long() + unique_coords, inverse_indices = torch.unique(quantized, dim=0, return_inverse=True) + num_cells = unique_coords.shape[0] - num_new_verts = unique_coords.shape[0] - new_vertices = torch.zeros((num_new_verts, 3), dtype=vertices.dtype, device=vertices.device) - - counts = torch.zeros((num_new_verts, 1), dtype=vertices.dtype, device=vertices.device) + new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device) + counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device) new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices) counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1])) @@ -387,11 +407,9 @@ def simplify_fn(vertices, faces, target=100000): new_faces = inverse_indices[faces] - v0 = new_faces[:, 0] - v1 = new_faces[:, 1] - v2 = new_faces[:, 2] - - valid_mask = (v0 != v1) & (v1 != v2) & (v2 != v0) + valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \ + (new_faces[:, 1] != new_faces[:, 2]) & \ + (new_faces[:, 2] != new_faces[:, 0]) new_faces = new_faces[valid_mask] unique_face_indices, inv_face = torch.unique(new_faces.reshape(-1), return_inverse=True) @@ -414,7 +432,7 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): v = vertices f = faces - if f.shape[0] == 0: + if f.numel() == 0: return v, f edges = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0) @@ -424,145 +442,75 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03): packed_undirected = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long() unique_packed, counts = torch.unique(packed_undirected, return_counts=True) - boundary_mask = counts == 1 - boundary_packed = unique_packed[boundary_mask] + boundary_packed = unique_packed[counts == 1] if boundary_packed.numel() == 0: return v, f - packed_directed_sorted = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long() + packed_directed_sorted = edges[:, 0].min(edges[:, 1]).long() * max_v + edges[:, 0].max(edges[:, 1]).long() is_boundary = torch.isin(packed_directed_sorted, boundary_packed) - boundary_edges_directed = edges[is_boundary] + b_edges = edges[is_boundary] - adj = {} - in_deg = {} - out_deg = {} - - edges_list = boundary_edges_directed.tolist() - for u, v_idx in edges_list: - if u not in adj: adj[u] = [] - adj[u].append(v_idx) - out_deg[u] = out_deg.get(u, 0) + 1 - in_deg[v_idx] = in_deg.get(v_idx, 0) + 1 - - manifold_nodes = set() - for node in set(list(in_deg.keys()) + list(out_deg.keys())): - if in_deg.get(node, 0) == 1 and out_deg.get(node, 0) == 1: - manifold_nodes.add(node) + adj = {u.item(): v_idx.item() for u, v_idx in b_edges} loops =[] - visited_nodes = set() + visited = set() - for start_node in list(adj.keys()): - if start_node not in manifold_nodes or start_node in visited_nodes: + for start_node in adj.keys(): + if start_node in visited: continue curr = start_node - current_loop =[] + loop = [] - while True: - current_loop.append(curr) - visited_nodes.add(curr) + while curr not in visited: + visited.add(curr) + loop.append(curr) + curr = adj.get(curr, -1) - next_node = adj[curr][0] - - if next_node == start_node: - if len(current_loop) >= 3: - loops.append(current_loop) + if curr == -1: + loop = [] + break + if curr == start_node: + loops.append(loop) break - if next_node not in manifold_nodes or next_node in visited_nodes: - break - - curr = next_node - - if len(current_loop) > len(edges_list): - break - - new_faces =[] - new_verts = [] - curr_v_idx = v.shape[0] + new_verts =[] + new_faces = [] + v_idx = v.shape[0] for loop in loops: - loop_indices = torch.tensor(loop, device=device, dtype=torch.long) - loop_points = v[loop_indices] + loop_t = torch.tensor(loop, device=device, dtype=torch.long) + loop_v = v[loop_t] - # Calculate perimeter - p1 = loop_points - p2 = torch.roll(loop_points, shifts=-1, dims=0) - perimeter = torch.norm(p1 - p2, dim=1).sum().item() + diffs = loop_v - torch.roll(loop_v, shifts=-1, dims=0) + perimeter = torch.norm(diffs, dim=1).sum().item() if perimeter <= max_perimeter: - centroid = loop_points.mean(dim=0) - new_verts.append(centroid) - center_idx = curr_v_idx - curr_v_idx += 1 + new_verts.append(loop_v.mean(dim=0)) for i in range(len(loop)): - u_idx = loop[i] - v_next_idx = loop[(i + 1) % len(loop)] - new_faces.append([u_idx, v_next_idx, center_idx]) + new_faces.append([loop[i], loop[(i + 1) % len(loop)], v_idx]) + v_idx += 1 - if new_faces: + if new_verts: v = torch.cat([v, torch.stack(new_verts)], dim=0) f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0) return v, f -def merge_duplicate_vertices(vertices, faces, tolerance=1e-5): +def make_double_sided(vertices, faces): is_batched = vertices.ndim == 3 if is_batched: - v_list, f_list = [],[] - for i in range(vertices.shape[0]): - v_i, f_i = merge_duplicate_vertices(vertices[i], faces[i], tolerance) - v_list.append(v_i) - f_list.append(f_i) - return torch.stack(v_list), torch.stack(f_list) + f_list =[] + for i in range(faces.shape[0]): + f_inv = faces[i][:,[0, 2, 1]] + f_list.append(torch.cat([faces[i], f_inv], dim=0)) + return vertices, torch.stack(f_list) - v_min = vertices.min(dim=0, keepdim=True)[0] - v_quant = ((vertices - v_min) / tolerance).round().long() - - unique_quant, inverse_indices = torch.unique(v_quant, dim=0, return_inverse=True) - - new_vertices = torch.zeros((unique_quant.shape[0], 3), dtype=vertices.dtype, device=vertices.device) - new_vertices.index_copy_(0, inverse_indices, vertices) - - new_faces = inverse_indices[faces.long()] - - valid = (new_faces[:, 0] != new_faces[:, 1]) & \ - (new_faces[:, 1] != new_faces[:, 2]) & \ - (new_faces[:, 2] != new_faces[:, 0]) - - return new_vertices, new_faces[valid] - -def fix_normals(vertices, faces): - is_batched = vertices.ndim == 3 - if is_batched: - v_list, f_list = [], [] - for i in range(vertices.shape[0]): - v_i, f_i = fix_normals(vertices[i], faces[i]) - v_list.append(v_i) - f_list.append(f_i) - return torch.stack(v_list), torch.stack(f_list) - - if faces.shape[0] == 0: - return vertices, faces - - center = vertices.mean(0) - v0 = vertices[faces[:, 0].long()] - v1 = vertices[faces[:, 1].long()] - v2 = vertices[faces[:, 2].long()] - - normals = torch.cross(v1 - v0, v2 - v0, dim=1) - - face_centers = (v0 + v1 + v2) / 3.0 - dir_from_center = face_centers - center - - dot = (normals * dir_from_center).sum(1) - flip_mask = dot < 0 - - faces[flip_mask] = faces[flip_mask][:, [0, 2, 1]] - return vertices, faces + faces_inv = faces[:, [0, 2, 1]] + faces_double = torch.cat([faces, faces_inv], dim=0) + return vertices, faces_double class PostProcessMesh(IO.ComfyNode): @classmethod @@ -572,7 +520,7 @@ class PostProcessMesh(IO.ComfyNode): category="latent/3d", inputs=[ IO.Mesh.Input("mesh"), - IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), + IO.Int.Input("simplify", default=1_000_000, min=0, max=50_000_000), IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0, step=0.0001) ], outputs=[ @@ -585,15 +533,13 @@ class PostProcessMesh(IO.ComfyNode): mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces - verts, faces = merge_duplicate_vertices(verts, faces, tolerance=1e-5) - if fill_holes_perimeter > 0: verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter) if simplify > 0 and faces.shape[0] > simplify: verts, faces = simplify_fn(verts, faces, target=simplify) - verts, faces = fix_normals(verts, faces) + verts, faces = make_double_sided(verts, faces) mesh.vertices = verts mesh.faces = faces