This commit is contained in:
Yousef Rafat 2026-02-24 14:53:54 +02:00
parent 2a27c3b417
commit a2c8a7aab5
2 changed files with 4 additions and 3 deletions

View File

@ -772,7 +772,7 @@ class Trellis2(nn.Module):
coords = transformer_options.get("coords", None)
mode = transformer_options.get("generation_mode", "structure_generation")
if coords is not None:
x = x.squeeze(0)
x = x.squeeze(-1).transpose(1, 2)
not_struct_mode = True
else:
mode = "structure_generation"
@ -824,5 +824,5 @@ class Trellis2(nn.Module):
if not_struct_mode:
out = out.feats
if mode == "shape_generation":
out = out.view(B, N, -1)
out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
return out

View File

@ -275,7 +275,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
decoded = structure_output.data.unsqueeze(1)
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
in_channels = 32
latent = torch.randn(1, coords.shape[0], in_channels)
# image like format
latent = torch.randn(1, in_channels, coords.shape[0], 1)
model = model.clone()
model.model_options = model.model_options.copy()
if "transformer_options" in model.model_options: