execution: add aimdo primary pytorch cache integration

We need to general pytorch cache defragmentation on an appropriate level for
aimdo. Do in here on the per node basis, which has a reasonable chance of
purging stale shapes out of the pytorch caching allocator and saving VRAM
without costing too much garbage collector thrash.

This looks like a lot of GC but because aimdo never fails from pytorch and
saves the pytorch allocator from ever need to defrag out of demand, but it
needs a oil change every now and then so we gotta do it. Doing it here also
means the pytorch temps are cleared from task manager VRAM usage so user
anxiety can go down a little when they see their vram drop back at the end
of workflows inline with inference usage (rather than assuming full VRAM
leaks).
This commit is contained in:
Rattus 2026-01-13 15:49:07 +10:00
parent 3597b27515
commit 64c2541b05
2 changed files with 25 additions and 1 deletions

View File

@ -1,6 +1,10 @@
import torch
from comfy.quant_ops import QuantizedTensor
import comfy_aimdo.torch
import logging
def vram_aligned_size(tensor):
if isinstance(tensor, list):
return sum([vram_aligned_size(t) for t in tensor])
@ -49,3 +53,5 @@ def interpret_gathered_like(tensors, gathered):
dest_views.append(actuals["data"])
return dest_views
aimdo_allocator = comfy_aimdo.torch.CUDAPluggableAllocator()

View File

@ -1,3 +1,4 @@
import gc
import copy
import heapq
import inspect
@ -9,9 +10,12 @@ import traceback
from enum import Enum
from typing import List, Literal, NamedTuple, Optional, Union
import asyncio
from contextlib import nullcontext
import torch
import comfy.pinned_memory
import comfy.memory_management
import comfy.model_management
from latent_preview import set_preview_method
import nodes
@ -515,7 +519,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
#Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
#that we just want to cull out each model run.
allocator = comfy.memory_management.aimdo_allocator
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
torch.cuda.synchronize()
if allocator is not None:
#FIXME: this is probably a little zealous
# Torch code comments says some stuff about not actually freeing tensors on mempool
#context release. Explicitly garbage collect now.
gc.collect()
torch.cuda.empty_cache()
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)