This commit is contained in:
Bowen Xue 2025-11-30 23:10:00 +00:00 committed by GitHub
commit 7c8e1419b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -714,13 +714,21 @@ class Decoder(nn.Module):
h = self.mid.block_2(h, temb, **kwargs) h = self.mid.block_2(h, temb, **kwargs)
# upsampling # upsampling
for i_level in reversed(range(self.num_resolutions)): with torch.no_grad():
for i_block in range(self.num_res_blocks+1): for i_level in reversed(range(self.num_resolutions)):
h = self.up[i_level].block[i_block](h, temb, **kwargs) for i_block in range(self.num_res_blocks + 1):
if len(self.up[i_level].attn) > 0: h_new = self.up[i_level].block[i_block](h, temb, **kwargs)
h = self.up[i_level].attn[i_block](h, **kwargs) if len(self.up[i_level].attn) > 0:
if i_level != 0: h_new = self.up[i_level].attn[i_block](h_new, **kwargs)
h = self.up[i_level].upsample(h) del h
model_management.soft_empty_cache()
h = h_new
if i_level != 0:
h_new = self.up[i_level].upsample(h)
del h
model_management.soft_empty_cache()
h = h_new
# end # end
if self.give_pre_end: if self.give_pre_end: