updated the trellis2 nodes

This commit is contained in:
Yousef Rafat 2026-02-02 21:20:46 +02:00
parent f1d25a460c
commit 23474ce816
2 changed files with 51 additions and 18 deletions

View File

@ -193,7 +193,48 @@ class Trellis2Conditioning(IO.ComfyNode):
negative = [[conditioning["cond_neg"], {embeds}]]
return IO.NodeOutput(positive, negative)
class EmptyLatentTrellis2(IO.ComfyNode):
class EmptyShapeLatentTrellis2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyLatentTrellis2",
category="latent/3d",
inputs=[
IO.Latent.Input("structure_output"),
],
outputs=[
IO.Latent.Output(),
]
)
def execute(cls, structure_output):
# i will see what i have to do here
coords = structure_output or structure_output.coords
in_channels = 32
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
class EmptyTextureLatentTrellis2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyLatentTrellis2",
category="latent/3d",
inputs=[
IO.Latent.Input("structure_output"),
],
outputs=[
IO.Latent.Output(),
]
)
def execute(cls, structure_output):
# TODO
in_channels = 32
latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1]))
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
class EmptyStructureLatentTrellis2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
@ -202,35 +243,26 @@ class EmptyLatentTrellis2(IO.ComfyNode):
inputs=[
IO.Int.Input("resolution", default=3072, min=1, max=8192),
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
IO.Vae.Input("vae"),
IO.Boolean.Input("shape_generation", tooltip="Setting to false will generate texture."),
IO.MultiCombo.Input("generation_type", options=["structure_generation", "shape_generation", "texture_generation"])
],
outputs=[
IO.Latent.Output(),
]
)
@classmethod
def execute(cls, batch_size, coords, vae, generation_type) -> IO.NodeOutput:
# TODO: i will probably update how shape/texture is generated
# could split this too
def execute(cls, res, batch_size):
in_channels = 32
shape_generation = generation_type == "shape_generation"
device = comfy.model_management.intermediate_device()
if shape_generation:
latent = SparseTensor(feats=torch.randn(batch_size, in_channels).to(device), coords=coords)
else:
# coords = shape_slat in txt gen case
latent = coords.replace(feats=torch.randn(coords.coords.shape[0], in_channels - coords.feats.shape[1]).to(device))
latent = torch.randn(batch_size, in_channels, res, res, res)
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
class Trellis2Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
Trellis2Conditioning,
EmptyLatentTrellis2,
EmptyShapeLatentTrellis2,
EmptyStructureLatentTrellis2,
EmptyTextureLatentTrellis2,
VaeDecodeTextureTrellis,
VaeDecodeShapeTrellis
]

View File

@ -2433,7 +2433,8 @@ async def init_builtin_extra_nodes():
"nodes_image_compare.py",
"nodes_zimage.py",
"nodes_lora_debug.py",
"nodes_color.py"
"nodes_color.py",
"nodes_trellis2.py"
]
import_failed = []