import torch from comfy_api.latest import Types def pack_variable_mesh_batch(vertices, faces, colors=None): batch_size = len(vertices) max_vertices = max(v.shape[0] for v in vertices) max_faces = max(f.shape[0] for f in faces) packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) for i, (v, f) in enumerate(zip(vertices, faces)): packed_vertices[i, :v.shape[0]] = v packed_faces[i, :f.shape[0]] = f mesh = Types.MESH(packed_vertices, packed_faces) mesh.vertex_counts = vertex_counts mesh.face_counts = face_counts if colors is not None: max_colors = max(c.shape[0] for c in colors) packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) for i, c in enumerate(colors): packed_colors[i, :c.shape[0]] = c mesh.colors = packed_colors mesh.color_counts = color_counts return mesh def get_mesh_batch_item(mesh, index): if hasattr(mesh, "vertex_counts"): vertex_count = int(mesh.vertex_counts[index].item()) face_count = int(mesh.face_counts[index].item()) vertices = mesh.vertices[index, :vertex_count] faces = mesh.faces[index, :face_count] colors = None if hasattr(mesh, "colors") and mesh.colors is not None: if hasattr(mesh, "color_counts"): color_count = int(mesh.color_counts[index].item()) colors = mesh.colors[index, :color_count] else: colors = mesh.colors[index, :vertex_count] return vertices, faces, colors colors = None if hasattr(mesh, "colors") and mesh.colors is not None: colors = mesh.colors[index] return mesh.vertices[index], mesh.faces[index], colors