Merge branch 'master' into fix/color-curves-shader-nested-sampler

This commit is contained in:
Terry Jia 2026-03-28 13:06:14 -04:00 committed by GitHub
commit 6df3aee425
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 61 additions and 43 deletions

View File

@ -110,11 +110,13 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
CACHE_RAM_AUTO_GB = -1.0
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")

View File

@ -141,3 +141,17 @@ def interpret_gathered_like(tensors, gathered):
return dest_views
aimdo_enabled = False
extra_ram_release_callback = None
RAM_CACHE_HEADROOM = 0
def set_ram_cache_release_state(callback, headroom):
global extra_ram_release_callback
global RAM_CACHE_HEADROOM
extra_ram_release_callback = callback
RAM_CACHE_HEADROOM = max(0, int(headroom))
def extra_ram_release(target):
if extra_ram_release_callback is None:
return 0
return extra_ram_release_callback(target)

View File

@ -669,7 +669,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
if shift_model.device == device:
if device is None or shift_model.device == device:
if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False
@ -679,8 +679,8 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
i = x[-1]
memory_to_free = 1e32
pins_to_free = 1e32
if not DISABLE_SMART_MEMORY:
memory_to_free = memory_required - get_free_memory(device)
if not DISABLE_SMART_MEMORY or device is None:
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
pins_to_free = pins_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
@ -708,7 +708,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
if len(unloaded_model) > 0:
soft_empty_cache()
else:
elif device is not None:
if vram_state != VRAMState.HIGH_VRAM:
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25:

View File

@ -300,9 +300,6 @@ class ModelPatcher:
def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free)
def get_ram_usage(self):
return self.model_size()
def loaded_size(self):
return self.model.model_loaded_weight_memory

View File

@ -2,6 +2,7 @@ import comfy.model_management
import comfy.memory_management
import comfy_aimdo.host_buffer
import comfy_aimdo.torch
import psutil
from comfy.cli_args import args
@ -12,6 +13,11 @@ def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
return
#FIXME: This is a RAM cache trigger event
ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM
#we split the difference and assume half the RAM cache headroom is for us
if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5):
comfy.memory_management.extra_ram_release(ram_headroom)
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:

View File

@ -280,9 +280,6 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
@ -840,9 +837,6 @@ class VAE:
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")

View File

@ -1,6 +1,5 @@
import asyncio
import bisect
import gc
import itertools
import psutil
import time
@ -475,6 +474,10 @@ class LRUCache(BasicCache):
self._mark_used(node_id)
return await self._set_immediate(node_id, value)
def set_local(self, node_id, value):
self._mark_used(node_id)
BasicCache.set_local(self, node_id, value)
async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes
await super()._ensure_subcache(node_id, children_ids)
@ -489,15 +492,10 @@ class LRUCache(BasicCache):
return self
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
#Small baseline weight used when a cache entry has no measurable CPU tensors.
#Keeps unknown-sized entries in eviction scoring without dominating tensor-backed entries.
RAM_CACHE_HYSTERESIS = 1.1
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
RAM_CACHE_DEFAULT_RAM_USAGE = 0.05
#Exponential bias towards evicting older workflows so garbage will be taken out
#in constantly changing setups.
@ -521,19 +519,17 @@ class RAMPressureCache(LRUCache):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return await super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
def set_local(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set_local(node_id, value)
if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
def ram_release(self, target):
if psutil.virtual_memory().available >= target:
return
clean_list = []
for key, (outputs, _), in self.cache.items():
for key, cache_entry in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
@ -542,22 +538,20 @@ class RAMPressureCache(LRUCache):
if outputs is None:
return
for output in outputs:
if isinstance(output, list):
if isinstance(output, (list, tuple)):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
#score Tensors at a 50% discount for RAM usage as they are likely to
#be high value intermediates
ram_usage += (output.numel() * output.element_size()) * 0.5
elif hasattr(output, "get_ram_usage"):
ram_usage += output.get_ram_usage()
scan_list_for_ram_usage(outputs)
ram_usage += output.numel() * output.element_size()
scan_list_for_ram_usage(cache_entry.outputs)
oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
while psutil.virtual_memory().available < target and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()
self.used_generation.pop(key, None)
self.timestamps.pop(key, None)
self.children.pop(key, None)

View File

@ -724,6 +724,9 @@ class PromptExecutor:
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
self._notify_prompt_lifecycle("start", prompt_id)
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None
comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom)
try:
with torch.inference_mode():
@ -773,7 +776,10 @@ class PromptExecutor:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
if self.cache_type == CacheType.RAM_PRESSURE:
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
comfy.memory_management.extra_ram_release(ram_headroom)
else:
# Only execute when the while-loop ends without break
# Send cached UI for intermediate output nodes that weren't executed
@ -801,6 +807,7 @@ class PromptExecutor:
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
finally:
comfy.memory_management.set_ram_cache_release_state(None, 0)
self._notify_prompt_lifecycle("end", prompt_id)

View File

@ -275,15 +275,19 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]:
def prompt_worker(q, server_instance):
current_time: float = 0.0
cache_ram = args.cache_ram
if cache_ram < 0:
cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0))
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_ram > 0:
elif cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none:
cache_type = execution.CacheType.NONE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } )
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0