From 2d904b28da9631a756784b9bd54c4b46b8290522 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 11 Mar 2026 22:50:17 +0200 Subject: [PATCH] upscale node + simple node simplification --- comfy_extras/nodes_trellis2.py | 78 ++++++++++++++++++++++++++++------ 1 file changed, 65 insertions(+), 13 deletions(-) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 0b94a2d0a..86f08f8bd 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -53,7 +53,6 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): category="latent/3d", inputs=[ IO.Latent.Input("samples"), - IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), IO.Combo.Input("resolution", options=["512", "1024"], default="512") ], @@ -64,7 +63,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples, structure_output, vae, resolution): + def execute(cls, samples, vae, resolution): resolution = int(resolution) patcher = vae.patcher @@ -72,8 +71,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): comfy.model_management.load_model_gpu(patcher) vae = vae.first_stage_model - decoded = structure_output.data.unsqueeze(1) - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + coords = samples["coords"] samples = samples["samples"] samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) @@ -93,7 +91,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): category="latent/3d", inputs=[ IO.Latent.Input("samples"), - IO.Voxel.Input("structure_output"), IO.Vae.Input("vae"), IO.AnyType.Input("shape_subs"), ], @@ -103,15 +100,15 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples, structure_output, vae, shape_subs): + def execute(cls, samples, vae, shape_subs): patcher = vae.patcher device = comfy.model_management.get_torch_device() comfy.model_management.load_model_gpu(patcher) vae = vae.first_stage_model - decoded = structure_output.data.unsqueeze(1) - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + coords = samples["coords"] + samples = samples["samples"] samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) std = tex_slat_normalization["std"].to(samples) @@ -161,6 +158,56 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): out = Types.VOXEL(decoded.squeeze(1).float()) return IO.NodeOutput(out) +class Trellis2UpsampleCascade(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Trellis2UpsampleCascade", + category="latent/3d", + inputs=[ + IO.Latent.Input("shape_latent_512"), + IO.Vae.Input("vae"), + IO.Combo.Input("target_resolution", options=["1024", "1536"], default="1024"), + IO.Int.Input("max_tokens", default=49152, min=1024, max=100000) + ], + outputs=[ + IO.AnyType.Output("hr_coords"), + ] + ) + + @classmethod + def execute(cls, shape_latent_512, vae, target_resolution, max_tokens): + device = comfy.model_management.get_torch_device() + + feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) + coords_512 = shape_latent_512["coords"].to(device) + + slat = shape_norm(feats, coords_512) + + decoder = vae.first_stage_model.shape_dec + decoder.to(device) + + slat.feats = slat.feats.to(next(decoder.parameters()).dtype) + hr_coords = decoder.upsample(slat, upsample_times=4) + decoder.cpu() + + lr_resolution = 512 + hr_resolution = int(target_resolution) + + while True: + quant_coords = torch.cat([ + hr_coords[:, :1], + ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + final_coords = quant_coords.unique(dim=0) + num_tokens = final_coords.shape[0] + + if num_tokens < max_tokens or hr_resolution <= 1024: + break + hr_resolution -= 128 + + return IO.NodeOutput(final_coords.cpu()) + dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) @@ -282,7 +329,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): node_id="EmptyShapeLatentTrellis2", category="latent/3d", inputs=[ - IO.Voxel.Input("structure_output"), + IO.AnyType.Input("structure_or_coords"), IO.Model.Input("model") ], outputs=[ @@ -292,9 +339,13 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): ) @classmethod - def execute(cls, structure_output, model): - decoded = structure_output.data.unsqueeze(1) - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + def execute(cls, structure_or_coords, model): + # to accept the upscaled coords + if hasattr(structure_or_coords, "data"): + decoded = structure_or_coords.data.unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + else: + coords = structure_or_coords in_channels = 32 # image like format latent = torch.randn(1, in_channels, coords.shape[0], 1) @@ -307,7 +358,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"]["coords"] = coords model.model_options["transformer_options"]["generation_mode"] = "shape_generation" - return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) + return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -556,6 +607,7 @@ class Trellis2Extension(ComfyExtension): VaeDecodeTextureTrellis, VaeDecodeShapeTrellis, VaeDecodeStructureTrellis2, + Trellis2UpsampleCascade, PostProcessMesh ]