From d80df37ac45075be0d1576e4c7a885de2f42dc6a Mon Sep 17 00:00:00 2001 From: tvukovic-amd Date: Tue, 27 Jan 2026 09:19:25 +0100 Subject: [PATCH] fix: add memory precheck before VAE decode to prevent crash --- comfy/model_management.py | 63 +++++++++++++++++++++++++++++++++++++++ comfy/sd.py | 60 +++++++++++++++++++++++++------------ 2 files changed, 104 insertions(+), 19 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9d39be7b2..453623d00 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1093,6 +1093,69 @@ def sync_stream(device, stream): return current_stream(device).wait_stream(stream) +def use_tiled_vae_decode(memory_needed, device=None): + try: + if device is None: + device = get_torch_device() + + # If running everything on CPU, no GPU memory check needed + if cpu_state == CPUState.CPU or args.cpu_vae: + return False + + inference_memory = minimum_inference_memory() + memory_required = max(inference_memory, memory_needed + extra_reserved_memory()) + + gpu_free = get_free_memory(device) + cpu_free = psutil.virtual_memory().available + + # Check if GPU have enough space for full decode (with reserves)? + if gpu_free >= memory_required: + return False + + # With --gpu-only, models can't offload to CPU (offload device = GPU) + if args.gpu_only: + return True + + # Calculate memory_to_free + memory_to_free = memory_required - gpu_free + + # Calculate how much we can offload from currently loaded models - only count models whose offload_device is CPU + # With --highvram, UNet has offload_device=GPU so it CAN'T be offloaded. + loaded_model_memory = 0 + cpu_offloadable_memory = 0 + + for loaded_model in current_loaded_models: + if loaded_model.device == device: + model_size = loaded_model.model_loaded_memory() + loaded_model_memory += model_size + if hasattr(loaded_model.model, 'offload_device'): + offload_dev = loaded_model.model.offload_device + if is_device_cpu(offload_dev): + cpu_offloadable_memory += model_size + else: + cpu_offloadable_memory += model_size + + # Check is there enough to offload (to CPU)? + if cpu_offloadable_memory < memory_to_free: + return True # Can't offload enough, must tile + + # Check if CPU can receive the offload (which prevents 0xC0000005 crash) + # Smart Memory ON (default) - partial offload: only memory_to_free bytes move to CPU + # Smart Memory OFF (--disable-smart-memory) - full offload: ALL models get fully unloaded + if DISABLE_SMART_MEMORY: + if cpu_free < loaded_model_memory: + return True + else: + return False + else: + # With smart memory, only partial offload (memory_to_free bytes) moves to CPU + if cpu_free < memory_to_free: + return True + return False + + except Exception as e: + return True + def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): if device is None or weight.device == device: if not copy: diff --git a/comfy/sd.py b/comfy/sd.py index f627f7d55..1b0322385 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -894,29 +894,51 @@ class VAE: do_tile = False if self.latent_dim == 2 and samples_in.ndim == 5: samples_in = samples_in[:, :, 0] - try: - memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) - model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) - batch_number = int(free_memory / memory_used) - batch_number = max(1, batch_number) - - for x in range(0, samples_in.shape[0], batch_number): - samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float()) - if pixel_samples is None: - pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) - pixel_samples[x:x+batch_number] = out - except model_management.OOM_EXCEPTION: - logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") - #NOTE: We don't know what tensors were allocated to stack variables at the time of the - #exception and the exception itself refs them all until we get out of this except block. - #So we just set a flag for tiler fallback so that tensor gc can happen once the - #exception is fully off the books. + + memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) + + # Memory check: switch to tiled decode if GPU can't fit full decode + # and models can't offload to CPU, preventing 0xC0000005 crash + if model_management.use_tiled_vae_decode(memory_used, self.device): + logging.warning("[VAE DECODE] Insufficient memory for regular VAE decoding, switching to tiled VAE decoding.") do_tile = True + + if not do_tile: + try: + model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) + free_memory = model_management.get_free_memory(self.device) + batch_number = int(free_memory / memory_used) + batch_number = max(1, batch_number) + + for x in range(0, samples_in.shape[0], batch_number): + samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) + out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float()) + if pixel_samples is None: + pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) + pixel_samples[x:x+batch_number] = out + + except model_management.OOM_EXCEPTION: + logging.warning("[VAE DECODE] OOM! Ran out of memory during regular VAE decoding, retrying with tiled VAE decoding.") + #NOTE: We don't know what tensors were allocated to stack variables at the time of the + #exception and the exception itself refs them all until we get out of this except block. + #So we just set a flag for tiler fallback so that tensor gc can happen once the + #exception is fully off the books. if do_tile: dims = samples_in.ndim - 2 + if dims == 1: + tile_shape = (1, samples_in.shape[1], 128) # 1D tile estimate + elif dims == 2: + tile_shape = (1, samples_in.shape[1], 64, 64) # 2D tile: 64x64 + else: + tile = 256 // self.spacial_compression_decode() + tile_shape = (1, samples_in.shape[1], 8, tile, tile) # 3D tile estimate + + # Calculate tile memory + tile_memory = self.memory_used_decode(tile_shape, self.vae_dtype) + + model_management.load_models_gpu([self.patcher], memory_required=tile_memory, force_full_load=self.disable_offload) + if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) elif dims == 2: