mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 16:50:57 +08:00
model: Add temporal roll to main VAE encoder
If there are no attention layers, its a standard resnet and VideoConv3d is asked for, substitute in the temporal rolling VAE algorithm. This reduces VAE usage by the temporal dimension (can be huge VRAM savings).
This commit is contained in:
parent
6571c912a7
commit
1d53c1f8f7
@ -159,17 +159,20 @@ class Downsample(nn.Module):
|
||||
stride=stride,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
if self.with_conv:
|
||||
if x.ndim == 4:
|
||||
if isinstance(self.conv, CarriedConv3d):
|
||||
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
elif x.ndim == 4:
|
||||
pad = (0, 1, 0, 1)
|
||||
mode = "constant"
|
||||
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
|
||||
x = self.conv(x)
|
||||
elif x.ndim == 5:
|
||||
pad = (1, 1, 1, 1, 2, 0)
|
||||
mode = "replicate"
|
||||
x = torch.nn.functional.pad(x, pad, mode=mode)
|
||||
x = self.conv(x)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
@ -549,9 +552,14 @@ class Encoder(nn.Module):
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.carried = False
|
||||
|
||||
if conv3d:
|
||||
conv_op = VideoConv3d
|
||||
if not attn_resolutions:
|
||||
conv_op = CarriedConv3d
|
||||
self.carried = True
|
||||
else:
|
||||
conv_op = VideoConv3d
|
||||
mid_attn_conv_op = ops.Conv3d
|
||||
else:
|
||||
conv_op = ops.Conv2d
|
||||
@ -564,6 +572,7 @@ class Encoder(nn.Module):
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
self.time_compress = 1
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
@ -590,10 +599,15 @@ class Encoder(nn.Module):
|
||||
if time_compress is not None:
|
||||
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
|
||||
stride = (1, 2, 2)
|
||||
else:
|
||||
self.time_compress *= 2
|
||||
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
if time_compress is not None:
|
||||
self.time_compress = time_compress
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
@ -619,15 +633,42 @@ class Encoder(nn.Module):
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h, temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
if self.carried:
|
||||
xl = [x[:, :, :1, :, :]]
|
||||
if x.shape[2] > self.time_compress:
|
||||
tc = self.time_compress
|
||||
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
|
||||
x = xl
|
||||
else:
|
||||
x = [x]
|
||||
out = []
|
||||
|
||||
conv_carry_in = None
|
||||
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
|
||||
# downsampling
|
||||
x1 = [ x1 ]
|
||||
h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
assert i == 0 #carried should not happen if attn exists
|
||||
h1 = self.down[i_level].attn[i_block](h1)
|
||||
if i_level != self.num_resolutions-1:
|
||||
h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
|
||||
|
||||
out.append(h1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
h = torch_cat_if_needed(out, dim=2)
|
||||
del out
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
@ -636,8 +677,8 @@ class Encoder(nn.Module):
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
h = [ nonlinearity(h) ]
|
||||
h = conv_carry_causal_3d(h, self.conv_out)
|
||||
return h
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user