"""Save-side 3D nodes: mesh packing/slicing helpers + GLB writer + SaveGLB node.""" import copy import json import logging import math import os import struct from io import BytesIO from typing import TypedDict import numpy as np from PIL import Image import torch from typing_extensions import override import folder_paths from comfy.cli_args import args from comfy_api.latest import ComfyExtension, IO, Types def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None, unlit=False, normals=None, metallic_roughness=None): # Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors, # stashing per-item lengths as runtime attrs so consumers can recover the real slice. # colors and uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts. # texture is (B, H, W, 3) — passed through unchanged batch_size = len(vertices) max_vertices = max(v.shape[0] for v in vertices) max_faces = max(f.shape[0] for f in faces) packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1])) packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1])) vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64) face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64) for i, (v, f) in enumerate(zip(vertices, faces)): packed_vertices[i, :v.shape[0]] = v packed_faces[i, :f.shape[0]] = f packed_colors = None if colors is not None: packed_colors = colors[0].new_zeros((batch_size, max_vertices, colors[0].shape[1])) for i, c in enumerate(colors): assert c.shape[0] == vertices[i].shape[0], ( f"vertex_colors[{i}] has {c.shape[0]} entries, expected {vertices[i].shape[0]} (1:1 with vertices)" ) packed_colors[i, :c.shape[0]] = c packed_uvs = None if uvs is not None: packed_uvs = uvs[0].new_zeros((batch_size, max_vertices, uvs[0].shape[1])) for i, u in enumerate(uvs): assert u.shape[0] == vertices[i].shape[0], ( f"uvs[{i}] has {u.shape[0]} entries, expected {vertices[i].shape[0]} (1:1 with vertices)" ) packed_uvs[i, :u.shape[0]] = u packed_normals = None if normals is not None: packed_normals = normals[0].new_zeros((batch_size, max_vertices, normals[0].shape[1])) for i, nrm in enumerate(normals): assert nrm.shape[0] == vertices[i].shape[0], ( f"normals[{i}] has {nrm.shape[0]} entries, expected {vertices[i].shape[0]} (1:1 with vertices)" ) packed_normals[i, :nrm.shape[0]] = nrm return Types.MESH(packed_vertices, packed_faces, uvs=packed_uvs, vertex_colors=packed_colors, texture=texture, metallic_roughness=metallic_roughness, vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit, normals=packed_normals) def get_mesh_batch_item(mesh, index): # Returns (vertices, faces, colors, uvs) for batch index, slicing to real lengths # if the mesh carries per-item counts (variable-size batch). v_colors = getattr(mesh, "vertex_colors", None) v_uvs = getattr(mesh, "uvs", None) v_normals = getattr(mesh, "normals", None) if getattr(mesh, "vertex_counts", None) is not None: vertex_count = int(mesh.vertex_counts[index].item()) face_count = int(mesh.face_counts[index].item()) vertices = mesh.vertices[index, :vertex_count] faces = mesh.faces[index, :face_count] colors = v_colors[index, :vertex_count] if v_colors is not None else None uvs = v_uvs[index, :vertex_count] if v_uvs is not None else None normals = v_normals[index, :vertex_count] if v_normals is not None else None return vertices, faces, colors, uvs, normals colors = v_colors[index] if v_colors is not None else None uvs = v_uvs[index] if v_uvs is not None else None normals = v_normals[index] if v_normals is not None else None return mesh.vertices[index], mesh.faces[index], colors, uvs, normals def _smooth_vertex_normals(vertices_np, faces_np): """Area-weighted per-vertex normals (unit length), fully smooth — no vertex splitting. Un-normalized face normals (the raw cross product) have magnitude 2*area, so accumulating them onto their vertices yields an area-weighted average.""" tris = vertices_np[faces_np] # (M, 3, 3) face_n = np.cross(tris[:, 1] - tris[:, 0], tris[:, 2] - tris[:, 0]) normals = np.zeros((vertices_np.shape[0], 3), dtype=np.float64) for k in range(3): np.add.at(normals, faces_np[:, k], face_n) lens = np.linalg.norm(normals, axis=1, keepdims=True) normals /= np.where(lens > 1e-12, lens, 1.0) return normals.astype(np.float32) def _compute_vertex_normals(vertices_np, faces_np, crease_angle=None): """Compute per-vertex normals, returning (vertices, faces_uint32, normals, remap). crease_angle is None (or >= 180) -> fully smooth normals; vertices/faces are returned unchanged and remap is None. Otherwise vertices are split along edges whose dihedral angle exceeds crease_angle (degrees) so hard creases stay sharp while smooth regions still interpolate. remap maps each output vertex back to its source index, so the caller can duplicate any per-vertex attributes (uvs / colors) to match.""" faces_i = faces_np.astype(np.int64) if crease_angle is None or crease_angle >= 180.0: return (vertices_np, faces_i.astype(np.uint32), _smooth_vertex_normals(vertices_np, faces_i), None) M = faces_i.shape[0] tris = vertices_np[faces_i] face_n = np.cross(tris[:, 1] - tris[:, 0], tris[:, 2] - tris[:, 0]) areas = np.linalg.norm(face_n, axis=1, keepdims=True) face_unit = face_n / np.where(areas > 1e-12, areas, 1.0) cos_thresh = math.cos(math.radians(crease_angle)) # Union faces that share an edge whose dihedral angle is below the crease # threshold; each connected component becomes one smoothing group. parent = list(range(M)) def find(x): while parent[x] != x: parent[x] = parent[parent[x]] x = parent[x] return x edge_faces = {} for fi in range(M): a, b, c = int(faces_i[fi, 0]), int(faces_i[fi, 1]), int(faces_i[fi, 2]) for u, v in ((a, b), (b, c), (c, a)): edge_faces.setdefault((u, v) if u < v else (v, u), []).append(fi) for fl in edge_faces.values(): if len(fl) == 2 and float(np.dot(face_unit[fl[0]], face_unit[fl[1]])) >= cos_thresh: ra, rb = find(fl[0]), find(fl[1]) if ra != rb: parent[ra] = rb # Emit one output vertex per (original vertex, smoothing group) pair. new_index = {} remap = [] out_faces = np.empty((M, 3), dtype=np.int64) for fi in range(M): g = find(fi) for k in range(3): ov = int(faces_i[fi, k]) key = (ov, g) ni = new_index.get(key) if ni is None: ni = len(remap) new_index[key] = ni remap.append(ov) out_faces[fi, k] = ni remap = np.asarray(remap, dtype=np.int64) normals = np.zeros((remap.shape[0], 3), dtype=np.float64) for k in range(3): np.add.at(normals, out_faces[:, k], face_n) lens = np.linalg.norm(normals, axis=1, keepdims=True) normals /= np.where(lens > 1e-12, lens, 1.0) return (vertices_np[remap], out_faces.astype(np.uint32), normals.astype(np.float32), remap) def save_glb(vertices, faces, filepath, metadata=None, uvs=None, vertex_colors=None, texture_image=None, metallic_roughness_image=None, unlit=False, normals=None): """ Save PyTorch tensor vertices and faces as a GLB file without external dependencies. Parameters: vertices: torch.Tensor of shape (N, 3) - The vertex coordinates faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces) filepath: str - Output filepath (should end with .glb) metadata: dict - Optional asset.extras metadata uvs: torch.Tensor of shape (N, 2) - Optional per-vertex texture coordinates vertex_colors: torch.Tensor of shape (N, 3) or (N, 4) - Optional per-vertex colors in [0, 1] texture_image: PIL.Image - Optional baseColor texture, embedded as PNG metallic_roughness_image: PIL.Image - Optional glTF metallicRoughness texture (R unused, G=roughness, B=metallic), embedded as PNG normals: torch.Tensor of shape (N, 3) - Optional per-vertex normals, written as the glTF NORMAL attribute. When omitted, NO normals are written and viewers fall back to flat (per-face) shading — use the MeshSmoothNormals node to generate them. """ # Convert tensors to numpy arrays vertices_np = vertices.cpu().numpy().astype(np.float32) faces_signed = faces.cpu().numpy().astype(np.int64) uvs_np = uvs.cpu().numpy().astype(np.float32) if uvs is not None else None colors_np = vertex_colors.cpu().numpy().astype(np.float32) if vertex_colors is not None else None if colors_np is not None: colors_np = np.clip(colors_np, 0.0, 1.0) n_verts = vertices_np.shape[0] if n_verts == 0: raise ValueError("save_glb: vertices is empty") if faces_signed.size > 0: fmin = int(faces_signed.min()) fmax = int(faces_signed.max()) if fmin < 0 or fmax >= n_verts: raise ValueError( f"save_glb: face index out of range [0, {n_verts}): min={fmin}, max={fmax}" ) if uvs_np is not None and uvs_np.shape[0] != n_verts: raise ValueError( f"save_glb: uvs has {uvs_np.shape[0]} entries but vertex count is {n_verts}" ) if colors_np is not None and colors_np.shape[0] != n_verts: raise ValueError( f"save_glb: vertex_colors has {colors_np.shape[0]} entries but vertex count is {n_verts}" ) normals_np = normals.cpu().numpy().astype(np.float32) if normals is not None else None if normals_np is not None and normals_np.shape[0] != n_verts: raise ValueError( f"save_glb: normals has {normals_np.shape[0]} entries but vertex count is {n_verts}" ) faces_np = faces_signed.astype(np.uint32) texture_png_bytes = None if texture_image is not None: buf = BytesIO() texture_image.save(buf, format="PNG") texture_png_bytes = buf.getvalue() mr_png_bytes = None if metallic_roughness_image is not None: buf = BytesIO() metallic_roughness_image.save(buf, format="PNG") mr_png_bytes = buf.getvalue() vertices_buffer = vertices_np.tobytes() indices_buffer = faces_np.tobytes() uvs_buffer = uvs_np.tobytes() if uvs_np is not None else b"" colors_buffer = colors_np.tobytes() if colors_np is not None else b"" normals_buffer = normals_np.tobytes() if normals_np is not None else b"" texture_buffer = texture_png_bytes if texture_png_bytes is not None else b"" mr_buffer = mr_png_bytes if mr_png_bytes is not None else b"" def pad_to_4_bytes(buffer): padding_length = (4 - (len(buffer) % 4)) % 4 return buffer + b'\x00' * padding_length vertices_buffer_padded = pad_to_4_bytes(vertices_buffer) indices_buffer_padded = pad_to_4_bytes(indices_buffer) uvs_buffer_padded = pad_to_4_bytes(uvs_buffer) colors_buffer_padded = pad_to_4_bytes(colors_buffer) normals_buffer_padded = pad_to_4_bytes(normals_buffer) texture_buffer_padded = pad_to_4_bytes(texture_buffer) mr_buffer_padded = pad_to_4_bytes(mr_buffer) buffer_data = b"".join([ vertices_buffer_padded, indices_buffer_padded, uvs_buffer_padded, colors_buffer_padded, normals_buffer_padded, texture_buffer_padded, mr_buffer_padded, ]) vertices_byte_length = len(vertices_buffer) vertices_byte_offset = 0 indices_byte_length = len(indices_buffer) indices_byte_offset = len(vertices_buffer_padded) uvs_byte_offset = indices_byte_offset + len(indices_buffer_padded) colors_byte_offset = uvs_byte_offset + len(uvs_buffer_padded) normals_byte_offset = colors_byte_offset + len(colors_buffer_padded) texture_byte_offset = normals_byte_offset + len(normals_buffer_padded) mr_byte_offset = texture_byte_offset + len(texture_buffer_padded) buffer_views = [ { "buffer": 0, "byteOffset": vertices_byte_offset, "byteLength": vertices_byte_length, "target": 34962 # ARRAY_BUFFER }, { "buffer": 0, "byteOffset": indices_byte_offset, "byteLength": indices_byte_length, "target": 34963 # ELEMENT_ARRAY_BUFFER } ] accessors = [ { "bufferView": 0, "byteOffset": 0, "componentType": 5126, # FLOAT "count": len(vertices_np), "type": "VEC3", "max": vertices_np.max(axis=0).tolist(), "min": vertices_np.min(axis=0).tolist() }, { "bufferView": 1, "byteOffset": 0, "componentType": 5125, # UNSIGNED_INT "count": faces_np.size, "type": "SCALAR" } ] primitive_attributes = {"POSITION": 0} if uvs_np is not None and len(uvs_np) > 0: buffer_views.append({ "buffer": 0, "byteOffset": uvs_byte_offset, "byteLength": len(uvs_buffer), "target": 34962 }) accessor_idx = len(accessors) accessors.append({ "bufferView": len(buffer_views) - 1, "byteOffset": 0, "componentType": 5126, "count": len(uvs_np), "type": "VEC2", }) primitive_attributes["TEXCOORD_0"] = accessor_idx if colors_np is not None and len(colors_np) > 0: buffer_views.append({ "buffer": 0, "byteOffset": colors_byte_offset, "byteLength": len(colors_buffer), "target": 34962 }) accessor_idx = len(accessors) accessors.append({ "bufferView": len(buffer_views) - 1, "byteOffset": 0, "componentType": 5126, "count": len(colors_np), "type": "VEC3" if colors_np.shape[1] == 3 else "VEC4", }) primitive_attributes["COLOR_0"] = accessor_idx if normals_np is not None and len(normals_np) > 0: buffer_views.append({ "buffer": 0, "byteOffset": normals_byte_offset, "byteLength": len(normals_buffer), "target": 34962 }) accessor_idx = len(accessors) accessors.append({ "bufferView": len(buffer_views) - 1, "byteOffset": 0, "componentType": 5126, # FLOAT "count": len(normals_np), "type": "VEC3", }) primitive_attributes["NORMAL"] = accessor_idx primitive = { "attributes": primitive_attributes, "indices": 1, "mode": 4 # TRIANGLES } images = [] textures = [] samplers = [] materials = [] extensions_used = [] if unlit and texture_png_bytes is None: # Flat, light-independent shading (KHR_materials_unlit): COLOR_0 is shown as-is, matching how a # gaussian splat renders (emissive). Without this the viewer lights the mesh and washes the colours. materials.append({ "pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0], "metallicFactor": 0.0, "roughnessFactor": 1.0}, "extensions": {"KHR_materials_unlit": {}}, "doubleSided": True, }) extensions_used.append("KHR_materials_unlit") primitive["material"] = 0 else: pbr = { "metallicFactor": 0.0, "roughnessFactor": 0.5, "baseColorFactor": [0.22, 0.22, 0.22, 1.0], } if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes: buffer_views.append({ "buffer": 0, "byteOffset": texture_byte_offset, "byteLength": len(texture_buffer), }) images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"}) samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071}) textures.append({"source": len(images) - 1, "sampler": 0}) pbr["baseColorTexture"] = {"index": len(textures) - 1, "texCoord": 0} if mr_png_bytes is not None and "TEXCOORD_0" in primitive_attributes: buffer_views.append({ "buffer": 0, "byteOffset": mr_byte_offset, "byteLength": len(mr_buffer), }) images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"}) if not samplers: samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071}) textures.append({"source": len(images) - 1, "sampler": 0}) pbr["metallicRoughnessTexture"] = {"index": len(textures) - 1, "texCoord": 0} # When a metallicRoughness texture is present, the factors scale it; use 1.0 # so the texture values pass through unchanged (glTF convention). pbr["metallicFactor"] = 1.0 pbr["roughnessFactor"] = 1.0 materials.append({ "pbrMetallicRoughness": pbr, "doubleSided": True, }) primitive["material"] = 0 gltf = { "asset": {"version": "2.0", "generator": "ComfyUI"}, "buffers": [{"byteLength": len(buffer_data)}], "bufferViews": buffer_views, "accessors": accessors, "meshes": [{"primitives": [primitive]}], "nodes": [{"mesh": 0}], "scenes": [{"nodes": [0]}], "scene": 0, } if images: gltf["images"] = images if samplers: gltf["samplers"] = samplers if textures: gltf["textures"] = textures if materials: gltf["materials"] = materials if extensions_used: gltf["extensionsUsed"] = extensions_used if metadata: gltf["asset"]["extras"] = metadata # Convert the JSON to bytes gltf_json = json.dumps(gltf).encode('utf8') def pad_json_to_4_bytes(buffer): padding_length = (4 - (len(buffer) % 4)) % 4 return buffer + b' ' * padding_length gltf_json_padded = pad_json_to_4_bytes(gltf_json) # Create the GLB header (a 4-byte ASCII magic identifier glTF) glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data)) # Create JSON chunk header (chunk type 0) json_chunk_header = struct.pack('