From 56e52e5d03f52c407cf529c8b211a1636a3ed221 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 24 Mar 2026 02:44:03 +0200 Subject: [PATCH] work on txt gen --- comfy/ldm/trellis2/model.py | 15 ++++++++------- comfy_extras/nodes_trellis2.py | 33 ++++++++++++++++++--------------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 8a0c6d8b6..34aeba3e1 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -754,6 +754,7 @@ class Trellis2(nn.Module): "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations } self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) + self.shape2txt = None if init_txt_model: self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args) @@ -835,11 +836,12 @@ class Trellis2(nn.Module): slat = transformer_options.get("shape_slat") if slat is None: raise ValueError("shape_slat can't be None") - slat.feats = slat.feats.repeat(B, 1) - x_st = sparse_cat([x_st, slat]) + + base_slat_feats = slat.feats[:N] + slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device) + x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1)) out = self.shape2txt(x_st, t_eval, c_eval) else: # structure - #timestep = timestep_reshift(timestep) orig_bsz = x.shape[0] if shape_rule: x = x[0].unsqueeze(0) @@ -850,8 +852,7 @@ class Trellis2(nn.Module): if not_struct_mode: out = out.feats - if not_struct_mode: - out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) - if rule and orig_bsz > 1: - out = out.repeat(orig_bsz, 1, 1, 1) + out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) + if rule and orig_bsz > 1: + out = out.repeat(orig_bsz, 1, 1, 1) return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index cba6b3241..dcf8dcb98 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -95,7 +95,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): IO.AnyType.Input("shape_subs"), ], outputs=[ - IO.Mesh.Output("mesh"), + IO.Voxel.Output("voxel"), ] ) @@ -116,11 +116,9 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean - mesh = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5 - faces = torch.stack([m.faces for m in mesh]) - verts = torch.stack([m.vertices for m in mesh]) - mesh = Types.MESH(vertices=verts, faces=faces) - return IO.NodeOutput(mesh) + voxel = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5 + voxel = Types.VOXEL(voxel) + return IO.NodeOutput(voxel) class VaeDecodeStructureTrellis2(IO.ComfyNode): @classmethod @@ -377,7 +375,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): node_id="EmptyTextureLatentTrellis2", category="latent/3d", inputs=[ - IO.Voxel.Input("structure_output"), + IO.Voxel.Input("structure_or_coords"), IO.Latent.Input("shape_latent"), IO.Model.Input("model") ], @@ -388,16 +386,21 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_output, shape_latent, model): - # TODO - decoded = structure_output.data.unsqueeze(1) - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() - in_channels = 32 + def execute(cls, structure_or_coords, shape_latent, model): + channels = 32 + if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: + decoded = structure_or_coords.data.unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + + elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: + coords = structure_or_coords.int() shape_latent = shape_latent["samples"] + if shape_latent.ndim == 4: + shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels) shape_latent = shape_norm(shape_latent, coords) - latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) + latent = torch.randn(1, channels, coords.shape[0], 1) model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: @@ -406,9 +409,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords - model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + model.model_options["transformer_options"]["generation_mode"] = "texture_generation" model.model_options["transformer_options"]["shape_slat"] = shape_latent - return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) + return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model) class EmptyStructureLatentTrellis2(IO.ComfyNode):