fix: add memory precheck before VAE decode to prevent crash

This commit is contained in:
tvukovic-amd 2026-01-27 09:19:25 +01:00
parent 09725967cf
commit d80df37ac4
2 changed files with 104 additions and 19 deletions

View File

@ -1093,6 +1093,69 @@ def sync_stream(device, stream):
return
current_stream(device).wait_stream(stream)
def use_tiled_vae_decode(memory_needed, device=None):
try:
if device is None:
device = get_torch_device()
# If running everything on CPU, no GPU memory check needed
if cpu_state == CPUState.CPU or args.cpu_vae:
return False
inference_memory = minimum_inference_memory()
memory_required = max(inference_memory, memory_needed + extra_reserved_memory())
gpu_free = get_free_memory(device)
cpu_free = psutil.virtual_memory().available
# Check if GPU have enough space for full decode (with reserves)?
if gpu_free >= memory_required:
return False
# With --gpu-only, models can't offload to CPU (offload device = GPU)
if args.gpu_only:
return True
# Calculate memory_to_free
memory_to_free = memory_required - gpu_free
# Calculate how much we can offload from currently loaded models - only count models whose offload_device is CPU
# With --highvram, UNet has offload_device=GPU so it CAN'T be offloaded.
loaded_model_memory = 0
cpu_offloadable_memory = 0
for loaded_model in current_loaded_models:
if loaded_model.device == device:
model_size = loaded_model.model_loaded_memory()
loaded_model_memory += model_size
if hasattr(loaded_model.model, 'offload_device'):
offload_dev = loaded_model.model.offload_device
if is_device_cpu(offload_dev):
cpu_offloadable_memory += model_size
else:
cpu_offloadable_memory += model_size
# Check is there enough to offload (to CPU)?
if cpu_offloadable_memory < memory_to_free:
return True # Can't offload enough, must tile
# Check if CPU can receive the offload (which prevents 0xC0000005 crash)
# Smart Memory ON (default) - partial offload: only memory_to_free bytes move to CPU
# Smart Memory OFF (--disable-smart-memory) - full offload: ALL models get fully unloaded
if DISABLE_SMART_MEMORY:
if cpu_free < loaded_model_memory:
return True
else:
return False
else:
# With smart memory, only partial offload (memory_to_free bytes) moves to CPU
if cpu_free < memory_to_free:
return True
return False
except Exception as e:
return True
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
if not copy:

View File

@ -894,29 +894,51 @@ class VAE:
do_tile = False
if self.latent_dim == 2 and samples_in.ndim == 5:
samples_in = samples_in[:, :, 0]
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
# Memory check: switch to tiled decode if GPU can't fit full decode
# and models can't offload to CPU, preventing 0xC0000005 crash
if model_management.use_tiled_vae_decode(memory_used, self.device):
logging.warning("[VAE DECODE] Insufficient memory for regular VAE decoding, switching to tiled VAE decoding.")
do_tile = True
if not do_tile:
try:
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
logging.warning("[VAE DECODE] OOM! Ran out of memory during regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
#So we just set a flag for tiler fallback so that tensor gc can happen once the
#exception is fully off the books.
if do_tile:
dims = samples_in.ndim - 2
if dims == 1:
tile_shape = (1, samples_in.shape[1], 128) # 1D tile estimate
elif dims == 2:
tile_shape = (1, samples_in.shape[1], 64, 64) # 2D tile: 64x64
else:
tile = 256 // self.spacial_compression_decode()
tile_shape = (1, samples_in.shape[1], 8, tile, tile) # 3D tile estimate
# Calculate tile memory
tile_memory = self.memory_used_decode(tile_shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=tile_memory, force_full_load=self.disable_offload)
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2: