resolution logit

This commit is contained in:
Yousef Rafat 2026-02-27 22:22:07 +02:00
parent 39270fdca9
commit 7d444a4fcc
2 changed files with 20 additions and 5 deletions

View File

@ -812,7 +812,7 @@ class Trellis2(nn.Module):
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
out = self.shape2txt(x, timestep, context if not txt_rule else cond) out = self.shape2txt(x, timestep, context if not txt_rule else cond)
else: # structure else: # structure
timestep = timestep_reshift(timestep) #timestep = timestep_reshift(timestep)
orig_bsz = x.shape[0] orig_bsz = x.shape[0]
if shape_rule: if shape_rule:
x = x[0].unsqueeze(0) x = x[0].unsqueeze(0)

View File

@ -136,7 +136,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
IO.Latent.Input("samples"), IO.Latent.Input("samples"),
IO.Voxel.Input("structure_output"), IO.Voxel.Input("structure_output"),
IO.Vae.Input("vae"), IO.Vae.Input("vae"),
IO.Int.Input("resolution", tooltip="Shape Generation Resolution"), IO.Combo.Input("resolution", options=["512", "1024"], default="512")
], ],
outputs=[ outputs=[
IO.Mesh.Output("mesh"), IO.Mesh.Output("mesh"),
@ -147,6 +147,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
@classmethod @classmethod
def execute(cls, samples, structure_output, vae, resolution): def execute(cls, samples, structure_output, vae, resolution):
resolution = int(resolution)
patcher = vae.patcher patcher = vae.patcher
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
comfy.model_management.load_model_gpu(patcher) comfy.model_management.load_model_gpu(patcher)
@ -154,14 +155,18 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
vae = vae.first_stage_model vae = vae.first_stage_model
decoded = structure_output.data.unsqueeze(1) decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
samples = samples["samples"] samples = samples["samples"]
samples = samples.squeeze(-1).transpose(1, 2).to(device) samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
std = shape_slat_normalization["std"].to(samples) std = shape_slat_normalization["std"].to(samples)
mean = shape_slat_normalization["mean"].to(samples) mean = shape_slat_normalization["mean"].to(samples)
samples = SparseTensor(feats = samples, coords=coords) samples = SparseTensor(feats = samples, coords=coords)
samples = samples * std + mean samples = samples * std + mean
mesh, subs = vae.decode_shape_slat(samples, resolution) mesh, subs = vae.decode_shape_slat(samples, resolution)
faces = torch.stack([m.faces for m in mesh])
verts = torch.stack([m.vertices for m in mesh])
mesh = Types.MESH(vertices=verts, faces=faces)
return IO.NodeOutput(mesh, subs) return IO.NodeOutput(mesh, subs)
class VaeDecodeTextureTrellis(IO.ComfyNode): class VaeDecodeTextureTrellis(IO.ComfyNode):
@ -192,13 +197,16 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
decoded = structure_output.data.unsqueeze(1) decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
samples = samples["samples"] samples = samples["samples"]
samples = samples.squeeze(-1).transpose(1, 2).to(device) samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
std = tex_slat_normalization["std"].to(samples) std = tex_slat_normalization["std"].to(samples)
mean = tex_slat_normalization["mean"].to(samples) mean = tex_slat_normalization["mean"].to(samples)
samples = SparseTensor(feats = samples, coords=coords) samples = SparseTensor(feats = samples, coords=coords)
samples = samples * std + mean samples = samples * std + mean
mesh = vae.decode_tex_slat(samples, shape_subs) mesh = vae.decode_tex_slat(samples, shape_subs)
faces = torch.stack([m.faces for m in mesh])
verts = torch.stack([m.vertices for m in mesh])
mesh = Types.MESH(vertices=verts, faces=faces)
return IO.NodeOutput(mesh) return IO.NodeOutput(mesh)
class VaeDecodeStructureTrellis2(IO.ComfyNode): class VaeDecodeStructureTrellis2(IO.ComfyNode):
@ -210,6 +218,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
inputs=[ inputs=[
IO.Latent.Input("samples"), IO.Latent.Input("samples"),
IO.Vae.Input("vae"), IO.Vae.Input("vae"),
IO.Combo.Input("resolution", options=["32", "64"], default="32")
], ],
outputs=[ outputs=[
IO.Voxel.Output("structure_output"), IO.Voxel.Output("structure_output"),
@ -217,7 +226,8 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, samples, vae): def execute(cls, samples, vae, resolution):
resolution = int(resolution)
vae = vae.first_stage_model vae = vae.first_stage_model
decoder = vae.struct_dec decoder = vae.struct_dec
load_device = comfy.model_management.get_torch_device() load_device = comfy.model_management.get_torch_device()
@ -227,6 +237,11 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
samples = samples.to(load_device) samples = samples.to(load_device)
decoded = decoder(samples)>0 decoded = decoder(samples)>0
decoder.to(offload_device) decoder.to(offload_device)
current_res = decoded.shape[2]
if current_res != resolution:
ratio = current_res // resolution
decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
out = Types.VOXEL(decoded.squeeze(1).float()) out = Types.VOXEL(decoded.squeeze(1).float())
return IO.NodeOutput(out) return IO.NodeOutput(out)