diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 566bc3f9c..41224ce3b 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -265,26 +265,6 @@ class HierarchicalCache(BasicCache): assert cache is not None return await cache._ensure_subcache(node_id, children_ids) -class NullCache: - - async def set_prompt(self, dynprompt, node_ids, is_changed_cache): - pass - - def all_node_ids(self): - return [] - - def clean_unused(self): - pass - - def get(self, node_id): - return None - - def set(self, node_id, value): - pass - - async def ensure_subcache_for(self, node_id, children_ids): - return self - class LRUCache(BasicCache): def __init__(self, key_class, max_size=100): super().__init__(key_class) @@ -336,3 +316,157 @@ class LRUCache(BasicCache): self._mark_used(child_id) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) return self + + +class DependencyAwareCache(BasicCache): + """ + A cache implementation that tracks dependencies between nodes and manages + their execution and caching accordingly. It extends the BasicCache class. + Nodes are removed from this cache once all of their descendants have been + executed. + """ + + def __init__(self, key_class): + """ + Initialize the DependencyAwareCache. + + Args: + key_class: The class used for generating cache keys. + """ + super().__init__(key_class) + self.descendants = {} # Maps node_id -> set of descendant node_ids + self.ancestors = {} # Maps node_id -> set of ancestor node_ids + self.executed_nodes = set() # Tracks nodes that have been executed + + async def set_prompt(self, dynprompt, node_ids, is_changed_cache): + """ + Clear the entire cache and rebuild the dependency graph. + + Args: + dynprompt: The dynamic prompt object containing node information. + node_ids: List of node IDs to initialize the cache for. + is_changed_cache: Flag indicating if the cache has changed. + """ + # Clear all existing cache data + self.cache.clear() + self.subcaches.clear() + self.descendants.clear() + self.ancestors.clear() + self.executed_nodes.clear() + + # Call the parent method to initialize the cache with the new prompt + await super().set_prompt(dynprompt, node_ids, is_changed_cache) + + # Rebuild the dependency graph + self._build_dependency_graph(dynprompt, node_ids) + + def _build_dependency_graph(self, dynprompt, node_ids): + """ + Build the dependency graph for all nodes. + + Args: + dynprompt: The dynamic prompt object containing node information. + node_ids: List of node IDs to build the graph for. + """ + self.descendants.clear() + self.ancestors.clear() + for node_id in node_ids: + self.descendants[node_id] = set() + self.ancestors[node_id] = set() + + for node_id in node_ids: + inputs = dynprompt.get_node(node_id)["inputs"] + for input_data in inputs.values(): + if is_link(input_data): # Check if the input is a link to another node + ancestor_id = input_data[0] + self.descendants[ancestor_id].add(node_id) + self.ancestors[node_id].add(ancestor_id) + + def set(self, node_id, value): + """ + Mark a node as executed and store its value in the cache. + + Args: + node_id: The ID of the node to store. + value: The value to store for the node. + """ + self._set_immediate(node_id, value) + self.executed_nodes.add(node_id) + self._cleanup_ancestors(node_id) + + def get(self, node_id): + """ + Retrieve the cached value for a node. + + Args: + node_id: The ID of the node to retrieve. + + Returns: + The cached value for the node. + """ + return self._get_immediate(node_id) + + async def ensure_subcache_for(self, node_id, children_ids): + """ + Ensure a subcache exists for a node and update dependencies. + + Args: + node_id: The ID of the parent node. + children_ids: List of child node IDs to associate with the parent node. + + Returns: + The subcache object for the node. + """ + subcache = await super()._ensure_subcache(node_id, children_ids) + for child_id in children_ids: + self.descendants[node_id].add(child_id) + self.ancestors[child_id].add(node_id) + return subcache + + def _cleanup_ancestors(self, node_id): + """ + Check if ancestors of a node can be removed from the cache. + + Args: + node_id: The ID of the node whose ancestors are to be checked. + """ + for ancestor_id in self.ancestors.get(node_id, []): + if ancestor_id in self.executed_nodes: + # Remove ancestor if all its descendants have been executed + if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]): + self._remove_node(ancestor_id) + + def _remove_node(self, node_id): + """ + Remove a node from the cache. + + Args: + node_id: The ID of the node to remove. + """ + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key in self.cache: + del self.cache[cache_key] + subcache_key = self.cache_key_set.get_subcache_key(node_id) + if subcache_key in self.subcaches: + del self.subcaches[subcache_key] + + def clean_unused(self): + """ + Clean up unused nodes. This is a no-op for this cache implementation. + """ + pass + + def recursive_debug_dump(self): + """ + Dump the cache and dependency graph for debugging. + + Returns: + A list containing the cache state and dependency graph. + """ + result = super().recursive_debug_dump() + result.append({ + "descendants": self.descendants, + "ancestors": self.ancestors, + "executed_nodes": list(self.executed_nodes), + }) + return result diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index d5bbacde3..f4b427265 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -153,9 +153,8 @@ class TopologicalSort: continue _, _, input_info = self.get_input_info(unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] - if (include_lazy or not is_lazy): - if not self.is_cached(from_node_id): - node_ids.append(from_node_id) + if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): + node_ids.append(from_node_id) links.append((from_node_id, from_socket, unique_id)) for link in links: @@ -195,34 +194,10 @@ class ExecutionList(TopologicalSort): super().__init__(dynprompt) self.output_cache = output_cache self.staged_node_id = None - self.execution_cache = {} - self.execution_cache_listeners = {} def is_cached(self, node_id): return self.output_cache.get(node_id) is not None - def cache_link(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: - self.execution_cache[to_node_id] = {} - self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) - if not from_node_id in self.execution_cache_listeners: - self.execution_cache_listeners[from_node_id] = set() - self.execution_cache_listeners[from_node_id].add(to_node_id) - - def get_output_cache(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: - return None - return self.execution_cache[to_node_id].get(from_node_id) - - def cache_update(self, node_id, value): - if node_id in self.execution_cache_listeners: - for to_node_id in self.execution_cache_listeners[node_id]: - self.execution_cache[to_node_id][node_id] = value - - def add_strong_link(self, from_node_id, from_socket, to_node_id): - super().add_strong_link(from_node_id, from_socket, to_node_id) - self.cache_link(from_node_id, to_node_id) - async def stage_node_execution(self): assert self.staged_node_id is None if self.is_empty(): @@ -302,8 +277,6 @@ class ExecutionList(TopologicalSort): def complete_node_execution(self): node_id = self.staged_node_id self.pop_node(node_id) - self.execution_cache.pop(node_id, None) - self.execution_cache_listeners.pop(node_id, None) self.staged_node_id = None def get_nodes_in_cycle(self): diff --git a/execution.py b/execution.py index 78c36a4b0..1dc35738b 100644 --- a/execution.py +++ b/execution.py @@ -18,7 +18,7 @@ from comfy_execution.caching import ( BasicCache, CacheKeySetID, CacheKeySetInputSignature, - NullCache, + DependencyAwareCache, HierarchicalCache, LRUCache, ) @@ -91,13 +91,13 @@ class IsChangedCache: class CacheType(Enum): CLASSIC = 0 LRU = 1 - NONE = 2 + DEPENDENCY_AWARE = 2 class CacheSet: def __init__(self, cache_type=None, cache_size=None): - if cache_type == CacheType.NONE: - self.init_null_cache() + if cache_type == CacheType.DEPENDENCY_AWARE: + self.init_dependency_aware_cache() logging.info("Disabling intermediate node cache.") elif cache_type == CacheType.LRU: if cache_size is None: @@ -120,12 +120,11 @@ class CacheSet: self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID) - def init_null_cache(self): - self.outputs = NullCache() - #The UI cache is expected to be iterable at the end of each workflow - #so it must cache at least a full workflow. Use Heirachical - self.ui = HierarchicalCache(CacheKeySetInputSignature) - self.objects = NullCache() + # only hold cached items while the decendents have not executed + def init_dependency_aware_cache(self): + self.outputs = DependencyAwareCache(CacheKeySetInputSignature) + self.ui = DependencyAwareCache(CacheKeySetInputSignature) + self.objects = DependencyAwareCache(CacheKeySetID) def recursive_debug_dump(self): result = { @@ -136,7 +135,7 @@ class CacheSet: SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") -def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): +def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): is_v3 = issubclass(class_def, _ComfyNodeInternal) if is_v3: valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) @@ -154,10 +153,10 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): input_unique_id = input_data[0] output_index = input_data[1] - if execution_list is None: + if outputs is None: mark_missing() continue # This might be a lazily-evaluated input - cached_output = execution_list.get_output_cache(input_unique_id, unique_id) + cached_output = outputs.get(input_unique_id) if cached_output is None: mark_missing() continue @@ -406,7 +405,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, cached_output = caches.ui.get(unique_id) or {} server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) get_progress_state().finish_progress(unique_id) - execution_list.cache_update(unique_id, caches.outputs.get(unique_id)) return (ExecutionResult.SUCCESS, None, None) input_data_all = None @@ -436,7 +434,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, for r in result: if is_link(r): source_node, source_output = r[0], r[1] - node_output = execution_list.get_output_cache(source_node, unique_id)[source_output] + node_output = caches.outputs.get(source_node)[source_output] for o in node_output: resolved_output.append(o) @@ -448,7 +446,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) + input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -551,15 +549,11 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, subcache.clean_unused() for node_id in new_output_ids: execution_list.add_node(node_id) - execution_list.cache_link(node_id, unique_id) for link in new_output_links: execution_list.add_strong_link(link[0], link[1], unique_id) pending_subgraph_results[unique_id] = cached_outputs return (ExecutionResult.PENDING, None, None) - caches.outputs.set(unique_id, output_data) - execution_list.cache_update(unique_id, output_data) - except comfy.model_management.InterruptProcessingException as iex: logging.info("Processing interrupted") diff --git a/main.py b/main.py index 4b4c5dcc4..35857dba8 100644 --- a/main.py +++ b/main.py @@ -173,7 +173,7 @@ def prompt_worker(q, server_instance): if args.cache_lru > 0: cache_type = execution.CacheType.LRU elif args.cache_none: - cache_type = execution.CacheType.NONE + cache_type = execution.CacheType.DEPENDENCY_AWARE e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) last_gc_collect = 0 diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index ace0d2279..ef73ad9fd 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -152,12 +152,12 @@ class TestExecution: # Initialize server and client # @fixture(scope="class", autouse=True, params=[ - { "extra_args" : [], "should_cache_results" : True }, - { "extra_args" : ["--cache-lru", 0], "should_cache_results" : True }, - { "extra_args" : ["--cache-lru", 100], "should_cache_results" : True }, - { "extra_args" : ["--cache-none"], "should_cache_results" : False }, + # (use_lru, lru_size) + (False, 0), + (True, 0), + (True, 100), ]) - def server(self, args_pytest, request): + def _server(self, args_pytest, request): # Start server pargs = [ 'python','main.py', @@ -167,10 +167,12 @@ class TestExecution: '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', '--cpu', ] - pargs += [ str(param) for param in request.param["extra_args"] ] + use_lru, lru_size = request.param + if use_lru: + pargs += ['--cache-lru', str(lru_size)] print("Running server with args:", pargs) # noqa: T201 p = subprocess.Popen(pargs) - yield request.param + yield p.kill() torch.cuda.empty_cache() @@ -191,7 +193,7 @@ class TestExecution: return comfy_client @fixture(scope="class", autouse=True) - def shared_client(self, args_pytest, server): + def shared_client(self, args_pytest, _server): client = self.start_client(args_pytest["listen"], args_pytest["port"]) yield client del client @@ -223,7 +225,7 @@ class TestExecution: assert result.did_run(mask) assert result.did_run(lazy_mix) - def test_full_cache(self, client: ComfyClient, builder: GraphBuilder, server): + def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -235,12 +237,9 @@ class TestExecution: client.run(g) result2 = client.run(g) for node_id, node in g.nodes.items(): - if server["should_cache_results"]: - assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" - else: - assert result2.did_run(node), f"Node {node_id} was cached, but should have been run" + assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" - def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder, server): + def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -252,12 +251,8 @@ class TestExecution: client.run(g) mask.inputs['value'] = 0.4 result2 = client.run(g) - if server["should_cache_results"]: - assert not result2.did_run(input1), "Input1 should have been cached" - assert not result2.did_run(input2), "Input2 should have been cached" - else: - assert result2.did_run(input1), "Input1 should have been rerun" - assert result2.did_run(input2), "Input2 should have been rerun" + assert not result2.did_run(input1), "Input1 should have been cached" + assert not result2.did_run(input2), "Input2 should have been cached" def test_error(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -416,7 +411,7 @@ class TestExecution: input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1) client.run(g) - def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder, server): + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): g = builder # Creating the nodes in this specific order previously caused a bug save = g.node("SaveImage") @@ -432,10 +427,7 @@ class TestExecution: result3 = client.run(g) result4 = client.run(g) assert result1.did_run(is_changed), "is_changed should have been run" - if server["should_cache_results"]: - assert not result2.did_run(is_changed), "is_changed should have been cached" - else: - assert result2.did_run(is_changed), "is_changed should have been re-run" + assert not result2.did_run(is_changed), "is_changed should have been cached" assert result3.did_run(is_changed), "is_changed should have been re-run" assert result4.did_run(is_changed), "is_changed should not have been cached" @@ -522,7 +514,7 @@ class TestExecution: assert len(images2) == 1, "Should have 1 image" # This tests that only constant outputs are used in the call to `IS_CHANGED` - def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server): + def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1) test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5) @@ -538,11 +530,7 @@ class TestExecution: images = result.get_images(output) assert len(images) == 1, "Should have 1 image" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" - if server["should_cache_results"]: - assert not result.did_run(test_node), "The execution should have been cached" - else: - assert result.did_run(test_node), "The execution should have been re-run" - + assert not result.did_run(test_node), "The execution should have been cached" def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): # Warmup execution to ensure server is fully initialized