ltx: vae: Automate truncation for encoder

Other VAEs just truncate without error. Do the same.
This commit is contained in:
Rattus 2026-03-19 13:14:00 +10:00
parent e860c3de75
commit 793386d754

View File

@ -1265,9 +1265,7 @@ class VideoVAE(nn.Module):
return config
def encode(self, x):
frames_count = x.shape[2]
if ((frames_count - 1) % 8) != 0:
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)