mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-30 20:32:45 +08:00
more reliable detection
This commit is contained in:
parent
0e239dc39b
commit
0e51bee64f
@ -826,13 +826,7 @@ class Trellis2(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, timestep, context, **kwargs):
|
def forward(self, x, timestep, context, **kwargs):
|
||||||
embeds = kwargs.get("embeds")
|
embeds = kwargs.get("embeds")
|
||||||
if not hasattr(x, "feats"):
|
mode = x.generation_mode
|
||||||
mode = "structure_generation"
|
|
||||||
else:
|
|
||||||
if x.feats.shape[1] == 32:
|
|
||||||
mode = "shape_generation"
|
|
||||||
else:
|
|
||||||
mode = "texture_generation"
|
|
||||||
if mode == "shape_generation":
|
if mode == "shape_generation":
|
||||||
# TODO
|
# TODO
|
||||||
out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)]))
|
out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)]))
|
||||||
|
|||||||
@ -250,6 +250,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1)
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1)
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
|
||||||
|
latent.generation_mode = "shape_generation"
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||||
@ -271,6 +272,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
# TODO
|
# TODO
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1]))
|
latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1]))
|
||||||
|
latent.generation_mode = "texture_generation"
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||||
@ -291,6 +293,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
|||||||
in_channels = 8
|
in_channels = 8
|
||||||
resolution = 16
|
resolution = 16
|
||||||
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
||||||
|
latent.generation_mode = "structure_generation"
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
def simplify_fn(vertices, faces, target=100000):
|
def simplify_fn(vertices, faces, target=100000):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user