This commit is contained in:
Yousef Rafat 2026-05-16 01:00:55 +03:00
parent 6980e15921
commit ca7fe65e7e
4 changed files with 36 additions and 39 deletions

View File

@ -779,15 +779,18 @@ class Trellis2(nn.Module):
def forward(self, x, timestep, context, **kwargs):
transformer_options = kwargs.get("transformer_options", {})
model_options = {}
if hasattr(self, "meta"):
model_options = self.meta
timestep = timestep.to(x.dtype)
embeds = kwargs.get("embeds")
if embeds is None:
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
is_1024 = self.img2shape.resolution == 1024
coords = transformer_options.get("coords", None)
coord_counts = transformer_options.get("coord_counts", None)
mode = transformer_options.get("generation_mode", "structure_generation")
coords = model_options.get("coords", None)
coord_counts = model_options.get("coord_counts", None)
mode = model_options.get("generation_mode", "structure_generation")
is_512_run = False
if mode == "shape_generation_512":
@ -881,7 +884,7 @@ class Trellis2(nn.Module):
elif mode == "texture_generation":
if self.shape2txt is None:
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
slat = transformer_options.get("shape_slat")
slat = model_options.get("shape_slat")
if slat is None:
raise ValueError("shape_slat can't be None")

View File

@ -129,6 +129,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
unet_config["num_heads"] = 12
return unet_config
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: # trellis2 texture
unet_config = {}
unet_config["image_model"] = "trellis2"
unet_config["resolution"] = 64
unet_config["num_heads"] = 12
unet_config["txt_only"] = True
return unet_config
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
unet_config = {}
unet_config["audio_model"] = "dit1.0"

View File

@ -257,7 +257,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
@classmethod
def execute(cls, samples, vae):
resolution = int(vae.resolution.item())
resolution = int(vae.first_stage_model.resolution.item())
sample_tensor = samples["samples"]
device = comfy.model_management.get_torch_device()
coords = samples["coords"]
@ -322,7 +322,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
def execute(cls, mesh, samples, vae, shape_subdivides):
shape_mesh = mesh
sample_tensor = samples["samples"]
resolution = int(vae.resolution.item())
resolution = int(vae.first_stage_model.resolution.item())
device = comfy.model_management.get_torch_device()
coords = samples["coords"]
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
@ -662,7 +662,6 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
node_id="EmptyTrellis2ShapeLatent",
category="latent/3d",
inputs=[
IO.Model.Input("model"),
IO.MultiType.Input(
"voxel",
types=[IO.Voxel, HighResVoxel],
@ -673,13 +672,12 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
)
],
outputs=[
IO.Model.Output(),
IO.Latent.Output(),
]
)
@classmethod
def execute(cls, model, voxel):
def execute(cls, voxel):
# to accept the upscaled coords
is_512_pass = False
@ -698,19 +696,13 @@ class EmptyTrellis2ShapeLatent(IO.ComfyNode):
in_channels = 32
# image like format
latent = torch.zeros(batch_size, in_channels, max_tokens, 1)
model = model.clone()
model.model_options = model.model_options.copy()
if "transformer_options" in model.model_options:
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
else:
model.model_options["transformer_options"] = {}
model.model_options["transformer_options"]["coords"] = coords
if is_512_pass:
model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512"
generation_mode = "shape_generation_512"
else:
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
return IO.NodeOutput(model, {"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"})
generation_mode = "shape_generation"
return IO.NodeOutput({"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2",
"model_options": {"generation_mode": generation_mode, "coords": coords, "coords_counts": counts}})
class EmptyTrellis2LatentTexture(IO.ComfyNode):
@classmethod
@ -719,32 +711,30 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
node_id="EmptyTrellis2LatentTexture",
category="latent/3d",
inputs=[
IO.Model.Input("model"),
IO.MultiType.Input(
"shape_structure",
"voxel",
types=[IO.Voxel, HighResVoxel],
tooltip=(
"Shape structure input. Accepts either a voxel structure "
"or upsampled coordinates from a previous cascade stage."
"or upsampled voxel coordinates from a previous cascade stage."
)
),
IO.Latent.Input("shape_latent"),
],
outputs=[
IO.Model.Output(),
IO.Latent.Output(),
]
)
@classmethod
def execute(cls, model, shape_structure, shape_latent):
def execute(cls, voxel, shape_latent):
channels = 32
if hasattr(shape_structure, "data") and shape_structure.data.ndim == 4:
decoded = shape_structure.data.unsqueeze(1)
if hasattr(voxel, "data") and voxel.data.ndim == 4:
decoded = voxel.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
elif isinstance(shape_structure, torch.Tensor) and shape_structure.ndim == 2:
coords = shape_structure.int()
elif isinstance(voxel, torch.Tensor) and voxel.ndim == 2:
coords = voxel.int()
batch_size, counts, max_tokens = infer_batched_coord_layout(coords)
@ -753,17 +743,9 @@ class EmptyTrellis2LatentTexture(IO.ComfyNode):
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
latent = torch.zeros(batch_size, channels, max_tokens, 1)
model = model.clone()
model.model_options = model.model_options.copy()
if "transformer_options" in model.model_options:
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
else:
model.model_options["transformer_options"] = {}
model.model_options["transformer_options"]["coords"] = coords
model.model_options["transformer_options"]["generation_mode"] = "texture_generation"
model.model_options["transformer_options"]["shape_slat"] = shape_latent
return IO.NodeOutput(model, {"samples": latent, "coords": coords, "coords_counts": counts, "type": "trellis2"})
return IO.NodeOutput({"samples": latent, "type": "trellis2", "coords": coords, "coords_counts": counts,
"model_options": {"generation_mode": "texture_generation",
"coords": coords, "coords_counts": counts, "shape_slat": shape_latent}})
class EmptyTrellis2LatentStructure(IO.ComfyNode):

View File

@ -1532,6 +1532,10 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
if "model_options" in latent:
inner = model.model.diffusion_model
inner.meta = latent["model_options"]
callback = latent_preview.prepare_callback(model, steps)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,