mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-15 22:27:34 +08:00
mm: dont clear_cache with mempools
Two things. * pyt2.7 crashes if you try and clear_cache in the presence of mempools. * mempools don't actually ever clear_cache because the mempool itself is considered a ref. Guard the code accordingly and remove useless clear_cache calls. The offload stream resizer will need some fixing.
This commit is contained in:
parent
f8f9a89f6e
commit
8067cb4f93
@ -1114,6 +1114,7 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
||||
torch.cuda.synchronize()
|
||||
del STREAM_CAST_BUFFERS[offload_stream]
|
||||
del cast_buffer
|
||||
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
|
||||
torch.cuda.empty_cache()
|
||||
with wf_context:
|
||||
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
|
||||
@ -1130,7 +1131,9 @@ def reset_cast_buffers():
|
||||
for offload_stream in STREAM_CAST_BUFFERS:
|
||||
offload_stream.synchronize()
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
torch.cuda.empty_cache()
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
@ -1686,9 +1689,11 @@ def soft_empty_cache(force=False):
|
||||
elif is_mlu():
|
||||
torch.mlu.empty_cache()
|
||||
elif torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
#Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
@ -934,7 +934,6 @@ class VAE:
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
torch.cuda.empty_cache()
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
@ -1010,7 +1009,6 @@ class VAE:
|
||||
do_tile = True
|
||||
|
||||
if do_tile:
|
||||
torch.cuda.empty_cache()
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
|
||||
@ -527,12 +527,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
finally:
|
||||
if allocator is not None:
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
torch.cuda.synchronize()
|
||||
if allocator is not None:
|
||||
#FIXME: this is probably a little zealous
|
||||
# Torch code comments says some stuff about not actually freeing tensors on mempool
|
||||
#context release. Explicitly garbage collect now.
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
|
||||
Loading…
Reference in New Issue
Block a user