execution: Convert the cache entry to NamedTuple

As commented in review.

Convert this to a named tuple and abstract away the tuple type
completely from graph.py.
This commit is contained in:
Rattus 2025-10-30 20:40:53 +10:00
parent f3f526fcd3
commit d2b87e674a
2 changed files with 19 additions and 13 deletions

View File

@ -209,7 +209,7 @@ class ExecutionList(TopologicalSort):
self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id)
def get_output_cache(self, from_node_id, to_node_id):
def get_cache(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache:
return None
value = self.execution_cache[to_node_id].get(from_node_id)
@ -217,7 +217,7 @@ class ExecutionList(TopologicalSort):
return None
#Write back to the main cache on touch.
self.output_cache.set(from_node_id, value)
return value[0]
return value
def cache_update(self, node_id, value):
if node_id in self.execution_cache_listeners:

View File

@ -89,6 +89,11 @@ class IsChangedCache:
return self.is_changed[node_id]
class CacheEntry(NamedTuple):
ui: dict
outputs: list
class CacheType(Enum):
CLASSIC = 0
LRU = 1
@ -160,14 +165,14 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
if execution_list is None:
mark_missing()
continue # This might be a lazily-evaluated input
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
if cached_output is None:
cached = execution_list.get_cache(input_unique_id, unique_id)
if cached is None or cached.outputs is None:
mark_missing()
continue
if output_index >= len(cached_output):
if output_index >= len(cached.outputs):
mark_missing()
continue
obj = cached_output[output_index]
obj = cached.outputs[output_index]
input_data_all[x] = obj
elif input_category is not None:
input_data_all[x] = [input_data]
@ -407,10 +412,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
cached = caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None:
cached_ui = cached[1] or {}
cached_ui = cached.ui or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
if cached[1] is not None:
ui_outputs[unique_id] = cached[1]
if cached.ui is not None:
ui_outputs[unique_id] = cached.ui
get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None)
@ -442,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
for r in result:
if is_link(r):
source_node, source_output = r[0], r[1]
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
for o in node_output:
node_cached = execution_list.get_cache(source_node, unique_id)
for o in node_cached.outputs[source_output]:
resolved_output.append(o)
else:
@ -563,8 +568,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None)
execution_list.cache_update(unique_id, (output_data, ui_outputs.get(unique_id)))
caches.outputs.set(unique_id, (output_data, ui_outputs.get(unique_id)))
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, cache_entry)
caches.outputs.set(unique_id, cache_entry)
except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted")