mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 19:42:34 +08:00
updated the trellis2 nodes
This commit is contained in:
parent
f1d25a460c
commit
23474ce816
@ -193,7 +193,48 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
negative = [[conditioning["cond_neg"], {embeds}]]
|
||||
return IO.NodeOutput(positive, negative)
|
||||
|
||||
class EmptyLatentTrellis2(IO.ComfyNode):
|
||||
class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyLatentTrellis2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("structure_output"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
def execute(cls, structure_output):
|
||||
# i will see what i have to do here
|
||||
coords = structure_output or structure_output.coords
|
||||
in_channels = 32
|
||||
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||
|
||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyLatentTrellis2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("structure_output"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
def execute(cls, structure_output):
|
||||
# TODO
|
||||
in_channels = 32
|
||||
latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1]))
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||
|
||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
@ -202,35 +243,26 @@ class EmptyLatentTrellis2(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Int.Input("resolution", default=3072, min=1, max=8192),
|
||||
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Boolean.Input("shape_generation", tooltip="Setting to false will generate texture."),
|
||||
IO.MultiCombo.Input("generation_type", options=["structure_generation", "shape_generation", "texture_generation"])
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, batch_size, coords, vae, generation_type) -> IO.NodeOutput:
|
||||
# TODO: i will probably update how shape/texture is generated
|
||||
# could split this too
|
||||
|
||||
def execute(cls, res, batch_size):
|
||||
in_channels = 32
|
||||
shape_generation = generation_type == "shape_generation"
|
||||
device = comfy.model_management.intermediate_device()
|
||||
if shape_generation:
|
||||
latent = SparseTensor(feats=torch.randn(batch_size, in_channels).to(device), coords=coords)
|
||||
else:
|
||||
# coords = shape_slat in txt gen case
|
||||
latent = coords.replace(feats=torch.randn(coords.coords.shape[0], in_channels - coords.feats.shape[1]).to(device))
|
||||
latent = torch.randn(batch_size, in_channels, res, res, res)
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||
|
||||
|
||||
class Trellis2Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
Trellis2Conditioning,
|
||||
EmptyLatentTrellis2,
|
||||
EmptyShapeLatentTrellis2,
|
||||
EmptyStructureLatentTrellis2,
|
||||
EmptyTextureLatentTrellis2,
|
||||
VaeDecodeTextureTrellis,
|
||||
VaeDecodeShapeTrellis
|
||||
]
|
||||
Loading…
Reference in New Issue
Block a user