mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-12 22:42:36 +08:00
execution: Convert the cache entry to NamedTuple
As commented in review. Convert this to a named tuple and abstract away the tuple type completely from graph.py.
This commit is contained in:
parent
f3f526fcd3
commit
d2b87e674a
@ -209,7 +209,7 @@ class ExecutionList(TopologicalSort):
|
|||||||
self.execution_cache_listeners[from_node_id] = set()
|
self.execution_cache_listeners[from_node_id] = set()
|
||||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||||
|
|
||||||
def get_output_cache(self, from_node_id, to_node_id):
|
def get_cache(self, from_node_id, to_node_id):
|
||||||
if not to_node_id in self.execution_cache:
|
if not to_node_id in self.execution_cache:
|
||||||
return None
|
return None
|
||||||
value = self.execution_cache[to_node_id].get(from_node_id)
|
value = self.execution_cache[to_node_id].get(from_node_id)
|
||||||
@ -217,7 +217,7 @@ class ExecutionList(TopologicalSort):
|
|||||||
return None
|
return None
|
||||||
#Write back to the main cache on touch.
|
#Write back to the main cache on touch.
|
||||||
self.output_cache.set(from_node_id, value)
|
self.output_cache.set(from_node_id, value)
|
||||||
return value[0]
|
return value
|
||||||
|
|
||||||
def cache_update(self, node_id, value):
|
def cache_update(self, node_id, value):
|
||||||
if node_id in self.execution_cache_listeners:
|
if node_id in self.execution_cache_listeners:
|
||||||
|
|||||||
28
execution.py
28
execution.py
@ -89,6 +89,11 @@ class IsChangedCache:
|
|||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
|
|
||||||
|
class CacheEntry(NamedTuple):
|
||||||
|
ui: dict
|
||||||
|
outputs: list
|
||||||
|
|
||||||
|
|
||||||
class CacheType(Enum):
|
class CacheType(Enum):
|
||||||
CLASSIC = 0
|
CLASSIC = 0
|
||||||
LRU = 1
|
LRU = 1
|
||||||
@ -160,14 +165,14 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
|||||||
if execution_list is None:
|
if execution_list is None:
|
||||||
mark_missing()
|
mark_missing()
|
||||||
continue # This might be a lazily-evaluated input
|
continue # This might be a lazily-evaluated input
|
||||||
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
|
cached = execution_list.get_cache(input_unique_id, unique_id)
|
||||||
if cached_output is None:
|
if cached is None or cached.outputs is None:
|
||||||
mark_missing()
|
mark_missing()
|
||||||
continue
|
continue
|
||||||
if output_index >= len(cached_output):
|
if output_index >= len(cached.outputs):
|
||||||
mark_missing()
|
mark_missing()
|
||||||
continue
|
continue
|
||||||
obj = cached_output[output_index]
|
obj = cached.outputs[output_index]
|
||||||
input_data_all[x] = obj
|
input_data_all[x] = obj
|
||||||
elif input_category is not None:
|
elif input_category is not None:
|
||||||
input_data_all[x] = [input_data]
|
input_data_all[x] = [input_data]
|
||||||
@ -407,10 +412,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
cached = caches.outputs.get(unique_id)
|
cached = caches.outputs.get(unique_id)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
cached_ui = cached[1] or {}
|
cached_ui = cached.ui or {}
|
||||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||||
if cached[1] is not None:
|
if cached.ui is not None:
|
||||||
ui_outputs[unique_id] = cached[1]
|
ui_outputs[unique_id] = cached.ui
|
||||||
get_progress_state().finish_progress(unique_id)
|
get_progress_state().finish_progress(unique_id)
|
||||||
execution_list.cache_update(unique_id, cached)
|
execution_list.cache_update(unique_id, cached)
|
||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
@ -442,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
for r in result:
|
for r in result:
|
||||||
if is_link(r):
|
if is_link(r):
|
||||||
source_node, source_output = r[0], r[1]
|
source_node, source_output = r[0], r[1]
|
||||||
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
|
node_cached = execution_list.get_cache(source_node, unique_id)
|
||||||
for o in node_output:
|
for o in node_cached.outputs[source_output]:
|
||||||
resolved_output.append(o)
|
resolved_output.append(o)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -563,8 +568,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
pending_subgraph_results[unique_id] = cached_outputs
|
pending_subgraph_results[unique_id] = cached_outputs
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return (ExecutionResult.PENDING, None, None)
|
||||||
|
|
||||||
execution_list.cache_update(unique_id, (output_data, ui_outputs.get(unique_id)))
|
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
|
||||||
caches.outputs.set(unique_id, (output_data, ui_outputs.get(unique_id)))
|
execution_list.cache_update(unique_id, cache_entry)
|
||||||
|
caches.outputs.set(unique_id, cache_entry)
|
||||||
|
|
||||||
except comfy.model_management.InterruptProcessingException as iex:
|
except comfy.model_management.InterruptProcessingException as iex:
|
||||||
logging.info("Processing interrupted")
|
logging.info("Processing interrupted")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user