Add stream host pin buffer for AIMDO casts

Introduce per-offload-stream HostBuffer reuse for pinned staging,
include it in cast buffer reset synchronization.

Defer actual casts that go via this pin path to a separate pass
such that the buffer can be allocated monolithically (to avoid
cudaHostRegister thrash).
This commit is contained in:
Rattus 2026-05-07 14:04:48 +10:00
parent cb44c74b12
commit 87f2f43bb7
2 changed files with 62 additions and 12 deletions

View File

@ -31,6 +31,7 @@ from contextlib import nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.host_buffer
import comfy_aimdo.vram_buffer
class VRAMState(Enum):
@ -1180,8 +1181,10 @@ STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
STREAM_PIN_BUFFERS = {}
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
DEFAULT_PIN_BUFFER_PRIME_SIZE = 1024 ** 2
def get_cast_buffer(offload_stream, device, size, ref):
global LARGEST_CASTED_WEIGHT
@ -1220,21 +1223,32 @@ def get_aimdo_cast_buffer(offload_stream, device):
if cast_buffer is None:
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer
def get_pin_buffer(offload_stream):
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
if pin_buffer is None:
# A small non-zero default primes HostBuffer's larger virtual reservation.
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(DEFAULT_PIN_BUFFER_PRIME_SIZE)
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
elif offload_stream is not None:
offload_stream.synchronize()
return pin_buffer
def reset_cast_buffers():
global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0)
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS):
if offload_stream is not None:
offload_stream.synchronize()
synchronize()
STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
STREAM_PIN_BUFFERS.clear()
soft_empty_cache()
def get_offload_stream(device):

View File

@ -75,6 +75,8 @@ except:
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@ -91,6 +93,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
offload_stream = None
cast_buffer = None
cast_buffer_offset = 0
stream_pin_hostbuf = None
stream_pin_offset = 0
stream_pin_queue = []
def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream
@ -124,6 +129,20 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
cast_buffer_offset += buffer_size
return buffer
def get_stream_pin_buffer_offset(buffer_size):
nonlocal stream_pin_hostbuf
nonlocal stream_pin_offset
if buffer_size == 0 or offload_stream is None:
return None
if stream_pin_hostbuf is None:
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
offset = stream_pin_offset
stream_pin_offset += buffer_size
return offset
for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
@ -162,17 +181,21 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
if xfer_dest is None:
xfer_dest = get_cast_buffer(dest_size)
if signature is None and pin is None:
comfy.pinned_memory.pin_memory(s)
pin = comfy.pinned_memory.get_pin(s)
else:
pin = None
if pin is None:
if signature is None:
comfy.pinned_memory.pin_memory(s)
pin = comfy.pinned_memory.get_pin(s)
if pin is not None:
comfy.model_management.cast_to_gathered(xfer_source, pin)
xfer_source = [ pin ]
if pin is None:
pin_offset = get_stream_pin_buffer_offset(dest_size)
if pin_offset is not None:
stream_pin_queue.append((xfer_source, pin_offset, dest_size, xfer_dest))
xfer_source = None
if pin is not None:
comfy.model_management.cast_to_gathered(xfer_source, pin)
xfer_source = [ pin ]
#send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
if xfer_source is not None:
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
@ -186,6 +209,19 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch
if stream_pin_offset > 0:
stream_pin_hostbuf_size = getattr(stream_pin_hostbuf, "_comfy_stream_pin_size", stream_pin_hostbuf.size)
if stream_pin_hostbuf_size < stream_pin_offset:
stream_pin_hostbuf_size = stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM
stream_pin_hostbuf.extend(size=stream_pin_hostbuf_size, reallocate=True)
stream_pin_hostbuf._comfy_stream_pin_size = stream_pin_hostbuf_size
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf, size=stream_pin_offset)
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
comfy.model_management.cast_to_gathered(xfer_source, pin)
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
return offload_stream