diff --git a/comfy/model_management.py b/comfy/model_management.py index aeddbaefe..6222c19ae 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -445,6 +445,20 @@ try: except: logging.warning("Could not pick default device.") +current_ram_listeners = set() + +def register_ram_listener(listener): + current_ram_listeners.add(listener) + +def unregister_ram_listener(listener): + current_ram_listeners.discard(listener) + +def free_ram(extra_ram=0, state_dict={}): + for tensor in state_dict.values(): + if isinstance(tensor, torch.Tensor): + extra_ram += tensor.numel() * tensor.element_size() + for listener in current_ram_listeners: + listener.free_ram(extra_ram) current_loaded_models = [] diff --git a/execution.py b/execution.py index 44e3bb65c..dd5fc8baf 100644 --- a/execution.py +++ b/execution.py @@ -613,13 +613,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, class PromptExecutor: def __init__(self, server, cache_type=False, cache_args=None): + self.caches = None self.cache_args = cache_args self.cache_type = cache_type self.server = server self.reset() def reset(self): + if self.caches is not None: + for cache in self.caches.all: + comfy.model_management.unregister_ram_listener(cache) + self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) + + for cache in self.caches.all: + comfy.model_management.register_ram_listener(cache) self.status_messages = [] self.success = True