mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-11 00:37:53 +08:00
custom data types
This commit is contained in:
parent
97220f703b
commit
6980e15921
@ -1,5 +1,5 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
from comfy_api.latest import ComfyExtension, IO, Types, io
|
||||
from comfy.ldm.trellis2.vae import SparseTensor
|
||||
import comfy.model_management
|
||||
from PIL import Image
|
||||
@ -10,6 +10,9 @@ import torch
|
||||
import scipy
|
||||
import copy
|
||||
|
||||
ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
|
||||
HighResVoxel = io.Custom("HIGH_RES_VOXEL")
|
||||
|
||||
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
||||
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
|
||||
if len(sample_shape) == 5:
|
||||
@ -247,7 +250,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output("mesh"),
|
||||
IO.AnyType.Output("shape_subdivides"),
|
||||
ShapeSubdivides.Output(display_name = "shape_subdivides"),
|
||||
]
|
||||
)
|
||||
|
||||
@ -304,7 +307,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.AnyType.Input("shape_subdivides",
|
||||
ShapeSubdivides.Input("shape_subdivides",
|
||||
tooltip=(
|
||||
"Shape information used to guide higher-detail reconstruction during decoding. "
|
||||
"Helps preserve structure consistency at higher resolutions."
|
||||
@ -419,7 +422,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode):
|
||||
))
|
||||
],
|
||||
outputs=[
|
||||
IO.AnyType.Output(
|
||||
HighResVoxel.Output(
|
||||
"high_res_voxel",
|
||||
tooltip=(
|
||||
"High-resolution sparse coordinates produced after cascade upsampling. "
|
||||
@ -662,7 +665,7 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
|
||||
IO.Model.Input("model"),
|
||||
IO.MultiType.Input(
|
||||
"voxel",
|
||||
types=[IO.Voxel, IO.AnyType],
|
||||
types=[IO.Voxel, HighResVoxel],
|
||||
tooltip=(
|
||||
"Shape structure input. Accepts either a voxel structure "
|
||||
"or upsampled voxel coordinates from a previous cascade stage."
|
||||
@ -676,20 +679,20 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, shape_structure):
|
||||
def execute(cls, model, voxel):
|
||||
# to accept the upscaled coords
|
||||
is_512_pass = False
|
||||
|
||||
if hasattr(shape_structure, "data") and shape_structure.data.ndim == 4:
|
||||
decoded = shape_structure.data.unsqueeze(1)
|
||||
if hasattr(voxel, "data") and voxel.data.ndim == 4:
|
||||
decoded = voxel.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
is_512_pass = True
|
||||
|
||||
elif isinstance(shape_structure, torch.Tensor) and shape_structure.ndim == 2:
|
||||
coords = shape_structure.int()
|
||||
elif isinstance(voxel, torch.Tensor) and voxel.ndim == 2:
|
||||
coords = voxel.int()
|
||||
is_512_pass = False
|
||||
else:
|
||||
raise ValueError(f"Invalid input to EmptyTrellis2ShapeLatent: {type(shape_structure)}")
|
||||
raise ValueError(f"Invalid input to EmptyTrellis2ShapeLatent: {type(voxel)}")
|
||||
|
||||
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
|
||||
in_channels = 32
|
||||
@ -719,7 +722,7 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
|
||||
IO.Model.Input("model"),
|
||||
IO.MultiType.Input(
|
||||
"shape_structure",
|
||||
types=[IO.Voxel, IO.AnyType],
|
||||
types=[IO.Voxel, HighResVoxel],
|
||||
tooltip=(
|
||||
"Shape structure input. Accepts either a voxel structure "
|
||||
"or upsampled coordinates from a previous cascade stage."
|
||||
|
||||
Loading…
Reference in New Issue
Block a user