from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy.ldm.trellis2.vae import SparseTensor import comfy.model_management import logging from PIL import Image import numpy as np import torch import copy shape_slat_normalization = { "mean": torch.tensor([ 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218, -0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944, 0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667, -0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665 ])[None], "std": torch.tensor([ 5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004, 5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578, 4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194, 5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691 ])[None] } tex_slat_normalization = { "mean": torch.tensor([ 3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075, 0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149, -1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717, 1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862 ])[None], "std": torch.tensor([ 2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822, 2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588, 2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999, 2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190 ])[None] } class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="VaeDecodeShapeTrellis", category="latent/3d", inputs=[ IO.Latent.Input("samples"), IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), IO.Combo.Input("resolution", options=["512", "1024"], default="512") ], outputs=[ IO.Mesh.Output("mesh"), IO.AnyType.Output("shape_subs"), ] ) @classmethod def execute(cls, samples, structure_output, vae, resolution): resolution = int(resolution) patcher = vae.patcher device = comfy.model_management.get_torch_device() comfy.model_management.load_model_gpu(patcher) vae = vae.first_stage_model decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) std = shape_slat_normalization["std"].to(samples) mean = shape_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean mesh, subs = vae.decode_shape_slat(samples, resolution) faces = torch.stack([m.faces for m in mesh]) verts = torch.stack([m.vertices for m in mesh]) mesh = Types.MESH(vertices=verts, faces=faces) return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="VaeDecodeTextureTrellis", category="latent/3d", inputs=[ IO.Latent.Input("samples"), IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), IO.AnyType.Input("shape_subs"), ], outputs=[ IO.Mesh.Output("mesh"), ] ) @classmethod def execute(cls, samples, structure_output, vae, shape_subs): patcher = vae.patcher device = comfy.model_management.get_torch_device() comfy.model_management.load_model_gpu(patcher) vae = vae.first_stage_model decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) std = tex_slat_normalization["std"].to(samples) mean = tex_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean mesh = vae.decode_tex_slat(samples, shape_subs) faces = torch.stack([m.faces for m in mesh]) verts = torch.stack([m.vertices for m in mesh]) mesh = Types.MESH(vertices=verts, faces=faces) return IO.NodeOutput(mesh) class VaeDecodeStructureTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="VaeDecodeStructureTrellis2", category="latent/3d", inputs=[ IO.Latent.Input("samples"), IO.Vae.Input("vae"), IO.Combo.Input("resolution", options=["32", "64"], default="32") ], outputs=[ IO.Voxel.Output("structure_output"), ] ) @classmethod def execute(cls, samples, vae, resolution): resolution = int(resolution) vae = vae.first_stage_model decoder = vae.struct_dec load_device = comfy.model_management.get_torch_device() offload_device = comfy.model_management.vae_offload_device() decoder = decoder.to(load_device) samples = samples["samples"] samples = samples.to(load_device) decoded = decoder(samples)>0 decoder.to(offload_device) current_res = decoded.shape[2] if current_res != resolution: ratio = current_res // resolution decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 out = Types.VOXEL(decoded.squeeze(1).float()) return IO.NodeOutput(out) dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) def run_conditioning(model, cropped_img_tensor, include_1024=True): model_internal = model.model device = comfy.model_management.intermediate_device() torch_device = comfy.model_management.get_torch_device() img_t = cropped_img_tensor.to(torch_device) def prepare_tensor(img, size): resized = torch.nn.functional.interpolate(img, size=(size, size), mode='bicubic', align_corners=False).clamp(0.0, 1.0) return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) model_internal.image_size = 512 input_512 = prepare_tensor(img_t, 512) cond_512 = model_internal(input_512)[0] cond_1024 = None if include_1024: model_internal.image_size = 1024 input_1024 = prepare_tensor(img_t, 1024) cond_1024 = model_internal(input_1024)[0] conditioning = { 'cond_512': cond_512.to(device), 'neg_cond': torch.zeros_like(cond_512).to(device), } if cond_1024 is not None: conditioning['cond_1024'] = cond_1024.to(device) return conditioning class Trellis2Conditioning(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="Trellis2Conditioning", category="conditioning/video_models", inputs=[ IO.ClipVision.Input("clip_vision_model"), IO.Image.Input("image"), IO.Mask.Input("mask"), IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black") ], outputs=[ IO.Conditioning.Output(display_name="positive"), IO.Conditioning.Output(display_name="negative"), ] ) @classmethod def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput: if image.ndim == 4: image = image[0] if mask.ndim == 3: mask = mask[0] img_np = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) mask_np = (mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) pil_img = Image.fromarray(img_np) pil_mask = Image.fromarray(mask_np) max_size = max(pil_img.size) scale = min(1.0, 1024 / max_size) if scale < 1.0: new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale) pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) rgba_np[:, :, :3] = np.array(pil_img) rgba_np[:, :, 3] = np.array(pil_mask) alpha = rgba_np[:, :, 3] bbox_coords = np.argwhere(alpha > 0.8 * 255) if len(bbox_coords) > 0: y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1]) y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1]) center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0 size = max(y_max - y_min, x_max - x_min) crop_x1 = int(center_x - size // 2) crop_y1 = int(center_y - size // 2) crop_x2 = int(center_x + size // 2) crop_y2 = int(center_y + size // 2) rgba_pil = Image.fromarray(rgba_np, 'RGBA') cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 else: logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") cropped_np = rgba_np.astype(np.float32) / 255.0 bg_colors = {"black": [0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32) fg = cropped_np[:, :, :3] alpha_float = cropped_np[:, :, 3:4] composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) cropped_img_tensor = torch.from_numpy(composite_np).movedim(-1, 0).unsqueeze(0).float() conditioning = run_conditioning(clip_vision_model, cropped_img_tensor, include_1024=True) embeds = conditioning["cond_1024"] positive = [[conditioning["cond_512"], {"embeds": embeds}]] negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]] return IO.NodeOutput(positive, negative) class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="EmptyShapeLatentTrellis2", category="latent/3d", inputs=[ IO.Voxel.Input("structure_output"), IO.Model.Input("model") ], outputs=[ IO.Latent.Output(), IO.Model.Output() ] ) @classmethod def execute(cls, structure_output, model): decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 # image like format latent = torch.randn(1, in_channels, coords.shape[0], 1) model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() else: model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords model.model_options["transformer_options"]["generation_mode"] = "shape_generation" return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="EmptyTextureLatentTrellis2", category="latent/3d", inputs=[ IO.Voxel.Input("structure_output"), ], outputs=[ IO.Latent.Output(), IO.Model.Output() ] ) @classmethod def execute(cls, structure_output, model): # TODO decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() else: model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords model.model_options["transformer_options"]["generation_mode"] = "shape_generation" return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="EmptyStructureLatentTrellis2", category="latent/3d", inputs=[ IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), ], outputs=[ IO.Latent.Output(), ] ) @classmethod def execute(cls, batch_size): in_channels = 8 resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) def simplify_fn(vertices, faces, target=100000): if vertices.shape[0] <= target: return vertices, faces min_feat = vertices.min(dim=0)[0] max_feat = vertices.max(dim=0)[0] extent = (max_feat - min_feat).max() grid_resolution = int(torch.sqrt(torch.tensor(target)).item() * 1.5) voxel_size = extent / grid_resolution quantized_coords = ((vertices - min_feat) / voxel_size).long() unique_coords, inverse_indices = torch.unique(quantized_coords, dim=0, return_inverse=True) num_new_verts = unique_coords.shape[0] new_vertices = torch.zeros((num_new_verts, 3), dtype=vertices.dtype, device=vertices.device) counts = torch.zeros((num_new_verts, 1), dtype=vertices.dtype, device=vertices.device) new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices) counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1])) new_vertices = new_vertices / counts.clamp(min=1) new_faces = inverse_indices[faces] v0 = new_faces[:, 0] v1 = new_faces[:, 1] v2 = new_faces[:, 2] valid_mask = (v0 != v1) & (v1 != v2) & (v2 != v0) new_faces = new_faces[valid_mask] unique_face_indices, inv_face = torch.unique(new_faces.reshape(-1), return_inverse=True) final_vertices = new_vertices[unique_face_indices] final_faces = inv_face.reshape(-1, 3) return final_vertices, final_faces def fill_holes_fn(vertices, faces, max_perimeter=0.03): is_batched = vertices.ndim == 3 if is_batched: v_list, f_list = [],[] for i in range(vertices.shape[0]): v_i, f_i = fill_holes_fn(vertices[i], faces[i], max_perimeter) v_list.append(v_i) f_list.append(f_i) return torch.stack(v_list), torch.stack(f_list) device = vertices.device v = vertices f = faces if f.shape[0] == 0: return v, f edges = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0) edges_sorted, _ = torch.sort(edges, dim=1) max_v = v.shape[0] packed_undirected = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long() unique_packed, counts = torch.unique(packed_undirected, return_counts=True) boundary_mask = counts == 1 boundary_packed = unique_packed[boundary_mask] if boundary_packed.numel() == 0: return v, f packed_directed_sorted = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long() is_boundary = torch.isin(packed_directed_sorted, boundary_packed) boundary_edges_directed = edges[is_boundary] adj = {} in_deg = {} out_deg = {} edges_list = boundary_edges_directed.tolist() for u, v_idx in edges_list: if u not in adj: adj[u] = [] adj[u].append(v_idx) out_deg[u] = out_deg.get(u, 0) + 1 in_deg[v_idx] = in_deg.get(v_idx, 0) + 1 manifold_nodes = set() for node in set(list(in_deg.keys()) + list(out_deg.keys())): if in_deg.get(node, 0) == 1 and out_deg.get(node, 0) == 1: manifold_nodes.add(node) loops =[] visited_nodes = set() for start_node in list(adj.keys()): if start_node not in manifold_nodes or start_node in visited_nodes: continue curr = start_node current_loop =[] while True: current_loop.append(curr) visited_nodes.add(curr) next_node = adj[curr][0] if next_node == start_node: if len(current_loop) >= 3: loops.append(current_loop) break if next_node not in manifold_nodes or next_node in visited_nodes: break curr = next_node if len(current_loop) > len(edges_list): break new_faces =[] new_verts = [] curr_v_idx = v.shape[0] for loop in loops: loop_indices = torch.tensor(loop, device=device, dtype=torch.long) loop_points = v[loop_indices] # Calculate perimeter p1 = loop_points p2 = torch.roll(loop_points, shifts=-1, dims=0) perimeter = torch.norm(p1 - p2, dim=1).sum().item() if perimeter <= max_perimeter: centroid = loop_points.mean(dim=0) new_verts.append(centroid) center_idx = curr_v_idx curr_v_idx += 1 for i in range(len(loop)): u_idx = loop[i] v_next_idx = loop[(i + 1) % len(loop)] new_faces.append([u_idx, v_next_idx, center_idx]) if new_faces: v = torch.cat([v, torch.stack(new_verts)], dim=0) f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0) return v, f def merge_duplicate_vertices(vertices, faces, tolerance=1e-5): is_batched = vertices.ndim == 3 if is_batched: v_list, f_list = [],[] for i in range(vertices.shape[0]): v_i, f_i = merge_duplicate_vertices(vertices[i], faces[i], tolerance) v_list.append(v_i) f_list.append(f_i) return torch.stack(v_list), torch.stack(f_list) v_min = vertices.min(dim=0, keepdim=True)[0] v_quant = ((vertices - v_min) / tolerance).round().long() unique_quant, inverse_indices = torch.unique(v_quant, dim=0, return_inverse=True) new_vertices = torch.zeros((unique_quant.shape[0], 3), dtype=vertices.dtype, device=vertices.device) new_vertices.index_copy_(0, inverse_indices, vertices) new_faces = inverse_indices[faces.long()] valid = (new_faces[:, 0] != new_faces[:, 1]) & \ (new_faces[:, 1] != new_faces[:, 2]) & \ (new_faces[:, 2] != new_faces[:, 0]) return new_vertices, new_faces[valid] def fix_normals(vertices, faces): is_batched = vertices.ndim == 3 if is_batched: v_list, f_list = [], [] for i in range(vertices.shape[0]): v_i, f_i = fix_normals(vertices[i], faces[i]) v_list.append(v_i) f_list.append(f_i) return torch.stack(v_list), torch.stack(f_list) if faces.shape[0] == 0: return vertices, faces center = vertices.mean(0) v0 = vertices[faces[:, 0].long()] v1 = vertices[faces[:, 1].long()] v2 = vertices[faces[:, 2].long()] normals = torch.cross(v1 - v0, v2 - v0, dim=1) face_centers = (v0 + v1 + v2) / 3.0 dir_from_center = face_centers - center dot = (normals * dir_from_center).sum(1) flip_mask = dot < 0 faces[flip_mask] = faces[flip_mask][:, [0, 2, 1]] return vertices, faces class PostProcessMesh(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="PostProcessMesh", category="latent/3d", inputs=[ IO.Mesh.Input("mesh"), IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0, step=0.0001) ], outputs=[ IO.Mesh.Output("output_mesh"), ] ) @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces verts, faces = merge_duplicate_vertices(verts, faces, tolerance=1e-5) if fill_holes_perimeter > 0: verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter) if simplify > 0 and faces.shape[0] > simplify: verts, faces = simplify_fn(verts, faces, target=simplify) verts, faces = fix_normals(verts, faces) mesh.vertices = verts mesh.faces = faces return IO.NodeOutput(mesh) class Trellis2Extension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ Trellis2Conditioning, EmptyShapeLatentTrellis2, EmptyStructureLatentTrellis2, EmptyTextureLatentTrellis2, VaeDecodeTextureTrellis, VaeDecodeShapeTrellis, VaeDecodeStructureTrellis2, PostProcessMesh ] async def comfy_entrypoint() -> Trellis2Extension: return Trellis2Extension()