ltx: vae: Add time stride awareness to causal_conv_3d

This commit is contained in:
Rattus 2026-03-19 11:25:42 +10:00
parent 1abec3e6ab
commit e860c3de75

View File

@ -23,6 +23,11 @@ class CausalConv3d(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
if isinstance(stride, int):
self.time_stride = stride
else:
self.time_stride = stride[0]
kernel_size = (kernel_size, kernel_size, kernel_size) kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0] self.time_kernel_size = kernel_size[0]
@ -58,18 +63,23 @@ class CausalConv3d(nn.Module):
pieces = [ cached, x ] pieces = [ cached, x ]
if is_end and not causal: if is_end and not causal:
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))) pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
input_length = sum([piece.shape[2] for piece in pieces])
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
needs_caching = not is_end needs_caching = not is_end
if needs_caching and x.shape[2] >= self.time_kernel_size - 1: if needs_caching and cache_length == 0:
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
needs_caching = False needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) if needs_caching and x.shape[2] >= cache_length:
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
x = torch.cat(pieces, dim=2) x = torch.cat(pieces, dim=2)
del pieces del pieces
del cached del cached
if needs_caching: if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
elif is_end: elif is_end:
self.temporal_cache_state[tid] = (None, True) self.temporal_cache_state[tid] = (None, True)