diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 1e61ec755..1a15cafd0 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -293,8 +293,14 @@ class Encoder(nn.Module): outputs = [] samples = [sample[:, :, :1, :, :]] if sample.shape[2] > 1: - n = max(1, max_chunk_size // (2 * frame_size)) - samples += list(torch.split(sample[:, :, 1:, :, :], 2 * n, dim=2)) + chunk_t = max(2, max_chunk_size // frame_size) + if chunk_t < 4: + chunk_t = 2 + elif chunk_t < 8: + chunk_t = 4 + else: + chunk_t = (chunk_t // 8) * 8 + samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2)) for chunk_idx, chunk in enumerate(samples): if chunk_idx == len(samples) - 1: mark_conv3d_ended(self)