mm: dont clear_cache with mempools

Two things.

* pyt2.7 crashes if you try and clear_cache in the presence of mempools.
* mempools don't actually ever clear_cache because the mempool itself is
considered a ref.

Guard the code accordingly and remove useless clear_cache calls.

The offload stream resizer will need some fixing.
This commit is contained in:
Rattus 2026-01-29 01:42:38 +10:00
parent f8f9a89f6e
commit 8067cb4f93
3 changed files with 10 additions and 11 deletions

View File

@ -1114,6 +1114,7 @@ def get_cast_buffer(offload_stream, device, size, ref):
torch.cuda.synchronize()
del STREAM_CAST_BUFFERS[offload_stream]
del cast_buffer
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
torch.cuda.empty_cache()
with wf_context:
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
@ -1130,7 +1131,9 @@ def reset_cast_buffers():
for offload_stream in STREAM_CAST_BUFFERS:
offload_stream.synchronize()
STREAM_CAST_BUFFERS.clear()
torch.cuda.empty_cache()
if comfy.memory_management.aimdo_allocator is None:
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
torch.cuda.empty_cache()
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
@ -1686,9 +1689,11 @@ def soft_empty_cache(force=False):
elif is_mlu():
torch.mlu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
if comfy.memory_management.aimdo_allocator is None:
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def unload_all_models():
free_memory(1e30, get_torch_device())

View File

@ -934,7 +934,6 @@ class VAE:
do_tile = True
if do_tile:
torch.cuda.empty_cache()
dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
@ -1010,7 +1009,6 @@ class VAE:
do_tile = True
if do_tile:
torch.cuda.empty_cache()
if self.latent_dim == 3:
tile = 256
overlap = tile // 4

View File

@ -527,12 +527,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
finally:
if allocator is not None:
comfy.model_management.reset_cast_buffers()
torch.cuda.synchronize()
if allocator is not None:
#FIXME: this is probably a little zealous
# Torch code comments says some stuff about not actually freeing tensors on mempool
#context release. Explicitly garbage collect now.
torch.cuda.empty_cache()
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data