ltx: vae-encode: round chunk sizes more strictly

Only powers of 2 and multiple of 8 are valid due to cache slicing.
This commit is contained in:
Rattus 2026-03-19 18:57:39 +10:00
parent 60a3de0ef3
commit ec60b6969b

View File

@ -293,8 +293,14 @@ class Encoder(nn.Module):
outputs = [] outputs = []
samples = [sample[:, :, :1, :, :]] samples = [sample[:, :, :1, :, :]]
if sample.shape[2] > 1: if sample.shape[2] > 1:
n = max(1, max_chunk_size // (2 * frame_size)) chunk_t = max(2, max_chunk_size // frame_size)
samples += list(torch.split(sample[:, :, 1:, :, :], 2 * n, dim=2)) if chunk_t < 4:
chunk_t = 2
elif chunk_t < 8:
chunk_t = 4
else:
chunk_t = (chunk_t // 8) * 8
samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
for chunk_idx, chunk in enumerate(samples): for chunk_idx, chunk in enumerate(samples):
if chunk_idx == len(samples) - 1: if chunk_idx == len(samples) - 1:
mark_conv3d_ended(self) mark_conv3d_ended(self)