mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 01:02:56 +08:00
ltx: vae: add cache state to downsample block
This commit is contained in:
parent
f6b869d7d3
commit
1abec3e6ab
@ -737,12 +737,25 @@ class SpaceToDepthDownsample(nn.Module):
|
|||||||
causal=True,
|
causal=True,
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
self.temporal_cache_state = {}
|
||||||
|
|
||||||
def forward(self, x, causal: bool = True):
|
def forward(self, x, causal: bool = True):
|
||||||
if self.stride[0] == 2:
|
tid = threading.get_ident()
|
||||||
|
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
|
||||||
|
if cached_input is not None:
|
||||||
|
x = torch_cat_if_needed([cached_input, x], dim=2)
|
||||||
|
cached_input = None
|
||||||
|
|
||||||
|
if self.stride[0] == 2 and pad_first:
|
||||||
x = torch.cat(
|
x = torch.cat(
|
||||||
[x[:, :, :1, :, :], x], dim=2
|
[x[:, :, :1, :, :], x], dim=2
|
||||||
) # duplicate first frames for padding
|
) # duplicate first frames for padding
|
||||||
|
pad_first = False
|
||||||
|
|
||||||
|
if x.shape[2] < self.stride[0]:
|
||||||
|
cached_input = x
|
||||||
|
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||||
|
return None
|
||||||
|
|
||||||
# skip connection
|
# skip connection
|
||||||
x_in = rearrange(
|
x_in = rearrange(
|
||||||
@ -757,15 +770,26 @@ class SpaceToDepthDownsample(nn.Module):
|
|||||||
|
|
||||||
# conv
|
# conv
|
||||||
x = self.conv(x, causal=causal)
|
x = self.conv(x, causal=causal)
|
||||||
x = rearrange(
|
if self.stride[0] == 2 and x.shape[2] == 1:
|
||||||
x,
|
if cached_x is not None:
|
||||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
x = torch_cat_if_needed([cached_x, x], dim=2)
|
||||||
p1=self.stride[0],
|
cached_x = None
|
||||||
p2=self.stride[1],
|
else:
|
||||||
p3=self.stride[2],
|
cached_x = x
|
||||||
)
|
x = None
|
||||||
|
|
||||||
x = x + x_in
|
if x is not None:
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||||
|
p1=self.stride[0],
|
||||||
|
p2=self.stride[1],
|
||||||
|
p3=self.stride[2],
|
||||||
|
)
|
||||||
|
|
||||||
|
cached = add_exchange_cache(x, cached, x_in, dim=2)
|
||||||
|
|
||||||
|
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user