custom data types

This commit is contained in:
Yousef Rafat 2026-05-15 22:39:19 +03:00
parent 97220f703b
commit 6980e15921

View File

@ -1,5 +1,5 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types
from comfy_api.latest import ComfyExtension, IO, Types, io
from comfy.ldm.trellis2.vae import SparseTensor
import comfy.model_management
from PIL import Image
@ -10,6 +10,9 @@ import torch
import scipy
import copy
ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
HighResVoxel = io.Custom("HIGH_RES_VOXEL")
def prepare_trellis_vae_for_decode(vae, sample_shape):
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
if len(sample_shape) == 5:
@ -247,7 +250,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
],
outputs=[
IO.Mesh.Output("mesh"),
IO.AnyType.Output("shape_subdivides"),
ShapeSubdivides.Output(display_name = "shape_subdivides"),
]
)
@ -304,7 +307,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
IO.Mesh.Input("mesh"),
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
IO.AnyType.Input("shape_subdivides",
ShapeSubdivides.Input("shape_subdivides",
tooltip=(
"Shape information used to guide higher-detail reconstruction during decoding. "
"Helps preserve structure consistency at higher resolutions."
@ -419,7 +422,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode):
))
],
outputs=[
IO.AnyType.Output(
HighResVoxel.Output(
"high_res_voxel",
tooltip=(
"High-resolution sparse coordinates produced after cascade upsampling. "
@ -662,7 +665,7 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
IO.Model.Input("model"),
IO.MultiType.Input(
"voxel",
types=[IO.Voxel, IO.AnyType],
types=[IO.Voxel, HighResVoxel],
tooltip=(
"Shape structure input. Accepts either a voxel structure "
"or upsampled voxel coordinates from a previous cascade stage."
@ -676,20 +679,20 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
)
@classmethod
def execute(cls, model, shape_structure):
def execute(cls, model, voxel):
# to accept the upscaled coords
is_512_pass = False
if hasattr(shape_structure, "data") and shape_structure.data.ndim == 4:
decoded = shape_structure.data.unsqueeze(1)
if hasattr(voxel, "data") and voxel.data.ndim == 4:
decoded = voxel.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
is_512_pass = True
elif isinstance(shape_structure, torch.Tensor) and shape_structure.ndim == 2:
coords = shape_structure.int()
elif isinstance(voxel, torch.Tensor) and voxel.ndim == 2:
coords = voxel.int()
is_512_pass = False
else:
raise ValueError(f"Invalid input to EmptyTrellis2ShapeLatent: {type(shape_structure)}")
raise ValueError(f"Invalid input to EmptyTrellis2ShapeLatent: {type(voxel)}")
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
in_channels = 32
@ -719,7 +722,7 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
IO.Model.Input("model"),
IO.MultiType.Input(
"shape_structure",
types=[IO.Voxel, IO.AnyType],
types=[IO.Voxel, HighResVoxel],
tooltip=(
"Shape structure input. Accepts either a voxel structure "
"or upsampled coordinates from a previous cascade stage."