This commit is contained in:
Bowen Xue 2025-11-30 23:10:00 +00:00 committed by GitHub
commit 7c8e1419b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -714,13 +714,21 @@ class Decoder(nn.Module):
h = self.mid.block_2(h, temb, **kwargs)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb, **kwargs)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, **kwargs)
if i_level != 0:
h = self.up[i_level].upsample(h)
with torch.no_grad():
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h_new = self.up[i_level].block[i_block](h, temb, **kwargs)
if len(self.up[i_level].attn) > 0:
h_new = self.up[i_level].attn[i_block](h_new, **kwargs)
del h
model_management.soft_empty_cache()
h = h_new
if i_level != 0:
h_new = self.up[i_level].upsample(h)
del h
model_management.soft_empty_cache()
h = h_new
# end
if self.give_pre_end: