mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 03:50:50 +08:00
caching: build headroom into the RAM cache
move the headroom logic into the RAM cache to make this a little easier to call to "free me some RAM". Rename the API to free_ram(). Split off the clean_list creation to a completely separate function to avoid any stray strong reference to the content-to-be-freed on the stack.
This commit is contained in:
parent
a17cf1c387
commit
68053b1180
@ -193,7 +193,7 @@ class BasicCache:
|
|||||||
self._clean_cache()
|
self._clean_cache()
|
||||||
self._clean_subcaches()
|
self._clean_subcaches()
|
||||||
|
|
||||||
def poll(self, **kwargs):
|
def free_ram(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _set_immediate(self, node_id, value):
|
def _set_immediate(self, node_id, value):
|
||||||
@ -284,7 +284,7 @@ class NullCache:
|
|||||||
def clean_unused(self):
|
def clean_unused(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def poll(self, **kwargs):
|
def free_ram(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get(self, node_id):
|
def get(self, node_id):
|
||||||
@ -366,9 +366,10 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
|||||||
|
|
||||||
class RAMPressureCache(LRUCache):
|
class RAMPressureCache(LRUCache):
|
||||||
|
|
||||||
def __init__(self, key_class):
|
def __init__(self, key_class, min_headroom=4.0):
|
||||||
super().__init__(key_class, 0)
|
super().__init__(key_class, 0)
|
||||||
self.timestamps = {}
|
self.timestamps = {}
|
||||||
|
self.min_headroom = min_headroom
|
||||||
|
|
||||||
def clean_unused(self):
|
def clean_unused(self):
|
||||||
self._clean_subcaches()
|
self._clean_subcaches()
|
||||||
@ -381,19 +382,10 @@ class RAMPressureCache(LRUCache):
|
|||||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||||
return super().get(node_id)
|
return super().get(node_id)
|
||||||
|
|
||||||
def poll(self, ram_headroom):
|
def _build_clean_list(self):
|
||||||
def _ram_gb():
|
|
||||||
return psutil.virtual_memory().available / (1024**3)
|
|
||||||
|
|
||||||
if _ram_gb() > ram_headroom:
|
|
||||||
return
|
|
||||||
gc.collect()
|
|
||||||
if _ram_gb() > ram_headroom:
|
|
||||||
return
|
|
||||||
|
|
||||||
clean_list = []
|
clean_list = []
|
||||||
|
|
||||||
for key, (outputs, _), in self.cache.items():
|
for key, (_, outputs), in self.cache.items():
|
||||||
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
||||||
|
|
||||||
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
||||||
@ -416,8 +408,22 @@ class RAMPressureCache(LRUCache):
|
|||||||
#In the case where we have no information on the node ram usage at all,
|
#In the case where we have no information on the node ram usage at all,
|
||||||
#break OOM score ties on the last touch timestamp (pure LRU)
|
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||||
|
return clean_list
|
||||||
|
|
||||||
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
|
def free_ram(self, extra_ram=0):
|
||||||
|
headroom_target = self.min_headroom + (extra_ram / (1024**3))
|
||||||
|
def _ram_gb():
|
||||||
|
return psutil.virtual_memory().available / (1024**3)
|
||||||
|
|
||||||
|
if _ram_gb() > headroom_target:
|
||||||
|
return
|
||||||
|
gc.collect()
|
||||||
|
if _ram_gb() > headroom_target:
|
||||||
|
return
|
||||||
|
|
||||||
|
clean_list = self._build_clean_list()
|
||||||
|
|
||||||
|
while _ram_gb() < headroom_target * RAM_CACHE_HYSTERESIS and clean_list:
|
||||||
_, _, key = clean_list.pop()
|
_, _, key = clean_list.pop()
|
||||||
del self.cache[key]
|
del self.cache[key]
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|||||||
@ -107,7 +107,7 @@ class CacheSet:
|
|||||||
self.init_null_cache()
|
self.init_null_cache()
|
||||||
logging.info("Disabling intermediate node cache.")
|
logging.info("Disabling intermediate node cache.")
|
||||||
elif cache_type == CacheType.RAM_PRESSURE:
|
elif cache_type == CacheType.RAM_PRESSURE:
|
||||||
cache_ram = cache_args.get("ram", 16.0)
|
cache_ram = cache_args.get("ram", 4.0)
|
||||||
self.init_ram_cache(cache_ram)
|
self.init_ram_cache(cache_ram)
|
||||||
logging.info("Using RAM pressure cache.")
|
logging.info("Using RAM pressure cache.")
|
||||||
elif cache_type == CacheType.LRU:
|
elif cache_type == CacheType.LRU:
|
||||||
@ -129,7 +129,7 @@ class CacheSet:
|
|||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_ram_cache(self, min_headroom):
|
def init_ram_cache(self, min_headroom):
|
||||||
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
|
self.outputs = RAMPressureCache(CacheKeySetInputSignature, min_headroom)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_null_cache(self):
|
def init_null_cache(self):
|
||||||
@ -717,7 +717,7 @@ class PromptExecutor:
|
|||||||
execution_list.unstage_node_execution()
|
execution_list.unstage_node_execution()
|
||||||
else: # result == ExecutionResult.SUCCESS:
|
else: # result == ExecutionResult.SUCCESS:
|
||||||
execution_list.complete_node_execution()
|
execution_list.complete_node_execution()
|
||||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
self.caches.outputs.free_ram()
|
||||||
else:
|
else:
|
||||||
# Only execute when the while-loop ends without break
|
# Only execute when the while-loop ends without break
|
||||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user