work on txt gen

This commit is contained in:
Yousef Rafat 2026-03-24 02:44:03 +02:00
parent def8947e75
commit 56e52e5d03
2 changed files with 26 additions and 22 deletions

View File

@ -754,6 +754,7 @@ class Trellis2(nn.Module):
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
} }
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
self.shape2txt = None
if init_txt_model: if init_txt_model:
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args) self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args)
@ -835,11 +836,12 @@ class Trellis2(nn.Module):
slat = transformer_options.get("shape_slat") slat = transformer_options.get("shape_slat")
if slat is None: if slat is None:
raise ValueError("shape_slat can't be None") raise ValueError("shape_slat can't be None")
slat.feats = slat.feats.repeat(B, 1)
x_st = sparse_cat([x_st, slat]) base_slat_feats = slat.feats[:N]
slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device)
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1))
out = self.shape2txt(x_st, t_eval, c_eval) out = self.shape2txt(x_st, t_eval, c_eval)
else: # structure else: # structure
#timestep = timestep_reshift(timestep)
orig_bsz = x.shape[0] orig_bsz = x.shape[0]
if shape_rule: if shape_rule:
x = x[0].unsqueeze(0) x = x[0].unsqueeze(0)
@ -850,8 +852,7 @@ class Trellis2(nn.Module):
if not_struct_mode: if not_struct_mode:
out = out.feats out = out.feats
if not_struct_mode: out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) if rule and orig_bsz > 1:
if rule and orig_bsz > 1: out = out.repeat(orig_bsz, 1, 1, 1)
out = out.repeat(orig_bsz, 1, 1, 1)
return out return out

View File

@ -95,7 +95,7 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
IO.AnyType.Input("shape_subs"), IO.AnyType.Input("shape_subs"),
], ],
outputs=[ outputs=[
IO.Mesh.Output("mesh"), IO.Voxel.Output("voxel"),
] ]
) )
@ -116,11 +116,9 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
samples = SparseTensor(feats = samples, coords=coords) samples = SparseTensor(feats = samples, coords=coords)
samples = samples * std + mean samples = samples * std + mean
mesh = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5 voxel = vae.decode_tex_slat(samples, shape_subs) * 0.5 + 0.5
faces = torch.stack([m.faces for m in mesh]) voxel = Types.VOXEL(voxel)
verts = torch.stack([m.vertices for m in mesh]) return IO.NodeOutput(voxel)
mesh = Types.MESH(vertices=verts, faces=faces)
return IO.NodeOutput(mesh)
class VaeDecodeStructureTrellis2(IO.ComfyNode): class VaeDecodeStructureTrellis2(IO.ComfyNode):
@classmethod @classmethod
@ -377,7 +375,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
node_id="EmptyTextureLatentTrellis2", node_id="EmptyTextureLatentTrellis2",
category="latent/3d", category="latent/3d",
inputs=[ inputs=[
IO.Voxel.Input("structure_output"), IO.Voxel.Input("structure_or_coords"),
IO.Latent.Input("shape_latent"), IO.Latent.Input("shape_latent"),
IO.Model.Input("model") IO.Model.Input("model")
], ],
@ -388,16 +386,21 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, structure_output, shape_latent, model): def execute(cls, structure_or_coords, shape_latent, model):
# TODO channels = 32
decoded = structure_output.data.unsqueeze(1) if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4:
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() decoded = structure_or_coords.data.unsqueeze(1)
in_channels = 32 coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2:
coords = structure_or_coords.int()
shape_latent = shape_latent["samples"] shape_latent = shape_latent["samples"]
if shape_latent.ndim == 4:
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
shape_latent = shape_norm(shape_latent, coords) shape_latent = shape_norm(shape_latent, coords)
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) latent = torch.randn(1, channels, coords.shape[0], 1)
model = model.clone() model = model.clone()
model.model_options = model.model_options.copy() model.model_options = model.model_options.copy()
if "transformer_options" in model.model_options: if "transformer_options" in model.model_options:
@ -406,9 +409,9 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
model.model_options["transformer_options"] = {} model.model_options["transformer_options"] = {}
model.model_options["transformer_options"]["coords"] = coords model.model_options["transformer_options"]["coords"] = coords
model.model_options["transformer_options"]["generation_mode"] = "shape_generation" model.model_options["transformer_options"]["generation_mode"] = "texture_generation"
model.model_options["transformer_options"]["shape_slat"] = shape_latent model.model_options["transformer_options"]["shape_slat"] = shape_latent
return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model)
class EmptyStructureLatentTrellis2(IO.ComfyNode): class EmptyStructureLatentTrellis2(IO.ComfyNode):