mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-08 08:12:34 +08:00
ltx: vae: Add time stride awareness to causal_conv_3d
This commit is contained in:
parent
1abec3e6ab
commit
e860c3de75
@ -23,6 +23,11 @@ class CausalConv3d(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
if isinstance(stride, int):
|
||||||
|
self.time_stride = stride
|
||||||
|
else:
|
||||||
|
self.time_stride = stride[0]
|
||||||
|
|
||||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||||
self.time_kernel_size = kernel_size[0]
|
self.time_kernel_size = kernel_size[0]
|
||||||
|
|
||||||
@ -58,18 +63,23 @@ class CausalConv3d(nn.Module):
|
|||||||
pieces = [ cached, x ]
|
pieces = [ cached, x ]
|
||||||
if is_end and not causal:
|
if is_end and not causal:
|
||||||
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
||||||
|
input_length = sum([piece.shape[2] for piece in pieces])
|
||||||
|
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
|
||||||
|
|
||||||
needs_caching = not is_end
|
needs_caching = not is_end
|
||||||
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
|
if needs_caching and cache_length == 0:
|
||||||
|
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
|
||||||
needs_caching = False
|
needs_caching = False
|
||||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
if needs_caching and x.shape[2] >= cache_length:
|
||||||
|
needs_caching = False
|
||||||
|
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||||
|
|
||||||
x = torch.cat(pieces, dim=2)
|
x = torch.cat(pieces, dim=2)
|
||||||
del pieces
|
del pieces
|
||||||
del cached
|
del cached
|
||||||
|
|
||||||
if needs_caching:
|
if needs_caching:
|
||||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||||
elif is_end:
|
elif is_end:
|
||||||
self.temporal_cache_state[tid] = (None, True)
|
self.temporal_cache_state[tid] = (None, True)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user