From 8e90bdc1ccad930527cbf3cd5170590ff3eb7902 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 12 Feb 2026 00:30:51 +0200 Subject: [PATCH] small error fixes --- comfy_extras/nodes_trellis2.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index c735469be..9fd257785 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,5 +1,5 @@ from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO +from comfy_api.latest import ComfyExtension, IO, Types import torch from comfy.ldm.trellis2.model import SparseTensor import comfy.model_management @@ -185,7 +185,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): IO.Vae.Input("vae"), ], outputs=[ - IO.Mesh.Output("structure_output"), + IO.Voxel.Output("structure_output"), ] ) @@ -194,16 +194,15 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): vae = vae.first_stage_model decoder = vae.struct_dec load_device = comfy.model_management.get_torch_device() - decoder = comfy.model_patcher.ModelPatcher( - decoder, load_device=load_device, offload_device=comfy.model_management.vae_offload_device() - ) - comfy.model_management.load_model_gpu(decoder) - decoder = decoder.model + offload_device = comfy.model_management.vae_offload_device() + decoder = decoder.to(load_device) samples = samples["samples"] samples = samples.to(load_device) decoded = decoder(samples)>0 - coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int() - return IO.NodeOutput(coords) + decoder.to(offload_device) + comfy.model_management.get_offload_stream + out = Types.VOXEL(decoded.squeeze(1).float()) + return IO.NodeOutput(out) class Trellis2Conditioning(IO.ComfyNode): @classmethod @@ -238,7 +237,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): node_id="EmptyShapeLatentTrellis2", category="latent/3d", inputs=[ - IO.Mesh.Input("structure_output"), + IO.Voxel.Input("structure_output"), ], outputs=[ IO.Latent.Output(), @@ -247,8 +246,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): - # i will see what i have to do here - coords = structure_output # or structure_output.coords + decoded = structure_output.data + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1) in_channels = 32 latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) @@ -260,7 +259,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): node_id="EmptyTextureLatentTrellis2", category="latent/3d", inputs=[ - IO.Mesh.Input("structure_output"), + IO.Voxel.Input("structure_output"), ], outputs=[ IO.Latent.Output(), @@ -271,7 +270,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): def execute(cls, structure_output): # TODO in_channels = 32 - latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1])) + latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1])) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyStructureLatentTrellis2(IO.ComfyNode):