diff --git a/comfy/memory_management.py b/comfy/memory_management.py index f8bca5263..88b6da1e3 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -1,6 +1,10 @@ import torch from comfy.quant_ops import QuantizedTensor +import comfy_aimdo.torch + +import logging + def vram_aligned_size(tensor): if isinstance(tensor, list): return sum([vram_aligned_size(t) for t in tensor]) @@ -49,3 +53,5 @@ def interpret_gathered_like(tensors, gathered): dest_views.append(actuals["data"]) return dest_views + +aimdo_allocator = comfy_aimdo.torch.CUDAPluggableAllocator() diff --git a/execution.py b/execution.py index 648f204ec..fe162db26 100644 --- a/execution.py +++ b/execution.py @@ -1,3 +1,4 @@ +import gc import copy import heapq import inspect @@ -9,9 +10,12 @@ import traceback from enum import Enum from typing import List, Literal, NamedTuple, Optional, Union import asyncio +from contextlib import nullcontext import torch +import comfy.pinned_memory +import comfy.memory_management import comfy.model_management from latent_preview import set_preview_method import nodes @@ -515,7 +519,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + + #Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows + #will cause all sorts of incompatible memory shapes to fragment the pytorch alloc + #that we just want to cull out each model run. + allocator = comfy.memory_management.aimdo_allocator + with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + torch.cuda.synchronize() + if allocator is not None: + #FIXME: this is probably a little zealous + # Torch code comments says some stuff about not actually freeing tensors on mempool + #context release. Explicitly garbage collect now. + gc.collect() + torch.cuda.empty_cache() + if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id)