This commit is contained in:
Yousef Rafat 2026-05-17 23:50:04 +03:00
parent 178e859b1b
commit 18c5f7d956
2 changed files with 533 additions and 563 deletions

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,7 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types, io
from comfy.ldm.trellis2.vae import SparseTensor
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
import comfy.model_management
from PIL import Image
import numpy as np
@ -24,57 +25,6 @@ def prepare_trellis_vae_for_decode(vae, sample_shape):
batch_number = max(1, int(free_memory / memory_required))
return batch_number
def pack_variable_mesh_batch(vertices, faces, colors=None):
batch_size = len(vertices)
max_vertices = max(v.shape[0] for v in vertices)
max_faces = max(f.shape[0] for f in faces)
packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1]))
packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1]))
vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64)
face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64)
for i, (v, f) in enumerate(zip(vertices, faces)):
packed_vertices[i, :v.shape[0]] = v
packed_faces[i, :f.shape[0]] = f
mesh = Types.MESH(packed_vertices, packed_faces)
mesh.vertex_counts = vertex_counts
mesh.face_counts = face_counts
if colors is not None:
max_colors = max(c.shape[0] for c in colors)
packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1]))
color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64)
for i, c in enumerate(colors):
packed_colors[i, :c.shape[0]] = c
mesh.vertex_colors = packed_colors
mesh.color_counts = color_counts
return mesh
def get_mesh_batch_item(mesh, index):
if hasattr(mesh, "vertex_counts"):
vertex_count = int(mesh.vertex_counts[index].item())
face_count = int(mesh.face_counts[index].item())
vertices = mesh.vertices[index, :vertex_count]
faces = mesh.faces[index, :face_count]
colors = None
if hasattr(mesh, "colors") and mesh.colors is not None:
if hasattr(mesh, "color_counts"):
color_count = int(mesh.color_counts[index].item())
colors = mesh.colors[index, :color_count]
else:
colors = mesh.colors[index, :vertex_count]
return vertices, faces, colors
colors = None
if hasattr(mesh, "colors") and mesh.colors is not None:
colors = mesh.colors[index]
return mesh.vertices[index], mesh.faces[index], colors
shape_slat_normalization = {
"mean": torch.tensor([
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
@ -263,14 +213,13 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
)),
],
outputs=[
IO.Voxel.Output("color_voxel"),
IO.Voxel.Output("voxel_colors"),
]
)
@classmethod
def execute(cls, samples, vae, shape_subdivides):
sample_tensor = samples["samples"]
resolution = int(vae.first_stage_model.resolution.item())
device = comfy.model_management.get_torch_device()
coords = samples["coords"]
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
@ -287,9 +236,9 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides)
color_feats = voxel.feats[:, :3]
voxel_coords = voxel.coords[:, 1:]
voxel_coords = voxel.coords#[:, 1:]
voxel = Types.VOXEL(voxel_coords, color_feats, resolution)
voxel = Types.VOXEL(voxel_coords, color_feats, 1024)
return IO.NodeOutput(voxel)
class VaeDecodeStructureTrellis2(IO.ComfyNode):
@ -607,6 +556,9 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
# to accept the upscaled coords
is_512_pass = False
if isinstance(voxel, dict):
voxel = voxel["coords"]
if hasattr(voxel, "data") and voxel.data.ndim == 4:
decoded = voxel.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
@ -627,8 +579,8 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
generation_mode = "shape_generation_512"
else:
generation_mode = "shape_generation"
return IO.NodeOutput({"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2",
"model_options": {"generation_mode": generation_mode, "coords": coords, "coords_counts": counts}})
return IO.NodeOutput({"samples": latent, "coords": coords, "coord_counts": counts, "type": "trellis2",
"model_options": {"generation_mode": generation_mode, "coords": coords, "coord_counts": counts}})
class EmptyTrellis2LatentTexture(IO.ComfyNode):
@classmethod
@ -655,6 +607,8 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
@classmethod
def execute(cls, voxel, shape_latent):
channels = 32
if isinstance(voxel, dict):
voxel = voxel["coords"]
if hasattr(voxel, "data") and voxel.data.ndim == 4:
decoded = voxel.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
@ -669,9 +623,9 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
latent = torch.zeros(batch_size, channels, max_tokens, 1)
return IO.NodeOutput({"samples": latent, "type": "trellis2", "coords": coords, "coords_counts": counts,
return IO.NodeOutput({"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts,
"model_options": {"generation_mode": "texture_generation",
"coords": coords, "coords_counts": counts, "shape_slat": shape_latent}})
"coords": coords, "coord_counts": counts, "shape_slat": shape_latent}})
class EmptyTrellis2LatentStructure(IO.ComfyNode):