mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-16 01:00:49 +08:00
model: Add temporal roll to main VAE decoder
If there are no attention layers, its a standard resnet and VideoConv3d is asked for, substitute in the temporal rolloing VAE algorithm. This reduces VAE usage by the temporal dimension (can be huge VRAM savings).
This commit is contained in:
parent
119fc04459
commit
6571c912a7
@ -126,29 +126,24 @@ class Upsample(nn.Module):
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
scale_factor = self.scale_factor
|
||||
if isinstance(scale_factor, (int, float)):
|
||||
scale_factor = (scale_factor,) * (x.ndim - 2)
|
||||
|
||||
if x.ndim == 5 and scale_factor[0] > 1.0:
|
||||
t = x.shape[2]
|
||||
if t > 1:
|
||||
a, b = x.split((1, t - 1), dim=2)
|
||||
del x
|
||||
b = interpolate_up(b, scale_factor)
|
||||
else:
|
||||
a = x
|
||||
|
||||
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
|
||||
if t > 1:
|
||||
x = torch.cat((a, b), dim=2)
|
||||
else:
|
||||
x = a
|
||||
results = []
|
||||
if conv_carry_in is None:
|
||||
first = x[:, :, :1, :, :]
|
||||
results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
|
||||
x = x[:, :, 1:, :, :]
|
||||
if x.shape[2] > 0:
|
||||
results.append(interpolate_up(x, scale_factor))
|
||||
x = torch_cat_if_needed(results, dim=2)
|
||||
else:
|
||||
x = interpolate_up(x, scale_factor)
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
return x
|
||||
|
||||
|
||||
@ -664,10 +659,17 @@ class Decoder(nn.Module):
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.tanh_out = tanh_out
|
||||
self.carried = False
|
||||
|
||||
if conv3d:
|
||||
conv_op = VideoConv3d
|
||||
conv_out_op = VideoConv3d
|
||||
if not attn_resolutions and resnet_op == ResnetBlock:
|
||||
conv_op = CarriedConv3d
|
||||
conv_out_op = CarriedConv3d
|
||||
self.carried = True
|
||||
else:
|
||||
conv_op = VideoConv3d
|
||||
conv_out_op = VideoConv3d
|
||||
|
||||
mid_attn_conv_op = ops.Conv3d
|
||||
else:
|
||||
conv_op = ops.Conv2d
|
||||
@ -742,25 +744,43 @@ class Decoder(nn.Module):
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
h = conv_carry_causal_3d([z], self.conv_in)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, **kwargs)
|
||||
h = self.mid.attn_1(h, **kwargs)
|
||||
h = self.mid.block_2(h, temb, **kwargs)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
if self.carried:
|
||||
h = torch.split(h, 2, dim=2)
|
||||
else:
|
||||
h = [ h ]
|
||||
out = []
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, **kwargs)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
conv_carry_in = None
|
||||
|
||||
# upsampling
|
||||
for i, h1 in enumerate(h):
|
||||
conv_carry_out = []
|
||||
if i == len(h) - 1:
|
||||
conv_carry_out = None
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
assert i == 0 #carried should not happen if attn exists
|
||||
h1 = self.up[i_level].attn[i_block](h1, **kwargs)
|
||||
if i_level != 0:
|
||||
h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
|
||||
|
||||
h1 = self.norm_out(h1)
|
||||
h1 = [ nonlinearity(h1) ]
|
||||
h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
|
||||
if self.tanh_out:
|
||||
h1 = torch.tanh(h1)
|
||||
out.append(h1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
|
||||
return out
|
||||
|
||||
Loading…
Reference in New Issue
Block a user