mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
sd: fix VAE tiled fallback VRAM leak (#10139)
When the VAE catches this VRAM OOM, it launches the fallback logic straight from the exception context. Python however refs the entire call stack that caused the exception including any local variables for the sake of exception report and debugging. In the case of tensors, this can hold on the references to GBs of VRAM and inhibit the VRAM allocated from freeing them. So dump the except context completely before going back to the VAE via the tiler by getting out of the except block with nothing but a flag. The greately increases the reliability of the tiler fallback, especially on low VRAM cards, as with the bug, if the leak randomly leaked more than the headroom needed for a single tile, the tiler would fallback would OOM and fail the flow.
This commit is contained in:
parent
bb32d4ec31
commit
911331c06c
16
comfy/sd.py
16
comfy/sd.py
@ -652,6 +652,7 @@ class VAE:
|
|||||||
def decode(self, samples_in, vae_options={}):
|
def decode(self, samples_in, vae_options={}):
|
||||||
self.throw_exception_if_invalid()
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = None
|
pixel_samples = None
|
||||||
|
do_tile = False
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
@ -667,6 +668,13 @@ class VAE:
|
|||||||
pixel_samples[x:x+batch_number] = out
|
pixel_samples[x:x+batch_number] = out
|
||||||
except model_management.OOM_EXCEPTION:
|
except model_management.OOM_EXCEPTION:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
|
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||||
|
#exception and the exception itself refs them all until we get out of this except block.
|
||||||
|
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||||
|
#exception is fully off the books.
|
||||||
|
do_tile = True
|
||||||
|
|
||||||
|
if do_tile:
|
||||||
dims = samples_in.ndim - 2
|
dims = samples_in.ndim - 2
|
||||||
if dims == 1 or self.extra_1d_channel is not None:
|
if dims == 1 or self.extra_1d_channel is not None:
|
||||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||||
@ -713,6 +721,7 @@ class VAE:
|
|||||||
self.throw_exception_if_invalid()
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
|
do_tile = False
|
||||||
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
if self.latent_dim == 3 and pixel_samples.ndim < 5:
|
||||||
if not self.not_video:
|
if not self.not_video:
|
||||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
@ -734,6 +743,13 @@ class VAE:
|
|||||||
|
|
||||||
except model_management.OOM_EXCEPTION:
|
except model_management.OOM_EXCEPTION:
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||||
|
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||||
|
#exception and the exception itself refs them all until we get out of this except block.
|
||||||
|
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||||
|
#exception is fully off the books.
|
||||||
|
do_tile = True
|
||||||
|
|
||||||
|
if do_tile:
|
||||||
if self.latent_dim == 3:
|
if self.latent_dim == 3:
|
||||||
tile = 256
|
tile = 256
|
||||||
overlap = tile // 4
|
overlap = tile // 4
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user