RAM-cache: remove poll and go via model_management

Go via model_management for full RAM cache
This commit is contained in:
Rattus 2026-03-25 22:54:25 +10:00
parent 7a0fb40daa
commit 19034e6514
2 changed files with 14 additions and 17 deletions

View File

@ -493,11 +493,6 @@ class LRUCache(BasicCache):
return self
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
RAM_CACHE_HYSTERESIS = 1.1
#Small baseline weight used when a cache entry has no measurable CPU tensors.
#Keeps unknown-sized entries in eviction scoring without dominating tensor-backed entries.
@ -529,19 +524,15 @@ class RAMPressureCache(LRUCache):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set_local(node_id, value)
def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
def ram_release(self, target):
if psutil.virtual_memory().available >= target:
return
if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
return
clean_list = []
for key, (outputs, _), in self.cache.items():
for key, cache_entry in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
@ -550,18 +541,21 @@ class RAMPressureCache(LRUCache):
if outputs is None:
return
for output in outputs:
if isinstance(output, list):
if isinstance(output, (list, tuple)):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
ram_usage += output.numel() * output.element_size()
scan_list_for_ram_usage(outputs)
scan_list_for_ram_usage(cache_entry.outputs)
oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
while psutil.virtual_memory().available < target and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
self.used_generation.pop(key, None)
self.timestamps.pop(key, None)
self.children.pop(key, None)
gc.collect()

View File

@ -109,6 +109,7 @@ class CacheType(Enum):
class CacheSet:
def __init__(self, cache_type=None, cache_args={}):
self.ram_release_callback = None
if cache_type == CacheType.NONE:
self.init_null_cache()
logging.info("Disabling intermediate node cache.")
@ -137,6 +138,7 @@ class CacheSet:
def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID)
self.ram_release_callback = self.outputs.ram_release
def init_null_cache(self):
self.outputs = NullCache()
@ -715,6 +717,7 @@ class PromptExecutor:
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
self._notify_prompt_lifecycle("start", prompt_id)
comfy.model_management.register_extra_ram_release_callback(self.caches.ram_release_callback)
try:
with torch.inference_mode():
@ -764,7 +767,6 @@ class PromptExecutor:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
@ -782,6 +784,7 @@ class PromptExecutor:
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
finally:
comfy.model_management.register_extra_ram_release_callback(None)
self._notify_prompt_lifecycle("end", prompt_id)