diff --git a/comfy/model_management.py b/comfy/model_management.py index 527197447..888cea5c3 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1120,7 +1120,8 @@ def get_cast_buffer(offload_stream, device, size, ref): def reset_cast_buffers(): global LARGEST_CASTED_WEIGHT LARGEST_CASTED_WEIGHT = (None, 0) - torch.cuda.synchronize() + for offload_stream in STREAM_CAST_BUFFERS: + offload_stream.synchronize() STREAM_CAST_BUFFERS.clear() torch.cuda.empty_cache() diff --git a/execution.py b/execution.py index a25bd36cd..30bf50b9d 100644 --- a/execution.py +++ b/execution.py @@ -523,8 +523,11 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, #that we just want to cull out each model run. allocator = comfy.memory_management.aimdo_allocator with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) - torch.cuda.synchronize() + try: + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + finally: + if allocator is not None: + torch.cuda.synchronize() if allocator is not None: #FIXME: this is probably a little zealous # Torch code comments says some stuff about not actually freeing tensors on mempool