diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 0d811e354..908e62e10 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -107,6 +107,8 @@ class TopologicalSort: self.blocking = {} # Which nodes are blocked by this node self.externalBlocks = 0 self.unblockedEvent = asyncio.Event() + self.priorities = {} + self.barrierNodes = set() def get_input_info(self, unique_id, input_name): class_type = self.dynprompt.get_node(unique_id)["class_type"] @@ -125,13 +127,37 @@ class TopologicalSort: def add_strong_link(self, from_node_id, from_socket, to_node_id): if not self.is_cached(from_node_id): - self.add_node(from_node_id) + self.add_node(from_node_id, priority=self.priorities.get(to_node_id, 0)) if to_node_id not in self.blocking[from_node_id]: self.blocking[from_node_id][to_node_id] = {} self.blockCount[to_node_id] += 1 self.blocking[from_node_id][to_node_id][from_socket] = True - def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None): + def is_barrier(self, node_id): + return node_id in self.barrierNodes + + def unbarrier(self, node_id): + if not node_id in self.barrierNodes: + return + self.barrierNodes.remove(node_id) + self.priorities[node_id] = self.priorities.get(node_id, 0) + 1 + + links = [] + inputs = self.dynprompt.get_node(node_id)["inputs"] + + for input_name in inputs: + value = inputs[input_name] + if is_link(value): + from_node_id, from_socket = value + _, _, input_info = self.get_input_info(node_id, input_name) + is_barrier = input_info is not None and "cache-barrier" in input_info and input_info["cache-barrier"] + if is_barrier: + links.append((from_node_id, from_socket, node_id)) + + for link in links: + self.add_strong_link(*link) + + def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None, priority=0): node_ids = [node_unique_id] links = [] @@ -143,6 +169,7 @@ class TopologicalSort: self.pendingNodes[unique_id] = True self.blockCount[unique_id] = 0 self.blocking[unique_id] = {} + self.priorities[unique_id] = priority inputs = self.dynprompt.get_node(unique_id)["inputs"] for input_name in inputs: @@ -153,10 +180,13 @@ class TopologicalSort: continue _, _, input_info = self.get_input_info(unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] - if (include_lazy or not is_lazy): + is_barrier = input_info is not None and "cache-barrier" in input_info and input_info["cache-barrier"] + if (include_lazy or (not is_lazy and not is_barrier)): if not self.is_cached(from_node_id): node_ids.append(from_node_id) links.append((from_node_id, from_socket, unique_id)) + if is_barrier: + self.barrierNodes.add(unique_id) for link in links: self.add_strong_link(*link) @@ -175,7 +205,7 @@ class TopologicalSort: return False def get_ready_nodes(self): - return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] + return [(self.priorities.get(node_id, 0), node_id) for node_id in self.pendingNodes if self.blockCount[node_id] == 0] def pop_node(self, unique_id): del self.pendingNodes[unique_id] @@ -281,25 +311,34 @@ class ExecutionList(TopologicalSort): class_def = nodes.NODE_CLASS_MAPPINGS[class_type] return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION)) - for node_id in node_list: + priority_level = 0 + priority_nodes = [] + for (priority, node_id) in node_list: + if priority > priority_level: + priority_level = priority + priority_nodes = [] + if priority == priority_level: + priority_nodes.append(node_id) + + for node_id in priority_nodes: if is_output(node_id) or is_async(node_id): return node_id #This should handle the VAEDecode -> preview case - for node_id in node_list: + for node_id in priority_nodes: for blocked_node_id in self.blocking[node_id]: if is_output(blocked_node_id): return node_id #This should handle the VAELoader -> VAEDecode -> preview case - for node_id in node_list: + for node_id in priority_nodes: for blocked_node_id in self.blocking[node_id]: for blocked_node_id1 in self.blocking[blocked_node_id]: if is_output(blocked_node_id1): return node_id #TODO: this function should be improved - return node_list[0] + return priority_nodes[0] def unstage_node_execution(self): assert self.staged_node_id is not None diff --git a/execution.py b/execution.py index 14b582be5..96835f80c 100644 --- a/execution.py +++ b/execution.py @@ -428,7 +428,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, input_data_all = None try: - if unique_id in pending_async_nodes: + if execution_list.is_barrier(unique_id): + execution_list.unbarrier(unique_id) + return (ExecutionResult.PENDING, None, None) + elif unique_id in pending_async_nodes: results = [] for r in pending_async_nodes[unique_id]: if isinstance(r, asyncio.Task):