Merge branch 'master' into fix/color-curves-shader-nested-sampler

This commit is contained in:
Terry Jia 2026-03-28 13:06:14 -04:00 committed by GitHub
commit 6df3aee425
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 61 additions and 43 deletions

View File

@ -110,11 +110,13 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
CACHE_RAM_AUTO_GB = -1.0
cache_group = parser.add_mutually_exclusive_group() cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB") cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")

View File

@ -141,3 +141,17 @@ def interpret_gathered_like(tensors, gathered):
return dest_views return dest_views
aimdo_enabled = False aimdo_enabled = False
extra_ram_release_callback = None
RAM_CACHE_HEADROOM = 0
def set_ram_cache_release_state(callback, headroom):
global extra_ram_release_callback
global RAM_CACHE_HEADROOM
extra_ram_release_callback = callback
RAM_CACHE_HEADROOM = max(0, int(headroom))
def extra_ram_release(target):
if extra_ram_release_callback is None:
return 0
return extra_ram_release_callback(target)

View File

@ -669,7 +669,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for i in range(len(current_loaded_models) -1, -1, -1): for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i] shift_model = current_loaded_models[i]
if shift_model.device == device: if device is None or shift_model.device == device:
if shift_model not in keep_loaded and not shift_model.is_dead(): if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False shift_model.currently_used = False
@ -679,8 +679,8 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
i = x[-1] i = x[-1]
memory_to_free = 1e32 memory_to_free = 1e32
pins_to_free = 1e32 pins_to_free = 1e32
if not DISABLE_SMART_MEMORY: if not DISABLE_SMART_MEMORY or device is None:
memory_to_free = memory_required - get_free_memory(device) memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
pins_to_free = pins_required - get_free_ram() pins_to_free = pins_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic: if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models #don't actually unload dynamic models for the sake of other dynamic models
@ -708,7 +708,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
if len(unloaded_model) > 0: if len(unloaded_model) > 0:
soft_empty_cache() soft_empty_cache()
else: elif device is not None:
if vram_state != VRAMState.HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25: if mem_free_torch > mem_free_total * 0.25:

View File

@ -300,9 +300,6 @@ class ModelPatcher:
def model_mmap_residency(self, free=False): def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free) return comfy.model_management.module_mmap_residency(self.model, free=free)
def get_ram_usage(self):
return self.model_size()
def loaded_size(self): def loaded_size(self):
return self.model.model_loaded_weight_memory return self.model.model_loaded_weight_memory

View File

@ -2,6 +2,7 @@ import comfy.model_management
import comfy.memory_management import comfy.memory_management
import comfy_aimdo.host_buffer import comfy_aimdo.host_buffer
import comfy_aimdo.torch import comfy_aimdo.torch
import psutil
from comfy.cli_args import args from comfy.cli_args import args
@ -12,6 +13,11 @@ def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
return return
#FIXME: This is a RAM cache trigger event #FIXME: This is a RAM cache trigger event
ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM
#we split the difference and assume half the RAM cache headroom is for us
if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5):
comfy.memory_management.extra_ram_release(ram_headroom)
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:

View File

@ -280,9 +280,6 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model) return self.patcher.add_patches(patches, strength_patch, strength_model)
@ -840,9 +837,6 @@ class VAE:
self.size = comfy.model_management.module_size(self.first_stage_model) self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self): def throw_exception_if_invalid(self):
if self.first_stage_model is None: if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")

View File

@ -1,6 +1,5 @@
import asyncio import asyncio
import bisect import bisect
import gc
import itertools import itertools
import psutil import psutil
import time import time
@ -475,6 +474,10 @@ class LRUCache(BasicCache):
self._mark_used(node_id) self._mark_used(node_id)
return await self._set_immediate(node_id, value) return await self._set_immediate(node_id, value)
def set_local(self, node_id, value):
self._mark_used(node_id)
BasicCache.set_local(self, node_id, value)
async def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes # Just uses subcaches for tracking 'live' nodes
await super()._ensure_subcache(node_id, children_ids) await super()._ensure_subcache(node_id, children_ids)
@ -489,15 +492,10 @@ class LRUCache(BasicCache):
return self return self
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure #Small baseline weight used when a cache entry has no measurable CPU tensors.
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows. #Keeps unknown-sized entries in eviction scoring without dominating tensor-backed entries.
RAM_CACHE_HYSTERESIS = 1.1 RAM_CACHE_DEFAULT_RAM_USAGE = 0.05
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
#Exponential bias towards evicting older workflows so garbage will be taken out #Exponential bias towards evicting older workflows so garbage will be taken out
#in constantly changing setups. #in constantly changing setups.
@ -521,19 +519,17 @@ class RAMPressureCache(LRUCache):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return await super().get(node_id) return await super().get(node_id)
def poll(self, ram_headroom): def set_local(self, node_id, value):
def _ram_gb(): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return psutil.virtual_memory().available / (1024**3) super().set_local(node_id, value)
if _ram_gb() > ram_headroom: def ram_release(self, target):
return if psutil.virtual_memory().available >= target:
gc.collect()
if _ram_gb() > ram_headroom:
return return
clean_list = [] clean_list = []
for key, (outputs, _), in self.cache.items(): for key, cache_entry in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
@ -542,22 +538,20 @@ class RAMPressureCache(LRUCache):
if outputs is None: if outputs is None:
return return
for output in outputs: for output in outputs:
if isinstance(output, list): if isinstance(output, (list, tuple)):
scan_list_for_ram_usage(output) scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu': elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
#score Tensors at a 50% discount for RAM usage as they are likely to ram_usage += output.numel() * output.element_size()
#be high value intermediates scan_list_for_ram_usage(cache_entry.outputs)
ram_usage += (output.numel() * output.element_size()) * 0.5
elif hasattr(output, "get_ram_usage"):
ram_usage += output.get_ram_usage()
scan_list_for_ram_usage(outputs)
oom_score *= ram_usage oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all, #In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU) #break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key)) bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list: while psutil.virtual_memory().available < target and clean_list:
_, _, key = clean_list.pop() _, _, key = clean_list.pop()
del self.cache[key] del self.cache[key]
gc.collect() self.used_generation.pop(key, None)
self.timestamps.pop(key, None)
self.children.pop(key, None)

View File

@ -724,6 +724,9 @@ class PromptExecutor:
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
self._notify_prompt_lifecycle("start", prompt_id) self._notify_prompt_lifecycle("start", prompt_id)
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None
comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom)
try: try:
with torch.inference_mode(): with torch.inference_mode():
@ -773,7 +776,10 @@ class PromptExecutor:
execution_list.unstage_node_execution() execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS: else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution() execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
if self.cache_type == CacheType.RAM_PRESSURE:
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
comfy.memory_management.extra_ram_release(ram_headroom)
else: else:
# Only execute when the while-loop ends without break # Only execute when the while-loop ends without break
# Send cached UI for intermediate output nodes that weren't executed # Send cached UI for intermediate output nodes that weren't executed
@ -801,6 +807,7 @@ class PromptExecutor:
if comfy.model_management.DISABLE_SMART_MEMORY: if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models() comfy.model_management.unload_all_models()
finally: finally:
comfy.memory_management.set_ram_cache_release_state(None, 0)
self._notify_prompt_lifecycle("end", prompt_id) self._notify_prompt_lifecycle("end", prompt_id)

View File

@ -275,15 +275,19 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]:
def prompt_worker(q, server_instance): def prompt_worker(q, server_instance):
current_time: float = 0.0 current_time: float = 0.0
cache_ram = args.cache_ram
if cache_ram < 0:
cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0))
cache_type = execution.CacheType.CLASSIC cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0: if args.cache_lru > 0:
cache_type = execution.CacheType.LRU cache_type = execution.CacheType.LRU
elif args.cache_ram > 0: elif cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none: elif args.cache_none:
cache_type = execution.CacheType.NONE cache_type = execution.CacheType.NONE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } ) e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } )
last_gc_collect = 0 last_gc_collect = 0
need_gc = False need_gc = False
gc_collect_interval = 10.0 gc_collect_interval = 10.0