From 8e1d7087a3ce6fca751d58858a2405a4c3e52105 Mon Sep 17 00:00:00 2001 From: KBRASK Date: Sun, 13 Apr 2025 12:03:50 +0800 Subject: [PATCH 1/2] Added explicit memory management during the VAE decode process. --- comfy/ldm/modules/diffusionmodules/model.py | 22 ++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 8162742cf..c59b03725 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -714,13 +714,21 @@ class Decoder(nn.Module): h = self.mid.block_2(h, temb, **kwargs) # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb, **kwargs) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h, **kwargs) - if i_level != 0: - h = self.up[i_level].upsample(h) + with torch.no_grad(): + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h_new = self.up[i_level].block[i_block](h, temb, **kwargs) + if len(self.up[i_level].attn) > 0: + h_new = self.up[i_level].attn[i_block](h_new, **kwargs) + del h + torch.cuda.empty_cache() + h = h_new + + if i_level != 0: + h_new = self.up[i_level].upsample(h) + del h + torch.cuda.empty_cache() + h = h_new # end if self.give_pre_end: From 0264149f4671649f5854f6669ecc538618a7c388 Mon Sep 17 00:00:00 2001 From: KBRASK Date: Sun, 20 Apr 2025 15:47:18 +0800 Subject: [PATCH 2/2] To enable support for non-CUDA devices, substitute torch.cuda.empty_cache() with model_management.soft_empty_cache(). --- comfy/ldm/modules/diffusionmodules/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index c59b03725..0b2d92e94 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -721,13 +721,13 @@ class Decoder(nn.Module): if len(self.up[i_level].attn) > 0: h_new = self.up[i_level].attn[i_block](h_new, **kwargs) del h - torch.cuda.empty_cache() + model_management.soft_empty_cache() h = h_new if i_level != 0: h_new = self.up[i_level].upsample(h) del h - torch.cuda.empty_cache() + model_management.soft_empty_cache() h = h_new # end