diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 13612175e..dbaadf723 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -110,11 +110,13 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") +CACHE_RAM_AUTO_GB = -1.0 + cache_group = parser.add_mutually_exclusive_group() cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") -cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB") +cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") diff --git a/comfy/memory_management.py b/comfy/memory_management.py index f9078fe7c..48e3c11da 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -141,3 +141,17 @@ def interpret_gathered_like(tensors, gathered): return dest_views aimdo_enabled = False + +extra_ram_release_callback = None +RAM_CACHE_HEADROOM = 0 + +def set_ram_cache_release_state(callback, headroom): + global extra_ram_release_callback + global RAM_CACHE_HEADROOM + extra_ram_release_callback = callback + RAM_CACHE_HEADROOM = max(0, int(headroom)) + +def extra_ram_release(target): + if extra_ram_release_callback is None: + return 0 + return extra_ram_release_callback(target) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9617d8388..ce079cf2f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -669,7 +669,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins for i in range(len(current_loaded_models) -1, -1, -1): shift_model = current_loaded_models[i] - if shift_model.device == device: + if device is None or shift_model.device == device: if shift_model not in keep_loaded and not shift_model.is_dead(): can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) shift_model.currently_used = False @@ -679,8 +679,8 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins i = x[-1] memory_to_free = 1e32 pins_to_free = 1e32 - if not DISABLE_SMART_MEMORY: - memory_to_free = memory_required - get_free_memory(device) + if not DISABLE_SMART_MEMORY or device is None: + memory_to_free = 0 if device is None else memory_required - get_free_memory(device) pins_to_free = pins_required - get_free_ram() if current_loaded_models[i].model.is_dynamic() and for_dynamic: #don't actually unload dynamic models for the sake of other dynamic models @@ -708,7 +708,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins if len(unloaded_model) > 0: soft_empty_cache() - else: + elif device is not None: if vram_state != VRAMState.HIGH_VRAM: mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) if mem_free_torch > mem_free_total * 0.25: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c26d37db2..6deb71e12 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -300,9 +300,6 @@ class ModelPatcher: def model_mmap_residency(self, free=False): return comfy.model_management.module_mmap_residency(self.model, free=free) - def get_ram_usage(self): - return self.model_size() - def loaded_size(self): return self.model.model_loaded_weight_memory diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index f6fb806c4..6f142282d 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -2,6 +2,7 @@ import comfy.model_management import comfy.memory_management import comfy_aimdo.host_buffer import comfy_aimdo.torch +import psutil from comfy.cli_args import args @@ -12,6 +13,11 @@ def pin_memory(module): if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: return #FIXME: This is a RAM cache trigger event + ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM + #we split the difference and assume half the RAM cache headroom is for us + if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5): + comfy.memory_management.extra_ram_release(ram_headroom) + size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: diff --git a/comfy/sd.py b/comfy/sd.py index e2645438c..e1a2840d2 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -280,9 +280,6 @@ class CLIP: n.apply_hooks_to_conds = self.apply_hooks_to_conds return n - def get_ram_usage(self): - return self.patcher.get_ram_usage() - def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): return self.patcher.add_patches(patches, strength_patch, strength_model) @@ -840,9 +837,6 @@ class VAE: self.size = comfy.model_management.module_size(self.first_stage_model) return self.size - def get_ram_usage(self): - return self.model_size() - def throw_exception_if_invalid(self): if self.first_stage_model is None: raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 78212bde3..f9c913bdb 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,6 +1,5 @@ import asyncio import bisect -import gc import itertools import psutil import time @@ -475,6 +474,10 @@ class LRUCache(BasicCache): self._mark_used(node_id) return await self._set_immediate(node_id, value) + def set_local(self, node_id, value): + self._mark_used(node_id) + BasicCache.set_local(self, node_id, value) + async def ensure_subcache_for(self, node_id, children_ids): # Just uses subcaches for tracking 'live' nodes await super()._ensure_subcache(node_id, children_ids) @@ -489,15 +492,10 @@ class LRUCache(BasicCache): return self -#Iterating the cache for usage analysis might be expensive, so if we trigger make sure -#to take a chunk out to give breathing space on high-node / low-ram-per-node flows. +#Small baseline weight used when a cache entry has no measurable CPU tensors. +#Keeps unknown-sized entries in eviction scoring without dominating tensor-backed entries. -RAM_CACHE_HYSTERESIS = 1.1 - -#This is kinda in GB but not really. It needs to be non-zero for the below heuristic -#and as long as Multi GB models dwarf this it will approximate OOM scoring OK - -RAM_CACHE_DEFAULT_RAM_USAGE = 0.1 +RAM_CACHE_DEFAULT_RAM_USAGE = 0.05 #Exponential bias towards evicting older workflows so garbage will be taken out #in constantly changing setups. @@ -521,19 +519,17 @@ class RAMPressureCache(LRUCache): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() return await super().get(node_id) - def poll(self, ram_headroom): - def _ram_gb(): - return psutil.virtual_memory().available / (1024**3) + def set_local(self, node_id, value): + self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() + super().set_local(node_id, value) - if _ram_gb() > ram_headroom: - return - gc.collect() - if _ram_gb() > ram_headroom: + def ram_release(self, target): + if psutil.virtual_memory().available >= target: return clean_list = [] - for key, (outputs, _), in self.cache.items(): + for key, cache_entry in self.cache.items(): oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE @@ -542,22 +538,20 @@ class RAMPressureCache(LRUCache): if outputs is None: return for output in outputs: - if isinstance(output, list): + if isinstance(output, (list, tuple)): scan_list_for_ram_usage(output) elif isinstance(output, torch.Tensor) and output.device.type == 'cpu': - #score Tensors at a 50% discount for RAM usage as they are likely to - #be high value intermediates - ram_usage += (output.numel() * output.element_size()) * 0.5 - elif hasattr(output, "get_ram_usage"): - ram_usage += output.get_ram_usage() - scan_list_for_ram_usage(outputs) + ram_usage += output.numel() * output.element_size() + scan_list_for_ram_usage(cache_entry.outputs) oom_score *= ram_usage #In the case where we have no information on the node ram usage at all, #break OOM score ties on the last touch timestamp (pure LRU) bisect.insort(clean_list, (oom_score, self.timestamps[key], key)) - while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list: + while psutil.virtual_memory().available < target and clean_list: _, _, key = clean_list.pop() del self.cache[key] - gc.collect() + self.used_generation.pop(key, None) + self.timestamps.pop(key, None) + self.children.pop(key, None) diff --git a/execution.py b/execution.py index 43c3c648d..5e02dffb2 100644 --- a/execution.py +++ b/execution.py @@ -724,6 +724,9 @@ class PromptExecutor: self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) self._notify_prompt_lifecycle("start", prompt_id) + ram_headroom = int(self.cache_args["ram"] * (1024 ** 3)) + ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None + comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom) try: with torch.inference_mode(): @@ -773,7 +776,10 @@ class PromptExecutor: execution_list.unstage_node_execution() else: # result == ExecutionResult.SUCCESS: execution_list.complete_node_execution() - self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) + + if self.cache_type == CacheType.RAM_PRESSURE: + comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom) + comfy.memory_management.extra_ram_release(ram_headroom) else: # Only execute when the while-loop ends without break # Send cached UI for intermediate output nodes that weren't executed @@ -801,6 +807,7 @@ class PromptExecutor: if comfy.model_management.DISABLE_SMART_MEMORY: comfy.model_management.unload_all_models() finally: + comfy.memory_management.set_ram_cache_release_state(None, 0) self._notify_prompt_lifecycle("end", prompt_id) diff --git a/main.py b/main.py index 058e8e2de..12b04719d 100644 --- a/main.py +++ b/main.py @@ -275,15 +275,19 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]: def prompt_worker(q, server_instance): current_time: float = 0.0 + cache_ram = args.cache_ram + if cache_ram < 0: + cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0)) + cache_type = execution.CacheType.CLASSIC if args.cache_lru > 0: cache_type = execution.CacheType.LRU - elif args.cache_ram > 0: + elif cache_ram > 0: cache_type = execution.CacheType.RAM_PRESSURE elif args.cache_none: cache_type = execution.CacheType.NONE - e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } ) + e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } ) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0