mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-08 13:20:50 +08:00
comfy_execution: add cache barriers
Add a system where an input is marked as a cache barrier, deferring its evaluation. Once the node is executed, the barrier is released and everything behind the barrier is executed at increase priority.
This commit is contained in:
parent
96ad4904fe
commit
783da446c1
@ -107,6 +107,8 @@ class TopologicalSort:
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
self.externalBlocks = 0
|
||||
self.unblockedEvent = asyncio.Event()
|
||||
self.priorities = {}
|
||||
self.barrierNodes = set()
|
||||
|
||||
def get_input_info(self, unique_id, input_name):
|
||||
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
||||
@ -125,13 +127,37 @@ class TopologicalSort:
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
if not self.is_cached(from_node_id):
|
||||
self.add_node(from_node_id)
|
||||
self.add_node(from_node_id, priority=self.priorities.get(to_node_id, 0))
|
||||
if to_node_id not in self.blocking[from_node_id]:
|
||||
self.blocking[from_node_id][to_node_id] = {}
|
||||
self.blockCount[to_node_id] += 1
|
||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||
|
||||
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
|
||||
def is_barrier(self, node_id):
|
||||
return node_id in self.barrierNodes
|
||||
|
||||
def unbarrier(self, node_id):
|
||||
if not node_id in self.barrierNodes:
|
||||
return
|
||||
self.barrierNodes.remove(node_id)
|
||||
self.priorities[node_id] = self.priorities.get(node_id, 0) + 1
|
||||
|
||||
links = []
|
||||
inputs = self.dynprompt.get_node(node_id)["inputs"]
|
||||
|
||||
for input_name in inputs:
|
||||
value = inputs[input_name]
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
_, _, input_info = self.get_input_info(node_id, input_name)
|
||||
is_barrier = input_info is not None and "cache-barrier" in input_info and input_info["cache-barrier"]
|
||||
if is_barrier:
|
||||
links.append((from_node_id, from_socket, node_id))
|
||||
|
||||
for link in links:
|
||||
self.add_strong_link(*link)
|
||||
|
||||
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None, priority=0):
|
||||
node_ids = [node_unique_id]
|
||||
links = []
|
||||
|
||||
@ -143,6 +169,7 @@ class TopologicalSort:
|
||||
self.pendingNodes[unique_id] = True
|
||||
self.blockCount[unique_id] = 0
|
||||
self.blocking[unique_id] = {}
|
||||
self.priorities[unique_id] = priority
|
||||
|
||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||
for input_name in inputs:
|
||||
@ -153,10 +180,13 @@ class TopologicalSort:
|
||||
continue
|
||||
_, _, input_info = self.get_input_info(unique_id, input_name)
|
||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||
if (include_lazy or not is_lazy):
|
||||
is_barrier = input_info is not None and "cache-barrier" in input_info and input_info["cache-barrier"]
|
||||
if (include_lazy or (not is_lazy and not is_barrier)):
|
||||
if not self.is_cached(from_node_id):
|
||||
node_ids.append(from_node_id)
|
||||
links.append((from_node_id, from_socket, unique_id))
|
||||
if is_barrier:
|
||||
self.barrierNodes.add(unique_id)
|
||||
|
||||
for link in links:
|
||||
self.add_strong_link(*link)
|
||||
@ -175,7 +205,7 @@ class TopologicalSort:
|
||||
return False
|
||||
|
||||
def get_ready_nodes(self):
|
||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||
return [(self.priorities.get(node_id, 0), node_id) for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||
|
||||
def pop_node(self, unique_id):
|
||||
del self.pendingNodes[unique_id]
|
||||
@ -281,25 +311,34 @@ class ExecutionList(TopologicalSort):
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
|
||||
|
||||
for node_id in node_list:
|
||||
priority_level = 0
|
||||
priority_nodes = []
|
||||
for (priority, node_id) in node_list:
|
||||
if priority > priority_level:
|
||||
priority_level = priority
|
||||
priority_nodes = []
|
||||
if priority == priority_level:
|
||||
priority_nodes.append(node_id)
|
||||
|
||||
for node_id in priority_nodes:
|
||||
if is_output(node_id) or is_async(node_id):
|
||||
return node_id
|
||||
|
||||
#This should handle the VAEDecode -> preview case
|
||||
for node_id in node_list:
|
||||
for node_id in priority_nodes:
|
||||
for blocked_node_id in self.blocking[node_id]:
|
||||
if is_output(blocked_node_id):
|
||||
return node_id
|
||||
|
||||
#This should handle the VAELoader -> VAEDecode -> preview case
|
||||
for node_id in node_list:
|
||||
for node_id in priority_nodes:
|
||||
for blocked_node_id in self.blocking[node_id]:
|
||||
for blocked_node_id1 in self.blocking[blocked_node_id]:
|
||||
if is_output(blocked_node_id1):
|
||||
return node_id
|
||||
|
||||
#TODO: this function should be improved
|
||||
return node_list[0]
|
||||
return priority_nodes[0]
|
||||
|
||||
def unstage_node_execution(self):
|
||||
assert self.staged_node_id is not None
|
||||
|
||||
@ -428,7 +428,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
|
||||
input_data_all = None
|
||||
try:
|
||||
if unique_id in pending_async_nodes:
|
||||
if execution_list.is_barrier(unique_id):
|
||||
execution_list.unbarrier(unique_id)
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
elif unique_id in pending_async_nodes:
|
||||
results = []
|
||||
for r in pending_async_nodes[unique_id]:
|
||||
if isinstance(r, asyncio.Task):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user