mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-08 13:20:50 +08:00
Merge branch 'execution_model_inversion' of github.com:guill/ComfyUI into pr-execution
This commit is contained in:
commit
87ab9d42d0
305
comfy/caching.py
Normal file
305
comfy/caching.py
Normal file
@ -0,0 +1,305 @@
|
||||
import itertools
|
||||
from typing import Sequence, Mapping
|
||||
|
||||
from .cmd.execution import nodes
|
||||
from .graph import DynamicPrompt
|
||||
from .graph_utils import is_link
|
||||
|
||||
|
||||
class CacheKeySet:
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.keys = {}
|
||||
self.subcache_keys = {}
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
raise NotImplementedError()
|
||||
|
||||
def all_node_ids(self):
|
||||
return set(self.keys.keys())
|
||||
|
||||
def get_used_keys(self):
|
||||
return self.keys.values()
|
||||
|
||||
def get_used_subcache_keys(self):
|
||||
return self.subcache_keys.values()
|
||||
|
||||
def get_data_key(self, node_id):
|
||||
return self.keys.get(node_id, None)
|
||||
|
||||
def get_subcache_key(self, node_id):
|
||||
return self.subcache_keys.get(node_id, None)
|
||||
|
||||
|
||||
class Unhashable:
|
||||
def __init__(self):
|
||||
self.value = float("NaN")
|
||||
|
||||
|
||||
def to_hashable(obj):
|
||||
# So that we don't infinitely recurse since frozenset and tuples
|
||||
# are Sequences.
|
||||
if isinstance(obj, (int, float, str, bool, type(None))):
|
||||
return obj
|
||||
elif isinstance(obj, Mapping):
|
||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
||||
elif isinstance(obj, Sequence):
|
||||
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
|
||||
else:
|
||||
# TODO - Support other objects like tensors?
|
||||
return Unhashable()
|
||||
|
||||
|
||||
class CacheKeySetID(CacheKeySet):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = (node_id, node["class_type"])
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
|
||||
class CacheKeySetInputSignature(CacheKeySet):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
self.dynprompt = dynprompt
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.add_keys(node_ids)
|
||||
|
||||
def include_node_id_in_input(self) -> bool:
|
||||
return False
|
||||
|
||||
def add_keys(self, node_ids):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
|
||||
def get_node_signature(self, dynprompt, node_id):
|
||||
signature = []
|
||||
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||
for ancestor_id in ancestors:
|
||||
signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping))
|
||||
return to_hashable(signature)
|
||||
|
||||
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
node = dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
signature = [class_type, self.is_changed_cache.get(node_id)]
|
||||
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT):
|
||||
signature.append(node_id)
|
||||
inputs = node["inputs"]
|
||||
for key in sorted(inputs.keys()):
|
||||
if is_link(inputs[key]):
|
||||
(ancestor_id, ancestor_socket) = inputs[key]
|
||||
ancestor_index = ancestor_order_mapping[ancestor_id]
|
||||
signature.append((key, ("ANCESTOR", ancestor_index, ancestor_socket)))
|
||||
else:
|
||||
signature.append((key, inputs[key]))
|
||||
return signature
|
||||
|
||||
# This function returns a list of all ancestors of the given node. The order of the list is
|
||||
# deterministic based on which specific inputs the ancestor is connected by.
|
||||
def get_ordered_ancestry(self, dynprompt, node_id):
|
||||
ancestors = []
|
||||
order_mapping = {}
|
||||
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
||||
return ancestors, order_mapping
|
||||
|
||||
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||
input_keys = sorted(inputs.keys())
|
||||
for key in input_keys:
|
||||
if is_link(inputs[key]):
|
||||
ancestor_id = inputs[key][0]
|
||||
if ancestor_id not in order_mapping:
|
||||
ancestors.append(ancestor_id)
|
||||
order_mapping[ancestor_id] = len(ancestors) - 1
|
||||
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
||||
|
||||
|
||||
class BasicCache:
|
||||
def __init__(self, key_class):
|
||||
self.key_class = key_class
|
||||
self.initialized = False
|
||||
self.dynprompt: DynamicPrompt
|
||||
self.cache_key_set: CacheKeySet
|
||||
self.cache = {}
|
||||
self.subcaches = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.dynprompt = dynprompt
|
||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||
self.is_changed_cache = is_changed_cache
|
||||
self.initialized = True
|
||||
|
||||
def all_node_ids(self):
|
||||
assert self.initialized
|
||||
node_ids = self.cache_key_set.all_node_ids()
|
||||
for subcache in self.subcaches.values():
|
||||
node_ids = node_ids.union(subcache.all_node_ids())
|
||||
return node_ids
|
||||
|
||||
def _clean_cache(self):
|
||||
preserve_keys = set(self.cache_key_set.get_used_keys())
|
||||
to_remove = []
|
||||
for key in self.cache:
|
||||
if key not in preserve_keys:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
del self.cache[key]
|
||||
|
||||
def _clean_subcaches(self):
|
||||
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
||||
|
||||
to_remove = []
|
||||
for key in self.subcaches:
|
||||
if key not in preserve_subcaches:
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
del self.subcaches[key]
|
||||
|
||||
def clean_unused(self):
|
||||
assert self.initialized
|
||||
self._clean_cache()
|
||||
self._clean_subcaches()
|
||||
|
||||
def _set_immediate(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
def _get_immediate(self, node_id):
|
||||
if not self.initialized:
|
||||
return None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _ensure_subcache(self, node_id, children_ids):
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
subcache = self.subcaches.get(subcache_key, None)
|
||||
if subcache is None:
|
||||
subcache = BasicCache(self.key_class)
|
||||
self.subcaches[subcache_key] = subcache
|
||||
subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||
return subcache
|
||||
|
||||
def _get_subcache(self, node_id):
|
||||
assert self.initialized
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
if subcache_key in self.subcaches:
|
||||
return self.subcaches[subcache_key]
|
||||
else:
|
||||
return None
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = []
|
||||
for key in self.cache:
|
||||
result.append({"key": key, "value": self.cache[key]})
|
||||
for key in self.subcaches:
|
||||
result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()})
|
||||
return result
|
||||
|
||||
|
||||
class HierarchicalCache(BasicCache):
|
||||
def __init__(self, key_class):
|
||||
super().__init__(key_class)
|
||||
|
||||
def _get_cache_for(self, node_id):
|
||||
assert self.dynprompt is not None
|
||||
parent_id = self.dynprompt.get_parent_node_id(node_id)
|
||||
if parent_id is None:
|
||||
return self
|
||||
|
||||
hierarchy = []
|
||||
while parent_id is not None:
|
||||
hierarchy.append(parent_id)
|
||||
parent_id = self.dynprompt.get_parent_node_id(parent_id)
|
||||
|
||||
cache = self
|
||||
for parent_id in reversed(hierarchy):
|
||||
cache = cache._get_subcache(parent_id)
|
||||
if cache is None:
|
||||
return None
|
||||
return cache
|
||||
|
||||
def get(self, node_id):
|
||||
cache = self._get_cache_for(node_id)
|
||||
if cache is None:
|
||||
return None
|
||||
return cache._get_immediate(node_id)
|
||||
|
||||
def set(self, node_id, value):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
cache._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
return cache._ensure_subcache(node_id, children_ids)
|
||||
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100):
|
||||
super().__init__(key_class)
|
||||
self.max_size = max_size
|
||||
self.min_generation = 0
|
||||
self.generation = 0
|
||||
self.used_generation = {}
|
||||
self.children = {}
|
||||
|
||||
def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
self.generation += 1
|
||||
for node_id in node_ids:
|
||||
self._mark_used(node_id)
|
||||
|
||||
def clean_unused(self):
|
||||
while len(self.cache) > self.max_size and self.min_generation < self.generation:
|
||||
self.min_generation += 1
|
||||
to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation]
|
||||
for key in to_remove:
|
||||
del self.cache[key]
|
||||
del self.used_generation[key]
|
||||
if key in self.children:
|
||||
del self.children[key]
|
||||
self._clean_subcaches()
|
||||
|
||||
def get(self, node_id):
|
||||
self._mark_used(node_id)
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
def _mark_used(self, node_id):
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key is not None:
|
||||
self.used_generation[cache_key] = self.generation
|
||||
|
||||
def set(self, node_id, value):
|
||||
self._mark_used(node_id)
|
||||
return self._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
# Just uses subcaches for tracking 'live' nodes
|
||||
super()._ensure_subcache(node_id, children_ids)
|
||||
|
||||
self.cache_key_set.add_keys(children_ids)
|
||||
self._mark_used(node_id)
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.children[cache_key] = []
|
||||
for child_id in children_ids:
|
||||
self._mark_used(child_id)
|
||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||
return self
|
||||
@ -87,6 +87,8 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.Auto,
|
||||
help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true",
|
||||
help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
@ -211,11 +213,11 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
return parser
|
||||
|
||||
|
||||
def _parse_args(parser: Optional[argparse.ArgumentParser] = None) -> Configuration:
|
||||
def _parse_args(parser: Optional[argparse.ArgumentParser] = None, args_parsing: bool = False) -> Configuration:
|
||||
if parser is None:
|
||||
parser = _create_parser()
|
||||
|
||||
if options.args_parsing:
|
||||
if args_parsing:
|
||||
args, _, config_files = parser.parse_known_args_with_config_files()
|
||||
else:
|
||||
args, _, config_files = parser.parse_known_args_with_config_files([])
|
||||
@ -261,4 +263,8 @@ def _setup_config_file_watcher(config: Configuration, parser: EnhancedConfigArgP
|
||||
atexit.register(observer.join)
|
||||
|
||||
|
||||
args = _parse_args()
|
||||
def default_configuration() -> Configuration:
|
||||
return _parse_args(_create_parser())
|
||||
|
||||
|
||||
args = _parse_args(args_parsing=options.args_parsing)
|
||||
|
||||
@ -70,6 +70,7 @@ class Configuration(dict):
|
||||
directml (Optional[int]): Use DirectML. -1 for auto-selection.
|
||||
disable_ipex_optimize (bool): Disable IPEX optimization for Intel GPUs.
|
||||
preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "auto".
|
||||
cache_lru (int): Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.
|
||||
use_split_cross_attention (bool): Use split cross-attention optimization.
|
||||
use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization.
|
||||
use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function.
|
||||
@ -179,6 +180,9 @@ class Configuration(dict):
|
||||
self.force_channels_last: bool = False
|
||||
self.force_hf_local_dir_mode = False
|
||||
|
||||
# from guill
|
||||
self.cache_lru: int = 0
|
||||
|
||||
# from opentracing docs
|
||||
self.otel_service_name: str = "comfyui"
|
||||
self.otel_service_version: str = "0.0.1"
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import json
|
||||
import uuid
|
||||
from asyncio import get_event_loop
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@ -107,7 +108,7 @@ class EmbeddedComfyClient:
|
||||
|
||||
from ..cmd.execution import PromptExecutor
|
||||
|
||||
self._prompt_executor = PromptExecutor(self._progress_handler)
|
||||
self._prompt_executor = PromptExecutor(self._progress_handler, lru_size=self._configuration.cache_lru if self._configuration is not None else 0)
|
||||
self._prompt_executor.raise_exceptions = True
|
||||
|
||||
await get_event_loop().run_in_executor(self._executor, create_executor_in_thread)
|
||||
@ -128,9 +129,9 @@ class EmbeddedComfyClient:
|
||||
try:
|
||||
prompt_mut = make_mutable(prompt)
|
||||
validation_tuple = validate_prompt(prompt_mut)
|
||||
if not validation_tuple[0]:
|
||||
validation_error_dict = validation_tuple[1] or {"message": "Unknown", "details": ""}
|
||||
raise ValueError("\n".join([validation_error_dict["message"], validation_error_dict["details"]]))
|
||||
if not validation_tuple.valid:
|
||||
validation_error_dict = {"message": "Unknown", "details": ""} if not validation_tuple.node_errors or len(validation_tuple.node_errors) == 0 else validation_tuple.node_errors
|
||||
raise ValueError(json.dumps(validation_error_dict))
|
||||
|
||||
prompt_executor: PromptExecutor = self._prompt_executor
|
||||
|
||||
@ -140,7 +141,7 @@ class EmbeddedComfyClient:
|
||||
prompt_executor.server = self._progress_handler
|
||||
|
||||
prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id},
|
||||
execute_outputs=validation_tuple[2])
|
||||
execute_outputs=validation_tuple.good_output_node_ids)
|
||||
return prompt_executor.outputs_ui
|
||||
except Exception as exc_info:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
|
||||
@ -20,7 +20,8 @@ from .. import model_management
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
||||
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \
|
||||
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted
|
||||
RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \
|
||||
HistoryResultDict
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||
from ..execution_context import new_execution_context, ExecutionContext
|
||||
from ..nodes.package import import_all_nodes_in_workspace
|
||||
@ -29,46 +30,122 @@ from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOption
|
||||
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
||||
# various functions that are declared here. It should have been a context in the first place.
|
||||
nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace)
|
||||
from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy.graph_utils import is_link, GraphBuilder
|
||||
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None):
|
||||
class IsChangedCache:
|
||||
def __init__(self, dynprompt, outputs_cache):
|
||||
self.dynprompt = dynprompt
|
||||
self.outputs_cache = outputs_cache
|
||||
self.is_changed = {}
|
||||
|
||||
def get(self, node_id):
|
||||
if node_id in self.is_changed:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if not hasattr(class_def, "IS_CHANGED"):
|
||||
self.is_changed[node_id] = False
|
||||
return self.is_changed[node_id]
|
||||
|
||||
if "is_changed" in node:
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
return self.is_changed[node_id]
|
||||
|
||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
|
||||
try:
|
||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
except:
|
||||
node["is_changed"] = float("NaN")
|
||||
finally:
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
return self.is_changed[node_id]
|
||||
|
||||
|
||||
class CacheSet:
|
||||
def __init__(self, lru_size=None):
|
||||
if lru_size is None or lru_size == 0:
|
||||
self.init_classic_cache()
|
||||
else:
|
||||
self.init_lru_cache(lru_size)
|
||||
self.all = [self.outputs, self.ui, self.objects]
|
||||
|
||||
# Useful for those with ample RAM/VRAM -- allows experimenting without
|
||||
# blowing away the cache every time
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
# Performs like the old cache -- dump data ASAP
|
||||
def init_classic_cache(self):
|
||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = {
|
||||
"outputs": self.outputs.recursive_debug_dump(),
|
||||
"ui": self.ui.recursive_debug_dump(),
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data=None):
|
||||
if extra_data is None:
|
||||
extra_data = {}
|
||||
if prompt is None:
|
||||
prompt = {}
|
||||
if outputs is None:
|
||||
outputs = {}
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_data_all = {}
|
||||
missing_keys = {}
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
if isinstance(input_data, list):
|
||||
input_type, input_category, input_info = get_input_info(class_def, x)
|
||||
|
||||
def mark_missing():
|
||||
missing_keys[x] = True
|
||||
input_data_all[x] = (None,)
|
||||
|
||||
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
input_data_all[x] = (None,)
|
||||
if outputs is None:
|
||||
mark_missing()
|
||||
continue # This might be a lazily-evaluated input
|
||||
cached_output = outputs.get(input_unique_id)
|
||||
if cached_output is None:
|
||||
mark_missing()
|
||||
continue
|
||||
obj = outputs[input_unique_id][output_index]
|
||||
if output_index >= len(cached_output):
|
||||
mark_missing()
|
||||
continue
|
||||
obj = cached_output[output_index]
|
||||
input_data_all[x] = obj
|
||||
else:
|
||||
if ("required" in valid_inputs and x in valid_inputs["required"]) or (
|
||||
"optional" in valid_inputs and x in valid_inputs["optional"]):
|
||||
input_data_all[x] = [input_data]
|
||||
elif input_category is not None:
|
||||
input_data_all[x] = [input_data]
|
||||
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
for x in h:
|
||||
if h[x] == "PROMPT":
|
||||
input_data_all[x] = [prompt]
|
||||
input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}]
|
||||
if h[x] == "DYNPROMPT":
|
||||
input_data_all[x] = [dynprompt]
|
||||
if h[x] == "EXTRA_PNGINFO":
|
||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||
if h[x] == "UNIQUE_ID":
|
||||
input_data_all[x] = [unique_id]
|
||||
return input_data_all
|
||||
return input_data_all, missing_keys
|
||||
|
||||
|
||||
@tracer.start_as_current_span("Execute Node")
|
||||
def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False):
|
||||
def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
span = get_current_span()
|
||||
class_type = obj.__class__.__name__
|
||||
span.set_attribute("class_type", class_type)
|
||||
@ -95,73 +172,103 @@ def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func:
|
||||
span.set_attribute(f"input_data_all.{kwarg_name}", filtered_items)
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
# check if node wants the lists
|
||||
input_is_list = False
|
||||
if hasattr(obj, "INPUT_IS_LIST"):
|
||||
input_is_list = obj.INPUT_IS_LIST
|
||||
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
|
||||
|
||||
if len(input_data_all) == 0:
|
||||
max_len_input = 0
|
||||
else:
|
||||
max_len_input = max([len(x) for x in input_data_all.values()])
|
||||
max_len_input = max(len(x) for x in input_data_all.values())
|
||||
|
||||
# get a slice of inputs, repeat last input when list isn't long enough
|
||||
def slice_dict(d, i):
|
||||
d_new = dict()
|
||||
for k, v in d.items():
|
||||
d_new[k] = v[i if len(v) > i else -1]
|
||||
return d_new
|
||||
return {k: v[i if len(v) > i else -1] for k, v in d.items()}
|
||||
|
||||
results = []
|
||||
|
||||
def process_inputs(inputs, index=None):
|
||||
if allow_interrupt:
|
||||
interruption.throw_exception_if_processing_interrupted()
|
||||
execution_block = None
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, ExecutionBlocker):
|
||||
execution_block = execution_block_cb(v) if execution_block_cb else v
|
||||
break
|
||||
if execution_block is None:
|
||||
if pre_execute_cb is not None and index is not None:
|
||||
pre_execute_cb(index)
|
||||
results.append(getattr(obj, func)(**inputs))
|
||||
else:
|
||||
results.append(execution_block)
|
||||
|
||||
if input_is_list:
|
||||
if allow_interrupt:
|
||||
interruption.throw_exception_if_processing_interrupted()
|
||||
results.append(getattr(obj, func)(**input_data_all))
|
||||
process_inputs(input_data_all, 0)
|
||||
elif max_len_input == 0:
|
||||
if allow_interrupt:
|
||||
interruption.throw_exception_if_processing_interrupted()
|
||||
results.append(getattr(obj, func)())
|
||||
process_inputs({})
|
||||
else:
|
||||
for i in range(max_len_input):
|
||||
if allow_interrupt:
|
||||
interruption.throw_exception_if_processing_interrupted()
|
||||
results.append(getattr(obj, func)(**slice_dict(input_data_all, i)))
|
||||
input_dict = slice_dict(input_data_all, i)
|
||||
process_inputs(input_dict, i)
|
||||
return results
|
||||
|
||||
|
||||
def get_output_data(obj, input_data_all) -> Tuple[List[typing.Any], typing.Dict[str, List[typing.Any]]]:
|
||||
def merge_result_data(results, obj):
|
||||
# check which outputs need concatenating
|
||||
output = []
|
||||
output_is_list = [False] * len(results[0])
|
||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||
output_is_list = obj.OUTPUT_IS_LIST
|
||||
|
||||
# merge node execution results
|
||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||
if is_list:
|
||||
output.append([x for o in results for x in o[i]])
|
||||
else:
|
||||
output.append([o[i] for o in results])
|
||||
return output
|
||||
|
||||
|
||||
def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
|
||||
results = []
|
||||
uis = []
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True)
|
||||
|
||||
for r in return_values:
|
||||
subgraph_results = []
|
||||
return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
has_subgraph = False
|
||||
for i in range(len(return_values)):
|
||||
r = return_values[i]
|
||||
if isinstance(r, dict):
|
||||
if 'ui' in r:
|
||||
uis.append(r['ui'])
|
||||
if 'result' in r:
|
||||
results.append(r['result'])
|
||||
if 'expand' in r:
|
||||
# Perform an expansion, but do not append results
|
||||
has_subgraph = True
|
||||
new_graph = r['expand']
|
||||
result = r.get("result", None)
|
||||
if isinstance(result, ExecutionBlocker):
|
||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||
subgraph_results.append((new_graph, result))
|
||||
elif 'result' in r:
|
||||
result = r.get("result", None)
|
||||
if isinstance(result, ExecutionBlocker):
|
||||
result = tuple([result] * len(obj.RETURN_TYPES))
|
||||
results.append(result)
|
||||
subgraph_results.append((None, result))
|
||||
else:
|
||||
if isinstance(r, ExecutionBlocker):
|
||||
r = tuple([r] * len(obj.RETURN_TYPES))
|
||||
results.append(r)
|
||||
subgraph_results.append((None, r))
|
||||
|
||||
output = []
|
||||
if len(results) > 0:
|
||||
# check which outputs need concatenating
|
||||
output_is_list = [False] * len(results[0])
|
||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||
output_is_list = obj.OUTPUT_IS_LIST
|
||||
|
||||
# merge node execution results
|
||||
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||
if is_list:
|
||||
output.append([x for o in results for x in o[i]])
|
||||
else:
|
||||
output.append([o[i] for o in results])
|
||||
|
||||
if has_subgraph:
|
||||
output = subgraph_results
|
||||
elif len(results) > 0:
|
||||
output = merge_result_data(results, obj)
|
||||
else:
|
||||
output = []
|
||||
ui = dict()
|
||||
if len(uis) > 0:
|
||||
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
|
||||
return output, ui
|
||||
return output, ui, has_subgraph
|
||||
|
||||
|
||||
def format_value(x) -> FormattedValue:
|
||||
@ -173,63 +280,148 @@ def format_value(x) -> FormattedValue:
|
||||
return str(x)
|
||||
|
||||
|
||||
def recursive_execute(server: ExecutorToClientProgress,
|
||||
prompt,
|
||||
outputs,
|
||||
current_item,
|
||||
extra_data,
|
||||
executed,
|
||||
prompt_id,
|
||||
outputs_ui,
|
||||
object_storage) -> RecursiveExecutionTuple:
|
||||
def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
parent_node_id = dynprompt.get_parent_node_id(unique_id)
|
||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if unique_id in outputs:
|
||||
return RecursiveExecutionTuple(True, None, None)
|
||||
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id,
|
||||
outputs_ui, object_storage)
|
||||
if result[0] is not True:
|
||||
# Another node failed further upstream
|
||||
return result
|
||||
if caches.outputs.get(unique_id) is not None:
|
||||
if server.client_id is not None:
|
||||
cached_output = caches.ui.get(unique_id) or {}
|
||||
server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": cached_output.get("output", None), "prompt_id": prompt_id}, server.client_id)
|
||||
return RecursiveExecutionTuple(ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
input_data_all = None
|
||||
try:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = unique_id
|
||||
server.send_sync("executing", {"node": unique_id, "prompt_id": prompt_id}, server.client_id)
|
||||
if unique_id in pending_subgraph_results:
|
||||
cached_results = pending_subgraph_results[unique_id]
|
||||
resolved_outputs = []
|
||||
for is_subgraph, result in cached_results:
|
||||
if not is_subgraph:
|
||||
resolved_outputs.append(result)
|
||||
else:
|
||||
resolved_output = []
|
||||
for r in result:
|
||||
if is_link(r):
|
||||
source_node, source_output = r[0], r[1]
|
||||
node_output = caches.outputs.get(source_node)[source_output]
|
||||
for o in node_output:
|
||||
resolved_output.append(o)
|
||||
|
||||
obj = object_storage.get((unique_id, class_type), None)
|
||||
if obj is None:
|
||||
obj = class_def()
|
||||
object_storage[(unique_id, class_type)] = obj
|
||||
|
||||
output_data, output_ui = get_output_data(obj, input_data_all)
|
||||
outputs[unique_id] = output_data
|
||||
if len(output_ui) > 0:
|
||||
outputs_ui[unique_id] = output_ui
|
||||
else:
|
||||
resolved_output.append(r)
|
||||
resolved_outputs.append(tuple(resolved_output))
|
||||
output_data = merge_result_data(resolved_outputs, class_def)
|
||||
output_ui = []
|
||||
has_subgraph = False
|
||||
else:
|
||||
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id},
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", {"node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id}, server.client_id)
|
||||
|
||||
obj = caches.objects.get(unique_id)
|
||||
if obj is None:
|
||||
obj = class_def()
|
||||
caches.objects.set(unique_id, obj)
|
||||
|
||||
if hasattr(obj, "check_lazy_status"):
|
||||
required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x, str) and (
|
||||
x not in input_data_all or x in missing_keys
|
||||
)]
|
||||
if len(required_inputs) > 0:
|
||||
for i in required_inputs:
|
||||
execution_list.make_input_strong_link(unique_id, i)
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
|
||||
def execution_block_cb(block):
|
||||
if block.message is not None:
|
||||
mes = {
|
||||
"prompt_id": prompt_id,
|
||||
"node_id": unique_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
|
||||
"exception_message": f"Execution Blocked: {block.message}",
|
||||
"exception_type": "ExecutionBlocked",
|
||||
"traceback": [],
|
||||
"current_inputs": [],
|
||||
"current_outputs": [],
|
||||
}
|
||||
server.send_sync("execution_error", mes, server.client_id)
|
||||
return ExecutionBlocker(None)
|
||||
else:
|
||||
return block
|
||||
|
||||
def pre_execute_cb(call_index):
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
|
||||
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
|
||||
if len(output_ui) > 0:
|
||||
caches.ui.set(unique_id, {
|
||||
"meta": {
|
||||
"node_id": unique_id,
|
||||
"display_node": display_node_id,
|
||||
"parent_node": parent_node_id,
|
||||
"real_node_id": real_node_id,
|
||||
},
|
||||
"output": output_ui
|
||||
})
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id},
|
||||
server.client_id)
|
||||
if has_subgraph:
|
||||
cached_outputs = []
|
||||
new_node_ids = []
|
||||
new_output_ids = []
|
||||
new_output_links = []
|
||||
for i in range(len(output_data)):
|
||||
new_graph, node_outputs = output_data[i]
|
||||
if new_graph is None:
|
||||
cached_outputs.append((False, node_outputs))
|
||||
else:
|
||||
# Check for conflicts
|
||||
for node_id in new_graph.keys():
|
||||
if dynprompt.has_node(node_id):
|
||||
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
|
||||
for node_id, node_info in new_graph.items():
|
||||
new_node_ids.append(node_id)
|
||||
display_id = node_info.get("override_display_id", unique_id)
|
||||
dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id)
|
||||
# Figure out if the newly created node is an output node
|
||||
class_type = node_info["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
new_output_ids.append(node_id)
|
||||
for i in range(len(node_outputs)):
|
||||
if is_link(node_outputs[i]):
|
||||
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
|
||||
new_output_links.append((from_node_id, from_socket))
|
||||
cached_outputs.append((True, node_outputs))
|
||||
new_node_ids = set(new_node_ids)
|
||||
for cache in caches.all:
|
||||
cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused()
|
||||
for node_id in new_output_ids:
|
||||
execution_list.add_node(node_id)
|
||||
for link in new_output_links:
|
||||
execution_list.add_strong_link(link[0], link[1], unique_id)
|
||||
pending_subgraph_results[unique_id] = cached_outputs
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
caches.outputs.set(unique_id, output_data)
|
||||
except interruption.InterruptProcessingException as iex:
|
||||
logging.info("Processing interrupted")
|
||||
|
||||
# skip formatting inputs/outputs
|
||||
error_details: RecursiveExecutionErrorDetailsInterrupted = {
|
||||
"node_id": unique_id,
|
||||
"node_id": real_node_id,
|
||||
}
|
||||
|
||||
return RecursiveExecutionTuple(False, error_details, iex)
|
||||
return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, iex)
|
||||
except Exception as ex:
|
||||
typ, _, tb = sys.exc_info()
|
||||
exception_type = full_type_name(typ)
|
||||
@ -239,126 +431,36 @@ def recursive_execute(server: ExecutorToClientProgress,
|
||||
for name, inputs in input_data_all.items():
|
||||
input_data_formatted[name] = [format_value(x) for x in inputs]
|
||||
|
||||
output_data_formatted = {}
|
||||
for node_id, node_outputs in outputs.items():
|
||||
output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs]
|
||||
|
||||
logging.error(f"!!! Exception during processing!!! {ex}")
|
||||
logging.error(f"!!! Exception during processing !!! {ex}")
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
error_details: RecursiveExecutionErrorDetails = {
|
||||
"node_id": unique_id,
|
||||
"node_id": real_node_id,
|
||||
"exception_message": str(ex),
|
||||
"exception_type": exception_type,
|
||||
"traceback": traceback.format_tb(tb),
|
||||
"current_inputs": input_data_formatted,
|
||||
"current_outputs": output_data_formatted
|
||||
"current_inputs": input_data_formatted
|
||||
}
|
||||
return RecursiveExecutionTuple(False, error_details, ex)
|
||||
return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex)
|
||||
|
||||
executed.add(unique_id)
|
||||
|
||||
return RecursiveExecutionTuple(True, None, None)
|
||||
|
||||
|
||||
_memo = {}
|
||||
|
||||
|
||||
def recursive_will_execute(prompt, outputs, current_item, memo: dict[str, list[str]] = None) -> list[str]:
|
||||
if memo is None:
|
||||
memo = _memo
|
||||
unique_id = current_item
|
||||
|
||||
if unique_id in memo:
|
||||
return memo[unique_id]
|
||||
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
will_execute = []
|
||||
if unique_id in outputs:
|
||||
return []
|
||||
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id not in outputs:
|
||||
will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo)
|
||||
|
||||
memo[unique_id] = will_execute + [unique_id]
|
||||
return memo[unique_id]
|
||||
|
||||
|
||||
def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item):
|
||||
unique_id = current_item
|
||||
inputs = prompt[unique_id]['inputs']
|
||||
class_type = prompt[unique_id]['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
|
||||
is_changed_old = ''
|
||||
is_changed = ''
|
||||
to_delete = False
|
||||
if hasattr(class_def, 'IS_CHANGED'):
|
||||
if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]:
|
||||
is_changed_old = old_prompt[unique_id]['is_changed']
|
||||
if 'is_changed' not in prompt[unique_id]:
|
||||
input_data_all = get_input_data(inputs, class_def, unique_id, outputs)
|
||||
if input_data_all is not None:
|
||||
try:
|
||||
# is_changed = class_def.IS_CHANGED(**input_data_all)
|
||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
prompt[unique_id]['is_changed'] = is_changed
|
||||
except:
|
||||
to_delete = True
|
||||
else:
|
||||
is_changed = prompt[unique_id]['is_changed']
|
||||
|
||||
if unique_id not in outputs:
|
||||
return True
|
||||
|
||||
if not to_delete:
|
||||
if is_changed != is_changed_old:
|
||||
to_delete = True
|
||||
elif unique_id not in old_prompt:
|
||||
to_delete = True
|
||||
elif class_type != old_prompt[unique_id]['class_type']:
|
||||
to_delete = True
|
||||
elif inputs == old_prompt[unique_id]['inputs']:
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
|
||||
if isinstance(input_data, list):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if input_unique_id in outputs:
|
||||
to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id)
|
||||
else:
|
||||
to_delete = True
|
||||
if to_delete:
|
||||
break
|
||||
else:
|
||||
to_delete = True
|
||||
|
||||
if to_delete:
|
||||
d = outputs.pop(unique_id)
|
||||
del d
|
||||
return to_delete
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
|
||||
class PromptExecutor:
|
||||
def __init__(self, server: ExecutorToClientProgress):
|
||||
def __init__(self, server: ExecutorToClientProgress, lru_size=None):
|
||||
self.success = None
|
||||
self.lru_size = lru_size
|
||||
self.server = server
|
||||
self.raise_exceptions = False
|
||||
self.reset()
|
||||
self.history_result: HistoryResultDict | None = None
|
||||
|
||||
def reset(self):
|
||||
self.outputs = {}
|
||||
self.object_storage = {}
|
||||
self.outputs_ui = {}
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
self.old_prompt = {}
|
||||
self.caches = CacheSet(self.lru_size)
|
||||
self.status_messages = []
|
||||
|
||||
def add_message(self, event, data, broadcast: bool):
|
||||
self.status_messages.append((event, data))
|
||||
@ -389,27 +491,14 @@ class PromptExecutor:
|
||||
"node_id": node_id,
|
||||
"node_type": class_type,
|
||||
"executed": list(executed),
|
||||
|
||||
"exception_message": error["exception_message"],
|
||||
"exception_type": error["exception_type"],
|
||||
"traceback": error["traceback"],
|
||||
"current_inputs": error["current_inputs"],
|
||||
"current_outputs": error["current_outputs"],
|
||||
"current_outputs": list(current_outputs),
|
||||
}
|
||||
self.add_message("execution_error", mes, broadcast=False)
|
||||
|
||||
# Next, remove the subsequent outputs since they will not be executed
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if (o not in current_outputs) and (o not in executed):
|
||||
to_delete += [o]
|
||||
if o in self.old_prompt:
|
||||
d = self.old_prompt.pop(o)
|
||||
del d
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
|
||||
if ex is not None and self.raise_exceptions:
|
||||
raise ex
|
||||
|
||||
@ -433,69 +522,59 @@ class PromptExecutor:
|
||||
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
with torch.inference_mode():
|
||||
# delete cached outputs if nodes don't exist for them
|
||||
to_delete = []
|
||||
for o in self.outputs:
|
||||
if o not in prompt:
|
||||
to_delete += [o]
|
||||
for o in to_delete:
|
||||
d = self.outputs.pop(o)
|
||||
del d
|
||||
to_delete = []
|
||||
for o in self.object_storage:
|
||||
if o[0] not in prompt:
|
||||
to_delete += [o]
|
||||
else:
|
||||
p = prompt[o[0]]
|
||||
if o[1] != p['class_type']:
|
||||
to_delete += [o]
|
||||
for o in to_delete:
|
||||
d = self.object_storage.pop(o)
|
||||
del d
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
|
||||
for x in prompt:
|
||||
recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x)
|
||||
|
||||
current_outputs = set(self.outputs.keys())
|
||||
for x in list(self.outputs_ui.keys()):
|
||||
if x not in current_outputs:
|
||||
d = self.outputs_ui.pop(x)
|
||||
del d
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
|
||||
model_management.cleanup_models(keep_clone_weights_loaded=True)
|
||||
self.add_message("execution_cached",
|
||||
{"nodes": list(current_outputs), "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
executed = set()
|
||||
output_node_id = None
|
||||
to_execute = []
|
||||
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
for node_id in list(execute_outputs):
|
||||
to_execute += [(0, node_id)]
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
while len(to_execute) > 0:
|
||||
# always execute the output that depends on the least amount of unexecuted nodes first
|
||||
memo = {}
|
||||
to_execute = sorted(list(
|
||||
map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute)))
|
||||
output_node_id = to_execute.pop(0)[-1]
|
||||
|
||||
# This call shouldn't raise anything if there's an error deep in
|
||||
# the actual SD code, instead it will report the node where the
|
||||
# error was raised
|
||||
self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id,
|
||||
extra_data, executed, prompt_id, self.outputs_ui,
|
||||
self.object_storage)
|
||||
if self.success is not True:
|
||||
self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex)
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
for x in executed:
|
||||
self.old_prompt[x] = copy.deepcopy(prompt[x])
|
||||
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
elif result == ExecutionResult.PENDING:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
all_node_ids = self.caches.ui.all_node_ids()
|
||||
for node_id in all_node_ids:
|
||||
ui_info = self.caches.ui.get(node_id)
|
||||
if ui_info is not None:
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
self.server.last_node_id = None
|
||||
if model_management.DISABLE_SMART_MEMORY:
|
||||
model_management.unload_all_models()
|
||||
|
||||
@property
|
||||
def outputs_ui(self) -> dict | None:
|
||||
return self.history_result["outputs"] if self.history_result is not None else None
|
||||
|
||||
|
||||
def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple:
|
||||
# todo: this should check if LoadImage / LoadImageMask paths exist
|
||||
@ -509,7 +588,7 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl
|
||||
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
|
||||
class_inputs = obj_class.INPUT_TYPES()
|
||||
required_inputs = class_inputs['required']
|
||||
valid_inputs = set(class_inputs.get('required', {})).union(set(class_inputs.get('optional', {})))
|
||||
|
||||
error: ValidationErrorDict
|
||||
errors = []
|
||||
@ -522,23 +601,26 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl
|
||||
validate_function_inputs = []
|
||||
if hasattr(obj_class, "VALIDATE_INPUTS"):
|
||||
validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args
|
||||
received_types = {}
|
||||
|
||||
for x in required_inputs:
|
||||
for x in valid_inputs:
|
||||
type_input, input_category, extra_info = get_input_info(obj_class, x)
|
||||
assert extra_info is not None
|
||||
if x not in inputs:
|
||||
error = {
|
||||
"type": "required_input_missing",
|
||||
"message": "Required input is missing",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x
|
||||
if input_category == "required":
|
||||
error = {
|
||||
"type": "required_input_missing",
|
||||
"message": "Required input is missing",
|
||||
"details": f"{x}",
|
||||
"extra_info": {
|
||||
"input_name": x
|
||||
}
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
val = inputs[x]
|
||||
info: InputTypeSpec = required_inputs[x]
|
||||
type_input = info[0]
|
||||
info: InputTypeSpec = (type_input, extra_info)
|
||||
if isinstance(val, list):
|
||||
if len(val) != 2:
|
||||
error = {
|
||||
@ -557,11 +639,9 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl
|
||||
o_id = val[0]
|
||||
o_class_type = prompt[o_id]['class_type']
|
||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||
type_input_from_prompt = r[val[1]]
|
||||
is_combo = all(isinstance(x, typing.List) or isinstance(x, typing.Tuple) for x in (type_input, type_input_from_prompt))
|
||||
is_invalid_string_to_combo = is_combo and len(type_input_from_prompt) != 0
|
||||
if type_input_from_prompt != type_input and is_invalid_string_to_combo:
|
||||
received_type = type_input_from_prompt
|
||||
received_type = r[val[1]]
|
||||
received_types[x] = received_type
|
||||
if 'input_types' not in validate_function_inputs and received_type != type_input:
|
||||
details = f"{x}, {received_type} != {type_input}"
|
||||
error = {
|
||||
"type": "return_type_mismatch",
|
||||
@ -612,6 +692,9 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl
|
||||
if type_input == "STRING":
|
||||
val = str(val)
|
||||
inputs[x] = val
|
||||
if type_input == "BOOLEAN":
|
||||
val = bool(val)
|
||||
inputs[x] = val
|
||||
except Exception as ex:
|
||||
error = {
|
||||
"type": "invalid_input_type",
|
||||
@ -627,7 +710,7 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if len(info) > 1:
|
||||
if x not in validate_function_inputs:
|
||||
has_min_max: IntSpecOptions | FloatSpecOptions = info[1]
|
||||
if "min" in has_min_max and val < has_min_max["min"]:
|
||||
error = {
|
||||
@ -656,7 +739,6 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl
|
||||
errors.append(error)
|
||||
continue
|
||||
|
||||
if x not in validate_function_inputs:
|
||||
if isinstance(type_input, list):
|
||||
if val not in type_input:
|
||||
input_config = info
|
||||
@ -684,17 +766,19 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl
|
||||
continue
|
||||
|
||||
if len(validate_function_inputs) > 0:
|
||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
||||
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
|
||||
input_filtered = {}
|
||||
for x in input_data_all:
|
||||
if x in validate_function_inputs:
|
||||
input_filtered[x] = input_data_all[x]
|
||||
if 'input_types' in validate_function_inputs:
|
||||
input_filtered['input_types'] = [received_types]
|
||||
|
||||
# ret = obj_class.VALIDATE_INPUTS(**input_filtered)
|
||||
ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS")
|
||||
for x in input_filtered:
|
||||
for i, r in enumerate(ret):
|
||||
if r is not True:
|
||||
if r is not True and not isinstance(r, ExecutionBlocker):
|
||||
details = f"{x}"
|
||||
if r is not False:
|
||||
details += f" - {str(r)}"
|
||||
@ -705,8 +789,6 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl
|
||||
"details": details,
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": info,
|
||||
"received_value": val,
|
||||
}
|
||||
}
|
||||
errors.append(error)
|
||||
@ -899,6 +981,7 @@ class PromptQueue(AbstractPromptQueue):
|
||||
|
||||
def task_done(self, item_id: str, outputs: dict,
|
||||
status: Optional[ExecutionStatus]):
|
||||
history_result = outputs
|
||||
with self.mutex:
|
||||
queue_item = self.currently_running.pop(item_id)
|
||||
prompt = queue_item.queue_tuple
|
||||
@ -909,14 +992,16 @@ class PromptQueue(AbstractPromptQueue):
|
||||
if status is not None:
|
||||
status_dict = copy.deepcopy(ExecutionStatus(*status)._asdict())
|
||||
|
||||
outputs_ = history_result["outputs"]
|
||||
self.history[prompt[1]] = {
|
||||
"prompt": prompt,
|
||||
"outputs": copy.deepcopy(outputs),
|
||||
"outputs": copy.deepcopy(outputs_),
|
||||
'status': status_dict,
|
||||
}
|
||||
self.history[prompt[1]].update(history_result)
|
||||
self.server.queue_updated()
|
||||
if queue_item.completed:
|
||||
queue_item.completed.set_result(outputs)
|
||||
queue_item.completed.set_result(outputs_)
|
||||
|
||||
def get_current_queue(self) -> Tuple[typing.List[QueueTuple], typing.List[QueueTuple]]:
|
||||
with self.mutex:
|
||||
|
||||
@ -45,7 +45,7 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
|
||||
e.execute(item[2], prompt_id, item[3], item[4])
|
||||
need_gc = True
|
||||
q.task_done(item_id,
|
||||
e.outputs_ui,
|
||||
e.history_result,
|
||||
status=ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
|
||||
@ -437,6 +437,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
obj_class = self.nodes.NODE_CLASS_MAPPINGS[node_class]
|
||||
info = {}
|
||||
info['input'] = obj_class.INPUT_TYPES()
|
||||
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
|
||||
info['output'] = obj_class.RETURN_TYPES
|
||||
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
|
||||
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']
|
||||
@ -829,6 +830,9 @@ class PromptServer(ExecutorToClientProgress):
|
||||
site = web.TCPSite(runner, host=address, port=port)
|
||||
await site.start()
|
||||
|
||||
self.address = address
|
||||
self.port = port
|
||||
|
||||
if verbose:
|
||||
logging.info("Starting server\n")
|
||||
logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port))
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from __future__ import annotations # for Python 3.7-3.9
|
||||
|
||||
import typing
|
||||
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple, List
|
||||
from enum import Enum
|
||||
from typing import Optional, Literal, Protocol, Union, NamedTuple, List
|
||||
|
||||
import PIL.Image
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
@ -45,11 +46,23 @@ class UnencodedPreviewImageMessage(NamedTuple):
|
||||
max_size: int = 512
|
||||
|
||||
|
||||
ExecutedMessage: TypeAlias = ExecutingMessage
|
||||
class ExecutionErrorMessage(TypedDict):
|
||||
prompt_id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
executed: list[str]
|
||||
exception_message: str
|
||||
exception_type: str
|
||||
traceback: list[str]
|
||||
current_inputs: list[typing.Never] | dict[str, FormattedValue]
|
||||
current_outputs: list[str]
|
||||
|
||||
SendSyncEvent: TypeAlias = Union[Literal["status", "executing", "progress", "executed"], BinaryEventTypes, None]
|
||||
|
||||
SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None]
|
||||
ExecutedMessage = ExecutingMessage
|
||||
|
||||
SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None]
|
||||
|
||||
SendSyncData = Union[StatusMessage, ExecutingMessage, ExecutionErrorMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None]
|
||||
|
||||
|
||||
class ExecutorToClientProgress(Protocol):
|
||||
@ -150,3 +163,33 @@ class RecursiveExecutionTuple(typing.NamedTuple):
|
||||
valid: bool
|
||||
error_details: Optional[RecursiveExecutionErrorDetails | RecursiveExecutionErrorDetailsInterrupted]
|
||||
exc_info: Optional[Exception]
|
||||
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
SUCCESS = 0
|
||||
FAILURE = 1
|
||||
PENDING = 2
|
||||
|
||||
def __bool__(self):
|
||||
return self == 0
|
||||
|
||||
|
||||
class DuplicateNodeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class HistoryResultDict(TypedDict, total=True):
|
||||
outputs: dict
|
||||
meta: dict
|
||||
|
||||
|
||||
class DependencyCycleError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NodeInputError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NodeNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
@ -21,7 +21,7 @@ from ..auth.permissions import jwt_decode
|
||||
from ..cmd.main_pre import tracer
|
||||
from ..cmd.server import PromptServer
|
||||
from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue, AbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData, HistoryResultDict
|
||||
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation, \
|
||||
ExecutionError
|
||||
|
||||
@ -164,6 +164,9 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
|
||||
|
||||
def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus]):
|
||||
# callee: executed on the worker thread
|
||||
if "outputs" in outputs:
|
||||
outputs: HistoryResultDict
|
||||
outputs = outputs["outputs"]
|
||||
assert self._is_callee
|
||||
pending = self._callee_local_in_progress.pop(item_id)
|
||||
assert pending is not None
|
||||
|
||||
234
comfy/graph.py
Normal file
234
comfy/graph.py
Normal file
@ -0,0 +1,234 @@
|
||||
from .cmd.execution import nodes
|
||||
from .component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError
|
||||
from .graph_utils import is_link
|
||||
|
||||
|
||||
class DynamicPrompt:
|
||||
def __init__(self, original_prompt):
|
||||
# The original prompt provided by the user
|
||||
self.original_prompt = original_prompt
|
||||
# Any extra pieces of the graph created during execution
|
||||
self.ephemeral_prompt = {}
|
||||
self.ephemeral_parents = {}
|
||||
self.ephemeral_display = {}
|
||||
|
||||
def get_node(self, node_id):
|
||||
if node_id in self.ephemeral_prompt:
|
||||
return self.ephemeral_prompt[node_id]
|
||||
if node_id in self.original_prompt:
|
||||
return self.original_prompt[node_id]
|
||||
raise NodeNotFoundError(f"Node {node_id} not found")
|
||||
|
||||
def has_node(self, node_id):
|
||||
return node_id in self.original_prompt or node_id in self.ephemeral_prompt
|
||||
|
||||
def add_ephemeral_node(self, node_id, node_info, parent_id, display_id):
|
||||
self.ephemeral_prompt[node_id] = node_info
|
||||
self.ephemeral_parents[node_id] = parent_id
|
||||
self.ephemeral_display[node_id] = display_id
|
||||
|
||||
def get_real_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_parents:
|
||||
node_id = self.ephemeral_parents[node_id]
|
||||
return node_id
|
||||
|
||||
def get_parent_node_id(self, node_id):
|
||||
return self.ephemeral_parents.get(node_id, None)
|
||||
|
||||
def get_display_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_display:
|
||||
node_id = self.ephemeral_display[node_id]
|
||||
return node_id
|
||||
|
||||
def all_node_ids(self):
|
||||
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
||||
|
||||
def get_original_prompt(self):
|
||||
return self.original_prompt
|
||||
|
||||
|
||||
def get_input_info(class_def, input_name):
|
||||
valid_inputs = class_def.INPUT_TYPES()
|
||||
input_info = None
|
||||
input_category = None
|
||||
if "required" in valid_inputs and input_name in valid_inputs["required"]:
|
||||
input_category = "required"
|
||||
input_info = valid_inputs["required"][input_name]
|
||||
elif "optional" in valid_inputs and input_name in valid_inputs["optional"]:
|
||||
input_category = "optional"
|
||||
input_info = valid_inputs["optional"][input_name]
|
||||
elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]:
|
||||
input_category = "hidden"
|
||||
input_info = valid_inputs["hidden"][input_name]
|
||||
if input_info is None:
|
||||
return None, None, None
|
||||
input_type = input_info[0]
|
||||
if len(input_info) > 1:
|
||||
extra_info = input_info[1]
|
||||
else:
|
||||
extra_info = {}
|
||||
return input_type, input_category, extra_info
|
||||
|
||||
|
||||
class TopologicalSort:
|
||||
def __init__(self, dynprompt):
|
||||
self.dynprompt = dynprompt
|
||||
self.pendingNodes = {}
|
||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
|
||||
def get_input_info(self, unique_id, input_name):
|
||||
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return get_input_info(class_def, input_name)
|
||||
|
||||
def make_input_strong_link(self, to_node_id, to_input):
|
||||
inputs = self.dynprompt.get_node(to_node_id)["inputs"]
|
||||
if to_input not in inputs:
|
||||
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all")
|
||||
value = inputs[to_input]
|
||||
if not is_link(value):
|
||||
raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant")
|
||||
from_node_id, from_socket = value
|
||||
self.add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
self.add_node(from_node_id)
|
||||
if to_node_id not in self.blocking[from_node_id]:
|
||||
self.blocking[from_node_id][to_node_id] = {}
|
||||
self.blockCount[to_node_id] += 1
|
||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||
|
||||
def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None):
|
||||
if unique_id in self.pendingNodes:
|
||||
return
|
||||
self.pendingNodes[unique_id] = True
|
||||
self.blockCount[unique_id] = 0
|
||||
self.blocking[unique_id] = {}
|
||||
|
||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||
for input_name in inputs:
|
||||
value = inputs[input_name]
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||
continue
|
||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||
if include_lazy or not is_lazy:
|
||||
self.add_strong_link(from_node_id, from_socket, unique_id)
|
||||
|
||||
def get_ready_nodes(self):
|
||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||
|
||||
def pop_node(self, unique_id):
|
||||
del self.pendingNodes[unique_id]
|
||||
for blocked_node_id in self.blocking[unique_id]:
|
||||
self.blockCount[blocked_node_id] -= 1
|
||||
del self.blocking[unique_id]
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.pendingNodes) == 0
|
||||
|
||||
|
||||
class ExecutionList(TopologicalSort):
|
||||
"""
|
||||
ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
|
||||
it can still be returned to the graph after having further dependencies added.
|
||||
"""
|
||||
|
||||
def __init__(self, dynprompt, output_cache):
|
||||
super().__init__(dynprompt)
|
||||
self.output_cache = output_cache
|
||||
self.staged_node_id = None
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
if self.output_cache.get(from_node_id) is not None:
|
||||
# Nothing to do
|
||||
return
|
||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
|
||||
def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
return None, None, None
|
||||
available = self.get_ready_nodes()
|
||||
if len(available) == 0:
|
||||
cycled_nodes = self.get_nodes_in_cycle()
|
||||
# Because cycles composed entirely of static nodes are caught during initial validation,
|
||||
# we will 'blame' the first node in the cycle that is not a static node.
|
||||
blamed_node = cycled_nodes[0]
|
||||
for node_id in cycled_nodes:
|
||||
display_node_id = self.dynprompt.get_display_node_id(node_id)
|
||||
if display_node_id != node_id:
|
||||
blamed_node = display_node_id
|
||||
break
|
||||
ex = DependencyCycleError("Dependency cycle detected")
|
||||
error_details = {
|
||||
"node_id": blamed_node,
|
||||
"exception_message": str(ex),
|
||||
"exception_type": "graph.DependencyCycleError",
|
||||
"traceback": [],
|
||||
"current_inputs": []
|
||||
}
|
||||
return None, error_details, ex
|
||||
next_node = available[0]
|
||||
# If an output node is available, do that first.
|
||||
# Technically this has no effect on the overall length of execution, but it feels better as a user
|
||||
# for a PreviewImage to display a result as soon as it can
|
||||
# Some other heuristics could probably be used here to improve the UX further.
|
||||
for node_id in available:
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
next_node = node_id
|
||||
break
|
||||
self.staged_node_id = next_node
|
||||
return self.staged_node_id, None, None
|
||||
|
||||
def unstage_node_execution(self):
|
||||
assert self.staged_node_id is not None
|
||||
self.staged_node_id = None
|
||||
|
||||
def complete_node_execution(self):
|
||||
node_id = self.staged_node_id
|
||||
self.pop_node(node_id)
|
||||
self.staged_node_id = None
|
||||
|
||||
def get_nodes_in_cycle(self):
|
||||
# We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
|
||||
# We're skipping some of the performance optimizations from the original TopologicalSort to keep
|
||||
# the code simple (and because having a cycle in the first place is a catastrophic error)
|
||||
blocked_by = {node_id: {} for node_id in self.pendingNodes}
|
||||
for from_node_id in self.blocking:
|
||||
for to_node_id in self.blocking[from_node_id]:
|
||||
if True in self.blocking[from_node_id][to_node_id].values():
|
||||
blocked_by[to_node_id][from_node_id] = True
|
||||
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
||||
while len(to_remove) > 0:
|
||||
for node_id in to_remove:
|
||||
for to_node_id in blocked_by:
|
||||
if node_id in blocked_by[to_node_id]:
|
||||
del blocked_by[to_node_id][node_id]
|
||||
del blocked_by[node_id]
|
||||
to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0]
|
||||
return list(blocked_by.keys())
|
||||
|
||||
|
||||
class ExecutionBlocker:
|
||||
"""
|
||||
Return this from a node and any users will be blocked with the given error message.
|
||||
If the message is None, execution will be blocked silently instead.
|
||||
Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's
|
||||
possible, a lazy input will be more efficient and have a better user experience.
|
||||
This functionality is useful in two cases:
|
||||
1. You want to conditionally prevent an output node from executing. (Particularly a built-in node
|
||||
like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using
|
||||
lazy evaluation to let it conditionally disable itself.)
|
||||
2. You have a node with multiple possible outputs, some of which are invalid and should not be used.
|
||||
(I would recommend not making nodes like this in the future -- instead, make multiple nodes with
|
||||
different outputs. Unfortunately, there are several popular existing nodes using this pattern.)
|
||||
"""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
143
comfy/graph_utils.py
Normal file
143
comfy/graph_utils.py
Normal file
@ -0,0 +1,143 @@
|
||||
def is_link(obj):
|
||||
if not isinstance(obj, list):
|
||||
return False
|
||||
if len(obj) != 2:
|
||||
return False
|
||||
if not isinstance(obj[0], str):
|
||||
return False
|
||||
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
"""
|
||||
The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end
|
||||
"""
|
||||
_default_prefix_root = ""
|
||||
_default_prefix_call_index = 0
|
||||
_default_prefix_graph_index = 0
|
||||
|
||||
def __init__(self, prefix=None):
|
||||
if prefix is None:
|
||||
self.prefix = GraphBuilder.alloc_prefix()
|
||||
else:
|
||||
self.prefix = prefix
|
||||
self.nodes = {}
|
||||
self.id_gen = 1
|
||||
|
||||
@classmethod
|
||||
def set_default_prefix(cls, prefix_root, call_index, graph_index=0):
|
||||
cls._default_prefix_root = prefix_root
|
||||
cls._default_prefix_call_index = call_index
|
||||
cls._default_prefix_graph_index = graph_index
|
||||
|
||||
@classmethod
|
||||
def alloc_prefix(cls, root=None, call_index=None, graph_index=None):
|
||||
if root is None:
|
||||
root = GraphBuilder._default_prefix_root
|
||||
if call_index is None:
|
||||
call_index = GraphBuilder._default_prefix_call_index
|
||||
if graph_index is None:
|
||||
graph_index = GraphBuilder._default_prefix_graph_index
|
||||
result = f"{root}.{call_index}.{graph_index}."
|
||||
GraphBuilder._default_prefix_graph_index += 1
|
||||
return result
|
||||
|
||||
def node(self, class_type, id=None, **kwargs):
|
||||
if id is None:
|
||||
id = str(self.id_gen)
|
||||
self.id_gen += 1
|
||||
id = self.prefix + id
|
||||
if id in self.nodes:
|
||||
return self.nodes[id]
|
||||
|
||||
node = Node(id, class_type, kwargs)
|
||||
self.nodes[id] = node
|
||||
return node
|
||||
|
||||
def lookup_node(self, id):
|
||||
id = self.prefix + id
|
||||
return self.nodes.get(id)
|
||||
|
||||
def finalize(self):
|
||||
output = {}
|
||||
for node_id, node in self.nodes.items():
|
||||
output[node_id] = node.serialize()
|
||||
return output
|
||||
|
||||
def replace_node_output(self, node_id, index, new_value):
|
||||
node_id = self.prefix + node_id
|
||||
to_remove = []
|
||||
for node in self.nodes.values():
|
||||
for key, value in node.inputs.items():
|
||||
if is_link(value) and value[0] == node_id and value[1] == index:
|
||||
if new_value is None:
|
||||
to_remove.append((node, key))
|
||||
else:
|
||||
node.inputs[key] = new_value
|
||||
for node, key in to_remove:
|
||||
del node.inputs[key]
|
||||
|
||||
def remove_node(self, id):
|
||||
id = self.prefix + id
|
||||
del self.nodes[id]
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, id, class_type, inputs):
|
||||
self.id = id
|
||||
self.class_type = class_type
|
||||
self.inputs = inputs
|
||||
self.override_display_id = None
|
||||
|
||||
def out(self, index):
|
||||
return [self.id, index]
|
||||
|
||||
def set_input(self, key, value):
|
||||
if value is None:
|
||||
if key in self.inputs:
|
||||
del self.inputs[key]
|
||||
else:
|
||||
self.inputs[key] = value
|
||||
|
||||
def get_input(self, key):
|
||||
return self.inputs.get(key)
|
||||
|
||||
def set_override_display_id(self, override_display_id):
|
||||
self.override_display_id = override_display_id
|
||||
|
||||
def serialize(self):
|
||||
serialized = {
|
||||
"class_type": self.class_type,
|
||||
"inputs": self.inputs
|
||||
}
|
||||
if self.override_display_id is not None:
|
||||
serialized["override_display_id"] = self.override_display_id
|
||||
return serialized
|
||||
|
||||
|
||||
def add_graph_prefix(graph, outputs, prefix):
|
||||
# Change the node IDs and any internal links
|
||||
new_graph = {}
|
||||
for node_id, node_info in graph.items():
|
||||
# Make sure the added nodes have unique IDs
|
||||
new_node_id = prefix + node_id
|
||||
new_node = {"class_type": node_info["class_type"], "inputs": {}}
|
||||
for input_name, input_value in node_info.get("inputs", {}).items():
|
||||
if is_link(input_value):
|
||||
new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]]
|
||||
else:
|
||||
new_node["inputs"][input_name] = input_value
|
||||
new_graph[new_node_id] = new_node
|
||||
|
||||
# Change the node IDs in the outputs
|
||||
new_outputs = []
|
||||
for n in range(len(outputs)):
|
||||
output = outputs[n]
|
||||
if is_link(output):
|
||||
new_outputs.append([prefix + output[0], output[1]])
|
||||
else:
|
||||
new_outputs.append(output)
|
||||
|
||||
return new_graph, tuple(new_outputs)
|
||||
@ -15,6 +15,7 @@ class IntSpecOptions(TypedDict, total=True):
|
||||
max: int
|
||||
step: NotRequired[int]
|
||||
display: NotRequired[Literal["number", "slider"]]
|
||||
lazy: NotRequired[bool]
|
||||
|
||||
|
||||
class FloatSpecOptions(TypedDict, total=True):
|
||||
@ -24,20 +25,24 @@ class FloatSpecOptions(TypedDict, total=True):
|
||||
step: NotRequired[float]
|
||||
round: NotRequired[float]
|
||||
display: NotRequired[Literal["number", "slider"]]
|
||||
lazy: NotRequired[bool]
|
||||
|
||||
|
||||
class StringSpecOptions(TypedDict, total=True):
|
||||
multiline: NotRequired[bool]
|
||||
default: NotRequired[str]
|
||||
dynamicPrompts: NotRequired[bool]
|
||||
lazy: NotRequired[bool]
|
||||
|
||||
|
||||
class BoolSpecOptions(TypedDict):
|
||||
default: NotRequired[bool]
|
||||
lazy: NotRequired[bool]
|
||||
|
||||
|
||||
class DefaultSpecOptions(TypedDict):
|
||||
default: NotRequired[Any]
|
||||
lazy: NotRequired[bool]
|
||||
|
||||
|
||||
# todo: analyze the base_nodes for these types
|
||||
@ -129,6 +134,20 @@ class CustomNode(Protocol):
|
||||
def __call__(cls, *args, **kwargs) -> 'CustomNode':
|
||||
...
|
||||
|
||||
def check_lazy_status(self, *args, **kwargs) -> list[str]:
|
||||
"""
|
||||
Return a list of input names that need to be evaluated.
|
||||
|
||||
This function will be called if there are any lazy inputs which have not yet been
|
||||
evaluated. As long as you return at least one field which has not yet been evaluated
|
||||
(and more exist), this function will be called again once the value of the requested
|
||||
field is available.
|
||||
|
||||
Any evaluated inputs will be passed as arguments to this function. Any unevaluated
|
||||
inputs will have the value None.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExportedNodes:
|
||||
|
||||
@ -959,8 +959,8 @@ export class GroupNodeHandler {
|
||||
const executed = handleEvent.call(
|
||||
this,
|
||||
"executed",
|
||||
(d) => d?.node,
|
||||
(d, id, node) => ({ ...d, node: id, merge: !node.resetExecution })
|
||||
(d) => d?.display_node,
|
||||
(d, id, node) => ({ ...d, node: id, display_node: id, merge: !node.resetExecution })
|
||||
);
|
||||
|
||||
const onRemoved = node.onRemoved;
|
||||
|
||||
@ -3,7 +3,7 @@ import { app } from "../../scripts/app.js";
|
||||
import { applyTextReplacements } from "../../scripts/utils.js";
|
||||
|
||||
const CONVERTED_TYPE = "converted-widget";
|
||||
const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"];
|
||||
const VALID_TYPES = ["STRING", "combo", "number", "toggle", "BOOLEAN"];
|
||||
const CONFIG = Symbol();
|
||||
const GET_CONFIG = Symbol();
|
||||
const TARGET = Symbol(); // Used for reroutes to specify the real target widget
|
||||
|
||||
@ -128,7 +128,7 @@ class ComfyApi extends EventTarget {
|
||||
this.dispatchEvent(new CustomEvent("progress", { detail: msg.data }));
|
||||
break;
|
||||
case "executing":
|
||||
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node }));
|
||||
this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.display_node }));
|
||||
break;
|
||||
case "executed":
|
||||
this.dispatchEvent(new CustomEvent("executed", { detail: msg.data }));
|
||||
|
||||
@ -58,7 +58,7 @@ export class ComfyApp {
|
||||
this.bodyTop = $el("div.comfyui-body-top", { parent: document.body });
|
||||
this.bodyLeft = $el("div.comfyui-body-left", { parent: document.body });
|
||||
this.bodyRight = $el("div.comfyui-body-right", { parent: document.body });
|
||||
this.bodyBottom = $el("div.comfyui-body-bottom", { parent: document.body });
|
||||
this.bodyBottom = $el("div.comfyui-body-bottom", { parent: document.body });
|
||||
this.menu = new ComfyAppMenu(this);
|
||||
|
||||
/**
|
||||
@ -1330,7 +1330,7 @@ export class ComfyApp {
|
||||
});
|
||||
|
||||
api.addEventListener("progress", ({ detail }) => {
|
||||
if (this.workflowManager.activePrompt?.workflow
|
||||
if (this.workflowManager.activePrompt?.workflow
|
||||
&& this.workflowManager.activePrompt.workflow !== this.workflowManager.activeWorkflow) return;
|
||||
this.progress = detail;
|
||||
this.graph.setDirtyCanvas(true, false);
|
||||
@ -1348,7 +1348,7 @@ export class ComfyApp {
|
||||
api.addEventListener("executed", ({ detail }) => {
|
||||
if (this.workflowManager.activePrompt ?.workflow
|
||||
&& this.workflowManager.activePrompt.workflow !== this.workflowManager.activeWorkflow) return;
|
||||
const output = this.nodeOutputs[detail.node];
|
||||
const output = this.nodeOutputs[detail.display_node];
|
||||
if (detail.merge && output) {
|
||||
for (const k in detail.output ?? {}) {
|
||||
const v = output[k];
|
||||
@ -1359,9 +1359,9 @@ export class ComfyApp {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
this.nodeOutputs[detail.node] = detail.output;
|
||||
this.nodeOutputs[detail.display_node] = detail.output;
|
||||
}
|
||||
const node = this.graph.getNodeById(detail.node);
|
||||
const node = this.graph.getNodeById(detail.display_node);
|
||||
if (node) {
|
||||
if (node.onExecuted)
|
||||
node.onExecuted(detail.output);
|
||||
@ -1638,7 +1638,7 @@ export class ComfyApp {
|
||||
resizeCanvas() {
|
||||
// Limit minimal scale to 1, see https://github.com/comfyanonymous/ComfyUI/pull/845
|
||||
const scale = Math.max(window.devicePixelRatio, 1);
|
||||
|
||||
|
||||
// Clear fixed width and height while calculating rect so it uses 100% instead
|
||||
this.canvasEl.height = this.canvasEl.width = "";
|
||||
const { width, height } = this.canvasEl.getBoundingClientRect();
|
||||
@ -1874,7 +1874,7 @@ export class ComfyApp {
|
||||
{
|
||||
graphData = structuredClone(graphData);
|
||||
}
|
||||
|
||||
|
||||
try {
|
||||
this.workflowManager.setWorkflow(workflow);
|
||||
} catch (error) {
|
||||
@ -1902,7 +1902,7 @@ export class ComfyApp {
|
||||
this.canvas.ds.offset = graphData.extra.ds.offset;
|
||||
this.canvas.ds.scale = graphData.extra.ds.scale;
|
||||
}
|
||||
|
||||
|
||||
try {
|
||||
this.workflowManager.activeWorkflow?.track()
|
||||
} catch (error) {
|
||||
|
||||
@ -239,7 +239,14 @@ class ComfyList {
|
||||
onclick: async () => {
|
||||
await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow, true, false);
|
||||
if (item.outputs) {
|
||||
app.nodeOutputs = item.outputs;
|
||||
app.nodeOutputs = {};
|
||||
for (const [key, value] of Object.entries(item.outputs)) {
|
||||
if (item.meta && item.meta[key] && item.meta[key].display_node) {
|
||||
app.nodeOutputs[item.meta[key].display_node] = value;
|
||||
} else {
|
||||
app.nodeOutputs[key] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}),
|
||||
|
||||
185
comfy_extras/nodes/nodes_color.py
Normal file
185
comfy_extras/nodes/nodes_color.py
Normal file
@ -0,0 +1,185 @@
|
||||
import colour
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.interpolate import interpn
|
||||
from scipy.interpolate.interpnd import LinearNDInterpolator
|
||||
from scipy.ndimage import gaussian_filter
|
||||
from scipy.optimize._lsap import linear_sum_assignment
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
|
||||
|
||||
class ColorPaletteExtractor(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"num_colors": ("INT", {"default": 8, "min": 2, "max": 64, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "PALETTE")
|
||||
FUNCTION = "extract_palette"
|
||||
CATEGORY = "image/color"
|
||||
|
||||
def extract_palette(self, image, num_colors):
|
||||
if len(image.shape) == 3:
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
image_np = image.squeeze().cpu().numpy()
|
||||
|
||||
pixels = image_np.reshape(-1, 3)
|
||||
|
||||
kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10)
|
||||
kmeans.fit(pixels)
|
||||
|
||||
colors = kmeans.cluster_centers_
|
||||
|
||||
_, counts = np.unique(kmeans.labels_, return_counts=True)
|
||||
sorted_indices = np.argsort(counts)[::-1]
|
||||
colors = colors[sorted_indices]
|
||||
|
||||
palette_height = 512
|
||||
palette_width = 512
|
||||
palette_image = np.zeros((palette_height, palette_width * num_colors, 3), dtype=np.uint8)
|
||||
for i, color in enumerate(colors):
|
||||
palette_image[:, i * palette_width:(i + 1) * palette_width] = color
|
||||
|
||||
palette_tensor = torch.from_numpy(palette_image).float() / 255.0
|
||||
palette_tensor = palette_tensor.unsqueeze(0).to(comfy.model_management.get_torch_device())
|
||||
|
||||
color_array = (colors * 255).astype(np.uint8)
|
||||
|
||||
return palette_tensor, color_array
|
||||
|
||||
|
||||
class ImageBasedColorRemap(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"reference_image": ("IMAGE",),
|
||||
"palette_size": ("INT", {"default": 8, "min": 2, "max": 64, "step": 1}),
|
||||
"lut_size": ("INT", {"default": 33, "min": 8, "max": 64, "step": 1}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "remap_colors"
|
||||
CATEGORY = "image/color"
|
||||
|
||||
def remap_colors(self, image, reference_image, palette_size, lut_size):
|
||||
# Ensure images are in the correct format (B, H, W, C)
|
||||
image = image.squeeze(0) if image.ndim == 4 else image
|
||||
reference_image = reference_image.squeeze(0) if reference_image.ndim == 4 else reference_image
|
||||
|
||||
# Convert torch tensors to numpy arrays
|
||||
image_np = image.cpu().numpy()
|
||||
reference_np = reference_image.cpu().numpy()
|
||||
|
||||
# Extract palettes using k-means clustering
|
||||
image_palette = self.extract_palette(image_np, palette_size)
|
||||
reference_palette = self.extract_palette(reference_np, palette_size)
|
||||
|
||||
# Convert palettes to LAB color space
|
||||
image_palette_lab = colour.XYZ_to_Lab(colour.sRGB_to_XYZ(image_palette))
|
||||
reference_palette_lab = colour.XYZ_to_Lab(colour.sRGB_to_XYZ(reference_palette))
|
||||
|
||||
# Align palettes
|
||||
aligned_image_palette_lab, aligned_reference_palette_lab = self.align_palettes(image_palette_lab, reference_palette_lab)
|
||||
|
||||
# Create 3D LUT in LAB space
|
||||
lut = self.create_color_remap_3dlut(aligned_image_palette_lab, aligned_reference_palette_lab, size=lut_size)
|
||||
|
||||
# Apply 3D LUT
|
||||
image_lab = colour.XYZ_to_Lab(colour.sRGB_to_XYZ(image_np))
|
||||
remapped_lab = self.apply_lut(image_lab, lut, lut_size)
|
||||
remapped_rgb = colour.XYZ_to_sRGB(colour.Lab_to_XYZ(remapped_lab))
|
||||
|
||||
# Clip values to ensure they're in the valid range
|
||||
remapped_rgb = np.clip(remapped_rgb, 0, 1)
|
||||
|
||||
# Convert back to torch tensor
|
||||
remapped_tensor = torch.from_numpy(remapped_rgb).float()
|
||||
remapped_tensor = remapped_tensor.unsqueeze(0) # Add batch dimension
|
||||
|
||||
return (remapped_tensor,)
|
||||
|
||||
def extract_palette(self, image, palette_size):
|
||||
pixels = image.reshape(-1, 3)
|
||||
kmeans = KMeans(n_clusters=palette_size, random_state=42, n_init=10)
|
||||
kmeans.fit(pixels)
|
||||
return kmeans.cluster_centers_
|
||||
|
||||
def align_palettes(self, palette1, palette2):
|
||||
distances = np.linalg.norm(palette1[:, np.newaxis] - palette2, axis=2)
|
||||
row_ind, col_ind = linear_sum_assignment(distances)
|
||||
return palette1, palette2[col_ind]
|
||||
|
||||
def create_color_remap_3dlut(self, original_colors, target_colors, size=33):
|
||||
# Create a regular grid in LAB space
|
||||
L = np.linspace(0, 100, size)
|
||||
a = np.linspace(-128, 127, size)
|
||||
b = np.linspace(-128, 127, size)
|
||||
grid = np.meshgrid(L, a, b, indexing='ij')
|
||||
|
||||
# Reshape the grid for KNN
|
||||
grid_points = np.vstack([g.ravel() for g in grid]).T
|
||||
|
||||
# Use KNN to find the nearest neighbor for each grid point
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
nn = NearestNeighbors(n_neighbors=1, algorithm='auto').fit(original_colors)
|
||||
_, indices = nn.kneighbors(grid_points)
|
||||
|
||||
# Create the LUT using the target colors
|
||||
lut = target_colors[indices.ravel()]
|
||||
lut = lut.reshape(size, size, size, 3)
|
||||
|
||||
# Apply Gaussian smoothing to create a more continuous mapping
|
||||
for i in range(3):
|
||||
lut[:,:,:,i] = gaussian_filter(lut[:,:,:,i], sigma=1)
|
||||
|
||||
return lut
|
||||
|
||||
def apply_lut(self, image_lab, lut, lut_size):
|
||||
points = (
|
||||
np.linspace(0, 100, lut_size), # L
|
||||
np.linspace(-128, 127, lut_size), # a
|
||||
np.linspace(-128, 127, lut_size) # b
|
||||
)
|
||||
|
||||
xi = np.stack([
|
||||
image_lab[..., 0],
|
||||
image_lab[..., 1],
|
||||
image_lab[..., 2]
|
||||
], axis=-1)
|
||||
|
||||
remapped_lab = np.zeros_like(image_lab)
|
||||
for i in range(3):
|
||||
remapped_lab[..., i] = interpn(
|
||||
points,
|
||||
lut[..., i],
|
||||
xi,
|
||||
method='linear',
|
||||
bounds_error=False,
|
||||
fill_value=None
|
||||
)
|
||||
|
||||
return remapped_lab
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ColorPaletteExtractor": ColorPaletteExtractor,
|
||||
"ImageBasedColorRemap": ImageBasedColorRemap,
|
||||
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ColorPaletteExtractor": "Extract Color Palette",
|
||||
"ImageBasedColorRemap": "Image-Based Color Remap",
|
||||
|
||||
}
|
||||
@ -54,7 +54,8 @@ class Example:
|
||||
"min": 0, #Minimum value
|
||||
"max": 4096, #Maximum value
|
||||
"step": 64, #Slider's step
|
||||
"display": "number" # Cosmetic only: display as "number" or "slider"
|
||||
"display": "number", # Cosmetic only: display as "number" or "slider"
|
||||
"lazy": True # Will only be evaluated if check_lazy_status requires it
|
||||
}),
|
||||
"float_field": ("FLOAT", {
|
||||
"default": 1.0,
|
||||
@ -62,11 +63,14 @@ class Example:
|
||||
"max": 10.0,
|
||||
"step": 0.01,
|
||||
"round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding.
|
||||
"display": "number"}),
|
||||
"display": "number",
|
||||
"lazy": True
|
||||
}),
|
||||
"print_to_screen": (["enable", "disable"],),
|
||||
"string_field": ("STRING", {
|
||||
"multiline": False, #True if you want the field to look like the one on the ClipTextEncode node
|
||||
"default": "Hello World!"
|
||||
"default": "Hello World!",
|
||||
"lazy": True
|
||||
}),
|
||||
},
|
||||
}
|
||||
@ -80,6 +84,23 @@ class Example:
|
||||
|
||||
CATEGORY = "Example"
|
||||
|
||||
def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen):
|
||||
"""
|
||||
Return a list of input names that need to be evaluated.
|
||||
|
||||
This function will be called if there are any lazy inputs which have not yet been
|
||||
evaluated. As long as you return at least one field which has not yet been evaluated
|
||||
(and more exist), this function will be called again once the value of the requested
|
||||
field is available.
|
||||
|
||||
Any evaluated inputs will be passed as arguments to this function. Any unevaluated
|
||||
inputs will have the value None.
|
||||
"""
|
||||
if print_to_screen == "enable":
|
||||
return ["int_field", "float_field", "string_field"]
|
||||
else:
|
||||
return []
|
||||
|
||||
def test(self, image, string_field, int_field, float_field, print_to_screen):
|
||||
if print_to_screen == "enable":
|
||||
print(f"""Your input contains:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
[pytest]
|
||||
markers =
|
||||
inference: mark as inference test (deselect with '-m "not inference"')
|
||||
execution: mark as execution test (deselect with '-m "not execution"')
|
||||
testpaths =
|
||||
tests
|
||||
tests-unit
|
||||
|
||||
@ -443,6 +443,7 @@ describe("group node", () => {
|
||||
new CustomEvent("executed", {
|
||||
detail: {
|
||||
node: `${nodes.save.id}`,
|
||||
display_node: `${nodes.save.id}`,
|
||||
output: {
|
||||
images: [
|
||||
{
|
||||
@ -483,6 +484,7 @@ describe("group node", () => {
|
||||
new CustomEvent("executed", {
|
||||
detail: {
|
||||
node: `${group.id}:5`,
|
||||
display_node: `${group.id}:5`,
|
||||
output: {
|
||||
images: [
|
||||
{
|
||||
|
||||
0
tests/asyncio/__init__.py
Normal file
0
tests/asyncio/__init__.py
Normal file
@ -45,6 +45,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip("flakey")
|
||||
def test_known_repos(tmp_path_factory):
|
||||
prev_hub_cache = os.getenv("HF_HUB_CACHE")
|
||||
os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache"))
|
||||
|
||||
374
tests/inference/test_execution.py
Normal file
374
tests/inference/test_execution.py
Normal file
@ -0,0 +1,374 @@
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from pytest import fixture
|
||||
|
||||
from comfy.cli_args import default_configuration
|
||||
from comfy.client.embedded_comfy_client import EmbeddedComfyClient
|
||||
from comfy.component_model.executor_types import SendSyncEvent, SendSyncData, ExecutingMessage, ExecutionErrorMessage, DependencyCycleError
|
||||
from comfy.distributed.server_stub import ServerStub
|
||||
from comfy.graph_utils import GraphBuilder, Node
|
||||
from comfy.nodes.package_typing import ExportedNodes
|
||||
|
||||
current_test_name = ContextVar('current_test_name', default=None)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_test_name(request):
|
||||
token = current_test_name.set(request.node.name)
|
||||
yield
|
||||
current_test_name.reset(token)
|
||||
|
||||
|
||||
class RunResult:
|
||||
def __init__(self, prompt_id: str):
|
||||
self.outputs: Dict[str, Dict] = {}
|
||||
self.runs: Dict[str, bool] = {}
|
||||
self.prompt_id: str = prompt_id
|
||||
|
||||
def get_output(self, node: Node):
|
||||
return self.outputs.get(node.id, None)
|
||||
|
||||
def did_run(self, node: Node):
|
||||
return self.runs.get(node.id, False)
|
||||
|
||||
def get_images(self, node: Node):
|
||||
output = self.get_output(node)
|
||||
if output is None:
|
||||
return []
|
||||
return output.get('image_objects', [])
|
||||
|
||||
def get_prompt_id(self):
|
||||
return self.prompt_id
|
||||
|
||||
|
||||
class _ProgressHandler(ServerStub):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tuples: list[tuple[SendSyncEvent, SendSyncData, str]] = []
|
||||
|
||||
def send_sync(self,
|
||||
event: SendSyncEvent,
|
||||
data: SendSyncData,
|
||||
sid: Optional[str] = None):
|
||||
self.tuples.append((event, data, sid))
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self, embedded_client: EmbeddedComfyClient, progress_handler: _ProgressHandler):
|
||||
self.embedded_client = embedded_client
|
||||
self.progress_handler = progress_handler
|
||||
|
||||
async def run(self, graph: GraphBuilder) -> RunResult:
|
||||
self.progress_handler.tuples = []
|
||||
for node in graph.nodes.values():
|
||||
if node.class_type == 'SaveImage':
|
||||
node.inputs['filename_prefix'] = current_test_name.get()
|
||||
|
||||
prompt_id = str(uuid.uuid4())
|
||||
try:
|
||||
outputs = await self.embedded_client.queue_prompt(graph.finalize(), prompt_id=prompt_id)
|
||||
except (RuntimeError, DependencyCycleError):
|
||||
outputs = {}
|
||||
result = RunResult(prompt_id=prompt_id)
|
||||
result.outputs = outputs
|
||||
result.runs = {}
|
||||
send_sync_event: SendSyncEvent
|
||||
send_sync_data: SendSyncData
|
||||
for send_sync_event, send_sync_data, _ in self.progress_handler.tuples:
|
||||
if send_sync_event == "executing":
|
||||
send_sync_data: ExecutingMessage
|
||||
result.runs[send_sync_data["node"]] = True
|
||||
elif send_sync_event == "execution_error":
|
||||
send_sync_data: ExecutionErrorMessage
|
||||
raise Exception(send_sync_data)
|
||||
|
||||
for node in outputs.values():
|
||||
if "images" in node:
|
||||
image_objects = node["image_objects"] = []
|
||||
for image in node["images"]:
|
||||
image_objects.append(Image.open(image["abs_path"]))
|
||||
return result
|
||||
|
||||
|
||||
# Loop through these variables
|
||||
@pytest.mark.execution
|
||||
class TestExecution:
|
||||
# Initialize server and client
|
||||
@fixture(scope="class", params=[
|
||||
# (lru_size)
|
||||
(0,),
|
||||
(100,),
|
||||
])
|
||||
async def client(self, request) -> Client:
|
||||
from comfy.cmd.execution import nodes
|
||||
from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=NODE_DISPLAY_NAME_MAPPINGS))
|
||||
lru_size, = request.param
|
||||
configuration = default_configuration()
|
||||
configuration.cache_lru = lru_size
|
||||
progress_handler = _ProgressHandler()
|
||||
async with EmbeddedComfyClient(configuration, progress_handler=progress_handler) as embedded_client:
|
||||
yield Client(embedded_client, progress_handler)
|
||||
|
||||
@fixture
|
||||
def builder(self, request):
|
||||
yield GraphBuilder(prefix=request.node.name)
|
||||
|
||||
async def test_lazy_input(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1)
|
||||
|
||||
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||
output = g.node("SaveImage", images=lazy_mix.out(0))
|
||||
result = await client.run(g)
|
||||
|
||||
result_image = result.get_images(output)[0]
|
||||
assert numpy.array(result_image).any() == 0, "Image should be black"
|
||||
assert result.did_run(input1)
|
||||
assert not result.did_run(input2)
|
||||
assert result.did_run(mask)
|
||||
assert result.did_run(lazy_mix)
|
||||
|
||||
async def test_full_cache(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
|
||||
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||
g.node("SaveImage", images=lazy_mix.out(0))
|
||||
|
||||
await client.run(g)
|
||||
result2 = await client.run(g)
|
||||
for node_id, node in g.nodes.items():
|
||||
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
|
||||
|
||||
async def test_partial_cache(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
|
||||
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||
g.node("SaveImage", images=lazy_mix.out(0))
|
||||
|
||||
await client.run(g)
|
||||
mask.inputs['value'] = 0.4
|
||||
result2 = await client.run(g)
|
||||
assert not result2.did_run(input1), "Input1 should have been cached"
|
||||
assert not result2.did_run(input2), "Input2 should have been cached"
|
||||
|
||||
async def test_error(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
# Different size of the two images
|
||||
input2 = g.node("StubImage", content="NOISE", height=256, width=256, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
|
||||
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||
g.node("SaveImage", images=lazy_mix.out(0))
|
||||
|
||||
try:
|
||||
await client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
||||
|
||||
@pytest.mark.parametrize("test_value, expect_error", [
|
||||
(5, True),
|
||||
("foo", True),
|
||||
(5.0, False),
|
||||
])
|
||||
async def test_validation_error_literal(self, test_value, expect_error, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0)
|
||||
g.node("SaveImage", images=validation1.out(0))
|
||||
|
||||
if expect_error:
|
||||
with pytest.raises(ValueError):
|
||||
await client.run(g)
|
||||
else:
|
||||
await client.run(g)
|
||||
|
||||
@pytest.mark.parametrize("test_type, test_value", [
|
||||
("StubInt", 5),
|
||||
("StubFloat", 5.0)
|
||||
])
|
||||
async def test_validation_error_edge1(self, test_type, test_value, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
stub = g.node(test_type, value=test_value)
|
||||
validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0)
|
||||
g.node("SaveImage", images=validation1.out(0))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await client.run(g)
|
||||
|
||||
@pytest.mark.parametrize("test_type, test_value, expect_error", [
|
||||
("StubInt", 5, True),
|
||||
("StubFloat", 5.0, False)
|
||||
])
|
||||
async def test_validation_error_edge2(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
stub = g.node(test_type, value=test_value)
|
||||
validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0)
|
||||
g.node("SaveImage", images=validation2.out(0))
|
||||
|
||||
if expect_error:
|
||||
with pytest.raises(ValueError):
|
||||
await client.run(g)
|
||||
else:
|
||||
await client.run(g)
|
||||
|
||||
@pytest.mark.parametrize("test_type, test_value, expect_error", [
|
||||
("StubInt", 5, True),
|
||||
("StubFloat", 5.0, False)
|
||||
])
|
||||
async def test_validation_error_edge3(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
stub = g.node(test_type, value=test_value)
|
||||
validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0)
|
||||
g.node("SaveImage", images=validation3.out(0))
|
||||
|
||||
if expect_error:
|
||||
with pytest.raises(ValueError):
|
||||
await client.run(g)
|
||||
else:
|
||||
await client.run(g)
|
||||
|
||||
@pytest.mark.parametrize("test_type, test_value, expect_error", [
|
||||
("StubInt", 5, True),
|
||||
("StubFloat", 5.0, False)
|
||||
])
|
||||
async def test_validation_error_edge4(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
stub = g.node(test_type, value=test_value)
|
||||
validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0)
|
||||
g.node("SaveImage", images=validation4.out(0))
|
||||
|
||||
if expect_error:
|
||||
with pytest.raises(ValueError):
|
||||
await client.run(g)
|
||||
else:
|
||||
await client.run(g)
|
||||
|
||||
async def test_cycle_error(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
|
||||
lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), mask=mask.out(0))
|
||||
lazy_mix2 = g.node("TestLazyMixImages", image1=lazy_mix1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||
g.node("SaveImage", images=lazy_mix2.out(0))
|
||||
|
||||
# When the cycle exists on initial submission, it should raise a validation error
|
||||
with pytest.raises(ValueError):
|
||||
await client.run(g)
|
||||
|
||||
async def test_dynamic_cycle_error(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
generator = g.node("TestDynamicDependencyCycle", input1=input1.out(0), input2=input2.out(0))
|
||||
g.node("SaveImage", images=generator.out(0))
|
||||
|
||||
# When the cycle is in a graph that is generated dynamically, it should raise a runtime error
|
||||
try:
|
||||
await client.run(g)
|
||||
assert False, "Should have raised an error"
|
||||
except Exception as e:
|
||||
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
||||
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
|
||||
|
||||
async def test_custom_is_changed(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
# Creating the nodes in this specific order previously caused a bug
|
||||
save = g.node("SaveImage")
|
||||
is_changed = g.node("TestCustomIsChanged", should_change=False)
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
save.set_input('images', is_changed.out(0))
|
||||
is_changed.set_input('image', input1.out(0))
|
||||
|
||||
result1 = await client.run(g)
|
||||
result2 = await client.run(g)
|
||||
is_changed.set_input('should_change', True)
|
||||
result3 = await client.run(g)
|
||||
result4 = await client.run(g)
|
||||
assert result1.did_run(is_changed), "is_changed should have been run"
|
||||
assert not result2.did_run(is_changed), "is_changed should have been cached"
|
||||
assert result3.did_run(is_changed), "is_changed should have been re-run"
|
||||
assert result4.did_run(is_changed), "is_changed should not have been cached"
|
||||
|
||||
async def test_undeclared_inputs(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
input3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input4 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
average = g.node("TestVariadicAverage", input1=input1.out(0), input2=input2.out(0), input3=input3.out(0), input4=input4.out(0))
|
||||
output = g.node("SaveImage", images=average.out(0))
|
||||
|
||||
result = await client.run(g)
|
||||
result_image = result.get_images(output)[0]
|
||||
expected = 255 // 4
|
||||
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
|
||||
|
||||
async def test_for_loop(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
iterations = 4
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
is_changed = g.node("TestCustomIsChanged", should_change=True, image=input2.out(0))
|
||||
for_open = g.node("TestForLoopOpen", remaining=iterations, initial_value1=is_changed.out(0))
|
||||
average = g.node("TestVariadicAverage", input1=input1.out(0), input2=for_open.out(2))
|
||||
for_close = g.node("TestForLoopClose", flow_control=for_open.out(0), initial_value1=average.out(0))
|
||||
output = g.node("SaveImage", images=for_close.out(0))
|
||||
|
||||
for iterations in range(1, 5):
|
||||
for_open.set_input('remaining', iterations)
|
||||
result = await client.run(g)
|
||||
result_image = result.get_images(output)[0]
|
||||
expected = 255 // (2 ** iterations)
|
||||
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
|
||||
assert result.did_run(is_changed)
|
||||
|
||||
async def test_mixed_expansion_returns(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
val_list = g.node("TestMakeListNode", value1=0.1, value2=0.2, value3=0.3)
|
||||
mixed = g.node("TestMixedExpansionReturns", input1=val_list.out(0))
|
||||
output_dynamic = g.node("SaveImage", images=mixed.out(0))
|
||||
output_literal = g.node("SaveImage", images=mixed.out(1))
|
||||
|
||||
result = await client.run(g)
|
||||
images_dynamic = result.get_images(output_dynamic)
|
||||
assert len(images_dynamic) == 3, "Should have 2 images"
|
||||
assert numpy.array(images_dynamic[0]).min() == 25 and numpy.array(images_dynamic[0]).max() == 25, "First image should be 0.1"
|
||||
assert numpy.array(images_dynamic[1]).min() == 51 and numpy.array(images_dynamic[1]).max() == 51, "Second image should be 0.2"
|
||||
assert numpy.array(images_dynamic[2]).min() == 76 and numpy.array(images_dynamic[2]).max() == 76, "Third image should be 0.3"
|
||||
|
||||
images_literal = result.get_images(output_literal)
|
||||
assert len(images_literal) == 3, "Should have 2 images"
|
||||
for i in range(3):
|
||||
assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"
|
||||
|
||||
async def test_output_reuse(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
output1 = g.node("PreviewImage", images=input1.out(0))
|
||||
output2 = g.node("PreviewImage", images=input1.out(0))
|
||||
|
||||
result = await client.run(g)
|
||||
images1 = result.get_images(output1)
|
||||
images2 = result.get_images(output2)
|
||||
assert len(images1) == 1, "Should have 1 image"
|
||||
assert len(images2) == 1, "Should have 1 image"
|
||||
23
tests/inference/testing_nodes/testing-pack/__init__.py
Normal file
23
tests/inference/testing_nodes/testing-pack/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
from .specific_tests import TEST_NODE_CLASS_MAPPINGS, TEST_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS
|
||||
from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS)
|
||||
# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {}
|
||||
NODE_CLASS_MAPPINGS.update(TEST_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS)
|
||||
NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS)
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {}
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS)
|
||||
|
||||
194
tests/inference/testing_nodes/testing-pack/conditions.py
Normal file
194
tests/inference/testing_nodes/testing-pack/conditions.py
Normal file
@ -0,0 +1,194 @@
|
||||
import re
|
||||
import torch
|
||||
|
||||
class TestIntConditions:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
|
||||
"b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
|
||||
"operation": (["==", "!=", "<", ">", "<=", ">="],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "int_condition"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def int_condition(self, a, b, operation):
|
||||
if operation == "==":
|
||||
return (a == b,)
|
||||
elif operation == "!=":
|
||||
return (a != b,)
|
||||
elif operation == "<":
|
||||
return (a < b,)
|
||||
elif operation == ">":
|
||||
return (a > b,)
|
||||
elif operation == "<=":
|
||||
return (a <= b,)
|
||||
elif operation == ">=":
|
||||
return (a >= b,)
|
||||
|
||||
|
||||
class TestFloatConditions:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}),
|
||||
"b": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}),
|
||||
"operation": (["==", "!=", "<", ">", "<=", ">="],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "float_condition"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def float_condition(self, a, b, operation):
|
||||
if operation == "==":
|
||||
return (a == b,)
|
||||
elif operation == "!=":
|
||||
return (a != b,)
|
||||
elif operation == "<":
|
||||
return (a < b,)
|
||||
elif operation == ">":
|
||||
return (a > b,)
|
||||
elif operation == "<=":
|
||||
return (a <= b,)
|
||||
elif operation == ">=":
|
||||
return (a >= b,)
|
||||
|
||||
class TestStringConditions:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("STRING", {"multiline": False}),
|
||||
"b": ("STRING", {"multiline": False}),
|
||||
"operation": (["a == b", "a != b", "a IN b", "a MATCH REGEX(b)", "a BEGINSWITH b", "a ENDSWITH b"],),
|
||||
"case_sensitive": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "string_condition"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def string_condition(self, a, b, operation, case_sensitive):
|
||||
if not case_sensitive:
|
||||
a = a.lower()
|
||||
b = b.lower()
|
||||
|
||||
if operation == "a == b":
|
||||
return (a == b,)
|
||||
elif operation == "a != b":
|
||||
return (a != b,)
|
||||
elif operation == "a IN b":
|
||||
return (a in b,)
|
||||
elif operation == "a MATCH REGEX(b)":
|
||||
try:
|
||||
return (re.match(b, a) is not None,)
|
||||
except:
|
||||
return (False,)
|
||||
elif operation == "a BEGINSWITH b":
|
||||
return (a.startswith(b),)
|
||||
elif operation == "a ENDSWITH b":
|
||||
return (a.endswith(b),)
|
||||
|
||||
class TestToBoolNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("*",),
|
||||
},
|
||||
"optional": {
|
||||
"invert": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "to_bool"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def to_bool(self, value, invert = False):
|
||||
if isinstance(value, torch.Tensor):
|
||||
if value.max().item() == 0 and value.min().item() == 0:
|
||||
result = False
|
||||
else:
|
||||
result = True
|
||||
else:
|
||||
try:
|
||||
result = bool(value)
|
||||
except:
|
||||
# Can't convert it? Well then it's something or other. I dunno, I'm not a Python programmer.
|
||||
result = True
|
||||
|
||||
if invert:
|
||||
result = not result
|
||||
|
||||
return (result,)
|
||||
|
||||
class TestBoolOperationNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("BOOLEAN",),
|
||||
"b": ("BOOLEAN",),
|
||||
"op": (["a AND b", "a OR b", "a XOR b", "NOT a"],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOOLEAN",)
|
||||
FUNCTION = "bool_operation"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def bool_operation(self, a, b, op):
|
||||
if op == "a AND b":
|
||||
return (a and b,)
|
||||
elif op == "a OR b":
|
||||
return (a or b,)
|
||||
elif op == "a XOR b":
|
||||
return (a ^ b,)
|
||||
elif op == "NOT a":
|
||||
return (not a,)
|
||||
|
||||
|
||||
CONDITION_NODE_CLASS_MAPPINGS = {
|
||||
"TestIntConditions": TestIntConditions,
|
||||
"TestFloatConditions": TestFloatConditions,
|
||||
"TestStringConditions": TestStringConditions,
|
||||
"TestToBoolNode": TestToBoolNode,
|
||||
"TestBoolOperationNode": TestBoolOperationNode,
|
||||
}
|
||||
|
||||
CONDITION_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestIntConditions": "Int Condition",
|
||||
"TestFloatConditions": "Float Condition",
|
||||
"TestStringConditions": "String Condition",
|
||||
"TestToBoolNode": "To Bool",
|
||||
"TestBoolOperationNode": "Bool Operation",
|
||||
}
|
||||
173
tests/inference/testing_nodes/testing-pack/flow_control.py
Normal file
173
tests/inference/testing_nodes/testing-pack/flow_control.py
Normal file
@ -0,0 +1,173 @@
|
||||
from comfy.graph_utils import GraphBuilder, is_link
|
||||
from comfy.graph import ExecutionBlocker
|
||||
from .tools import VariantSupport
|
||||
|
||||
NUM_FLOW_SOCKETS = 5
|
||||
@VariantSupport()
|
||||
class TestWhileLoopOpen:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
inputs = {
|
||||
"required": {
|
||||
"condition": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
}
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
inputs["optional"][f"initial_value{i}"] = ("*",)
|
||||
return inputs
|
||||
|
||||
RETURN_TYPES = tuple(["FLOW_CONTROL"] + ["*"] * NUM_FLOW_SOCKETS)
|
||||
RETURN_NAMES = tuple(["FLOW_CONTROL"] + [f"value{i}" for i in range(NUM_FLOW_SOCKETS)])
|
||||
FUNCTION = "while_loop_open"
|
||||
|
||||
CATEGORY = "Testing/Flow"
|
||||
|
||||
def while_loop_open(self, condition, **kwargs):
|
||||
values = []
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
values.append(kwargs.get(f"initial_value{i}", None))
|
||||
return tuple(["stub"] + values)
|
||||
|
||||
@VariantSupport()
|
||||
class TestWhileLoopClose:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
inputs = {
|
||||
"required": {
|
||||
"flow_control": ("FLOW_CONTROL", {"rawLink": True}),
|
||||
"condition": ("BOOLEAN", {"forceInput": True}),
|
||||
},
|
||||
"optional": {
|
||||
},
|
||||
"hidden": {
|
||||
"dynprompt": "DYNPROMPT",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
}
|
||||
}
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
inputs["optional"][f"initial_value{i}"] = ("*",)
|
||||
return inputs
|
||||
|
||||
RETURN_TYPES = tuple(["*"] * NUM_FLOW_SOCKETS)
|
||||
RETURN_NAMES = tuple([f"value{i}" for i in range(NUM_FLOW_SOCKETS)])
|
||||
FUNCTION = "while_loop_close"
|
||||
|
||||
CATEGORY = "Testing/Flow"
|
||||
|
||||
def explore_dependencies(self, node_id, dynprompt, upstream):
|
||||
node_info = dynprompt.get_node(node_id)
|
||||
if "inputs" not in node_info:
|
||||
return
|
||||
for k, v in node_info["inputs"].items():
|
||||
if is_link(v):
|
||||
parent_id = v[0]
|
||||
if parent_id not in upstream:
|
||||
upstream[parent_id] = []
|
||||
self.explore_dependencies(parent_id, dynprompt, upstream)
|
||||
upstream[parent_id].append(node_id)
|
||||
|
||||
def collect_contained(self, node_id, upstream, contained):
|
||||
if node_id not in upstream:
|
||||
return
|
||||
for child_id in upstream[node_id]:
|
||||
if child_id not in contained:
|
||||
contained[child_id] = True
|
||||
self.collect_contained(child_id, upstream, contained)
|
||||
|
||||
|
||||
def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=None, **kwargs):
|
||||
assert dynprompt is not None
|
||||
if not condition:
|
||||
# We're done with the loop
|
||||
values = []
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
values.append(kwargs.get(f"initial_value{i}", None))
|
||||
return tuple(values)
|
||||
|
||||
# We want to loop
|
||||
upstream = {}
|
||||
# Get the list of all nodes between the open and close nodes
|
||||
self.explore_dependencies(unique_id, dynprompt, upstream)
|
||||
|
||||
contained = {}
|
||||
open_node = flow_control[0]
|
||||
self.collect_contained(open_node, upstream, contained)
|
||||
contained[unique_id] = True
|
||||
contained[open_node] = True
|
||||
|
||||
# We'll use the default prefix, but to avoid having node names grow exponentially in size,
|
||||
# we'll use "Recurse" for the name of the recursively-generated copy of this node.
|
||||
graph = GraphBuilder()
|
||||
for node_id in contained:
|
||||
original_node = dynprompt.get_node(node_id)
|
||||
node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id)
|
||||
node.set_override_display_id(node_id)
|
||||
for node_id in contained:
|
||||
original_node = dynprompt.get_node(node_id)
|
||||
node = graph.lookup_node("Recurse" if node_id == unique_id else node_id)
|
||||
assert node is not None
|
||||
for k, v in original_node["inputs"].items():
|
||||
if is_link(v) and v[0] in contained:
|
||||
parent = graph.lookup_node(v[0])
|
||||
assert parent is not None
|
||||
node.set_input(k, parent.out(v[1]))
|
||||
else:
|
||||
node.set_input(k, v)
|
||||
new_open = graph.lookup_node(open_node)
|
||||
assert new_open is not None
|
||||
for i in range(NUM_FLOW_SOCKETS):
|
||||
key = f"initial_value{i}"
|
||||
new_open.set_input(key, kwargs.get(key, None))
|
||||
my_clone = graph.lookup_node("Recurse")
|
||||
assert my_clone is not None
|
||||
result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS))
|
||||
return {
|
||||
"result": tuple(result),
|
||||
"expand": graph.finalize(),
|
||||
}
|
||||
|
||||
@VariantSupport()
|
||||
class TestExecutionBlockerNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
inputs = {
|
||||
"required": {
|
||||
"input": ("*",),
|
||||
"block": ("BOOLEAN",),
|
||||
"verbose": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
}
|
||||
return inputs
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
RETURN_NAMES = ("output",)
|
||||
FUNCTION = "execution_blocker"
|
||||
|
||||
CATEGORY = "Testing/Flow"
|
||||
|
||||
def execution_blocker(self, input, block, verbose):
|
||||
if block:
|
||||
return (ExecutionBlocker("Blocked Execution" if verbose else None),)
|
||||
return (input,)
|
||||
|
||||
FLOW_CONTROL_NODE_CLASS_MAPPINGS = {
|
||||
"TestWhileLoopOpen": TestWhileLoopOpen,
|
||||
"TestWhileLoopClose": TestWhileLoopClose,
|
||||
"TestExecutionBlocker": TestExecutionBlockerNode,
|
||||
}
|
||||
FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestWhileLoopOpen": "While Loop Open",
|
||||
"TestWhileLoopClose": "While Loop Close",
|
||||
"TestExecutionBlocker": "Execution Blocker",
|
||||
}
|
||||
308
tests/inference/testing_nodes/testing-pack/specific_tests.py
Normal file
308
tests/inference/testing_nodes/testing-pack/specific_tests.py
Normal file
@ -0,0 +1,308 @@
|
||||
import torch
|
||||
from .tools import VariantSupport
|
||||
from comfy.graph_utils import GraphBuilder
|
||||
|
||||
class TestLazyMixImages:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image1": ("IMAGE",{"lazy": True}),
|
||||
"image2": ("IMAGE",{"lazy": True}),
|
||||
"mask": ("MASK",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "mix"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def check_lazy_status(self, mask, image1, image2):
|
||||
mask_min = mask.min()
|
||||
mask_max = mask.max()
|
||||
needed = []
|
||||
if image1 is None and (mask_min != 1.0 or mask_max != 1.0):
|
||||
needed.append("image1")
|
||||
if image2 is None and (mask_min != 0.0 or mask_max != 0.0):
|
||||
needed.append("image2")
|
||||
return needed
|
||||
|
||||
# Not trying to handle different batch sizes here just to keep the demo simple
|
||||
def mix(self, mask, image1, image2):
|
||||
mask_min = mask.min()
|
||||
mask_max = mask.max()
|
||||
if mask_min == 0.0 and mask_max == 0.0:
|
||||
return (image1,)
|
||||
elif mask_min == 1.0 and mask_max == 1.0:
|
||||
return (image2,)
|
||||
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if len(mask.shape) == 3:
|
||||
mask = mask.unsqueeze(3)
|
||||
if mask.shape[3] < image1.shape[3]:
|
||||
mask = mask.repeat(1, 1, 1, image1.shape[3])
|
||||
|
||||
result = image1 * (1. - mask) + image2 * mask,
|
||||
return (result[0],)
|
||||
|
||||
class TestVariadicAverage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "variadic_average"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def variadic_average(self, input1, **kwargs):
|
||||
inputs = [input1]
|
||||
while 'input' + str(len(inputs) + 1) in kwargs:
|
||||
inputs.append(kwargs['input' + str(len(inputs) + 1)])
|
||||
return (torch.stack(inputs).mean(dim=0),)
|
||||
|
||||
|
||||
class TestCustomIsChanged:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
},
|
||||
"optional": {
|
||||
"should_change": ("BOOL", {"default": False}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_is_changed"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_is_changed(self, image, should_change=False):
|
||||
return (image,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, should_change=False, *args, **kwargs):
|
||||
if should_change:
|
||||
return float("NaN")
|
||||
else:
|
||||
return False
|
||||
|
||||
class TestCustomValidation1:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("IMAGE,FLOAT",),
|
||||
"input2": ("IMAGE,FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_validation1"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_validation1(self, input1, input2):
|
||||
if isinstance(input1, float) and isinstance(input2, float):
|
||||
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
||||
else:
|
||||
result = input1 * input2
|
||||
return (result,)
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, input1=None, input2=None):
|
||||
if input1 is not None:
|
||||
if not isinstance(input1, (torch.Tensor, float)):
|
||||
return f"Invalid type of input1: {type(input1)}"
|
||||
if input2 is not None:
|
||||
if not isinstance(input2, (torch.Tensor, float)):
|
||||
return f"Invalid type of input2: {type(input2)}"
|
||||
|
||||
return True
|
||||
|
||||
class TestCustomValidation2:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("IMAGE,FLOAT",),
|
||||
"input2": ("IMAGE,FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_validation2"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_validation2(self, input1, input2):
|
||||
if isinstance(input1, float) and isinstance(input2, float):
|
||||
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
||||
else:
|
||||
result = input1 * input2
|
||||
return (result,)
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None):
|
||||
if input1 is not None:
|
||||
if not isinstance(input1, (torch.Tensor, float)):
|
||||
return f"Invalid type of input1: {type(input1)}"
|
||||
if input2 is not None:
|
||||
if not isinstance(input2, (torch.Tensor, float)):
|
||||
return f"Invalid type of input2: {type(input2)}"
|
||||
|
||||
if 'input1' in input_types:
|
||||
if input_types['input1'] not in ["IMAGE", "FLOAT"]:
|
||||
return f"Invalid type of input1: {input_types['input1']}"
|
||||
if 'input2' in input_types:
|
||||
if input_types['input2'] not in ["IMAGE", "FLOAT"]:
|
||||
return f"Invalid type of input2: {input_types['input2']}"
|
||||
|
||||
return True
|
||||
|
||||
@VariantSupport()
|
||||
class TestCustomValidation3:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("IMAGE,FLOAT",),
|
||||
"input2": ("IMAGE,FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_validation3"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_validation3(self, input1, input2):
|
||||
if isinstance(input1, float) and isinstance(input2, float):
|
||||
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
||||
else:
|
||||
result = input1 * input2
|
||||
return (result,)
|
||||
|
||||
class TestCustomValidation4:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("FLOAT",),
|
||||
"input2": ("FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_validation4"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_validation4(self, input1, input2):
|
||||
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
||||
return (result,)
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(cls, input1, input2):
|
||||
if input1 is not None:
|
||||
if not isinstance(input1, float):
|
||||
return f"Invalid type of input1: {type(input1)}"
|
||||
if input2 is not None:
|
||||
if not isinstance(input2, float):
|
||||
return f"Invalid type of input2: {type(input2)}"
|
||||
|
||||
return True
|
||||
|
||||
class TestDynamicDependencyCycle:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("IMAGE",),
|
||||
"input2": ("IMAGE",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "dynamic_dependency_cycle"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def dynamic_dependency_cycle(self, input1, input2):
|
||||
g = GraphBuilder()
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
mix1 = g.node("TestLazyMixImages", image1=input1, mask=mask.out(0))
|
||||
mix2 = g.node("TestLazyMixImages", image1=mix1.out(0), image2=input2, mask=mask.out(0))
|
||||
|
||||
# Create the cyle
|
||||
mix1.set_input("image2", mix2.out(0))
|
||||
|
||||
return {
|
||||
"result": (mix2.out(0),),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
class TestMixedExpansionReturns:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"input1": ("FLOAT",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE","IMAGE")
|
||||
FUNCTION = "mixed_expansion_returns"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def mixed_expansion_returns(self, input1):
|
||||
white_image = torch.ones([1, 512, 512, 3])
|
||||
if input1 <= 0.1:
|
||||
return (torch.ones([1, 512, 512, 3]) * 0.1, white_image)
|
||||
elif input1 <= 0.2:
|
||||
return {
|
||||
"result": (torch.ones([1, 512, 512, 3]) * 0.2, white_image),
|
||||
}
|
||||
else:
|
||||
g = GraphBuilder()
|
||||
mask = g.node("StubMask", value=0.3, height=512, width=512, batch_size=1)
|
||||
black = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
white = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mix = g.node("TestLazyMixImages", image1=black.out(0), image2=white.out(0), mask=mask.out(0))
|
||||
return {
|
||||
"result": (mix.out(0), white_image),
|
||||
"expand": g.finalize(),
|
||||
}
|
||||
|
||||
TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
"TestCustomIsChanged": TestCustomIsChanged,
|
||||
"TestCustomValidation1": TestCustomValidation1,
|
||||
"TestCustomValidation2": TestCustomValidation2,
|
||||
"TestCustomValidation3": TestCustomValidation3,
|
||||
"TestCustomValidation4": TestCustomValidation4,
|
||||
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
|
||||
"TestMixedExpansionReturns": TestMixedExpansionReturns,
|
||||
}
|
||||
|
||||
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestLazyMixImages": "Lazy Mix Images",
|
||||
"TestVariadicAverage": "Variadic Average",
|
||||
"TestCustomIsChanged": "Custom IsChanged",
|
||||
"TestCustomValidation1": "Custom Validation 1",
|
||||
"TestCustomValidation2": "Custom Validation 2",
|
||||
"TestCustomValidation3": "Custom Validation 3",
|
||||
"TestCustomValidation4": "Custom Validation 4",
|
||||
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
|
||||
"TestMixedExpansionReturns": "Mixed Expansion Returns",
|
||||
}
|
||||
105
tests/inference/testing_nodes/testing-pack/stubs.py
Normal file
105
tests/inference/testing_nodes/testing-pack/stubs.py
Normal file
@ -0,0 +1,105 @@
|
||||
import torch
|
||||
|
||||
class StubImage:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"content": (['WHITE', 'BLACK', 'NOISE'],),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stub_image"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_image(self, content, height, width, batch_size):
|
||||
if content == "WHITE":
|
||||
return (torch.ones(batch_size, height, width, 3),)
|
||||
elif content == "BLACK":
|
||||
return (torch.zeros(batch_size, height, width, 3),)
|
||||
elif content == "NOISE":
|
||||
return (torch.rand(batch_size, height, width, 3),)
|
||||
|
||||
class StubMask:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "stub_mask"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_mask(self, value, height, width, batch_size):
|
||||
return (torch.ones(batch_size, height, width) * value,)
|
||||
|
||||
class StubInt:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "stub_int"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_int(self, value):
|
||||
return (value,)
|
||||
|
||||
class StubFloat:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("FLOAT",)
|
||||
FUNCTION = "stub_float"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_float(self, value):
|
||||
return (value,)
|
||||
|
||||
TEST_STUB_NODE_CLASS_MAPPINGS = {
|
||||
"StubImage": StubImage,
|
||||
"StubMask": StubMask,
|
||||
"StubInt": StubInt,
|
||||
"StubFloat": StubFloat,
|
||||
}
|
||||
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StubImage": "Stub Image",
|
||||
"StubMask": "Stub Mask",
|
||||
"StubInt": "Stub Int",
|
||||
"StubFloat": "Stub Float",
|
||||
}
|
||||
53
tests/inference/testing_nodes/testing-pack/tools.py
Normal file
53
tests/inference/testing_nodes/testing-pack/tools.py
Normal file
@ -0,0 +1,53 @@
|
||||
|
||||
def MakeSmartType(t):
|
||||
if isinstance(t, str):
|
||||
return SmartType(t)
|
||||
return t
|
||||
|
||||
class SmartType(str):
|
||||
def __ne__(self, other):
|
||||
if self == "*" or other == "*":
|
||||
return False
|
||||
selfset = set(self.split(','))
|
||||
otherset = set(other.split(','))
|
||||
return not selfset.issubset(otherset)
|
||||
|
||||
def VariantSupport():
|
||||
def decorator(cls):
|
||||
if hasattr(cls, "INPUT_TYPES"):
|
||||
old_input_types = getattr(cls, "INPUT_TYPES")
|
||||
def new_input_types(*args, **kwargs):
|
||||
types = old_input_types(*args, **kwargs)
|
||||
for category in ["required", "optional"]:
|
||||
if category not in types:
|
||||
continue
|
||||
for key, value in types[category].items():
|
||||
if isinstance(value, tuple):
|
||||
types[category][key] = (MakeSmartType(value[0]),) + value[1:]
|
||||
return types
|
||||
setattr(cls, "INPUT_TYPES", new_input_types)
|
||||
if hasattr(cls, "RETURN_TYPES"):
|
||||
old_return_types = cls.RETURN_TYPES
|
||||
setattr(cls, "RETURN_TYPES", tuple(MakeSmartType(x) for x in old_return_types))
|
||||
if hasattr(cls, "VALIDATE_INPUTS"):
|
||||
# Reflection is used to determine what the function signature is, so we can't just change the function signature
|
||||
raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet")
|
||||
else:
|
||||
def validate_inputs(input_types):
|
||||
inputs = cls.INPUT_TYPES()
|
||||
for key, value in input_types.items():
|
||||
if isinstance(value, SmartType):
|
||||
continue
|
||||
if "required" in inputs and key in inputs["required"]:
|
||||
expected_type = inputs["required"][key][0]
|
||||
elif "optional" in inputs and key in inputs["optional"]:
|
||||
expected_type = inputs["optional"][key][0]
|
||||
else:
|
||||
expected_type = None
|
||||
if expected_type is not None and MakeSmartType(value) != expected_type:
|
||||
return f"Invalid type of {key}: {value} (expected {expected_type})"
|
||||
return True
|
||||
setattr(cls, "VALIDATE_INPUTS", validate_inputs)
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
364
tests/inference/testing_nodes/testing-pack/util.py
Normal file
364
tests/inference/testing_nodes/testing-pack/util.py
Normal file
@ -0,0 +1,364 @@
|
||||
from comfy.graph_utils import GraphBuilder
|
||||
from .tools import VariantSupport
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulateNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"to_add": ("*",),
|
||||
},
|
||||
"optional": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ACCUMULATION",)
|
||||
FUNCTION = "accumulate"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def accumulate(self, to_add, accumulation = None):
|
||||
if accumulation is None:
|
||||
value = [to_add]
|
||||
else:
|
||||
value = accumulation["accum"] + [to_add]
|
||||
return ({"accum": value},)
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationHeadNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ACCUMULATION", "*",)
|
||||
FUNCTION = "accumulation_head"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def accumulation_head(self, accumulation):
|
||||
accum = accumulation["accum"]
|
||||
if len(accum) == 0:
|
||||
return (accumulation, None)
|
||||
else:
|
||||
return ({"accum": accum[1:]}, accum[0])
|
||||
|
||||
class TestAccumulationTailNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ACCUMULATION", "*",)
|
||||
FUNCTION = "accumulation_tail"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def accumulation_tail(self, accumulation):
|
||||
accum = accumulation["accum"]
|
||||
if len(accum) == 0:
|
||||
return (None, accumulation)
|
||||
else:
|
||||
return ({"accum": accum[:-1]}, accum[-1])
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationToListNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
OUTPUT_IS_LIST = (True,)
|
||||
|
||||
FUNCTION = "accumulation_to_list"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def accumulation_to_list(self, accumulation):
|
||||
return (accumulation["accum"],)
|
||||
|
||||
@VariantSupport()
|
||||
class TestListToAccumulationNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"list": ("*",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ACCUMULATION",)
|
||||
INPUT_IS_LIST = (True,)
|
||||
|
||||
FUNCTION = "list_to_accumulation"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def list_to_accumulation(self, list):
|
||||
return ({"accum": list},)
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationGetLengthNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT",)
|
||||
|
||||
FUNCTION = "accumlength"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def accumlength(self, accumulation):
|
||||
return (len(accumulation['accum']),)
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationGetItemNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
"index": ("INT", {"default":0, "step":1})
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
|
||||
FUNCTION = "get_item"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def get_item(self, accumulation, index):
|
||||
return (accumulation['accum'][index],)
|
||||
|
||||
@VariantSupport()
|
||||
class TestAccumulationSetItemNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"accumulation": ("ACCUMULATION",),
|
||||
"index": ("INT", {"default":0, "step":1}),
|
||||
"value": ("*",),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("ACCUMULATION",)
|
||||
|
||||
FUNCTION = "set_item"
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def set_item(self, accumulation, index, value):
|
||||
new_accum = accumulation['accum'][:]
|
||||
new_accum[index] = value
|
||||
return ({"accum": new_accum},)
|
||||
|
||||
class TestIntMathOperation:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
|
||||
"b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}),
|
||||
"operation": (["add", "subtract", "multiply", "divide", "modulo", "power"],),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("INT",)
|
||||
FUNCTION = "int_math_operation"
|
||||
|
||||
CATEGORY = "Testing/Logic"
|
||||
|
||||
def int_math_operation(self, a, b, operation):
|
||||
if operation == "add":
|
||||
return (a + b,)
|
||||
elif operation == "subtract":
|
||||
return (a - b,)
|
||||
elif operation == "multiply":
|
||||
return (a * b,)
|
||||
elif operation == "divide":
|
||||
return (a // b,)
|
||||
elif operation == "modulo":
|
||||
return (a % b,)
|
||||
elif operation == "power":
|
||||
return (a ** b,)
|
||||
|
||||
|
||||
from .flow_control import NUM_FLOW_SOCKETS
|
||||
@VariantSupport()
|
||||
class TestForLoopOpen:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"remaining": ("INT", {"default": 1, "min": 0, "max": 100000, "step": 1}),
|
||||
},
|
||||
"optional": {
|
||||
f"initial_value{i}": ("*",) for i in range(1, NUM_FLOW_SOCKETS)
|
||||
},
|
||||
"hidden": {
|
||||
"initial_value0": ("*",)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = tuple(["FLOW_CONTROL", "INT",] + ["*"] * (NUM_FLOW_SOCKETS-1))
|
||||
RETURN_NAMES = tuple(["flow_control", "remaining"] + [f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)])
|
||||
FUNCTION = "for_loop_open"
|
||||
|
||||
CATEGORY = "Testing/Flow"
|
||||
|
||||
def for_loop_open(self, remaining, **kwargs):
|
||||
graph = GraphBuilder()
|
||||
if "initial_value0" in kwargs:
|
||||
remaining = kwargs["initial_value0"]
|
||||
while_open = graph.node("TestWhileLoopOpen", condition=remaining, initial_value0=remaining, **{(f"initial_value{i}"): kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)})
|
||||
outputs = [kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)]
|
||||
return {
|
||||
"result": tuple(["stub", remaining] + outputs),
|
||||
"expand": graph.finalize(),
|
||||
}
|
||||
|
||||
@VariantSupport()
|
||||
class TestForLoopClose:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"flow_control": ("FLOW_CONTROL", {"rawLink": True}),
|
||||
},
|
||||
"optional": {
|
||||
f"initial_value{i}": ("*",{"rawLink": True}) for i in range(1, NUM_FLOW_SOCKETS)
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = tuple(["*"] * (NUM_FLOW_SOCKETS-1))
|
||||
RETURN_NAMES = tuple([f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)])
|
||||
FUNCTION = "for_loop_close"
|
||||
|
||||
CATEGORY = "Testing/Flow"
|
||||
|
||||
def for_loop_close(self, flow_control, **kwargs):
|
||||
graph = GraphBuilder()
|
||||
while_open = flow_control[0]
|
||||
sub = graph.node("TestIntMathOperation", operation="subtract", a=[while_open,1], b=1)
|
||||
cond = graph.node("TestToBoolNode", value=sub.out(0))
|
||||
input_values = {f"initial_value{i}": kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)}
|
||||
while_close = graph.node("TestWhileLoopClose",
|
||||
flow_control=flow_control,
|
||||
condition=cond.out(0),
|
||||
initial_value0=sub.out(0),
|
||||
**input_values)
|
||||
return {
|
||||
"result": tuple([while_close.out(i) for i in range(1, NUM_FLOW_SOCKETS)]),
|
||||
"expand": graph.finalize(),
|
||||
}
|
||||
|
||||
NUM_LIST_SOCKETS = 10
|
||||
@VariantSupport()
|
||||
class TestMakeListNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value1": ("*",),
|
||||
},
|
||||
"optional": {
|
||||
f"value{i}": ("*",) for i in range(1, NUM_LIST_SOCKETS)
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("*",)
|
||||
FUNCTION = "make_list"
|
||||
OUTPUT_IS_LIST = (True,)
|
||||
|
||||
CATEGORY = "Testing/Lists"
|
||||
|
||||
def make_list(self, **kwargs):
|
||||
result = []
|
||||
for i in range(NUM_LIST_SOCKETS):
|
||||
if f"value{i}" in kwargs:
|
||||
result.append(kwargs[f"value{i}"])
|
||||
return (result,)
|
||||
|
||||
UTILITY_NODE_CLASS_MAPPINGS = {
|
||||
"TestAccumulateNode": TestAccumulateNode,
|
||||
"TestAccumulationHeadNode": TestAccumulationHeadNode,
|
||||
"TestAccumulationTailNode": TestAccumulationTailNode,
|
||||
"TestAccumulationToListNode": TestAccumulationToListNode,
|
||||
"TestListToAccumulationNode": TestListToAccumulationNode,
|
||||
"TestAccumulationGetLengthNode": TestAccumulationGetLengthNode,
|
||||
"TestAccumulationGetItemNode": TestAccumulationGetItemNode,
|
||||
"TestAccumulationSetItemNode": TestAccumulationSetItemNode,
|
||||
"TestForLoopOpen": TestForLoopOpen,
|
||||
"TestForLoopClose": TestForLoopClose,
|
||||
"TestIntMathOperation": TestIntMathOperation,
|
||||
"TestMakeListNode": TestMakeListNode,
|
||||
}
|
||||
UTILITY_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestAccumulateNode": "Accumulate",
|
||||
"TestAccumulationHeadNode": "Accumulation Head",
|
||||
"TestAccumulationTailNode": "Accumulation Tail",
|
||||
"TestAccumulationToListNode": "Accumulation to List",
|
||||
"TestListToAccumulationNode": "List to Accumulation",
|
||||
"TestAccumulationGetLengthNode": "Accumulation Get Length",
|
||||
"TestAccumulationGetItemNode": "Accumulation Get Item",
|
||||
"TestAccumulationSetItemNode": "Accumulation Set Item",
|
||||
"TestForLoopOpen": "For Loop Open",
|
||||
"TestForLoopClose": "For Loop Close",
|
||||
"TestIntMathOperation": "Int Math Operation",
|
||||
"TestMakeListNode": "Make List",
|
||||
}
|
||||
0
tests/unit/app_test/__init__.py
Normal file
0
tests/unit/app_test/__init__.py
Normal file
101
tests/unit/app_test/frontend_manager_test.py
Normal file
101
tests/unit/app_test/frontend_manager_test.py
Normal file
@ -0,0 +1,101 @@
|
||||
import argparse
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from comfy.app.frontend_management import (
|
||||
FrontendManager,
|
||||
FrontEndProvider,
|
||||
Release,
|
||||
)
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_releases():
|
||||
return [
|
||||
Release(
|
||||
id=1,
|
||||
tag_name="1.0.0",
|
||||
name="Release 1.0.0",
|
||||
prerelease=False,
|
||||
created_at="2022-01-01T00:00:00Z",
|
||||
published_at="2022-01-01T00:00:00Z",
|
||||
body="Release notes for 1.0.0",
|
||||
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
|
||||
),
|
||||
Release(
|
||||
id=2,
|
||||
tag_name="2.0.0",
|
||||
name="Release 2.0.0",
|
||||
prerelease=False,
|
||||
created_at="2022-02-01T00:00:00Z",
|
||||
published_at="2022-02-01T00:00:00Z",
|
||||
body="Release notes for 2.0.0",
|
||||
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider(mock_releases):
|
||||
provider = FrontEndProvider(
|
||||
owner="test-owner",
|
||||
repo="test-repo",
|
||||
)
|
||||
provider.all_releases = mock_releases
|
||||
provider.latest_release = mock_releases[1]
|
||||
FrontendManager.PROVIDERS = [provider]
|
||||
return provider
|
||||
|
||||
|
||||
def test_get_release(mock_provider, mock_releases):
|
||||
version = "1.0.0"
|
||||
release = mock_provider.get_release(version)
|
||||
assert release == mock_releases[0]
|
||||
|
||||
|
||||
def test_get_release_latest(mock_provider, mock_releases):
|
||||
version = "latest"
|
||||
release = mock_provider.get_release(version)
|
||||
assert release == mock_releases[1]
|
||||
|
||||
|
||||
def test_get_release_invalid_version(mock_provider):
|
||||
version = "invalid"
|
||||
with pytest.raises(ValueError):
|
||||
mock_provider.get_release(version)
|
||||
|
||||
|
||||
def test_init_frontend_default():
|
||||
version_string = DEFAULT_VERSION_STRING
|
||||
frontend_path = FrontendManager.init_frontend(version_string)
|
||||
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
|
||||
|
||||
|
||||
def test_init_frontend_invalid_version():
|
||||
version_string = "test-owner/test-repo@1.100.99"
|
||||
with pytest.raises(HTTPError):
|
||||
FrontendManager.init_frontend_unsafe(version_string)
|
||||
|
||||
|
||||
def test_init_frontend_invalid_provider():
|
||||
version_string = "invalid/invalid@latest"
|
||||
with pytest.raises(HTTPError):
|
||||
FrontendManager.init_frontend_unsafe(version_string)
|
||||
|
||||
|
||||
def test_parse_version_string():
|
||||
version_string = "owner/repo@1.0.0"
|
||||
repo_owner, repo_name, version = FrontendManager.parse_version_string(
|
||||
version_string
|
||||
)
|
||||
assert repo_owner == "owner"
|
||||
assert repo_name == "repo"
|
||||
assert version == "1.0.0"
|
||||
|
||||
|
||||
def test_parse_version_string_invalid():
|
||||
version_string = "invalid"
|
||||
with pytest.raises(argparse.ArgumentTypeError):
|
||||
FrontendManager.parse_version_string(version_string)
|
||||
Loading…
Reference in New Issue
Block a user