add color support for save mesh

This commit is contained in:
Yousef Rafat 2026-03-25 02:40:15 +02:00
parent 56e52e5d03
commit fe25190cae
2 changed files with 53 additions and 7 deletions

View File

@ -484,7 +484,7 @@ class VoxelToMesh(IO.ComfyNode):
decode = execute # TODO: remove decode = execute # TODO: remove
def save_glb(vertices, faces, filepath, metadata=None): def save_glb(vertices, faces, filepath, metadata=None, colors=None):
""" """
Save PyTorch tensor vertices and faces as a GLB file without external dependencies. Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
@ -515,6 +515,13 @@ def save_glb(vertices, faces, filepath, metadata=None):
indices_byte_length = len(indices_buffer) indices_byte_length = len(indices_buffer)
indices_byte_offset = len(vertices_buffer_padded) indices_byte_offset = len(vertices_buffer_padded)
if colors is not None:
colors_np = colors.cpu().numpy().astype(np.float32)
colors_buffer = colors_np.tobytes()
colors_byte_length = len(colors_buffer)
colors_byte_offset = len(buffer_data)
buffer_data += pad_to_4_bytes(colors_buffer)
gltf = { gltf = {
"asset": {"version": "2.0", "generator": "ComfyUI"}, "asset": {"version": "2.0", "generator": "ComfyUI"},
"buffers": [ "buffers": [
@ -580,6 +587,11 @@ def save_glb(vertices, faces, filepath, metadata=None):
"scene": 0 "scene": 0
} }
if colors is not None:
gltf["bufferViews"].append({"buffer": 0, "byteOffset": colors_byte_offset, "byteLength": colors_byte_length, "target": 34962})
gltf["accessors"].append({"bufferView": 2, "byteOffset": 0, "componentType": 5126, "count": len(colors_np), "type": "VEC3"})
gltf["meshes"][0]["primitives"][0]["attributes"]["COLOR_0"] = 2
if metadata is not None: if metadata is not None:
gltf["asset"]["extras"] = metadata gltf["asset"]["extras"] = metadata
@ -669,7 +681,8 @@ class SaveGLB(IO.ComfyNode):
# Handle Mesh input - save vertices and faces as GLB # Handle Mesh input - save vertices and faces as GLB
for i in range(mesh.vertices.shape[0]): for i in range(mesh.vertices.shape[0]):
f = f"{filename}_{counter:05}_.glb" f = f"{filename}_{counter:05}_.glb"
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata) v_colors = mesh.colors[i] if hasattr(mesh, "colors") and mesh.colors is not None else None
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata, v_colors)
results.append({ results.append({
"filename": f, "filename": f,
"subfolder": subfolder, "subfolder": subfolder,

View File

@ -45,6 +45,34 @@ def shape_norm(shape_latent, coords):
samples = samples * std + mean samples = samples * std + mean
return samples return samples
def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution, chunk_size=4096):
"""
Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field.
Keeps chunking internal to prevent OOM crashes on large matrices.
"""
device = voxel_coords.device
# Map Voxel Grid to Real 3D Space
origin = torch.tensor([-0.5, -0.5, -0.5], device=device)
voxel_size = 1.0 / resolution
voxel_pos = voxel_coords.float() * voxel_size + origin
verts = mesh.vertices.to(device).squeeze(0)
v_colors = torch.zeros((verts.shape[0], 3), device=device)
for i in range(0, verts.shape[0], chunk_size):
v_chunk = verts[i : i + chunk_size]
dists = torch.cdist(v_chunk, voxel_pos)
nearest_idx = torch.argmin(dists, dim=1)
v_colors[i : i + chunk_size] = voxel_colors[nearest_idx]
final_colors = (v_colors * 0.5 + 0.5).clamp(0, 1).unsqueeze(0)
out_mesh = copy.deepcopy(mesh)
out_mesh.colors = final_colors
return out_mesh
class VaeDecodeShapeTrellis(IO.ComfyNode): class VaeDecodeShapeTrellis(IO.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -90,18 +118,20 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
node_id="VaeDecodeTextureTrellis", node_id="VaeDecodeTextureTrellis",
category="latent/3d", category="latent/3d",
inputs=[ inputs=[
IO.Mesh.Input("shape_mesh"),
IO.Latent.Input("samples"), IO.Latent.Input("samples"),
IO.Vae.Input("vae"), IO.Vae.Input("vae"),
IO.AnyType.Input("shape_subs"), IO.AnyType.Input("shape_subs"),
], ],
outputs=[ outputs=[
IO.Voxel.Output("voxel"), IO.Mesh.Output("mesh"),
] ]
) )
@classmethod @classmethod
def execute(cls, samples, vae, shape_subs): def execute(cls, shape_mesh, samples, vae, shape_subs):
resolution = 1024
patcher = vae.patcher patcher = vae.patcher
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
comfy.model_management.load_model_gpu(patcher) comfy.model_management.load_model_gpu(patcher)
@ -116,9 +146,12 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
samples = SparseTensor(feats = samples, coords=coords) samples = SparseTensor(feats = samples, coords=coords)
samples = samples * std + mean samples = samples * std + mean
voxel = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5 voxel = vae.decode_tex_slat(samples, shape_subs)
voxel = Types.VOXEL(voxel) color_feats = voxel.feats[:, :3]
return IO.NodeOutput(voxel) voxel_coords = voxel.coords[:, 1:]
out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution)
return IO.NodeOutput(out_mesh)
class VaeDecodeStructureTrellis2(IO.ComfyNode): class VaeDecodeStructureTrellis2(IO.ComfyNode):
@classmethod @classmethod