diff --git a/comfy_extras/trellis2.py b/comfy_extras/nodes_trellis2.py similarity index 82% rename from comfy_extras/trellis2.py rename to comfy_extras/nodes_trellis2.py index c3ad56007..304d95493 100644 --- a/comfy_extras/trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -193,7 +193,48 @@ class Trellis2Conditioning(IO.ComfyNode): negative = [[conditioning["cond_neg"], {embeds}]] return IO.NodeOutput(positive, negative) -class EmptyLatentTrellis2(IO.ComfyNode): +class EmptyShapeLatentTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentTrellis2", + category="latent/3d", + inputs=[ + IO.Latent.Input("structure_output"), + ], + outputs=[ + IO.Latent.Output(), + ] + ) + + def execute(cls, structure_output): + # i will see what i have to do here + coords = structure_output or structure_output.coords + in_channels = 32 + latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + +class EmptyTextureLatentTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentTrellis2", + category="latent/3d", + inputs=[ + IO.Latent.Input("structure_output"), + ], + outputs=[ + IO.Latent.Output(), + ] + ) + + def execute(cls, structure_output): + # TODO + in_channels = 32 + latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1])) + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + +class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( @@ -202,35 +243,26 @@ class EmptyLatentTrellis2(IO.ComfyNode): inputs=[ IO.Int.Input("resolution", default=3072, min=1, max=8192), IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), - IO.Vae.Input("vae"), - IO.Boolean.Input("shape_generation", tooltip="Setting to false will generate texture."), - IO.MultiCombo.Input("generation_type", options=["structure_generation", "shape_generation", "texture_generation"]) ], outputs=[ IO.Latent.Output(), ] ) - - @classmethod - def execute(cls, batch_size, coords, vae, generation_type) -> IO.NodeOutput: - # TODO: i will probably update how shape/texture is generated - # could split this too + + def execute(cls, res, batch_size): in_channels = 32 - shape_generation = generation_type == "shape_generation" - device = comfy.model_management.intermediate_device() - if shape_generation: - latent = SparseTensor(feats=torch.randn(batch_size, in_channels).to(device), coords=coords) - else: - # coords = shape_slat in txt gen case - latent = coords.replace(feats=torch.randn(coords.coords.shape[0], in_channels - coords.feats.shape[1]).to(device)) + latent = torch.randn(batch_size, in_channels, res, res, res) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + class Trellis2Extension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ Trellis2Conditioning, - EmptyLatentTrellis2, + EmptyShapeLatentTrellis2, + EmptyStructureLatentTrellis2, + EmptyTextureLatentTrellis2, VaeDecodeTextureTrellis, VaeDecodeShapeTrellis ] diff --git a/nodes.py b/nodes.py index 1cb43d9e2..051e808cc 100644 --- a/nodes.py +++ b/nodes.py @@ -2433,7 +2433,8 @@ async def init_builtin_extra_nodes(): "nodes_image_compare.py", "nodes_zimage.py", "nodes_lora_debug.py", - "nodes_color.py" + "nodes_color.py", + "nodes_trellis2.py" ] import_failed = []