package the trellis2 resolution

instead of taking it as an input from the user
This commit is contained in:
Yousef Rafat 2026-05-14 21:42:52 +03:00
parent 3d5f9aead7
commit 0bae96f2dd
2 changed files with 5 additions and 6 deletions

View File

@ -1419,6 +1419,7 @@ class Vae(nn.Module):
num_res_blocks_middle=2, num_res_blocks_middle=2,
channels=[512, 128, 32], channels=[512, 128, 32],
) )
self.register_buffer("resolution", torch.tensor(1024.0), persistent=False)
@torch.no_grad() @torch.no_grad()
def decode_shape_slat(self, slat, resolution: int): def decode_shape_slat(self, slat, resolution: int):

View File

@ -244,7 +244,6 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
inputs=[ inputs=[
IO.Latent.Input("samples"), IO.Latent.Input("samples"),
IO.Vae.Input("vae"), IO.Vae.Input("vae"),
IO.Combo.Input("resolution", options=["512", "1024"], default="1024")
], ],
outputs=[ outputs=[
IO.Mesh.Output("mesh"), IO.Mesh.Output("mesh"),
@ -253,9 +252,9 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, samples, vae, resolution): def execute(cls, samples, vae):
resolution = int(resolution) resolution = int(vae.resolution.item())
sample_tensor = samples["samples"] sample_tensor = samples["samples"]
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
coords = samples["coords"] coords = samples["coords"]
@ -306,7 +305,6 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
IO.Latent.Input("samples"), IO.Latent.Input("samples"),
IO.Vae.Input("vae"), IO.Vae.Input("vae"),
IO.AnyType.Input("shape_subs"), IO.AnyType.Input("shape_subs"),
IO.Combo.Input("resolution", options=["512", "1024"], default="1024")
], ],
outputs=[ outputs=[
IO.Mesh.Output("mesh"), IO.Mesh.Output("mesh"),
@ -314,10 +312,10 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, mesh, samples, vae, shape_subs, resolution): def execute(cls, mesh, samples, vae, shape_subs):
shape_mesh = mesh shape_mesh = mesh
sample_tensor = samples["samples"] sample_tensor = samples["samples"]
resolution = int(resolution) resolution = int(vae.resolution.item())
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
coords = samples["coords"] coords = samples["coords"]
prepare_trellis_vae_for_decode(vae, sample_tensor.shape) prepare_trellis_vae_for_decode(vae, sample_tensor.shape)