From 8067cb4f93de50b322187aae4921526f7bacf092 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 29 Jan 2026 01:42:38 +1000 Subject: [PATCH] mm: dont clear_cache with mempools Two things. * pyt2.7 crashes if you try and clear_cache in the presence of mempools. * mempools don't actually ever clear_cache because the mempool itself is considered a ref. Guard the code accordingly and remove useless clear_cache calls. The offload stream resizer will need some fixing. --- comfy/model_management.py | 13 +++++++++---- comfy/sd.py | 2 -- execution.py | 6 +----- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 804be7768..bb9ae5852 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1114,6 +1114,7 @@ def get_cast_buffer(offload_stream, device, size, ref): torch.cuda.synchronize() del STREAM_CAST_BUFFERS[offload_stream] del cast_buffer + #FIXME: This doesn't work in Aimdo because mempool cant clear cache torch.cuda.empty_cache() with wf_context: cast_buffer = torch.empty((size), dtype=torch.int8, device=device) @@ -1130,7 +1131,9 @@ def reset_cast_buffers(): for offload_stream in STREAM_CAST_BUFFERS: offload_stream.synchronize() STREAM_CAST_BUFFERS.clear() - torch.cuda.empty_cache() + if comfy.memory_management.aimdo_allocator is None: + #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist + torch.cuda.empty_cache() def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) @@ -1686,9 +1689,11 @@ def soft_empty_cache(force=False): elif is_mlu(): torch.mlu.empty_cache() elif torch.cuda.is_available(): - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + if comfy.memory_management.aimdo_allocator is None: + #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def unload_all_models(): free_memory(1e30, get_torch_device()) diff --git a/comfy/sd.py b/comfy/sd.py index 7e67c6919..fd0ac85e7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -934,7 +934,6 @@ class VAE: do_tile = True if do_tile: - torch.cuda.empty_cache() dims = samples_in.ndim - 2 if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) @@ -1010,7 +1009,6 @@ class VAE: do_tile = True if do_tile: - torch.cuda.empty_cache() if self.latent_dim == 3: tile = 256 overlap = tile // 4 diff --git a/execution.py b/execution.py index 9607e1636..93fafc4a2 100644 --- a/execution.py +++ b/execution.py @@ -527,12 +527,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) finally: if allocator is not None: + comfy.model_management.reset_cast_buffers() torch.cuda.synchronize() - if allocator is not None: - #FIXME: this is probably a little zealous - # Torch code comments says some stuff about not actually freeing tensors on mempool - #context release. Explicitly garbage collect now. - torch.cuda.empty_cache() if has_pending_tasks: pending_async_nodes[unique_id] = output_data