mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-09 13:50:49 +08:00
Improve lazy graph evaluation, add logic operators
This commit is contained in:
parent
6f2589f256
commit
97f911280e
@ -57,6 +57,7 @@ def _execute_prompt(
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
|
||||
async def __execute_prompt(
|
||||
prompt: dict,
|
||||
prompt_id: str,
|
||||
@ -68,6 +69,7 @@ async def __execute_prompt(
|
||||
with context_configuration(configuration):
|
||||
return await ___execute_prompt(prompt, prompt_id, client_id, span_context, progress_handler, partial_execution_targets)
|
||||
|
||||
|
||||
async def ___execute_prompt(
|
||||
prompt: dict,
|
||||
prompt_id: str,
|
||||
@ -116,7 +118,7 @@ async def ___execute_prompt(
|
||||
raise exc_info
|
||||
|
||||
|
||||
def _cleanup():
|
||||
def _cleanup(invalidate_nodes=True):
|
||||
from ..cmd.execution import PromptExecutor
|
||||
from ..nodes_context import invalidate
|
||||
try:
|
||||
@ -132,10 +134,11 @@ def _cleanup():
|
||||
model_management.soft_empty_cache()
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
invalidate()
|
||||
except:
|
||||
pass
|
||||
if invalidate_nodes:
|
||||
try:
|
||||
invalidate()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
class Comfy:
|
||||
@ -180,7 +183,6 @@ class Comfy:
|
||||
self._history = History()
|
||||
self._context_stack = []
|
||||
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._is_running
|
||||
@ -200,6 +202,9 @@ class Comfy:
|
||||
def history(self) -> History:
|
||||
return self._history
|
||||
|
||||
async def clear_cache(self):
|
||||
await get_event_loop().run_in_executor(self._executor, _cleanup, False)
|
||||
|
||||
def __exit__(self, *args):
|
||||
get_event_loop().run_in_executor(self._executor, _cleanup)
|
||||
self._executor.shutdown(wait=True)
|
||||
|
||||
@ -18,8 +18,9 @@ from typing import List, Optional, Tuple, Literal
|
||||
|
||||
# order matters
|
||||
from .main_pre import tracer
|
||||
|
||||
import torch
|
||||
from frozendict import frozendict
|
||||
from comfy_execution.graph_types import FrozenTopologicalSort, Input
|
||||
from opentelemetry.trace import get_current_span, StatusCode, Status
|
||||
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
|
||||
@ -30,7 +31,7 @@ from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io
|
||||
from ..execution_context import current_execution_context
|
||||
from ..execution_context import current_execution_context, context_set_execution_list_and_inputs
|
||||
from .. import interruption
|
||||
from .. import model_management
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
@ -241,7 +242,12 @@ async def resolve_map_node_over_list_results(results):
|
||||
|
||||
|
||||
@tracer.start_as_current_span("Execute Node")
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None, execution_list=None, executed=None):
|
||||
with context_set_execution_list_and_inputs(FrozenTopologicalSort.from_topological_sort(execution_list) if execution_list is not None else None, frozenset(executed) if executed is not None else None):
|
||||
return await __async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb, hidden_inputs)
|
||||
|
||||
|
||||
async def __async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
span = get_current_span()
|
||||
class_type = obj.__class__.__name__
|
||||
span.set_attribute("class_type", class_type)
|
||||
@ -368,8 +374,8 @@ def merge_result_data(results, obj):
|
||||
return output
|
||||
|
||||
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None, inputs=None, execution_list=None, executed=None):
|
||||
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs, execution_list=execution_list, executed=executed)
|
||||
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
|
||||
if has_pending_task:
|
||||
return return_values, {}, False, has_pending_task
|
||||
@ -469,12 +475,12 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca
|
||||
:return:
|
||||
"""
|
||||
with (context_execute_node(node_id),
|
||||
vanilla_node_execution_environment(),
|
||||
use_requests_caching()):
|
||||
vanilla_node_execution_environment(),
|
||||
use_requests_caching()):
|
||||
return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
||||
|
||||
|
||||
async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
|
||||
async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list: ExecutionList, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
@ -543,7 +549,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
||||
else:
|
||||
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
|
||||
if lazy_status_present:
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
|
||||
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs, execution_list=execution_list, executed=executed)
|
||||
required_inputs = await resolve_map_node_over_list_results(required_inputs)
|
||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], []))
|
||||
required_inputs = [x for x in required_inputs if isinstance(x, str) and (
|
||||
@ -577,7 +583,7 @@ async def _execute(server, dynprompt, caches: CacheSet, current_item: str, extra
|
||||
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
|
||||
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
|
||||
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs, inputs=inputs, execution_list=execution_list, executed=executed)
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
unblock = execution_list.add_external_block(unique_id)
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from typing import NamedTuple
|
||||
|
||||
from torch import Tensor
|
||||
from .tensor_types import ImageBatch, MaskBatch
|
||||
|
||||
|
||||
class RgbMaskTuple(NamedTuple):
|
||||
rgb: Tensor
|
||||
mask: Tensor
|
||||
class ImageMaskTuple(NamedTuple):
|
||||
image: ImageBatch
|
||||
mask: MaskBatch
|
||||
|
||||
@ -5,6 +5,7 @@ from contextvars import ContextVar
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Optional, Final
|
||||
|
||||
from comfy_execution.graph_types import FrozenTopologicalSort
|
||||
from .cli_args import cli_args_configuration
|
||||
from .cli_args_types import Configuration
|
||||
from .component_model import cvpickle
|
||||
@ -17,15 +18,26 @@ from .progress_types import AbstractProgressRegistry, ProgressRegistryStub
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExecutionContext:
|
||||
# at any time
|
||||
server: ExecutorToClientProgress
|
||||
folder_names_and_paths: FolderNames
|
||||
custom_nodes: ExportedNodes
|
||||
configuration: Optional[Configuration] = None
|
||||
|
||||
# during prompt execution
|
||||
progress_registry: Optional[AbstractProgressRegistry] = None
|
||||
|
||||
# during node execution
|
||||
node_id: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
list_index: Optional[int] = None
|
||||
inference_mode: bool = True
|
||||
progress_registry: Optional[AbstractProgressRegistry] = None
|
||||
configuration: Optional[Configuration] = None
|
||||
execution_list: Optional[FrozenTopologicalSort] = None
|
||||
executed: Optional[frozenset] = None
|
||||
|
||||
@property
|
||||
def inputs(self) -> dict:
|
||||
return self.execution_list.dynprompt.get_node(self.node_id)['inputs']
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
@ -123,3 +135,11 @@ def context_set_node_and_prompt(prompt_id: str, node_id: str, list_index: Option
|
||||
new_ctx = replace(current_ctx, task_id=prompt_id, node_id=node_id, list_index=list_index)
|
||||
with _new_execution_context(new_ctx):
|
||||
yield new_ctx
|
||||
|
||||
|
||||
@contextmanager
|
||||
def context_set_execution_list_and_inputs(execution_list: FrozenTopologicalSort, executed: frozenset):
|
||||
current_ctx = current_execution_context()
|
||||
new_ctx = replace(current_ctx, execution_list=execution_list, executed=executed)
|
||||
with _new_execution_context(new_ctx):
|
||||
yield new_ctx
|
||||
|
||||
29
comfy/lazy_helpers.py
Normal file
29
comfy/lazy_helpers.py
Normal file
@ -0,0 +1,29 @@
|
||||
from .execution_context import current_execution_context
|
||||
from typing import Any, Generator, Sequence
|
||||
|
||||
|
||||
def is_input_unscheduled_unexecuted(*arg_names: Sequence[str]) -> Generator[bool, Any, None]:
|
||||
"""
|
||||
returns true if the given argument in the context of an executing node is not scheduled nor executed
|
||||
this will be true for inputs that are marked as lazy, and this method is more robust against nodes that return None
|
||||
:param arg_names: each arg to evaluate
|
||||
:return: True for each arg that is not scheduled nor executed
|
||||
"""
|
||||
context = current_execution_context()
|
||||
if context is None or context.execution_list is None:
|
||||
raise LookupError("Not executing a node")
|
||||
assert context.execution_list is not None
|
||||
dynprompt = context.execution_list.dynprompt
|
||||
executed = context.executed
|
||||
execution_list = context.execution_list
|
||||
inputs = context.inputs
|
||||
unscheduled_unexecuted = dynprompt.all_node_ids() - executed - set(execution_list.pendingNodes.keys())
|
||||
for arg_name in arg_names:
|
||||
if arg_name not in inputs:
|
||||
raise ValueError(f"Input {arg_name} not found")
|
||||
input_ = inputs[arg_name]
|
||||
if isinstance(input_, list) or isinstance(input_, tuple) and len(input_) == 2:
|
||||
node_id, *_ = input_
|
||||
yield node_id in unscheduled_unexecuted
|
||||
else:
|
||||
yield False
|
||||
@ -30,6 +30,7 @@ from ..cli_args import args
|
||||
from ..cmd import folder_paths, latent_preview
|
||||
from ..comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
||||
from ..component_model.deprecation import _deprecate_method
|
||||
from ..component_model.images_types import ImageMaskTuple
|
||||
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch, Latent
|
||||
from ..execution_context import current_execution_context
|
||||
from ..images import open_image
|
||||
@ -1791,7 +1792,7 @@ class LoadImage:
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "load_image"
|
||||
|
||||
def load_image(self, image: str) -> tuple[RGBImageBatch, MaskBatch]:
|
||||
def load_image(self, image: str) -> ImageMaskTuple:
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
output_images = []
|
||||
output_masks = []
|
||||
@ -1829,7 +1830,9 @@ class LoadImage:
|
||||
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
else:
|
||||
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
|
||||
# unlike upstream, the mask is now the size of the original image, even when there was no alpha channel
|
||||
# this is opaque
|
||||
mask = torch.zeros((h, w), dtype=torch.float32, device="cpu")
|
||||
output_images.append(image)
|
||||
output_masks.append(mask.unsqueeze(0))
|
||||
|
||||
@ -1840,7 +1843,7 @@ class LoadImage:
|
||||
output_image = output_images[0]
|
||||
output_mask = output_masks[0]
|
||||
|
||||
return (output_image, output_mask)
|
||||
return ImageMaskTuple(output_image, output_mask)
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, image):
|
||||
|
||||
@ -33,7 +33,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .component_model.images_types import RgbMaskTuple
|
||||
from .component_model.images_types import ImageMaskTuple
|
||||
|
||||
|
||||
read_exr = lambda fp: cv2.imread(fp, cv2.IMREAD_UNCHANGED).astype(np.float32)
|
||||
@ -50,7 +50,7 @@ def mut_linear_to_srgb(np_array) -> None:
|
||||
np_array[~less] = np.power(np_array[~less], 1 / 2.4) * 1.055 - 0.055
|
||||
|
||||
|
||||
def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple:
|
||||
def load_exr(file_path: str, srgb: bool) -> ImageMaskTuple:
|
||||
image = read_exr(file_path)
|
||||
rgb = np.flip(image[:, :, :3], 2).copy()
|
||||
if srgb:
|
||||
@ -62,7 +62,7 @@ def load_exr(file_path: str, srgb: bool) -> RgbMaskTuple:
|
||||
if image.shape[2] > 3:
|
||||
mask[0] = torch.from_numpy(np.clip(image[:, :, 3], 0, 1))
|
||||
|
||||
return RgbMaskTuple(rgb, mask)
|
||||
return ImageMaskTuple(rgb, mask)
|
||||
|
||||
|
||||
def load_exr_latent(file_path: str) -> Tuple[Tensor]:
|
||||
|
||||
@ -5,7 +5,7 @@ import inspect
|
||||
from typing import Optional, Type, Literal, Union, NamedTuple
|
||||
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
from comfy_execution.graph_utils import is_link, ExecutionBlocker
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||
from comfy.component_model.executor_types import DependencyCycleError, NodeInputError, NodeNotFoundError, \
|
||||
DependencyExecutionErrorMessage, ComboOptions
|
||||
@ -249,7 +249,7 @@ class ExecutionList(TopologicalSort):
|
||||
# for a PreviewImage to display a result as soon as it can
|
||||
# Some other heuristics could probably be used here to improve the UX further.
|
||||
def is_output(node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_type = self.dynprompt.get_node(node_id)['class_type']
|
||||
class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
return True
|
||||
@ -258,7 +258,7 @@ class ExecutionList(TopologicalSort):
|
||||
# If an available node is async, do that first.
|
||||
# This will execute the asynchronous function earlier, reducing the overall time.
|
||||
def is_async(node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_type = self.dynprompt.get_node(node_id)['class_type']
|
||||
class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
|
||||
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
|
||||
|
||||
|
||||
81
comfy_execution/graph_types.py
Normal file
81
comfy_execution/graph_types.py
Normal file
@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from typing import Mapping
|
||||
|
||||
from frozendict import frozendict, deepfreeze
|
||||
|
||||
from comfy.component_model.executor_types import NodeNotFoundError
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .graph import TopologicalSort, DynamicPrompt
|
||||
|
||||
NodeOutputByIdAndIndex = tuple[str, int]
|
||||
InputValue = typing.Union[NodeOutputByIdAndIndex, bool, bytes, int, float, str, typing.IO, typing.BinaryIO]
|
||||
Input = Mapping[str, InputValue]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FrozenTopologicalSort:
|
||||
dynprompt: "FrozenPrompt"
|
||||
pendingNodes: frozendict
|
||||
blockCount: frozendict
|
||||
blocking: frozendict
|
||||
externalBlocks: int
|
||||
|
||||
@classmethod
|
||||
def from_topological_sort(cls, ts: "TopologicalSort") -> FrozenTopologicalSort:
|
||||
return cls(
|
||||
dynprompt=FrozenPrompt.from_dynamic_prompt(ts.dynprompt),
|
||||
pendingNodes=deepfreeze(ts.pendingNodes),
|
||||
blockCount=deepfreeze(ts.blockCount),
|
||||
blocking=deepfreeze(ts.blocking),
|
||||
externalBlocks=ts.externalBlocks,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FrozenPrompt:
|
||||
original_prompt: frozendict
|
||||
ephemeral_prompt: frozendict
|
||||
ephemeral_parents: frozendict
|
||||
ephemeral_display: frozendict
|
||||
|
||||
@classmethod
|
||||
def from_dynamic_prompt(cls, dynprompt: "DynamicPrompt") -> FrozenPrompt:
|
||||
return cls(
|
||||
original_prompt=deepfreeze(dynprompt.original_prompt),
|
||||
ephemeral_prompt=deepfreeze(dynprompt.ephemeral_prompt),
|
||||
ephemeral_parents=deepfreeze(dynprompt.ephemeral_parents),
|
||||
ephemeral_display=deepfreeze(dynprompt.ephemeral_display),
|
||||
)
|
||||
|
||||
def get_node(self, node_id):
|
||||
if node_id in self.ephemeral_prompt:
|
||||
return self.ephemeral_prompt[node_id]
|
||||
if node_id in self.original_prompt:
|
||||
return self.original_prompt[node_id]
|
||||
raise NodeNotFoundError(f"Node {node_id} not found")
|
||||
|
||||
def has_node(self, node_id):
|
||||
return node_id in self.original_prompt or node_id in self.ephemeral_prompt
|
||||
|
||||
def get_real_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_parents:
|
||||
node_id = self.ephemeral_parents[node_id]
|
||||
return node_id
|
||||
|
||||
def get_parent_node_id(self, node_id):
|
||||
return self.ephemeral_parents.get(node_id, None)
|
||||
|
||||
def get_display_node_id(self, node_id):
|
||||
while node_id in self.ephemeral_display:
|
||||
node_id = self.ephemeral_display[node_id]
|
||||
return node_id
|
||||
|
||||
def all_node_ids(self):
|
||||
return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys()))
|
||||
|
||||
def get_original_prompt(self):
|
||||
return self.original_prompt
|
||||
@ -42,7 +42,7 @@ class CurrentNodeContext:
|
||||
self._cm = core_execution_context.context_set_node_and_prompt(
|
||||
prompt_id=prompt_id,
|
||||
node_id=node_id,
|
||||
list_index=list_index
|
||||
list_index=list_index,
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
179
comfy_extras/nodes/nodes_logic.py
Normal file
179
comfy_extras/nodes/nodes_logic.py
Normal file
@ -0,0 +1,179 @@
|
||||
import inspect
|
||||
import operator
|
||||
from typing import OrderedDict, Callable, Any
|
||||
|
||||
from comfy.comfy_types import IO
|
||||
from comfy.lazy_helpers import is_input_unscheduled_unexecuted
|
||||
from comfy.node_helpers import export_custom_nodes
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
|
||||
|
||||
def takes_n_args(obj, n):
|
||||
if not callable(obj):
|
||||
return False
|
||||
|
||||
try:
|
||||
sig = inspect.signature(obj)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
params = sig.parameters.values()
|
||||
|
||||
named_param_count = sum(
|
||||
1 for p in params
|
||||
if p.kind not in (inspect.Parameter.VAR_POSITIONAL,
|
||||
inspect.Parameter.VAR_KEYWORD)
|
||||
)
|
||||
|
||||
return named_param_count == n
|
||||
|
||||
|
||||
_BINARY_OPS: dict[str, Callable[[Any, Any], Any]] = OrderedDict({
|
||||
**{op: getattr(operator, op) for op in dir(operator) if takes_n_args(getattr(operator, op), 2)},
|
||||
"and": lambda a, b: a and b,
|
||||
"or": lambda a, b: a or b,
|
||||
})
|
||||
_UNARY_OPS: dict[str, Callable[[Any], Any]] = {
|
||||
**{op: getattr(operator, op) for op in dir(operator) if takes_n_args(getattr(operator, op), 1)},
|
||||
"not": lambda a: not a
|
||||
}
|
||||
|
||||
|
||||
class IsNone(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {},
|
||||
"optional": {
|
||||
"value": (IO.ANY, {}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "logic"
|
||||
|
||||
def execute(self, value=None):
|
||||
return (value is None,)
|
||||
|
||||
|
||||
class LazySwitch(CustomNode):
|
||||
"""
|
||||
sherlocked from KJ nodes with fixes
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"switch": ("BOOLEAN",),
|
||||
},
|
||||
"optional": {
|
||||
"on_false": (IO.ANY, {"lazy": True}),
|
||||
"on_true": (IO.ANY, {"lazy": True}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "logic"
|
||||
DESCRIPTION = "Controls flow of execution based on a boolean switch."
|
||||
|
||||
def check_lazy_status(self, switch, on_false=None, on_true=None):
|
||||
try:
|
||||
on_false_not_evaluated, on_true_not_evaluated = is_input_unscheduled_unexecuted("on_false", "on_true")
|
||||
except LookupError:
|
||||
on_false_not_evaluated, on_true_not_evaluated = on_false is None, on_true is None
|
||||
if switch and on_true_not_evaluated:
|
||||
return ["on_true"]
|
||||
if not switch and on_false_not_evaluated:
|
||||
return ["on_false"]
|
||||
return []
|
||||
|
||||
def execute(self, switch, on_false=None, on_true=None):
|
||||
value = on_true if switch else on_false
|
||||
return (value,)
|
||||
|
||||
|
||||
class UnaryOperation(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {},
|
||||
"optional": {
|
||||
"value": (IO.ANY, {}),
|
||||
"op": (list(_UNARY_OPS.keys()), {"default": "not"})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "logic"
|
||||
|
||||
def execute(self, value, op):
|
||||
return _UNARY_OPS[op](value),
|
||||
|
||||
|
||||
class BooleanUnaryOperation(UnaryOperation):
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
|
||||
def execute(self, value, op):
|
||||
return bool(super().execute(value, op)[0]),
|
||||
|
||||
|
||||
class BinaryOperation(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {},
|
||||
"optional": OrderedDict({
|
||||
"lhs": (IO.ANY, {"lazy": True}),
|
||||
"op": (list(_BINARY_OPS.keys()), {"default": "eq"}),
|
||||
"rhs": (IO.ANY, {"lazy": True}),
|
||||
})
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY,)
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "logic"
|
||||
DESCRIPTION = ""
|
||||
|
||||
def check_lazy_status(self, lhs=None, op=None, rhs=None) -> list[str]:
|
||||
try:
|
||||
lhs_not_evaluated, rhs_not_evaluated = is_input_unscheduled_unexecuted("lhs", "rhs")
|
||||
except LookupError:
|
||||
lhs_not_evaluated, rhs_not_evaluated = lhs is None, rhs is None
|
||||
lhs_evaluated, rhs_evaluated = not lhs_not_evaluated, not rhs_not_evaluated
|
||||
match op:
|
||||
case "and":
|
||||
if lhs_not_evaluated:
|
||||
return ["lhs"]
|
||||
if lhs_evaluated and lhs is not False and rhs_not_evaluated:
|
||||
return ["rhs"]
|
||||
return []
|
||||
case "or":
|
||||
if lhs_not_evaluated:
|
||||
return ["lhs"]
|
||||
if lhs_evaluated and lhs is not True and rhs_not_evaluated:
|
||||
return ["rhs"]
|
||||
return []
|
||||
case _:
|
||||
to_eval = []
|
||||
if lhs_not_evaluated:
|
||||
to_eval.append("lhs")
|
||||
if rhs_not_evaluated:
|
||||
to_eval.append("rhs")
|
||||
return to_eval
|
||||
|
||||
def execute(self, lhs, op, rhs):
|
||||
return _BINARY_OPS[op](lhs, rhs),
|
||||
|
||||
|
||||
class BooleanBinaryOperation(BinaryOperation):
|
||||
RETURN_TYPES = (IO.BOOLEAN,)
|
||||
|
||||
def execute(self, lhs, op, rhs):
|
||||
return bool(super().execute(lhs, op, rhs)[0]),
|
||||
|
||||
|
||||
export_custom_nodes()
|
||||
@ -34,7 +34,8 @@ from torch import Tensor
|
||||
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.comfy_types import IO
|
||||
from comfy.component_model.tensor_types import RGBAImageBatch, RGBImageBatch
|
||||
from comfy.component_model.images_types import ImageMaskTuple
|
||||
from comfy.component_model.tensor_types import RGBAImageBatch, RGBImageBatch, MaskBatch, ImageBatch
|
||||
from comfy.digest import digest
|
||||
from comfy.node_helpers import export_custom_nodes
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes, FunctionReturnsUIVariables, SaveNodeResult, \
|
||||
@ -809,18 +810,20 @@ class ImageRequestParameter(CustomNode):
|
||||
},
|
||||
"optional": {
|
||||
**_open_api_common_schema,
|
||||
"default_if_empty": ("IMAGE",)
|
||||
"default_if_empty": ("IMAGE",),
|
||||
"alpha_is_transparency": ("BOOLEAN", {"default": False}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "api/openapi"
|
||||
|
||||
def execute(self, value: str = "", default_if_empty=None, *args, **kwargs) -> ValidatedNodeResult:
|
||||
def execute(self, value: str = "", default_if_empty=None, alpha_is_transparency=False, *args, **kwargs) -> ImageMaskTuple:
|
||||
if value.strip() == "":
|
||||
return (default_if_empty,)
|
||||
output_images = []
|
||||
output_masks = []
|
||||
f: OpenFile
|
||||
fsspec_kwargs = {}
|
||||
if value.startswith('http'):
|
||||
@ -832,31 +835,41 @@ class ImageRequestParameter(CustomNode):
|
||||
})
|
||||
# todo: additional security is needed here to prevent users from accessing local paths
|
||||
# however this generally needs to be done with user accounts on all OSes
|
||||
with fsspec.open(value, mode="rb", **fsspec_kwargs) as f:
|
||||
# from LoadImage
|
||||
img = Image.open(f)
|
||||
for i in ImageSequence.Iterator(img):
|
||||
prev_value = None
|
||||
try:
|
||||
i = ImageOps.exif_transpose(i)
|
||||
except OSError:
|
||||
prev_value = ImageFile.LOAD_TRUNCATED_IMAGES
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
i = ImageOps.exif_transpose(i)
|
||||
finally:
|
||||
if prev_value is not None:
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
image = i.convert("RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image)[None,]
|
||||
output_images.append(image)
|
||||
if len(output_images) > 1:
|
||||
output_image = torch.cat(output_images, dim=0)
|
||||
else:
|
||||
output_image = output_images[0]
|
||||
return (output_image,)
|
||||
with fsspec.open_files(value, mode="rb", **fsspec_kwargs) as files:
|
||||
for f in files:
|
||||
# from LoadImage
|
||||
img = Image.open(f)
|
||||
for i in ImageSequence.Iterator(img):
|
||||
prev_value = None
|
||||
try:
|
||||
i = ImageOps.exif_transpose(i)
|
||||
except OSError:
|
||||
prev_value = ImageFile.LOAD_TRUNCATED_IMAGES
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
i = ImageOps.exif_transpose(i)
|
||||
finally:
|
||||
if prev_value is not None:
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
image = i.convert("RGBA" if alpha_is_transparency else "RGB")
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image)[None,]
|
||||
if 'A' in i.getbands():
|
||||
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
elif i.mode == 'P' and 'transparency' in i.info:
|
||||
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
else:
|
||||
mask = torch.zeros((image.shape[1], image.shape[2]), dtype=torch.float32, device="cpu")
|
||||
output_images.append(image)
|
||||
output_masks.append(mask.unsqueeze(0))
|
||||
|
||||
output_images_batched: ImageBatch = torch.cat(output_images, dim=0)
|
||||
output_masks_batched: MaskBatch = torch.cat(output_masks, dim=0)
|
||||
|
||||
return ImageMaskTuple(output_images_batched, output_masks_batched)
|
||||
|
||||
|
||||
export_custom_nodes()
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import time
|
||||
import urllib.request
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, AsyncGenerator
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
@ -128,7 +128,7 @@ class TestExecution:
|
||||
(0, True),
|
||||
(100, True),
|
||||
])
|
||||
async def client(self, request) -> ComfyClient:
|
||||
async def client(self, request):
|
||||
from ..inference.testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||||
|
||||
# ??? todo: we have to deal with this
|
||||
@ -784,3 +784,163 @@ class TestExecution:
|
||||
result = client.get_all_history(max_items=5, offset=len(all_history) - 1)
|
||||
|
||||
assert len(result) <= 1, "Should return at most 1 item when offset is near end"
|
||||
|
||||
async def test_lazy_switch_true_branch(self, client: ComfyClient, builder: GraphBuilder):
|
||||
await client.embedded_client.clear_cache()
|
||||
g = builder
|
||||
# Create a "True" boolean value
|
||||
true_int = g.node("StubInt", value=1)
|
||||
true_bool = g.node("TestIntConditions", a=true_int.out(0), b=1, operation="==") # 1 == 1 -> True
|
||||
|
||||
# Create nodes for branches
|
||||
node_true = g.node("StubImage", content="WHITE", height=32, width=32, batch_size=1)
|
||||
node_false = g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
|
||||
|
||||
# Create lazy switch
|
||||
# Note: LazySwitch is imported at the top of the file
|
||||
switch = g.node("LazySwitch", switch=true_bool.out(0), on_true=node_true.out(0), on_false=node_false.out(0))
|
||||
output = g.node("SaveImage", images=switch.out(0))
|
||||
|
||||
result = await client.run(g)
|
||||
|
||||
# Check execution
|
||||
assert result.did_run(true_int), "True stub int should run"
|
||||
assert result.did_run(true_bool), "Boolean condition node should run"
|
||||
assert result.did_run(node_true), "on_true node should run"
|
||||
assert not result.did_run(node_false), "on_false node should NOT run"
|
||||
assert result.did_run(switch), "LazySwitch node should run"
|
||||
assert result.did_run(output), "SaveImage node should run"
|
||||
|
||||
# Check output
|
||||
result_image = result.get_images(output)[0]
|
||||
assert numpy.array(result_image).mean() == 255, "Image should be white"
|
||||
|
||||
async def test_lazy_switch_false_branch(self, client: ComfyClient, builder: GraphBuilder):
|
||||
await client.embedded_client.clear_cache()
|
||||
g = builder
|
||||
# Create a "False" boolean value
|
||||
false_int = g.node("StubInt", value=0)
|
||||
false_bool = g.node("TestIntConditions", a=false_int.out(0), b=1, operation="==") # 0 == 1 -> False
|
||||
|
||||
# Create nodes for branches
|
||||
node_true = g.node("StubImage", content="WHITE", height=32, width=32, batch_size=1)
|
||||
node_false = g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
|
||||
|
||||
# Create lazy switch
|
||||
switch = g.node("LazySwitch", switch=false_bool.out(0), on_true=node_true.out(0), on_false=node_false.out(0))
|
||||
output = g.node("SaveImage", images=switch.out(0))
|
||||
|
||||
result = await client.run(g)
|
||||
|
||||
# Check execution
|
||||
assert result.did_run(false_int), "False stub int should run"
|
||||
assert result.did_run(false_bool), "Boolean condition node should run"
|
||||
assert not result.did_run(node_true), "on_true node should NOT run"
|
||||
assert result.did_run(node_false), "on_false node should run"
|
||||
assert result.did_run(switch), "LazySwitch node should run"
|
||||
assert result.did_run(output), "SaveImage node should run"
|
||||
|
||||
# Check output
|
||||
result_image = result.get_images(output)[0]
|
||||
assert numpy.array(result_image).mean() == 0, "Image should be black"
|
||||
|
||||
async def test_lazy_binary_op_and_short_circuit(self, client: ComfyClient, builder: GraphBuilder):
|
||||
await client.embedded_client.clear_cache()
|
||||
g = builder
|
||||
# Create a "False" boolean value
|
||||
false_int = g.node("StubInt", value=0)
|
||||
lhs_bool = g.node("TestIntConditions", a=false_int.out(0), b=1, operation="==") # 0 == 1 -> False
|
||||
|
||||
# Create a "True" boolean value for RHS (this node should not run)
|
||||
true_int_rhs = g.node("StubInt", value=1)
|
||||
rhs_bool = g.node("TestIntConditions", a=true_int_rhs.out(0), b=1, operation="==") # 1 == 1 -> True
|
||||
|
||||
# Create binary op
|
||||
# Note: BinaryOperation is imported at the top of the file
|
||||
binary_op = g.node("BinaryOperation", lhs=lhs_bool.out(0), op="and", rhs=rhs_bool.out(0))
|
||||
|
||||
# Create lazy switch to check result
|
||||
node_true = g.node("StubImage", content="WHITE", height=32, width=32, batch_size=1)
|
||||
node_false = g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
|
||||
switch = g.node("LazySwitch", switch=binary_op.out(0), on_true=node_true.out(0), on_false=node_false.out(0))
|
||||
output = g.node("SaveImage", images=switch.out(0))
|
||||
|
||||
result = await client.run(g)
|
||||
|
||||
# Check execution
|
||||
assert result.did_run(false_int), "LHS int node should run"
|
||||
assert result.did_run(lhs_bool), "LHS bool node should run"
|
||||
assert not result.did_run(true_int_rhs), "RHS int node should NOT run (short-circuit)"
|
||||
assert not result.did_run(rhs_bool), "RHS bool node should NOT run (short-circuit)"
|
||||
assert result.did_run(binary_op), "BinaryOp should run"
|
||||
assert not result.did_run(node_true), "on_true node should NOT run"
|
||||
assert result.did_run(node_false), "on_false node should run"
|
||||
|
||||
# Check output
|
||||
result_image = result.get_images(output)[0]
|
||||
assert numpy.array(result_image).mean() == 0, "Image should be black (result of 'and' was False)"
|
||||
|
||||
async def test_lazy_binary_op_or_short_circuit(self, client: ComfyClient, builder: GraphBuilder):
|
||||
await client.embedded_client.clear_cache()
|
||||
g = builder
|
||||
# Create a "True" boolean value
|
||||
true_int = g.node("StubInt", value=1)
|
||||
lhs_bool = g.node("TestIntConditions", a=true_int.out(0), b=1, operation="==") # 1 == 1 -> True
|
||||
|
||||
# Create a "False" boolean value for RHS (this node should not run)
|
||||
false_int_rhs = g.node("StubInt", value=0)
|
||||
rhs_bool = g.node("TestIntConditions", a=false_int_rhs.out(0), b=1, operation="==") # 0 == 1 -> False
|
||||
|
||||
# Create binary op
|
||||
binary_op = g.node("BinaryOperation", lhs=lhs_bool.out(0), op="or", rhs=rhs_bool.out(0))
|
||||
|
||||
# Create lazy switch to check result
|
||||
node_true = g.node("StubImage", content="WHITE", height=32, width=32, batch_size=1)
|
||||
node_false = g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
|
||||
switch = g.node("LazySwitch", switch=binary_op.out(0), on_true=node_true.out(0), on_false=node_false.out(0))
|
||||
output = g.node("SaveImage", images=switch.out(0))
|
||||
|
||||
result = await client.run(g)
|
||||
|
||||
# Check execution
|
||||
assert result.did_run(true_int), "LHS int node should run"
|
||||
assert result.did_run(lhs_bool), "LHS bool node should run"
|
||||
assert not result.did_run(false_int_rhs), "RHS int node should NOT run (short-circuit)"
|
||||
assert not result.did_run(rhs_bool), "RHS bool node should NOT run (short-circuit)"
|
||||
assert result.did_run(binary_op), "BinaryOp should run"
|
||||
assert result.did_run(node_true), "on_true node should run"
|
||||
assert not result.did_run(node_false), "on_false node should NOT run"
|
||||
|
||||
# Check output
|
||||
result_image = result.get_images(output)[0]
|
||||
assert numpy.array(result_image).mean() == 255, "Image should be white (result of 'or' was True)"
|
||||
|
||||
async def test_lazy_switch_with_none_input(self, client: ComfyClient, builder: GraphBuilder):
|
||||
await client.embedded_client.clear_cache()
|
||||
g = builder
|
||||
# Create a "False" boolean value
|
||||
false_int = g.node("StubInt", value=0)
|
||||
false_bool = g.node("TestIntConditions", a=false_int.out(0), b=1, operation="==") # 0 == 1 -> False
|
||||
|
||||
# Create nodes for branches
|
||||
# This node will return None as its value is empty and default_if_empty is not set
|
||||
node_true_image = g.node("ImageRequestParameter", value="", description="1")
|
||||
node_false_image = g.node("StubImage", content="BLACK", height=32, width=32, batch_size=1)
|
||||
|
||||
# Create lazy switch
|
||||
switch = g.node("LazySwitch", switch=false_bool.out(0), on_true=node_true_image.out(0), on_false=node_false_image.out(0))
|
||||
output = g.node("SaveImage", images=switch.out(0))
|
||||
|
||||
result = await client.run(g)
|
||||
|
||||
# Check execution
|
||||
assert result.did_run(false_int), "False stub int should run"
|
||||
assert result.did_run(false_bool), "Boolean condition node should run"
|
||||
assert not result.did_run(node_true_image), "on_true (ImageRequestParameter) node should NOT run"
|
||||
assert result.did_run(node_false_image), "on_false node should run"
|
||||
assert result.did_run(switch), "LazySwitch node should run"
|
||||
assert result.did_run(output), "SaveImage node should run"
|
||||
|
||||
# Check output
|
||||
result_image = result.get_images(output)[0]
|
||||
assert numpy.array(result_image).mean() == 0, "Image should be black"
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import platform
|
||||
|
||||
# noqa: E402
|
||||
from comfy.cmd.main_pre import args
|
||||
import os
|
||||
@ -280,7 +283,6 @@ def test_posix_join_curly_brackets():
|
||||
assert joined_path == "a_{test}/b/c"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
|
||||
def test_file_request_parameter(use_temporary_input_directory):
|
||||
_image_1x1_px = np.array([[[255, 0, 0]]], dtype=np.uint8)
|
||||
image_path = os.path.join(use_temporary_input_directory, "test_image.png")
|
||||
@ -288,7 +290,7 @@ def test_file_request_parameter(use_temporary_input_directory):
|
||||
image.save(image_path)
|
||||
|
||||
n = ImageRequestParameter()
|
||||
loaded_image, = n.execute(value=image_path)
|
||||
loaded_image, *_ = n.execute(value=image_path)
|
||||
assert loaded_image.shape == (1, 1, 1, 3)
|
||||
from comfy.nodes.base_nodes import LoadImage
|
||||
|
||||
@ -299,13 +301,174 @@ def test_file_request_parameter(use_temporary_input_directory):
|
||||
assert torch.allclose(loaded_image, load_image_node_rgb)
|
||||
|
||||
|
||||
def test_file_request_parameter2(use_temporary_input_directory):
|
||||
n = ImageRequestParameter()
|
||||
|
||||
# Test 1: Load a single RGB image
|
||||
_image_1x1_px_rgb = np.array([[[255, 0, 0]]], dtype=np.uint8) # 1x1 RGB
|
||||
image_path_rgb = os.path.join(use_temporary_input_directory, "test_image_rgb.png")
|
||||
image_rgb = Image.fromarray(_image_1x1_px_rgb, 'RGB')
|
||||
image_rgb.save(image_path_rgb)
|
||||
|
||||
loaded_image_rgb, loaded_mask_rgb = n.execute(value=image_path_rgb, alpha_is_transparency=True)
|
||||
|
||||
# Node converts RGB to RGBA
|
||||
assert loaded_image_rgb.shape == (1, 1, 1, 4) # B, H, W, C
|
||||
# Check RGB values
|
||||
assert torch.allclose(loaded_image_rgb[0, 0, 0, :3], torch.tensor([1.0, 0.0, 0.0]))
|
||||
# Check added Alpha channel
|
||||
assert torch.allclose(loaded_image_rgb[0, 0, 0, 3], torch.tensor(1.0))
|
||||
# Check mask (should be all 0s for 1.0 alpha)
|
||||
assert loaded_mask_rgb.shape == (1, 1, 1) # B, H, W
|
||||
assert torch.all(loaded_mask_rgb == 0.0)
|
||||
|
||||
# Test 2: Load a single RGBA image with transparency
|
||||
_image_1x1_px_rgba = np.array([[[255, 0, 0, 128]]], dtype=np.uint8) # 1x1 RGBA
|
||||
image_path_rgba = os.path.join(use_temporary_input_directory, "test_image_rgba.png")
|
||||
image_rgba = Image.fromarray(_image_1x1_px_rgba, 'RGBA')
|
||||
image_rgba.save(image_path_rgba)
|
||||
|
||||
loaded_image_rgba, loaded_mask_rgba = n.execute(value=image_path_rgba, alpha_is_transparency=True)
|
||||
|
||||
# Node should load RGBA as is
|
||||
assert loaded_image_rgba.shape == (1, 1, 1, 4) # B, H, W, C
|
||||
# Check RGBA values
|
||||
assert torch.allclose(loaded_image_rgba[0, 0, 0, :], torch.tensor([1.0, 0.0, 0.0, 128 / 255.0]))
|
||||
# Check mask (should be 1.0 - alpha)
|
||||
assert loaded_mask_rgba.shape == (1, 1, 1) # B, H, W
|
||||
assert torch.allclose(loaded_mask_rgba[0, 0, 0], torch.tensor(1.0 - 128 / 255.0))
|
||||
|
||||
# Test 3: Load a single RGB image with alpha_is_transparency=False
|
||||
loaded_image_rgb_no_alpha, loaded_mask_rgb_no_alpha = n.execute(value=image_path_rgb, alpha_is_transparency=False)
|
||||
|
||||
# Node converts to RGB
|
||||
assert loaded_image_rgb_no_alpha.shape == (1, 1, 1, 3) # B, H, W, C
|
||||
# Check RGB values
|
||||
assert torch.allclose(loaded_image_rgb_no_alpha[0, 0, 0, :], torch.tensor([1.0, 0.0, 0.0]))
|
||||
# Check mask
|
||||
assert loaded_mask_rgb_no_alpha.shape == (1, 1, 1) # B, H, W
|
||||
|
||||
# Test 4: Load a single RGBA image with alpha_is_transparency=False
|
||||
loaded_image_rgba_no_alpha, loaded_mask_rgba_no_alpha = n.execute(value=image_path_rgba, alpha_is_transparency=False)
|
||||
|
||||
# Node should load RGBA as RGB (dropping alpha)
|
||||
assert loaded_image_rgba_no_alpha.shape == (1, 1, 1, 3) # B, H, W, C
|
||||
# Check RGB values (straight, not pre-multiplied)
|
||||
assert torch.allclose(loaded_image_rgba_no_alpha[0, 0, 0, :], torch.tensor([1.0, 0.0, 0.0]))
|
||||
assert loaded_mask_rgba_no_alpha.shape == (1, 1, 1) # B, H, W
|
||||
|
||||
|
||||
def test_file_request_parameter_glob(use_temporary_input_directory):
|
||||
# 1. Create dummy images (2x2)
|
||||
# Image 1 (RGB)
|
||||
img_rgb_data = np.array([
|
||||
[[255, 0, 0], [0, 255, 0]],
|
||||
[[0, 0, 255], [255, 255, 255]]
|
||||
], dtype=np.uint8)
|
||||
img_rgb = Image.fromarray(img_rgb_data, 'RGB')
|
||||
path_rgb = os.path.join(use_temporary_input_directory, "img_rgb.png")
|
||||
img_rgb.save(path_rgb)
|
||||
|
||||
# Image 2 (RGBA with transparency)
|
||||
img_rgba_data_rgb = np.array([
|
||||
[[10, 20, 30], [40, 50, 60]],
|
||||
[[70, 80, 90], [100, 110, 120]]
|
||||
], dtype=np.uint8)
|
||||
img_rgba_data_a = np.array([
|
||||
[255, 128], # alpha 1.0, 0.5
|
||||
[0, 255] # alpha 0.0, 1.0
|
||||
], dtype=np.uint8)
|
||||
img_rgba_data = np.dstack((img_rgba_data_rgb, img_rgba_data_a))
|
||||
img_rgba = Image.fromarray(img_rgba_data, 'RGBA')
|
||||
path_rgba = os.path.join(use_temporary_input_directory, "img_rgba.png")
|
||||
img_rgba.save(path_rgba)
|
||||
|
||||
# 2. Construct fsspec URL
|
||||
prefix = 'file:///' if platform.system() == "Windows" else 'file://'
|
||||
# Add glob pattern
|
||||
url = f"{prefix}{use_temporary_input_directory}/*.png"
|
||||
|
||||
n = ImageRequestParameter()
|
||||
|
||||
# 3. Test with alpha_is_transparency=True (default)
|
||||
loaded_images, loaded_masks = n.execute(value=url, alpha_is_transparency=True)
|
||||
|
||||
# 4. Verify results (True) - Order independent
|
||||
assert loaded_images.shape == (2, 2, 2, 4) # B=2, H=2, W=2, C=RGBA
|
||||
assert loaded_masks.shape == (2, 2, 2) # B=2, H=2, W=2
|
||||
|
||||
# Find which image is which by checking the mask sum
|
||||
mask_sums = torch.sum(loaded_masks, dim=(1, 2))
|
||||
expected_rgba_mask_sum = (1.0 - 128/255.0) + 1.0 # From alpha 128 and 0
|
||||
|
||||
# Argmin should find the all-zero mask (from the RGB image)
|
||||
rgb_img_index = torch.argmin(mask_sums)
|
||||
# Argmax should find the mask with transparency
|
||||
rgba_img_index = torch.argmax(mask_sums)
|
||||
|
||||
assert rgb_img_index != rgba_img_index
|
||||
assert torch.allclose(mask_sums[rgb_img_index], torch.tensor(0.0))
|
||||
assert torch.allclose(mask_sums[rgba_img_index], torch.tensor(expected_rgba_mask_sum))
|
||||
|
||||
# Check RGB image tensor (which was converted to RGBA)
|
||||
rgb_image_tensor = loaded_images[rgb_img_index]
|
||||
assert torch.allclose(rgb_image_tensor[0, 0, :3], torch.tensor([1.0, 0.0, 0.0])) # Red pixel
|
||||
assert torch.allclose(rgb_image_tensor[0, 0, 3], torch.tensor(1.0)) # Added alpha
|
||||
assert torch.allclose(rgb_image_tensor[1, 1, :3], torch.tensor([1.0, 1.0, 1.0])) # White pixel
|
||||
assert torch.allclose(rgb_image_tensor[1, 1, 3], torch.tensor(1.0)) # Added alpha
|
||||
|
||||
# Check RGBA image tensor
|
||||
rgba_image_tensor = loaded_images[rgba_img_index]
|
||||
# Pixel [0, 1] (alpha 128)
|
||||
assert torch.allclose(rgba_image_tensor[0, 1, :3], torch.tensor([40/255.0, 50/255.0, 60/255.0]))
|
||||
assert torch.allclose(rgba_image_tensor[0, 1, 3], torch.tensor(128/255.0)) # Original alpha
|
||||
# Pixel [1, 0] (alpha 0)
|
||||
assert torch.allclose(rgba_image_tensor[1, 0, :3], torch.tensor([70/255.0, 80/255.0, 90/255.0]))
|
||||
assert torch.allclose(rgba_image_tensor[1, 0, 3], torch.tensor(0.0)) # Original alpha
|
||||
assert torch.allclose(rgba_image_tensor[1, 0, 3], torch.tensor(0.0)) # Original alpha
|
||||
|
||||
# 5. Test with alpha_is_transparency=False
|
||||
loaded_images_no_alpha, loaded_masks_no_alpha = n.execute(value=url, alpha_is_transparency=False)
|
||||
|
||||
# 6. Verify results (False)
|
||||
assert loaded_images_no_alpha.shape == (2, 2, 2, 3) # B=2, H=2, W=2, C=RGB
|
||||
assert loaded_masks_no_alpha.shape == (2, 2, 2) # B=2, H, W (empty)
|
||||
|
||||
# Find which image is which by checking pixel sum (RGB image has brighter pixels)
|
||||
img_sums = torch.sum(loaded_images_no_alpha, dim=(1, 2, 3))
|
||||
rgb_img_index = torch.argmax(img_sums)
|
||||
rgba_img_index = torch.argmin(img_sums)
|
||||
|
||||
assert rgb_img_index != rgba_img_index
|
||||
|
||||
# Check RGB image tensor
|
||||
rgb_image_tensor_no_alpha = loaded_images_no_alpha[rgb_img_index]
|
||||
assert torch.allclose(rgb_image_tensor_no_alpha[0, 0, :], torch.tensor([1.0, 0.0, 0.0])) # Red
|
||||
assert torch.allclose(rgb_image_tensor_no_alpha[1, 1, :], torch.tensor([1.0, 1.0, 1.0])) # White
|
||||
|
||||
# Check RGBA image tensor (which was converted to RGB, dropping alpha - straight matte)
|
||||
rgba_image_tensor_no_alpha = loaded_images_no_alpha[rgba_img_index]
|
||||
# Pixel [0, 1] (RGB [40, 50, 60], alpha 128)
|
||||
# RGB channels are passed through unaltered
|
||||
expected_rgb_0_1 = torch.tensor([40 / 255.0, 50 / 255.0, 60 / 255.0])
|
||||
assert torch.allclose(rgba_image_tensor_no_alpha[0, 1, :], expected_rgb_0_1)
|
||||
# Pixel [1, 0] (RGB [70, 80, 90], alpha 0)
|
||||
# RGB channels are passed through unaltered
|
||||
expected_rgb_1_0 = torch.tensor([70 / 255.0, 80 / 255.0, 90 / 255.0])
|
||||
assert torch.allclose(rgba_image_tensor_no_alpha[1, 0, :], expected_rgb_1_0)
|
||||
|
||||
def test_file_request_to_http_url_no_exceptions():
|
||||
n = ImageRequestParameter()
|
||||
loaded_image, = n.execute(value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a6/A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg/484px-A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg")
|
||||
_, height, width, channels = loaded_image.shape
|
||||
loaded_image, loaded_mask = n.execute(value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a6/A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg/484px-A_rainbow_at_sunset_after_rain_in_Gaziantep%2C_Turkey.IMG_2448.jpg")
|
||||
# This is an RGB jpg, so it will be converted to RGBA
|
||||
b, height, width, channels = loaded_image.shape
|
||||
assert b == 1
|
||||
assert width == 484
|
||||
assert height == 480
|
||||
assert channels == 3
|
||||
# Mask should be all zeros
|
||||
assert loaded_mask.shape == (1, 480, 484)
|
||||
assert torch.all(loaded_mask == 0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("format,bits,supports_16bit", [
|
||||
@ -676,4 +839,4 @@ def test_numeric_exif(format, use_temporary_output_directory):
|
||||
assert tag_id is not None
|
||||
if tag_id in exif_data:
|
||||
# Convert both to strings for comparison since formats might store numbers differently
|
||||
assert str(exif_data[tag_id]) == expected_value
|
||||
assert str(exif_data[tag_id]) == expected_value
|
||||
|
||||
158
tests/unit/test_operator_nodes.py
Normal file
158
tests/unit/test_operator_nodes.py
Normal file
@ -0,0 +1,158 @@
|
||||
import pytest
|
||||
from comfy_extras.nodes.nodes_logic import LazySwitch, UnaryOperation, BooleanUnaryOperation, BinaryOperation, \
|
||||
BooleanBinaryOperation
|
||||
|
||||
|
||||
def test_lazy_switch():
|
||||
n = LazySwitch()
|
||||
|
||||
# Test True branch
|
||||
res_true_check = n.check_lazy_status(switch=True)
|
||||
assert res_true_check == ["on_true"]
|
||||
res_true, = n.execute(switch=True, on_false="val_false", on_true="val_true")
|
||||
assert res_true == "val_true"
|
||||
|
||||
# Test False branch
|
||||
res_false_check = n.check_lazy_status(switch=False)
|
||||
assert res_false_check == ["on_false"]
|
||||
res_false, = n.execute(switch=False, on_false="val_false", on_true="val_true")
|
||||
assert res_false == "val_false"
|
||||
|
||||
# Test with None
|
||||
res_none, = n.execute(switch=False, on_false=None, on_true="val_true")
|
||||
assert res_none is None
|
||||
|
||||
|
||||
def test_unary_operation():
|
||||
n = UnaryOperation()
|
||||
|
||||
# Test 'not'
|
||||
res_not_true, = n.execute(value=True, op="not")
|
||||
assert res_not_true is False
|
||||
res_not_false, = n.execute(value=False, op="not")
|
||||
assert res_not_false is True
|
||||
res_not_str, = n.execute(value="hello", op="not")
|
||||
assert res_not_str is False
|
||||
res_not_empty_str, = n.execute(value="", op="not")
|
||||
assert res_not_empty_str is True
|
||||
|
||||
# Test 'inv' (invert, ~)
|
||||
res_inv, = n.execute(value=5, op="inv")
|
||||
assert res_inv == ~5 # -6
|
||||
|
||||
# Test 'neg' (-)
|
||||
res_neg, = n.execute(value=10, op="neg")
|
||||
assert res_neg == -10
|
||||
res_neg_neg, = n.execute(value=-10, op="neg")
|
||||
assert res_neg_neg == 10
|
||||
|
||||
# Test 'pos' (+)
|
||||
res_pos, = n.execute(value=-5, op="pos")
|
||||
assert res_pos == -5
|
||||
res_pos_pos, = n.execute(value=5, op="pos")
|
||||
assert res_pos_pos == 5
|
||||
|
||||
|
||||
def test_boolean_unary_operation():
|
||||
n = BooleanUnaryOperation()
|
||||
|
||||
# Test 'not'
|
||||
res_not_true, = n.execute(value=True, op="not")
|
||||
assert res_not_true is False
|
||||
res_not_false, = n.execute(value=False, op="not")
|
||||
assert res_not_false is True
|
||||
|
||||
# Test truthiness
|
||||
res_not_int_1, = n.execute(value=1, op="not")
|
||||
assert res_not_int_1 is False
|
||||
res_not_int_0, = n.execute(value=0, op="not")
|
||||
assert res_not_int_0 is True
|
||||
res_not_str, = n.execute(value="hello", op="not")
|
||||
assert res_not_str is False
|
||||
res_not_empty_str, = n.execute(value="", op="not")
|
||||
assert res_not_empty_str is True
|
||||
|
||||
|
||||
def test_binary_operation():
|
||||
n = BinaryOperation()
|
||||
|
||||
# Test ops
|
||||
res_eq, = n.execute(lhs=5, op="eq", rhs=5)
|
||||
assert res_eq is True
|
||||
res_ne, = n.execute(lhs=5, op="ne", rhs=6)
|
||||
assert res_ne is True
|
||||
res_lt, = n.execute(lhs=5, op="lt", rhs=6)
|
||||
assert res_lt is True
|
||||
res_gt, = n.execute(lhs=6, op="gt", rhs=5)
|
||||
assert res_gt is True
|
||||
res_le, = n.execute(lhs=5, op="le", rhs=5)
|
||||
assert res_le is True
|
||||
res_ge, = n.execute(lhs=5, op="ge", rhs=5)
|
||||
assert res_ge is True
|
||||
res_add, = n.execute(lhs=5, op="add", rhs=3)
|
||||
assert res_add == 8
|
||||
res_sub, = n.execute(lhs=5, op="sub", rhs=3)
|
||||
assert res_sub == 2
|
||||
res_mul, = n.execute(lhs=5, op="mul", rhs=3)
|
||||
assert res_mul == 15
|
||||
res_div, = n.execute(lhs=10, op="truediv", rhs=4)
|
||||
assert res_div == 2.5
|
||||
res_floor_div, = n.execute(lhs=10, op="floordiv", rhs=3)
|
||||
assert res_floor_div == 3
|
||||
|
||||
# Test logical 'and'
|
||||
res_and_tt, = n.execute(lhs=True, op="and", rhs=True)
|
||||
assert res_and_tt is True
|
||||
res_and_tf, = n.execute(lhs=True, op="and", rhs=False)
|
||||
assert res_and_tf is False
|
||||
res_and_ff, = n.execute(lhs=False, op="and", rhs=False)
|
||||
assert res_and_ff is False
|
||||
|
||||
# Test logical 'or'
|
||||
res_or_tf, = n.execute(lhs=True, op="or", rhs=False)
|
||||
assert res_or_tf is True
|
||||
res_or_ft, = n.execute(lhs=False, op="or", rhs=True)
|
||||
assert res_or_ft is True
|
||||
res_or_ff, = n.execute(lhs=False, op="or", rhs=False)
|
||||
assert res_or_ff is False
|
||||
|
||||
|
||||
def test_binary_operation_lazy_check():
|
||||
n = BinaryOperation()
|
||||
|
||||
# Test standard ops
|
||||
assert n.check_lazy_status(op="eq") == ["lhs", "rhs"]
|
||||
assert n.check_lazy_status(op="add", lhs=1, rhs=2) == ["lhs", "rhs"]
|
||||
|
||||
# Test 'and'
|
||||
assert n.check_lazy_status(op="and", lhs=None) == ["lhs"]
|
||||
assert n.check_lazy_status(op="and", lhs=True, rhs=None) == ["rhs"]
|
||||
assert n.check_lazy_status(op="and", lhs=False, rhs=None) == []
|
||||
assert n.check_lazy_status(op="and", lhs=False, rhs=True) == []
|
||||
|
||||
# Test 'or'
|
||||
assert n.check_lazy_status(op="or", lhs=None) == ["lhs"]
|
||||
assert n.check_lazy_status(op="or", lhs=True, rhs=None) == []
|
||||
assert n.check_lazy_status(op="or", lhs=False, rhs=None) == ["rhs"]
|
||||
assert n.check_lazy_status(op="or", lhs=True, rhs=False) == []
|
||||
|
||||
|
||||
def test_boolean_binary_operation():
|
||||
n = BooleanBinaryOperation()
|
||||
|
||||
# Test 'eq'
|
||||
res_eq, = n.execute(lhs=5, op="eq", rhs=5)
|
||||
assert res_eq is True
|
||||
res_ne, = n.execute(lhs=5, op="eq", rhs=6)
|
||||
assert res_ne is False
|
||||
|
||||
# Test truthiness
|
||||
res_and_truthy, = n.execute(lhs="hello", op="and", rhs=1)
|
||||
assert res_and_truthy is True
|
||||
res_and_falsy, = n.execute(lhs="hello", op="and", rhs=0)
|
||||
assert res_and_falsy is False
|
||||
res_or_falsy, = n.execute(lhs="", op="or", rhs=0)
|
||||
assert res_or_falsy is False
|
||||
res_or_truthy, = n.execute(lhs="", op="or", rhs="test")
|
||||
assert res_or_truthy is True
|
||||
|
||||
Loading…
Reference in New Issue
Block a user