mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-11 05:52:33 +08:00
Implement RAM Pressure cache
Implement a cache sensitive to RAM pressure. When RAM headroom drops down below a certain threshold, evict RAM-expensive nodes from the cache. Models and tensors are measured directly for RAM usage. An OOM score is then computed based on the RAM usage of the node. Note the due to indirection through shared objects (like a model patcher), multiple nodes can account the same RAM as their individual usage. The intent is this will free chains of nodes particularly model loaders and associate loras as they all score similar and are sorted in close to each other. Has a bias towards unloading model nodes mid flow while being able to keep results like text encodings and VAE.
This commit is contained in:
parent
0c95f22907
commit
f3f526fcd3
@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
|
||||
@ -1,4 +1,9 @@
|
||||
import bisect
|
||||
import gc
|
||||
import itertools
|
||||
import psutil
|
||||
import time
|
||||
import torch
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from abc import ABC, abstractmethod
|
||||
@ -188,6 +193,9 @@ class BasicCache:
|
||||
self._clean_cache()
|
||||
self._clean_subcaches()
|
||||
|
||||
def poll(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _set_immediate(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
@ -276,6 +284,9 @@ class NullCache:
|
||||
def clean_unused(self):
|
||||
pass
|
||||
|
||||
def poll(self, **kwargs):
|
||||
pass
|
||||
|
||||
def get(self, node_id):
|
||||
return None
|
||||
|
||||
@ -336,3 +347,75 @@ class LRUCache(BasicCache):
|
||||
self._mark_used(child_id)
|
||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||
return self
|
||||
|
||||
|
||||
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
|
||||
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
|
||||
|
||||
RAM_CACHE_HYSTERESIS = 1.1
|
||||
|
||||
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
|
||||
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
|
||||
|
||||
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
|
||||
|
||||
#Exponential bias towards evicting older workflows so garbage will be taken out
|
||||
#in constantly changing setups.
|
||||
|
||||
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
||||
|
||||
class RAMPressureCache(LRUCache):
|
||||
|
||||
def __init__(self, key_class):
|
||||
super().__init__(key_class, 0)
|
||||
self.timestamps = {}
|
||||
|
||||
def clean_unused(self):
|
||||
self._clean_subcaches()
|
||||
|
||||
def set(self, node_id, value):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
super().set(node_id, value)
|
||||
|
||||
def get(self, node_id):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
return super().get(node_id)
|
||||
|
||||
def poll(self, ram_headroom):
|
||||
def _ram_gb():
|
||||
return psutil.virtual_memory().available / (1024**3)
|
||||
|
||||
if _ram_gb() > ram_headroom:
|
||||
return
|
||||
gc.collect()
|
||||
if _ram_gb() > ram_headroom:
|
||||
return
|
||||
|
||||
clean_list = []
|
||||
|
||||
for key, (outputs, _), in self.cache.items():
|
||||
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
||||
|
||||
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
||||
def scan_list_for_ram_usage(outputs):
|
||||
nonlocal ram_usage
|
||||
for output in outputs:
|
||||
if isinstance(output, list):
|
||||
scan_list_for_ram_usage(output)
|
||||
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
||||
#score Tensors at a 50% discount for RAM usage as they are likely to
|
||||
#be high value intermediates
|
||||
ram_usage += (output.numel() * output.element_size()) * 0.5
|
||||
elif hasattr(output, "get_ram_usage"):
|
||||
ram_usage += output.get_ram_usage()
|
||||
scan_list_for_ram_usage(outputs)
|
||||
|
||||
oom_score *= ram_usage
|
||||
#In the case where we have no information on the node ram usage at all,
|
||||
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||
|
||||
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
|
||||
_, _, key = clean_list.pop()
|
||||
del self.cache[key]
|
||||
gc.collect()
|
||||
|
||||
22
execution.py
22
execution.py
@ -21,6 +21,7 @@ from comfy_execution.caching import (
|
||||
NullCache,
|
||||
HierarchicalCache,
|
||||
LRUCache,
|
||||
RAMPressureCache,
|
||||
)
|
||||
from comfy_execution.graph import (
|
||||
DynamicPrompt,
|
||||
@ -92,16 +93,20 @@ class CacheType(Enum):
|
||||
CLASSIC = 0
|
||||
LRU = 1
|
||||
NONE = 2
|
||||
RAM_PRESSURE = 3
|
||||
|
||||
|
||||
class CacheSet:
|
||||
def __init__(self, cache_type=None, cache_size=None):
|
||||
def __init__(self, cache_type=None, cache_args={}):
|
||||
if cache_type == CacheType.NONE:
|
||||
self.init_null_cache()
|
||||
logging.info("Disabling intermediate node cache.")
|
||||
elif cache_type == CacheType.RAM_PRESSURE:
|
||||
cache_ram = cache_args.get("ram", 16.0)
|
||||
self.init_ram_cache(cache_ram)
|
||||
logging.info("Using RAM pressure cache.")
|
||||
elif cache_type == CacheType.LRU:
|
||||
if cache_size is None:
|
||||
cache_size = 0
|
||||
cache_size = cache_args.get("lru", 0)
|
||||
self.init_lru_cache(cache_size)
|
||||
logging.info("Using LRU cache")
|
||||
else:
|
||||
@ -118,6 +123,10 @@ class CacheSet:
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_ram_cache(self, min_headroom):
|
||||
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_null_cache(self):
|
||||
self.outputs = NullCache()
|
||||
self.objects = NullCache()
|
||||
@ -600,14 +609,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
class PromptExecutor:
|
||||
def __init__(self, server, cache_type=False, cache_size=None):
|
||||
self.cache_size = cache_size
|
||||
def __init__(self, server, cache_type=False, cache_args=None):
|
||||
self.cache_args = cache_args
|
||||
self.cache_type = cache_type
|
||||
self.server = server
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
|
||||
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
|
||||
@ -705,6 +714,7 @@ class PromptExecutor:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
|
||||
4
main.py
4
main.py
@ -172,10 +172,12 @@ def prompt_worker(q, server_instance):
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
if args.cache_lru > 0:
|
||||
cache_type = execution.CacheType.LRU
|
||||
elif args.cache_ram > 0:
|
||||
cache_type = execution.CacheType.RAM_PRESSURE
|
||||
elif args.cache_none:
|
||||
cache_type = execution.CacheType.NONE
|
||||
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user