caching: build headroom into the RAM cache

move the headroom logic into the RAM cache to make this a little easier
to call to "free me some RAM".

Rename the API to free_ram().

Split off the clean_list creation to a completely separate function to
avoid any stray strong reference to the content-to-be-freed on the
stack.
This commit is contained in:
Rattus 2025-11-14 15:31:57 +10:00
parent a17cf1c387
commit 68053b1180
2 changed files with 24 additions and 18 deletions

View File

@ -193,7 +193,7 @@ class BasicCache:
self._clean_cache() self._clean_cache()
self._clean_subcaches() self._clean_subcaches()
def poll(self, **kwargs): def free_ram(self, *args, **kwargs):
pass pass
def _set_immediate(self, node_id, value): def _set_immediate(self, node_id, value):
@ -284,7 +284,7 @@ class NullCache:
def clean_unused(self): def clean_unused(self):
pass pass
def poll(self, **kwargs): def free_ram(self, *args, **kwargs):
pass pass
def get(self, node_id): def get(self, node_id):
@ -366,9 +366,10 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache): class RAMPressureCache(LRUCache):
def __init__(self, key_class): def __init__(self, key_class, min_headroom=4.0):
super().__init__(key_class, 0) super().__init__(key_class, 0)
self.timestamps = {} self.timestamps = {}
self.min_headroom = min_headroom
def clean_unused(self): def clean_unused(self):
self._clean_subcaches() self._clean_subcaches()
@ -381,19 +382,10 @@ class RAMPressureCache(LRUCache):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id) return super().get(node_id)
def poll(self, ram_headroom): def _build_clean_list(self):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
return
clean_list = [] clean_list = []
for key, (outputs, _), in self.cache.items(): for key, (_, outputs), in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
@ -416,8 +408,22 @@ class RAMPressureCache(LRUCache):
#In the case where we have no information on the node ram usage at all, #In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU) #break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key)) bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
return clean_list
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list: def free_ram(self, extra_ram=0):
headroom_target = self.min_headroom + (extra_ram / (1024**3))
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
if _ram_gb() > headroom_target:
return
gc.collect()
if _ram_gb() > headroom_target:
return
clean_list = self._build_clean_list()
while _ram_gb() < headroom_target * RAM_CACHE_HYSTERESIS and clean_list:
_, _, key = clean_list.pop() _, _, key = clean_list.pop()
del self.cache[key] del self.cache[key]
gc.collect() gc.collect()

View File

@ -107,7 +107,7 @@ class CacheSet:
self.init_null_cache() self.init_null_cache()
logging.info("Disabling intermediate node cache.") logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.RAM_PRESSURE: elif cache_type == CacheType.RAM_PRESSURE:
cache_ram = cache_args.get("ram", 16.0) cache_ram = cache_args.get("ram", 4.0)
self.init_ram_cache(cache_ram) self.init_ram_cache(cache_ram)
logging.info("Using RAM pressure cache.") logging.info("Using RAM pressure cache.")
elif cache_type == CacheType.LRU: elif cache_type == CacheType.LRU:
@ -129,7 +129,7 @@ class CacheSet:
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_ram_cache(self, min_headroom): def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature) self.outputs = RAMPressureCache(CacheKeySetInputSignature, min_headroom)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self): def init_null_cache(self):
@ -717,7 +717,7 @@ class PromptExecutor:
execution_list.unstage_node_execution() execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS: else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution() execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) self.caches.outputs.free_ram()
else: else:
# Only execute when the while-loop ends without break # Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)