diff --git a/comfy/memory_management.py b/comfy/memory_management.py index f9078fe7c..48e3c11da 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -141,3 +141,17 @@ def interpret_gathered_like(tensors, gathered): return dest_views aimdo_enabled = False + +extra_ram_release_callback = None +RAM_CACHE_HEADROOM = 0 + +def set_ram_cache_release_state(callback, headroom): + global extra_ram_release_callback + global RAM_CACHE_HEADROOM + extra_ram_release_callback = callback + RAM_CACHE_HEADROOM = max(0, int(headroom)) + +def extra_ram_release(target): + if extra_ram_release_callback is None: + return 0 + return extra_ram_release_callback(target) diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index f6fb806c4..6f142282d 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -2,6 +2,7 @@ import comfy.model_management import comfy.memory_management import comfy_aimdo.host_buffer import comfy_aimdo.torch +import psutil from comfy.cli_args import args @@ -12,6 +13,11 @@ def pin_memory(module): if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: return #FIXME: This is a RAM cache trigger event + ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM + #we split the difference and assume half the RAM cache headroom is for us + if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5): + comfy.memory_management.extra_ram_release(ram_headroom) + size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: diff --git a/execution.py b/execution.py index 50e32e287..ff4a480dc 100644 --- a/execution.py +++ b/execution.py @@ -109,7 +109,6 @@ class CacheType(Enum): class CacheSet: def __init__(self, cache_type=None, cache_args={}): - self.ram_release_callback = None if cache_type == CacheType.NONE: self.init_null_cache() logging.info("Disabling intermediate node cache.") @@ -138,7 +137,6 @@ class CacheSet: def init_ram_cache(self, min_headroom): self.outputs = RAMPressureCache(CacheKeySetInputSignature, enable_providers=True) self.objects = HierarchicalCache(CacheKeySetID) - self.ram_release_callback = self.outputs.ram_release def init_null_cache(self): self.outputs = NullCache() @@ -717,7 +715,9 @@ class PromptExecutor: self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) self._notify_prompt_lifecycle("start", prompt_id) - comfy.model_management.register_extra_ram_release_callback(self.caches.ram_release_callback) + ram_headroom = int(self.cache_args["ram"] * (1024 ** 3)) + ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None + comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom) try: with torch.inference_mode(): @@ -767,6 +767,10 @@ class PromptExecutor: execution_list.unstage_node_execution() else: # result == ExecutionResult.SUCCESS: execution_list.complete_node_execution() + + if self.cache_type == CacheType.RAM_PRESSURE: + comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom) + comfy.memory_management.extra_ram_release(ram_headroom) else: # Only execute when the while-loop ends without break self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) @@ -784,7 +788,7 @@ class PromptExecutor: if comfy.model_management.DISABLE_SMART_MEMORY: comfy.model_management.unload_all_models() finally: - comfy.model_management.register_extra_ram_release_callback(None) + comfy.memory_management.set_ram_cache_release_state(None, 0) self._notify_prompt_lifecycle("end", prompt_id)