From 0e51bee64ff909f5dff90a2782e4973818196624 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 13 Feb 2026 00:10:25 +0200 Subject: [PATCH] more reliable detection --- comfy/ldm/trellis2/model.py | 8 +------- comfy_extras/nodes_trellis2.py | 3 +++ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 17286a553..760372f5c 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -826,13 +826,7 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): embeds = kwargs.get("embeds") - if not hasattr(x, "feats"): - mode = "structure_generation" - else: - if x.feats.shape[1] == 32: - mode = "shape_generation" - else: - mode = "texture_generation" + mode = x.generation_mode if mode == "shape_generation": # TODO out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)])) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 9fd257785..a5c387c1d 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -250,6 +250,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1) in_channels = 32 latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) + latent.generation_mode = "shape_generation" return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @@ -271,6 +272,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): # TODO in_channels = 32 latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1])) + latent.generation_mode = "texture_generation" return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @@ -291,6 +293,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): in_channels = 8 resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) + latent.generation_mode = "structure_generation" return IO.NodeOutput({"samples": latent, "type": "trellis2"}) def simplify_fn(vertices, faces, target=100000):