Implement JIT pinned memory pressure

Replace the predictive pin pressure mechanism with JIT PIN memory
pressure.
This commit is contained in:
Rattus 2026-05-07 19:55:41 +10:00
parent 5262d9efce
commit 663aaf8b96
4 changed files with 66 additions and 15 deletions

View File

@ -498,6 +498,8 @@ current_loaded_models = []
DIRTY_MMAPS = set()
PIN_PRESSURE_HYSTERESIS = 128 * 1024 * 1024
def module_size(module):
module_mem = 0
sd = module.state_dict()
@ -511,6 +513,21 @@ def mark_mmap_dirty(storage):
if mmap_refs is not None:
DIRTY_MMAPS.add(mmap_refs[0])
def ensure_pin_budget(size, evict_active=False):
if MAX_PINNED_MEMORY <= 0:
return
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
if shortfall <= 0:
return
shortfall += PIN_PRESSURE_HYSTERESIS
for loaded_model in reversed(current_loaded_models):
model = loaded_model.model
if model is not None and model.is_dynamic() and (evict_active or not model.dynamic_pins[model.load_device]["active"]):
shortfall -= model.partially_unload_ram(shortfall)
if shortfall <= 0:
break
class LoadedModel:
def __init__(self, model):
@ -1133,7 +1150,6 @@ LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
STREAM_PIN_BUFFERS = {}
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
DEFAULT_PIN_BUFFER_PRIME_SIZE = 1024 ** 2
def get_cast_buffer(offload_stream, device, size, ref):
global LARGEST_CASTED_WEIGHT
@ -1177,14 +1193,29 @@ def get_aimdo_cast_buffer(offload_stream, device):
def get_pin_buffer(offload_stream):
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
if pin_buffer is None:
# A small non-zero default primes HostBuffer's larger virtual reservation.
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(DEFAULT_PIN_BUFFER_PRIME_SIZE)
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0)
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
elif offload_stream is not None:
offload_stream.synchronize()
return pin_buffer
def resize_pin_buffer(pin_buffer, size):
global TOTAL_PINNED_MEMORY
old_size = getattr(pin_buffer, "_comfy_stream_pin_size", 0)
if size <= old_size:
return True
growth = size - old_size
ensure_pin_budget(growth, evict_active=True)
try:
pin_buffer.extend(size=size, reallocate=True)
except RuntimeError:
return False
pin_buffer._comfy_stream_pin_size = size
TOTAL_PINNED_MEMORY += growth
return True
def reset_cast_buffers():
global TOTAL_PINNED_MEMORY
global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT
@ -1197,8 +1228,18 @@ def reset_cast_buffers():
for mmap_obj in DIRTY_MMAPS:
mmap_obj.bounce()
DIRTY_MMAPS.clear()
for pin_buffer in STREAM_PIN_BUFFERS.values():
TOTAL_PINNED_MEMORY -= getattr(pin_buffer, "_comfy_stream_pin_size", 0)
if TOTAL_PINNED_MEMORY < 0:
TOTAL_PINNED_MEMORY = 0
for loaded_model in current_loaded_models:
model = loaded_model.model
if model is not None and model.is_dynamic():
model.dynamic_pins[model.load_device]["active"] = False
STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
STREAM_PIN_BUFFERS.clear()
@ -1344,8 +1385,7 @@ def pin_memory(tensor):
return False
size = tensor.nbytes
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False
ensure_pin_budget(size)
ptr = tensor.data_ptr()
if ptr == 0:

View File

@ -1085,7 +1085,7 @@ class ModelPatcher:
return 0
def partially_unload_ram(self, ram_to_unload):
pass
return 0
def detach(self, unpatch_all=True):
self.eject_model()
@ -1495,7 +1495,12 @@ class ModelPatcherDynamic(ModelPatcher):
if not hasattr(self.model, "dynamic_pins"):
self.model.dynamic_pins = {}
if self.load_device not in self.model.dynamic_pins:
self.model.dynamic_pins[self.load_device] = {"hostbuf": comfy_aimdo.host_buffer.HostBuffer(0), "stack": [], "failed": False}
self.model.dynamic_pins[self.load_device] = {
"hostbuf": comfy_aimdo.host_buffer.HostBuffer(0),
"stack": [],
"failed": False,
"active": False,
}
self.non_dynamic_delegate_model = None
assert load_device is not None
@ -1559,6 +1564,7 @@ class ModelPatcherDynamic(ModelPatcher):
vbar = self._vbar_get(create=True)
pin_state = self.model.dynamic_pins[self.load_device]
pin_state["failed"] = False
pin_state["active"] = True
if vbar is not None:
vbar.prioritize()
@ -1688,9 +1694,10 @@ class ModelPatcherDynamic(ModelPatcher):
return self.model.dynamic_pins[self.load_device]["hostbuf"].size
def partially_unload_ram(self, ram_to_unload):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
hostbuf = pin_state["hostbuf"]
stack = self.model.dynamic_pins[self.load_device]["stack"]
stack = pin_state["stack"]
while len(stack) > 0:
module, offset = stack.pop()
size = module._pin.numel() * module._pin.element_size()
@ -1699,9 +1706,11 @@ class ModelPatcherDynamic(ModelPatcher):
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
freed += size
ram_to_unload -= size
if ram_to_unload <= 0:
return
return freed
return freed
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
#This isn't used by the core at all and can only be to load a model out of

View File

@ -138,6 +138,8 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
if stream_pin_hostbuf is None:
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
if stream_pin_hostbuf is None:
return None
offset = stream_pin_offset
stream_pin_offset += buffer_size
@ -213,8 +215,10 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
stream_pin_hostbuf_size = getattr(stream_pin_hostbuf, "_comfy_stream_pin_size", stream_pin_hostbuf.size)
if stream_pin_hostbuf_size < stream_pin_offset:
stream_pin_hostbuf_size = stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM
stream_pin_hostbuf.extend(size=stream_pin_hostbuf_size, reallocate=True)
stream_pin_hostbuf._comfy_stream_pin_size = stream_pin_hostbuf_size
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_hostbuf_size):
for xfer_source, _, _, xfer_dest in stream_pin_queue:
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
return offload_stream
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf, size=stream_pin_offset)
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:

View File

@ -16,9 +16,7 @@ def pin_memory(module):
hostbuf = pin_state["hostbuf"]
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
offset = hostbuf.size
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
pin_state["failed"] = True
return False
comfy.model_management.ensure_pin_budget(size)
try:
hostbuf.extend(size=size)