diff --git a/comfy/caching.py b/comfy/caching.py new file mode 100644 index 000000000..c8ac1b58d --- /dev/null +++ b/comfy/caching.py @@ -0,0 +1,305 @@ +import itertools +from typing import Sequence, Mapping + +from .cmd.execution import nodes +from .graph import DynamicPrompt +from .graph_utils import is_link + + +class CacheKeySet: + def __init__(self, dynprompt, node_ids, is_changed_cache): + self.keys = {} + self.subcache_keys = {} + + def add_keys(self, node_ids): + raise NotImplementedError() + + def all_node_ids(self): + return set(self.keys.keys()) + + def get_used_keys(self): + return self.keys.values() + + def get_used_subcache_keys(self): + return self.subcache_keys.values() + + def get_data_key(self, node_id): + return self.keys.get(node_id, None) + + def get_subcache_key(self, node_id): + return self.subcache_keys.get(node_id, None) + + +class Unhashable: + def __init__(self): + self.value = float("NaN") + + +def to_hashable(obj): + # So that we don't infinitely recurse since frozenset and tuples + # are Sequences. + if isinstance(obj, (int, float, str, bool, type(None))): + return obj + elif isinstance(obj, Mapping): + return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())]) + elif isinstance(obj, Sequence): + return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj])) + else: + # TODO - Support other objects like tensors? + return Unhashable() + + +class CacheKeySetID(CacheKeySet): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + self.dynprompt = dynprompt + self.add_keys(node_ids) + + def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = (node_id, node["class_type"]) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + + +class CacheKeySetInputSignature(CacheKeySet): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + self.dynprompt = dynprompt + self.is_changed_cache = is_changed_cache + self.add_keys(node_ids) + + def include_node_id_in_input(self) -> bool: + return False + + def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + + def get_node_signature(self, dynprompt, node_id): + signature = [] + ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) + signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) + for ancestor_id in ancestors: + signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) + return to_hashable(signature) + + def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + node = dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + signature = [class_type, self.is_changed_cache.get(node_id)] + if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT): + signature.append(node_id) + inputs = node["inputs"] + for key in sorted(inputs.keys()): + if is_link(inputs[key]): + (ancestor_id, ancestor_socket) = inputs[key] + ancestor_index = ancestor_order_mapping[ancestor_id] + signature.append((key, ("ANCESTOR", ancestor_index, ancestor_socket))) + else: + signature.append((key, inputs[key])) + return signature + + # This function returns a list of all ancestors of the given node. The order of the list is + # deterministic based on which specific inputs the ancestor is connected by. + def get_ordered_ancestry(self, dynprompt, node_id): + ancestors = [] + order_mapping = {} + self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping) + return ancestors, order_mapping + + def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): + inputs = dynprompt.get_node(node_id)["inputs"] + input_keys = sorted(inputs.keys()) + for key in input_keys: + if is_link(inputs[key]): + ancestor_id = inputs[key][0] + if ancestor_id not in order_mapping: + ancestors.append(ancestor_id) + order_mapping[ancestor_id] = len(ancestors) - 1 + self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) + + +class BasicCache: + def __init__(self, key_class): + self.key_class = key_class + self.initialized = False + self.dynprompt: DynamicPrompt + self.cache_key_set: CacheKeySet + self.cache = {} + self.subcaches = {} + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + self.dynprompt = dynprompt + self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) + self.is_changed_cache = is_changed_cache + self.initialized = True + + def all_node_ids(self): + assert self.initialized + node_ids = self.cache_key_set.all_node_ids() + for subcache in self.subcaches.values(): + node_ids = node_ids.union(subcache.all_node_ids()) + return node_ids + + def _clean_cache(self): + preserve_keys = set(self.cache_key_set.get_used_keys()) + to_remove = [] + for key in self.cache: + if key not in preserve_keys: + to_remove.append(key) + for key in to_remove: + del self.cache[key] + + def _clean_subcaches(self): + preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys()) + + to_remove = [] + for key in self.subcaches: + if key not in preserve_subcaches: + to_remove.append(key) + for key in to_remove: + del self.subcaches[key] + + def clean_unused(self): + assert self.initialized + self._clean_cache() + self._clean_subcaches() + + def _set_immediate(self, node_id, value): + assert self.initialized + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + def _get_immediate(self, node_id): + if not self.initialized: + return None + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key in self.cache: + return self.cache[cache_key] + else: + return None + + def _ensure_subcache(self, node_id, children_ids): + subcache_key = self.cache_key_set.get_subcache_key(node_id) + subcache = self.subcaches.get(subcache_key, None) + if subcache is None: + subcache = BasicCache(self.key_class) + self.subcaches[subcache_key] = subcache + subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) + return subcache + + def _get_subcache(self, node_id): + assert self.initialized + subcache_key = self.cache_key_set.get_subcache_key(node_id) + if subcache_key in self.subcaches: + return self.subcaches[subcache_key] + else: + return None + + def recursive_debug_dump(self): + result = [] + for key in self.cache: + result.append({"key": key, "value": self.cache[key]}) + for key in self.subcaches: + result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()}) + return result + + +class HierarchicalCache(BasicCache): + def __init__(self, key_class): + super().__init__(key_class) + + def _get_cache_for(self, node_id): + assert self.dynprompt is not None + parent_id = self.dynprompt.get_parent_node_id(node_id) + if parent_id is None: + return self + + hierarchy = [] + while parent_id is not None: + hierarchy.append(parent_id) + parent_id = self.dynprompt.get_parent_node_id(parent_id) + + cache = self + for parent_id in reversed(hierarchy): + cache = cache._get_subcache(parent_id) + if cache is None: + return None + return cache + + def get(self, node_id): + cache = self._get_cache_for(node_id) + if cache is None: + return None + return cache._get_immediate(node_id) + + def set(self, node_id, value): + cache = self._get_cache_for(node_id) + assert cache is not None + cache._set_immediate(node_id, value) + + def ensure_subcache_for(self, node_id, children_ids): + cache = self._get_cache_for(node_id) + assert cache is not None + return cache._ensure_subcache(node_id, children_ids) + + +class LRUCache(BasicCache): + def __init__(self, key_class, max_size=100): + super().__init__(key_class) + self.max_size = max_size + self.min_generation = 0 + self.generation = 0 + self.used_generation = {} + self.children = {} + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + super().set_prompt(dynprompt, node_ids, is_changed_cache) + self.generation += 1 + for node_id in node_ids: + self._mark_used(node_id) + + def clean_unused(self): + while len(self.cache) > self.max_size and self.min_generation < self.generation: + self.min_generation += 1 + to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation] + for key in to_remove: + del self.cache[key] + del self.used_generation[key] + if key in self.children: + del self.children[key] + self._clean_subcaches() + + def get(self, node_id): + self._mark_used(node_id) + return self._get_immediate(node_id) + + def _mark_used(self, node_id): + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key is not None: + self.used_generation[cache_key] = self.generation + + def set(self, node_id, value): + self._mark_used(node_id) + return self._set_immediate(node_id, value) + + def ensure_subcache_for(self, node_id, children_ids): + # Just uses subcaches for tracking 'live' nodes + super()._ensure_subcache(node_id, children_ids) + + self.cache_key_set.add_keys(children_ids) + self._mark_used(node_id) + cache_key = self.cache_key_set.get_data_key(node_id) + self.children[cache_key] = [] + for child_id in children_ids: + self._mark_used(child_id) + self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) + return self diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 70506f24d..0b3b0c01b 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -87,6 +87,8 @@ def _create_parser() -> EnhancedConfigArgParser: parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.Auto, help="Default preview method for sampler nodes.", action=EnumAction) + cache_group = parser.add_mutually_exclusive_group() + cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") @@ -211,11 +213,11 @@ def _create_parser() -> EnhancedConfigArgParser: return parser -def _parse_args(parser: Optional[argparse.ArgumentParser] = None) -> Configuration: +def _parse_args(parser: Optional[argparse.ArgumentParser] = None, args_parsing: bool = False) -> Configuration: if parser is None: parser = _create_parser() - if options.args_parsing: + if args_parsing: args, _, config_files = parser.parse_known_args_with_config_files() else: args, _, config_files = parser.parse_known_args_with_config_files([]) @@ -261,4 +263,8 @@ def _setup_config_file_watcher(config: Configuration, parser: EnhancedConfigArgP atexit.register(observer.join) -args = _parse_args() +def default_configuration() -> Configuration: + return _parse_args(_create_parser()) + + +args = _parse_args(args_parsing=options.args_parsing) diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 79d971ad9..9843f3735 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -70,6 +70,7 @@ class Configuration(dict): directml (Optional[int]): Use DirectML. -1 for auto-selection. disable_ipex_optimize (bool): Disable IPEX optimization for Intel GPUs. preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "auto". + cache_lru (int): Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM. use_split_cross_attention (bool): Use split cross-attention optimization. use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization. use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function. @@ -179,6 +180,9 @@ class Configuration(dict): self.force_channels_last: bool = False self.force_hf_local_dir_mode = False + # from guill + self.cache_lru: int = 0 + # from opentracing docs self.otel_service_name: str = "comfyui" self.otel_service_version: str = "0.0.1" diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index c94a9024f..577b9ac22 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import gc +import json import uuid from asyncio import get_event_loop from concurrent.futures import ThreadPoolExecutor @@ -107,7 +108,7 @@ class EmbeddedComfyClient: from ..cmd.execution import PromptExecutor - self._prompt_executor = PromptExecutor(self._progress_handler) + self._prompt_executor = PromptExecutor(self._progress_handler, lru_size=self._configuration.cache_lru if self._configuration is not None else 0) self._prompt_executor.raise_exceptions = True await get_event_loop().run_in_executor(self._executor, create_executor_in_thread) @@ -128,9 +129,9 @@ class EmbeddedComfyClient: try: prompt_mut = make_mutable(prompt) validation_tuple = validate_prompt(prompt_mut) - if not validation_tuple[0]: - validation_error_dict = validation_tuple[1] or {"message": "Unknown", "details": ""} - raise ValueError("\n".join([validation_error_dict["message"], validation_error_dict["details"]])) + if not validation_tuple.valid: + validation_error_dict = {"message": "Unknown", "details": ""} if not validation_tuple.node_errors or len(validation_tuple.node_errors) == 0 else validation_tuple.node_errors + raise ValueError(json.dumps(validation_error_dict)) prompt_executor: PromptExecutor = self._prompt_executor @@ -140,7 +141,7 @@ class EmbeddedComfyClient: prompt_executor.server = self._progress_handler prompt_executor.execute(prompt_mut, prompt_id, {"client_id": client_id}, - execute_outputs=validation_tuple[2]) + execute_outputs=validation_tuple.good_output_node_ids) return prompt_executor.outputs_ui except Exception as exc_info: span.set_status(Status(StatusCode.ERROR)) diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index b1a07273d..97d5c5304 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -20,7 +20,8 @@ from .. import model_management from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \ ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict, FormattedValue, RecursiveExecutionTuple, \ - RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted + RecursiveExecutionErrorDetails, RecursiveExecutionErrorDetailsInterrupted, ExecutionResult, DuplicateNodeError, \ + HistoryResultDict from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..execution_context import new_execution_context, ExecutionContext from ..nodes.package import import_all_nodes_in_workspace @@ -29,46 +30,122 @@ from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOption # ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the # various functions that are declared here. It should have been a context in the first place. nodes: ExportedNodes = lazy_object_proxy.Proxy(import_all_nodes_in_workspace) +from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker +from comfy.graph_utils import is_link, GraphBuilder +from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID -def get_input_data(inputs, class_def, unique_id, outputs=None, prompt=None, extra_data=None): +class IsChangedCache: + def __init__(self, dynprompt, outputs_cache): + self.dynprompt = dynprompt + self.outputs_cache = outputs_cache + self.is_changed = {} + + def get(self, node_id): + if node_id in self.is_changed: + return self.is_changed[node_id] + + node = self.dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if not hasattr(class_def, "IS_CHANGED"): + self.is_changed[node_id] = False + return self.is_changed[node_id] + + if "is_changed" in node: + self.is_changed[node_id] = node["is_changed"] + return self.is_changed[node_id] + + input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache) + try: + is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") + node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] + except: + node["is_changed"] = float("NaN") + finally: + self.is_changed[node_id] = node["is_changed"] + return self.is_changed[node_id] + + +class CacheSet: + def __init__(self, lru_size=None): + if lru_size is None or lru_size == 0: + self.init_classic_cache() + else: + self.init_lru_cache(lru_size) + self.all = [self.outputs, self.ui, self.objects] + + # Useful for those with ample RAM/VRAM -- allows experimenting without + # blowing away the cache every time + def init_lru_cache(self, cache_size): + self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) + self.objects = HierarchicalCache(CacheKeySetID) + + # Performs like the old cache -- dump data ASAP + def init_classic_cache(self): + self.outputs = HierarchicalCache(CacheKeySetInputSignature) + self.ui = HierarchicalCache(CacheKeySetInputSignature) + self.objects = HierarchicalCache(CacheKeySetID) + + def recursive_debug_dump(self): + result = { + "outputs": self.outputs.recursive_debug_dump(), + "ui": self.ui.recursive_debug_dump(), + } + return result + + +def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data=None): if extra_data is None: extra_data = {} - if prompt is None: - prompt = {} if outputs is None: outputs = {} valid_inputs = class_def.INPUT_TYPES() input_data_all = {} + missing_keys = {} for x in inputs: input_data = inputs[x] - if isinstance(input_data, list): + input_type, input_category, input_info = get_input_info(class_def, x) + + def mark_missing(): + missing_keys[x] = True + input_data_all[x] = (None,) + + if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): input_unique_id = input_data[0] output_index = input_data[1] - if input_unique_id not in outputs: - input_data_all[x] = (None,) + if outputs is None: + mark_missing() + continue # This might be a lazily-evaluated input + cached_output = outputs.get(input_unique_id) + if cached_output is None: + mark_missing() continue - obj = outputs[input_unique_id][output_index] + if output_index >= len(cached_output): + mark_missing() + continue + obj = cached_output[output_index] input_data_all[x] = obj - else: - if ("required" in valid_inputs and x in valid_inputs["required"]) or ( - "optional" in valid_inputs and x in valid_inputs["optional"]): - input_data_all[x] = [input_data] + elif input_category is not None: + input_data_all[x] = [input_data] if "hidden" in valid_inputs: h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": - input_data_all[x] = [prompt] + input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}] + if h[x] == "DYNPROMPT": + input_data_all[x] = [dynprompt] if h[x] == "EXTRA_PNGINFO": input_data_all[x] = [extra_data.get('extra_pnginfo', None)] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] - return input_data_all + return input_data_all, missing_keys @tracer.start_as_current_span("Execute Node") -def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False): +def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): span = get_current_span() class_type = obj.__class__.__name__ span.set_attribute("class_type", class_type) @@ -95,73 +172,103 @@ def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: span.set_attribute(f"input_data_all.{kwarg_name}", filtered_items) except TypeError: pass - # check if node wants the lists - input_is_list = False - if hasattr(obj, "INPUT_IS_LIST"): - input_is_list = obj.INPUT_IS_LIST + input_is_list = getattr(obj, "INPUT_IS_LIST", False) if len(input_data_all) == 0: max_len_input = 0 else: - max_len_input = max([len(x) for x in input_data_all.values()]) + max_len_input = max(len(x) for x in input_data_all.values()) # get a slice of inputs, repeat last input when list isn't long enough def slice_dict(d, i): - d_new = dict() - for k, v in d.items(): - d_new[k] = v[i if len(v) > i else -1] - return d_new + return {k: v[i if len(v) > i else -1] for k, v in d.items()} results = [] + + def process_inputs(inputs, index=None): + if allow_interrupt: + interruption.throw_exception_if_processing_interrupted() + execution_block = None + for k, v in inputs.items(): + if isinstance(v, ExecutionBlocker): + execution_block = execution_block_cb(v) if execution_block_cb else v + break + if execution_block is None: + if pre_execute_cb is not None and index is not None: + pre_execute_cb(index) + results.append(getattr(obj, func)(**inputs)) + else: + results.append(execution_block) + if input_is_list: - if allow_interrupt: - interruption.throw_exception_if_processing_interrupted() - results.append(getattr(obj, func)(**input_data_all)) + process_inputs(input_data_all, 0) elif max_len_input == 0: - if allow_interrupt: - interruption.throw_exception_if_processing_interrupted() - results.append(getattr(obj, func)()) + process_inputs({}) else: for i in range(max_len_input): - if allow_interrupt: - interruption.throw_exception_if_processing_interrupted() - results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) + input_dict = slice_dict(input_data_all, i) + process_inputs(input_dict, i) return results -def get_output_data(obj, input_data_all) -> Tuple[List[typing.Any], typing.Dict[str, List[typing.Any]]]: +def merge_result_data(results, obj): + # check which outputs need concatenating + output = [] + output_is_list = [False] * len(results[0]) + if hasattr(obj, "OUTPUT_IS_LIST"): + output_is_list = obj.OUTPUT_IS_LIST + + # merge node execution results + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: + output.append([x for o in results for x in o[i]]) + else: + output.append([o[i] for o in results]) + return output + + +def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None): results = [] uis = [] - return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) - - for r in return_values: + subgraph_results = [] + return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + has_subgraph = False + for i in range(len(return_values)): + r = return_values[i] if isinstance(r, dict): if 'ui' in r: uis.append(r['ui']) - if 'result' in r: - results.append(r['result']) + if 'expand' in r: + # Perform an expansion, but do not append results + has_subgraph = True + new_graph = r['expand'] + result = r.get("result", None) + if isinstance(result, ExecutionBlocker): + result = tuple([result] * len(obj.RETURN_TYPES)) + subgraph_results.append((new_graph, result)) + elif 'result' in r: + result = r.get("result", None) + if isinstance(result, ExecutionBlocker): + result = tuple([result] * len(obj.RETURN_TYPES)) + results.append(result) + subgraph_results.append((None, result)) else: + if isinstance(r, ExecutionBlocker): + r = tuple([r] * len(obj.RETURN_TYPES)) results.append(r) + subgraph_results.append((None, r)) - output = [] - if len(results) > 0: - # check which outputs need concatenating - output_is_list = [False] * len(results[0]) - if hasattr(obj, "OUTPUT_IS_LIST"): - output_is_list = obj.OUTPUT_IS_LIST - - # merge node execution results - for i, is_list in zip(range(len(results[0])), output_is_list): - if is_list: - output.append([x for o in results for x in o[i]]) - else: - output.append([o[i] for o in results]) - + if has_subgraph: + output = subgraph_results + elif len(results) > 0: + output = merge_result_data(results, obj) + else: + output = [] ui = dict() if len(uis) > 0: ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} - return output, ui + return output, ui, has_subgraph def format_value(x) -> FormattedValue: @@ -173,63 +280,148 @@ def format_value(x) -> FormattedValue: return str(x) -def recursive_execute(server: ExecutorToClientProgress, - prompt, - outputs, - current_item, - extra_data, - executed, - prompt_id, - outputs_ui, - object_storage) -> RecursiveExecutionTuple: +def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results): unique_id = current_item - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] + real_node_id = dynprompt.get_real_node_id(unique_id) + display_node_id = dynprompt.get_display_node_id(unique_id) + parent_node_id = dynprompt.get_parent_node_id(unique_id) + inputs = dynprompt.get_node(unique_id)['inputs'] + class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if unique_id in outputs: - return RecursiveExecutionTuple(True, None, None) - - for x in inputs: - input_data = inputs[x] - - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id not in outputs: - result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, - outputs_ui, object_storage) - if result[0] is not True: - # Another node failed further upstream - return result + if caches.outputs.get(unique_id) is not None: + if server.client_id is not None: + cached_output = caches.ui.get(unique_id) or {} + server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": cached_output.get("output", None), "prompt_id": prompt_id}, server.client_id) + return RecursiveExecutionTuple(ExecutionResult.SUCCESS, None, None) input_data_all = None try: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) - if server.client_id is not None: - server.last_node_id = unique_id - server.send_sync("executing", {"node": unique_id, "prompt_id": prompt_id}, server.client_id) + if unique_id in pending_subgraph_results: + cached_results = pending_subgraph_results[unique_id] + resolved_outputs = [] + for is_subgraph, result in cached_results: + if not is_subgraph: + resolved_outputs.append(result) + else: + resolved_output = [] + for r in result: + if is_link(r): + source_node, source_output = r[0], r[1] + node_output = caches.outputs.get(source_node)[source_output] + for o in node_output: + resolved_output.append(o) - obj = object_storage.get((unique_id, class_type), None) - if obj is None: - obj = class_def() - object_storage[(unique_id, class_type)] = obj - - output_data, output_ui = get_output_data(obj, input_data_all) - outputs[unique_id] = output_data - if len(output_ui) > 0: - outputs_ui[unique_id] = output_ui + else: + resolved_output.append(r) + resolved_outputs.append(tuple(resolved_output)) + output_data = merge_result_data(resolved_outputs, class_def) + output_ui = [] + has_subgraph = False + else: + input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) if server.client_id is not None: - server.send_sync("executed", {"node": unique_id, "output": output_ui, "prompt_id": prompt_id}, + server.last_node_id = display_node_id + server.send_sync("executing", {"node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id}, server.client_id) + + obj = caches.objects.get(unique_id) + if obj is None: + obj = class_def() + caches.objects.set(unique_id, obj) + + if hasattr(obj, "check_lazy_status"): + required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True) + required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], [])) + required_inputs = [x for x in required_inputs if isinstance(x, str) and ( + x not in input_data_all or x in missing_keys + )] + if len(required_inputs) > 0: + for i in required_inputs: + execution_list.make_input_strong_link(unique_id, i) + return (ExecutionResult.PENDING, None, None) + + def execution_block_cb(block): + if block.message is not None: + mes = { + "prompt_id": prompt_id, + "node_id": unique_id, + "node_type": class_type, + "executed": list(executed), + + "exception_message": f"Execution Blocked: {block.message}", + "exception_type": "ExecutionBlocked", + "traceback": [], + "current_inputs": [], + "current_outputs": [], + } + server.send_sync("execution_error", mes, server.client_id) + return ExecutionBlocker(None) + else: + return block + + def pre_execute_cb(call_index): + GraphBuilder.set_default_prefix(unique_id, call_index, 0) + + output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb) + if len(output_ui) > 0: + caches.ui.set(unique_id, { + "meta": { + "node_id": unique_id, + "display_node": display_node_id, + "parent_node": parent_node_id, + "real_node_id": real_node_id, + }, + "output": output_ui + }) + if server.client_id is not None: + server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id}, server.client_id) + if has_subgraph: + cached_outputs = [] + new_node_ids = [] + new_output_ids = [] + new_output_links = [] + for i in range(len(output_data)): + new_graph, node_outputs = output_data[i] + if new_graph is None: + cached_outputs.append((False, node_outputs)) + else: + # Check for conflicts + for node_id in new_graph.keys(): + if dynprompt.has_node(node_id): + raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.") + for node_id, node_info in new_graph.items(): + new_node_ids.append(node_id) + display_id = node_info.get("override_display_id", unique_id) + dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id) + # Figure out if the newly created node is an output node + class_type = node_info["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + new_output_ids.append(node_id) + for i in range(len(node_outputs)): + if is_link(node_outputs[i]): + from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1] + new_output_links.append((from_node_id, from_socket)) + cached_outputs.append((True, node_outputs)) + new_node_ids = set(new_node_ids) + for cache in caches.all: + cache.ensure_subcache_for(unique_id, new_node_ids).clean_unused() + for node_id in new_output_ids: + execution_list.add_node(node_id) + for link in new_output_links: + execution_list.add_strong_link(link[0], link[1], unique_id) + pending_subgraph_results[unique_id] = cached_outputs + return (ExecutionResult.PENDING, None, None) + caches.outputs.set(unique_id, output_data) except interruption.InterruptProcessingException as iex: logging.info("Processing interrupted") # skip formatting inputs/outputs error_details: RecursiveExecutionErrorDetailsInterrupted = { - "node_id": unique_id, + "node_id": real_node_id, } - return RecursiveExecutionTuple(False, error_details, iex) + return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, iex) except Exception as ex: typ, _, tb = sys.exc_info() exception_type = full_type_name(typ) @@ -239,126 +431,36 @@ def recursive_execute(server: ExecutorToClientProgress, for name, inputs in input_data_all.items(): input_data_formatted[name] = [format_value(x) for x in inputs] - output_data_formatted = {} - for node_id, node_outputs in outputs.items(): - output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] - - logging.error(f"!!! Exception during processing!!! {ex}") + logging.error(f"!!! Exception during processing !!! {ex}") logging.error(traceback.format_exc()) error_details: RecursiveExecutionErrorDetails = { - "node_id": unique_id, + "node_id": real_node_id, "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), - "current_inputs": input_data_formatted, - "current_outputs": output_data_formatted + "current_inputs": input_data_formatted } - return RecursiveExecutionTuple(False, error_details, ex) + return RecursiveExecutionTuple(ExecutionResult.FAILURE, error_details, ex) executed.add(unique_id) - return RecursiveExecutionTuple(True, None, None) - - -_memo = {} - - -def recursive_will_execute(prompt, outputs, current_item, memo: dict[str, list[str]] = None) -> list[str]: - if memo is None: - memo = _memo - unique_id = current_item - - if unique_id in memo: - return memo[unique_id] - - inputs = prompt[unique_id]['inputs'] - will_execute = [] - if unique_id in outputs: - return [] - - for x in inputs: - input_data = inputs[x] - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id not in outputs: - will_execute += recursive_will_execute(prompt, outputs, input_unique_id, memo) - - memo[unique_id] = will_execute + [unique_id] - return memo[unique_id] - - -def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item): - unique_id = current_item - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] - class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - - is_changed_old = '' - is_changed = '' - to_delete = False - if hasattr(class_def, 'IS_CHANGED'): - if unique_id in old_prompt and 'is_changed' in old_prompt[unique_id]: - is_changed_old = old_prompt[unique_id]['is_changed'] - if 'is_changed' not in prompt[unique_id]: - input_data_all = get_input_data(inputs, class_def, unique_id, outputs) - if input_data_all is not None: - try: - # is_changed = class_def.IS_CHANGED(**input_data_all) - is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") - prompt[unique_id]['is_changed'] = is_changed - except: - to_delete = True - else: - is_changed = prompt[unique_id]['is_changed'] - - if unique_id not in outputs: - return True - - if not to_delete: - if is_changed != is_changed_old: - to_delete = True - elif unique_id not in old_prompt: - to_delete = True - elif class_type != old_prompt[unique_id]['class_type']: - to_delete = True - elif inputs == old_prompt[unique_id]['inputs']: - for x in inputs: - input_data = inputs[x] - - if isinstance(input_data, list): - input_unique_id = input_data[0] - output_index = input_data[1] - if input_unique_id in outputs: - to_delete = recursive_output_delete_if_changed(prompt, old_prompt, outputs, input_unique_id) - else: - to_delete = True - if to_delete: - break - else: - to_delete = True - - if to_delete: - d = outputs.pop(unique_id) - del d - return to_delete + return (ExecutionResult.SUCCESS, None, None) class PromptExecutor: - def __init__(self, server: ExecutorToClientProgress): + def __init__(self, server: ExecutorToClientProgress, lru_size=None): self.success = None + self.lru_size = lru_size self.server = server self.raise_exceptions = False self.reset() + self.history_result: HistoryResultDict | None = None def reset(self): - self.outputs = {} - self.object_storage = {} - self.outputs_ui = {} - self.status_messages = [] self.success = True - self.old_prompt = {} + self.caches = CacheSet(self.lru_size) + self.status_messages = [] def add_message(self, event, data, broadcast: bool): self.status_messages.append((event, data)) @@ -389,27 +491,14 @@ class PromptExecutor: "node_id": node_id, "node_type": class_type, "executed": list(executed), - "exception_message": error["exception_message"], "exception_type": error["exception_type"], "traceback": error["traceback"], "current_inputs": error["current_inputs"], - "current_outputs": error["current_outputs"], + "current_outputs": list(current_outputs), } self.add_message("execution_error", mes, broadcast=False) - # Next, remove the subsequent outputs since they will not be executed - to_delete = [] - for o in self.outputs: - if (o not in current_outputs) and (o not in executed): - to_delete += [o] - if o in self.old_prompt: - d = self.old_prompt.pop(o) - del d - for o in to_delete: - d = self.outputs.pop(o) - del d - if ex is not None and self.raise_exceptions: raise ex @@ -433,69 +522,59 @@ class PromptExecutor: self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False) with torch.inference_mode(): - # delete cached outputs if nodes don't exist for them - to_delete = [] - for o in self.outputs: - if o not in prompt: - to_delete += [o] - for o in to_delete: - d = self.outputs.pop(o) - del d - to_delete = [] - for o in self.object_storage: - if o[0] not in prompt: - to_delete += [o] - else: - p = prompt[o[0]] - if o[1] != p['class_type']: - to_delete += [o] - for o in to_delete: - d = self.object_storage.pop(o) - del d + dynamic_prompt = DynamicPrompt(prompt) + is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) + for cache in self.caches.all: + cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + cache.clean_unused() - for x in prompt: - recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) - - current_outputs = set(self.outputs.keys()) - for x in list(self.outputs_ui.keys()): - if x not in current_outputs: - d = self.outputs_ui.pop(x) - del d + current_outputs = self.caches.outputs.all_node_ids() model_management.cleanup_models(keep_clone_weights_loaded=True) self.add_message("execution_cached", {"nodes": list(current_outputs), "prompt_id": prompt_id}, broadcast=False) + pending_subgraph_results = {} executed = set() - output_node_id = None - to_execute = [] - + execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) for node_id in list(execute_outputs): - to_execute += [(0, node_id)] + execution_list.add_node(node_id) - while len(to_execute) > 0: - # always execute the output that depends on the least amount of unexecuted nodes first - memo = {} - to_execute = sorted(list( - map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1], memo)), a[-1]), to_execute))) - output_node_id = to_execute.pop(0)[-1] - - # This call shouldn't raise anything if there's an error deep in - # the actual SD code, instead it will report the node where the - # error was raised - self.success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, - extra_data, executed, prompt_id, self.outputs_ui, - self.object_storage) - if self.success is not True: - self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) + while not execution_list.is_empty(): + node_id, error, ex = execution_list.stage_node_execution() + if error is not None: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break - for x in executed: - self.old_prompt[x] = copy.deepcopy(prompt[x]) + result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results) + if result == ExecutionResult.FAILURE: + self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + break + elif result == ExecutionResult.PENDING: + execution_list.unstage_node_execution() + else: # result == ExecutionResult.SUCCESS: + execution_list.complete_node_execution() + + ui_outputs = {} + meta_outputs = {} + all_node_ids = self.caches.ui.all_node_ids() + for node_id in all_node_ids: + ui_info = self.caches.ui.get(node_id) + if ui_info is not None: + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] + self.history_result = { + "outputs": ui_outputs, + "meta": meta_outputs, + } self.server.last_node_id = None if model_management.DISABLE_SMART_MEMORY: model_management.unload_all_models() + @property + def outputs_ui(self) -> dict | None: + return self.history_result["outputs"] if self.history_result is not None else None + def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple: # todo: this should check if LoadImage / LoadImageMask paths exist @@ -509,7 +588,7 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] class_inputs = obj_class.INPUT_TYPES() - required_inputs = class_inputs['required'] + valid_inputs = set(class_inputs.get('required', {})).union(set(class_inputs.get('optional', {}))) error: ValidationErrorDict errors = [] @@ -522,23 +601,26 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl validate_function_inputs = [] if hasattr(obj_class, "VALIDATE_INPUTS"): validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args + received_types = {} - for x in required_inputs: + for x in valid_inputs: + type_input, input_category, extra_info = get_input_info(obj_class, x) + assert extra_info is not None if x not in inputs: - error = { - "type": "required_input_missing", - "message": "Required input is missing", - "details": f"{x}", - "extra_info": { - "input_name": x + if input_category == "required": + error = { + "type": "required_input_missing", + "message": "Required input is missing", + "details": f"{x}", + "extra_info": { + "input_name": x + } } - } - errors.append(error) + errors.append(error) continue val = inputs[x] - info: InputTypeSpec = required_inputs[x] - type_input = info[0] + info: InputTypeSpec = (type_input, extra_info) if isinstance(val, list): if len(val) != 2: error = { @@ -557,11 +639,9 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES - type_input_from_prompt = r[val[1]] - is_combo = all(isinstance(x, typing.List) or isinstance(x, typing.Tuple) for x in (type_input, type_input_from_prompt)) - is_invalid_string_to_combo = is_combo and len(type_input_from_prompt) != 0 - if type_input_from_prompt != type_input and is_invalid_string_to_combo: - received_type = type_input_from_prompt + received_type = r[val[1]] + received_types[x] = received_type + if 'input_types' not in validate_function_inputs and received_type != type_input: details = f"{x}, {received_type} != {type_input}" error = { "type": "return_type_mismatch", @@ -612,6 +692,9 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl if type_input == "STRING": val = str(val) inputs[x] = val + if type_input == "BOOLEAN": + val = bool(val) + inputs[x] = val except Exception as ex: error = { "type": "invalid_input_type", @@ -627,7 +710,7 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl errors.append(error) continue - if len(info) > 1: + if x not in validate_function_inputs: has_min_max: IntSpecOptions | FloatSpecOptions = info[1] if "min" in has_min_max and val < has_min_max["min"]: error = { @@ -656,7 +739,6 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl errors.append(error) continue - if x not in validate_function_inputs: if isinstance(type_input, list): if val not in type_input: input_config = info @@ -684,17 +766,19 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl continue if len(validate_function_inputs) > 0: - input_data_all = get_input_data(inputs, obj_class, unique_id) + input_data_all, _ = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs: input_filtered[x] = input_data_all[x] + if 'input_types' in validate_function_inputs: + input_filtered['input_types'] = [received_types] # ret = obj_class.VALIDATE_INPUTS(**input_filtered) ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") for x in input_filtered: for i, r in enumerate(ret): - if r is not True: + if r is not True and not isinstance(r, ExecutionBlocker): details = f"{x}" if r is not False: details += f" - {str(r)}" @@ -705,8 +789,6 @@ def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTupl "details": details, "extra_info": { "input_name": x, - "input_config": info, - "received_value": val, } } errors.append(error) @@ -899,6 +981,7 @@ class PromptQueue(AbstractPromptQueue): def task_done(self, item_id: str, outputs: dict, status: Optional[ExecutionStatus]): + history_result = outputs with self.mutex: queue_item = self.currently_running.pop(item_id) prompt = queue_item.queue_tuple @@ -909,14 +992,16 @@ class PromptQueue(AbstractPromptQueue): if status is not None: status_dict = copy.deepcopy(ExecutionStatus(*status)._asdict()) + outputs_ = history_result["outputs"] self.history[prompt[1]] = { "prompt": prompt, - "outputs": copy.deepcopy(outputs), + "outputs": copy.deepcopy(outputs_), 'status': status_dict, } + self.history[prompt[1]].update(history_result) self.server.queue_updated() if queue_item.completed: - queue_item.completed.set_result(outputs) + queue_item.completed.set_result(outputs_) def get_current_queue(self) -> Tuple[typing.List[QueueTuple], typing.List[QueueTuple]]: with self.mutex: diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 249dda843..822d3b23a 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -45,7 +45,7 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer): e.execute(item[2], prompt_id, item[3], item[4]) need_gc = True q.task_done(item_id, - e.outputs_ui, + e.history_result, status=ExecutionStatus( status_str='success' if e.success else 'error', completed=e.success, diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index bcf061aa9..50ca11a66 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -437,6 +437,7 @@ class PromptServer(ExecutorToClientProgress): obj_class = self.nodes.NODE_CLASS_MAPPINGS[node_class] info = {} info['input'] = obj_class.INPUT_TYPES() + info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} info['output'] = obj_class.RETURN_TYPES info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] @@ -829,6 +830,9 @@ class PromptServer(ExecutorToClientProgress): site = web.TCPSite(runner, host=address, port=port) await site.start() + self.address = address + self.port = port + if verbose: logging.info("Starting server\n") logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port)) diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index d75f8e212..215c6c309 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -1,7 +1,8 @@ from __future__ import annotations # for Python 3.7-3.9 import typing -from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple, List +from enum import Enum +from typing import Optional, Literal, Protocol, Union, NamedTuple, List import PIL.Image from typing_extensions import NotRequired, TypedDict @@ -45,11 +46,23 @@ class UnencodedPreviewImageMessage(NamedTuple): max_size: int = 512 -ExecutedMessage: TypeAlias = ExecutingMessage +class ExecutionErrorMessage(TypedDict): + prompt_id: str + node_id: str + node_type: str + executed: list[str] + exception_message: str + exception_type: str + traceback: list[str] + current_inputs: list[typing.Never] | dict[str, FormattedValue] + current_outputs: list[str] -SendSyncEvent: TypeAlias = Union[Literal["status", "executing", "progress", "executed"], BinaryEventTypes, None] -SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None] +ExecutedMessage = ExecutingMessage + +SendSyncEvent = Union[Literal["status", "execution_error", "executing", "progress", "executed"], BinaryEventTypes, None] + +SendSyncData = Union[StatusMessage, ExecutingMessage, ExecutionErrorMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None] class ExecutorToClientProgress(Protocol): @@ -150,3 +163,33 @@ class RecursiveExecutionTuple(typing.NamedTuple): valid: bool error_details: Optional[RecursiveExecutionErrorDetails | RecursiveExecutionErrorDetailsInterrupted] exc_info: Optional[Exception] + + +class ExecutionResult(Enum): + SUCCESS = 0 + FAILURE = 1 + PENDING = 2 + + def __bool__(self): + return self == 0 + + +class DuplicateNodeError(Exception): + pass + + +class HistoryResultDict(TypedDict, total=True): + outputs: dict + meta: dict + + +class DependencyCycleError(Exception): + pass + + +class NodeInputError(Exception): + pass + + +class NodeNotFoundError(Exception): + pass diff --git a/comfy/distributed/distributed_prompt_queue.py b/comfy/distributed/distributed_prompt_queue.py index 4c2b59ea5..7566f20ec 100644 --- a/comfy/distributed/distributed_prompt_queue.py +++ b/comfy/distributed/distributed_prompt_queue.py @@ -21,7 +21,7 @@ from ..auth.permissions import jwt_decode from ..cmd.main_pre import tracer from ..cmd.server import PromptServer from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue, AbstractPromptQueue -from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData +from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData, HistoryResultDict from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation, \ ExecutionError @@ -164,6 +164,9 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue): def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus]): # callee: executed on the worker thread + if "outputs" in outputs: + outputs: HistoryResultDict + outputs = outputs["outputs"] assert self._is_callee pending = self._callee_local_in_progress.pop(item_id) assert pending is not None diff --git a/comfy/graph.py b/comfy/graph.py new file mode 100644 index 000000000..74babe200 --- /dev/null +++ b/comfy/graph.py @@ -0,0 +1,234 @@ +from .cmd.execution import nodes +from .component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError +from .graph_utils import is_link + + +class DynamicPrompt: + def __init__(self, original_prompt): + # The original prompt provided by the user + self.original_prompt = original_prompt + # Any extra pieces of the graph created during execution + self.ephemeral_prompt = {} + self.ephemeral_parents = {} + self.ephemeral_display = {} + + def get_node(self, node_id): + if node_id in self.ephemeral_prompt: + return self.ephemeral_prompt[node_id] + if node_id in self.original_prompt: + return self.original_prompt[node_id] + raise NodeNotFoundError(f"Node {node_id} not found") + + def has_node(self, node_id): + return node_id in self.original_prompt or node_id in self.ephemeral_prompt + + def add_ephemeral_node(self, node_id, node_info, parent_id, display_id): + self.ephemeral_prompt[node_id] = node_info + self.ephemeral_parents[node_id] = parent_id + self.ephemeral_display[node_id] = display_id + + def get_real_node_id(self, node_id): + while node_id in self.ephemeral_parents: + node_id = self.ephemeral_parents[node_id] + return node_id + + def get_parent_node_id(self, node_id): + return self.ephemeral_parents.get(node_id, None) + + def get_display_node_id(self, node_id): + while node_id in self.ephemeral_display: + node_id = self.ephemeral_display[node_id] + return node_id + + def all_node_ids(self): + return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys())) + + def get_original_prompt(self): + return self.original_prompt + + +def get_input_info(class_def, input_name): + valid_inputs = class_def.INPUT_TYPES() + input_info = None + input_category = None + if "required" in valid_inputs and input_name in valid_inputs["required"]: + input_category = "required" + input_info = valid_inputs["required"][input_name] + elif "optional" in valid_inputs and input_name in valid_inputs["optional"]: + input_category = "optional" + input_info = valid_inputs["optional"][input_name] + elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]: + input_category = "hidden" + input_info = valid_inputs["hidden"][input_name] + if input_info is None: + return None, None, None + input_type = input_info[0] + if len(input_info) > 1: + extra_info = input_info[1] + else: + extra_info = {} + return input_type, input_category, extra_info + + +class TopologicalSort: + def __init__(self, dynprompt): + self.dynprompt = dynprompt + self.pendingNodes = {} + self.blockCount = {} # Number of nodes this node is directly blocked by + self.blocking = {} # Which nodes are blocked by this node + + def get_input_info(self, unique_id, input_name): + class_type = self.dynprompt.get_node(unique_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + return get_input_info(class_def, input_name) + + def make_input_strong_link(self, to_node_id, to_input): + inputs = self.dynprompt.get_node(to_node_id)["inputs"] + if to_input not in inputs: + raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all") + value = inputs[to_input] + if not is_link(value): + raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant") + from_node_id, from_socket = value + self.add_strong_link(from_node_id, from_socket, to_node_id) + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + self.add_node(from_node_id) + if to_node_id not in self.blocking[from_node_id]: + self.blocking[from_node_id][to_node_id] = {} + self.blockCount[to_node_id] += 1 + self.blocking[from_node_id][to_node_id][from_socket] = True + + def add_node(self, unique_id, include_lazy=False, subgraph_nodes=None): + if unique_id in self.pendingNodes: + return + self.pendingNodes[unique_id] = True + self.blockCount[unique_id] = 0 + self.blocking[unique_id] = {} + + inputs = self.dynprompt.get_node(unique_id)["inputs"] + for input_name in inputs: + value = inputs[input_name] + if is_link(value): + from_node_id, from_socket = value + if subgraph_nodes is not None and from_node_id not in subgraph_nodes: + continue + input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] + if include_lazy or not is_lazy: + self.add_strong_link(from_node_id, from_socket, unique_id) + + def get_ready_nodes(self): + return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] + + def pop_node(self, unique_id): + del self.pendingNodes[unique_id] + for blocked_node_id in self.blocking[unique_id]: + self.blockCount[blocked_node_id] -= 1 + del self.blocking[unique_id] + + def is_empty(self): + return len(self.pendingNodes) == 0 + + +class ExecutionList(TopologicalSort): + """ + ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, + it can still be returned to the graph after having further dependencies added. + """ + + def __init__(self, dynprompt, output_cache): + super().__init__(dynprompt) + self.output_cache = output_cache + self.staged_node_id = None + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + if self.output_cache.get(from_node_id) is not None: + # Nothing to do + return + super().add_strong_link(from_node_id, from_socket, to_node_id) + + def stage_node_execution(self): + assert self.staged_node_id is None + if self.is_empty(): + return None, None, None + available = self.get_ready_nodes() + if len(available) == 0: + cycled_nodes = self.get_nodes_in_cycle() + # Because cycles composed entirely of static nodes are caught during initial validation, + # we will 'blame' the first node in the cycle that is not a static node. + blamed_node = cycled_nodes[0] + for node_id in cycled_nodes: + display_node_id = self.dynprompt.get_display_node_id(node_id) + if display_node_id != node_id: + blamed_node = display_node_id + break + ex = DependencyCycleError("Dependency cycle detected") + error_details = { + "node_id": blamed_node, + "exception_message": str(ex), + "exception_type": "graph.DependencyCycleError", + "traceback": [], + "current_inputs": [] + } + return None, error_details, ex + next_node = available[0] + # If an output node is available, do that first. + # Technically this has no effect on the overall length of execution, but it feels better as a user + # for a PreviewImage to display a result as soon as it can + # Some other heuristics could probably be used here to improve the UX further. + for node_id in available: + class_type = self.dynprompt.get_node(node_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + next_node = node_id + break + self.staged_node_id = next_node + return self.staged_node_id, None, None + + def unstage_node_execution(self): + assert self.staged_node_id is not None + self.staged_node_id = None + + def complete_node_execution(self): + node_id = self.staged_node_id + self.pop_node(node_id) + self.staged_node_id = None + + def get_nodes_in_cycle(self): + # We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle. + # We're skipping some of the performance optimizations from the original TopologicalSort to keep + # the code simple (and because having a cycle in the first place is a catastrophic error) + blocked_by = {node_id: {} for node_id in self.pendingNodes} + for from_node_id in self.blocking: + for to_node_id in self.blocking[from_node_id]: + if True in self.blocking[from_node_id][to_node_id].values(): + blocked_by[to_node_id][from_node_id] = True + to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] + while len(to_remove) > 0: + for node_id in to_remove: + for to_node_id in blocked_by: + if node_id in blocked_by[to_node_id]: + del blocked_by[to_node_id][node_id] + del blocked_by[node_id] + to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] + return list(blocked_by.keys()) + + +class ExecutionBlocker: + """ + Return this from a node and any users will be blocked with the given error message. + If the message is None, execution will be blocked silently instead. + Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's + possible, a lazy input will be more efficient and have a better user experience. + This functionality is useful in two cases: + 1. You want to conditionally prevent an output node from executing. (Particularly a built-in node + like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using + lazy evaluation to let it conditionally disable itself.) + 2. You have a node with multiple possible outputs, some of which are invalid and should not be used. + (I would recommend not making nodes like this in the future -- instead, make multiple nodes with + different outputs. Unfortunately, there are several popular existing nodes using this pattern.) + """ + + def __init__(self, message): + self.message = message diff --git a/comfy/graph_utils.py b/comfy/graph_utils.py new file mode 100644 index 000000000..dfbdc147a --- /dev/null +++ b/comfy/graph_utils.py @@ -0,0 +1,143 @@ +def is_link(obj): + if not isinstance(obj, list): + return False + if len(obj) != 2: + return False + if not isinstance(obj[0], str): + return False + if not isinstance(obj[1], int) and not isinstance(obj[1], float): + return False + return True + + +class GraphBuilder: + """ + The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end + """ + _default_prefix_root = "" + _default_prefix_call_index = 0 + _default_prefix_graph_index = 0 + + def __init__(self, prefix=None): + if prefix is None: + self.prefix = GraphBuilder.alloc_prefix() + else: + self.prefix = prefix + self.nodes = {} + self.id_gen = 1 + + @classmethod + def set_default_prefix(cls, prefix_root, call_index, graph_index=0): + cls._default_prefix_root = prefix_root + cls._default_prefix_call_index = call_index + cls._default_prefix_graph_index = graph_index + + @classmethod + def alloc_prefix(cls, root=None, call_index=None, graph_index=None): + if root is None: + root = GraphBuilder._default_prefix_root + if call_index is None: + call_index = GraphBuilder._default_prefix_call_index + if graph_index is None: + graph_index = GraphBuilder._default_prefix_graph_index + result = f"{root}.{call_index}.{graph_index}." + GraphBuilder._default_prefix_graph_index += 1 + return result + + def node(self, class_type, id=None, **kwargs): + if id is None: + id = str(self.id_gen) + self.id_gen += 1 + id = self.prefix + id + if id in self.nodes: + return self.nodes[id] + + node = Node(id, class_type, kwargs) + self.nodes[id] = node + return node + + def lookup_node(self, id): + id = self.prefix + id + return self.nodes.get(id) + + def finalize(self): + output = {} + for node_id, node in self.nodes.items(): + output[node_id] = node.serialize() + return output + + def replace_node_output(self, node_id, index, new_value): + node_id = self.prefix + node_id + to_remove = [] + for node in self.nodes.values(): + for key, value in node.inputs.items(): + if is_link(value) and value[0] == node_id and value[1] == index: + if new_value is None: + to_remove.append((node, key)) + else: + node.inputs[key] = new_value + for node, key in to_remove: + del node.inputs[key] + + def remove_node(self, id): + id = self.prefix + id + del self.nodes[id] + + +class Node: + def __init__(self, id, class_type, inputs): + self.id = id + self.class_type = class_type + self.inputs = inputs + self.override_display_id = None + + def out(self, index): + return [self.id, index] + + def set_input(self, key, value): + if value is None: + if key in self.inputs: + del self.inputs[key] + else: + self.inputs[key] = value + + def get_input(self, key): + return self.inputs.get(key) + + def set_override_display_id(self, override_display_id): + self.override_display_id = override_display_id + + def serialize(self): + serialized = { + "class_type": self.class_type, + "inputs": self.inputs + } + if self.override_display_id is not None: + serialized["override_display_id"] = self.override_display_id + return serialized + + +def add_graph_prefix(graph, outputs, prefix): + # Change the node IDs and any internal links + new_graph = {} + for node_id, node_info in graph.items(): + # Make sure the added nodes have unique IDs + new_node_id = prefix + node_id + new_node = {"class_type": node_info["class_type"], "inputs": {}} + for input_name, input_value in node_info.get("inputs", {}).items(): + if is_link(input_value): + new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]] + else: + new_node["inputs"][input_name] = input_value + new_graph[new_node_id] = new_node + + # Change the node IDs in the outputs + new_outputs = [] + for n in range(len(outputs)): + output = outputs[n] + if is_link(output): + new_outputs.append([prefix + output[0], output[1]]) + else: + new_outputs.append(output) + + return new_graph, tuple(new_outputs) diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index e78aa6a58..cc6cfd003 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -15,6 +15,7 @@ class IntSpecOptions(TypedDict, total=True): max: int step: NotRequired[int] display: NotRequired[Literal["number", "slider"]] + lazy: NotRequired[bool] class FloatSpecOptions(TypedDict, total=True): @@ -24,20 +25,24 @@ class FloatSpecOptions(TypedDict, total=True): step: NotRequired[float] round: NotRequired[float] display: NotRequired[Literal["number", "slider"]] + lazy: NotRequired[bool] class StringSpecOptions(TypedDict, total=True): multiline: NotRequired[bool] default: NotRequired[str] dynamicPrompts: NotRequired[bool] + lazy: NotRequired[bool] class BoolSpecOptions(TypedDict): default: NotRequired[bool] + lazy: NotRequired[bool] class DefaultSpecOptions(TypedDict): default: NotRequired[Any] + lazy: NotRequired[bool] # todo: analyze the base_nodes for these types @@ -129,6 +134,20 @@ class CustomNode(Protocol): def __call__(cls, *args, **kwargs) -> 'CustomNode': ... + def check_lazy_status(self, *args, **kwargs) -> list[str]: + """ + Return a list of input names that need to be evaluated. + + This function will be called if there are any lazy inputs which have not yet been + evaluated. As long as you return at least one field which has not yet been evaluated + (and more exist), this function will be called again once the value of the requested + field is available. + + Any evaluated inputs will be passed as arguments to this function. Any unevaluated + inputs will have the value None. + """ + ... + @dataclass class ExportedNodes: diff --git a/comfy/web/extensions/core/groupNode.js b/comfy/web/extensions/core/groupNode.js index 9a2238908..163e42b89 100644 --- a/comfy/web/extensions/core/groupNode.js +++ b/comfy/web/extensions/core/groupNode.js @@ -959,8 +959,8 @@ export class GroupNodeHandler { const executed = handleEvent.call( this, "executed", - (d) => d?.node, - (d, id, node) => ({ ...d, node: id, merge: !node.resetExecution }) + (d) => d?.display_node, + (d, id, node) => ({ ...d, node: id, display_node: id, merge: !node.resetExecution }) ); const onRemoved = node.onRemoved; diff --git a/comfy/web/extensions/core/widgetInputs.js b/comfy/web/extensions/core/widgetInputs.js index 0fba7147d..a66451c64 100644 --- a/comfy/web/extensions/core/widgetInputs.js +++ b/comfy/web/extensions/core/widgetInputs.js @@ -3,7 +3,7 @@ import { app } from "../../scripts/app.js"; import { applyTextReplacements } from "../../scripts/utils.js"; const CONVERTED_TYPE = "converted-widget"; -const VALID_TYPES = ["STRING", "combo", "number", "BOOLEAN"]; +const VALID_TYPES = ["STRING", "combo", "number", "toggle", "BOOLEAN"]; const CONFIG = Symbol(); const GET_CONFIG = Symbol(); const TARGET = Symbol(); // Used for reroutes to specify the real target widget diff --git a/comfy/web/scripts/api.js b/comfy/web/scripts/api.js index 39f0a9bb2..b079692b5 100644 --- a/comfy/web/scripts/api.js +++ b/comfy/web/scripts/api.js @@ -128,7 +128,7 @@ class ComfyApi extends EventTarget { this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); break; case "executing": - this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node })); + this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.display_node })); break; case "executed": this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); diff --git a/comfy/web/scripts/app.js b/comfy/web/scripts/app.js index 8b4478a32..341aadf92 100644 --- a/comfy/web/scripts/app.js +++ b/comfy/web/scripts/app.js @@ -58,7 +58,7 @@ export class ComfyApp { this.bodyTop = $el("div.comfyui-body-top", { parent: document.body }); this.bodyLeft = $el("div.comfyui-body-left", { parent: document.body }); this.bodyRight = $el("div.comfyui-body-right", { parent: document.body }); - this.bodyBottom = $el("div.comfyui-body-bottom", { parent: document.body }); + this.bodyBottom = $el("div.comfyui-body-bottom", { parent: document.body }); this.menu = new ComfyAppMenu(this); /** @@ -1330,7 +1330,7 @@ export class ComfyApp { }); api.addEventListener("progress", ({ detail }) => { - if (this.workflowManager.activePrompt?.workflow + if (this.workflowManager.activePrompt?.workflow && this.workflowManager.activePrompt.workflow !== this.workflowManager.activeWorkflow) return; this.progress = detail; this.graph.setDirtyCanvas(true, false); @@ -1348,7 +1348,7 @@ export class ComfyApp { api.addEventListener("executed", ({ detail }) => { if (this.workflowManager.activePrompt ?.workflow && this.workflowManager.activePrompt.workflow !== this.workflowManager.activeWorkflow) return; - const output = this.nodeOutputs[detail.node]; + const output = this.nodeOutputs[detail.display_node]; if (detail.merge && output) { for (const k in detail.output ?? {}) { const v = output[k]; @@ -1359,9 +1359,9 @@ export class ComfyApp { } } } else { - this.nodeOutputs[detail.node] = detail.output; + this.nodeOutputs[detail.display_node] = detail.output; } - const node = this.graph.getNodeById(detail.node); + const node = this.graph.getNodeById(detail.display_node); if (node) { if (node.onExecuted) node.onExecuted(detail.output); @@ -1638,7 +1638,7 @@ export class ComfyApp { resizeCanvas() { // Limit minimal scale to 1, see https://github.com/comfyanonymous/ComfyUI/pull/845 const scale = Math.max(window.devicePixelRatio, 1); - + // Clear fixed width and height while calculating rect so it uses 100% instead this.canvasEl.height = this.canvasEl.width = ""; const { width, height } = this.canvasEl.getBoundingClientRect(); @@ -1874,7 +1874,7 @@ export class ComfyApp { { graphData = structuredClone(graphData); } - + try { this.workflowManager.setWorkflow(workflow); } catch (error) { @@ -1902,7 +1902,7 @@ export class ComfyApp { this.canvas.ds.offset = graphData.extra.ds.offset; this.canvas.ds.scale = graphData.extra.ds.scale; } - + try { this.workflowManager.activeWorkflow?.track() } catch (error) { diff --git a/comfy/web/scripts/ui.js b/comfy/web/scripts/ui.js index 2c47412c9..f35e2e0c3 100644 --- a/comfy/web/scripts/ui.js +++ b/comfy/web/scripts/ui.js @@ -239,7 +239,14 @@ class ComfyList { onclick: async () => { await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow, true, false); if (item.outputs) { - app.nodeOutputs = item.outputs; + app.nodeOutputs = {}; + for (const [key, value] of Object.entries(item.outputs)) { + if (item.meta && item.meta[key] && item.meta[key].display_node) { + app.nodeOutputs[item.meta[key].display_node] = value; + } else { + app.nodeOutputs[key] = value; + } + } } }, }), diff --git a/comfy_extras/nodes/nodes_color.py b/comfy_extras/nodes/nodes_color.py new file mode 100644 index 000000000..cb955df33 --- /dev/null +++ b/comfy_extras/nodes/nodes_color.py @@ -0,0 +1,185 @@ +import colour +import numpy as np +import torch +from scipy.interpolate import interpn +from scipy.interpolate.interpnd import LinearNDInterpolator +from scipy.ndimage import gaussian_filter +from scipy.optimize._lsap import linear_sum_assignment +from sklearn.cluster import KMeans + +import comfy.model_management +from comfy.nodes.package_typing import CustomNode + + +class ColorPaletteExtractor(CustomNode): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "num_colors": ("INT", {"default": 8, "min": 2, "max": 64, "step": 1}), + } + } + + RETURN_TYPES = ("IMAGE", "PALETTE") + FUNCTION = "extract_palette" + CATEGORY = "image/color" + + def extract_palette(self, image, num_colors): + if len(image.shape) == 3: + image = image.unsqueeze(0) + + image_np = image.squeeze().cpu().numpy() + + pixels = image_np.reshape(-1, 3) + + kmeans = KMeans(n_clusters=num_colors, random_state=42, n_init=10) + kmeans.fit(pixels) + + colors = kmeans.cluster_centers_ + + _, counts = np.unique(kmeans.labels_, return_counts=True) + sorted_indices = np.argsort(counts)[::-1] + colors = colors[sorted_indices] + + palette_height = 512 + palette_width = 512 + palette_image = np.zeros((palette_height, palette_width * num_colors, 3), dtype=np.uint8) + for i, color in enumerate(colors): + palette_image[:, i * palette_width:(i + 1) * palette_width] = color + + palette_tensor = torch.from_numpy(palette_image).float() / 255.0 + palette_tensor = palette_tensor.unsqueeze(0).to(comfy.model_management.get_torch_device()) + + color_array = (colors * 255).astype(np.uint8) + + return palette_tensor, color_array + + +class ImageBasedColorRemap(CustomNode): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "reference_image": ("IMAGE",), + "palette_size": ("INT", {"default": 8, "min": 2, "max": 64, "step": 1}), + "lut_size": ("INT", {"default": 33, "min": 8, "max": 64, "step": 1}), + } + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "remap_colors" + CATEGORY = "image/color" + + def remap_colors(self, image, reference_image, palette_size, lut_size): + # Ensure images are in the correct format (B, H, W, C) + image = image.squeeze(0) if image.ndim == 4 else image + reference_image = reference_image.squeeze(0) if reference_image.ndim == 4 else reference_image + + # Convert torch tensors to numpy arrays + image_np = image.cpu().numpy() + reference_np = reference_image.cpu().numpy() + + # Extract palettes using k-means clustering + image_palette = self.extract_palette(image_np, palette_size) + reference_palette = self.extract_palette(reference_np, palette_size) + + # Convert palettes to LAB color space + image_palette_lab = colour.XYZ_to_Lab(colour.sRGB_to_XYZ(image_palette)) + reference_palette_lab = colour.XYZ_to_Lab(colour.sRGB_to_XYZ(reference_palette)) + + # Align palettes + aligned_image_palette_lab, aligned_reference_palette_lab = self.align_palettes(image_palette_lab, reference_palette_lab) + + # Create 3D LUT in LAB space + lut = self.create_color_remap_3dlut(aligned_image_palette_lab, aligned_reference_palette_lab, size=lut_size) + + # Apply 3D LUT + image_lab = colour.XYZ_to_Lab(colour.sRGB_to_XYZ(image_np)) + remapped_lab = self.apply_lut(image_lab, lut, lut_size) + remapped_rgb = colour.XYZ_to_sRGB(colour.Lab_to_XYZ(remapped_lab)) + + # Clip values to ensure they're in the valid range + remapped_rgb = np.clip(remapped_rgb, 0, 1) + + # Convert back to torch tensor + remapped_tensor = torch.from_numpy(remapped_rgb).float() + remapped_tensor = remapped_tensor.unsqueeze(0) # Add batch dimension + + return (remapped_tensor,) + + def extract_palette(self, image, palette_size): + pixels = image.reshape(-1, 3) + kmeans = KMeans(n_clusters=palette_size, random_state=42, n_init=10) + kmeans.fit(pixels) + return kmeans.cluster_centers_ + + def align_palettes(self, palette1, palette2): + distances = np.linalg.norm(palette1[:, np.newaxis] - palette2, axis=2) + row_ind, col_ind = linear_sum_assignment(distances) + return palette1, palette2[col_ind] + + def create_color_remap_3dlut(self, original_colors, target_colors, size=33): + # Create a regular grid in LAB space + L = np.linspace(0, 100, size) + a = np.linspace(-128, 127, size) + b = np.linspace(-128, 127, size) + grid = np.meshgrid(L, a, b, indexing='ij') + + # Reshape the grid for KNN + grid_points = np.vstack([g.ravel() for g in grid]).T + + # Use KNN to find the nearest neighbor for each grid point + from sklearn.neighbors import NearestNeighbors + nn = NearestNeighbors(n_neighbors=1, algorithm='auto').fit(original_colors) + _, indices = nn.kneighbors(grid_points) + + # Create the LUT using the target colors + lut = target_colors[indices.ravel()] + lut = lut.reshape(size, size, size, 3) + + # Apply Gaussian smoothing to create a more continuous mapping + for i in range(3): + lut[:,:,:,i] = gaussian_filter(lut[:,:,:,i], sigma=1) + + return lut + + def apply_lut(self, image_lab, lut, lut_size): + points = ( + np.linspace(0, 100, lut_size), # L + np.linspace(-128, 127, lut_size), # a + np.linspace(-128, 127, lut_size) # b + ) + + xi = np.stack([ + image_lab[..., 0], + image_lab[..., 1], + image_lab[..., 2] + ], axis=-1) + + remapped_lab = np.zeros_like(image_lab) + for i in range(3): + remapped_lab[..., i] = interpn( + points, + lut[..., i], + xi, + method='linear', + bounds_error=False, + fill_value=None + ) + + return remapped_lab + + +NODE_CLASS_MAPPINGS = { + "ColorPaletteExtractor": ColorPaletteExtractor, + "ImageBasedColorRemap": ImageBasedColorRemap, + +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ColorPaletteExtractor": "Extract Color Palette", + "ImageBasedColorRemap": "Image-Based Color Remap", + +} diff --git a/custom_nodes/example_node.py.example b/custom_nodes/example_node.py.example index 72ca3688c..9c68ab769 100644 --- a/custom_nodes/example_node.py.example +++ b/custom_nodes/example_node.py.example @@ -54,7 +54,8 @@ class Example: "min": 0, #Minimum value "max": 4096, #Maximum value "step": 64, #Slider's step - "display": "number" # Cosmetic only: display as "number" or "slider" + "display": "number", # Cosmetic only: display as "number" or "slider" + "lazy": True # Will only be evaluated if check_lazy_status requires it }), "float_field": ("FLOAT", { "default": 1.0, @@ -62,11 +63,14 @@ class Example: "max": 10.0, "step": 0.01, "round": 0.001, #The value representing the precision to round to, will be set to the step value by default. Can be set to False to disable rounding. - "display": "number"}), + "display": "number", + "lazy": True + }), "print_to_screen": (["enable", "disable"],), "string_field": ("STRING", { "multiline": False, #True if you want the field to look like the one on the ClipTextEncode node - "default": "Hello World!" + "default": "Hello World!", + "lazy": True }), }, } @@ -80,6 +84,23 @@ class Example: CATEGORY = "Example" + def check_lazy_status(self, image, string_field, int_field, float_field, print_to_screen): + """ + Return a list of input names that need to be evaluated. + + This function will be called if there are any lazy inputs which have not yet been + evaluated. As long as you return at least one field which has not yet been evaluated + (and more exist), this function will be called again once the value of the requested + field is available. + + Any evaluated inputs will be passed as arguments to this function. Any unevaluated + inputs will have the value None. + """ + if print_to_screen == "enable": + return ["int_field", "float_field", "string_field"] + else: + return [] + def test(self, image, string_field, int_field, float_field, print_to_screen): if print_to_screen == "enable": print(f"""Your input contains: diff --git a/pytest.ini b/pytest.ini index 891354251..78b077400 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,7 @@ [pytest] markers = inference: mark as inference test (deselect with '-m "not inference"') + execution: mark as execution test (deselect with '-m "not execution"') testpaths = tests tests-unit diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js index e114e5f93..976083ffc 100644 --- a/tests-ui/tests/groupNode.test.js +++ b/tests-ui/tests/groupNode.test.js @@ -443,6 +443,7 @@ describe("group node", () => { new CustomEvent("executed", { detail: { node: `${nodes.save.id}`, + display_node: `${nodes.save.id}`, output: { images: [ { @@ -483,6 +484,7 @@ describe("group node", () => { new CustomEvent("executed", { detail: { node: `${group.id}:5`, + display_node: `${group.id}:5`, output: { images: [ { diff --git a/tests/asyncio/__init__.py b/tests/asyncio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/downloader/test_huggingface_downloads.py b/tests/downloader/test_huggingface_downloads.py index d1182770b..4a6225380 100644 --- a/tests/downloader/test_huggingface_downloads.py +++ b/tests/downloader/test_huggingface_downloads.py @@ -45,6 +45,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text @pytest.mark.asyncio +@pytest.mark.skip("flakey") def test_known_repos(tmp_path_factory): prev_hub_cache = os.getenv("HF_HUB_CACHE") os.environ["HF_HUB_CACHE"] = str(tmp_path_factory.mktemp("huggingface_root_cache")) diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py new file mode 100644 index 000000000..8ecd11c5b --- /dev/null +++ b/tests/inference/test_execution.py @@ -0,0 +1,374 @@ +import uuid +from contextvars import ContextVar +from typing import Dict, Optional + +import numpy +import pytest +from PIL import Image +from pytest import fixture + +from comfy.cli_args import default_configuration +from comfy.client.embedded_comfy_client import EmbeddedComfyClient +from comfy.component_model.executor_types import SendSyncEvent, SendSyncData, ExecutingMessage, ExecutionErrorMessage, DependencyCycleError +from comfy.distributed.server_stub import ServerStub +from comfy.graph_utils import GraphBuilder, Node +from comfy.nodes.package_typing import ExportedNodes + +current_test_name = ContextVar('current_test_name', default=None) + + +@pytest.fixture(autouse=True) +def set_test_name(request): + token = current_test_name.set(request.node.name) + yield + current_test_name.reset(token) + + +class RunResult: + def __init__(self, prompt_id: str): + self.outputs: Dict[str, Dict] = {} + self.runs: Dict[str, bool] = {} + self.prompt_id: str = prompt_id + + def get_output(self, node: Node): + return self.outputs.get(node.id, None) + + def did_run(self, node: Node): + return self.runs.get(node.id, False) + + def get_images(self, node: Node): + output = self.get_output(node) + if output is None: + return [] + return output.get('image_objects', []) + + def get_prompt_id(self): + return self.prompt_id + + +class _ProgressHandler(ServerStub): + def __init__(self): + super().__init__() + self.tuples: list[tuple[SendSyncEvent, SendSyncData, str]] = [] + + def send_sync(self, + event: SendSyncEvent, + data: SendSyncData, + sid: Optional[str] = None): + self.tuples.append((event, data, sid)) + + +class Client: + def __init__(self, embedded_client: EmbeddedComfyClient, progress_handler: _ProgressHandler): + self.embedded_client = embedded_client + self.progress_handler = progress_handler + + async def run(self, graph: GraphBuilder) -> RunResult: + self.progress_handler.tuples = [] + for node in graph.nodes.values(): + if node.class_type == 'SaveImage': + node.inputs['filename_prefix'] = current_test_name.get() + + prompt_id = str(uuid.uuid4()) + try: + outputs = await self.embedded_client.queue_prompt(graph.finalize(), prompt_id=prompt_id) + except (RuntimeError, DependencyCycleError): + outputs = {} + result = RunResult(prompt_id=prompt_id) + result.outputs = outputs + result.runs = {} + send_sync_event: SendSyncEvent + send_sync_data: SendSyncData + for send_sync_event, send_sync_data, _ in self.progress_handler.tuples: + if send_sync_event == "executing": + send_sync_data: ExecutingMessage + result.runs[send_sync_data["node"]] = True + elif send_sync_event == "execution_error": + send_sync_data: ExecutionErrorMessage + raise Exception(send_sync_data) + + for node in outputs.values(): + if "images" in node: + image_objects = node["image_objects"] = [] + for image in node["images"]: + image_objects.append(Image.open(image["abs_path"])) + return result + + +# Loop through these variables +@pytest.mark.execution +class TestExecution: + # Initialize server and client + @fixture(scope="class", params=[ + # (lru_size) + (0,), + (100,), + ]) + async def client(self, request) -> Client: + from comfy.cmd.execution import nodes + from .testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + + nodes.update(ExportedNodes(NODE_CLASS_MAPPINGS=NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=NODE_DISPLAY_NAME_MAPPINGS)) + lru_size, = request.param + configuration = default_configuration() + configuration.cache_lru = lru_size + progress_handler = _ProgressHandler() + async with EmbeddedComfyClient(configuration, progress_handler=progress_handler) as embedded_client: + yield Client(embedded_client, progress_handler) + + @fixture + def builder(self, request): + yield GraphBuilder(prefix=request.node.name) + + async def test_lazy_input(self, client: Client, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + output = g.node("SaveImage", images=lazy_mix.out(0)) + result = await client.run(g) + + result_image = result.get_images(output)[0] + assert numpy.array(result_image).any() == 0, "Image should be black" + assert result.did_run(input1) + assert not result.did_run(input2) + assert result.did_run(mask) + assert result.did_run(lazy_mix) + + async def test_full_cache(self, client: Client, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix.out(0)) + + await client.run(g) + result2 = await client.run(g) + for node_id, node in g.nodes.items(): + assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + + async def test_partial_cache(self, client: Client, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix.out(0)) + + await client.run(g) + mask.inputs['value'] = 0.4 + result2 = await client.run(g) + assert not result2.did_run(input1), "Input1 should have been cached" + assert not result2.did_run(input2), "Input2 should have been cached" + + async def test_error(self, client: Client, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + # Different size of the two images + input2 = g.node("StubImage", content="NOISE", height=256, width=256, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix.out(0)) + + try: + await client.run(g) + assert False, "Should have raised an error" + except Exception as e: + assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" + + @pytest.mark.parametrize("test_value, expect_error", [ + (5, True), + ("foo", True), + (5.0, False), + ]) + async def test_validation_error_literal(self, test_value, expect_error, client: Client, builder: GraphBuilder): + g = builder + validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0) + g.node("SaveImage", images=validation1.out(0)) + + if expect_error: + with pytest.raises(ValueError): + await client.run(g) + else: + await client.run(g) + + @pytest.mark.parametrize("test_type, test_value", [ + ("StubInt", 5), + ("StubFloat", 5.0) + ]) + async def test_validation_error_edge1(self, test_type, test_value, client: Client, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation1.out(0)) + + with pytest.raises(ValueError): + await client.run(g) + + @pytest.mark.parametrize("test_type, test_value, expect_error", [ + ("StubInt", 5, True), + ("StubFloat", 5.0, False) + ]) + async def test_validation_error_edge2(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation2.out(0)) + + if expect_error: + with pytest.raises(ValueError): + await client.run(g) + else: + await client.run(g) + + @pytest.mark.parametrize("test_type, test_value, expect_error", [ + ("StubInt", 5, True), + ("StubFloat", 5.0, False) + ]) + async def test_validation_error_edge3(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation3.out(0)) + + if expect_error: + with pytest.raises(ValueError): + await client.run(g) + else: + await client.run(g) + + @pytest.mark.parametrize("test_type, test_value, expect_error", [ + ("StubInt", 5, True), + ("StubFloat", 5.0, False) + ]) + async def test_validation_error_edge4(self, test_type, test_value, expect_error, client: Client, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation4.out(0)) + + if expect_error: + with pytest.raises(ValueError): + await client.run(g) + else: + await client.run(g) + + async def test_cycle_error(self, client: Client, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix1 = g.node("TestLazyMixImages", image1=input1.out(0), mask=mask.out(0)) + lazy_mix2 = g.node("TestLazyMixImages", image1=lazy_mix1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix2.out(0)) + + # When the cycle exists on initial submission, it should raise a validation error + with pytest.raises(ValueError): + await client.run(g) + + async def test_dynamic_cycle_error(self, client: Client, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + generator = g.node("TestDynamicDependencyCycle", input1=input1.out(0), input2=input2.out(0)) + g.node("SaveImage", images=generator.out(0)) + + # When the cycle is in a graph that is generated dynamically, it should raise a runtime error + try: + await client.run(g) + assert False, "Should have raised an error" + except Exception as e: + assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" + assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node" + + async def test_custom_is_changed(self, client: Client, builder: GraphBuilder): + g = builder + # Creating the nodes in this specific order previously caused a bug + save = g.node("SaveImage") + is_changed = g.node("TestCustomIsChanged", should_change=False) + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + save.set_input('images', is_changed.out(0)) + is_changed.set_input('image', input1.out(0)) + + result1 = await client.run(g) + result2 = await client.run(g) + is_changed.set_input('should_change', True) + result3 = await client.run(g) + result4 = await client.run(g) + assert result1.did_run(is_changed), "is_changed should have been run" + assert not result2.did_run(is_changed), "is_changed should have been cached" + assert result3.did_run(is_changed), "is_changed should have been re-run" + assert result4.did_run(is_changed), "is_changed should not have been cached" + + async def test_undeclared_inputs(self, client: Client, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + input3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input4 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + average = g.node("TestVariadicAverage", input1=input1.out(0), input2=input2.out(0), input3=input3.out(0), input4=input4.out(0)) + output = g.node("SaveImage", images=average.out(0)) + + result = await client.run(g) + result_image = result.get_images(output)[0] + expected = 255 // 4 + assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey" + + async def test_for_loop(self, client: Client, builder: GraphBuilder): + g = builder + iterations = 4 + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + is_changed = g.node("TestCustomIsChanged", should_change=True, image=input2.out(0)) + for_open = g.node("TestForLoopOpen", remaining=iterations, initial_value1=is_changed.out(0)) + average = g.node("TestVariadicAverage", input1=input1.out(0), input2=for_open.out(2)) + for_close = g.node("TestForLoopClose", flow_control=for_open.out(0), initial_value1=average.out(0)) + output = g.node("SaveImage", images=for_close.out(0)) + + for iterations in range(1, 5): + for_open.set_input('remaining', iterations) + result = await client.run(g) + result_image = result.get_images(output)[0] + expected = 255 // (2 ** iterations) + assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey" + assert result.did_run(is_changed) + + async def test_mixed_expansion_returns(self, client: Client, builder: GraphBuilder): + g = builder + val_list = g.node("TestMakeListNode", value1=0.1, value2=0.2, value3=0.3) + mixed = g.node("TestMixedExpansionReturns", input1=val_list.out(0)) + output_dynamic = g.node("SaveImage", images=mixed.out(0)) + output_literal = g.node("SaveImage", images=mixed.out(1)) + + result = await client.run(g) + images_dynamic = result.get_images(output_dynamic) + assert len(images_dynamic) == 3, "Should have 2 images" + assert numpy.array(images_dynamic[0]).min() == 25 and numpy.array(images_dynamic[0]).max() == 25, "First image should be 0.1" + assert numpy.array(images_dynamic[1]).min() == 51 and numpy.array(images_dynamic[1]).max() == 51, "Second image should be 0.2" + assert numpy.array(images_dynamic[2]).min() == 76 and numpy.array(images_dynamic[2]).max() == 76, "Third image should be 0.3" + + images_literal = result.get_images(output_literal) + assert len(images_literal) == 3, "Should have 2 images" + for i in range(3): + assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white" + + async def test_output_reuse(self, client: Client, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + output1 = g.node("PreviewImage", images=input1.out(0)) + output2 = g.node("PreviewImage", images=input1.out(0)) + + result = await client.run(g) + images1 = result.get_images(output1) + images2 = result.get_images(output2) + assert len(images1) == 1, "Should have 1 image" + assert len(images2) == 1, "Should have 1 image" diff --git a/tests/inference/testing_nodes/testing-pack/__init__.py b/tests/inference/testing_nodes/testing-pack/__init__.py new file mode 100644 index 000000000..dcc71659a --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/__init__.py @@ -0,0 +1,23 @@ +from .specific_tests import TEST_NODE_CLASS_MAPPINGS, TEST_NODE_DISPLAY_NAME_MAPPINGS +from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS +from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS +from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS +from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS + +# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS) +# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS) + +NODE_CLASS_MAPPINGS = {} +NODE_CLASS_MAPPINGS.update(TEST_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS) + +NODE_DISPLAY_NAME_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS) + diff --git a/tests/inference/testing_nodes/testing-pack/conditions.py b/tests/inference/testing_nodes/testing-pack/conditions.py new file mode 100644 index 000000000..0c200ee28 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/conditions.py @@ -0,0 +1,194 @@ +import re +import torch + +class TestIntConditions: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "operation": (["==", "!=", "<", ">", "<=", ">="],), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "int_condition" + + CATEGORY = "Testing/Logic" + + def int_condition(self, a, b, operation): + if operation == "==": + return (a == b,) + elif operation == "!=": + return (a != b,) + elif operation == "<": + return (a < b,) + elif operation == ">": + return (a > b,) + elif operation == "<=": + return (a <= b,) + elif operation == ">=": + return (a >= b,) + + +class TestFloatConditions: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}), + "b": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}), + "operation": (["==", "!=", "<", ">", "<=", ">="],), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "float_condition" + + CATEGORY = "Testing/Logic" + + def float_condition(self, a, b, operation): + if operation == "==": + return (a == b,) + elif operation == "!=": + return (a != b,) + elif operation == "<": + return (a < b,) + elif operation == ">": + return (a > b,) + elif operation == "<=": + return (a <= b,) + elif operation == ">=": + return (a >= b,) + +class TestStringConditions: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("STRING", {"multiline": False}), + "b": ("STRING", {"multiline": False}), + "operation": (["a == b", "a != b", "a IN b", "a MATCH REGEX(b)", "a BEGINSWITH b", "a ENDSWITH b"],), + "case_sensitive": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "string_condition" + + CATEGORY = "Testing/Logic" + + def string_condition(self, a, b, operation, case_sensitive): + if not case_sensitive: + a = a.lower() + b = b.lower() + + if operation == "a == b": + return (a == b,) + elif operation == "a != b": + return (a != b,) + elif operation == "a IN b": + return (a in b,) + elif operation == "a MATCH REGEX(b)": + try: + return (re.match(b, a) is not None,) + except: + return (False,) + elif operation == "a BEGINSWITH b": + return (a.startswith(b),) + elif operation == "a ENDSWITH b": + return (a.endswith(b),) + +class TestToBoolNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("*",), + }, + "optional": { + "invert": ("BOOLEAN", {"default": False}), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "to_bool" + + CATEGORY = "Testing/Logic" + + def to_bool(self, value, invert = False): + if isinstance(value, torch.Tensor): + if value.max().item() == 0 and value.min().item() == 0: + result = False + else: + result = True + else: + try: + result = bool(value) + except: + # Can't convert it? Well then it's something or other. I dunno, I'm not a Python programmer. + result = True + + if invert: + result = not result + + return (result,) + +class TestBoolOperationNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("BOOLEAN",), + "b": ("BOOLEAN",), + "op": (["a AND b", "a OR b", "a XOR b", "NOT a"],), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "bool_operation" + + CATEGORY = "Testing/Logic" + + def bool_operation(self, a, b, op): + if op == "a AND b": + return (a and b,) + elif op == "a OR b": + return (a or b,) + elif op == "a XOR b": + return (a ^ b,) + elif op == "NOT a": + return (not a,) + + +CONDITION_NODE_CLASS_MAPPINGS = { + "TestIntConditions": TestIntConditions, + "TestFloatConditions": TestFloatConditions, + "TestStringConditions": TestStringConditions, + "TestToBoolNode": TestToBoolNode, + "TestBoolOperationNode": TestBoolOperationNode, +} + +CONDITION_NODE_DISPLAY_NAME_MAPPINGS = { + "TestIntConditions": "Int Condition", + "TestFloatConditions": "Float Condition", + "TestStringConditions": "String Condition", + "TestToBoolNode": "To Bool", + "TestBoolOperationNode": "Bool Operation", +} diff --git a/tests/inference/testing_nodes/testing-pack/flow_control.py b/tests/inference/testing_nodes/testing-pack/flow_control.py new file mode 100644 index 000000000..1ef1cf803 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/flow_control.py @@ -0,0 +1,173 @@ +from comfy.graph_utils import GraphBuilder, is_link +from comfy.graph import ExecutionBlocker +from .tools import VariantSupport + +NUM_FLOW_SOCKETS = 5 +@VariantSupport() +class TestWhileLoopOpen: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "condition": ("BOOLEAN", {"default": True}), + }, + "optional": { + }, + } + for i in range(NUM_FLOW_SOCKETS): + inputs["optional"][f"initial_value{i}"] = ("*",) + return inputs + + RETURN_TYPES = tuple(["FLOW_CONTROL"] + ["*"] * NUM_FLOW_SOCKETS) + RETURN_NAMES = tuple(["FLOW_CONTROL"] + [f"value{i}" for i in range(NUM_FLOW_SOCKETS)]) + FUNCTION = "while_loop_open" + + CATEGORY = "Testing/Flow" + + def while_loop_open(self, condition, **kwargs): + values = [] + for i in range(NUM_FLOW_SOCKETS): + values.append(kwargs.get(f"initial_value{i}", None)) + return tuple(["stub"] + values) + +@VariantSupport() +class TestWhileLoopClose: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "flow_control": ("FLOW_CONTROL", {"rawLink": True}), + "condition": ("BOOLEAN", {"forceInput": True}), + }, + "optional": { + }, + "hidden": { + "dynprompt": "DYNPROMPT", + "unique_id": "UNIQUE_ID", + } + } + for i in range(NUM_FLOW_SOCKETS): + inputs["optional"][f"initial_value{i}"] = ("*",) + return inputs + + RETURN_TYPES = tuple(["*"] * NUM_FLOW_SOCKETS) + RETURN_NAMES = tuple([f"value{i}" for i in range(NUM_FLOW_SOCKETS)]) + FUNCTION = "while_loop_close" + + CATEGORY = "Testing/Flow" + + def explore_dependencies(self, node_id, dynprompt, upstream): + node_info = dynprompt.get_node(node_id) + if "inputs" not in node_info: + return + for k, v in node_info["inputs"].items(): + if is_link(v): + parent_id = v[0] + if parent_id not in upstream: + upstream[parent_id] = [] + self.explore_dependencies(parent_id, dynprompt, upstream) + upstream[parent_id].append(node_id) + + def collect_contained(self, node_id, upstream, contained): + if node_id not in upstream: + return + for child_id in upstream[node_id]: + if child_id not in contained: + contained[child_id] = True + self.collect_contained(child_id, upstream, contained) + + + def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=None, **kwargs): + assert dynprompt is not None + if not condition: + # We're done with the loop + values = [] + for i in range(NUM_FLOW_SOCKETS): + values.append(kwargs.get(f"initial_value{i}", None)) + return tuple(values) + + # We want to loop + upstream = {} + # Get the list of all nodes between the open and close nodes + self.explore_dependencies(unique_id, dynprompt, upstream) + + contained = {} + open_node = flow_control[0] + self.collect_contained(open_node, upstream, contained) + contained[unique_id] = True + contained[open_node] = True + + # We'll use the default prefix, but to avoid having node names grow exponentially in size, + # we'll use "Recurse" for the name of the recursively-generated copy of this node. + graph = GraphBuilder() + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id) + node.set_override_display_id(node_id) + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.lookup_node("Recurse" if node_id == unique_id else node_id) + assert node is not None + for k, v in original_node["inputs"].items(): + if is_link(v) and v[0] in contained: + parent = graph.lookup_node(v[0]) + assert parent is not None + node.set_input(k, parent.out(v[1])) + else: + node.set_input(k, v) + new_open = graph.lookup_node(open_node) + assert new_open is not None + for i in range(NUM_FLOW_SOCKETS): + key = f"initial_value{i}" + new_open.set_input(key, kwargs.get(key, None)) + my_clone = graph.lookup_node("Recurse") + assert my_clone is not None + result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS)) + return { + "result": tuple(result), + "expand": graph.finalize(), + } + +@VariantSupport() +class TestExecutionBlockerNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "input": ("*",), + "block": ("BOOLEAN",), + "verbose": ("BOOLEAN", {"default": False}), + }, + } + return inputs + + RETURN_TYPES = ("*",) + RETURN_NAMES = ("output",) + FUNCTION = "execution_blocker" + + CATEGORY = "Testing/Flow" + + def execution_blocker(self, input, block, verbose): + if block: + return (ExecutionBlocker("Blocked Execution" if verbose else None),) + return (input,) + +FLOW_CONTROL_NODE_CLASS_MAPPINGS = { + "TestWhileLoopOpen": TestWhileLoopOpen, + "TestWhileLoopClose": TestWhileLoopClose, + "TestExecutionBlocker": TestExecutionBlockerNode, +} +FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = { + "TestWhileLoopOpen": "While Loop Open", + "TestWhileLoopClose": "While Loop Close", + "TestExecutionBlocker": "Execution Blocker", +} diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py new file mode 100644 index 000000000..8e8ce32ce --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -0,0 +1,308 @@ +import torch +from .tools import VariantSupport +from comfy.graph_utils import GraphBuilder + +class TestLazyMixImages: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image1": ("IMAGE",{"lazy": True}), + "image2": ("IMAGE",{"lazy": True}), + "mask": ("MASK",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "mix" + + CATEGORY = "Testing/Nodes" + + def check_lazy_status(self, mask, image1, image2): + mask_min = mask.min() + mask_max = mask.max() + needed = [] + if image1 is None and (mask_min != 1.0 or mask_max != 1.0): + needed.append("image1") + if image2 is None and (mask_min != 0.0 or mask_max != 0.0): + needed.append("image2") + return needed + + # Not trying to handle different batch sizes here just to keep the demo simple + def mix(self, mask, image1, image2): + mask_min = mask.min() + mask_max = mask.max() + if mask_min == 0.0 and mask_max == 0.0: + return (image1,) + elif mask_min == 1.0 and mask_max == 1.0: + return (image2,) + + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + if len(mask.shape) == 3: + mask = mask.unsqueeze(3) + if mask.shape[3] < image1.shape[3]: + mask = mask.repeat(1, 1, 1, image1.shape[3]) + + result = image1 * (1. - mask) + image2 * mask, + return (result[0],) + +class TestVariadicAverage: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "variadic_average" + + CATEGORY = "Testing/Nodes" + + def variadic_average(self, input1, **kwargs): + inputs = [input1] + while 'input' + str(len(inputs) + 1) in kwargs: + inputs.append(kwargs['input' + str(len(inputs) + 1)]) + return (torch.stack(inputs).mean(dim=0),) + + +class TestCustomIsChanged: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + "optional": { + "should_change": ("BOOL", {"default": False}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_is_changed" + + CATEGORY = "Testing/Nodes" + + def custom_is_changed(self, image, should_change=False): + return (image,) + + @classmethod + def IS_CHANGED(cls, should_change=False, *args, **kwargs): + if should_change: + return float("NaN") + else: + return False + +class TestCustomValidation1: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE,FLOAT",), + "input2": ("IMAGE,FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation1" + + CATEGORY = "Testing/Nodes" + + def custom_validation1(self, input1, input2): + if isinstance(input1, float) and isinstance(input2, float): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + else: + result = input1 * input2 + return (result,) + + @classmethod + def VALIDATE_INPUTS(cls, input1=None, input2=None): + if input1 is not None: + if not isinstance(input1, (torch.Tensor, float)): + return f"Invalid type of input1: {type(input1)}" + if input2 is not None: + if not isinstance(input2, (torch.Tensor, float)): + return f"Invalid type of input2: {type(input2)}" + + return True + +class TestCustomValidation2: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE,FLOAT",), + "input2": ("IMAGE,FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation2" + + CATEGORY = "Testing/Nodes" + + def custom_validation2(self, input1, input2): + if isinstance(input1, float) and isinstance(input2, float): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + else: + result = input1 * input2 + return (result,) + + @classmethod + def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None): + if input1 is not None: + if not isinstance(input1, (torch.Tensor, float)): + return f"Invalid type of input1: {type(input1)}" + if input2 is not None: + if not isinstance(input2, (torch.Tensor, float)): + return f"Invalid type of input2: {type(input2)}" + + if 'input1' in input_types: + if input_types['input1'] not in ["IMAGE", "FLOAT"]: + return f"Invalid type of input1: {input_types['input1']}" + if 'input2' in input_types: + if input_types['input2'] not in ["IMAGE", "FLOAT"]: + return f"Invalid type of input2: {input_types['input2']}" + + return True + +@VariantSupport() +class TestCustomValidation3: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE,FLOAT",), + "input2": ("IMAGE,FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation3" + + CATEGORY = "Testing/Nodes" + + def custom_validation3(self, input1, input2): + if isinstance(input1, float) and isinstance(input2, float): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + else: + result = input1 * input2 + return (result,) + +class TestCustomValidation4: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("FLOAT",), + "input2": ("FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation4" + + CATEGORY = "Testing/Nodes" + + def custom_validation4(self, input1, input2): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + return (result,) + + @classmethod + def VALIDATE_INPUTS(cls, input1, input2): + if input1 is not None: + if not isinstance(input1, float): + return f"Invalid type of input1: {type(input1)}" + if input2 is not None: + if not isinstance(input2, float): + return f"Invalid type of input2: {type(input2)}" + + return True + +class TestDynamicDependencyCycle: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE",), + "input2": ("IMAGE",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "dynamic_dependency_cycle" + + CATEGORY = "Testing/Nodes" + + def dynamic_dependency_cycle(self, input1, input2): + g = GraphBuilder() + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + mix1 = g.node("TestLazyMixImages", image1=input1, mask=mask.out(0)) + mix2 = g.node("TestLazyMixImages", image1=mix1.out(0), image2=input2, mask=mask.out(0)) + + # Create the cyle + mix1.set_input("image2", mix2.out(0)) + + return { + "result": (mix2.out(0),), + "expand": g.finalize(), + } + +class TestMixedExpansionReturns: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE","IMAGE") + FUNCTION = "mixed_expansion_returns" + + CATEGORY = "Testing/Nodes" + + def mixed_expansion_returns(self, input1): + white_image = torch.ones([1, 512, 512, 3]) + if input1 <= 0.1: + return (torch.ones([1, 512, 512, 3]) * 0.1, white_image) + elif input1 <= 0.2: + return { + "result": (torch.ones([1, 512, 512, 3]) * 0.2, white_image), + } + else: + g = GraphBuilder() + mask = g.node("StubMask", value=0.3, height=512, width=512, batch_size=1) + black = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + white = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mix = g.node("TestLazyMixImages", image1=black.out(0), image2=white.out(0), mask=mask.out(0)) + return { + "result": (mix.out(0), white_image), + "expand": g.finalize(), + } + +TEST_NODE_CLASS_MAPPINGS = { + "TestLazyMixImages": TestLazyMixImages, + "TestVariadicAverage": TestVariadicAverage, + "TestCustomIsChanged": TestCustomIsChanged, + "TestCustomValidation1": TestCustomValidation1, + "TestCustomValidation2": TestCustomValidation2, + "TestCustomValidation3": TestCustomValidation3, + "TestCustomValidation4": TestCustomValidation4, + "TestDynamicDependencyCycle": TestDynamicDependencyCycle, + "TestMixedExpansionReturns": TestMixedExpansionReturns, +} + +TEST_NODE_DISPLAY_NAME_MAPPINGS = { + "TestLazyMixImages": "Lazy Mix Images", + "TestVariadicAverage": "Variadic Average", + "TestCustomIsChanged": "Custom IsChanged", + "TestCustomValidation1": "Custom Validation 1", + "TestCustomValidation2": "Custom Validation 2", + "TestCustomValidation3": "Custom Validation 3", + "TestCustomValidation4": "Custom Validation 4", + "TestDynamicDependencyCycle": "Dynamic Dependency Cycle", + "TestMixedExpansionReturns": "Mixed Expansion Returns", +} diff --git a/tests/inference/testing_nodes/testing-pack/stubs.py b/tests/inference/testing_nodes/testing-pack/stubs.py new file mode 100644 index 000000000..9be6eac9d --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/stubs.py @@ -0,0 +1,105 @@ +import torch + +class StubImage: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "content": (['WHITE', 'BLACK', 'NOISE'],), + "height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "stub_image" + + CATEGORY = "Testing/Stub Nodes" + + def stub_image(self, content, height, width, batch_size): + if content == "WHITE": + return (torch.ones(batch_size, height, width, 3),) + elif content == "BLACK": + return (torch.zeros(batch_size, height, width, 3),) + elif content == "NOISE": + return (torch.rand(batch_size, height, width, 3),) + +class StubMask: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + "height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}), + }, + } + + RETURN_TYPES = ("MASK",) + FUNCTION = "stub_mask" + + CATEGORY = "Testing/Stub Nodes" + + def stub_mask(self, value, height, width, batch_size): + return (torch.ones(batch_size, height, width) * value,) + +class StubInt: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}), + }, + } + + RETURN_TYPES = ("INT",) + FUNCTION = "stub_int" + + CATEGORY = "Testing/Stub Nodes" + + def stub_int(self, value): + return (value,) + +class StubFloat: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}), + }, + } + + RETURN_TYPES = ("FLOAT",) + FUNCTION = "stub_float" + + CATEGORY = "Testing/Stub Nodes" + + def stub_float(self, value): + return (value,) + +TEST_STUB_NODE_CLASS_MAPPINGS = { + "StubImage": StubImage, + "StubMask": StubMask, + "StubInt": StubInt, + "StubFloat": StubFloat, +} +TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = { + "StubImage": "Stub Image", + "StubMask": "Stub Mask", + "StubInt": "Stub Int", + "StubFloat": "Stub Float", +} diff --git a/tests/inference/testing_nodes/testing-pack/tools.py b/tests/inference/testing_nodes/testing-pack/tools.py new file mode 100644 index 000000000..34b28c0eb --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/tools.py @@ -0,0 +1,53 @@ + +def MakeSmartType(t): + if isinstance(t, str): + return SmartType(t) + return t + +class SmartType(str): + def __ne__(self, other): + if self == "*" or other == "*": + return False + selfset = set(self.split(',')) + otherset = set(other.split(',')) + return not selfset.issubset(otherset) + +def VariantSupport(): + def decorator(cls): + if hasattr(cls, "INPUT_TYPES"): + old_input_types = getattr(cls, "INPUT_TYPES") + def new_input_types(*args, **kwargs): + types = old_input_types(*args, **kwargs) + for category in ["required", "optional"]: + if category not in types: + continue + for key, value in types[category].items(): + if isinstance(value, tuple): + types[category][key] = (MakeSmartType(value[0]),) + value[1:] + return types + setattr(cls, "INPUT_TYPES", new_input_types) + if hasattr(cls, "RETURN_TYPES"): + old_return_types = cls.RETURN_TYPES + setattr(cls, "RETURN_TYPES", tuple(MakeSmartType(x) for x in old_return_types)) + if hasattr(cls, "VALIDATE_INPUTS"): + # Reflection is used to determine what the function signature is, so we can't just change the function signature + raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet") + else: + def validate_inputs(input_types): + inputs = cls.INPUT_TYPES() + for key, value in input_types.items(): + if isinstance(value, SmartType): + continue + if "required" in inputs and key in inputs["required"]: + expected_type = inputs["required"][key][0] + elif "optional" in inputs and key in inputs["optional"]: + expected_type = inputs["optional"][key][0] + else: + expected_type = None + if expected_type is not None and MakeSmartType(value) != expected_type: + return f"Invalid type of {key}: {value} (expected {expected_type})" + return True + setattr(cls, "VALIDATE_INPUTS", validate_inputs) + return cls + return decorator + diff --git a/tests/inference/testing_nodes/testing-pack/util.py b/tests/inference/testing_nodes/testing-pack/util.py new file mode 100644 index 000000000..fea83e37a --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/util.py @@ -0,0 +1,364 @@ +from comfy.graph_utils import GraphBuilder +from .tools import VariantSupport + +@VariantSupport() +class TestAccumulateNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "to_add": ("*",), + }, + "optional": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("ACCUMULATION",) + FUNCTION = "accumulate" + + CATEGORY = "Testing/Lists" + + def accumulate(self, to_add, accumulation = None): + if accumulation is None: + value = [to_add] + else: + value = accumulation["accum"] + [to_add] + return ({"accum": value},) + +@VariantSupport() +class TestAccumulationHeadNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("ACCUMULATION", "*",) + FUNCTION = "accumulation_head" + + CATEGORY = "Testing/Lists" + + def accumulation_head(self, accumulation): + accum = accumulation["accum"] + if len(accum) == 0: + return (accumulation, None) + else: + return ({"accum": accum[1:]}, accum[0]) + +class TestAccumulationTailNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("ACCUMULATION", "*",) + FUNCTION = "accumulation_tail" + + CATEGORY = "Testing/Lists" + + def accumulation_tail(self, accumulation): + accum = accumulation["accum"] + if len(accum) == 0: + return (None, accumulation) + else: + return ({"accum": accum[:-1]}, accum[-1]) + +@VariantSupport() +class TestAccumulationToListNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("*",) + OUTPUT_IS_LIST = (True,) + + FUNCTION = "accumulation_to_list" + + CATEGORY = "Testing/Lists" + + def accumulation_to_list(self, accumulation): + return (accumulation["accum"],) + +@VariantSupport() +class TestListToAccumulationNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "list": ("*",), + }, + } + + RETURN_TYPES = ("ACCUMULATION",) + INPUT_IS_LIST = (True,) + + FUNCTION = "list_to_accumulation" + + CATEGORY = "Testing/Lists" + + def list_to_accumulation(self, list): + return ({"accum": list},) + +@VariantSupport() +class TestAccumulationGetLengthNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("INT",) + + FUNCTION = "accumlength" + + CATEGORY = "Testing/Lists" + + def accumlength(self, accumulation): + return (len(accumulation['accum']),) + +@VariantSupport() +class TestAccumulationGetItemNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + "index": ("INT", {"default":0, "step":1}) + }, + } + + RETURN_TYPES = ("*",) + + FUNCTION = "get_item" + + CATEGORY = "Testing/Lists" + + def get_item(self, accumulation, index): + return (accumulation['accum'][index],) + +@VariantSupport() +class TestAccumulationSetItemNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + "index": ("INT", {"default":0, "step":1}), + "value": ("*",), + }, + } + + RETURN_TYPES = ("ACCUMULATION",) + + FUNCTION = "set_item" + + CATEGORY = "Testing/Lists" + + def set_item(self, accumulation, index, value): + new_accum = accumulation['accum'][:] + new_accum[index] = value + return ({"accum": new_accum},) + +class TestIntMathOperation: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "operation": (["add", "subtract", "multiply", "divide", "modulo", "power"],), + }, + } + + RETURN_TYPES = ("INT",) + FUNCTION = "int_math_operation" + + CATEGORY = "Testing/Logic" + + def int_math_operation(self, a, b, operation): + if operation == "add": + return (a + b,) + elif operation == "subtract": + return (a - b,) + elif operation == "multiply": + return (a * b,) + elif operation == "divide": + return (a // b,) + elif operation == "modulo": + return (a % b,) + elif operation == "power": + return (a ** b,) + + +from .flow_control import NUM_FLOW_SOCKETS +@VariantSupport() +class TestForLoopOpen: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "remaining": ("INT", {"default": 1, "min": 0, "max": 100000, "step": 1}), + }, + "optional": { + f"initial_value{i}": ("*",) for i in range(1, NUM_FLOW_SOCKETS) + }, + "hidden": { + "initial_value0": ("*",) + } + } + + RETURN_TYPES = tuple(["FLOW_CONTROL", "INT",] + ["*"] * (NUM_FLOW_SOCKETS-1)) + RETURN_NAMES = tuple(["flow_control", "remaining"] + [f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)]) + FUNCTION = "for_loop_open" + + CATEGORY = "Testing/Flow" + + def for_loop_open(self, remaining, **kwargs): + graph = GraphBuilder() + if "initial_value0" in kwargs: + remaining = kwargs["initial_value0"] + while_open = graph.node("TestWhileLoopOpen", condition=remaining, initial_value0=remaining, **{(f"initial_value{i}"): kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)}) + outputs = [kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)] + return { + "result": tuple(["stub", remaining] + outputs), + "expand": graph.finalize(), + } + +@VariantSupport() +class TestForLoopClose: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "flow_control": ("FLOW_CONTROL", {"rawLink": True}), + }, + "optional": { + f"initial_value{i}": ("*",{"rawLink": True}) for i in range(1, NUM_FLOW_SOCKETS) + }, + } + + RETURN_TYPES = tuple(["*"] * (NUM_FLOW_SOCKETS-1)) + RETURN_NAMES = tuple([f"value{i}" for i in range(1, NUM_FLOW_SOCKETS)]) + FUNCTION = "for_loop_close" + + CATEGORY = "Testing/Flow" + + def for_loop_close(self, flow_control, **kwargs): + graph = GraphBuilder() + while_open = flow_control[0] + sub = graph.node("TestIntMathOperation", operation="subtract", a=[while_open,1], b=1) + cond = graph.node("TestToBoolNode", value=sub.out(0)) + input_values = {f"initial_value{i}": kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)} + while_close = graph.node("TestWhileLoopClose", + flow_control=flow_control, + condition=cond.out(0), + initial_value0=sub.out(0), + **input_values) + return { + "result": tuple([while_close.out(i) for i in range(1, NUM_FLOW_SOCKETS)]), + "expand": graph.finalize(), + } + +NUM_LIST_SOCKETS = 10 +@VariantSupport() +class TestMakeListNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value1": ("*",), + }, + "optional": { + f"value{i}": ("*",) for i in range(1, NUM_LIST_SOCKETS) + }, + } + + RETURN_TYPES = ("*",) + FUNCTION = "make_list" + OUTPUT_IS_LIST = (True,) + + CATEGORY = "Testing/Lists" + + def make_list(self, **kwargs): + result = [] + for i in range(NUM_LIST_SOCKETS): + if f"value{i}" in kwargs: + result.append(kwargs[f"value{i}"]) + return (result,) + +UTILITY_NODE_CLASS_MAPPINGS = { + "TestAccumulateNode": TestAccumulateNode, + "TestAccumulationHeadNode": TestAccumulationHeadNode, + "TestAccumulationTailNode": TestAccumulationTailNode, + "TestAccumulationToListNode": TestAccumulationToListNode, + "TestListToAccumulationNode": TestListToAccumulationNode, + "TestAccumulationGetLengthNode": TestAccumulationGetLengthNode, + "TestAccumulationGetItemNode": TestAccumulationGetItemNode, + "TestAccumulationSetItemNode": TestAccumulationSetItemNode, + "TestForLoopOpen": TestForLoopOpen, + "TestForLoopClose": TestForLoopClose, + "TestIntMathOperation": TestIntMathOperation, + "TestMakeListNode": TestMakeListNode, +} +UTILITY_NODE_DISPLAY_NAME_MAPPINGS = { + "TestAccumulateNode": "Accumulate", + "TestAccumulationHeadNode": "Accumulation Head", + "TestAccumulationTailNode": "Accumulation Tail", + "TestAccumulationToListNode": "Accumulation to List", + "TestListToAccumulationNode": "List to Accumulation", + "TestAccumulationGetLengthNode": "Accumulation Get Length", + "TestAccumulationGetItemNode": "Accumulation Get Item", + "TestAccumulationSetItemNode": "Accumulation Set Item", + "TestForLoopOpen": "For Loop Open", + "TestForLoopClose": "For Loop Close", + "TestIntMathOperation": "Int Math Operation", + "TestMakeListNode": "Make List", +} diff --git a/tests/unit/app_test/__init__.py b/tests/unit/app_test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/app_test/frontend_manager_test.py b/tests/unit/app_test/frontend_manager_test.py new file mode 100644 index 000000000..dd44527b4 --- /dev/null +++ b/tests/unit/app_test/frontend_manager_test.py @@ -0,0 +1,101 @@ +import argparse + +import pytest +from requests.exceptions import HTTPError + +from comfy.app.frontend_management import ( + FrontendManager, + FrontEndProvider, + Release, +) +from comfy.cli_args import DEFAULT_VERSION_STRING + + +@pytest.fixture +def mock_releases(): + return [ + Release( + id=1, + tag_name="1.0.0", + name="Release 1.0.0", + prerelease=False, + created_at="2022-01-01T00:00:00Z", + published_at="2022-01-01T00:00:00Z", + body="Release notes for 1.0.0", + assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}], + ), + Release( + id=2, + tag_name="2.0.0", + name="Release 2.0.0", + prerelease=False, + created_at="2022-02-01T00:00:00Z", + published_at="2022-02-01T00:00:00Z", + body="Release notes for 2.0.0", + assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}], + ), + ] + + +@pytest.fixture +def mock_provider(mock_releases): + provider = FrontEndProvider( + owner="test-owner", + repo="test-repo", + ) + provider.all_releases = mock_releases + provider.latest_release = mock_releases[1] + FrontendManager.PROVIDERS = [provider] + return provider + + +def test_get_release(mock_provider, mock_releases): + version = "1.0.0" + release = mock_provider.get_release(version) + assert release == mock_releases[0] + + +def test_get_release_latest(mock_provider, mock_releases): + version = "latest" + release = mock_provider.get_release(version) + assert release == mock_releases[1] + + +def test_get_release_invalid_version(mock_provider): + version = "invalid" + with pytest.raises(ValueError): + mock_provider.get_release(version) + + +def test_init_frontend_default(): + version_string = DEFAULT_VERSION_STRING + frontend_path = FrontendManager.init_frontend(version_string) + assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH + + +def test_init_frontend_invalid_version(): + version_string = "test-owner/test-repo@1.100.99" + with pytest.raises(HTTPError): + FrontendManager.init_frontend_unsafe(version_string) + + +def test_init_frontend_invalid_provider(): + version_string = "invalid/invalid@latest" + with pytest.raises(HTTPError): + FrontendManager.init_frontend_unsafe(version_string) + + +def test_parse_version_string(): + version_string = "owner/repo@1.0.0" + repo_owner, repo_name, version = FrontendManager.parse_version_string( + version_string + ) + assert repo_owner == "owner" + assert repo_name == "repo" + assert version == "1.0.0" + + +def test_parse_version_string_invalid(): + version_string = "invalid" + with pytest.raises(argparse.ArgumentTypeError): + FrontendManager.parse_version_string(version_string)