diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index a565ec37e..fb5276f94 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -772,7 +772,7 @@ class Trellis2(nn.Module): coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") if coords is not None: - x = x.squeeze(0) + x = x.squeeze(-1).transpose(1, 2) not_struct_mode = True else: mode = "structure_generation" @@ -824,5 +824,5 @@ class Trellis2(nn.Module): if not_struct_mode: out = out.feats if mode == "shape_generation": - out = out.view(B, N, -1) + out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 1b43f7f62..f40ff5161 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -275,7 +275,8 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 - latent = torch.randn(1, coords.shape[0], in_channels) + # image like format + latent = torch.randn(1, in_channels, coords.shape[0], 1) model = model.clone() model.model_options = model.model_options.copy() if "transformer_options" in model.model_options: