mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 05:10:18 +08:00
Move commit registration
This commit is contained in:
parent
8284ea2fca
commit
fb1feed1a2
305
comfy/caching.py
305
comfy/caching.py
@ -1,305 +0,0 @@
|
|||||||
import itertools
|
|
||||||
from typing import Sequence, Mapping
|
|
||||||
|
|
||||||
from .cmd.execution import nodes
|
|
||||||
from .graph import DynamicPrompt
|
|
||||||
from .graph_utils import is_link
|
|
||||||
|
|
||||||
|
|
||||||
class CacheKeySet:
|
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
|
||||||
self.keys = {}
|
|
||||||
self.subcache_keys = {}
|
|
||||||
|
|
||||||
def add_keys(self, node_ids):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def all_node_ids(self):
|
|
||||||
return set(self.keys.keys())
|
|
||||||
|
|
||||||
def get_used_keys(self):
|
|
||||||
return self.keys.values()
|
|
||||||
|
|
||||||
def get_used_subcache_keys(self):
|
|
||||||
return self.subcache_keys.values()
|
|
||||||
|
|
||||||
def get_data_key(self, node_id):
|
|
||||||
return self.keys.get(node_id, None)
|
|
||||||
|
|
||||||
def get_subcache_key(self, node_id):
|
|
||||||
return self.subcache_keys.get(node_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
class Unhashable:
|
|
||||||
def __init__(self):
|
|
||||||
self.value = float("NaN")
|
|
||||||
|
|
||||||
|
|
||||||
def to_hashable(obj):
|
|
||||||
# So that we don't infinitely recurse since frozenset and tuples
|
|
||||||
# are Sequences.
|
|
||||||
if isinstance(obj, (int, float, str, bool, type(None))):
|
|
||||||
return obj
|
|
||||||
elif isinstance(obj, Mapping):
|
|
||||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
|
||||||
elif isinstance(obj, Sequence):
|
|
||||||
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
|
|
||||||
else:
|
|
||||||
# TODO - Support other objects like tensors?
|
|
||||||
return Unhashable()
|
|
||||||
|
|
||||||
|
|
||||||
class CacheKeySetID(CacheKeySet):
|
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
|
||||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
|
||||||
self.dynprompt = dynprompt
|
|
||||||
self.add_keys(node_ids)
|
|
||||||
|
|
||||||
def add_keys(self, node_ids):
|
|
||||||
for node_id in node_ids:
|
|
||||||
if node_id in self.keys:
|
|
||||||
continue
|
|
||||||
node = self.dynprompt.get_node(node_id)
|
|
||||||
self.keys[node_id] = (node_id, node["class_type"])
|
|
||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
|
||||||
|
|
||||||
|
|
||||||
class CacheKeySetInputSignature(CacheKeySet):
|
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
|
||||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
|
||||||
self.dynprompt = dynprompt
|
|
||||||
self.is_changed_cache = is_changed_cache
|
|
||||||
self.add_keys(node_ids)
|
|
||||||
|
|
||||||
def include_node_id_in_input(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def add_keys(self, node_ids):
|
|
||||||
for node_id in node_ids:
|
|
||||||
if node_id in self.keys:
|
|
||||||
continue
|
|
||||||
node = self.dynprompt.get_node(node_id)
|
|
||||||
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
|
||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
|
||||||
|
|
||||||
def get_node_signature(self, dynprompt, node_id):
|
|
||||||
signature = []
|
|
||||||
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
|
||||||
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
|
||||||
for ancestor_id in ancestors:
|
|
||||||
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
|
||||||
return to_hashable(signature)
|
|
||||||
|
|
||||||
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
|
||||||
node = dynprompt.get_node(node_id)
|
|
||||||
class_type = node["class_type"]
|
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
|
||||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
|
||||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
|
|
||||||
signature.append(node_id)
|
|
||||||
inputs = node["inputs"]
|
|
||||||
for key in sorted(inputs.keys()):
|
|
||||||
if is_link(inputs[key]):
|
|
||||||
(ancestor_id, ancestor_socket) = inputs[key]
|
|
||||||
ancestor_index = ancestor_order_mapping[ancestor_id]
|
|
||||||
signature.append((key, ("ANCESTOR", ancestor_index, ancestor_socket)))
|
|
||||||
else:
|
|
||||||
signature.append((key, inputs[key]))
|
|
||||||
return signature
|
|
||||||
|
|
||||||
# This function returns a list of all ancestors of the given node. The order of the list is
|
|
||||||
# deterministic based on which specific inputs the ancestor is connected by.
|
|
||||||
def get_ordered_ancestry(self, dynprompt, node_id):
|
|
||||||
ancestors = []
|
|
||||||
order_mapping = {}
|
|
||||||
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
|
||||||
return ancestors, order_mapping
|
|
||||||
|
|
||||||
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
|
||||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
|
||||||
input_keys = sorted(inputs.keys())
|
|
||||||
for key in input_keys:
|
|
||||||
if is_link(inputs[key]):
|
|
||||||
ancestor_id = inputs[key][0]
|
|
||||||
if ancestor_id not in order_mapping:
|
|
||||||
ancestors.append(ancestor_id)
|
|
||||||
order_mapping[ancestor_id] = len(ancestors) - 1
|
|
||||||
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
|
||||||
|
|
||||||
|
|
||||||
class BasicCache:
|
|
||||||
def __init__(self, key_class):
|
|
||||||
self.key_class = key_class
|
|
||||||
self.initialized = False
|
|
||||||
self.dynprompt: DynamicPrompt
|
|
||||||
self.cache_key_set: CacheKeySet
|
|
||||||
self.cache = {}
|
|
||||||
self.subcaches = {}
|
|
||||||
|
|
||||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
|
||||||
self.dynprompt = dynprompt
|
|
||||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
|
||||||
self.is_changed_cache = is_changed_cache
|
|
||||||
self.initialized = True
|
|
||||||
|
|
||||||
def all_node_ids(self):
|
|
||||||
assert self.initialized
|
|
||||||
node_ids = self.cache_key_set.all_node_ids()
|
|
||||||
for subcache in self.subcaches.values():
|
|
||||||
node_ids = node_ids.union(subcache.all_node_ids())
|
|
||||||
return node_ids
|
|
||||||
|
|
||||||
def _clean_cache(self):
|
|
||||||
preserve_keys = set(self.cache_key_set.get_used_keys())
|
|
||||||
to_remove = []
|
|
||||||
for key in self.cache:
|
|
||||||
if key not in preserve_keys:
|
|
||||||
to_remove.append(key)
|
|
||||||
for key in to_remove:
|
|
||||||
del self.cache[key]
|
|
||||||
|
|
||||||
def _clean_subcaches(self):
|
|
||||||
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
|
||||||
|
|
||||||
to_remove = []
|
|
||||||
for key in self.subcaches:
|
|
||||||
if key not in preserve_subcaches:
|
|
||||||
to_remove.append(key)
|
|
||||||
for key in to_remove:
|
|
||||||
del self.subcaches[key]
|
|
||||||
|
|
||||||
def clean_unused(self):
|
|
||||||
assert self.initialized
|
|
||||||
self._clean_cache()
|
|
||||||
self._clean_subcaches()
|
|
||||||
|
|
||||||
def _set_immediate(self, node_id, value):
|
|
||||||
assert self.initialized
|
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
|
||||||
self.cache[cache_key] = value
|
|
||||||
|
|
||||||
def _get_immediate(self, node_id):
|
|
||||||
if not self.initialized:
|
|
||||||
return None
|
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
|
||||||
if cache_key in self.cache:
|
|
||||||
return self.cache[cache_key]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _ensure_subcache(self, node_id, children_ids):
|
|
||||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
|
||||||
subcache = self.subcaches.get(subcache_key, None)
|
|
||||||
if subcache is None:
|
|
||||||
subcache = BasicCache(self.key_class)
|
|
||||||
self.subcaches[subcache_key] = subcache
|
|
||||||
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
|
||||||
return subcache
|
|
||||||
|
|
||||||
def _get_subcache(self, node_id):
|
|
||||||
assert self.initialized
|
|
||||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
|
||||||
if subcache_key in self.subcaches:
|
|
||||||
return self.subcaches[subcache_key]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def recursive_debug_dump(self):
|
|
||||||
result = []
|
|
||||||
for key in self.cache:
|
|
||||||
result.append({"key": key, "value": self.cache[key]})
|
|
||||||
for key in self.subcaches:
|
|
||||||
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class HierarchicalCache(BasicCache):
|
|
||||||
def __init__(self, key_class):
|
|
||||||
super().__init__(key_class)
|
|
||||||
|
|
||||||
def _get_cache_for(self, node_id):
|
|
||||||
assert self.dynprompt is not None
|
|
||||||
parent_id = self.dynprompt.get_parent_node_id(node_id)
|
|
||||||
if parent_id is None:
|
|
||||||
return self
|
|
||||||
|
|
||||||
hierarchy = []
|
|
||||||
while parent_id is not None:
|
|
||||||
hierarchy.append(parent_id)
|
|
||||||
parent_id = self.dynprompt.get_parent_node_id(parent_id)
|
|
||||||
|
|
||||||
cache = self
|
|
||||||
for parent_id in reversed(hierarchy):
|
|
||||||
cache = cache._get_subcache(parent_id)
|
|
||||||
if cache is None:
|
|
||||||
return None
|
|
||||||
return cache
|
|
||||||
|
|
||||||
def get(self, node_id):
|
|
||||||
cache = self._get_cache_for(node_id)
|
|
||||||
if cache is None:
|
|
||||||
return None
|
|
||||||
return cache._get_immediate(node_id)
|
|
||||||
|
|
||||||
def set(self, node_id, value):
|
|
||||||
cache = self._get_cache_for(node_id)
|
|
||||||
assert cache is not None
|
|
||||||
cache._set_immediate(node_id, value)
|
|
||||||
|
|
||||||
def ensure_subcache_for(self, node_id, children_ids):
|
|
||||||
cache = self._get_cache_for(node_id)
|
|
||||||
assert cache is not None
|
|
||||||
return cache._ensure_subcache(node_id, children_ids)
|
|
||||||
|
|
||||||
|
|
||||||
class LRUCache(BasicCache):
|
|
||||||
def __init__(self, key_class, max_size=100):
|
|
||||||
super().__init__(key_class)
|
|
||||||
self.max_size = max_size
|
|
||||||
self.min_generation = 0
|
|
||||||
self.generation = 0
|
|
||||||
self.used_generation = {}
|
|
||||||
self.children = {}
|
|
||||||
|
|
||||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
|
||||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
|
||||||
self.generation += 1
|
|
||||||
for node_id in node_ids:
|
|
||||||
self._mark_used(node_id)
|
|
||||||
|
|
||||||
def clean_unused(self):
|
|
||||||
while len(self.cache) > self.max_size and self.min_generation < self.generation:
|
|
||||||
self.min_generation += 1
|
|
||||||
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
|
|
||||||
for key in to_remove:
|
|
||||||
del self.cache[key]
|
|
||||||
del self.used_generation[key]
|
|
||||||
if key in self.children:
|
|
||||||
del self.children[key]
|
|
||||||
self._clean_subcaches()
|
|
||||||
|
|
||||||
def get(self, node_id):
|
|
||||||
self._mark_used(node_id)
|
|
||||||
return self._get_immediate(node_id)
|
|
||||||
|
|
||||||
def _mark_used(self, node_id):
|
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
|
||||||
if cache_key is not None:
|
|
||||||
self.used_generation[cache_key] = self.generation
|
|
||||||
|
|
||||||
def set(self, node_id, value):
|
|
||||||
self._mark_used(node_id)
|
|
||||||
return self._set_immediate(node_id, value)
|
|
||||||
|
|
||||||
def ensure_subcache_for(self, node_id, children_ids):
|
|
||||||
# Just uses subcaches for tracking 'live' nodes
|
|
||||||
super()._ensure_subcache(node_id, children_ids)
|
|
||||||
|
|
||||||
self.cache_key_set.add_keys(children_ids)
|
|
||||||
self._mark_used(node_id)
|
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
|
||||||
self.children[cache_key] = []
|
|
||||||
for child_id in children_ids:
|
|
||||||
self._mark_used(child_id)
|
|
||||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
|
||||||
return self
|
|
||||||
File diff suppressed because it is too large
Load Diff
234
comfy/graph.py
234
comfy/graph.py
@ -1,234 +0,0 @@
|
|||||||
from .cmd.execution import nodes
|
|
||||||
from .component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError
|
|
||||||
from .graph_utils import is_link
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicPrompt:
|
|
||||||
def __init__(self, original_prompt):
|
|
||||||
# The original prompt provided by the user
|
|
||||||
self.original_prompt = original_prompt
|
|
||||||
# Any extra pieces of the graph created during execution
|
|
||||||
self.ephemeral_prompt = {}
|
|
||||||
self.ephemeral_parents = {}
|
|
||||||
self.ephemeral_display = {}
|
|
||||||
|
|
||||||
def get_node(self, node_id):
|
|
||||||
if node_id in self.ephemeral_prompt:
|
|
||||||
return self.ephemeral_prompt[node_id]
|
|
||||||
if node_id in self.original_prompt:
|
|
||||||
return self.original_prompt[node_id]
|
|
||||||
raise NodeNotFoundError(f"Node {node_id} not found")
|
|
||||||
|
|
||||||
def has_node(self, node_id):
|
|
||||||
return node_id in self.original_prompt or node_id in self.ephemeral_prompt
|
|
||||||
|
|
||||||
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
|
|
||||||
self.ephemeral_prompt[node_id] = node_info
|
|
||||||
self.ephemeral_parents[node_id] = parent_id
|
|
||||||
self.ephemeral_display[node_id] = display_id
|
|
||||||
|
|
||||||
def get_real_node_id(self, node_id):
|
|
||||||
while node_id in self.ephemeral_parents:
|
|
||||||
node_id = self.ephemeral_parents[node_id]
|
|
||||||
return node_id
|
|
||||||
|
|
||||||
def get_parent_node_id(self, node_id):
|
|
||||||
return self.ephemeral_parents.get(node_id, None)
|
|
||||||
|
|
||||||
def get_display_node_id(self, node_id):
|
|
||||||
while node_id in self.ephemeral_display:
|
|
||||||
node_id = self.ephemeral_display[node_id]
|
|
||||||
return node_id
|
|
||||||
|
|
||||||
def all_node_ids(self):
|
|
||||||
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
|
||||||
|
|
||||||
def get_original_prompt(self):
|
|
||||||
return self.original_prompt
|
|
||||||
|
|
||||||
|
|
||||||
def get_input_info(class_def, input_name):
|
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
|
||||||
input_info = None
|
|
||||||
input_category = None
|
|
||||||
if "required" in valid_inputs and input_name in valid_inputs["required"]:
|
|
||||||
input_category = "required"
|
|
||||||
input_info = valid_inputs["required"][input_name]
|
|
||||||
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
|
|
||||||
input_category = "optional"
|
|
||||||
input_info = valid_inputs["optional"][input_name]
|
|
||||||
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
|
|
||||||
input_category = "hidden"
|
|
||||||
input_info = valid_inputs["hidden"][input_name]
|
|
||||||
if input_info is None:
|
|
||||||
return None, None, None
|
|
||||||
input_type = input_info[0]
|
|
||||||
if len(input_info) > 1:
|
|
||||||
extra_info = input_info[1]
|
|
||||||
else:
|
|
||||||
extra_info = {}
|
|
||||||
return input_type, input_category, extra_info
|
|
||||||
|
|
||||||
|
|
||||||
class TopologicalSort:
|
|
||||||
def __init__(self, dynprompt):
|
|
||||||
self.dynprompt = dynprompt
|
|
||||||
self.pendingNodes = {}
|
|
||||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
|
||||||
self.blocking = {} # Which nodes are blocked by this node
|
|
||||||
|
|
||||||
def get_input_info(self, unique_id, input_name):
|
|
||||||
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
|
||||||
return get_input_info(class_def, input_name)
|
|
||||||
|
|
||||||
def make_input_strong_link(self, to_node_id, to_input):
|
|
||||||
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
|
|
||||||
if to_input not in inputs:
|
|
||||||
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all")
|
|
||||||
value = inputs[to_input]
|
|
||||||
if not is_link(value):
|
|
||||||
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant")
|
|
||||||
from_node_id, from_socket = value
|
|
||||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
|
||||||
|
|
||||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
|
||||||
self.add_node(from_node_id)
|
|
||||||
if to_node_id not in self.blocking[from_node_id]:
|
|
||||||
self.blocking[from_node_id][to_node_id] = {}
|
|
||||||
self.blockCount[to_node_id] += 1
|
|
||||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
|
||||||
|
|
||||||
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
|
|
||||||
if unique_id in self.pendingNodes:
|
|
||||||
return
|
|
||||||
self.pendingNodes[unique_id] = True
|
|
||||||
self.blockCount[unique_id] = 0
|
|
||||||
self.blocking[unique_id] = {}
|
|
||||||
|
|
||||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
|
||||||
for input_name in inputs:
|
|
||||||
value = inputs[input_name]
|
|
||||||
if is_link(value):
|
|
||||||
from_node_id, from_socket = value
|
|
||||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
|
||||||
continue
|
|
||||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
|
||||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
|
||||||
if include_lazy or not is_lazy:
|
|
||||||
self.add_strong_link(from_node_id, from_socket, unique_id)
|
|
||||||
|
|
||||||
def get_ready_nodes(self):
|
|
||||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
|
||||||
|
|
||||||
def pop_node(self, unique_id):
|
|
||||||
del self.pendingNodes[unique_id]
|
|
||||||
for blocked_node_id in self.blocking[unique_id]:
|
|
||||||
self.blockCount[blocked_node_id] -= 1
|
|
||||||
del self.blocking[unique_id]
|
|
||||||
|
|
||||||
def is_empty(self):
|
|
||||||
return len(self.pendingNodes) == 0
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionList(TopologicalSort):
|
|
||||||
"""
|
|
||||||
ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
|
|
||||||
it can still be returned to the graph after having further dependencies added.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, dynprompt, output_cache):
|
|
||||||
super().__init__(dynprompt)
|
|
||||||
self.output_cache = output_cache
|
|
||||||
self.staged_node_id = None
|
|
||||||
|
|
||||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
|
||||||
if self.output_cache.get(from_node_id) is not None:
|
|
||||||
# Nothing to do
|
|
||||||
return
|
|
||||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
|
||||||
|
|
||||||
def stage_node_execution(self):
|
|
||||||
assert self.staged_node_id is None
|
|
||||||
if self.is_empty():
|
|
||||||
return None, None, None
|
|
||||||
available = self.get_ready_nodes()
|
|
||||||
if len(available) == 0:
|
|
||||||
cycled_nodes = self.get_nodes_in_cycle()
|
|
||||||
# Because cycles composed entirely of static nodes are caught during initial validation,
|
|
||||||
# we will 'blame' the first node in the cycle that is not a static node.
|
|
||||||
blamed_node = cycled_nodes[0]
|
|
||||||
for node_id in cycled_nodes:
|
|
||||||
display_node_id = self.dynprompt.get_display_node_id(node_id)
|
|
||||||
if display_node_id != node_id:
|
|
||||||
blamed_node = display_node_id
|
|
||||||
break
|
|
||||||
ex = DependencyCycleError("Dependency cycle detected")
|
|
||||||
error_details = {
|
|
||||||
"node_id": blamed_node,
|
|
||||||
"exception_message": str(ex),
|
|
||||||
"exception_type": "graph.DependencyCycleError",
|
|
||||||
"traceback": [],
|
|
||||||
"current_inputs": []
|
|
||||||
}
|
|
||||||
return None, error_details, ex
|
|
||||||
next_node = available[0]
|
|
||||||
# If an output node is available, do that first.
|
|
||||||
# Technically this has no effect on the overall length of execution, but it feels better as a user
|
|
||||||
# for a PreviewImage to display a result as soon as it can
|
|
||||||
# Some other heuristics could probably be used here to improve the UX further.
|
|
||||||
for node_id in available:
|
|
||||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
|
||||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
|
||||||
next_node = node_id
|
|
||||||
break
|
|
||||||
self.staged_node_id = next_node
|
|
||||||
return self.staged_node_id, None, None
|
|
||||||
|
|
||||||
def unstage_node_execution(self):
|
|
||||||
assert self.staged_node_id is not None
|
|
||||||
self.staged_node_id = None
|
|
||||||
|
|
||||||
def complete_node_execution(self):
|
|
||||||
node_id = self.staged_node_id
|
|
||||||
self.pop_node(node_id)
|
|
||||||
self.staged_node_id = None
|
|
||||||
|
|
||||||
def get_nodes_in_cycle(self):
|
|
||||||
# We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
|
|
||||||
# We're skipping some of the performance optimizations from the original TopologicalSort to keep
|
|
||||||
# the code simple (and because having a cycle in the first place is a catastrophic error)
|
|
||||||
blocked_by = {node_id: {} for node_id in self.pendingNodes}
|
|
||||||
for from_node_id in self.blocking:
|
|
||||||
for to_node_id in self.blocking[from_node_id]:
|
|
||||||
if True in self.blocking[from_node_id][to_node_id].values():
|
|
||||||
blocked_by[to_node_id][from_node_id] = True
|
|
||||||
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
|
||||||
while len(to_remove) > 0:
|
|
||||||
for node_id in to_remove:
|
|
||||||
for to_node_id in blocked_by:
|
|
||||||
if node_id in blocked_by[to_node_id]:
|
|
||||||
del blocked_by[to_node_id][node_id]
|
|
||||||
del blocked_by[node_id]
|
|
||||||
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
|
||||||
return list(blocked_by.keys())
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionBlocker:
|
|
||||||
"""
|
|
||||||
Return this from a node and any users will be blocked with the given error message.
|
|
||||||
If the message is None, execution will be blocked silently instead.
|
|
||||||
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
|
||||||
possible, a lazy input will be more efficient and have a better user experience.
|
|
||||||
This functionality is useful in two cases:
|
|
||||||
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
|
||||||
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
|
||||||
lazy evaluation to let it conditionally disable itself.)
|
|
||||||
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
|
||||||
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
|
||||||
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message):
|
|
||||||
self.message = message
|
|
||||||
@ -1,143 +0,0 @@
|
|||||||
def is_link(obj):
|
|
||||||
if not isinstance(obj, list):
|
|
||||||
return False
|
|
||||||
if len(obj) != 2:
|
|
||||||
return False
|
|
||||||
if not isinstance(obj[0], str):
|
|
||||||
return False
|
|
||||||
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class GraphBuilder:
|
|
||||||
"""
|
|
||||||
The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
|
|
||||||
"""
|
|
||||||
_default_prefix_root = ""
|
|
||||||
_default_prefix_call_index = 0
|
|
||||||
_default_prefix_graph_index = 0
|
|
||||||
|
|
||||||
def __init__(self, prefix=None):
|
|
||||||
if prefix is None:
|
|
||||||
self.prefix = GraphBuilder.alloc_prefix()
|
|
||||||
else:
|
|
||||||
self.prefix = prefix
|
|
||||||
self.nodes = {}
|
|
||||||
self.id_gen = 1
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def set_default_prefix(cls, prefix_root, call_index, graph_index=0):
|
|
||||||
cls._default_prefix_root = prefix_root
|
|
||||||
cls._default_prefix_call_index = call_index
|
|
||||||
cls._default_prefix_graph_index = graph_index
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def alloc_prefix(cls, root=None, call_index=None, graph_index=None):
|
|
||||||
if root is None:
|
|
||||||
root = GraphBuilder._default_prefix_root
|
|
||||||
if call_index is None:
|
|
||||||
call_index = GraphBuilder._default_prefix_call_index
|
|
||||||
if graph_index is None:
|
|
||||||
graph_index = GraphBuilder._default_prefix_graph_index
|
|
||||||
result = f"{root}.{call_index}.{graph_index}."
|
|
||||||
GraphBuilder._default_prefix_graph_index += 1
|
|
||||||
return result
|
|
||||||
|
|
||||||
def node(self, class_type, id=None, **kwargs):
|
|
||||||
if id is None:
|
|
||||||
id = str(self.id_gen)
|
|
||||||
self.id_gen += 1
|
|
||||||
id = self.prefix + id
|
|
||||||
if id in self.nodes:
|
|
||||||
return self.nodes[id]
|
|
||||||
|
|
||||||
node = Node(id, class_type, kwargs)
|
|
||||||
self.nodes[id] = node
|
|
||||||
return node
|
|
||||||
|
|
||||||
def lookup_node(self, id):
|
|
||||||
id = self.prefix + id
|
|
||||||
return self.nodes.get(id)
|
|
||||||
|
|
||||||
def finalize(self):
|
|
||||||
output = {}
|
|
||||||
for node_id, node in self.nodes.items():
|
|
||||||
output[node_id] = node.serialize()
|
|
||||||
return output
|
|
||||||
|
|
||||||
def replace_node_output(self, node_id, index, new_value):
|
|
||||||
node_id = self.prefix + node_id
|
|
||||||
to_remove = []
|
|
||||||
for node in self.nodes.values():
|
|
||||||
for key, value in node.inputs.items():
|
|
||||||
if is_link(value) and value[0] == node_id and value[1] == index:
|
|
||||||
if new_value is None:
|
|
||||||
to_remove.append((node, key))
|
|
||||||
else:
|
|
||||||
node.inputs[key] = new_value
|
|
||||||
for node, key in to_remove:
|
|
||||||
del node.inputs[key]
|
|
||||||
|
|
||||||
def remove_node(self, id):
|
|
||||||
id = self.prefix + id
|
|
||||||
del self.nodes[id]
|
|
||||||
|
|
||||||
|
|
||||||
class Node:
|
|
||||||
def __init__(self, id, class_type, inputs):
|
|
||||||
self.id = id
|
|
||||||
self.class_type = class_type
|
|
||||||
self.inputs = inputs
|
|
||||||
self.override_display_id = None
|
|
||||||
|
|
||||||
def out(self, index):
|
|
||||||
return [self.id, index]
|
|
||||||
|
|
||||||
def set_input(self, key, value):
|
|
||||||
if value is None:
|
|
||||||
if key in self.inputs:
|
|
||||||
del self.inputs[key]
|
|
||||||
else:
|
|
||||||
self.inputs[key] = value
|
|
||||||
|
|
||||||
def get_input(self, key):
|
|
||||||
return self.inputs.get(key)
|
|
||||||
|
|
||||||
def set_override_display_id(self, override_display_id):
|
|
||||||
self.override_display_id = override_display_id
|
|
||||||
|
|
||||||
def serialize(self):
|
|
||||||
serialized = {
|
|
||||||
"class_type": self.class_type,
|
|
||||||
"inputs": self.inputs
|
|
||||||
}
|
|
||||||
if self.override_display_id is not None:
|
|
||||||
serialized["override_display_id"] = self.override_display_id
|
|
||||||
return serialized
|
|
||||||
|
|
||||||
|
|
||||||
def add_graph_prefix(graph, outputs, prefix):
|
|
||||||
# Change the node IDs and any internal links
|
|
||||||
new_graph = {}
|
|
||||||
for node_id, node_info in graph.items():
|
|
||||||
# Make sure the added nodes have unique IDs
|
|
||||||
new_node_id = prefix + node_id
|
|
||||||
new_node = {"class_type": node_info["class_type"], "inputs": {}}
|
|
||||||
for input_name, input_value in node_info.get("inputs", {}).items():
|
|
||||||
if is_link(input_value):
|
|
||||||
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
|
|
||||||
else:
|
|
||||||
new_node["inputs"][input_name] = input_value
|
|
||||||
new_graph[new_node_id] = new_node
|
|
||||||
|
|
||||||
# Change the node IDs in the outputs
|
|
||||||
new_outputs = []
|
|
||||||
for n in range(len(outputs)):
|
|
||||||
output = outputs[n]
|
|
||||||
if is_link(output):
|
|
||||||
new_outputs.append([prefix + output[0], output[1]])
|
|
||||||
else:
|
|
||||||
new_outputs.append(output)
|
|
||||||
|
|
||||||
return new_graph, tuple(new_outputs)
|
|
||||||
Loading…
Reference in New Issue
Block a user