mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 07:22:34 +08:00
.
This commit is contained in:
parent
6191cd86bf
commit
c5a750205d
@ -8,6 +8,7 @@ from comfy.ldm.trellis2.attention import (
|
|||||||
)
|
)
|
||||||
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
||||||
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
||||||
|
import builtins
|
||||||
|
|
||||||
class SparseGELU(nn.GELU):
|
class SparseGELU(nn.GELU):
|
||||||
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
||||||
@ -481,6 +482,8 @@ class SLatFlowModel(nn.Module):
|
|||||||
if isinstance(cond, list):
|
if isinstance(cond, list):
|
||||||
cond = VarLenTensor.from_tensor_list(cond)
|
cond = VarLenTensor.from_tensor_list(cond)
|
||||||
|
|
||||||
|
dtype = next(self.input_layer.parameters()).dtype
|
||||||
|
x = x.to(dtype)
|
||||||
h = self.input_layer(x)
|
h = self.input_layer(x)
|
||||||
h = manual_cast(h, self.dtype)
|
h = manual_cast(h, self.dtype)
|
||||||
t_emb = self.t_embedder(t, out_dtype = t.dtype)
|
t_emb = self.t_embedder(t, out_dtype = t.dtype)
|
||||||
@ -832,8 +835,14 @@ class Trellis2(nn.Module):
|
|||||||
_, cond = context.chunk(2)
|
_, cond = context.chunk(2)
|
||||||
cond = embeds.chunk(2)[0]
|
cond = embeds.chunk(2)[0]
|
||||||
context = torch.cat([torch.zeros_like(cond), cond])
|
context = torch.cat([torch.zeros_like(cond), cond])
|
||||||
mode = kwargs.get("generation_mode")
|
mode = getattr(builtins, "TRELLIS_MODE", "structure_generation")
|
||||||
coords = kwargs.get("coords")
|
coords = getattr(builtins, "TRELLIS_COORDS", None)
|
||||||
|
if coords is not None:
|
||||||
|
x = x.squeeze(0)
|
||||||
|
not_struct_mode = True
|
||||||
|
else:
|
||||||
|
mode = "structure_generation"
|
||||||
|
not_struct_mode = False
|
||||||
transformer_options = kwargs.get("transformer_options")
|
transformer_options = kwargs.get("transformer_options")
|
||||||
sigmas = transformer_options.get("sigmas")[0].item()
|
sigmas = transformer_options.get("sigmas")[0].item()
|
||||||
if sigmas < 1.00001:
|
if sigmas < 1.00001:
|
||||||
@ -842,7 +851,6 @@ class Trellis2(nn.Module):
|
|||||||
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
||||||
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
||||||
|
|
||||||
not_struct_mode = mode in ["shape_generation", "texture_generation"]
|
|
||||||
if not_struct_mode:
|
if not_struct_mode:
|
||||||
x = SparseTensor(feats=x, coords=coords)
|
x = SparseTensor(feats=x, coords=coords)
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import builtins
|
||||||
|
|
||||||
shape_slat_normalization = {
|
shape_slat_normalization = {
|
||||||
"mean": torch.tensor([
|
"mean": torch.tensor([
|
||||||
@ -268,8 +269,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
|||||||
decoded = structure_output.data.unsqueeze(1)
|
decoded = structure_output.data.unsqueeze(1)
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = torch.randn(coords.shape[0], in_channels)
|
latent = torch.randn(1, coords.shape[0], in_channels)
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation", "coords": coords})
|
builtins.TRELLIS_MODE = "shape_generation"
|
||||||
|
builtins.TRELLIS_COORDS = coords
|
||||||
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -292,7 +295,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
|
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation", "coords": coords})
|
builtins.TRELLIS_MODE = "texture_generation"
|
||||||
|
builtins.TRELLIS_COORDS = coords
|
||||||
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -312,7 +317,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
|||||||
in_channels = 8
|
in_channels = 8
|
||||||
resolution = 16
|
resolution = 16
|
||||||
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "structure_generation"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
def simplify_fn(vertices, faces, target=100000):
|
def simplify_fn(vertices, faces, target=100000):
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user