mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-03 02:00:29 +08:00
fix: add memory precheck before VAE decode to prevent crash
This commit is contained in:
parent
09725967cf
commit
d80df37ac4
@ -1093,6 +1093,69 @@ def sync_stream(device, stream):
|
||||
return
|
||||
current_stream(device).wait_stream(stream)
|
||||
|
||||
def use_tiled_vae_decode(memory_needed, device=None):
|
||||
try:
|
||||
if device is None:
|
||||
device = get_torch_device()
|
||||
|
||||
# If running everything on CPU, no GPU memory check needed
|
||||
if cpu_state == CPUState.CPU or args.cpu_vae:
|
||||
return False
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
memory_required = max(inference_memory, memory_needed + extra_reserved_memory())
|
||||
|
||||
gpu_free = get_free_memory(device)
|
||||
cpu_free = psutil.virtual_memory().available
|
||||
|
||||
# Check if GPU have enough space for full decode (with reserves)?
|
||||
if gpu_free >= memory_required:
|
||||
return False
|
||||
|
||||
# With --gpu-only, models can't offload to CPU (offload device = GPU)
|
||||
if args.gpu_only:
|
||||
return True
|
||||
|
||||
# Calculate memory_to_free
|
||||
memory_to_free = memory_required - gpu_free
|
||||
|
||||
# Calculate how much we can offload from currently loaded models - only count models whose offload_device is CPU
|
||||
# With --highvram, UNet has offload_device=GPU so it CAN'T be offloaded.
|
||||
loaded_model_memory = 0
|
||||
cpu_offloadable_memory = 0
|
||||
|
||||
for loaded_model in current_loaded_models:
|
||||
if loaded_model.device == device:
|
||||
model_size = loaded_model.model_loaded_memory()
|
||||
loaded_model_memory += model_size
|
||||
if hasattr(loaded_model.model, 'offload_device'):
|
||||
offload_dev = loaded_model.model.offload_device
|
||||
if is_device_cpu(offload_dev):
|
||||
cpu_offloadable_memory += model_size
|
||||
else:
|
||||
cpu_offloadable_memory += model_size
|
||||
|
||||
# Check is there enough to offload (to CPU)?
|
||||
if cpu_offloadable_memory < memory_to_free:
|
||||
return True # Can't offload enough, must tile
|
||||
|
||||
# Check if CPU can receive the offload (which prevents 0xC0000005 crash)
|
||||
# Smart Memory ON (default) - partial offload: only memory_to_free bytes move to CPU
|
||||
# Smart Memory OFF (--disable-smart-memory) - full offload: ALL models get fully unloaded
|
||||
if DISABLE_SMART_MEMORY:
|
||||
if cpu_free < loaded_model_memory:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
# With smart memory, only partial offload (memory_to_free bytes) moves to CPU
|
||||
if cpu_free < memory_to_free:
|
||||
return True
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
return True
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
|
||||
60
comfy/sd.py
60
comfy/sd.py
@ -894,29 +894,51 @@ class VAE:
|
||||
do_tile = False
|
||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||
samples_in = samples_in[:, :, 0]
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
|
||||
# Memory check: switch to tiled decode if GPU can't fit full decode
|
||||
# and models can't offload to CPU, preventing 0xC0000005 crash
|
||||
if model_management.use_tiled_vae_decode(memory_used, self.device):
|
||||
logging.warning("[VAE DECODE] Insufficient memory for regular VAE decoding, switching to tiled VAE decoding.")
|
||||
do_tile = True
|
||||
|
||||
if not do_tile:
|
||||
try:
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
free_memory = model_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("[VAE DECODE] OOM! Ran out of memory during regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||
#exception is fully off the books.
|
||||
|
||||
if do_tile:
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1:
|
||||
tile_shape = (1, samples_in.shape[1], 128) # 1D tile estimate
|
||||
elif dims == 2:
|
||||
tile_shape = (1, samples_in.shape[1], 64, 64) # 2D tile: 64x64
|
||||
else:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
tile_shape = (1, samples_in.shape[1], 8, tile, tile) # 3D tile estimate
|
||||
|
||||
# Calculate tile memory
|
||||
tile_memory = self.memory_used_decode(tile_shape, self.vae_dtype)
|
||||
|
||||
model_management.load_models_gpu([self.patcher], memory_required=tile_memory, force_full_load=self.disable_offload)
|
||||
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user