From 7d444a4fcca545cf37d3bd42bc3e2d53f0994a3b Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 27 Feb 2026 22:22:07 +0200 Subject: [PATCH] resolution logit --- comfy/ldm/trellis2/model.py | 2 +- comfy_extras/nodes_trellis2.py | 23 +++++++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 45740faea..bd8309f2b 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -812,7 +812,7 @@ class Trellis2(nn.Module): raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure - timestep = timestep_reshift(timestep) + #timestep = timestep_reshift(timestep) orig_bsz = x.shape[0] if shape_rule: x = x[0].unsqueeze(0) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index e781d35e3..739233523 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -136,7 +136,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): IO.Latent.Input("samples"), IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), - IO.Int.Input("resolution", tooltip="Shape Generation Resolution"), + IO.Combo.Input("resolution", options=["512", "1024"], default="512") ], outputs=[ IO.Mesh.Output("mesh"), @@ -147,6 +147,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, structure_output, vae, resolution): + resolution = int(resolution) patcher = vae.patcher device = comfy.model_management.get_torch_device() comfy.model_management.load_model_gpu(patcher) @@ -154,14 +155,18 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): vae = vae.first_stage_model decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + samples = samples["samples"] - samples = samples.squeeze(-1).transpose(1, 2).to(device) + samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) std = shape_slat_normalization["std"].to(samples) mean = shape_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean mesh, subs = vae.decode_shape_slat(samples, resolution) + faces = torch.stack([m.faces for m in mesh]) + verts = torch.stack([m.vertices for m in mesh]) + mesh = Types.MESH(vertices=verts, faces=faces) return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @@ -192,13 +197,16 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() samples = samples["samples"] - samples = samples.squeeze(-1).transpose(1, 2).to(device) + samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) std = tex_slat_normalization["std"].to(samples) mean = tex_slat_normalization["mean"].to(samples) samples = SparseTensor(feats = samples, coords=coords) samples = samples * std + mean mesh = vae.decode_tex_slat(samples, shape_subs) + faces = torch.stack([m.faces for m in mesh]) + verts = torch.stack([m.vertices for m in mesh]) + mesh = Types.MESH(vertices=verts, faces=faces) return IO.NodeOutput(mesh) class VaeDecodeStructureTrellis2(IO.ComfyNode): @@ -210,6 +218,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): inputs=[ IO.Latent.Input("samples"), IO.Vae.Input("vae"), + IO.Combo.Input("resolution", options=["32", "64"], default="32") ], outputs=[ IO.Voxel.Output("structure_output"), @@ -217,7 +226,8 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, samples, vae): + def execute(cls, samples, vae, resolution): + resolution = int(resolution) vae = vae.first_stage_model decoder = vae.struct_dec load_device = comfy.model_management.get_torch_device() @@ -227,6 +237,11 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): samples = samples.to(load_device) decoded = decoder(samples)>0 decoder.to(offload_device) + current_res = decoded.shape[2] + + if current_res != resolution: + ratio = current_res // resolution + decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 out = Types.VOXEL(decoded.squeeze(1).float()) return IO.NodeOutput(out)