ltx: vae: add cache state to downsample block

This commit is contained in:
Rattus 2026-03-19 10:29:51 +10:00
parent f6b869d7d3
commit 1abec3e6ab

View File

@ -737,12 +737,25 @@ class SpaceToDepthDownsample(nn.Module):
causal=True, causal=True,
spatial_padding_mode=spatial_padding_mode, spatial_padding_mode=spatial_padding_mode,
) )
self.temporal_cache_state = {}
def forward(self, x, causal: bool = True): def forward(self, x, causal: bool = True):
if self.stride[0] == 2: tid = threading.get_ident()
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
if cached_input is not None:
x = torch_cat_if_needed([cached_input, x], dim=2)
cached_input = None
if self.stride[0] == 2 and pad_first:
x = torch.cat( x = torch.cat(
[x[:, :, :1, :, :], x], dim=2 [x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding ) # duplicate first frames for padding
pad_first = False
if x.shape[2] < self.stride[0]:
cached_input = x
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
return None
# skip connection # skip connection
x_in = rearrange( x_in = rearrange(
@ -757,15 +770,26 @@ class SpaceToDepthDownsample(nn.Module):
# conv # conv
x = self.conv(x, causal=causal) x = self.conv(x, causal=causal)
x = rearrange( if self.stride[0] == 2 and x.shape[2] == 1:
x, if cached_x is not None:
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", x = torch_cat_if_needed([cached_x, x], dim=2)
p1=self.stride[0], cached_x = None
p2=self.stride[1], else:
p3=self.stride[2], cached_x = x
) x = None
x = x + x_in if x is not None:
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
cached = add_exchange_cache(x, cached, x_in, dim=2)
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
return x return x