This commit is contained in:
Yousef Rafat 2026-02-24 14:53:54 +02:00
parent 2a27c3b417
commit a2c8a7aab5
2 changed files with 4 additions and 3 deletions

View File

@ -772,7 +772,7 @@ class Trellis2(nn.Module):
coords = transformer_options.get("coords", None) coords = transformer_options.get("coords", None)
mode = transformer_options.get("generation_mode", "structure_generation") mode = transformer_options.get("generation_mode", "structure_generation")
if coords is not None: if coords is not None:
x = x.squeeze(0) x = x.squeeze(-1).transpose(1, 2)
not_struct_mode = True not_struct_mode = True
else: else:
mode = "structure_generation" mode = "structure_generation"
@ -824,5 +824,5 @@ class Trellis2(nn.Module):
if not_struct_mode: if not_struct_mode:
out = out.feats out = out.feats
if mode == "shape_generation": if mode == "shape_generation":
out = out.view(B, N, -1) out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
return out return out

View File

@ -275,7 +275,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
decoded = structure_output.data.unsqueeze(1) decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
in_channels = 32 in_channels = 32
latent = torch.randn(1, coords.shape[0], in_channels) # image like format
latent = torch.randn(1, in_channels, coords.shape[0], 1)
model = model.clone() model = model.clone()
model.model_options = model.model_options.copy() model.model_options = model.model_options.copy()
if "transformer_options" in model.model_options: if "transformer_options" in model.model_options: