model: Add temporal roll to main VAE encoder

If there are no attention layers, its a standard resnet and VideoConv3d
is asked for, substitute in the temporal rolling VAE algorithm. This
reduces VAE usage by the temporal dimension (can be huge VRAM savings).
This commit is contained in:
Rattus 2025-11-30 08:56:07 +10:00
parent 6571c912a7
commit 1d53c1f8f7

View File

@ -159,17 +159,20 @@ class Downsample(nn.Module):
stride=stride, stride=stride,
padding=0) padding=0)
def forward(self, x): def forward(self, x, conv_carry_in=None, conv_carry_out=None):
if self.with_conv: if self.with_conv:
if x.ndim == 4: if isinstance(self.conv, CarriedConv3d):
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
elif x.ndim == 4:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
mode = "constant" mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0) x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
x = self.conv(x)
elif x.ndim == 5: elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0) pad = (1, 1, 1, 1, 2, 0)
mode = "replicate" mode = "replicate"
x = torch.nn.functional.pad(x, pad, mode=mode) x = torch.nn.functional.pad(x, pad, mode=mode)
x = self.conv(x) x = self.conv(x)
else: else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x return x
@ -549,9 +552,14 @@ class Encoder(nn.Module):
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.resolution = resolution self.resolution = resolution
self.in_channels = in_channels self.in_channels = in_channels
self.carried = False
if conv3d: if conv3d:
conv_op = VideoConv3d if not attn_resolutions:
conv_op = CarriedConv3d
self.carried = True
else:
conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d mid_attn_conv_op = ops.Conv3d
else: else:
conv_op = ops.Conv2d conv_op = ops.Conv2d
@ -564,6 +572,7 @@ class Encoder(nn.Module):
stride=1, stride=1,
padding=1) padding=1)
self.time_compress = 1
curr_res = resolution curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult) in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult self.in_ch_mult = in_ch_mult
@ -590,10 +599,15 @@ class Encoder(nn.Module):
if time_compress is not None: if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress): if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2) stride = (1, 2, 2)
else:
self.time_compress *= 2
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op) down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2 curr_res = curr_res // 2
self.down.append(down) self.down.append(down)
if time_compress is not None:
self.time_compress = time_compress
# middle # middle
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, self.mid.block_1 = ResnetBlock(in_channels=block_in,
@ -619,15 +633,42 @@ class Encoder(nn.Module):
def forward(self, x): def forward(self, x):
# timestep embedding # timestep embedding
temb = None temb = None
# downsampling
h = self.conv_in(x) if self.carried:
for i_level in range(self.num_resolutions): xl = [x[:, :, :1, :, :]]
for i_block in range(self.num_res_blocks): if x.shape[2] > self.time_compress:
h = self.down[i_level].block[i_block](h, temb) tc = self.time_compress
if len(self.down[i_level].attn) > 0: xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
h = self.down[i_level].attn[i_block](h) x = xl
if i_level != self.num_resolutions-1: else:
h = self.down[i_level].downsample(h) x = [x]
out = []
conv_carry_in = None
for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
# downsampling
x1 = [ x1 ]
h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
if len(self.down[i_level].attn) > 0:
assert i == 0 #carried should not happen if attn exists
h1 = self.down[i_level].attn[i_block](h1)
if i_level != self.num_resolutions-1:
h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
out.append(h1)
conv_carry_in = conv_carry_out
h = torch_cat_if_needed(out, dim=2)
del out
# middle # middle
h = self.mid.block_1(h, temb) h = self.mid.block_1(h, temb)
@ -636,8 +677,8 @@ class Encoder(nn.Module):
# end # end
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = [ nonlinearity(h) ]
h = self.conv_out(h) h = conv_carry_causal_3d(h, self.conv_out)
return h return h