From 68053b1180a31bcca9cab18fb042e8d3117a5c4c Mon Sep 17 00:00:00 2001 From: Rattus Date: Fri, 14 Nov 2025 15:31:57 +1000 Subject: [PATCH] caching: build headroom into the RAM cache move the headroom logic into the RAM cache to make this a little easier to call to "free me some RAM". Rename the API to free_ram(). Split off the clean_list creation to a completely separate function to avoid any stray strong reference to the content-to-be-freed on the stack. --- comfy_execution/caching.py | 36 +++++++++++++++++++++--------------- execution.py | 6 +++--- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 326a279fc..43f882469 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -193,7 +193,7 @@ class BasicCache: self._clean_cache() self._clean_subcaches() - def poll(self, **kwargs): + def free_ram(self, *args, **kwargs): pass def _set_immediate(self, node_id, value): @@ -284,7 +284,7 @@ class NullCache: def clean_unused(self): pass - def poll(self, **kwargs): + def free_ram(self, *args, **kwargs): pass def get(self, node_id): @@ -366,9 +366,10 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 class RAMPressureCache(LRUCache): - def __init__(self, key_class): + def __init__(self, key_class, min_headroom=4.0): super().__init__(key_class, 0) self.timestamps = {} + self.min_headroom = min_headroom def clean_unused(self): self._clean_subcaches() @@ -381,19 +382,10 @@ class RAMPressureCache(LRUCache): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() return super().get(node_id) - def poll(self, ram_headroom): - def _ram_gb(): - return psutil.virtual_memory().available / (1024**3) - - if _ram_gb() > ram_headroom: - return - gc.collect() - if _ram_gb() > ram_headroom: - return - + def _build_clean_list(self): clean_list = [] - for key, (outputs, _), in self.cache.items(): + for key, (_, outputs), in self.cache.items(): oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE @@ -416,8 +408,22 @@ class RAMPressureCache(LRUCache): #In the case where we have no information on the node ram usage at all, #break OOM score ties on the last touch timestamp (pure LRU) bisect.insort(clean_list, (oom_score, self.timestamps[key], key)) + return clean_list - while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list: + def free_ram(self, extra_ram=0): + headroom_target = self.min_headroom + (extra_ram / (1024**3)) + def _ram_gb(): + return psutil.virtual_memory().available / (1024**3) + + if _ram_gb() > headroom_target: + return + gc.collect() + if _ram_gb() > headroom_target: + return + + clean_list = self._build_clean_list() + + while _ram_gb() < headroom_target * RAM_CACHE_HYSTERESIS and clean_list: _, _, key = clean_list.pop() del self.cache[key] gc.collect() diff --git a/execution.py b/execution.py index 17c77beab..44e3bb65c 100644 --- a/execution.py +++ b/execution.py @@ -107,7 +107,7 @@ class CacheSet: self.init_null_cache() logging.info("Disabling intermediate node cache.") elif cache_type == CacheType.RAM_PRESSURE: - cache_ram = cache_args.get("ram", 16.0) + cache_ram = cache_args.get("ram", 4.0) self.init_ram_cache(cache_ram) logging.info("Using RAM pressure cache.") elif cache_type == CacheType.LRU: @@ -129,7 +129,7 @@ class CacheSet: self.objects = HierarchicalCache(CacheKeySetID) def init_ram_cache(self, min_headroom): - self.outputs = RAMPressureCache(CacheKeySetInputSignature) + self.outputs = RAMPressureCache(CacheKeySetInputSignature, min_headroom) self.objects = HierarchicalCache(CacheKeySetID) def init_null_cache(self): @@ -717,7 +717,7 @@ class PromptExecutor: execution_list.unstage_node_execution() else: # result == ExecutionResult.SUCCESS: execution_list.complete_node_execution() - self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) + self.caches.outputs.free_ram() else: # Only execute when the while-loop ends without break self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)