From c5a750205d1a35d5ee8937f8997f8b2c10e37b10 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Feb 2026 17:39:44 +0200 Subject: [PATCH] . --- comfy/ldm/trellis2/model.py | 14 +++++++++++--- comfy_extras/nodes_trellis2.py | 13 +++++++++---- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 8579b0580..5ff2a1ce0 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -8,6 +8,7 @@ from comfy.ldm.trellis2.attention import ( ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder from comfy.ldm.flux.math import apply_rope, apply_rope1 +import builtins class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: @@ -481,6 +482,8 @@ class SLatFlowModel(nn.Module): if isinstance(cond, list): cond = VarLenTensor.from_tensor_list(cond) + dtype = next(self.input_layer.parameters()).dtype + x = x.to(dtype) h = self.input_layer(x) h = manual_cast(h, self.dtype) t_emb = self.t_embedder(t, out_dtype = t.dtype) @@ -832,8 +835,14 @@ class Trellis2(nn.Module): _, cond = context.chunk(2) cond = embeds.chunk(2)[0] context = torch.cat([torch.zeros_like(cond), cond]) - mode = kwargs.get("generation_mode") - coords = kwargs.get("coords") + mode = getattr(builtins, "TRELLIS_MODE", "structure_generation") + coords = getattr(builtins, "TRELLIS_COORDS", None) + if coords is not None: + x = x.squeeze(0) + not_struct_mode = True + else: + mode = "structure_generation" + not_struct_mode = False transformer_options = kwargs.get("transformer_options") sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: @@ -842,7 +851,6 @@ class Trellis2(nn.Module): shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] - not_struct_mode = mode in ["shape_generation", "texture_generation"] if not_struct_mode: x = SparseTensor(feats=x, coords=coords) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 1683949a3..c8d84fd23 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -4,6 +4,7 @@ import torch import comfy.model_management from PIL import Image import numpy as np +import builtins shape_slat_normalization = { "mean": torch.tensor([ @@ -268,8 +269,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 - latent = torch.randn(coords.shape[0], in_channels) - return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation", "coords": coords}) + latent = torch.randn(1, coords.shape[0], in_channels) + builtins.TRELLIS_MODE = "shape_generation" + builtins.TRELLIS_COORDS = coords + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -292,7 +295,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) - return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation", "coords": coords}) + builtins.TRELLIS_MODE = "texture_generation" + builtins.TRELLIS_COORDS = coords + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod @@ -312,7 +317,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): in_channels = 8 resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) - return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "structure_generation"}) + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) def simplify_fn(vertices, faces, target=100000):