mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 03:50:50 +08:00
caching: build headroom into the RAM cache
move the headroom logic into the RAM cache to make this a little easier to call to "free me some RAM". Rename the API to free_ram(). Split off the clean_list creation to a completely separate function to avoid any stray strong reference to the content-to-be-freed on the stack.
This commit is contained in:
parent
a17cf1c387
commit
68053b1180
@ -193,7 +193,7 @@ class BasicCache:
|
||||
self._clean_cache()
|
||||
self._clean_subcaches()
|
||||
|
||||
def poll(self, **kwargs):
|
||||
def free_ram(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def _set_immediate(self, node_id, value):
|
||||
@ -284,7 +284,7 @@ class NullCache:
|
||||
def clean_unused(self):
|
||||
pass
|
||||
|
||||
def poll(self, **kwargs):
|
||||
def free_ram(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def get(self, node_id):
|
||||
@ -366,9 +366,10 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
||||
|
||||
class RAMPressureCache(LRUCache):
|
||||
|
||||
def __init__(self, key_class):
|
||||
def __init__(self, key_class, min_headroom=4.0):
|
||||
super().__init__(key_class, 0)
|
||||
self.timestamps = {}
|
||||
self.min_headroom = min_headroom
|
||||
|
||||
def clean_unused(self):
|
||||
self._clean_subcaches()
|
||||
@ -381,19 +382,10 @@ class RAMPressureCache(LRUCache):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
return super().get(node_id)
|
||||
|
||||
def poll(self, ram_headroom):
|
||||
def _ram_gb():
|
||||
return psutil.virtual_memory().available / (1024**3)
|
||||
|
||||
if _ram_gb() > ram_headroom:
|
||||
return
|
||||
gc.collect()
|
||||
if _ram_gb() > ram_headroom:
|
||||
return
|
||||
|
||||
def _build_clean_list(self):
|
||||
clean_list = []
|
||||
|
||||
for key, (outputs, _), in self.cache.items():
|
||||
for key, (_, outputs), in self.cache.items():
|
||||
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
||||
|
||||
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
||||
@ -416,8 +408,22 @@ class RAMPressureCache(LRUCache):
|
||||
#In the case where we have no information on the node ram usage at all,
|
||||
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||
return clean_list
|
||||
|
||||
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
|
||||
def free_ram(self, extra_ram=0):
|
||||
headroom_target = self.min_headroom + (extra_ram / (1024**3))
|
||||
def _ram_gb():
|
||||
return psutil.virtual_memory().available / (1024**3)
|
||||
|
||||
if _ram_gb() > headroom_target:
|
||||
return
|
||||
gc.collect()
|
||||
if _ram_gb() > headroom_target:
|
||||
return
|
||||
|
||||
clean_list = self._build_clean_list()
|
||||
|
||||
while _ram_gb() < headroom_target * RAM_CACHE_HYSTERESIS and clean_list:
|
||||
_, _, key = clean_list.pop()
|
||||
del self.cache[key]
|
||||
gc.collect()
|
||||
|
||||
@ -107,7 +107,7 @@ class CacheSet:
|
||||
self.init_null_cache()
|
||||
logging.info("Disabling intermediate node cache.")
|
||||
elif cache_type == CacheType.RAM_PRESSURE:
|
||||
cache_ram = cache_args.get("ram", 16.0)
|
||||
cache_ram = cache_args.get("ram", 4.0)
|
||||
self.init_ram_cache(cache_ram)
|
||||
logging.info("Using RAM pressure cache.")
|
||||
elif cache_type == CacheType.LRU:
|
||||
@ -129,7 +129,7 @@ class CacheSet:
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_ram_cache(self, min_headroom):
|
||||
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
|
||||
self.outputs = RAMPressureCache(CacheKeySetInputSignature, min_headroom)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_null_cache(self):
|
||||
@ -717,7 +717,7 @@ class PromptExecutor:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
self.caches.outputs.free_ram()
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user