mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-24 17:32:40 +08:00
small error fixes
This commit is contained in:
parent
b7764479c2
commit
8e90bdc1cc
@ -1,5 +1,5 @@
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO
|
from comfy_api.latest import ComfyExtension, IO, Types
|
||||||
import torch
|
import torch
|
||||||
from comfy.ldm.trellis2.model import SparseTensor
|
from comfy.ldm.trellis2.model import SparseTensor
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@ -185,7 +185,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
|||||||
IO.Vae.Input("vae"),
|
IO.Vae.Input("vae"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Mesh.Output("structure_output"),
|
IO.Voxel.Output("structure_output"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -194,16 +194,15 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
|||||||
vae = vae.first_stage_model
|
vae = vae.first_stage_model
|
||||||
decoder = vae.struct_dec
|
decoder = vae.struct_dec
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
decoder = comfy.model_patcher.ModelPatcher(
|
offload_device = comfy.model_management.vae_offload_device()
|
||||||
decoder, load_device=load_device, offload_device=comfy.model_management.vae_offload_device()
|
decoder = decoder.to(load_device)
|
||||||
)
|
|
||||||
comfy.model_management.load_model_gpu(decoder)
|
|
||||||
decoder = decoder.model
|
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
samples = samples.to(load_device)
|
samples = samples.to(load_device)
|
||||||
decoded = decoder(samples)>0
|
decoded = decoder(samples)>0
|
||||||
coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int()
|
decoder.to(offload_device)
|
||||||
return IO.NodeOutput(coords)
|
comfy.model_management.get_offload_stream
|
||||||
|
out = Types.VOXEL(decoded.squeeze(1).float())
|
||||||
|
return IO.NodeOutput(out)
|
||||||
|
|
||||||
class Trellis2Conditioning(IO.ComfyNode):
|
class Trellis2Conditioning(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -238,7 +237,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
node_id="EmptyShapeLatentTrellis2",
|
node_id="EmptyShapeLatentTrellis2",
|
||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Mesh.Input("structure_output"),
|
IO.Voxel.Input("structure_output"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Latent.Output(),
|
IO.Latent.Output(),
|
||||||
@ -247,8 +246,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, structure_output):
|
def execute(cls, structure_output):
|
||||||
# i will see what i have to do here
|
decoded = structure_output.data
|
||||||
coords = structure_output # or structure_output.coords
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1)
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
@ -260,7 +259,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
node_id="EmptyTextureLatentTrellis2",
|
node_id="EmptyTextureLatentTrellis2",
|
||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Mesh.Input("structure_output"),
|
IO.Voxel.Input("structure_output"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Latent.Output(),
|
IO.Latent.Output(),
|
||||||
@ -271,7 +270,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
def execute(cls, structure_output):
|
def execute(cls, structure_output):
|
||||||
# TODO
|
# TODO
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1]))
|
latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1]))
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user