Added explicit memory management during the VAE decode process.

This commit is contained in:
KBRASK 2025-04-13 12:03:50 +08:00
parent 9ee6ca99d8
commit 8e1d7087a3

View File

@ -714,13 +714,21 @@ class Decoder(nn.Module):
h = self.mid.block_2(h, temb, **kwargs)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb, **kwargs)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, **kwargs)
if i_level != 0:
h = self.up[i_level].upsample(h)
with torch.no_grad():
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h_new = self.up[i_level].block[i_block](h, temb, **kwargs)
if len(self.up[i_level].attn) > 0:
h_new = self.up[i_level].attn[i_block](h_new, **kwargs)
del h
torch.cuda.empty_cache()
h = h_new
if i_level != 0:
h_new = self.up[i_level].upsample(h)
del h
torch.cuda.empty_cache()
h = h_new
# end
if self.give_pre_end: