small error fixes

This commit is contained in:
Yousef Rafat 2026-02-12 00:30:51 +02:00
parent b7764479c2
commit 8e90bdc1cc

View File

@ -1,5 +1,5 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.latest import ComfyExtension, IO, Types
import torch
from comfy.ldm.trellis2.model import SparseTensor
import comfy.model_management
@ -185,7 +185,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
IO.Vae.Input("vae"),
],
outputs=[
IO.Mesh.Output("structure_output"),
IO.Voxel.Output("structure_output"),
]
)
@ -194,16 +194,15 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
vae = vae.first_stage_model
decoder = vae.struct_dec
load_device = comfy.model_management.get_torch_device()
decoder = comfy.model_patcher.ModelPatcher(
decoder, load_device=load_device, offload_device=comfy.model_management.vae_offload_device()
)
comfy.model_management.load_model_gpu(decoder)
decoder = decoder.model
offload_device = comfy.model_management.vae_offload_device()
decoder = decoder.to(load_device)
samples = samples["samples"]
samples = samples.to(load_device)
decoded = decoder(samples)>0
coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int()
return IO.NodeOutput(coords)
decoder.to(offload_device)
comfy.model_management.get_offload_stream
out = Types.VOXEL(decoded.squeeze(1).float())
return IO.NodeOutput(out)
class Trellis2Conditioning(IO.ComfyNode):
@classmethod
@ -238,7 +237,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
node_id="EmptyShapeLatentTrellis2",
category="latent/3d",
inputs=[
IO.Mesh.Input("structure_output"),
IO.Voxel.Input("structure_output"),
],
outputs=[
IO.Latent.Output(),
@ -247,8 +246,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
@classmethod
def execute(cls, structure_output):
# i will see what i have to do here
coords = structure_output # or structure_output.coords
decoded = structure_output.data
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1)
in_channels = 32
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
@ -260,7 +259,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
node_id="EmptyTextureLatentTrellis2",
category="latent/3d",
inputs=[
IO.Mesh.Input("structure_output"),
IO.Voxel.Input("structure_output"),
],
outputs=[
IO.Latent.Output(),
@ -271,7 +270,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
def execute(cls, structure_output):
# TODO
in_channels = 32
latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1]))
latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1]))
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
class EmptyStructureLatentTrellis2(IO.ComfyNode):