From 9cfd71e8212fc17a03be3e73f268ce3e25b3305b Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 16 Feb 2026 18:49:56 +1000 Subject: [PATCH] mm: Use Aimdo raw allocator for cast buffers pytorch manages allocation of growing buffers on streams poorly. Pyt has no windows support for the expandable segments allocator (which is the right tool for this job), while also segmenting the memory by stream such that it can be generally re-used. So kick the problem to aimdo which can just grow a virtual region thats freed per stream. --- comfy/model_management.py | 22 ++++++++++++++++++++-- comfy/ops.py | 9 ++++++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3b39d6080..afda2f086 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,7 @@ from contextlib import nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops +import comfy_aimdo.vram_buffer class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -1181,6 +1182,10 @@ stream_counters = {} STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) +STREAM_AIMDO_CAST_BUFFERS = {} +LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) + +DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 def get_cast_buffer(offload_stream, device, size, ref): global LARGEST_CASTED_WEIGHT @@ -1214,13 +1219,26 @@ def get_cast_buffer(offload_stream, device, size, ref): return cast_buffer +def get_aimdo_cast_buffer(offload_stream, device): + cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None) + if cast_buffer is None: + cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index) + STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer + + return cast_buffer + def reset_cast_buffers(): global LARGEST_CASTED_WEIGHT + global LARGEST_AIMDO_CASTED_WEIGHT + LARGEST_CASTED_WEIGHT = (None, 0) - for offload_stream in STREAM_CAST_BUFFERS: - offload_stream.synchronize() + LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) + for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): + if offload_stream is not None: + offload_stream.synchronize() synchronize() STREAM_CAST_BUFFERS.clear() + STREAM_AIMDO_CAST_BUFFERS.clear() soft_empty_cache() def get_offload_stream(device): diff --git a/comfy/ops.py b/comfy/ops.py index 050f7cda0..febde458d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -138,10 +138,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu dest_size = comfy.memory_management.vram_aligned_size(xfer_source) offload_stream = comfy.model_management.get_offload_stream(device) if xfer_dest is None and offload_stream is not None: - xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) - if xfer_dest is None: + cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) + if cast_buffer.size() < dest_size and s is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]: offload_stream = comfy.model_management.get_offload_stream(device) - xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size), device) + if dest_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]: + comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (s, dest_size) if xfer_dest is None: xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) offload_stream = None