mm: Add free_ram()

Add the free_ram() API and a means to install implementations of the
freer (I.E. the RAM cache).
This commit is contained in:
Rattus 2025-11-18 09:15:44 +10:00
parent 68053b1180
commit 62a2622591
2 changed files with 22 additions and 0 deletions

View File

@ -445,6 +445,20 @@ try:
except:
logging.warning("Could not pick default device.")
current_ram_listeners = set()
def register_ram_listener(listener):
current_ram_listeners.add(listener)
def unregister_ram_listener(listener):
current_ram_listeners.discard(listener)
def free_ram(extra_ram=0, state_dict={}):
for tensor in state_dict.values():
if isinstance(tensor, torch.Tensor):
extra_ram += tensor.numel() * tensor.element_size()
for listener in current_ram_listeners:
listener.free_ram(extra_ram)
current_loaded_models = []

View File

@ -613,13 +613,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
class PromptExecutor:
def __init__(self, server, cache_type=False, cache_args=None):
self.caches = None
self.cache_args = cache_args
self.cache_type = cache_type
self.server = server
self.reset()
def reset(self):
if self.caches is not None:
for cache in self.caches.all:
comfy.model_management.unregister_ram_listener(cache)
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
for cache in self.caches.all:
comfy.model_management.register_ram_listener(cache)
self.status_messages = []
self.success = True