mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 08:22:36 +08:00
fix: add memory precheck before VAE decode to prevent crash
This commit is contained in:
parent
09725967cf
commit
d80df37ac4
@ -1093,6 +1093,69 @@ def sync_stream(device, stream):
|
|||||||
return
|
return
|
||||||
current_stream(device).wait_stream(stream)
|
current_stream(device).wait_stream(stream)
|
||||||
|
|
||||||
|
def use_tiled_vae_decode(memory_needed, device=None):
|
||||||
|
try:
|
||||||
|
if device is None:
|
||||||
|
device = get_torch_device()
|
||||||
|
|
||||||
|
# If running everything on CPU, no GPU memory check needed
|
||||||
|
if cpu_state == CPUState.CPU or args.cpu_vae:
|
||||||
|
return False
|
||||||
|
|
||||||
|
inference_memory = minimum_inference_memory()
|
||||||
|
memory_required = max(inference_memory, memory_needed + extra_reserved_memory())
|
||||||
|
|
||||||
|
gpu_free = get_free_memory(device)
|
||||||
|
cpu_free = psutil.virtual_memory().available
|
||||||
|
|
||||||
|
# Check if GPU have enough space for full decode (with reserves)?
|
||||||
|
if gpu_free >= memory_required:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# With --gpu-only, models can't offload to CPU (offload device = GPU)
|
||||||
|
if args.gpu_only:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Calculate memory_to_free
|
||||||
|
memory_to_free = memory_required - gpu_free
|
||||||
|
|
||||||
|
# Calculate how much we can offload from currently loaded models - only count models whose offload_device is CPU
|
||||||
|
# With --highvram, UNet has offload_device=GPU so it CAN'T be offloaded.
|
||||||
|
loaded_model_memory = 0
|
||||||
|
cpu_offloadable_memory = 0
|
||||||
|
|
||||||
|
for loaded_model in current_loaded_models:
|
||||||
|
if loaded_model.device == device:
|
||||||
|
model_size = loaded_model.model_loaded_memory()
|
||||||
|
loaded_model_memory += model_size
|
||||||
|
if hasattr(loaded_model.model, 'offload_device'):
|
||||||
|
offload_dev = loaded_model.model.offload_device
|
||||||
|
if is_device_cpu(offload_dev):
|
||||||
|
cpu_offloadable_memory += model_size
|
||||||
|
else:
|
||||||
|
cpu_offloadable_memory += model_size
|
||||||
|
|
||||||
|
# Check is there enough to offload (to CPU)?
|
||||||
|
if cpu_offloadable_memory < memory_to_free:
|
||||||
|
return True # Can't offload enough, must tile
|
||||||
|
|
||||||
|
# Check if CPU can receive the offload (which prevents 0xC0000005 crash)
|
||||||
|
# Smart Memory ON (default) - partial offload: only memory_to_free bytes move to CPU
|
||||||
|
# Smart Memory OFF (--disable-smart-memory) - full offload: ALL models get fully unloaded
|
||||||
|
if DISABLE_SMART_MEMORY:
|
||||||
|
if cpu_free < loaded_model_memory:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# With smart memory, only partial offload (memory_to_free bytes) moves to CPU
|
||||||
|
if cpu_free < memory_to_free:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return True
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||||
if device is None or weight.device == device:
|
if device is None or weight.device == device:
|
||||||
if not copy:
|
if not copy:
|
||||||
|
|||||||
58
comfy/sd.py
58
comfy/sd.py
@ -894,29 +894,51 @@ class VAE:
|
|||||||
do_tile = False
|
do_tile = False
|
||||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||||
samples_in = samples_in[:, :, 0]
|
samples_in = samples_in[:, :, 0]
|
||||||
try:
|
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
|
||||||
free_memory = model_management.get_free_memory(self.device)
|
|
||||||
batch_number = int(free_memory / memory_used)
|
|
||||||
batch_number = max(1, batch_number)
|
|
||||||
|
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
|
||||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
|
# Memory check: switch to tiled decode if GPU can't fit full decode
|
||||||
if pixel_samples is None:
|
# and models can't offload to CPU, preventing 0xC0000005 crash
|
||||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
if model_management.use_tiled_vae_decode(memory_used, self.device):
|
||||||
pixel_samples[x:x+batch_number] = out
|
logging.warning("[VAE DECODE] Insufficient memory for regular VAE decoding, switching to tiled VAE decoding.")
|
||||||
except model_management.OOM_EXCEPTION:
|
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
|
||||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
|
||||||
#exception and the exception itself refs them all until we get out of this except block.
|
|
||||||
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
|
||||||
#exception is fully off the books.
|
|
||||||
do_tile = True
|
do_tile = True
|
||||||
|
|
||||||
|
if not do_tile:
|
||||||
|
try:
|
||||||
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
|
batch_number = int(free_memory / memory_used)
|
||||||
|
batch_number = max(1, batch_number)
|
||||||
|
|
||||||
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
|
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||||
|
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
|
||||||
|
if pixel_samples is None:
|
||||||
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||||
|
pixel_samples[x:x+batch_number] = out
|
||||||
|
|
||||||
|
except model_management.OOM_EXCEPTION:
|
||||||
|
logging.warning("[VAE DECODE] OOM! Ran out of memory during regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
|
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||||
|
#exception and the exception itself refs them all until we get out of this except block.
|
||||||
|
#So we just set a flag for tiler fallback so that tensor gc can happen once the
|
||||||
|
#exception is fully off the books.
|
||||||
|
|
||||||
if do_tile:
|
if do_tile:
|
||||||
dims = samples_in.ndim - 2
|
dims = samples_in.ndim - 2
|
||||||
|
if dims == 1:
|
||||||
|
tile_shape = (1, samples_in.shape[1], 128) # 1D tile estimate
|
||||||
|
elif dims == 2:
|
||||||
|
tile_shape = (1, samples_in.shape[1], 64, 64) # 2D tile: 64x64
|
||||||
|
else:
|
||||||
|
tile = 256 // self.spacial_compression_decode()
|
||||||
|
tile_shape = (1, samples_in.shape[1], 8, tile, tile) # 3D tile estimate
|
||||||
|
|
||||||
|
# Calculate tile memory
|
||||||
|
tile_memory = self.memory_used_decode(tile_shape, self.vae_dtype)
|
||||||
|
|
||||||
|
model_management.load_models_gpu([self.patcher], memory_required=tile_memory, force_full_load=self.disable_offload)
|
||||||
|
|
||||||
if dims == 1 or self.extra_1d_channel is not None:
|
if dims == 1 or self.extra_1d_channel is not None:
|
||||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user