comfy-aimdo 0.2 - Improved pytorch allocator integration (#12557)

Integrate comfy-aimdo 0.2 which takes a different approach to
installing the memory allocator hook. Instead of using the complicated
and buggy pytorch MemPool+CudaPluggableAlloctor, cuda is directly hooked
making the process much more transparent to both comfy and pytorch. As
far as pytorch knows, aimdo doesnt exist anymore, and just operates
behind the scenes.

Remove all the mempool setup stuff for dynamic_vram and bump the
comfy-aimdo version. Remove the allocator object from memory_management
and demote its use as an enablment check to a boolean flag.

Comfy-aimdo 0.2 also support the pytorch cuda async allocator, so
remove the dynamic_vram based force disablement of cuda_malloc and
just go back to the old settings of allocators based on command line
input.
This commit is contained in:
rattus 2026-02-21 10:52:57 -08:00 committed by GitHub
parent 602b2505a4
commit 0bfb936ab4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 18 additions and 32 deletions

View File

@ -78,4 +78,4 @@ def interpret_gathered_like(tensors, gathered):
return dest_views return dest_views
aimdo_allocator = None aimdo_enabled = False

View File

@ -836,7 +836,7 @@ def unet_inital_load_device(parameters, dtype):
mem_dev = get_free_memory(torch_dev) mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev) mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None: if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
return torch_dev return torch_dev
else: else:
return cpu_dev return cpu_dev
@ -1121,7 +1121,6 @@ def get_cast_buffer(offload_stream, device, size, ref):
synchronize() synchronize()
del STREAM_CAST_BUFFERS[offload_stream] del STREAM_CAST_BUFFERS[offload_stream]
del cast_buffer del cast_buffer
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
soft_empty_cache() soft_empty_cache()
with wf_context: with wf_context:
cast_buffer = torch.empty((size), dtype=torch.int8, device=device) cast_buffer = torch.empty((size), dtype=torch.int8, device=device)

View File

@ -1154,7 +1154,7 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
def model_trange(*args, **kwargs): def model_trange(*args, **kwargs):
if comfy.memory_management.aimdo_allocator is None: if not comfy.memory_management.aimdo_enabled:
return trange(*args, **kwargs) return trange(*args, **kwargs)
pbar = trange(*args, **kwargs, smoothing=1.0) pbar = trange(*args, **kwargs, smoothing=1.0)

View File

@ -1,10 +1,8 @@
import os import os
import importlib.util import importlib.util
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram from comfy.cli_args import args, PerformanceFeature
import subprocess import subprocess
import comfy_aimdo.control
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names(): def get_gpu_names():
if os.name == 'nt': if os.name == 'nt':
@ -87,10 +85,6 @@ if not args.cuda_malloc:
except: except:
pass pass
if enables_dynamic_vram() and comfy_aimdo.control.init():
args.cuda_malloc = False
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ""
if args.disable_cuda_malloc: if args.disable_cuda_malloc:
args.cuda_malloc = False args.cuda_malloc = False

View File

@ -9,7 +9,6 @@ import traceback
from enum import Enum from enum import Enum
from typing import List, Literal, NamedTuple, Optional, Union from typing import List, Literal, NamedTuple, Optional, Union
import asyncio import asyncio
from contextlib import nullcontext
import torch import torch
@ -521,19 +520,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0) GraphBuilder.set_default_prefix(unique_id, call_index, 0)
#Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows try:
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
#that we just want to cull out each model run. finally:
allocator = comfy.memory_management.aimdo_allocator if comfy.memory_management.aimdo_enabled:
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): if args.verbose == "DEBUG":
try: comfy_aimdo.control.analyze()
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) comfy.model_management.reset_cast_buffers()
finally: comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if allocator is not None:
if args.verbose == "DEBUG":
comfy_aimdo.model_vbar.vbars_analyze()
comfy.model_management.reset_cast_buffers()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if has_pending_tasks: if has_pending_tasks:
pending_async_nodes[unique_id] = output_data pending_async_nodes[unique_id] = output_data

11
main.py
View File

@ -173,6 +173,10 @@ import gc
if 'torch' in sys.modules: if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
import comfy_aimdo.control
if enables_dynamic_vram():
comfy_aimdo.control.init()
import comfy.utils import comfy.utils
@ -188,13 +192,9 @@ import hook_breaker_ac10a0
import comfy.memory_management import comfy.memory_management
import comfy.model_patcher import comfy.model_patcher
import comfy_aimdo.control
import comfy_aimdo.torch
if enables_dynamic_vram(): if enables_dynamic_vram():
if comfy.model_management.torch_version_numeric < (2, 8): if comfy.model_management.torch_version_numeric < (2, 8):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
comfy.memory_management.aimdo_allocator = None
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
if args.verbose == 'DEBUG': if args.verbose == 'DEBUG':
comfy_aimdo.control.set_log_debug() comfy_aimdo.control.set_log_debug()
@ -208,11 +208,10 @@ if enables_dynamic_vram():
comfy_aimdo.control.set_log_info() comfy_aimdo.control.set_log_info()
comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic
comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator() comfy.memory_management.aimdo_enabled = True
logging.info("DynamicVRAM support detected and enabled") logging.info("DynamicVRAM support detected and enabled")
else: else:
logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
comfy.memory_management.aimdo_allocator = None
def cuda_malloc_warning(): def cuda_malloc_warning():

View File

@ -22,7 +22,7 @@ alembic
SQLAlchemy SQLAlchemy
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.7 comfy-kitchen>=0.2.7
comfy-aimdo>=0.1.8 comfy-aimdo>=0.2.0
requests requests
#non essential dependencies: #non essential dependencies: