From 3d5f9aead7a49802d3fddaf0cf4c274a20acbb06 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 14 May 2026 20:37:58 +0300 Subject: [PATCH] shape_structure and tooltip --- comfy_extras/nodes_trellis2.py | 40 +++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index f424213e9..3acfdcabd 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -648,7 +648,14 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode): category="latent/3d", inputs=[ IO.Model.Input("model"), - IO.AnyType.Input("structure_or_coords"), + IO.MultiType.Input( + "shape_structure", + types=[IO.Voxel, IO.AnyType], + tooltip=( + "Shape structure input. Accepts either a voxel structure " + "or upsampled coordinates from a previous cascade stage." + ) + ) ], outputs=[ IO.Model.Output(), @@ -657,20 +664,20 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode): ) @classmethod - def execute(cls, model, structure_or_coords): + def execute(cls, model, shape_structure): # to accept the upscaled coords is_512_pass = False - if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: - decoded = structure_or_coords.data.unsqueeze(1) + if hasattr(shape_structure, "data") and shape_structure.data.ndim == 4: + decoded = shape_structure.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() is_512_pass = True - elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: - coords = structure_or_coords.int() + elif isinstance(shape_structure, torch.Tensor) and shape_structure.ndim == 2: + coords = shape_structure.int() is_512_pass = False else: - raise ValueError(f"Invalid input to EmptyTrellis2ShapeLatent: {type(structure_or_coords)}") + raise ValueError(f"Invalid input to EmptyTrellis2ShapeLatent: {type(shape_structure)}") batch_size, counts, max_tokens = infer_batched_coord_layout(coords) in_channels = 32 @@ -698,7 +705,14 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode): category="latent/3d", inputs=[ IO.Model.Input("model"), - IO.Voxel.Input("structure_or_coords"), + IO.MultiType.Input( + "shape_structure", + types=[IO.Voxel, IO.AnyType], + tooltip=( + "Shape structure input. Accepts either a voxel structure " + "or upsampled coordinates from a previous cascade stage." + ) + ), IO.Latent.Input("shape_latent"), ], outputs=[ @@ -708,14 +722,14 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode): ) @classmethod - def execute(cls, model, structure_or_coords, shape_latent): + def execute(cls, model, shape_structure, shape_latent): channels = 32 - if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: - decoded = structure_or_coords.data.unsqueeze(1) + if hasattr(shape_structure, "data") and shape_structure.data.ndim == 4: + decoded = shape_structure.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() - elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: - coords = structure_or_coords.int() + elif isinstance(shape_structure, torch.Tensor) and shape_structure.ndim == 2: + coords = shape_structure.int() batch_size, counts, max_tokens = infer_batched_coord_layout(coords)