Cap encode chunks at the configured frame batch size.

This commit is contained in:
Talmaj Marinc 2026-04-14 15:05:27 +02:00
parent 3e961f9960
commit e962c3f846

View File

@ -453,13 +453,13 @@ class AutoencoderKLCogVideoX(nn.Module):
def encode(self, x):
t = x.shape[2]
frame_batch = self.num_sample_frames_batch_size
num_batches = max(t // frame_batch, 1)
# ceil so remainder frames get their own chunk instead of inflating the first one
num_batches = max(-(-t // frame_batch), 1)
conv_cache = None
enc = []
for i in range(num_batches):
remaining = t % frame_batch
start = frame_batch * i + (0 if i == 0 else remaining)
end = frame_batch * (i + 1) + remaining
start = i * frame_batch
end = min((i + 1) * frame_batch, t)
chunk, conv_cache = self.encoder(x[:, :, start:end], conv_cache=conv_cache)
enc.append(chunk.to(x.device))
enc = torch.cat(enc, dim=2)