diff --git a/comfy_extras/nodes_save_3d.py b/comfy_extras/nodes_save_3d.py new file mode 100644 index 000000000..d79917fb5 --- /dev/null +++ b/comfy_extras/nodes_save_3d.py @@ -0,0 +1,380 @@ +"""Save-side 3D nodes: mesh packing/slicing helpers + GLB writer + SaveGLB node. + +Pairs with nodes_load_3d.py (load-side counterpart). +""" + +import json +import logging +import os +import struct + +import numpy as np +import torch +from typing_extensions import override + +import folder_paths +from comfy.cli_args import args +from comfy_api.latest import ComfyExtension, IO, Types + + +def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None): + # Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors, + # stashing per-item lengths as runtime attrs so consumers can recover the real slice. + # uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts. + batch_size = len(vertices) + max_vertices = max(v.shape[0] for v in vertices) + max_faces = max(f.shape[0] for f in faces) + + packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) + packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) + vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) + face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) + + for i, (v, f) in enumerate(zip(vertices, faces)): + packed_vertices[i, :v.shape[0]] = v + packed_faces[i, :f.shape[0]] = f + + packed_colors = None + color_counts = None + if colors is not None: + max_colors = max(c.shape[0] for c in colors) + packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1])) + color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64) + for i, c in enumerate(colors): + packed_colors[i, :c.shape[0]] = c + + packed_uvs = None + if uvs is not None: + packed_uvs = uvs[0].new_zeros((batch_size, max_vertices, uvs[0].shape[1])) + for i, u in enumerate(uvs): + packed_uvs[i, :u.shape[0]] = u + + mesh = Types.MESH(packed_vertices, packed_faces, uvs=packed_uvs, vertex_colors=packed_colors) + mesh.vertex_counts = vertex_counts + mesh.face_counts = face_counts + if color_counts is not None: + mesh.color_counts = color_counts + return mesh + + +def get_mesh_batch_item(mesh, index): + # Returns (vertices, faces, colors) for batch index, slicing to real lengths + # if pack_variable_mesh_batch added per-item counts. + if hasattr(mesh, "vertex_counts"): + vertex_count = int(mesh.vertex_counts[index].item()) + face_count = int(mesh.face_counts[index].item()) + vertices = mesh.vertices[index, :vertex_count] + faces = mesh.faces[index, :face_count] + colors = None + v_colors = getattr(mesh, "vertex_colors", None) + if v_colors is not None: + if hasattr(mesh, "color_counts"): + color_count = int(mesh.color_counts[index].item()) + colors = v_colors[index, :color_count] + else: + colors = v_colors[index, :vertex_count] + return vertices, faces, colors + + colors = None + v_colors = getattr(mesh, "vertex_colors", None) + if v_colors is not None: + colors = v_colors[index] + return mesh.vertices[index], mesh.faces[index], colors + + +def save_glb(vertices, faces, filepath, metadata=None, + uvs=None, vertex_colors=None, texture_image=None): + """ + Save PyTorch tensor vertices and faces as a GLB file without external dependencies. + + Parameters: + vertices: torch.Tensor of shape (N, 3) - The vertex coordinates + faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces) + filepath: str - Output filepath (should end with .glb) + metadata: dict - Optional asset.extras metadata + uvs: torch.Tensor of shape (N, 2) - Optional per-vertex texture coordinates + vertex_colors: torch.Tensor of shape (N, 3) or (N, 4) - Optional per-vertex colors in [0, 1] + texture_image: PIL.Image - Optional baseColor texture, embedded as PNG + """ + + # Convert tensors to numpy arrays + vertices_np = vertices.cpu().numpy().astype(np.float32) + faces_np = faces.cpu().numpy().astype(np.uint32) + uvs_np = uvs.cpu().numpy().astype(np.float32) if uvs is not None else None + colors_np = vertex_colors.cpu().numpy().astype(np.float32) if vertex_colors is not None else None + if colors_np is not None: + colors_np = np.clip(colors_np, 0.0, 1.0) + texture_png_bytes = None + if texture_image is not None: + import io as _io + buf = _io.BytesIO() + texture_image.save(buf, format="PNG") + texture_png_bytes = buf.getvalue() + + vertices_buffer = vertices_np.tobytes() + indices_buffer = faces_np.tobytes() + uvs_buffer = uvs_np.tobytes() if uvs_np is not None else b"" + colors_buffer = colors_np.tobytes() if colors_np is not None else b"" + texture_buffer = texture_png_bytes if texture_png_bytes is not None else b"" + + def pad_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b'\x00' * padding_length + + vertices_buffer_padded = pad_to_4_bytes(vertices_buffer) + indices_buffer_padded = pad_to_4_bytes(indices_buffer) + uvs_buffer_padded = pad_to_4_bytes(uvs_buffer) + colors_buffer_padded = pad_to_4_bytes(colors_buffer) + texture_buffer_padded = pad_to_4_bytes(texture_buffer) + + buffer_data = (vertices_buffer_padded + indices_buffer_padded + + uvs_buffer_padded + colors_buffer_padded + texture_buffer_padded) + + vertices_byte_length = len(vertices_buffer) + vertices_byte_offset = 0 + indices_byte_length = len(indices_buffer) + indices_byte_offset = len(vertices_buffer_padded) + uvs_byte_offset = indices_byte_offset + len(indices_buffer_padded) + colors_byte_offset = uvs_byte_offset + len(uvs_buffer_padded) + texture_byte_offset = colors_byte_offset + len(colors_buffer_padded) + + buffer_views = [ + { + "buffer": 0, + "byteOffset": vertices_byte_offset, + "byteLength": vertices_byte_length, + "target": 34962 # ARRAY_BUFFER + }, + { + "buffer": 0, + "byteOffset": indices_byte_offset, + "byteLength": indices_byte_length, + "target": 34963 # ELEMENT_ARRAY_BUFFER + } + ] + accessors = [ + { + "bufferView": 0, + "byteOffset": 0, + "componentType": 5126, # FLOAT + "count": len(vertices_np), + "type": "VEC3", + "max": vertices_np.max(axis=0).tolist(), + "min": vertices_np.min(axis=0).tolist() + }, + { + "bufferView": 1, + "byteOffset": 0, + "componentType": 5125, # UNSIGNED_INT + "count": faces_np.size, + "type": "SCALAR" + } + ] + primitive_attributes = {"POSITION": 0} + + if uvs_np is not None and len(uvs_np) > 0: + buffer_views.append({ + "buffer": 0, + "byteOffset": uvs_byte_offset, + "byteLength": len(uvs_buffer), + "target": 34962 + }) + accessor_idx = len(accessors) + accessors.append({ + "bufferView": len(buffer_views) - 1, + "byteOffset": 0, + "componentType": 5126, + "count": len(uvs_np), + "type": "VEC2", + }) + primitive_attributes["TEXCOORD_0"] = accessor_idx + + if colors_np is not None and len(colors_np) > 0: + buffer_views.append({ + "buffer": 0, + "byteOffset": colors_byte_offset, + "byteLength": len(colors_buffer), + "target": 34962 + }) + accessor_idx = len(accessors) + accessors.append({ + "bufferView": len(buffer_views) - 1, + "byteOffset": 0, + "componentType": 5126, + "count": len(colors_np), + "type": "VEC3" if colors_np.shape[1] == 3 else "VEC4", + }) + primitive_attributes["COLOR_0"] = accessor_idx + + primitive = { + "attributes": primitive_attributes, + "indices": 1, + "mode": 4 # TRIANGLES + } + + images = [] + textures = [] + samplers = [] + materials = [] + if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes: + buffer_views.append({ + "buffer": 0, + "byteOffset": texture_byte_offset, + "byteLength": len(texture_buffer), + }) + images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"}) + samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071}) + textures.append({"source": 0, "sampler": 0}) + materials.append({ + "pbrMetallicRoughness": { + "baseColorTexture": {"index": 0, "texCoord": 0}, + "metallicFactor": 0.0, + "roughnessFactor": 1.0, + }, + "doubleSided": True, + }) + primitive["material"] = 0 + + gltf = { + "asset": {"version": "2.0", "generator": "ComfyUI"}, + "buffers": [{"byteLength": len(buffer_data)}], + "bufferViews": buffer_views, + "accessors": accessors, + "meshes": [{"primitives": [primitive]}], + "nodes": [{"mesh": 0}], + "scenes": [{"nodes": [0]}], + "scene": 0, + } + if images: + gltf["images"] = images + if samplers: + gltf["samplers"] = samplers + if textures: + gltf["textures"] = textures + if materials: + gltf["materials"] = materials + + if metadata is not None: + gltf["asset"]["extras"] = metadata + + # Convert the JSON to bytes + gltf_json = json.dumps(gltf).encode('utf8') + + def pad_json_to_4_bytes(buffer): + padding_length = (4 - (len(buffer) % 4)) % 4 + return buffer + b' ' * padding_length + + gltf_json_padded = pad_json_to_4_bytes(gltf_json) + + # Create the GLB header + # Magic glTF + glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data)) + + # Create JSON chunk header (chunk type 0) + json_chunk_header = struct.pack('