small error fixes

This commit is contained in:
Yousef Rafat 2026-02-12 00:30:51 +02:00
parent b7764479c2
commit 8e90bdc1cc

View File

@ -1,5 +1,5 @@
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO from comfy_api.latest import ComfyExtension, IO, Types
import torch import torch
from comfy.ldm.trellis2.model import SparseTensor from comfy.ldm.trellis2.model import SparseTensor
import comfy.model_management import comfy.model_management
@ -185,7 +185,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
IO.Vae.Input("vae"), IO.Vae.Input("vae"),
], ],
outputs=[ outputs=[
IO.Mesh.Output("structure_output"), IO.Voxel.Output("structure_output"),
] ]
) )
@ -194,16 +194,15 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
vae = vae.first_stage_model vae = vae.first_stage_model
decoder = vae.struct_dec decoder = vae.struct_dec
load_device = comfy.model_management.get_torch_device() load_device = comfy.model_management.get_torch_device()
decoder = comfy.model_patcher.ModelPatcher( offload_device = comfy.model_management.vae_offload_device()
decoder, load_device=load_device, offload_device=comfy.model_management.vae_offload_device() decoder = decoder.to(load_device)
)
comfy.model_management.load_model_gpu(decoder)
decoder = decoder.model
samples = samples["samples"] samples = samples["samples"]
samples = samples.to(load_device) samples = samples.to(load_device)
decoded = decoder(samples)>0 decoded = decoder(samples)>0
coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int() decoder.to(offload_device)
return IO.NodeOutput(coords) comfy.model_management.get_offload_stream
out = Types.VOXEL(decoded.squeeze(1).float())
return IO.NodeOutput(out)
class Trellis2Conditioning(IO.ComfyNode): class Trellis2Conditioning(IO.ComfyNode):
@classmethod @classmethod
@ -238,7 +237,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
node_id="EmptyShapeLatentTrellis2", node_id="EmptyShapeLatentTrellis2",
category="latent/3d", category="latent/3d",
inputs=[ inputs=[
IO.Mesh.Input("structure_output"), IO.Voxel.Input("structure_output"),
], ],
outputs=[ outputs=[
IO.Latent.Output(), IO.Latent.Output(),
@ -247,8 +246,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
@classmethod @classmethod
def execute(cls, structure_output): def execute(cls, structure_output):
# i will see what i have to do here decoded = structure_output.data
coords = structure_output # or structure_output.coords coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1)
in_channels = 32 in_channels = 32
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
return IO.NodeOutput({"samples": latent, "type": "trellis2"}) return IO.NodeOutput({"samples": latent, "type": "trellis2"})
@ -260,7 +259,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
node_id="EmptyTextureLatentTrellis2", node_id="EmptyTextureLatentTrellis2",
category="latent/3d", category="latent/3d",
inputs=[ inputs=[
IO.Mesh.Input("structure_output"), IO.Voxel.Input("structure_output"),
], ],
outputs=[ outputs=[
IO.Latent.Output(), IO.Latent.Output(),
@ -271,7 +270,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
def execute(cls, structure_output): def execute(cls, structure_output):
# TODO # TODO
in_channels = 32 in_channels = 32
latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1])) latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1]))
return IO.NodeOutput({"samples": latent, "type": "trellis2"}) return IO.NodeOutput({"samples": latent, "type": "trellis2"})
class EmptyStructureLatentTrellis2(IO.ComfyNode): class EmptyStructureLatentTrellis2(IO.ComfyNode):