Improve lazy graph evaluation, add logic operators

This commit is contained in:
doctorpangloss 2025-10-31 14:27:27 -07:00
parent 6f2589f256
commit 97f911280e
15 changed files with 885 additions and 68 deletions

View File

@ -57,6 +57,7 @@ def _execute_prompt(
finally:
detach(token)
async def __execute_prompt(
prompt: dict,
prompt_id: str,
@ -68,6 +69,7 @@ async def __execute_prompt(
with context_configuration(configuration):
return await ___execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, partial_execution_targets)
async def ___execute_prompt(
prompt: dict,
prompt_id: str,
@ -116,7 +118,7 @@ async def ___execute_prompt(
raise exc_info
def _cleanup():
def _cleanup(invalidate_nodes=True):
from ..cmd.execution import PromptExecutor
from ..nodes_context import invalidate
try:
@ -132,6 +134,7 @@ def _cleanup():
model_management.soft_empty_cache()
except:
pass
if invalidate_nodes:
try:
invalidate()
except:
@ -180,7 +183,6 @@ class Comfy:
self._history = History()
self._context_stack = []
@property
def is_running(self) -> bool:
return self._is_running
@ -200,6 +202,9 @@ class Comfy:
def history(self) -> History:
return self._history
async def clear_cache(self):
await get_event_loop().run_in_executor(self._executor, _cleanup, False)
def __exit__(self, *args):
get_event_loop().run_in_executor(self._executor, _cleanup)
self._executor.shutdown(wait=True)

View File

@ -18,8 +18,9 @@ from typing import List, Optional, Tuple, Literal
# order matters
from .main_pre import tracer
import torch
from frozendict import frozendict
from comfy_execution.graph_types import FrozenTopologicalSort, Input
from opentelemetry.trace import get_current_span, StatusCode, Status
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
@ -30,7 +31,7 @@ from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io
from ..execution_context import current_execution_context
from ..execution_context import current_execution_context, context_set_execution_list_and_inputs
from .. import interruption
from .. import model_management
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
@ -241,7 +242,12 @@ async def resolve_map_node_over_list_results(results):
@tracer.start_as_current_span("Execute Node")
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None, execution_list=None, executed=None):
with context_set_execution_list_and_inputs(FrozenTopologicalSort.from_topological_sort(execution_list) if execution_list is not None else None, frozenset(executed) if executed is not None else None):
return await __async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb, hidden_inputs)
async def __async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
span = get_current_span()
class_type = obj.__class__.__name__
span.set_attribute("class_type", class_type)
@ -368,8 +374,8 @@ def merge_result_data(results, obj):
return output
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None, inputs=None, execution_list=None, executed=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs, execution_list=execution_list, executed=executed)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task:
return return_values, {}, False, has_pending_task
@ -474,7 +480,7 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca
return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list: ExecutionList, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
@ -543,7 +549,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present:
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs, execution_list=execution_list, executed=executed)
required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], []))
required_inputs = [x for x in required_inputs if isinstance(x, str) and (
@ -577,7 +583,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs, inputs=inputs, execution_list=execution_list, executed=executed)
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)

View File

@ -1,8 +1,8 @@
from typing import NamedTuple
from torch import Tensor
from .tensor_types import ImageBatch, MaskBatch
class RgbMaskTuple(NamedTuple):
rgb: Tensor
mask: Tensor
class ImageMaskTuple(NamedTuple):
image: ImageBatch
mask: MaskBatch

View File

@ -5,6 +5,7 @@ from contextvars import ContextVar
from dataclasses import dataclass, replace
from typing import Optional, Final
from comfy_execution.graph_types import FrozenTopologicalSort
from .cli_args import cli_args_configuration
from .cli_args_types import Configuration
from .component_model import cvpickle
@ -17,15 +18,26 @@ from .progress_types import AbstractProgressRegistry, ProgressRegistryStub
@dataclass(frozen=True)
class ExecutionContext:
# at any time
server: ExecutorToClientProgress
folder_names_and_paths: FolderNames
custom_nodes: ExportedNodes
configuration: Optional[Configuration] = None
# during prompt execution
progress_registry: Optional[AbstractProgressRegistry] = None
# during node execution
node_id: Optional[str] = None
task_id: Optional[str] = None
list_index: Optional[int] = None
inference_mode: bool = True
progress_registry: Optional[AbstractProgressRegistry] = None
configuration: Optional[Configuration] = None
execution_list: Optional[FrozenTopologicalSort] = None
executed: Optional[frozenset] = None
@property
def inputs(self) -> dict:
return self.execution_list.dynprompt.get_node(self.node_id)['inputs']
def __iter__(self):
"""
@ -123,3 +135,11 @@ def context_set_node_and_prompt(prompt_id: str, node_id: str, list_index: Option
new_ctx = replace(current_ctx, task_id=prompt_id, node_id=node_id, list_index=list_index)
with _new_execution_context(new_ctx):
yield new_ctx
@contextmanager
def context_set_execution_list_and_inputs(execution_list: FrozenTopologicalSort, executed: frozenset):
current_ctx = current_execution_context()
new_ctx = replace(current_ctx, execution_list=execution_list, executed=executed)
with _new_execution_context(new_ctx):
yield new_ctx

29
comfy/lazy_helpers.py Normal file
View File

@ -0,0 +1,29 @@
from .execution_context import current_execution_context
from typing import Any, Generator, Sequence
def is_input_unscheduled_unexecuted(*arg_names: Sequence[str]) -> Generator[bool, Any, None]:
"""
returns true if the given argument in the context of an executing node is not scheduled nor executed
this will be true for inputs that are marked as lazy, and this method is more robust against nodes that return None
:param arg_names: each arg to evaluate
:return: True for each arg that is not scheduled nor executed
"""
context = current_execution_context()
if context is None or context.execution_list is None:
raise LookupError("Not executing a node")
assert context.execution_list is not None
dynprompt = context.execution_list.dynprompt
executed = context.executed
execution_list = context.execution_list
inputs = context.inputs
unscheduled_unexecuted = dynprompt.all_node_ids() - executed - set(execution_list.pendingNodes.keys())
for arg_name in arg_names:
if arg_name not in inputs:
raise ValueError(f"Input {arg_name} not found")
input_ = inputs[arg_name]
if isinstance(input_, list) or isinstance(input_, tuple) and len(input_) == 2:
node_id, *_ = input_
yield node_id in unscheduled_unexecuted
else:
yield False

View File

@ -30,6 +30,7 @@ from ..cli_args import args
from ..cmd import folder_paths, latent_preview
from ..comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
from ..component_model.deprecation import _deprecate_method
from ..component_model.images_types import ImageMaskTuple
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch, Latent
from ..execution_context import current_execution_context
from ..images import open_image
@ -1791,7 +1792,7 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image"
def load_image(self, image: str) -> tuple[RGBImageBatch, MaskBatch]:
def load_image(self, image: str) -> ImageMaskTuple:
image_path = folder_paths.get_annotated_filepath(image)
output_images = []
output_masks = []
@ -1829,7 +1830,9 @@ class LoadImage:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
# unlike upstream, the mask is now the size of the original image, even when there was no alpha channel
# this is opaque
mask = torch.zeros((h, w), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
@ -1840,7 +1843,7 @@ class LoadImage:
output_image = output_images[0]
output_mask = output_masks[0]
return (output_image, output_mask)
return ImageMaskTuple(output_image, output_mask)
@classmethod
def VALIDATE_INPUTS(s, image):

View File

@ -33,7 +33,7 @@ import numpy as np
import torch
from torch import Tensor
from .component_model.images_types import RgbMaskTuple
from .component_model.images_types import ImageMaskTuple
read_exr = lambda fp: cv2.imread(fp, cv2.IMREAD_UNCHANGED).astype(np.float32)
@ -50,7 +50,7 @@ def mut_linear_to_srgb(np_array) -> None:
np_array[~less] = np.power(np_array[~less], 1 / 2.4) * 1.055 - 0.055
def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple:
def load_exr(file_path: str, srgb: bool) -> ImageMaskTuple:
image = read_exr(file_path)
rgb = np.flip(image[:, :, :3], 2).copy()
if srgb:
@ -62,7 +62,7 @@ def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple:
if image.shape[2] > 3:
mask[0] = torch.from_numpy(np.clip(image[:, :, 3], 0, 1))
return RgbMaskTuple(rgb, mask)
return ImageMaskTuple(rgb, mask)
def load_exr_latent(file_path: str) -> Tuple[Tensor]:

View File

@ -5,7 +5,7 @@ import inspect
from typing import Optional, Type, Literal, Union, NamedTuple
from comfy.nodes.package_typing import CustomNode, InputTypes
from comfy_execution.graph_utils import is_link, ExecutionBlocker
from comfy_execution.graph_utils import ExecutionBlocker
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
from comfy.component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \
DependencyExecutionErrorMessage, ComboOptions
@ -249,7 +249,7 @@ class ExecutionList(TopologicalSort):
# for a PreviewImage to display a result as soon as it can
# Some other heuristics could probably be used here to improve the UX further.
def is_output(node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_type = self.dynprompt.get_node(node_id)['class_type']
class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
return True
@ -258,7 +258,7 @@ class ExecutionList(TopologicalSort):
# If an available node is async, do that first.
# This will execute the asynchronous function earlier, reducing the overall time.
def is_async(node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_type = self.dynprompt.get_node(node_id)['class_type']
class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))

View File

@ -0,0 +1,81 @@
from __future__ import annotations
import typing
from dataclasses import dataclass
from typing import Mapping
from frozendict import frozendict, deepfreeze
from comfy.component_model.executor_types import NodeNotFoundError
if typing.TYPE_CHECKING:
from .graph import TopologicalSort, DynamicPrompt
NodeOutputByIdAndIndex = tuple[str, int]
InputValue = typing.Union[NodeOutputByIdAndIndex, bool, bytes, int, float, str, typing.IO, typing.BinaryIO]
Input = Mapping[str, InputValue]
@dataclass(frozen=True)
class FrozenTopologicalSort:
dynprompt: "FrozenPrompt"
pendingNodes: frozendict
blockCount: frozendict
blocking: frozendict
externalBlocks: int
@classmethod
def from_topological_sort(cls, ts: "TopologicalSort") -> FrozenTopologicalSort:
return cls(
dynprompt=FrozenPrompt.from_dynamic_prompt(ts.dynprompt),
pendingNodes=deepfreeze(ts.pendingNodes),
blockCount=deepfreeze(ts.blockCount),
blocking=deepfreeze(ts.blocking),
externalBlocks=ts.externalBlocks,
)
@dataclass(frozen=True)
class FrozenPrompt:
original_prompt: frozendict
ephemeral_prompt: frozendict
ephemeral_parents: frozendict
ephemeral_display: frozendict
@classmethod
def from_dynamic_prompt(cls, dynprompt: "DynamicPrompt") -> FrozenPrompt:
return cls(
original_prompt=deepfreeze(dynprompt.original_prompt),
ephemeral_prompt=deepfreeze(dynprompt.ephemeral_prompt),
ephemeral_parents=deepfreeze(dynprompt.ephemeral_parents),
ephemeral_display=deepfreeze(dynprompt.ephemeral_display),
)
def get_node(self, node_id):
if node_id in self.ephemeral_prompt:
return self.ephemeral_prompt[node_id]
if node_id in self.original_prompt:
return self.original_prompt[node_id]
raise NodeNotFoundError(f"Node {node_id} not found")
def has_node(self, node_id):
return node_id in self.original_prompt or node_id in self.ephemeral_prompt
def get_real_node_id(self, node_id):
while node_id in self.ephemeral_parents:
node_id = self.ephemeral_parents[node_id]
return node_id
def get_parent_node_id(self, node_id):
return self.ephemeral_parents.get(node_id, None)
def get_display_node_id(self, node_id):
while node_id in self.ephemeral_display:
node_id = self.ephemeral_display[node_id]
return node_id
def all_node_ids(self):
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
def get_original_prompt(self):
return self.original_prompt

View File

@ -42,7 +42,7 @@ class CurrentNodeContext:
self._cm = core_execution_context.context_set_node_and_prompt(
prompt_id=prompt_id,
node_id=node_id,
list_index=list_index
list_index=list_index,
)
def __enter__(self):

View File

@ -0,0 +1,179 @@
import inspect
import operator
from typing import OrderedDict, Callable, Any
from comfy.comfy_types import IO
from comfy.lazy_helpers import is_input_unscheduled_unexecuted
from comfy.node_helpers import export_custom_nodes
from comfy.nodes.package_typing import CustomNode, InputTypes
def takes_n_args(obj, n):
if not callable(obj):
return False
try:
sig = inspect.signature(obj)
except (TypeError, ValueError):
return False
params = sig.parameters.values()
named_param_count = sum(
1 for p in params
if p.kind not in (inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD)
)
return named_param_count == n
_BINARY_OPS: dict[str, Callable[[Any, Any], Any]] = OrderedDict({
**{op: getattr(operator, op) for op in dir(operator) if takes_n_args(getattr(operator, op), 2)},
"and": lambda a, b: a and b,
"or": lambda a, b: a or b,
})
_UNARY_OPS: dict[str, Callable[[Any], Any]] = {
**{op: getattr(operator, op) for op in dir(operator) if takes_n_args(getattr(operator, op), 1)},
"not": lambda a: not a
}
class IsNone(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {},
"optional": {
"value": (IO.ANY, {}),
}
}
RETURN_TYPES = (IO.BOOLEAN,)
FUNCTION = "execute"
CATEGORY = "logic"
def execute(self, value=None):
return (value is None,)
class LazySwitch(CustomNode):
"""
sherlocked from KJ nodes with fixes
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"switch": ("BOOLEAN",),
},
"optional": {
"on_false": (IO.ANY, {"lazy": True}),
"on_true": (IO.ANY, {"lazy": True}),
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "execute"
CATEGORY = "logic"
DESCRIPTION = "Controls flow of execution based on a boolean switch."
def check_lazy_status(self, switch, on_false=None, on_true=None):
try:
on_false_not_evaluated, on_true_not_evaluated = is_input_unscheduled_unexecuted("on_false", "on_true")
except LookupError:
on_false_not_evaluated, on_true_not_evaluated = on_false is None, on_true is None
if switch and on_true_not_evaluated:
return ["on_true"]
if not switch and on_false_not_evaluated:
return ["on_false"]
return []
def execute(self, switch, on_false=None, on_true=None):
value = on_true if switch else on_false
return (value,)
class UnaryOperation(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {},
"optional": {
"value": (IO.ANY, {}),
"op": (list(_UNARY_OPS.keys()), {"default": "not"})
}
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "execute"
CATEGORY = "logic"
def execute(self, value, op):
return _UNARY_OPS[op](value),
class BooleanUnaryOperation(UnaryOperation):
RETURN_TYPES = (IO.BOOLEAN,)
def execute(self, value, op):
return bool(super().execute(value, op)[0]),
class BinaryOperation(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {},
"optional": OrderedDict({
"lhs": (IO.ANY, {"lazy": True}),
"op": (list(_BINARY_OPS.keys()), {"default": "eq"}),
"rhs": (IO.ANY, {"lazy": True}),
})
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "execute"
CATEGORY = "logic"
DESCRIPTION = ""
def check_lazy_status(self, lhs=None, op=None, rhs=None) -> list[str]:
try:
lhs_not_evaluated, rhs_not_evaluated = is_input_unscheduled_unexecuted("lhs", "rhs")
except LookupError:
lhs_not_evaluated, rhs_not_evaluated = lhs is None, rhs is None
lhs_evaluated, rhs_evaluated = not lhs_not_evaluated, not rhs_not_evaluated
match op:
case "and":
if lhs_not_evaluated:
return ["lhs"]
if lhs_evaluated and lhs is not False and rhs_not_evaluated:
return ["rhs"]
return []
case "or":
if lhs_not_evaluated:
return ["lhs"]
if lhs_evaluated and lhs is not True and rhs_not_evaluated:
return ["rhs"]
return []
case _:
to_eval = []
if lhs_not_evaluated:
to_eval.append("lhs")
if rhs_not_evaluated:
to_eval.append("rhs")
return to_eval
def execute(self, lhs, op, rhs):
return _BINARY_OPS[op](lhs, rhs),
class BooleanBinaryOperation(BinaryOperation):
RETURN_TYPES = (IO.BOOLEAN,)
def execute(self, lhs, op, rhs):
return bool(super().execute(lhs, op, rhs)[0]),
export_custom_nodes()

View File

@ -34,7 +34,8 @@ from torch import Tensor
from comfy.cmd import folder_paths
from comfy.comfy_types import IO
from comfy.component_model.tensor_types import RGBAImageBatch, RGBImageBatch
from comfy.component_model.images_types import ImageMaskTuple
from comfy.component_model.tensor_types import RGBAImageBatch, RGBImageBatch, MaskBatch, ImageBatch
from comfy.digest import digest
from comfy.node_helpers import export_custom_nodes
from comfy.nodes.package_typing import CustomNode, InputTypes, FunctionReturnsUIVariables, SaveNodeResult, \
@ -809,18 +810,20 @@ class ImageRequestParameter(CustomNode):
},
"optional": {
**_open_api_common_schema,
"default_if_empty": ("IMAGE",)
"default_if_empty": ("IMAGE",),
"alpha_is_transparency": ("BOOLEAN", {"default": False}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "execute"
CATEGORY = "api/openapi"
def execute(self, value: str = "", default_if_empty=None, *args, **kwargs) -> ValidatedNodeResult:
def execute(self, value: str = "", default_if_empty=None, alpha_is_transparency=False, *args, **kwargs) -> ImageMaskTuple:
if value.strip() == "":
return (default_if_empty,)
output_images = []
output_masks = []
f: OpenFile
fsspec_kwargs = {}
if value.startswith('http'):
@ -832,7 +835,8 @@ class ImageRequestParameter(CustomNode):
})
# todo: additional security is needed here to prevent users from accessing local paths
# however this generally needs to be done with user accounts on all OSes
with fsspec.open(value, mode="rb", **fsspec_kwargs) as f:
with fsspec.open_files(value, mode="rb", **fsspec_kwargs) as files:
for f in files:
# from LoadImage
img = Image.open(f)
for i in ImageSequence.Iterator(img):
@ -848,15 +852,24 @@ class ImageRequestParameter(CustomNode):
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
image = i.convert("RGBA" if alpha_is_transparency else "RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
output_images.append(image)
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
output_image = output_images[0]
return (output_image,)
mask = torch.zeros((image.shape[1], image.shape[2]), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))
output_images_batched: ImageBatch = torch.cat(output_images, dim=0)
output_masks_batched: MaskBatch = torch.cat(output_masks, dim=0)
return ImageMaskTuple(output_images_batched, output_masks_batched)
export_custom_nodes()

View File

@ -3,7 +3,7 @@ import logging
import time
import urllib.request
import uuid
from typing import Dict, Optional
from typing import Dict, Optional, AsyncGenerator
import numpy
import pytest
@ -128,7 +128,7 @@ class TestExecution:
(0, True),
(100, True),
])
async def client(self, request) -> ComfyClient:
async def client(self, request):
from ..inference.testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
# ??? todo: we have to deal with this
@ -784,3 +784,163 @@ class TestExecution:
result = client.get_all_history(max_items=5, offset=len(all_history) - 1)
assert len(result) <= 1, "Should return at most 1 item when offset is near end"
async def test_lazy_switch_true_branch(self, client: ComfyClient, builder: GraphBuilder):
await client.embedded_client.clear_cache()
g = builder
# Create a "True" boolean value
true_int = g.node("StubInt", value=1)
true_bool = g.node("TestIntConditions", a=true_int.out(0), b=1, operation="==") # 1 == 1 -> True
# Create nodes for branches
node_true = g.node("StubImage", content="WHITE", height=32, width=32, batch_size=1)
node_false = g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
# Create lazy switch
# Note: LazySwitch is imported at the top of the file
switch = g.node("LazySwitch", switch=true_bool.out(0), on_true=node_true.out(0), on_false=node_false.out(0))
output = g.node("SaveImage", images=switch.out(0))
result = await client.run(g)
# Check execution
assert result.did_run(true_int), "True stub int should run"
assert result.did_run(true_bool), "Boolean condition node should run"
assert result.did_run(node_true), "on_true node should run"
assert not result.did_run(node_false), "on_false node should NOT run"
assert result.did_run(switch), "LazySwitch node should run"
assert result.did_run(output), "SaveImage node should run"
# Check output
result_image = result.get_images(output)[0]
assert numpy.array(result_image).mean() == 255, "Image should be white"
async def test_lazy_switch_false_branch(self, client: ComfyClient, builder: GraphBuilder):
await client.embedded_client.clear_cache()
g = builder
# Create a "False" boolean value
false_int = g.node("StubInt", value=0)
false_bool = g.node("TestIntConditions", a=false_int.out(0), b=1, operation="==") # 0 == 1 -> False
# Create nodes for branches
node_true = g.node("StubImage", content="WHITE", height=32, width=32, batch_size=1)
node_false = g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
# Create lazy switch
switch = g.node("LazySwitch", switch=false_bool.out(0), on_true=node_true.out(0), on_false=node_false.out(0))
output = g.node("SaveImage", images=switch.out(0))
result = await client.run(g)
# Check execution
assert result.did_run(false_int), "False stub int should run"
assert result.did_run(false_bool), "Boolean condition node should run"
assert not result.did_run(node_true), "on_true node should NOT run"
assert result.did_run(node_false), "on_false node should run"
assert result.did_run(switch), "LazySwitch node should run"
assert result.did_run(output), "SaveImage node should run"
# Check output
result_image = result.get_images(output)[0]
assert numpy.array(result_image).mean() == 0, "Image should be black"
async def test_lazy_binary_op_and_short_circuit(self, client: ComfyClient, builder: GraphBuilder):
await client.embedded_client.clear_cache()
g = builder
# Create a "False" boolean value
false_int = g.node("StubInt", value=0)
lhs_bool = g.node("TestIntConditions", a=false_int.out(0), b=1, operation="==") # 0 == 1 -> False
# Create a "True" boolean value for RHS (this node should not run)
true_int_rhs = g.node("StubInt", value=1)
rhs_bool = g.node("TestIntConditions", a=true_int_rhs.out(0), b=1, operation="==") # 1 == 1 -> True
# Create binary op
# Note: BinaryOperation is imported at the top of the file
binary_op = g.node("BinaryOperation", lhs=lhs_bool.out(0), op="and", rhs=rhs_bool.out(0))
# Create lazy switch to check result
node_true = g.node("StubImage", content="WHITE", height=32, width=32, batch_size=1)
node_false = g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
switch = g.node("LazySwitch", switch=binary_op.out(0), on_true=node_true.out(0), on_false=node_false.out(0))
output = g.node("SaveImage", images=switch.out(0))
result = await client.run(g)
# Check execution
assert result.did_run(false_int), "LHS int node should run"
assert result.did_run(lhs_bool), "LHS bool node should run"
assert not result.did_run(true_int_rhs), "RHS int node should NOT run (short-circuit)"
assert not result.did_run(rhs_bool), "RHS bool node should NOT run (short-circuit)"
assert result.did_run(binary_op), "BinaryOp should run"
assert not result.did_run(node_true), "on_true node should NOT run"
assert result.did_run(node_false), "on_false node should run"
# Check output
result_image = result.get_images(output)[0]
assert numpy.array(result_image).mean() == 0, "Image should be black (result of 'and' was False)"
async def test_lazy_binary_op_or_short_circuit(self, client: ComfyClient, builder: GraphBuilder):
await client.embedded_client.clear_cache()
g = builder
# Create a "True" boolean value
true_int = g.node("StubInt", value=1)
lhs_bool = g.node("TestIntConditions", a=true_int.out(0), b=1, operation="==") # 1 == 1 -> True
# Create a "False" boolean value for RHS (this node should not run)
false_int_rhs = g.node("StubInt", value=0)
rhs_bool = g.node("TestIntConditions", a=false_int_rhs.out(0), b=1, operation="==") # 0 == 1 -> False
# Create binary op
binary_op = g.node("BinaryOperation", lhs=lhs_bool.out(0), op="or", rhs=rhs_bool.out(0))
# Create lazy switch to check result
node_true = g.node("StubImage", content="WHITE", height=32, width=32, batch_size=1)
node_false = g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
switch = g.node("LazySwitch", switch=binary_op.out(0), on_true=node_true.out(0), on_false=node_false.out(0))
output = g.node("SaveImage", images=switch.out(0))
result = await client.run(g)
# Check execution
assert result.did_run(true_int), "LHS int node should run"
assert result.did_run(lhs_bool), "LHS bool node should run"
assert not result.did_run(false_int_rhs), "RHS int node should NOT run (short-circuit)"
assert not result.did_run(rhs_bool), "RHS bool node should NOT run (short-circuit)"
assert result.did_run(binary_op), "BinaryOp should run"
assert result.did_run(node_true), "on_true node should run"
assert not result.did_run(node_false), "on_false node should NOT run"
# Check output
result_image = result.get_images(output)[0]
assert numpy.array(result_image).mean() == 255, "Image should be white (result of 'or' was True)"
async def test_lazy_switch_with_none_input(self, client: ComfyClient, builder: GraphBuilder):
await client.embedded_client.clear_cache()
g = builder
# Create a "False" boolean value
false_int = g.node("StubInt", value=0)
false_bool = g.node("TestIntConditions", a=false_int.out(0), b=1, operation="==") # 0 == 1 -> False
# Create nodes for branches
# This node will return None as its value is empty and default_if_empty is not set
node_true_image = g.node("ImageRequestParameter", value="", description="1")
node_false_image = g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
# Create lazy switch
switch = g.node("LazySwitch", switch=false_bool.out(0), on_true=node_true_image.out(0), on_false=node_false_image.out(0))
output = g.node("SaveImage", images=switch.out(0))
result = await client.run(g)
# Check execution
assert result.did_run(false_int), "False stub int should run"
assert result.did_run(false_bool), "Boolean condition node should run"
assert not result.did_run(node_true_image), "on_true (ImageRequestParameter) node should NOT run"
assert result.did_run(node_false_image), "on_false node should run"
assert result.did_run(switch), "LazySwitch node should run"
assert result.did_run(output), "SaveImage node should run"
# Check output
result_image = result.get_images(output)[0]
assert numpy.array(result_image).mean() == 0, "Image should be black"

View File

@ -1,4 +1,7 @@
from __future__ import annotations
import platform
# noqa: E402
from comfy.cmd.main_pre import args
import os
@ -280,7 +283,6 @@ def test_posix_join_curly_brackets():
assert joined_path == "a_{test}/b/c"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
def test_file_request_parameter(use_temporary_input_directory):
_image_1x1_px = np.array([[[255, 0, 0]]], dtype=np.uint8)
image_path = os.path.join(use_temporary_input_directory, "test_image.png")
@ -288,7 +290,7 @@ def test_file_request_parameter(use_temporary_input_directory):
image.save(image_path)
n = ImageRequestParameter()
loaded_image, = n.execute(value=image_path)
loaded_image, *_ = n.execute(value=image_path)
assert loaded_image.shape == (1, 1, 1, 3)
from comfy.nodes.base_nodes import LoadImage
@ -299,13 +301,174 @@ def test_file_request_parameter(use_temporary_input_directory):
assert torch.allclose(loaded_image, load_image_node_rgb)
def test_file_request_parameter2(use_temporary_input_directory):
n = ImageRequestParameter()
# Test 1: Load a single RGB image
_image_1x1_px_rgb = np.array([[[255, 0, 0]]], dtype=np.uint8) # 1x1 RGB
image_path_rgb = os.path.join(use_temporary_input_directory, "test_image_rgb.png")
image_rgb = Image.fromarray(_image_1x1_px_rgb, 'RGB')
image_rgb.save(image_path_rgb)
loaded_image_rgb, loaded_mask_rgb = n.execute(value=image_path_rgb, alpha_is_transparency=True)
# Node converts RGB to RGBA
assert loaded_image_rgb.shape == (1, 1, 1, 4) # B, H, W, C
# Check RGB values
assert torch.allclose(loaded_image_rgb[0, 0, 0, :3], torch.tensor([1.0, 0.0, 0.0]))
# Check added Alpha channel
assert torch.allclose(loaded_image_rgb[0, 0, 0, 3], torch.tensor(1.0))
# Check mask (should be all 0s for 1.0 alpha)
assert loaded_mask_rgb.shape == (1, 1, 1) # B, H, W
assert torch.all(loaded_mask_rgb == 0.0)
# Test 2: Load a single RGBA image with transparency
_image_1x1_px_rgba = np.array([[[255, 0, 0, 128]]], dtype=np.uint8) # 1x1 RGBA
image_path_rgba = os.path.join(use_temporary_input_directory, "test_image_rgba.png")
image_rgba = Image.fromarray(_image_1x1_px_rgba, 'RGBA')
image_rgba.save(image_path_rgba)
loaded_image_rgba, loaded_mask_rgba = n.execute(value=image_path_rgba, alpha_is_transparency=True)
# Node should load RGBA as is
assert loaded_image_rgba.shape == (1, 1, 1, 4) # B, H, W, C
# Check RGBA values
assert torch.allclose(loaded_image_rgba[0, 0, 0, :], torch.tensor([1.0, 0.0, 0.0, 128 / 255.0]))
# Check mask (should be 1.0 - alpha)
assert loaded_mask_rgba.shape == (1, 1, 1) # B, H, W
assert torch.allclose(loaded_mask_rgba[0, 0, 0], torch.tensor(1.0 - 128 / 255.0))
# Test 3: Load a single RGB image with alpha_is_transparency=False
loaded_image_rgb_no_alpha, loaded_mask_rgb_no_alpha = n.execute(value=image_path_rgb, alpha_is_transparency=False)
# Node converts to RGB
assert loaded_image_rgb_no_alpha.shape == (1, 1, 1, 3) # B, H, W, C
# Check RGB values
assert torch.allclose(loaded_image_rgb_no_alpha[0, 0, 0, :], torch.tensor([1.0, 0.0, 0.0]))
# Check mask
assert loaded_mask_rgb_no_alpha.shape == (1, 1, 1) # B, H, W
# Test 4: Load a single RGBA image with alpha_is_transparency=False
loaded_image_rgba_no_alpha, loaded_mask_rgba_no_alpha = n.execute(value=image_path_rgba, alpha_is_transparency=False)
# Node should load RGBA as RGB (dropping alpha)
assert loaded_image_rgba_no_alpha.shape == (1, 1, 1, 3) # B, H, W, C
# Check RGB values (straight, not pre-multiplied)
assert torch.allclose(loaded_image_rgba_no_alpha[0, 0, 0, :], torch.tensor([1.0, 0.0, 0.0]))
assert loaded_mask_rgba_no_alpha.shape == (1, 1, 1) # B, H, W
def test_file_request_parameter_glob(use_temporary_input_directory):
# 1. Create dummy images (2x2)
# Image 1 (RGB)
img_rgb_data = np.array([
[[255, 0, 0], [0, 255, 0]],
[[0, 0, 255], [255, 255, 255]]
], dtype=np.uint8)
img_rgb = Image.fromarray(img_rgb_data, 'RGB')
path_rgb = os.path.join(use_temporary_input_directory, "img_rgb.png")
img_rgb.save(path_rgb)
# Image 2 (RGBA with transparency)
img_rgba_data_rgb = np.array([
[[10, 20, 30], [40, 50, 60]],
[[70, 80, 90], [100, 110, 120]]
], dtype=np.uint8)
img_rgba_data_a = np.array([
[255, 128], # alpha 1.0, 0.5
[0, 255] # alpha 0.0, 1.0
], dtype=np.uint8)
img_rgba_data = np.dstack((img_rgba_data_rgb, img_rgba_data_a))
img_rgba = Image.fromarray(img_rgba_data, 'RGBA')
path_rgba = os.path.join(use_temporary_input_directory, "img_rgba.png")
img_rgba.save(path_rgba)
# 2. Construct fsspec URL
prefix = 'file:///' if platform.system() == "Windows" else 'file://'
# Add glob pattern
url = f"{prefix}{use_temporary_input_directory}/*.png"
n = ImageRequestParameter()
# 3. Test with alpha_is_transparency=True (default)
loaded_images, loaded_masks = n.execute(value=url, alpha_is_transparency=True)
# 4. Verify results (True) - Order independent
assert loaded_images.shape == (2, 2, 2, 4) # B=2, H=2, W=2, C=RGBA
assert loaded_masks.shape == (2, 2, 2) # B=2, H=2, W=2
# Find which image is which by checking the mask sum
mask_sums = torch.sum(loaded_masks, dim=(1, 2))
expected_rgba_mask_sum = (1.0 - 128/255.0) + 1.0 # From alpha 128 and 0
# Argmin should find the all-zero mask (from the RGB image)
rgb_img_index = torch.argmin(mask_sums)
# Argmax should find the mask with transparency
rgba_img_index = torch.argmax(mask_sums)
assert rgb_img_index != rgba_img_index
assert torch.allclose(mask_sums[rgb_img_index], torch.tensor(0.0))
assert torch.allclose(mask_sums[rgba_img_index], torch.tensor(expected_rgba_mask_sum))
# Check RGB image tensor (which was converted to RGBA)
rgb_image_tensor = loaded_images[rgb_img_index]
assert torch.allclose(rgb_image_tensor[0, 0, :3], torch.tensor([1.0, 0.0, 0.0])) # Red pixel
assert torch.allclose(rgb_image_tensor[0, 0, 3], torch.tensor(1.0)) # Added alpha
assert torch.allclose(rgb_image_tensor[1, 1, :3], torch.tensor([1.0, 1.0, 1.0])) # White pixel
assert torch.allclose(rgb_image_tensor[1, 1, 3], torch.tensor(1.0)) # Added alpha
# Check RGBA image tensor
rgba_image_tensor = loaded_images[rgba_img_index]
# Pixel [0, 1] (alpha 128)
assert torch.allclose(rgba_image_tensor[0, 1, :3], torch.tensor([40/255.0, 50/255.0, 60/255.0]))
assert torch.allclose(rgba_image_tensor[0, 1, 3], torch.tensor(128/255.0)) # Original alpha
# Pixel [1, 0] (alpha 0)
assert torch.allclose(rgba_image_tensor[1, 0, :3], torch.tensor([70/255.0, 80/255.0, 90/255.0]))
assert torch.allclose(rgba_image_tensor[1, 0, 3], torch.tensor(0.0)) # Original alpha
assert torch.allclose(rgba_image_tensor[1, 0, 3], torch.tensor(0.0)) # Original alpha
# 5. Test with alpha_is_transparency=False
loaded_images_no_alpha, loaded_masks_no_alpha = n.execute(value=url, alpha_is_transparency=False)
# 6. Verify results (False)
assert loaded_images_no_alpha.shape == (2, 2, 2, 3) # B=2, H=2, W=2, C=RGB
assert loaded_masks_no_alpha.shape == (2, 2, 2) # B=2, H, W (empty)
# Find which image is which by checking pixel sum (RGB image has brighter pixels)
img_sums = torch.sum(loaded_images_no_alpha, dim=(1, 2, 3))
rgb_img_index = torch.argmax(img_sums)
rgba_img_index = torch.argmin(img_sums)
assert rgb_img_index != rgba_img_index
# Check RGB image tensor
rgb_image_tensor_no_alpha = loaded_images_no_alpha[rgb_img_index]
assert torch.allclose(rgb_image_tensor_no_alpha[0, 0, :], torch.tensor([1.0, 0.0, 0.0])) # Red
assert torch.allclose(rgb_image_tensor_no_alpha[1, 1, :], torch.tensor([1.0, 1.0, 1.0])) # White
# Check RGBA image tensor (which was converted to RGB, dropping alpha - straight matte)
rgba_image_tensor_no_alpha = loaded_images_no_alpha[rgba_img_index]
# Pixel [0, 1] (RGB [40, 50, 60], alpha 128)
# RGB channels are passed through unaltered
expected_rgb_0_1 = torch.tensor([40 / 255.0, 50 / 255.0, 60 / 255.0])
assert torch.allclose(rgba_image_tensor_no_alpha[0, 1, :], expected_rgb_0_1)
# Pixel [1, 0] (RGB [70, 80, 90], alpha 0)
# RGB channels are passed through unaltered
expected_rgb_1_0 = torch.tensor([70 / 255.0, 80 / 255.0, 90 / 255.0])
assert torch.allclose(rgba_image_tensor_no_alpha[1, 0, :], expected_rgb_1_0)
def test_file_request_to_http_url_no_exceptions():
n = ImageRequestParameter()
loaded_image, = n.execute(value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a6/A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg/484px-A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg")
_, height, width, channels = loaded_image.shape
loaded_image, loaded_mask = n.execute(value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a6/A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg/484px-A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg")
# This is an RGB jpg, so it will be converted to RGBA
b, height, width, channels = loaded_image.shape
assert b == 1
assert width == 484
assert height == 480
assert channels == 3
# Mask should be all zeros
assert loaded_mask.shape == (1, 480, 484)
assert torch.all(loaded_mask == 0.0)
@pytest.mark.parametrize("format,bits,supports_16bit", [

View File

@ -0,0 +1,158 @@
import pytest
from comfy_extras.nodes.nodes_logic import LazySwitch, UnaryOperation, BooleanUnaryOperation, BinaryOperation, \
BooleanBinaryOperation
def test_lazy_switch():
n = LazySwitch()
# Test True branch
res_true_check = n.check_lazy_status(switch=True)
assert res_true_check == ["on_true"]
res_true, = n.execute(switch=True, on_false="val_false", on_true="val_true")
assert res_true == "val_true"
# Test False branch
res_false_check = n.check_lazy_status(switch=False)
assert res_false_check == ["on_false"]
res_false, = n.execute(switch=False, on_false="val_false", on_true="val_true")
assert res_false == "val_false"
# Test with None
res_none, = n.execute(switch=False, on_false=None, on_true="val_true")
assert res_none is None
def test_unary_operation():
n = UnaryOperation()
# Test 'not'
res_not_true, = n.execute(value=True, op="not")
assert res_not_true is False
res_not_false, = n.execute(value=False, op="not")
assert res_not_false is True
res_not_str, = n.execute(value="hello", op="not")
assert res_not_str is False
res_not_empty_str, = n.execute(value="", op="not")
assert res_not_empty_str is True
# Test 'inv' (invert, ~)
res_inv, = n.execute(value=5, op="inv")
assert res_inv == ~5 # -6
# Test 'neg' (-)
res_neg, = n.execute(value=10, op="neg")
assert res_neg == -10
res_neg_neg, = n.execute(value=-10, op="neg")
assert res_neg_neg == 10
# Test 'pos' (+)
res_pos, = n.execute(value=-5, op="pos")
assert res_pos == -5
res_pos_pos, = n.execute(value=5, op="pos")
assert res_pos_pos == 5
def test_boolean_unary_operation():
n = BooleanUnaryOperation()
# Test 'not'
res_not_true, = n.execute(value=True, op="not")
assert res_not_true is False
res_not_false, = n.execute(value=False, op="not")
assert res_not_false is True
# Test truthiness
res_not_int_1, = n.execute(value=1, op="not")
assert res_not_int_1 is False
res_not_int_0, = n.execute(value=0, op="not")
assert res_not_int_0 is True
res_not_str, = n.execute(value="hello", op="not")
assert res_not_str is False
res_not_empty_str, = n.execute(value="", op="not")
assert res_not_empty_str is True
def test_binary_operation():
n = BinaryOperation()
# Test ops
res_eq, = n.execute(lhs=5, op="eq", rhs=5)
assert res_eq is True
res_ne, = n.execute(lhs=5, op="ne", rhs=6)
assert res_ne is True
res_lt, = n.execute(lhs=5, op="lt", rhs=6)
assert res_lt is True
res_gt, = n.execute(lhs=6, op="gt", rhs=5)
assert res_gt is True
res_le, = n.execute(lhs=5, op="le", rhs=5)
assert res_le is True
res_ge, = n.execute(lhs=5, op="ge", rhs=5)
assert res_ge is True
res_add, = n.execute(lhs=5, op="add", rhs=3)
assert res_add == 8
res_sub, = n.execute(lhs=5, op="sub", rhs=3)
assert res_sub == 2
res_mul, = n.execute(lhs=5, op="mul", rhs=3)
assert res_mul == 15
res_div, = n.execute(lhs=10, op="truediv", rhs=4)
assert res_div == 2.5
res_floor_div, = n.execute(lhs=10, op="floordiv", rhs=3)
assert res_floor_div == 3
# Test logical 'and'
res_and_tt, = n.execute(lhs=True, op="and", rhs=True)
assert res_and_tt is True
res_and_tf, = n.execute(lhs=True, op="and", rhs=False)
assert res_and_tf is False
res_and_ff, = n.execute(lhs=False, op="and", rhs=False)
assert res_and_ff is False
# Test logical 'or'
res_or_tf, = n.execute(lhs=True, op="or", rhs=False)
assert res_or_tf is True
res_or_ft, = n.execute(lhs=False, op="or", rhs=True)
assert res_or_ft is True
res_or_ff, = n.execute(lhs=False, op="or", rhs=False)
assert res_or_ff is False
def test_binary_operation_lazy_check():
n = BinaryOperation()
# Test standard ops
assert n.check_lazy_status(op="eq") == ["lhs", "rhs"]
assert n.check_lazy_status(op="add", lhs=1, rhs=2) == ["lhs", "rhs"]
# Test 'and'
assert n.check_lazy_status(op="and", lhs=None) == ["lhs"]
assert n.check_lazy_status(op="and", lhs=True, rhs=None) == ["rhs"]
assert n.check_lazy_status(op="and", lhs=False, rhs=None) == []
assert n.check_lazy_status(op="and", lhs=False, rhs=True) == []
# Test 'or'
assert n.check_lazy_status(op="or", lhs=None) == ["lhs"]
assert n.check_lazy_status(op="or", lhs=True, rhs=None) == []
assert n.check_lazy_status(op="or", lhs=False, rhs=None) == ["rhs"]
assert n.check_lazy_status(op="or", lhs=True, rhs=False) == []
def test_boolean_binary_operation():
n = BooleanBinaryOperation()
# Test 'eq'
res_eq, = n.execute(lhs=5, op="eq", rhs=5)
assert res_eq is True
res_ne, = n.execute(lhs=5, op="eq", rhs=6)
assert res_ne is False
# Test truthiness
res_and_truthy, = n.execute(lhs="hello", op="and", rhs=1)
assert res_and_truthy is True
res_and_falsy, = n.execute(lhs="hello", op="and", rhs=0)
assert res_and_falsy is False
res_or_falsy, = n.execute(lhs="", op="or", rhs=0)
assert res_or_falsy is False
res_or_truthy, = n.execute(lhs="", op="or", rhs="test")
assert res_or_truthy is True