upscale node + simple node simplification

This commit is contained in:
Yousef Rafat 2026-03-11 22:50:17 +02:00
parent 011f624dd5
commit 2d904b28da

View File

@ -53,7 +53,6 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
category="latent/3d",
inputs=[
IO.Latent.Input("samples"),
IO.Voxel.Input("structure_output"),
IO.Vae.Input("vae"),
IO.Combo.Input("resolution", options=["512", "1024"], default="512")
],
@ -64,7 +63,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
)
@classmethod
def execute(cls, samples, structure_output, vae, resolution):
def execute(cls, samples, vae, resolution):
resolution = int(resolution)
patcher = vae.patcher
@ -72,8 +71,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
comfy.model_management.load_model_gpu(patcher)
vae = vae.first_stage_model
decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
coords = samples["coords"]
samples = samples["samples"]
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
@ -93,7 +91,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
category="latent/3d",
inputs=[
IO.Latent.Input("samples"),
IO.Voxel.Input("structure_output"),
IO.Vae.Input("vae"),
IO.AnyType.Input("shape_subs"),
],
@ -103,15 +100,15 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
)
@classmethod
def execute(cls, samples, structure_output, vae, shape_subs):
def execute(cls, samples, vae, shape_subs):
patcher = vae.patcher
device = comfy.model_management.get_torch_device()
comfy.model_management.load_model_gpu(patcher)
vae = vae.first_stage_model
decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
coords = samples["coords"]
samples = samples["samples"]
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
std = tex_slat_normalization["std"].to(samples)
@ -161,6 +158,56 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
out = Types.VOXEL(decoded.squeeze(1).float())
return IO.NodeOutput(out)
class Trellis2UpsampleCascade(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Trellis2UpsampleCascade",
category="latent/3d",
inputs=[
IO.Latent.Input("shape_latent_512"),
IO.Vae.Input("vae"),
IO.Combo.Input("target_resolution", options=["1024", "1536"], default="1024"),
IO.Int.Input("max_tokens", default=49152, min=1024, max=100000)
],
outputs=[
IO.AnyType.Output("hr_coords"),
]
)
@classmethod
def execute(cls, shape_latent_512, vae, target_resolution, max_tokens):
device = comfy.model_management.get_torch_device()
feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
coords_512 = shape_latent_512["coords"].to(device)
slat = shape_norm(feats, coords_512)
decoder = vae.first_stage_model.shape_dec
decoder.to(device)
slat.feats = slat.feats.to(next(decoder.parameters()).dtype)
hr_coords = decoder.upsample(slat, upsample_times=4)
decoder.cpu()
lr_resolution = 512
hr_resolution = int(target_resolution)
while True:
quant_coords = torch.cat([
hr_coords[:, :1],
((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
], dim=1)
final_coords = quant_coords.unique(dim=0)
num_tokens = final_coords.shape[0]
if num_tokens < max_tokens or hr_resolution <= 1024:
break
hr_resolution -= 128
return IO.NodeOutput(final_coords.cpu())
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
@ -282,7 +329,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
node_id="EmptyShapeLatentTrellis2",
category="latent/3d",
inputs=[
IO.Voxel.Input("structure_output"),
IO.AnyType.Input("structure_or_coords"),
IO.Model.Input("model")
],
outputs=[
@ -292,9 +339,13 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
)
@classmethod
def execute(cls, structure_output, model):
decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
def execute(cls, structure_or_coords, model):
# to accept the upscaled coords
if hasattr(structure_or_coords, "data"):
decoded = structure_or_coords.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
else:
coords = structure_or_coords
in_channels = 32
# image like format
latent = torch.randn(1, in_channels, coords.shape[0], 1)
@ -307,7 +358,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
model.model_options["transformer_options"]["coords"] = coords
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model)
return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model)
class EmptyTextureLatentTrellis2(IO.ComfyNode):
@classmethod
@ -556,6 +607,7 @@ class Trellis2Extension(ComfyExtension):
VaeDecodeTextureTrellis,
VaeDecodeShapeTrellis,
VaeDecodeStructureTrellis2,
Trellis2UpsampleCascade,
PostProcessMesh
]