diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index aadb562e6..0d811e354 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -209,7 +209,7 @@ class ExecutionList(TopologicalSort): self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) - def get_output_cache(self, from_node_id, to_node_id): + def get_cache(self, from_node_id, to_node_id): if not to_node_id in self.execution_cache: return None value = self.execution_cache[to_node_id].get(from_node_id) @@ -217,7 +217,7 @@ class ExecutionList(TopologicalSort): return None #Write back to the main cache on touch. self.output_cache.set(from_node_id, value) - return value[0] + return value def cache_update(self, node_id, value): if node_id in self.execution_cache_listeners: diff --git a/execution.py b/execution.py index 5b522fcc9..8e099c336 100644 --- a/execution.py +++ b/execution.py @@ -89,6 +89,11 @@ class IsChangedCache: return self.is_changed[node_id] +class CacheEntry(NamedTuple): + ui: dict + outputs: list + + class CacheType(Enum): CLASSIC = 0 LRU = 1 @@ -160,14 +165,14 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= if execution_list is None: mark_missing() continue # This might be a lazily-evaluated input - cached_output = execution_list.get_output_cache(input_unique_id, unique_id) - if cached_output is None: + cached = execution_list.get_cache(input_unique_id, unique_id) + if cached is None or cached.outputs is None: mark_missing() continue - if output_index >= len(cached_output): + if output_index >= len(cached.outputs): mark_missing() continue - obj = cached_output[output_index] + obj = cached.outputs[output_index] input_data_all[x] = obj elif input_category is not None: input_data_all[x] = [input_data] @@ -407,10 +412,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, cached = caches.outputs.get(unique_id) if cached is not None: if server.client_id is not None: - cached_ui = cached[1] or {} + cached_ui = cached.ui or {} server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id) - if cached[1] is not None: - ui_outputs[unique_id] = cached[1] + if cached.ui is not None: + ui_outputs[unique_id] = cached.ui get_progress_state().finish_progress(unique_id) execution_list.cache_update(unique_id, cached) return (ExecutionResult.SUCCESS, None, None) @@ -442,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, for r in result: if is_link(r): source_node, source_output = r[0], r[1] - node_output = execution_list.get_output_cache(source_node, unique_id)[source_output] - for o in node_output: + node_cached = execution_list.get_cache(source_node, unique_id) + for o in node_cached.outputs[source_output]: resolved_output.append(o) else: @@ -563,8 +568,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, pending_subgraph_results[unique_id] = cached_outputs return (ExecutionResult.PENDING, None, None) - execution_list.cache_update(unique_id, (output_data, ui_outputs.get(unique_id))) - caches.outputs.set(unique_id, (output_data, ui_outputs.get(unique_id))) + cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) + execution_list.cache_update(unique_id, cache_entry) + caches.outputs.set(unique_id, cache_entry) except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted")