mm: Use Aimdo raw allocator for cast buffers

pytorch manages allocation of growing buffers on streams poorly. Pyt
has no windows support for the expandable segments allocator (which is
the right tool for this job), while also segmenting the memory by
stream such that it can be generally re-used. So kick the problem to
aimdo which can just grow a virtual region thats freed per stream.
This commit is contained in:
Rattus 2026-02-16 18:49:56 +10:00
parent 24de8dc01b
commit 9cfd71e821
2 changed files with 26 additions and 5 deletions

View File

@ -31,6 +31,7 @@ from contextlib import nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.vram_buffer
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -1181,6 +1182,10 @@ stream_counters = {}
STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
def get_cast_buffer(offload_stream, device, size, ref):
global LARGEST_CASTED_WEIGHT
@ -1214,13 +1219,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
return cast_buffer
def get_aimdo_cast_buffer(offload_stream, device):
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
if cast_buffer is None:
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer
def reset_cast_buffers():
global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0)
for offload_stream in STREAM_CAST_BUFFERS:
offload_stream.synchronize()
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
if offload_stream is not None:
offload_stream.synchronize()
synchronize()
STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
soft_empty_cache()
def get_offload_stream(device):

View File

@ -138,10 +138,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
offload_stream = comfy.model_management.get_offload_stream(device)
if xfer_dest is None and offload_stream is not None:
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
if xfer_dest is None:
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
if cast_buffer.size() < dest_size and s is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
offload_stream = comfy.model_management.get_offload_stream(device)
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size), device)
if dest_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (s, dest_size)
if xfer_dest is None:
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
offload_stream = None