diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 748ff967d..55777e1e9 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,5 +1,5 @@ from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO, Types +from comfy_api.latest import ComfyExtension, IO, Types, io from comfy.ldm.trellis2.vae import SparseTensor import comfy.model_management from PIL import Image @@ -10,6 +10,9 @@ import torch import scipy import copy +ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES") +HighResVoxel = io.Custom("HIGH_RES_VOXEL") + def prepare_trellis_vae_for_decode(vae, sample_shape): memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype) if len(sample_shape) == 5: @@ -247,7 +250,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): ], outputs=[ IO.Mesh.Output("mesh"), - IO.AnyType.Output("shape_subdivides"), + ShapeSubdivides.Output(display_name = "shape_subdivides"), ] ) @@ -304,7 +307,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): IO.Mesh.Input("mesh"), IO.Latent.Input("samples"), IO.Vae.Input("vae"), - IO.AnyType.Input("shape_subdivides", + ShapeSubdivides.Input("shape_subdivides", tooltip=( "Shape information used to guide higher-detail reconstruction during decoding. " "Helps preserve structure consistency at higher resolutions." @@ -419,7 +422,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode): )) ], outputs=[ - IO.AnyType.Output( + HighResVoxel.Output( "high_res_voxel", tooltip=( "High-resolution sparse coordinates produced after cascade upsampling. " @@ -662,7 +665,7 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode): IO.Model.Input("model"), IO.MultiType.Input( "voxel", - types=[IO.Voxel, IO.AnyType], + types=[IO.Voxel, HighResVoxel], tooltip=( "Shape structure input. Accepts either a voxel structure " "or upsampled voxel coordinates from a previous cascade stage." @@ -676,20 +679,20 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode): ) @classmethod - def execute(cls, model, shape_structure): + def execute(cls, model, voxel): # to accept the upscaled coords is_512_pass = False - if hasattr(shape_structure, "data") and shape_structure.data.ndim == 4: - decoded = shape_structure.data.unsqueeze(1) + if hasattr(voxel, "data") and voxel.data.ndim == 4: + decoded = voxel.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() is_512_pass = True - elif isinstance(shape_structure, torch.Tensor) and shape_structure.ndim == 2: - coords = shape_structure.int() + elif isinstance(voxel, torch.Tensor) and voxel.ndim == 2: + coords = voxel.int() is_512_pass = False else: - raise ValueError(f"Invalid input to EmptyTrellis2ShapeLatent: {type(shape_structure)}") + raise ValueError(f"Invalid input to EmptyTrellis2ShapeLatent: {type(voxel)}") batch_size, counts, max_tokens = infer_batched_coord_layout(coords) in_channels = 32 @@ -719,7 +722,7 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode): IO.Model.Input("model"), IO.MultiType.Input( "shape_structure", - types=[IO.Voxel, IO.AnyType], + types=[IO.Voxel, HighResVoxel], tooltip=( "Shape structure input. Accepts either a voxel structure " "or upsampled coordinates from a previous cascade stage."