resolution logit

This commit is contained in:
Yousef Rafat 2026-02-27 22:22:07 +02:00
parent 39270fdca9
commit 7d444a4fcc
2 changed files with 20 additions and 5 deletions

View File

@ -812,7 +812,7 @@ class Trellis2(nn.Module):
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
else: # structure
timestep = timestep_reshift(timestep)
#timestep = timestep_reshift(timestep)
orig_bsz = x.shape[0]
if shape_rule:
x = x[0].unsqueeze(0)

View File

@ -136,7 +136,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
IO.Latent.Input("samples"),
IO.Voxel.Input("structure_output"),
IO.Vae.Input("vae"),
IO.Int.Input("resolution", tooltip="Shape Generation Resolution"),
IO.Combo.Input("resolution", options=["512", "1024"], default="512")
],
outputs=[
IO.Mesh.Output("mesh"),
@ -147,6 +147,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
@classmethod
def execute(cls, samples, structure_output, vae, resolution):
resolution = int(resolution)
patcher = vae.patcher
device = comfy.model_management.get_torch_device()
comfy.model_management.load_model_gpu(patcher)
@ -154,14 +155,18 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
vae = vae.first_stage_model
decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
samples = samples["samples"]
samples = samples.squeeze(-1).transpose(1, 2).to(device)
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
std = shape_slat_normalization["std"].to(samples)
mean = shape_slat_normalization["mean"].to(samples)
samples = SparseTensor(feats = samples, coords=coords)
samples = samples * std + mean
mesh, subs = vae.decode_shape_slat(samples, resolution)
faces = torch.stack([m.faces for m in mesh])
verts = torch.stack([m.vertices for m in mesh])
mesh = Types.MESH(vertices=verts, faces=faces)
return IO.NodeOutput(mesh, subs)
class VaeDecodeTextureTrellis(IO.ComfyNode):
@ -192,13 +197,16 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
samples = samples["samples"]
samples = samples.squeeze(-1).transpose(1, 2).to(device)
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
std = tex_slat_normalization["std"].to(samples)
mean = tex_slat_normalization["mean"].to(samples)
samples = SparseTensor(feats = samples, coords=coords)
samples = samples * std + mean
mesh = vae.decode_tex_slat(samples, shape_subs)
faces = torch.stack([m.faces for m in mesh])
verts = torch.stack([m.vertices for m in mesh])
mesh = Types.MESH(vertices=verts, faces=faces)
return IO.NodeOutput(mesh)
class VaeDecodeStructureTrellis2(IO.ComfyNode):
@ -210,6 +218,7 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
IO.Combo.Input("resolution", options=["32", "64"], default="32")
],
outputs=[
IO.Voxel.Output("structure_output"),
@ -217,7 +226,8 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
)
@classmethod
def execute(cls, samples, vae):
def execute(cls, samples, vae, resolution):
resolution = int(resolution)
vae = vae.first_stage_model
decoder = vae.struct_dec
load_device = comfy.model_management.get_torch_device()
@ -227,6 +237,11 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
samples = samples.to(load_device)
decoded = decoder(samples)>0
decoder.to(offload_device)
current_res = decoded.shape[2]
if current_res != resolution:
ratio = current_res // resolution
decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
out = Types.VOXEL(decoded.squeeze(1).float())
return IO.NodeOutput(out)