work on txt gen

This commit is contained in:
Yousef Rafat 2026-03-24 02:44:03 +02:00
parent def8947e75
commit 56e52e5d03
2 changed files with 26 additions and 22 deletions

View File

@ -754,6 +754,7 @@ class Trellis2(nn.Module):
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
}
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
self.shape2txt = None
if init_txt_model:
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args)
@ -835,11 +836,12 @@ class Trellis2(nn.Module):
slat = transformer_options.get("shape_slat")
if slat is None:
raise ValueError("shape_slat can't be None")
slat.feats = slat.feats.repeat(B, 1)
x_st = sparse_cat([x_st, slat])
base_slat_feats = slat.feats[:N]
slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device)
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1))
out = self.shape2txt(x_st, t_eval, c_eval)
else: # structure
#timestep = timestep_reshift(timestep)
orig_bsz = x.shape[0]
if shape_rule:
x = x[0].unsqueeze(0)
@ -850,8 +852,7 @@ class Trellis2(nn.Module):
if not_struct_mode:
out = out.feats
if not_struct_mode:
out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
if rule and orig_bsz > 1:
out = out.repeat(orig_bsz, 1, 1, 1)
out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
if rule and orig_bsz > 1:
out = out.repeat(orig_bsz, 1, 1, 1)
return out

View File

@ -95,7 +95,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
IO.AnyType.Input("shape_subs"),
],
outputs=[
IO.Mesh.Output("mesh"),
IO.Voxel.Output("voxel"),
]
)
@ -116,11 +116,9 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
samples = SparseTensor(feats = samples, coords=coords)
samples = samples * std + mean
mesh = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5
faces = torch.stack([m.faces for m in mesh])
verts = torch.stack([m.vertices for m in mesh])
mesh = Types.MESH(vertices=verts, faces=faces)
return IO.NodeOutput(mesh)
voxel = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5
voxel = Types.VOXEL(voxel)
return IO.NodeOutput(voxel)
class VaeDecodeStructureTrellis2(IO.ComfyNode):
@classmethod
@ -377,7 +375,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
node_id="EmptyTextureLatentTrellis2",
category="latent/3d",
inputs=[
IO.Voxel.Input("structure_output"),
IO.Voxel.Input("structure_or_coords"),
IO.Latent.Input("shape_latent"),
IO.Model.Input("model")
],
@ -388,16 +386,21 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
)
@classmethod
def execute(cls, structure_output, shape_latent, model):
# TODO
decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
in_channels = 32
def execute(cls, structure_or_coords, shape_latent, model):
channels = 32
if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4:
decoded = structure_or_coords.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2:
coords = structure_or_coords.int()
shape_latent = shape_latent["samples"]
if shape_latent.ndim == 4:
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
shape_latent = shape_norm(shape_latent, coords)
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
latent = torch.randn(1, channels, coords.shape[0], 1)
model = model.clone()
model.model_options = model.model_options.copy()
if "transformer_options" in model.model_options:
@ -406,9 +409,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
model.model_options["transformer_options"] = {}
model.model_options["transformer_options"]["coords"] = coords
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
model.model_options["transformer_options"]["generation_mode"] = "texture_generation"
model.model_options["transformer_options"]["shape_slat"] = shape_latent
return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model)
return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model)
class EmptyStructureLatentTrellis2(IO.ComfyNode):