mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-07 12:02:37 +08:00
updated the trellis2 nodes
This commit is contained in:
parent
f1d25a460c
commit
23474ce816
@ -193,7 +193,48 @@ class Trellis2Conditioning(IO.ComfyNode):
|
|||||||
negative = [[conditioning["cond_neg"], {embeds}]]
|
negative = [[conditioning["cond_neg"], {embeds}]]
|
||||||
return IO.NodeOutput(positive, negative)
|
return IO.NodeOutput(positive, negative)
|
||||||
|
|
||||||
class EmptyLatentTrellis2(IO.ComfyNode):
|
class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="EmptyLatentTrellis2",
|
||||||
|
category="latent/3d",
|
||||||
|
inputs=[
|
||||||
|
IO.Latent.Input("structure_output"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Latent.Output(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute(cls, structure_output):
|
||||||
|
# i will see what i have to do here
|
||||||
|
coords = structure_output or structure_output.coords
|
||||||
|
in_channels = 32
|
||||||
|
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
||||||
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
|
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="EmptyLatentTrellis2",
|
||||||
|
category="latent/3d",
|
||||||
|
inputs=[
|
||||||
|
IO.Latent.Input("structure_output"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Latent.Output(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute(cls, structure_output):
|
||||||
|
# TODO
|
||||||
|
in_channels = 32
|
||||||
|
latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1]))
|
||||||
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
|
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
@ -202,35 +243,26 @@ class EmptyLatentTrellis2(IO.ComfyNode):
|
|||||||
inputs=[
|
inputs=[
|
||||||
IO.Int.Input("resolution", default=3072, min=1, max=8192),
|
IO.Int.Input("resolution", default=3072, min=1, max=8192),
|
||||||
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
|
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
|
||||||
IO.Vae.Input("vae"),
|
|
||||||
IO.Boolean.Input("shape_generation", tooltip="Setting to false will generate texture."),
|
|
||||||
IO.MultiCombo.Input("generation_type", options=["structure_generation", "shape_generation", "texture_generation"])
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Latent.Output(),
|
IO.Latent.Output(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
def execute(cls, res, batch_size):
|
||||||
def execute(cls, batch_size, coords, vae, generation_type) -> IO.NodeOutput:
|
|
||||||
# TODO: i will probably update how shape/texture is generated
|
|
||||||
# could split this too
|
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
shape_generation = generation_type == "shape_generation"
|
latent = torch.randn(batch_size, in_channels, res, res, res)
|
||||||
device = comfy.model_management.intermediate_device()
|
|
||||||
if shape_generation:
|
|
||||||
latent = SparseTensor(feats=torch.randn(batch_size, in_channels).to(device), coords=coords)
|
|
||||||
else:
|
|
||||||
# coords = shape_slat in txt gen case
|
|
||||||
latent = coords.replace(feats=torch.randn(coords.coords.shape[0], in_channels - coords.feats.shape[1]).to(device))
|
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
|
|
||||||
class Trellis2Extension(ComfyExtension):
|
class Trellis2Extension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
Trellis2Conditioning,
|
Trellis2Conditioning,
|
||||||
EmptyLatentTrellis2,
|
EmptyShapeLatentTrellis2,
|
||||||
|
EmptyStructureLatentTrellis2,
|
||||||
|
EmptyTextureLatentTrellis2,
|
||||||
VaeDecodeTextureTrellis,
|
VaeDecodeTextureTrellis,
|
||||||
VaeDecodeShapeTrellis
|
VaeDecodeShapeTrellis
|
||||||
]
|
]
|
||||||
3
nodes.py
3
nodes.py
@ -2433,7 +2433,8 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_image_compare.py",
|
"nodes_image_compare.py",
|
||||||
"nodes_zimage.py",
|
"nodes_zimage.py",
|
||||||
"nodes_lora_debug.py",
|
"nodes_lora_debug.py",
|
||||||
"nodes_color.py"
|
"nodes_color.py",
|
||||||
|
"nodes_trellis2.py"
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user