From 64c2541b0545de224e3dda660d41467be02bc4a2 Mon Sep 17 00:00:00 2001 From: Rattus Date: Tue, 13 Jan 2026 15:49:07 +1000 Subject: [PATCH] execution: add aimdo primary pytorch cache integration We need to general pytorch cache defragmentation on an appropriate level for aimdo. Do in here on the per node basis, which has a reasonable chance of purging stale shapes out of the pytorch caching allocator and saving VRAM without costing too much garbage collector thrash. This looks like a lot of GC but because aimdo never fails from pytorch and saves the pytorch allocator from ever need to defrag out of demand, but it needs a oil change every now and then so we gotta do it. Doing it here also means the pytorch temps are cleared from task manager VRAM usage so user anxiety can go down a little when they see their vram drop back at the end of workflows inline with inference usage (rather than assuming full VRAM leaks). --- comfy/memory_management.py | 6 ++++++ execution.py | 20 +++++++++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/comfy/memory_management.py b/comfy/memory_management.py index f8bca5263..88b6da1e3 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -1,6 +1,10 @@ import torch from comfy.quant_ops import QuantizedTensor +import comfy_aimdo.torch + +import logging + def vram_aligned_size(tensor): if isinstance(tensor, list): return sum([vram_aligned_size(t) for t in tensor]) @@ -49,3 +53,5 @@ def interpret_gathered_like(tensors, gathered): dest_views.append(actuals["data"]) return dest_views + +aimdo_allocator = comfy_aimdo.torch.CUDAPluggableAllocator() diff --git a/execution.py b/execution.py index 648f204ec..fe162db26 100644 --- a/execution.py +++ b/execution.py @@ -1,3 +1,4 @@ +import gc import copy import heapq import inspect @@ -9,9 +10,12 @@ import traceback from enum import Enum from typing import List, Literal, NamedTuple, Optional, Union import asyncio +from contextlib import nullcontext import torch +import comfy.pinned_memory +import comfy.memory_management import comfy.model_management from latent_preview import set_preview_method import nodes @@ -515,7 +519,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + + #Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows + #will cause all sorts of incompatible memory shapes to fragment the pytorch alloc + #that we just want to cull out each model run. + allocator = comfy.memory_management.aimdo_allocator + with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + torch.cuda.synchronize() + if allocator is not None: + #FIXME: this is probably a little zealous + # Torch code comments says some stuff about not actually freeing tensors on mempool + #context release. Explicitly garbage collect now. + gc.collect() + torch.cuda.empty_cache() + if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id)