diff --git a/comfy/client/embedded_comfy_client.py b/comfy/client/embedded_comfy_client.py index 3ff4dbeae..389fd2b6b 100644 --- a/comfy/client/embedded_comfy_client.py +++ b/comfy/client/embedded_comfy_client.py @@ -57,6 +57,7 @@ def _execute_prompt( finally: detach(token) + async def __execute_prompt( prompt: dict, prompt_id: str, @@ -68,6 +69,7 @@ async def __execute_prompt( with context_configuration(configuration): return await ___execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, partial_execution_targets) + async def ___execute_prompt( prompt: dict, prompt_id: str, @@ -116,7 +118,7 @@ async def ___execute_prompt( raise exc_info -def _cleanup(): +def _cleanup(invalidate_nodes=True): from ..cmd.execution import PromptExecutor from ..nodes_context import invalidate try: @@ -132,10 +134,11 @@ def _cleanup(): model_management.soft_empty_cache() except: pass - try: - invalidate() - except: - pass + if invalidate_nodes: + try: + invalidate() + except: + pass class Comfy: @@ -180,7 +183,6 @@ class Comfy: self._history = History() self._context_stack = [] - @property def is_running(self) -> bool: return self._is_running @@ -200,6 +202,9 @@ class Comfy: def history(self) -> History: return self._history + async def clear_cache(self): + await get_event_loop().run_in_executor(self._executor, _cleanup, False) + def __exit__(self, *args): get_event_loop().run_in_executor(self._executor, _cleanup) self._executor.shutdown(wait=True) diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index fb048c95d..3cfe73407 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -18,8 +18,9 @@ from typing import List, Optional, Tuple, Literal # order matters from .main_pre import tracer - import torch +from frozendict import frozendict +from comfy_execution.graph_types import FrozenTopologicalSort, Input from opentelemetry.trace import get_current_span, StatusCode, Status from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \ @@ -30,7 +31,7 @@ from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io -from ..execution_context import current_execution_context +from ..execution_context import current_execution_context, context_set_execution_list_and_inputs from .. import interruption from .. import model_management from ..component_model.abstract_prompt_queue import AbstractPromptQueue @@ -241,7 +242,12 @@ async def resolve_map_node_over_list_results(results): @tracer.start_as_current_span("Execute Node") -async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): +async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None, execution_list=None, executed=None): + with context_set_execution_list_and_inputs(FrozenTopologicalSort.from_topological_sort(execution_list) if execution_list is not None else None, frozenset(executed) if executed is not None else None): + return await __async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb, hidden_inputs) + + +async def __async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): span = get_current_span() class_type = obj.__class__.__name__ span.set_attribute("class_type", class_type) @@ -368,8 +374,8 @@ def merge_result_data(results, obj): return output -async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): - return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) +async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None, inputs=None, execution_list=None, executed=None): + return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs, execution_list=execution_list, executed=executed) has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) if has_pending_task: return return_values, {}, False, has_pending_task @@ -469,12 +475,12 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca :return: """ with (context_execute_node(node_id), - vanilla_node_execution_environment(), - use_requests_caching()): + vanilla_node_execution_environment(), + use_requests_caching()): return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple: +async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list: ExecutionList, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple: unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -543,7 +549,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra else: lazy_status_present = getattr(obj, "check_lazy_status", None) is not None if lazy_status_present: - required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs) + required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs, execution_list=execution_list, executed=executed) required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], [])) required_inputs = [x for x in required_inputs if isinstance(x, str) and ( @@ -577,7 +583,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs, inputs=inputs, execution_list=execution_list, executed=executed) if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) diff --git a/comfy/component_model/images_types.py b/comfy/component_model/images_types.py index 8ef892059..02a94b176 100644 --- a/comfy/component_model/images_types.py +++ b/comfy/component_model/images_types.py @@ -1,8 +1,8 @@ from typing import NamedTuple -from torch import Tensor +from .tensor_types import ImageBatch, MaskBatch -class RgbMaskTuple(NamedTuple): - rgb: Tensor - mask: Tensor +class ImageMaskTuple(NamedTuple): + image: ImageBatch + mask: MaskBatch diff --git a/comfy/execution_context.py b/comfy/execution_context.py index 400764831..fef63a4a5 100644 --- a/comfy/execution_context.py +++ b/comfy/execution_context.py @@ -5,6 +5,7 @@ from contextvars import ContextVar from dataclasses import dataclass, replace from typing import Optional, Final +from comfy_execution.graph_types import FrozenTopologicalSort from .cli_args import cli_args_configuration from .cli_args_types import Configuration from .component_model import cvpickle @@ -17,15 +18,26 @@ from .progress_types import AbstractProgressRegistry, ProgressRegistryStub @dataclass(frozen=True) class ExecutionContext: + # at any time server: ExecutorToClientProgress folder_names_and_paths: FolderNames custom_nodes: ExportedNodes + configuration: Optional[Configuration] = None + + # during prompt execution + progress_registry: Optional[AbstractProgressRegistry] = None + + # during node execution node_id: Optional[str] = None task_id: Optional[str] = None list_index: Optional[int] = None inference_mode: bool = True - progress_registry: Optional[AbstractProgressRegistry] = None - configuration: Optional[Configuration] = None + execution_list: Optional[FrozenTopologicalSort] = None + executed: Optional[frozenset] = None + + @property + def inputs(self) -> dict: + return self.execution_list.dynprompt.get_node(self.node_id)['inputs'] def __iter__(self): """ @@ -123,3 +135,11 @@ def context_set_node_and_prompt(prompt_id: str, node_id: str, list_index: Option new_ctx = replace(current_ctx, task_id=prompt_id, node_id=node_id, list_index=list_index) with _new_execution_context(new_ctx): yield new_ctx + + +@contextmanager +def context_set_execution_list_and_inputs(execution_list: FrozenTopologicalSort, executed: frozenset): + current_ctx = current_execution_context() + new_ctx = replace(current_ctx, execution_list=execution_list, executed=executed) + with _new_execution_context(new_ctx): + yield new_ctx diff --git a/comfy/lazy_helpers.py b/comfy/lazy_helpers.py new file mode 100644 index 000000000..6060e3ca2 --- /dev/null +++ b/comfy/lazy_helpers.py @@ -0,0 +1,29 @@ +from .execution_context import current_execution_context +from typing import Any, Generator, Sequence + + +def is_input_unscheduled_unexecuted(*arg_names: Sequence[str]) -> Generator[bool, Any, None]: + """ + returns true if the given argument in the context of an executing node is not scheduled nor executed + this will be true for inputs that are marked as lazy, and this method is more robust against nodes that return None + :param arg_names: each arg to evaluate + :return: True for each arg that is not scheduled nor executed + """ + context = current_execution_context() + if context is None or context.execution_list is None: + raise LookupError("Not executing a node") + assert context.execution_list is not None + dynprompt = context.execution_list.dynprompt + executed = context.executed + execution_list = context.execution_list + inputs = context.inputs + unscheduled_unexecuted = dynprompt.all_node_ids() - executed - set(execution_list.pendingNodes.keys()) + for arg_name in arg_names: + if arg_name not in inputs: + raise ValueError(f"Input {arg_name} not found") + input_ = inputs[arg_name] + if isinstance(input_, list) or isinstance(input_, tuple) and len(input_) == 2: + node_id, *_ = input_ + yield node_id in unscheduled_unexecuted + else: + yield False diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index ab013a9eb..59144017b 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -30,6 +30,7 @@ from ..cli_args import args from ..cmd import folder_paths, latent_preview from ..comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator from ..component_model.deprecation import _deprecate_method +from ..component_model.images_types import ImageMaskTuple from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch, Latent from ..execution_context import current_execution_context from ..images import open_image @@ -1791,7 +1792,7 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" - def load_image(self, image: str) -> tuple[RGBImageBatch, MaskBatch]: + def load_image(self, image: str) -> ImageMaskTuple: image_path = folder_paths.get_annotated_filepath(image) output_images = [] output_masks = [] @@ -1829,7 +1830,9 @@ class LoadImage: mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 mask = 1. - torch.from_numpy(mask) else: - mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu") + # unlike upstream, the mask is now the size of the original image, even when there was no alpha channel + # this is opaque + mask = torch.zeros((h, w), dtype=torch.float32, device="cpu") output_images.append(image) output_masks.append(mask.unsqueeze(0)) @@ -1840,7 +1843,7 @@ class LoadImage: output_image = output_images[0] output_mask = output_masks[0] - return (output_image, output_mask) + return ImageMaskTuple(output_image, output_mask) @classmethod def VALIDATE_INPUTS(s, image): diff --git a/comfy/open_exr.py b/comfy/open_exr.py index 1dd711aee..97e96b8e0 100644 --- a/comfy/open_exr.py +++ b/comfy/open_exr.py @@ -33,7 +33,7 @@ import numpy as np import torch from torch import Tensor -from .component_model.images_types import RgbMaskTuple +from .component_model.images_types import ImageMaskTuple read_exr = lambda fp: cv2.imread(fp, cv2.IMREAD_UNCHANGED).astype(np.float32) @@ -50,7 +50,7 @@ def mut_linear_to_srgb(np_array) -> None: np_array[~less] = np.power(np_array[~less], 1 / 2.4) * 1.055 - 0.055 -def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple: +def load_exr(file_path: str, srgb: bool) -> ImageMaskTuple: image = read_exr(file_path) rgb = np.flip(image[:, :, :3], 2).copy() if srgb: @@ -62,7 +62,7 @@ def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple: if image.shape[2] > 3: mask[0] = torch.from_numpy(np.clip(image[:, :, 3], 0, 1)) - return RgbMaskTuple(rgb, mask) + return ImageMaskTuple(rgb, mask) def load_exr_latent(file_path: str) -> Tuple[Tensor]: diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index b3dfcf49b..088ed1482 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -5,7 +5,7 @@ import inspect from typing import Optional, Type, Literal, Union, NamedTuple from comfy.nodes.package_typing import CustomNode, InputTypes -from comfy_execution.graph_utils import is_link, ExecutionBlocker +from comfy_execution.graph_utils import ExecutionBlocker from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions from comfy.component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \ DependencyExecutionErrorMessage, ComboOptions @@ -249,7 +249,7 @@ class ExecutionList(TopologicalSort): # for a PreviewImage to display a result as soon as it can # Some other heuristics could probably be used here to improve the UX further. def is_output(node_id): - class_type = self.dynprompt.get_node(node_id)["class_type"] + class_type = self.dynprompt.get_node(node_id)['class_type'] class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type] if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: return True @@ -258,7 +258,7 @@ class ExecutionList(TopologicalSort): # If an available node is async, do that first. # This will execute the asynchronous function earlier, reducing the overall time. def is_async(node_id): - class_type = self.dynprompt.get_node(node_id)["class_type"] + class_type = self.dynprompt.get_node(node_id)['class_type'] class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type] return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION)) diff --git a/comfy_execution/graph_types.py b/comfy_execution/graph_types.py new file mode 100644 index 000000000..76bb9157c --- /dev/null +++ b/comfy_execution/graph_types.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import typing +from dataclasses import dataclass +from typing import Mapping + +from frozendict import frozendict, deepfreeze + +from comfy.component_model.executor_types import NodeNotFoundError + +if typing.TYPE_CHECKING: + from .graph import TopologicalSort, DynamicPrompt + +NodeOutputByIdAndIndex = tuple[str, int] +InputValue = typing.Union[NodeOutputByIdAndIndex, bool, bytes, int, float, str, typing.IO, typing.BinaryIO] +Input = Mapping[str, InputValue] + + +@dataclass(frozen=True) +class FrozenTopologicalSort: + dynprompt: "FrozenPrompt" + pendingNodes: frozendict + blockCount: frozendict + blocking: frozendict + externalBlocks: int + + @classmethod + def from_topological_sort(cls, ts: "TopologicalSort") -> FrozenTopologicalSort: + return cls( + dynprompt=FrozenPrompt.from_dynamic_prompt(ts.dynprompt), + pendingNodes=deepfreeze(ts.pendingNodes), + blockCount=deepfreeze(ts.blockCount), + blocking=deepfreeze(ts.blocking), + externalBlocks=ts.externalBlocks, + ) + + +@dataclass(frozen=True) +class FrozenPrompt: + original_prompt: frozendict + ephemeral_prompt: frozendict + ephemeral_parents: frozendict + ephemeral_display: frozendict + + @classmethod + def from_dynamic_prompt(cls, dynprompt: "DynamicPrompt") -> FrozenPrompt: + return cls( + original_prompt=deepfreeze(dynprompt.original_prompt), + ephemeral_prompt=deepfreeze(dynprompt.ephemeral_prompt), + ephemeral_parents=deepfreeze(dynprompt.ephemeral_parents), + ephemeral_display=deepfreeze(dynprompt.ephemeral_display), + ) + + def get_node(self, node_id): + if node_id in self.ephemeral_prompt: + return self.ephemeral_prompt[node_id] + if node_id in self.original_prompt: + return self.original_prompt[node_id] + raise NodeNotFoundError(f"Node {node_id} not found") + + def has_node(self, node_id): + return node_id in self.original_prompt or node_id in self.ephemeral_prompt + + def get_real_node_id(self, node_id): + while node_id in self.ephemeral_parents: + node_id = self.ephemeral_parents[node_id] + return node_id + + def get_parent_node_id(self, node_id): + return self.ephemeral_parents.get(node_id, None) + + def get_display_node_id(self, node_id): + while node_id in self.ephemeral_display: + node_id = self.ephemeral_display[node_id] + return node_id + + def all_node_ids(self): + return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys())) + + def get_original_prompt(self): + return self.original_prompt diff --git a/comfy_execution/utils.py b/comfy_execution/utils.py index 6f94101f9..a8cc60ab2 100644 --- a/comfy_execution/utils.py +++ b/comfy_execution/utils.py @@ -42,7 +42,7 @@ class CurrentNodeContext: self._cm = core_execution_context.context_set_node_and_prompt( prompt_id=prompt_id, node_id=node_id, - list_index=list_index + list_index=list_index, ) def __enter__(self): diff --git a/comfy_extras/nodes/nodes_logic.py b/comfy_extras/nodes/nodes_logic.py new file mode 100644 index 000000000..7c1107089 --- /dev/null +++ b/comfy_extras/nodes/nodes_logic.py @@ -0,0 +1,179 @@ +import inspect +import operator +from typing import OrderedDict, Callable, Any + +from comfy.comfy_types import IO +from comfy.lazy_helpers import is_input_unscheduled_unexecuted +from comfy.node_helpers import export_custom_nodes +from comfy.nodes.package_typing import CustomNode, InputTypes + + +def takes_n_args(obj, n): + if not callable(obj): + return False + + try: + sig = inspect.signature(obj) + except (TypeError, ValueError): + return False + + params = sig.parameters.values() + + named_param_count = sum( + 1 for p in params + if p.kind not in (inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD) + ) + + return named_param_count == n + + +_BINARY_OPS: dict[str, Callable[[Any, Any], Any]] = OrderedDict({ + **{op: getattr(operator, op) for op in dir(operator) if takes_n_args(getattr(operator, op), 2)}, + "and": lambda a, b: a and b, + "or": lambda a, b: a or b, +}) +_UNARY_OPS: dict[str, Callable[[Any], Any]] = { + **{op: getattr(operator, op) for op in dir(operator) if takes_n_args(getattr(operator, op), 1)}, + "not": lambda a: not a +} + + +class IsNone(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": {}, + "optional": { + "value": (IO.ANY, {}), + } + } + + RETURN_TYPES = (IO.BOOLEAN,) + FUNCTION = "execute" + CATEGORY = "logic" + + def execute(self, value=None): + return (value is None,) + + +class LazySwitch(CustomNode): + """ + sherlocked from KJ nodes with fixes + """ + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "switch": ("BOOLEAN",), + }, + "optional": { + "on_false": (IO.ANY, {"lazy": True}), + "on_true": (IO.ANY, {"lazy": True}), + }, + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "execute" + CATEGORY = "logic" + DESCRIPTION = "Controls flow of execution based on a boolean switch." + + def check_lazy_status(self, switch, on_false=None, on_true=None): + try: + on_false_not_evaluated, on_true_not_evaluated = is_input_unscheduled_unexecuted("on_false", "on_true") + except LookupError: + on_false_not_evaluated, on_true_not_evaluated = on_false is None, on_true is None + if switch and on_true_not_evaluated: + return ["on_true"] + if not switch and on_false_not_evaluated: + return ["on_false"] + return [] + + def execute(self, switch, on_false=None, on_true=None): + value = on_true if switch else on_false + return (value,) + + +class UnaryOperation(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": {}, + "optional": { + "value": (IO.ANY, {}), + "op": (list(_UNARY_OPS.keys()), {"default": "not"}) + } + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "execute" + CATEGORY = "logic" + + def execute(self, value, op): + return _UNARY_OPS[op](value), + + +class BooleanUnaryOperation(UnaryOperation): + RETURN_TYPES = (IO.BOOLEAN,) + + def execute(self, value, op): + return bool(super().execute(value, op)[0]), + + +class BinaryOperation(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": {}, + "optional": OrderedDict({ + "lhs": (IO.ANY, {"lazy": True}), + "op": (list(_BINARY_OPS.keys()), {"default": "eq"}), + "rhs": (IO.ANY, {"lazy": True}), + }) + } + + RETURN_TYPES = (IO.ANY,) + FUNCTION = "execute" + CATEGORY = "logic" + DESCRIPTION = "" + + def check_lazy_status(self, lhs=None, op=None, rhs=None) -> list[str]: + try: + lhs_not_evaluated, rhs_not_evaluated = is_input_unscheduled_unexecuted("lhs", "rhs") + except LookupError: + lhs_not_evaluated, rhs_not_evaluated = lhs is None, rhs is None + lhs_evaluated, rhs_evaluated = not lhs_not_evaluated, not rhs_not_evaluated + match op: + case "and": + if lhs_not_evaluated: + return ["lhs"] + if lhs_evaluated and lhs is not False and rhs_not_evaluated: + return ["rhs"] + return [] + case "or": + if lhs_not_evaluated: + return ["lhs"] + if lhs_evaluated and lhs is not True and rhs_not_evaluated: + return ["rhs"] + return [] + case _: + to_eval = [] + if lhs_not_evaluated: + to_eval.append("lhs") + if rhs_not_evaluated: + to_eval.append("rhs") + return to_eval + + def execute(self, lhs, op, rhs): + return _BINARY_OPS[op](lhs, rhs), + + +class BooleanBinaryOperation(BinaryOperation): + RETURN_TYPES = (IO.BOOLEAN,) + + def execute(self, lhs, op, rhs): + return bool(super().execute(lhs, op, rhs)[0]), + + +export_custom_nodes() diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index bce7242de..2ecbe56df 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -34,7 +34,8 @@ from torch import Tensor from comfy.cmd import folder_paths from comfy.comfy_types import IO -from comfy.component_model.tensor_types import RGBAImageBatch, RGBImageBatch +from comfy.component_model.images_types import ImageMaskTuple +from comfy.component_model.tensor_types import RGBAImageBatch, RGBImageBatch, MaskBatch, ImageBatch from comfy.digest import digest from comfy.node_helpers import export_custom_nodes from comfy.nodes.package_typing import CustomNode, InputTypes, FunctionReturnsUIVariables, SaveNodeResult, \ @@ -809,18 +810,20 @@ class ImageRequestParameter(CustomNode): }, "optional": { **_open_api_common_schema, - "default_if_empty": ("IMAGE",) + "default_if_empty": ("IMAGE",), + "alpha_is_transparency": ("BOOLEAN", {"default": False}), } } - RETURN_TYPES = ("IMAGE",) + RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "execute" CATEGORY = "api/openapi" - def execute(self, value: str = "", default_if_empty=None, *args, **kwargs) -> ValidatedNodeResult: + def execute(self, value: str = "", default_if_empty=None, alpha_is_transparency=False, *args, **kwargs) -> ImageMaskTuple: if value.strip() == "": return (default_if_empty,) output_images = [] + output_masks = [] f: OpenFile fsspec_kwargs = {} if value.startswith('http'): @@ -832,31 +835,41 @@ class ImageRequestParameter(CustomNode): }) # todo: additional security is needed here to prevent users from accessing local paths # however this generally needs to be done with user accounts on all OSes - with fsspec.open(value, mode="rb", **fsspec_kwargs) as f: - # from LoadImage - img = Image.open(f) - for i in ImageSequence.Iterator(img): - prev_value = None - try: - i = ImageOps.exif_transpose(i) - except OSError: - prev_value = ImageFile.LOAD_TRUNCATED_IMAGES - ImageFile.LOAD_TRUNCATED_IMAGES = True - i = ImageOps.exif_transpose(i) - finally: - if prev_value is not None: - ImageFile.LOAD_TRUNCATED_IMAGES = prev_value - if i.mode == 'I': - i = i.point(lambda i: i * (1 / 255)) - image = i.convert("RGB") - image = np.array(image).astype(np.float32) / 255.0 - image = torch.from_numpy(image)[None,] - output_images.append(image) - if len(output_images) > 1: - output_image = torch.cat(output_images, dim=0) - else: - output_image = output_images[0] - return (output_image,) + with fsspec.open_files(value, mode="rb", **fsspec_kwargs) as files: + for f in files: + # from LoadImage + img = Image.open(f) + for i in ImageSequence.Iterator(img): + prev_value = None + try: + i = ImageOps.exif_transpose(i) + except OSError: + prev_value = ImageFile.LOAD_TRUNCATED_IMAGES + ImageFile.LOAD_TRUNCATED_IMAGES = True + i = ImageOps.exif_transpose(i) + finally: + if prev_value is not None: + ImageFile.LOAD_TRUNCATED_IMAGES = prev_value + if i.mode == 'I': + i = i.point(lambda i: i * (1 / 255)) + image = i.convert("RGBA" if alpha_is_transparency else "RGB") + image = np.array(image).astype(np.float32) / 255.0 + image = torch.from_numpy(image)[None,] + if 'A' in i.getbands(): + mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + elif i.mode == 'P' and 'transparency' in i.info: + mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0 + mask = 1. - torch.from_numpy(mask) + else: + mask = torch.zeros((image.shape[1], image.shape[2]), dtype=torch.float32, device="cpu") + output_images.append(image) + output_masks.append(mask.unsqueeze(0)) + + output_images_batched: ImageBatch = torch.cat(output_images, dim=0) + output_masks_batched: MaskBatch = torch.cat(output_masks, dim=0) + + return ImageMaskTuple(output_images_batched, output_masks_batched) export_custom_nodes() diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index b7d7ad2ca..356997950 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -3,7 +3,7 @@ import logging import time import urllib.request import uuid -from typing import Dict, Optional +from typing import Dict, Optional, AsyncGenerator import numpy import pytest @@ -128,7 +128,7 @@ class TestExecution: (0, True), (100, True), ]) - async def client(self, request) -> ComfyClient: + async def client(self, request): from ..inference.testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS # ??? todo: we have to deal with this @@ -784,3 +784,163 @@ class TestExecution: result = client.get_all_history(max_items=5, offset=len(all_history) - 1) assert len(result) <= 1, "Should return at most 1 item when offset is near end" + + async def test_lazy_switch_true_branch(self, client: ComfyClient, builder: GraphBuilder): + await client.embedded_client.clear_cache() + g = builder + # Create a "True" boolean value + true_int = g.node("StubInt", value=1) + true_bool = g.node("TestIntConditions", a=true_int.out(0), b=1, operation="==") # 1 == 1 -> True + + # Create nodes for branches + node_true = g.node("StubImage", content="WHITE", height=32, width=32, batch_size=1) + node_false = g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1) + + # Create lazy switch + # Note: LazySwitch is imported at the top of the file + switch = g.node("LazySwitch", switch=true_bool.out(0), on_true=node_true.out(0), on_false=node_false.out(0)) + output = g.node("SaveImage", images=switch.out(0)) + + result = await client.run(g) + + # Check execution + assert result.did_run(true_int), "True stub int should run" + assert result.did_run(true_bool), "Boolean condition node should run" + assert result.did_run(node_true), "on_true node should run" + assert not result.did_run(node_false), "on_false node should NOT run" + assert result.did_run(switch), "LazySwitch node should run" + assert result.did_run(output), "SaveImage node should run" + + # Check output + result_image = result.get_images(output)[0] + assert numpy.array(result_image).mean() == 255, "Image should be white" + + async def test_lazy_switch_false_branch(self, client: ComfyClient, builder: GraphBuilder): + await client.embedded_client.clear_cache() + g = builder + # Create a "False" boolean value + false_int = g.node("StubInt", value=0) + false_bool = g.node("TestIntConditions", a=false_int.out(0), b=1, operation="==") # 0 == 1 -> False + + # Create nodes for branches + node_true = g.node("StubImage", content="WHITE", height=32, width=32, batch_size=1) + node_false = g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1) + + # Create lazy switch + switch = g.node("LazySwitch", switch=false_bool.out(0), on_true=node_true.out(0), on_false=node_false.out(0)) + output = g.node("SaveImage", images=switch.out(0)) + + result = await client.run(g) + + # Check execution + assert result.did_run(false_int), "False stub int should run" + assert result.did_run(false_bool), "Boolean condition node should run" + assert not result.did_run(node_true), "on_true node should NOT run" + assert result.did_run(node_false), "on_false node should run" + assert result.did_run(switch), "LazySwitch node should run" + assert result.did_run(output), "SaveImage node should run" + + # Check output + result_image = result.get_images(output)[0] + assert numpy.array(result_image).mean() == 0, "Image should be black" + + async def test_lazy_binary_op_and_short_circuit(self, client: ComfyClient, builder: GraphBuilder): + await client.embedded_client.clear_cache() + g = builder + # Create a "False" boolean value + false_int = g.node("StubInt", value=0) + lhs_bool = g.node("TestIntConditions", a=false_int.out(0), b=1, operation="==") # 0 == 1 -> False + + # Create a "True" boolean value for RHS (this node should not run) + true_int_rhs = g.node("StubInt", value=1) + rhs_bool = g.node("TestIntConditions", a=true_int_rhs.out(0), b=1, operation="==") # 1 == 1 -> True + + # Create binary op + # Note: BinaryOperation is imported at the top of the file + binary_op = g.node("BinaryOperation", lhs=lhs_bool.out(0), op="and", rhs=rhs_bool.out(0)) + + # Create lazy switch to check result + node_true = g.node("StubImage", content="WHITE", height=32, width=32, batch_size=1) + node_false = g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1) + switch = g.node("LazySwitch", switch=binary_op.out(0), on_true=node_true.out(0), on_false=node_false.out(0)) + output = g.node("SaveImage", images=switch.out(0)) + + result = await client.run(g) + + # Check execution + assert result.did_run(false_int), "LHS int node should run" + assert result.did_run(lhs_bool), "LHS bool node should run" + assert not result.did_run(true_int_rhs), "RHS int node should NOT run (short-circuit)" + assert not result.did_run(rhs_bool), "RHS bool node should NOT run (short-circuit)" + assert result.did_run(binary_op), "BinaryOp should run" + assert not result.did_run(node_true), "on_true node should NOT run" + assert result.did_run(node_false), "on_false node should run" + + # Check output + result_image = result.get_images(output)[0] + assert numpy.array(result_image).mean() == 0, "Image should be black (result of 'and' was False)" + + async def test_lazy_binary_op_or_short_circuit(self, client: ComfyClient, builder: GraphBuilder): + await client.embedded_client.clear_cache() + g = builder + # Create a "True" boolean value + true_int = g.node("StubInt", value=1) + lhs_bool = g.node("TestIntConditions", a=true_int.out(0), b=1, operation="==") # 1 == 1 -> True + + # Create a "False" boolean value for RHS (this node should not run) + false_int_rhs = g.node("StubInt", value=0) + rhs_bool = g.node("TestIntConditions", a=false_int_rhs.out(0), b=1, operation="==") # 0 == 1 -> False + + # Create binary op + binary_op = g.node("BinaryOperation", lhs=lhs_bool.out(0), op="or", rhs=rhs_bool.out(0)) + + # Create lazy switch to check result + node_true = g.node("StubImage", content="WHITE", height=32, width=32, batch_size=1) + node_false = g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1) + switch = g.node("LazySwitch", switch=binary_op.out(0), on_true=node_true.out(0), on_false=node_false.out(0)) + output = g.node("SaveImage", images=switch.out(0)) + + result = await client.run(g) + + # Check execution + assert result.did_run(true_int), "LHS int node should run" + assert result.did_run(lhs_bool), "LHS bool node should run" + assert not result.did_run(false_int_rhs), "RHS int node should NOT run (short-circuit)" + assert not result.did_run(rhs_bool), "RHS bool node should NOT run (short-circuit)" + assert result.did_run(binary_op), "BinaryOp should run" + assert result.did_run(node_true), "on_true node should run" + assert not result.did_run(node_false), "on_false node should NOT run" + + # Check output + result_image = result.get_images(output)[0] + assert numpy.array(result_image).mean() == 255, "Image should be white (result of 'or' was True)" + + async def test_lazy_switch_with_none_input(self, client: ComfyClient, builder: GraphBuilder): + await client.embedded_client.clear_cache() + g = builder + # Create a "False" boolean value + false_int = g.node("StubInt", value=0) + false_bool = g.node("TestIntConditions", a=false_int.out(0), b=1, operation="==") # 0 == 1 -> False + + # Create nodes for branches + # This node will return None as its value is empty and default_if_empty is not set + node_true_image = g.node("ImageRequestParameter", value="", description="1") + node_false_image = g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1) + + # Create lazy switch + switch = g.node("LazySwitch", switch=false_bool.out(0), on_true=node_true_image.out(0), on_false=node_false_image.out(0)) + output = g.node("SaveImage", images=switch.out(0)) + + result = await client.run(g) + + # Check execution + assert result.did_run(false_int), "False stub int should run" + assert result.did_run(false_bool), "Boolean condition node should run" + assert not result.did_run(node_true_image), "on_true (ImageRequestParameter) node should NOT run" + assert result.did_run(node_false_image), "on_false node should run" + assert result.did_run(switch), "LazySwitch node should run" + assert result.did_run(output), "SaveImage node should run" + + # Check output + result_image = result.get_images(output)[0] + assert numpy.array(result_image).mean() == 0, "Image should be black" diff --git a/tests/unit/test_openapi_nodes.py b/tests/unit/test_openapi_nodes.py index 79a0e6c64..c59d4ee56 100644 --- a/tests/unit/test_openapi_nodes.py +++ b/tests/unit/test_openapi_nodes.py @@ -1,4 +1,7 @@ from __future__ import annotations + +import platform + # noqa: E402 from comfy.cmd.main_pre import args import os @@ -280,7 +283,6 @@ def test_posix_join_curly_brackets(): assert joined_path == "a_{test}/b/c" -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") def test_file_request_parameter(use_temporary_input_directory): _image_1x1_px = np.array([[[255, 0, 0]]], dtype=np.uint8) image_path = os.path.join(use_temporary_input_directory, "test_image.png") @@ -288,7 +290,7 @@ def test_file_request_parameter(use_temporary_input_directory): image.save(image_path) n = ImageRequestParameter() - loaded_image, = n.execute(value=image_path) + loaded_image, *_ = n.execute(value=image_path) assert loaded_image.shape == (1, 1, 1, 3) from comfy.nodes.base_nodes import LoadImage @@ -299,13 +301,174 @@ def test_file_request_parameter(use_temporary_input_directory): assert torch.allclose(loaded_image, load_image_node_rgb) +def test_file_request_parameter2(use_temporary_input_directory): + n = ImageRequestParameter() + + # Test 1: Load a single RGB image + _image_1x1_px_rgb = np.array([[[255, 0, 0]]], dtype=np.uint8) # 1x1 RGB + image_path_rgb = os.path.join(use_temporary_input_directory, "test_image_rgb.png") + image_rgb = Image.fromarray(_image_1x1_px_rgb, 'RGB') + image_rgb.save(image_path_rgb) + + loaded_image_rgb, loaded_mask_rgb = n.execute(value=image_path_rgb, alpha_is_transparency=True) + + # Node converts RGB to RGBA + assert loaded_image_rgb.shape == (1, 1, 1, 4) # B, H, W, C + # Check RGB values + assert torch.allclose(loaded_image_rgb[0, 0, 0, :3], torch.tensor([1.0, 0.0, 0.0])) + # Check added Alpha channel + assert torch.allclose(loaded_image_rgb[0, 0, 0, 3], torch.tensor(1.0)) + # Check mask (should be all 0s for 1.0 alpha) + assert loaded_mask_rgb.shape == (1, 1, 1) # B, H, W + assert torch.all(loaded_mask_rgb == 0.0) + + # Test 2: Load a single RGBA image with transparency + _image_1x1_px_rgba = np.array([[[255, 0, 0, 128]]], dtype=np.uint8) # 1x1 RGBA + image_path_rgba = os.path.join(use_temporary_input_directory, "test_image_rgba.png") + image_rgba = Image.fromarray(_image_1x1_px_rgba, 'RGBA') + image_rgba.save(image_path_rgba) + + loaded_image_rgba, loaded_mask_rgba = n.execute(value=image_path_rgba, alpha_is_transparency=True) + + # Node should load RGBA as is + assert loaded_image_rgba.shape == (1, 1, 1, 4) # B, H, W, C + # Check RGBA values + assert torch.allclose(loaded_image_rgba[0, 0, 0, :], torch.tensor([1.0, 0.0, 0.0, 128 / 255.0])) + # Check mask (should be 1.0 - alpha) + assert loaded_mask_rgba.shape == (1, 1, 1) # B, H, W + assert torch.allclose(loaded_mask_rgba[0, 0, 0], torch.tensor(1.0 - 128 / 255.0)) + + # Test 3: Load a single RGB image with alpha_is_transparency=False + loaded_image_rgb_no_alpha, loaded_mask_rgb_no_alpha = n.execute(value=image_path_rgb, alpha_is_transparency=False) + + # Node converts to RGB + assert loaded_image_rgb_no_alpha.shape == (1, 1, 1, 3) # B, H, W, C + # Check RGB values + assert torch.allclose(loaded_image_rgb_no_alpha[0, 0, 0, :], torch.tensor([1.0, 0.0, 0.0])) + # Check mask + assert loaded_mask_rgb_no_alpha.shape == (1, 1, 1) # B, H, W + + # Test 4: Load a single RGBA image with alpha_is_transparency=False + loaded_image_rgba_no_alpha, loaded_mask_rgba_no_alpha = n.execute(value=image_path_rgba, alpha_is_transparency=False) + + # Node should load RGBA as RGB (dropping alpha) + assert loaded_image_rgba_no_alpha.shape == (1, 1, 1, 3) # B, H, W, C + # Check RGB values (straight, not pre-multiplied) + assert torch.allclose(loaded_image_rgba_no_alpha[0, 0, 0, :], torch.tensor([1.0, 0.0, 0.0])) + assert loaded_mask_rgba_no_alpha.shape == (1, 1, 1) # B, H, W + + +def test_file_request_parameter_glob(use_temporary_input_directory): + # 1. Create dummy images (2x2) + # Image 1 (RGB) + img_rgb_data = np.array([ + [[255, 0, 0], [0, 255, 0]], + [[0, 0, 255], [255, 255, 255]] + ], dtype=np.uint8) + img_rgb = Image.fromarray(img_rgb_data, 'RGB') + path_rgb = os.path.join(use_temporary_input_directory, "img_rgb.png") + img_rgb.save(path_rgb) + + # Image 2 (RGBA with transparency) + img_rgba_data_rgb = np.array([ + [[10, 20, 30], [40, 50, 60]], + [[70, 80, 90], [100, 110, 120]] + ], dtype=np.uint8) + img_rgba_data_a = np.array([ + [255, 128], # alpha 1.0, 0.5 + [0, 255] # alpha 0.0, 1.0 + ], dtype=np.uint8) + img_rgba_data = np.dstack((img_rgba_data_rgb, img_rgba_data_a)) + img_rgba = Image.fromarray(img_rgba_data, 'RGBA') + path_rgba = os.path.join(use_temporary_input_directory, "img_rgba.png") + img_rgba.save(path_rgba) + + # 2. Construct fsspec URL + prefix = 'file:///' if platform.system() == "Windows" else 'file://' + # Add glob pattern + url = f"{prefix}{use_temporary_input_directory}/*.png" + + n = ImageRequestParameter() + + # 3. Test with alpha_is_transparency=True (default) + loaded_images, loaded_masks = n.execute(value=url, alpha_is_transparency=True) + + # 4. Verify results (True) - Order independent + assert loaded_images.shape == (2, 2, 2, 4) # B=2, H=2, W=2, C=RGBA + assert loaded_masks.shape == (2, 2, 2) # B=2, H=2, W=2 + + # Find which image is which by checking the mask sum + mask_sums = torch.sum(loaded_masks, dim=(1, 2)) + expected_rgba_mask_sum = (1.0 - 128/255.0) + 1.0 # From alpha 128 and 0 + + # Argmin should find the all-zero mask (from the RGB image) + rgb_img_index = torch.argmin(mask_sums) + # Argmax should find the mask with transparency + rgba_img_index = torch.argmax(mask_sums) + + assert rgb_img_index != rgba_img_index + assert torch.allclose(mask_sums[rgb_img_index], torch.tensor(0.0)) + assert torch.allclose(mask_sums[rgba_img_index], torch.tensor(expected_rgba_mask_sum)) + + # Check RGB image tensor (which was converted to RGBA) + rgb_image_tensor = loaded_images[rgb_img_index] + assert torch.allclose(rgb_image_tensor[0, 0, :3], torch.tensor([1.0, 0.0, 0.0])) # Red pixel + assert torch.allclose(rgb_image_tensor[0, 0, 3], torch.tensor(1.0)) # Added alpha + assert torch.allclose(rgb_image_tensor[1, 1, :3], torch.tensor([1.0, 1.0, 1.0])) # White pixel + assert torch.allclose(rgb_image_tensor[1, 1, 3], torch.tensor(1.0)) # Added alpha + + # Check RGBA image tensor + rgba_image_tensor = loaded_images[rgba_img_index] + # Pixel [0, 1] (alpha 128) + assert torch.allclose(rgba_image_tensor[0, 1, :3], torch.tensor([40/255.0, 50/255.0, 60/255.0])) + assert torch.allclose(rgba_image_tensor[0, 1, 3], torch.tensor(128/255.0)) # Original alpha + # Pixel [1, 0] (alpha 0) + assert torch.allclose(rgba_image_tensor[1, 0, :3], torch.tensor([70/255.0, 80/255.0, 90/255.0])) + assert torch.allclose(rgba_image_tensor[1, 0, 3], torch.tensor(0.0)) # Original alpha + assert torch.allclose(rgba_image_tensor[1, 0, 3], torch.tensor(0.0)) # Original alpha + + # 5. Test with alpha_is_transparency=False + loaded_images_no_alpha, loaded_masks_no_alpha = n.execute(value=url, alpha_is_transparency=False) + + # 6. Verify results (False) + assert loaded_images_no_alpha.shape == (2, 2, 2, 3) # B=2, H=2, W=2, C=RGB + assert loaded_masks_no_alpha.shape == (2, 2, 2) # B=2, H, W (empty) + + # Find which image is which by checking pixel sum (RGB image has brighter pixels) + img_sums = torch.sum(loaded_images_no_alpha, dim=(1, 2, 3)) + rgb_img_index = torch.argmax(img_sums) + rgba_img_index = torch.argmin(img_sums) + + assert rgb_img_index != rgba_img_index + + # Check RGB image tensor + rgb_image_tensor_no_alpha = loaded_images_no_alpha[rgb_img_index] + assert torch.allclose(rgb_image_tensor_no_alpha[0, 0, :], torch.tensor([1.0, 0.0, 0.0])) # Red + assert torch.allclose(rgb_image_tensor_no_alpha[1, 1, :], torch.tensor([1.0, 1.0, 1.0])) # White + + # Check RGBA image tensor (which was converted to RGB, dropping alpha - straight matte) + rgba_image_tensor_no_alpha = loaded_images_no_alpha[rgba_img_index] + # Pixel [0, 1] (RGB [40, 50, 60], alpha 128) + # RGB channels are passed through unaltered + expected_rgb_0_1 = torch.tensor([40 / 255.0, 50 / 255.0, 60 / 255.0]) + assert torch.allclose(rgba_image_tensor_no_alpha[0, 1, :], expected_rgb_0_1) + # Pixel [1, 0] (RGB [70, 80, 90], alpha 0) + # RGB channels are passed through unaltered + expected_rgb_1_0 = torch.tensor([70 / 255.0, 80 / 255.0, 90 / 255.0]) + assert torch.allclose(rgba_image_tensor_no_alpha[1, 0, :], expected_rgb_1_0) + def test_file_request_to_http_url_no_exceptions(): n = ImageRequestParameter() - loaded_image, = n.execute(value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a6/A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg/484px-A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg") - _, height, width, channels = loaded_image.shape + loaded_image, loaded_mask = n.execute(value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a6/A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg/484px-A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg") + # This is an RGB jpg, so it will be converted to RGBA + b, height, width, channels = loaded_image.shape + assert b == 1 assert width == 484 assert height == 480 assert channels == 3 + # Mask should be all zeros + assert loaded_mask.shape == (1, 480, 484) + assert torch.all(loaded_mask == 0.0) @pytest.mark.parametrize("format,bits,supports_16bit", [ @@ -676,4 +839,4 @@ def test_numeric_exif(format, use_temporary_output_directory): assert tag_id is not None if tag_id in exif_data: # Convert both to strings for comparison since formats might store numbers differently - assert str(exif_data[tag_id]) == expected_value \ No newline at end of file + assert str(exif_data[tag_id]) == expected_value diff --git a/tests/unit/test_operator_nodes.py b/tests/unit/test_operator_nodes.py new file mode 100644 index 000000000..3c563e620 --- /dev/null +++ b/tests/unit/test_operator_nodes.py @@ -0,0 +1,158 @@ +import pytest +from comfy_extras.nodes.nodes_logic import LazySwitch, UnaryOperation, BooleanUnaryOperation, BinaryOperation, \ + BooleanBinaryOperation + + +def test_lazy_switch(): + n = LazySwitch() + + # Test True branch + res_true_check = n.check_lazy_status(switch=True) + assert res_true_check == ["on_true"] + res_true, = n.execute(switch=True, on_false="val_false", on_true="val_true") + assert res_true == "val_true" + + # Test False branch + res_false_check = n.check_lazy_status(switch=False) + assert res_false_check == ["on_false"] + res_false, = n.execute(switch=False, on_false="val_false", on_true="val_true") + assert res_false == "val_false" + + # Test with None + res_none, = n.execute(switch=False, on_false=None, on_true="val_true") + assert res_none is None + + +def test_unary_operation(): + n = UnaryOperation() + + # Test 'not' + res_not_true, = n.execute(value=True, op="not") + assert res_not_true is False + res_not_false, = n.execute(value=False, op="not") + assert res_not_false is True + res_not_str, = n.execute(value="hello", op="not") + assert res_not_str is False + res_not_empty_str, = n.execute(value="", op="not") + assert res_not_empty_str is True + + # Test 'inv' (invert, ~) + res_inv, = n.execute(value=5, op="inv") + assert res_inv == ~5 # -6 + + # Test 'neg' (-) + res_neg, = n.execute(value=10, op="neg") + assert res_neg == -10 + res_neg_neg, = n.execute(value=-10, op="neg") + assert res_neg_neg == 10 + + # Test 'pos' (+) + res_pos, = n.execute(value=-5, op="pos") + assert res_pos == -5 + res_pos_pos, = n.execute(value=5, op="pos") + assert res_pos_pos == 5 + + +def test_boolean_unary_operation(): + n = BooleanUnaryOperation() + + # Test 'not' + res_not_true, = n.execute(value=True, op="not") + assert res_not_true is False + res_not_false, = n.execute(value=False, op="not") + assert res_not_false is True + + # Test truthiness + res_not_int_1, = n.execute(value=1, op="not") + assert res_not_int_1 is False + res_not_int_0, = n.execute(value=0, op="not") + assert res_not_int_0 is True + res_not_str, = n.execute(value="hello", op="not") + assert res_not_str is False + res_not_empty_str, = n.execute(value="", op="not") + assert res_not_empty_str is True + + +def test_binary_operation(): + n = BinaryOperation() + + # Test ops + res_eq, = n.execute(lhs=5, op="eq", rhs=5) + assert res_eq is True + res_ne, = n.execute(lhs=5, op="ne", rhs=6) + assert res_ne is True + res_lt, = n.execute(lhs=5, op="lt", rhs=6) + assert res_lt is True + res_gt, = n.execute(lhs=6, op="gt", rhs=5) + assert res_gt is True + res_le, = n.execute(lhs=5, op="le", rhs=5) + assert res_le is True + res_ge, = n.execute(lhs=5, op="ge", rhs=5) + assert res_ge is True + res_add, = n.execute(lhs=5, op="add", rhs=3) + assert res_add == 8 + res_sub, = n.execute(lhs=5, op="sub", rhs=3) + assert res_sub == 2 + res_mul, = n.execute(lhs=5, op="mul", rhs=3) + assert res_mul == 15 + res_div, = n.execute(lhs=10, op="truediv", rhs=4) + assert res_div == 2.5 + res_floor_div, = n.execute(lhs=10, op="floordiv", rhs=3) + assert res_floor_div == 3 + + # Test logical 'and' + res_and_tt, = n.execute(lhs=True, op="and", rhs=True) + assert res_and_tt is True + res_and_tf, = n.execute(lhs=True, op="and", rhs=False) + assert res_and_tf is False + res_and_ff, = n.execute(lhs=False, op="and", rhs=False) + assert res_and_ff is False + + # Test logical 'or' + res_or_tf, = n.execute(lhs=True, op="or", rhs=False) + assert res_or_tf is True + res_or_ft, = n.execute(lhs=False, op="or", rhs=True) + assert res_or_ft is True + res_or_ff, = n.execute(lhs=False, op="or", rhs=False) + assert res_or_ff is False + + +def test_binary_operation_lazy_check(): + n = BinaryOperation() + + # Test standard ops + assert n.check_lazy_status(op="eq") == ["lhs", "rhs"] + assert n.check_lazy_status(op="add", lhs=1, rhs=2) == ["lhs", "rhs"] + + # Test 'and' + assert n.check_lazy_status(op="and", lhs=None) == ["lhs"] + assert n.check_lazy_status(op="and", lhs=True, rhs=None) == ["rhs"] + assert n.check_lazy_status(op="and", lhs=False, rhs=None) == [] + assert n.check_lazy_status(op="and", lhs=False, rhs=True) == [] + + # Test 'or' + assert n.check_lazy_status(op="or", lhs=None) == ["lhs"] + assert n.check_lazy_status(op="or", lhs=True, rhs=None) == [] + assert n.check_lazy_status(op="or", lhs=False, rhs=None) == ["rhs"] + assert n.check_lazy_status(op="or", lhs=True, rhs=False) == [] + + +def test_boolean_binary_operation(): + n = BooleanBinaryOperation() + + # Test 'eq' + res_eq, = n.execute(lhs=5, op="eq", rhs=5) + assert res_eq is True + res_ne, = n.execute(lhs=5, op="eq", rhs=6) + assert res_ne is False + + # Test truthiness + res_and_truthy, = n.execute(lhs="hello", op="and", rhs=1) + assert res_and_truthy is True + res_and_falsy, = n.execute(lhs="hello", op="and", rhs=0) + assert res_and_falsy is False + res_or_falsy, = n.execute(lhs="", op="or", rhs=0) + assert res_or_falsy is False + res_or_truthy, = n.execute(lhs="", op="or", rhs="test") + assert res_or_truthy is True +