This commit is contained in:
Yousef Rafat 2026-03-20 02:36:01 +02:00
parent 2d904b28da
commit 5d2548822c
2 changed files with 4 additions and 3 deletions

View File

@ -756,6 +756,7 @@ class Trellis2(nn.Module):
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
if init_txt_model: if init_txt_model:
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args)
args.pop("out_channels") args.pop("out_channels")
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
self.guidance_interval = [0.6, 1.0] self.guidance_interval = [0.6, 1.0]

View File

@ -54,7 +54,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
inputs=[ inputs=[
IO.Latent.Input("samples"), IO.Latent.Input("samples"),
IO.Vae.Input("vae"), IO.Vae.Input("vae"),
IO.Combo.Input("resolution", options=["512", "1024"], default="512") IO.Combo.Input("resolution", options=["512", "1024"], default="1024")
], ],
outputs=[ outputs=[
IO.Mesh.Output("mesh"), IO.Mesh.Output("mesh"),
@ -116,7 +116,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
samples = SparseTensor(feats = samples, coords=coords) samples = SparseTensor(feats = samples, coords=coords)
samples = samples * std + mean samples = samples * std + mean
mesh = vae.decode_tex_slat(samples, shape_subs) mesh = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5
faces = torch.stack([m.faces for m in mesh]) faces = torch.stack([m.faces for m in mesh])
verts = torch.stack([m.vertices for m in mesh]) verts = torch.stack([m.vertices for m in mesh])
mesh = Types.MESH(vertices=verts, faces=faces) mesh = Types.MESH(vertices=verts, faces=faces)
@ -541,7 +541,7 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
new_verts.append(loop_v.mean(dim=0)) new_verts.append(loop_v.mean(dim=0))
for i in range(len(loop)): for i in range(len(loop)):
new_faces.append([loop[i], loop[(i + 1) % len(loop)], v_idx]) new_faces.append([loop[(i + 1) % len(loop)], loop[i], v_idx])
v_idx += 1 v_idx += 1
if new_verts: if new_verts: