Update with our changes

This commit is contained in:
doctorpangloss 2024-08-16 14:31:26 -07:00
parent f04b582744
commit 24a9eb2600
4 changed files with 288 additions and 153 deletions

View File

@ -1,10 +1,10 @@
import itertools import itertools
from typing import Sequence, Mapping from typing import Sequence, Mapping
from comfy_execution.graph import DynamicPrompt
import nodes from .cmd.execution import nodes
from .graph import DynamicPrompt
from .graph_utils import is_link
from comfy_execution.graph_utils import is_link
class CacheKeySet: class CacheKeySet:
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed_cache):
@ -29,10 +29,12 @@ class CacheKeySet:
def get_subcache_key(self, node_id): def get_subcache_key(self, node_id):
return self.subcache_keys.get(node_id, None) return self.subcache_keys.get(node_id, None)
class Unhashable: class Unhashable:
def __init__(self): def __init__(self):
self.value = float("NaN") self.value = float("NaN")
def to_hashable(obj): def to_hashable(obj):
# So that we don't infinitely recurse since frozenset and tuples # So that we don't infinitely recurse since frozenset and tuples
# are Sequences. # are Sequences.
@ -46,6 +48,7 @@ def to_hashable(obj):
# TODO - Support other objects like tensors? # TODO - Support other objects like tensors?
return Unhashable() return Unhashable()
class CacheKeySetID(CacheKeySet): class CacheKeySetID(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache) super().__init__(dynprompt, node_ids, is_changed_cache)
@ -60,6 +63,7 @@ class CacheKeySetID(CacheKeySet):
self.keys[node_id] = (node_id, node["class_type"]) self.keys[node_id] = (node_id, node["class_type"])
self.subcache_keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"])
class CacheKeySetInputSignature(CacheKeySet): class CacheKeySetInputSignature(CacheKeySet):
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed_cache):
super().__init__(dynprompt, node_ids, is_changed_cache) super().__init__(dynprompt, node_ids, is_changed_cache)
@ -98,7 +102,7 @@ class CacheKeySetInputSignature(CacheKeySet):
if is_link(inputs[key]): if is_link(inputs[key]):
(ancestor_id, ancestor_socket) = inputs[key] (ancestor_id, ancestor_socket) = inputs[key]
ancestor_index = ancestor_order_mapping[ancestor_id] ancestor_index = ancestor_order_mapping[ancestor_id]
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) signature.append((key, ("ANCESTOR", ancestor_index, ancestor_socket)))
else: else:
signature.append((key, inputs[key])) signature.append((key, inputs[key]))
return signature return signature
@ -122,6 +126,7 @@ class CacheKeySetInputSignature(CacheKeySet):
order_mapping[ancestor_id] = len(ancestors) - 1 order_mapping[ancestor_id] = len(ancestors) - 1
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
class BasicCache: class BasicCache:
def __init__(self, key_class): def __init__(self, key_class):
self.key_class = key_class self.key_class = key_class
@ -207,6 +212,7 @@ class BasicCache:
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()}) result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
return result return result
class HierarchicalCache(BasicCache): class HierarchicalCache(BasicCache):
def __init__(self, key_class): def __init__(self, key_class):
super().__init__(key_class) super().__init__(key_class)
@ -245,6 +251,7 @@ class HierarchicalCache(BasicCache):
assert cache is not None assert cache is not None
return cache._ensure_subcache(node_id, children_ids) return cache._ensure_subcache(node_id, children_ids)
class LRUCache(BasicCache): class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100): def __init__(self, key_class, max_size=100):
super().__init__(key_class) super().__init__(key_class)
@ -296,4 +303,3 @@ class LRUCache(BasicCache):
self._mark_used(child_id) self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self return self

View File

@ -1,30 +1,44 @@
import sys from __future__ import annotations
import copy import copy
import logging
import threading
import heapq import heapq
import inspect
import logging
import sys
import threading
import time import time
import traceback import traceback
from enum import Enum import typing
import inspect from os import PathLike
from typing import List, Literal, NamedTuple, Optional from typing import List, Optional, Tuple
import lazy_object_proxy
import torch import torch
import nodes from opentelemetry.trace import get_current_span, StatusCode, Status
import comfy.model_management from .main_pre import tracer
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from .. import interruption
from comfy_execution.graph_utils import is_link, GraphBuilder from .. import model_management
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from comfy.cli_args import args from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \
HistoryResultDict
from ..component_model.files import canonicalize_path
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import new_execution_context, ExecutionContext
from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions
class ExecutionResult(Enum): # ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
SUCCESS = 0 # various functions that are declared here. It should have been a context in the first place.
FAILURE = 1 nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
PENDING = 2
# order matters
from ..graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from ..graph_utils import is_link, GraphBuilder
from ..caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
class DuplicateNodeError(Exception):
pass
class IsChangedCache: class IsChangedCache:
def __init__(self, dynprompt, outputs_cache): def __init__(self, dynprompt, outputs_cache):
@ -49,19 +63,19 @@ class IsChangedCache:
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache) input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
try: try:
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED") is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
except Exception as e: except:
logging.warning("WARNING: {}".format(e))
node["is_changed"] = float("NaN") node["is_changed"] = float("NaN")
finally: finally:
self.is_changed[node_id] = node["is_changed"] self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id] return self.is_changed[node_id]
class CacheSet: class CacheSet:
def __init__(self, lru_size=None): def __init__(self, lru_size=None):
if lru_size is None or lru_size == 0: if lru_size is None or lru_size == 0:
self.init_classic_cache() self.init_classic_cache()
else: else:
self.init_lru_cache(lru_size) self.init_lru_cache(lru_size)
self.all = [self.outputs, self.ui, self.objects] self.all = [self.outputs, self.ui, self.objects]
@ -86,22 +100,29 @@ class CacheSet:
} }
return result return result
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data=None):
if extra_data is None:
extra_data = {}
if outputs is None:
outputs = {}
valid_inputs = class_def.INPUT_TYPES() valid_inputs = class_def.INPUT_TYPES()
input_data_all = {} input_data_all = {}
missing_keys = {} missing_keys = {}
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x) input_type, input_category, input_info = get_input_info(class_def, x)
def mark_missing(): def mark_missing():
missing_keys[x] = True missing_keys[x] = True
input_data_all[x] = (None,) input_data_all[x] = (None,)
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
input_unique_id = input_data[0] input_unique_id = input_data[0]
output_index = input_data[1] output_index = input_data[1]
if outputs is None: if outputs is None:
mark_missing() mark_missing()
continue # This might be a lazily-evaluated input continue # This might be a lazily-evaluated input
cached_output = outputs.get(input_unique_id) cached_output = outputs.get(input_unique_id)
if cached_output is None: if cached_output is None:
mark_missing() mark_missing()
@ -114,6 +135,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
elif input_category is not None: elif input_category is not None:
input_data_all[x] = [input_data] input_data_all[x] = [input_data]
# todo: this should be retrieved from the execution context
if "hidden" in valid_inputs: if "hidden" in valid_inputs:
h = valid_inputs["hidden"] h = valid_inputs["hidden"]
for x in h: for x in h:
@ -127,9 +149,35 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
input_data_all[x] = [unique_id] input_data_all[x] = [unique_id]
return input_data_all, missing_keys return input_data_all, missing_keys
map_node_over_list = None #Don't hook this please
def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): @tracer.start_as_current_span("Execute Node")
def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
span = get_current_span()
class_type = obj.__class__.__name__
span.set_attribute("class_type", class_type)
if input_data_all is not None:
for kwarg_name, kwarg_value in input_data_all.items():
if isinstance(kwarg_value, str) or isinstance(kwarg_value, bool) or isinstance(kwarg_value, int) or isinstance(kwarg_value, float):
span.set_attribute(f"input_data_all.{kwarg_name}", kwarg_value)
else:
try:
items_to_display = []
if hasattr(kwarg_value, "shape"):
# if the object has a shape attribute (likely a NumPy array or similar), get up to the first ten elements
flat_values = kwarg_value.flatten() if hasattr(kwarg_value, "flatten") else kwarg_value
items_to_display = [flat_values[i] for i in range(min(10, flat_values.size))]
elif hasattr(kwarg_value, "__getitem__") and hasattr(kwarg_value, "__len__"):
# If the object is indexable and has a length, get the first ten items
items_to_display = [kwarg_value[i] for i in range(min(10, len(kwarg_value)))]
filtered_items = [
item for item in items_to_display if isinstance(item, (str, bool, int, float))
]
if filtered_items:
span.set_attribute(f"input_data_all.{kwarg_name}", filtered_items)
except TypeError:
pass
# check if node wants the lists # check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False) input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@ -137,15 +185,16 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
max_len_input = 0 max_len_input = 0
else: else:
max_len_input = max(len(x) for x in input_data_all.values()) max_len_input = max(len(x) for x in input_data_all.values())
# get a slice of inputs, repeat last input when list isn't long enough # get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i): def slice_dict(d, i):
return {k: v[i if len(v) > i else -1] for k, v in d.items()} return {k: v[i if len(v) > i else -1] for k, v in d.items()}
results = [] results = []
def process_inputs(inputs, index=None): def process_inputs(inputs, index=None):
if allow_interrupt: if allow_interrupt:
nodes.before_node_execution() interruption.throw_exception_if_processing_interrupted()
execution_block = None execution_block = None
for k, v in inputs.items(): for k, v in inputs.items():
if isinstance(v, ExecutionBlocker): if isinstance(v, ExecutionBlocker):
@ -162,12 +211,13 @@ def _map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execut
process_inputs(input_data_all, 0) process_inputs(input_data_all, 0)
elif max_len_input == 0: elif max_len_input == 0:
process_inputs({}) process_inputs({})
else: else:
for i in range(max_len_input): for i in range(max_len_input):
input_dict = slice_dict(input_data_all, i) input_dict = slice_dict(input_data_all, i)
process_inputs(input_dict, i) process_inputs(input_dict, i)
return results return results
def merge_result_data(results, obj): def merge_result_data(results, obj):
# check which outputs need concatenating # check which outputs need concatenating
output = [] output = []
@ -183,12 +233,12 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results]) output.append([o[i] for o in results])
return output return output
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
results = [] results = []
uis = [] uis = []
subgraph_results = [] subgraph_results = []
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
has_subgraph = False has_subgraph = False
for i in range(len(return_values)): for i in range(len(return_values)):
r = return_values[i] r = return_values[i]
@ -214,19 +264,20 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
r = tuple([r] * len(obj.RETURN_TYPES)) r = tuple([r] * len(obj.RETURN_TYPES))
results.append(r) results.append(r)
subgraph_results.append((None, r)) subgraph_results.append((None, r))
if has_subgraph: if has_subgraph:
output = subgraph_results output = subgraph_results
elif len(results) > 0: elif len(results) > 0:
output = merge_result_data(results, obj) output = merge_result_data(results, obj)
else: else:
output = [] output = []
ui = dict() ui = dict()
if len(uis) > 0: if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui, has_subgraph return output, ui, has_subgraph
def format_value(x):
def format_value(x) -> FormattedValue:
if x is None: if x is None:
return None return None
elif isinstance(x, (int, float, bool, str)): elif isinstance(x, (int, float, bool, str)):
@ -234,6 +285,7 @@ def format_value(x):
else: else:
return str(x) return str(x)
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
unique_id = current_item unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id) real_node_id = dynprompt.get_real_node_id(unique_id)
@ -245,8 +297,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
if caches.outputs.get(unique_id) is not None: if caches.outputs.get(unique_id) is not None:
if server.client_id is not None: if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {} cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": cached_output.get("output", None), "prompt_id": prompt_id}, server.client_id)
return (ExecutionResult.SUCCESS, None, None) return RecursiveExecutionTuple(ExecutionResult.SUCCESS, None, None)
input_data_all = None input_data_all = None
try: try:
@ -275,7 +327,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
if server.client_id is not None: if server.client_id is not None:
server.last_node_id = display_node_id server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", {"node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id}, server.client_id)
obj = caches.objects.get(unique_id) obj = caches.objects.get(unique_id)
if obj is None: if obj is None:
@ -283,10 +335,10 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
caches.objects.set(unique_id, obj) caches.objects.set(unique_id, obj)
if hasattr(obj, "check_lazy_status"): if hasattr(obj, "check_lazy_status"):
required_inputs = _map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True) required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and ( required_inputs = [x for x in required_inputs if isinstance(x, str) and (
x not in input_data_all or x in missing_keys x not in input_data_all or x in missing_keys
)] )]
if len(required_inputs) > 0: if len(required_inputs) > 0:
for i in required_inputs: for i in required_inputs:
@ -311,8 +363,10 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
return ExecutionBlocker(None) return ExecutionBlocker(None)
else: else:
return block return block
def pre_execute_cb(call_index): def pre_execute_cb(call_index):
GraphBuilder.set_default_prefix(unique_id, call_index, 0) GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
if len(output_ui) > 0: if len(output_ui) > 0:
caches.ui.set(unique_id, { caches.ui.set(unique_id, {
@ -325,7 +379,8 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
"output": output_ui "output": output_ui
}) })
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id},
server.client_id)
if has_subgraph: if has_subgraph:
cached_outputs = [] cached_outputs = []
new_node_ids = [] new_node_ids = []
@ -364,15 +419,15 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
pending_subgraph_results[unique_id] = cached_outputs pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None) return (ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data) caches.outputs.set(unique_id, output_data)
except comfy.model_management.InterruptProcessingException as iex: except interruption.InterruptProcessingException as iex:
logging.info("Processing interrupted") logging.info("Processing interrupted")
# skip formatting inputs/outputs # skip formatting inputs/outputs
error_details = { error_details: RecursiveExecutionErrorDetailsInterrupted = {
"node_id": real_node_id, "node_id": real_node_id,
} }
return (ExecutionResult.FAILURE, error_details, iex) return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, iex)
except Exception as ex: except Exception as ex:
typ, _, tb = sys.exc_info() typ, _, tb = sys.exc_info()
exception_type = full_type_name(typ) exception_type = full_type_name(typ)
@ -382,40 +437,46 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
for name, inputs in input_data_all.items(): for name, inputs in input_data_all.items():
input_data_formatted[name] = [format_value(x) for x in inputs] input_data_formatted[name] = [format_value(x) for x in inputs]
logging.error(f"!!! Exception during processing !!! {ex}") logging.error("An error occurred while executing a workflow", exc_info=ex)
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
error_details = { error_details: RecursiveExecutionErrorDetails = {
"node_id": real_node_id, "node_id": real_node_id,
"exception_message": str(ex), "exception_message": str(ex),
"exception_type": exception_type, "exception_type": exception_type,
"traceback": traceback.format_tb(tb), "traceback": traceback.format_tb(tb),
"current_inputs": input_data_formatted "current_inputs": input_data_formatted
} }
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
return (ExecutionResult.FAILURE, error_details, ex) if isinstance(ex, model_management.OOM_EXCEPTION):
logging.error("Got an OOM, unloading all loaded models.")
model_management.unload_all_models()
return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex)
executed.add(unique_id) executed.add(unique_id)
return (ExecutionResult.SUCCESS, None, None) return ExecutionResult.SUCCESS, None, None
class PromptExecutor: class PromptExecutor:
def __init__(self, server, lru_size=None): def __init__(self, server: ExecutorToClientProgress, lru_size=None):
self.success = None
self.lru_size = lru_size self.lru_size = lru_size
self.server = server self.server = server
self.raise_exceptions = False
self.reset() self.reset()
self.history_result: HistoryResultDict | None = None
def reset(self): def reset(self):
self.success = True
self.caches = CacheSet(self.lru_size) self.caches = CacheSet(self.lru_size)
self.status_messages = [] self.status_messages = []
self.success = True
def add_message(self, event, data: dict, broadcast: bool): def add_message(self, event, data: dict, broadcast: bool):
data = { data = {
**data, **data,
# todo: use a real time library
"timestamp": int(time.time() * 1000), "timestamp": int(time.time() * 1000),
} }
self.status_messages.append((event, data)) self.status_messages.append((event, data))
@ -423,12 +484,16 @@ class PromptExecutor:
self.server.send_sync(event, data, self.server.client_id) self.server.send_sync(event, data, self.server.client_id)
def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
current_span = get_current_span()
current_span.set_status(Status(StatusCode.ERROR))
current_span.record_exception(ex)
node_id = error["node_id"] node_id = error["node_id"]
class_type = prompt[node_id]["class_type"] class_type = prompt[node_id]["class_type"]
# First, send back the status to the frontend depending # First, send back the status to the frontend depending
# on the exception type # on the exception type
if isinstance(ex, comfy.model_management.InterruptProcessingException): if isinstance(ex, interruption.InterruptProcessingException):
mes = { mes = {
"prompt_id": prompt_id, "prompt_id": prompt_id,
"node_id": node_id, "node_id": node_id,
@ -449,9 +514,20 @@ class PromptExecutor:
"current_outputs": list(current_outputs), "current_outputs": list(current_outputs),
} }
self.add_message("execution_error", mes, broadcast=False) self.add_message("execution_error", mes, broadcast=False)
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): if ex is not None and self.raise_exceptions:
nodes.interrupt_processing(False) raise ex
def execute(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
with new_execution_context(ExecutionContext(self.server)):
self._execute_inner(prompt, prompt_id, extra_data, execute_outputs)
def _execute_inner(self, prompt, prompt_id, extra_data=None, execute_outputs: List[str] = None):
if execute_outputs is None:
execute_outputs = []
if extra_data is None:
extra_data = {}
interruption.interrupt_current_processing(False)
if "client_id" in extra_data: if "client_id" in extra_data:
self.server.client_id = extra_data["client_id"] self.server.client_id = extra_data["client_id"]
@ -459,7 +535,7 @@ class PromptExecutor:
self.server.client_id = None self.server.client_id = None
self.status_messages = [] self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode(): with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt) dynamic_prompt = DynamicPrompt(prompt)
@ -473,10 +549,10 @@ class PromptExecutor:
if self.caches.outputs.get(node_id) is not None: if self.caches.outputs.get(node_id) is not None:
cached_nodes.append(node_id) cached_nodes.append(node_id)
comfy.model_management.cleanup_models(keep_clone_weights_loaded=True) model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached", self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id}, { "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False) broadcast=False)
pending_subgraph_results = {} pending_subgraph_results = {}
executed = set() executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
@ -496,7 +572,7 @@ class PromptExecutor:
break break
elif result == ExecutionResult.PENDING: elif result == ExecutionResult.PENDING:
execution_list.unstage_node_execution() execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS: else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution() execution_list.complete_node_execution()
else: else:
# Only execute when the while-loop ends without break # Only execute when the while-loop ends without break
@ -515,12 +591,17 @@ class PromptExecutor:
"meta": meta_outputs, "meta": meta_outputs,
} }
self.server.last_node_id = None self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY: if model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models() model_management.unload_all_models()
@property
def outputs_ui(self) -> dict | None:
return self.history_result["outputs"] if self.history_result is not None else None
def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple:
def validate_inputs(prompt, item, validated): # todo: this should check if LoadImage / LoadImageMask paths exist
# todo: or, nodes should provide a way to validate their values
unique_id = item unique_id = item
if unique_id in validated: if unique_id in validated:
return validated[unique_id] return validated[unique_id]
@ -530,11 +611,16 @@ def validate_inputs(prompt, item, validated):
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
class_inputs = obj_class.INPUT_TYPES() class_inputs = obj_class.INPUT_TYPES()
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) valid_inputs = set(class_inputs.get('required', {})).union(set(class_inputs.get('optional', {})))
error: ValidationErrorDict
errors = [] errors = []
valid = True valid = True
# todo: investigate if these are at the right indent level
info: Optional[InputTypeSpec] = None
val = None
validate_function_inputs = [] validate_function_inputs = []
validate_has_kwargs = False validate_has_kwargs = False
if hasattr(obj_class, "VALIDATE_INPUTS"): if hasattr(obj_class, "VALIDATE_INPUTS"):
@ -560,7 +646,7 @@ def validate_inputs(prompt, item, validated):
continue continue
val = inputs[x] val = inputs[x]
info = (type_input, extra_info) info: InputTypeSpec = (type_input, extra_info)
if isinstance(val, list): if isinstance(val, list):
if len(val) != 2: if len(val) != 2:
error = { error = {
@ -581,7 +667,8 @@ def validate_inputs(prompt, item, validated):
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
received_type = r[val[1]] received_type = r[val[1]]
received_types[x] = received_type received_types[x] = received_type
if 'input_types' not in validate_function_inputs and received_type != type_input: any_enum = received_type == [] and (isinstance(type_input, list) or isinstance(type_input, tuple))
if 'input_types' not in validate_function_inputs and received_type != type_input and not any_enum:
details = f"{x}, {received_type} != {type_input}" details = f"{x}, {received_type} != {type_input}"
error = { error = {
"type": "return_type_mismatch", "type": "return_type_mismatch",
@ -597,8 +684,8 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
try: try:
r = validate_inputs(prompt, o_id, validated) r2 = validate_inputs(prompt, o_id, validated)
if r[0] is False: if r2[0] is False:
# `r` will be set in `validated[o_id]` already # `r` will be set in `validated[o_id]` already
valid = False valid = False
continue continue
@ -619,7 +706,7 @@ def validate_inputs(prompt, item, validated):
"linked_node": val "linked_node": val
} }
}] }]
validated[o_id] = (False, reasons, o_id) validated[o_id] = ValidateInputsTuple(False, reasons, o_id)
continue continue
else: else:
try: try:
@ -650,11 +737,12 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if x not in validate_function_inputs and not validate_has_kwargs: if x not in validate_function_inputs:
if "min" in extra_info and val < extra_info["min"]: has_min_max: IntSpecOptions | FloatSpecOptions = info[1]
if "min" in has_min_max and val < has_min_max["min"]:
error = { error = {
"type": "value_smaller_than_min", "type": "value_smaller_than_min",
"message": "Value {} smaller than min of {}".format(val, extra_info["min"]), "message": "Value {} smaller than min of {}".format(val, has_min_max["min"]),
"details": f"{x}", "details": f"{x}",
"extra_info": { "extra_info": {
"input_name": x, "input_name": x,
@ -664,10 +752,10 @@ def validate_inputs(prompt, item, validated):
} }
errors.append(error) errors.append(error)
continue continue
if "max" in extra_info and val > extra_info["max"]: if "max" in has_min_max and val > has_min_max["max"]:
error = { error = {
"type": "value_bigger_than_max", "type": "value_bigger_than_max",
"message": "Value {} bigger than max of {}".format(val, extra_info["max"]), "message": "Value {} bigger than max of {}".format(val, has_min_max["max"]),
"details": f"{x}", "details": f"{x}",
"extra_info": { "extra_info": {
"input_name": x, "input_name": x,
@ -679,6 +767,11 @@ def validate_inputs(prompt, item, validated):
continue continue
if isinstance(type_input, list): if isinstance(type_input, list):
if "\\" in val:
# try to normalize paths for comparison purposes
val = canonicalize_path(val)
if all(isinstance(item, (str, PathLike)) for item in type_input):
type_input = [canonicalize_path(item) for item in type_input]
if val not in type_input: if val not in type_input:
input_config = info input_config = info
list_info = "" list_info = ""
@ -713,8 +806,8 @@ def validate_inputs(prompt, item, validated):
if 'input_types' in validate_function_inputs: if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types] input_filtered['input_types'] = [received_types]
#ret = obj_class.VALIDATE_INPUTS(**input_filtered) # ret = obj_class.VALIDATE_INPUTS(**input_filtered)
ret = _map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
for x in input_filtered: for x in input_filtered:
for i, r in enumerate(ret): for i, r in enumerate(ret):
if r is not True and not isinstance(r, ExecutionBlocker): if r is not True and not isinstance(r, ExecutionBlocker):
@ -734,20 +827,45 @@ def validate_inputs(prompt, item, validated):
continue continue
if len(errors) > 0 or valid is not True: if len(errors) > 0 or valid is not True:
ret = (False, errors, unique_id) ret = ValidateInputsTuple(False, errors, unique_id)
else: else:
ret = (True, [], unique_id) ret = ValidateInputsTuple(True, [], unique_id)
validated[unique_id] = ret validated[unique_id] = ret
return ret return ret
def full_type_name(klass): def full_type_name(klass):
module = klass.__module__ module = klass.__module__
if module == 'builtins': if module == 'builtins':
return klass.__qualname__ return klass.__qualname__
return module + '.' + klass.__qualname__ return module + '.' + klass.__qualname__
def validate_prompt(prompt):
@tracer.start_as_current_span("Validate Prompt")
def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
res = _validate_prompt(prompt)
if not res.valid:
span = get_current_span()
span.set_status(Status(StatusCode.ERROR))
if res.error is not None and len(res.error) > 0:
span.set_attributes({
f"error.{k}": v for k, v in res.error.items() if isinstance(v, (bool, str, bytes, int, float, list))
})
if "extra_info" in res.error and isinstance(res.error["extra_info"], dict):
extra_info: ValidationErrorExtraInfoDict = res.error["extra_info"]
span.set_attributes({
f"error.extra_info.{k}": v for k, v in extra_info.items() if isinstance(v, (str, list))
})
if len(res.node_errors) > 0:
for node_id, node_error in res.node_errors.items():
for node_error_field, node_error_value in node_error.items():
if isinstance(node_error_value, (str, bool, int, float)):
span.set_attribute(f"node_errors.{node_id}.{node_error_field}", node_error_value)
return res
def _validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
outputs = set() outputs = set()
for x in prompt: for x in prompt:
if 'class_type' not in prompt[x]: if 'class_type' not in prompt[x]:
@ -757,7 +875,7 @@ def validate_prompt(prompt):
"details": f"Node ID '#{x}'", "details": f"Node ID '#{x}'",
"extra_info": {} "extra_info": {}
} }
return (False, error, [], []) return ValidationTuple(False, error, [], [])
class_type = prompt[x]['class_type'] class_type = prompt[x]['class_type']
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None) class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
@ -768,7 +886,7 @@ def validate_prompt(prompt):
"details": f"Node ID '#{x}'", "details": f"Node ID '#{x}'",
"extra_info": {} "extra_info": {}
} }
return (False, error, [], []) return ValidationTuple(False, error, [], [])
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True: if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
outputs.add(x) outputs.add(x)
@ -780,15 +898,15 @@ def validate_prompt(prompt):
"details": "", "details": "",
"extra_info": {} "extra_info": {}
} }
return (False, error, [], []) return ValidationTuple(False, error, [], [])
good_outputs = set() good_outputs = set()
errors = [] errors = []
node_errors = {} node_errors: typing.Dict[str, NodeErrorsDictValue] = {}
validated = {} validated: typing.Dict[str, ValidateInputsTuple] = {}
for o in outputs: for o in outputs:
valid = False valid = False
reasons = [] reasons: List[ValidationErrorDict] = []
try: try:
m = validate_inputs(prompt, o, validated) m = validate_inputs(prompt, o, validated)
valid = m[0] valid = m[0]
@ -806,7 +924,7 @@ def validate_prompt(prompt):
"traceback": traceback.format_tb(tb) "traceback": traceback.format_tb(tb)
} }
}] }]
validated[o] = (False, reasons, o) validated[o] = ValidateInputsTuple(False, reasons, o)
if valid is True: if valid is True:
good_outputs.add(o) good_outputs.add(o)
@ -839,8 +957,8 @@ def validate_prompt(prompt):
if len(good_outputs) == 0: if len(good_outputs) == 0:
errors_list = [] errors_list = []
for o, errors in errors: for o, _errors in errors:
for error in errors: for error in _errors:
errors_list.append(f"{error['message']}: {error['details']}") errors_list.append(f"{error['message']}: {error['details']}")
errors_list = "\n".join(errors_list) errors_list = "\n".join(errors_list)
@ -851,73 +969,78 @@ def validate_prompt(prompt):
"extra_info": {} "extra_info": {}
} }
return (False, error, list(good_outputs), node_errors) return ValidationTuple(False, error, list(good_outputs), node_errors)
return (True, None, list(good_outputs), node_errors) return ValidationTuple(True, None, list(good_outputs), node_errors)
MAXIMUM_HISTORY_SIZE = 10000
class PromptQueue: class PromptQueue(AbstractPromptQueue):
def __init__(self, server): def __init__(self, server: ExecutorToClientProgress):
self.server = server self.server = server
self.mutex = threading.RLock() self.mutex = threading.RLock()
self.not_empty = threading.Condition(self.mutex) self.not_empty = threading.Condition(self.mutex)
self.task_counter = 0 self.queue: typing.List[QueueItem] = []
self.queue = [] self.currently_running: typing.Dict[str, QueueItem] = {}
self.currently_running = {} # history maps the second integer prompt id in the queue tuple to a dictionary with keys "prompt" and "outputs
self.history = {} # todo: use the new History class for the sake of simplicity
self.history: typing.Dict[str, HistoryEntry] = {}
self.flags = {} self.flags = {}
server.prompt_queue = self
def put(self, item): def size(self) -> int:
return len(self.queue)
def put(self, item: QueueItem):
with self.mutex: with self.mutex:
heapq.heappush(self.queue, item) heapq.heappush(self.queue, item)
self.server.queue_updated() self.server.queue_updated()
self.not_empty.notify() self.not_empty.notify()
def get(self, timeout=None): def get(self, timeout=None) -> typing.Optional[typing.Tuple[QueueTuple, str]]:
with self.not_empty: with self.not_empty:
while len(self.queue) == 0: while len(self.queue) == 0:
self.not_empty.wait(timeout=timeout) self.not_empty.wait(timeout=timeout)
if timeout is not None and len(self.queue) == 0: if timeout is not None and len(self.queue) == 0:
return None return None
item = heapq.heappop(self.queue) item_with_future: QueueItem = heapq.heappop(self.queue)
i = self.task_counter assert item_with_future.prompt_id is not None
self.currently_running[i] = copy.deepcopy(item) assert item_with_future.prompt_id != ""
self.task_counter += 1 assert item_with_future.prompt_id not in self.currently_running
assert isinstance(item_with_future.prompt_id, str)
task_id = item_with_future.prompt_id
self.currently_running[task_id] = item_with_future
self.server.queue_updated() self.server.queue_updated()
return (item, i) return copy.deepcopy(item_with_future.queue_tuple), task_id
class ExecutionStatus(NamedTuple): def task_done(self, item_id: str, outputs: dict,
status_str: Literal['success', 'error'] status: Optional[ExecutionStatus]):
completed: bool history_result = outputs
messages: List[str]
def task_done(self, item_id, history_result,
status: Optional['PromptQueue.ExecutionStatus']):
with self.mutex: with self.mutex:
prompt = self.currently_running.pop(item_id) queue_item = self.currently_running.pop(item_id)
prompt = queue_item.queue_tuple
if len(self.history) > MAXIMUM_HISTORY_SIZE: if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history))) self.history.pop(next(iter(self.history)))
status_dict: Optional[dict] = None status_dict: Optional[dict] = None
if status is not None: if status is not None:
status_dict = copy.deepcopy(status._asdict()) status_dict = copy.deepcopy(ExecutionStatus(*status)._asdict())
outputs_ = history_result["outputs"]
self.history[prompt[1]] = { self.history[prompt[1]] = {
"prompt": prompt, "prompt": prompt,
"outputs": {}, "outputs": copy.deepcopy(outputs_),
'status': status_dict, 'status': status_dict,
} }
self.history[prompt[1]].update(history_result) self.history[prompt[1]].update(history_result)
self.server.queue_updated() self.server.queue_updated()
if queue_item.completed:
queue_item.completed.set_result(outputs_)
def get_current_queue(self): def get_current_queue(self) -> Tuple[typing.List[QueueTuple], typing.List[QueueTuple]]:
with self.mutex: with self.mutex:
out = [] out: typing.List[QueueTuple] = []
for x in self.currently_running.values(): for x in self.currently_running.values():
out += [x] out += [x.queue_tuple]
return (out, copy.deepcopy(self.queue)) return out, copy.deepcopy([item.queue_tuple for item in self.queue])
def get_tasks_remaining(self): def get_tasks_remaining(self):
with self.mutex: with self.mutex:
@ -925,17 +1048,22 @@ class PromptQueue:
def wipe_queue(self): def wipe_queue(self):
with self.mutex: with self.mutex:
for item in self.queue:
if item.completed:
item.completed.set_exception(Exception("queue cancelled"))
self.queue = [] self.queue = []
self.server.queue_updated() self.server.queue_updated()
def delete_queue_item(self, function): def delete_queue_item(self, function):
with self.mutex: with self.mutex:
for x in range(len(self.queue)): for x in range(len(self.queue)):
if function(self.queue[x]): if function(self.queue[x].queue_tuple):
if len(self.queue) == 1: if len(self.queue) == 1:
self.wipe_queue() self.wipe_queue()
else: else:
self.queue.pop(x) item = self.queue.pop(x)
if item.completed:
item.completed.set_exception(Exception("queue item deleted"))
heapq.heapify(self.queue) heapq.heapify(self.queue)
self.server.queue_updated() self.server.queue_updated()
return True return True
@ -962,9 +1090,9 @@ class PromptQueue:
def wipe_history(self): def wipe_history(self):
with self.mutex: with self.mutex:
self.history = {} self.history.clear()
def delete_history_item(self, id_to_delete): def delete_history_item(self, id_to_delete: str):
with self.mutex: with self.mutex:
self.history.pop(id_to_delete, None) self.history.pop(id_to_delete, None)

View File

@ -1,15 +1,7 @@
import nodes from .cmd.execution import nodes
from .component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError
from .graph_utils import is_link
from comfy_execution.graph_utils import is_link
class DependencyCycleError(Exception):
pass
class NodeInputError(Exception):
pass
class NodeNotFoundError(Exception):
pass
class DynamicPrompt: class DynamicPrompt:
def __init__(self, original_prompt): def __init__(self, original_prompt):
@ -54,6 +46,7 @@ class DynamicPrompt:
def get_original_prompt(self): def get_original_prompt(self):
return self.original_prompt return self.original_prompt
def get_input_info(class_def, input_name): def get_input_info(class_def, input_name):
valid_inputs = class_def.INPUT_TYPES() valid_inputs = class_def.INPUT_TYPES()
input_info = None input_info = None
@ -76,12 +69,13 @@ def get_input_info(class_def, input_name):
extra_info = {} extra_info = {}
return input_type, input_category, extra_info return input_type, input_category, extra_info
class TopologicalSort: class TopologicalSort:
def __init__(self, dynprompt): def __init__(self, dynprompt):
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.pendingNodes = {} self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node self.blocking = {} # Which nodes are blocked by this node
def get_input_info(self, unique_id, input_name): def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"] class_type = self.dynprompt.get_node(unique_id)["class_type"]
@ -136,11 +130,13 @@ class TopologicalSort:
def is_empty(self): def is_empty(self):
return len(self.pendingNodes) == 0 return len(self.pendingNodes) == 0
class ExecutionList(TopologicalSort): class ExecutionList(TopologicalSort):
""" """
ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
it can still be returned to the graph after having further dependencies added. it can still be returned to the graph after having further dependencies added.
""" """
def __init__(self, dynprompt, output_cache): def __init__(self, dynprompt, output_cache):
super().__init__(dynprompt) super().__init__(dynprompt)
self.output_cache = output_cache self.output_cache = output_cache
@ -203,7 +199,7 @@ class ExecutionList(TopologicalSort):
# We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle. # We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
# We're skipping some of the performance optimizations from the original TopologicalSort to keep # We're skipping some of the performance optimizations from the original TopologicalSort to keep
# the code simple (and because having a cycle in the first place is a catastrophic error) # the code simple (and because having a cycle in the first place is a catastrophic error)
blocked_by = { node_id: {} for node_id in self.pendingNodes } blocked_by = {node_id: {} for node_id in self.pendingNodes}
for from_node_id in self.blocking: for from_node_id in self.blocking:
for to_node_id in self.blocking[from_node_id]: for to_node_id in self.blocking[from_node_id]:
if True in self.blocking[from_node_id][to_node_id].values(): if True in self.blocking[from_node_id][to_node_id].values():
@ -218,6 +214,7 @@ class ExecutionList(TopologicalSort):
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
return list(blocked_by.keys()) return list(blocked_by.keys())
class ExecutionBlocker: class ExecutionBlocker:
""" """
Return this from a node and any users will be blocked with the given error message. Return this from a node and any users will be blocked with the given error message.
@ -232,6 +229,6 @@ class ExecutionBlocker:
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with (I would recommend not making nodes like this in the future -- instead, make multiple nodes with
different outputs. Unfortunately, there are several popular existing nodes using this pattern.) different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
""" """
def __init__(self, message): def __init__(self, message):
self.message = message self.message = message

View File

@ -9,13 +9,16 @@ def is_link(obj):
return False return False
return True return True
# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
class GraphBuilder: class GraphBuilder:
"""
The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
"""
_default_prefix_root = "" _default_prefix_root = ""
_default_prefix_call_index = 0 _default_prefix_call_index = 0
_default_prefix_graph_index = 0 _default_prefix_graph_index = 0
def __init__(self, prefix = None): def __init__(self, prefix=None):
if prefix is None: if prefix is None:
self.prefix = GraphBuilder.alloc_prefix() self.prefix = GraphBuilder.alloc_prefix()
else: else:
@ -24,7 +27,7 @@ class GraphBuilder:
self.id_gen = 1 self.id_gen = 1
@classmethod @classmethod
def set_default_prefix(cls, prefix_root, call_index, graph_index = 0): def set_default_prefix(cls, prefix_root, call_index, graph_index=0):
cls._default_prefix_root = prefix_root cls._default_prefix_root = prefix_root
cls._default_prefix_call_index = call_index cls._default_prefix_call_index = call_index
cls._default_prefix_graph_index = graph_index cls._default_prefix_graph_index = graph_index
@ -80,6 +83,7 @@ class GraphBuilder:
id = self.prefix + id id = self.prefix + id
del self.nodes[id] del self.nodes[id]
class Node: class Node:
def __init__(self, id, class_type, inputs): def __init__(self, id, class_type, inputs):
self.id = id self.id = id
@ -112,13 +116,14 @@ class Node:
serialized["override_display_id"] = self.override_display_id serialized["override_display_id"] = self.override_display_id
return serialized return serialized
def add_graph_prefix(graph, outputs, prefix): def add_graph_prefix(graph, outputs, prefix):
# Change the node IDs and any internal links # Change the node IDs and any internal links
new_graph = {} new_graph = {}
for node_id, node_info in graph.items(): for node_id, node_info in graph.items():
# Make sure the added nodes have unique IDs # Make sure the added nodes have unique IDs
new_node_id = prefix + node_id new_node_id = prefix + node_id
new_node = { "class_type": node_info["class_type"], "inputs": {} } new_node = {"class_type": node_info["class_type"], "inputs": {}}
for input_name, input_value in node_info.get("inputs", {}).items(): for input_name, input_value in node_info.get("inputs", {}).items():
if is_link(input_value): if is_link(input_value):
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]] new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
@ -136,4 +141,3 @@ def add_graph_prefix(graph, outputs, prefix):
new_outputs.append(output) new_outputs.append(output)
return new_graph, tuple(new_outputs) return new_graph, tuple(new_outputs)