diff --git a/comfy/model_management.py b/comfy/model_management.py index 804be7768..bb9ae5852 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1114,6 +1114,7 @@ def get_cast_buffer(offload_stream, device, size, ref): torch.cuda.synchronize() del STREAM_CAST_BUFFERS[offload_stream] del cast_buffer + #FIXME: This doesn't work in Aimdo because mempool cant clear cache torch.cuda.empty_cache() with wf_context: cast_buffer = torch.empty((size), dtype=torch.int8, device=device) @@ -1130,7 +1131,9 @@ def reset_cast_buffers(): for offload_stream in STREAM_CAST_BUFFERS: offload_stream.synchronize() STREAM_CAST_BUFFERS.clear() - torch.cuda.empty_cache() + if comfy.memory_management.aimdo_allocator is None: + #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist + torch.cuda.empty_cache() def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) @@ -1686,9 +1689,11 @@ def soft_empty_cache(force=False): elif is_mlu(): torch.mlu.empty_cache() elif torch.cuda.is_available(): - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + if comfy.memory_management.aimdo_allocator is None: + #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def unload_all_models(): free_memory(1e30, get_torch_device()) diff --git a/comfy/sd.py b/comfy/sd.py index 7e67c6919..fd0ac85e7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -934,7 +934,6 @@ class VAE: do_tile = True if do_tile: - torch.cuda.empty_cache() dims = samples_in.ndim - 2 if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) @@ -1010,7 +1009,6 @@ class VAE: do_tile = True if do_tile: - torch.cuda.empty_cache() if self.latent_dim == 3: tile = 256 overlap = tile // 4 diff --git a/execution.py b/execution.py index 9607e1636..93fafc4a2 100644 --- a/execution.py +++ b/execution.py @@ -527,12 +527,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) finally: if allocator is not None: + comfy.model_management.reset_cast_buffers() torch.cuda.synchronize() - if allocator is not None: - #FIXME: this is probably a little zealous - # Torch code comments says some stuff about not actually freeing tensors on mempool - #context release. Explicitly garbage collect now. - torch.cuda.empty_cache() if has_pending_tasks: pending_async_nodes[unique_id] = output_data