pinning updates

This commit is contained in:
Rattus 2026-05-18 02:30:03 +10:00
parent 52a68b9b1c
commit c451854f84
4 changed files with 70 additions and 54 deletions

View File

@ -502,7 +502,7 @@ PIN_PRESSURE_HYSTERESIS = 256 * 1024 * 1024
#Freeing registerables on pressure does imply a GPU sync, so go big on
#the hysteresis so each expensive sync gives us back a good chunk.
REGISTERABLE_PIN_HYSTERESIS = 768 * 1024 * 1024
REGISTERABLE_PIN_HYSTERESIS = 2048 * 1024 * 1024
def module_size(module):
module_mem = 0
@ -518,30 +518,31 @@ def mark_mmap_dirty(storage):
DIRTY_MMAPS.add(mmap_refs[0])
def free_pins(size, evict_active=False):
if size <= 0:
return
freed_total = 0
for loaded_model in reversed(current_loaded_models):
if size <= 0:
return freed_total
model = loaded_model.model
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
size -= model.partially_unload_ram(size)
if size <= 0:
break
freed = model.partially_unload_ram(size)
freed_total += freed
size -= freed
return freed_total
def ensure_pin_budget(size, evict_active=False):
if MAX_MODEL_MEMORY <= 0:
return
shortfall = TOTAL_MODEL_MEMORY + size - MAX_MODEL_MEMORY
shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available
if shortfall <= 0:
return
return True
free_pins(shortfall + PIN_PRESSURE_HYSTERESIS, evict_active=evict_active)
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
return free_pins(to_free, evict_active=evict_active) >= shortfall
def ensure_pin_registerable(size, evict_active=False):
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0 or shortfall <= 0:
return
if MAX_PINNED_MEMORY <= 0:
return False
if shortfall <= 0:
return True
shortfall += REGISTERABLE_PIN_HYSTERESIS
for loaded_model in reversed(current_loaded_models):
@ -549,7 +550,8 @@ def ensure_pin_registerable(size, evict_active=False):
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
shortfall -= model.unregister_inactive_pins(shortfall)
if shortfall <= 0:
return
return True
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
class LoadedModel:
def __init__(self, model):
@ -1225,24 +1227,23 @@ def get_pin_buffer(offload_stream):
return pin_buffer
def resize_pin_buffer(pin_buffer, size):
global TOTAL_MODEL_MEMORY, TOTAL_PINNED_MEMORY
global TOTAL_PINNED_MEMORY
old_size = pin_buffer.size
if size <= old_size:
return True
growth = size - old_size
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM, free_active=True)
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
ensure_pin_budget(growth, evict_active=True)
ensure_pin_registerable(growth, evict_active=True)
try:
pin_buffer.extend(size=size, reallocate=True)
except RuntimeError:
return False
TOTAL_MODEL_MEMORY += pin_buffer.size - old_size
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
return True
def reset_cast_buffers():
global TOTAL_MODEL_MEMORY, TOTAL_PINNED_MEMORY
global TOTAL_PINNED_MEMORY
global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT
@ -1258,9 +1259,7 @@ def reset_cast_buffers():
DIRTY_MMAPS.clear()
for pin_buffer in STREAM_PIN_BUFFERS.values():
TOTAL_MODEL_MEMORY -= pin_buffer.size
TOTAL_PINNED_MEMORY -= pin_buffer.size
TOTAL_MODEL_MEMORY = max(0, TOTAL_MODEL_MEMORY)
TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY)
for loaded_model in current_loaded_models:
@ -1268,7 +1267,7 @@ def reset_cast_buffers():
if model is not None and model.is_dynamic():
model.model.dynamic_pins[model.load_device]["active"] = False
model.partially_unload_ram(1e30, subsets=[ "patches" ])
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), [], [-1])
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), [], [-1], [0])
STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
@ -1376,14 +1375,11 @@ def cast_to_device(tensor, device, dtype, copy=False):
PINNED_MEMORY = {}
TOTAL_MODEL_MEMORY = 0
TOTAL_PINNED_MEMORY = 0
MAX_MODEL_MEMORY = -1
MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
ram = get_total_memory(torch.device("cpu"))
MAX_MODEL_MEMORY = min(ram - 4 * 1024 * 1024 * 1024, ram * 0.90)
if WINDOWS:
MAX_PINNED_MEMORY = ram * 0.40 # Windows limit is apparently 50%
else:

View File

@ -1128,6 +1128,10 @@ class ModelPatcher:
# Pinned memory pressure tracking is only implemented for DynamicVram loading
return 0
def loaded_ram_size(self):
# Loaded RAM pressure tracking is only implemented for DynamicVram loading
return 0
def partially_unload_ram(self, ram_to_unload):
return 0
@ -1557,8 +1561,8 @@ class ModelPatcherDynamic(ModelPatcher):
self.model.dynamic_pins = {}
if self.load_device not in self.model.dynamic_pins:
self.model.dynamic_pins[self.load_device] = {
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024), [], [-1]),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), [], [-1]),
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024), [], [-1], [0]),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), [], [-1], [0]),
"failed": False,
"active": False,
}
@ -1766,15 +1770,19 @@ class ModelPatcherDynamic(ModelPatcher):
return freed
def pinned_memory_size(self):
def loaded_ram_size(self):
return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
self.model.dynamic_pins[self.load_device]["patches"][0].size)
def pinned_memory_size(self):
return (self.model.dynamic_pins[self.load_device]["weights"][3][0] +
self.model.dynamic_pins[self.load_device]["patches"][3][0])
def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split = pin_state[subset]
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
split = stack_split[0]
while split >= 0:
module, offset = stack[split]
@ -1788,6 +1796,7 @@ class ModelPatcherDynamic(ModelPatcher):
continue
module._pin_registered = False
comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size)
pinned_size[0] = max(0, pinned_size[0] - size)
freed += size
ram_to_unload -= size
if ram_to_unload <= 0:
@ -1798,16 +1807,16 @@ class ModelPatcherDynamic(ModelPatcher):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split = pin_state[subset]
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
while len(stack) > 0:
module, offset = stack.pop()
size = module._pin.numel() * module._pin.element_size()
del module._pin
hostbuf.truncate(offset, do_unregister=module._pin_registered)
stack_split[0] = min(stack_split[0], len(stack) - 1)
comfy.model_management.TOTAL_MODEL_MEMORY = max(0, comfy.model_management.TOTAL_MODEL_MEMORY - size)
if module._pin_registered:
comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size)
pinned_size[0] = max(0, pinned_size[0] - size)
freed += size
ram_to_unload -= size
if ram_to_unload <= 0:

View File

@ -7,37 +7,44 @@ import torch
from comfy.cli_args import args
def get_pin(module, subset="weights"):
return getattr(module, "_pin", None)
pin = getattr(module, "_pin", None)
if pin is None or module._pin_registered or args.disable_pinned_memory:
return pin
_, _, stack_split, pinned_size = module._pin_state[subset]
size = pin.nbytes
comfy.model_management.ensure_pin_registerable(size)
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
comfy.model_management.discard_cuda_async_error()
return pin
module._pin_registered = True
stack_split[0] = max(stack_split[0], module._pin_stack_index)
comfy.model_management.TOTAL_PINNED_MEMORY += size
pinned_size[0] += size
return pin
def pin_memory(module, subset="weights", size=None):
pin_state = module._pin_state
if pin_state["failed"] or args.disable_pinned_memory:
if args.disable_pinned_memory:
return
hostbuf, stack, stack_split = pin_state[subset]
pin = get_pin(module, subset)
if pin is not None:
if module._pin_registered:
return
size = module._pin.nbytes
comfy.model_management.ensure_pin_registerable(size)
if torch.cuda.cudart().cudaHostRegister(module._pin.data_ptr(), size, 1) != 0:
comfy.model_management.discard_cuda_async_error()
return False
module._pin_registered = True
stack_split[0] = max(stack_split[0], module._pin_stack_index)
comfy.model_management.TOTAL_PINNED_MEMORY += size
return True
if pin is not None or pin_state["failed"]:
return
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
if size is None:
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
offset = hostbuf.size
registerable_size = size + max(0, hostbuf.size - pinned_size[0])
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
comfy.model_management.ensure_pin_budget(size)
comfy.model_management.ensure_pin_registerable(size)
if (not comfy.model_management.ensure_pin_budget(size) or
not comfy.model_management.ensure_pin_registerable(registerable_size)):
pin_state["failed"] = True
return False
try:
hostbuf.extend(size=size)
@ -51,6 +58,6 @@ def pin_memory(module, subset="weights", size=None):
module._pin_registered = True
module._pin_stack_index = len(stack) - 1
stack_split[0] = max(stack_split[0], module._pin_stack_index)
comfy.model_management.TOTAL_MODEL_MEMORY += size
comfy.model_management.TOTAL_PINNED_MEMORY += size
pinned_size[0] += size
return True

View File

@ -784,8 +784,12 @@ class PromptExecutor:
if self.cache_type == CacheType.RAM_PRESSURE:
ram_release_callback(ram_inactive_headroom)
ram_shortfall = ram_headroom - psutil.virtual_memory().available
comfy.model_management.free_pins(ram_shortfall)
ram_release_callback(ram_headroom, free_active=True)
freed = comfy.model_management.free_pins(ram_shortfall + 512 * (1024 ** 2))
if freed < ram_shortfall:
if freed > 64 * (1024 ** 2):
# AIMDO MEM_DECOMMIT can outrun psutil.available catching up.
time.sleep(0.05)
ram_release_callback(ram_headroom, free_active=True)
else:
# Only execute when the while-loop ends without break
# Send cached UI for intermediate output nodes that weren't executed