diff --git a/pytest.ini b/pytest.ini index b5a68e0f1..d34fb5190 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,6 @@ [pytest] markers = inference: mark as inference test (deselect with '-m "not inference"') + execution: mark as execution test (deselect with '-m "not execution"') testpaths = tests -addopts = -s \ No newline at end of file +addopts = -s diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py new file mode 100644 index 000000000..52e195627 --- /dev/null +++ b/tests/inference/test_execution.py @@ -0,0 +1,294 @@ +from io import BytesIO +import numpy +from PIL import Image +import pytest +from pytest import fixture +import time +import torch +from typing import Union, Dict +import json +import subprocess +import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client) +import uuid +import urllib.request +import urllib.parse +from comfy.graph_utils import GraphBuilder, Node + +class RunResult: + def __init__(self, prompt_id: str): + self.outputs: Dict[str,Dict] = {} + self.runs: Dict[str,bool] = {} + self.prompt_id: str = prompt_id + + def get_output(self, node: Node): + return self.outputs.get(node.id, None) + + def did_run(self, node: Node): + return self.runs.get(node.id, False) + + def get_images(self, node: Node): + output = self.get_output(node) + if output is None: + return [] + return output.get('image_objects', []) + + def get_prompt_id(self): + return self.prompt_id + +class ComfyClient: + def __init__(self): + self.test_name = "" + + def connect(self, + listen:str = '127.0.0.1', + port:Union[str,int] = 8188, + client_id: str = str(uuid.uuid4()) + ): + self.client_id = client_id + self.server_address = f"{listen}:{port}" + ws = websocket.WebSocket() + ws.connect("ws://{}/ws?clientId={}".format(self.server_address, self.client_id)) + self.ws = ws + + def queue_prompt(self, prompt): + p = {"prompt": prompt, "client_id": self.client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request("http://{}/prompt".format(self.server_address), data=data) + return json.loads(urllib.request.urlopen(req).read()) + + def get_image(self, filename, subfolder, folder_type): + data = {"filename": filename, "subfolder": subfolder, "type": folder_type} + url_values = urllib.parse.urlencode(data) + with urllib.request.urlopen("http://{}/view?{}".format(self.server_address, url_values)) as response: + return response.read() + + def get_history(self, prompt_id): + with urllib.request.urlopen("http://{}/history/{}".format(self.server_address, prompt_id)) as response: + return json.loads(response.read()) + + def set_test_name(self, name): + self.test_name = name + + def run(self, graph): + prompt = graph.finalize() + for node in graph.nodes.values(): + if node.class_type == 'SaveImage': + node.inputs['filename_prefix'] = self.test_name + + prompt_id = self.queue_prompt(prompt)['prompt_id'] + result = RunResult(prompt_id) + while True: + out = self.ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data['prompt_id'] != prompt_id: + continue + if data['node'] is None: + break + result.runs[data['node']] = True + elif message['type'] == 'execution_error': + raise Exception(message['data']) + elif message['type'] == 'execution_cached': + pass # Probably want to store this off for testing + + history = self.get_history(prompt_id)[prompt_id] + for o in history['outputs']: + for node_id in history['outputs']: + node_output = history['outputs'][node_id] + result.outputs[node_id] = node_output + if 'images' in node_output: + images_output = [] + for image in node_output['images']: + image_data = self.get_image(image['filename'], image['subfolder'], image['type']) + image_obj = Image.open(BytesIO(image_data)) + images_output.append(image_obj) + node_output['image_objects'] = images_output + + return result + +# +# Loop through these variables +# +@pytest.mark.execution +class TestExecution: + # + # Initialize server and client + # + @fixture(scope="class", autouse=True) + def _server(self, args_pytest): + # Start server + p = subprocess.Popen([ + 'python','main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + '--enable-variants', + ]) + yield + p.kill() + torch.cuda.empty_cache() + + def start_client(self, listen:str, port:int): + # Start client + comfy_client = ComfyClient() + # Connect to server (with retries) + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + comfy_client.connect(listen=listen, port=port) + except ConnectionRefusedError as e: + print(e) + print(f"({i+1}/{n_tries}) Retrying...") + else: + break + return comfy_client + + @fixture(scope="class", autouse=True) + def shared_client(self, args_pytest, _server): + client = self.start_client(args_pytest["listen"], args_pytest["port"]) + yield client + del client + torch.cuda.empty_cache() + + @fixture + def client(self, shared_client, request): + shared_client.set_test_name(f"execution[{request.node.name}]") + yield shared_client + + def clear_cache(self, client: ComfyClient): + g = GraphBuilder(prefix="foo") + random = g.node("StubImage", content="NOISE", height=1, width=1, batch_size=1) + g.node("PreviewImage", images=random.out(0)) + client.run(g) + + @fixture + def builder(self): + yield GraphBuilder(prefix="") + + def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + output = g.node("SaveImage", images=lazy_mix.out(0)) + result = client.run(g) + + result_image = result.get_images(output)[0] + assert numpy.array(result_image).any() == 0, "Image should be black" + assert result.did_run(input1) + assert not result.did_run(input2) + assert result.did_run(mask) + assert result.did_run(lazy_mix) + + def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): + self.clear_cache(client) + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix.out(0)) + + result1 = client.run(g) + result2 = client.run(g) + for node_id, node in g.nodes.items(): + assert result1.did_run(node), f"Node {node_id} didn't run" + assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + + def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): + self.clear_cache(client) + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix.out(0)) + + result1 = client.run(g) + mask.inputs['value'] = 0.4 + result2 = client.run(g) + for node_id, node in g.nodes.items(): + assert result1.did_run(node), f"Node {node_id} didn't run" + assert not result2.did_run(input1), "Input1 should have been cached" + assert not result2.did_run(input2), "Input2 should have been cached" + assert result2.did_run(mask), "Mask should have been re-run" + assert result2.did_run(lazy_mix), "Lazy mix should have been re-run" + + def test_error(self, client: ComfyClient, builder: GraphBuilder): + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + # Different size of the two images + input2 = g.node("StubImage", content="NOISE", height=256, width=256, batch_size=1) + mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) + + lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0)) + g.node("SaveImage", images=lazy_mix.out(0)) + + try: + client.run(g) + except Exception as e: + assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" + + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): + g = builder + # Creating the nodes in this specific order previously caused a bug + save = g.node("SaveImage") + is_changed = g.node("TestCustomIsChanged", should_change=False) + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + + save.set_input('images', is_changed.out(0)) + is_changed.set_input('image', input1.out(0)) + + result1 = client.run(g) + result2 = client.run(g) + is_changed.set_input('should_change', True) + result3 = client.run(g) + result4 = client.run(g) + assert result1.did_run(is_changed), "is_changed should have been run" + assert not result2.did_run(is_changed), "is_changed should have been cached" + assert result3.did_run(is_changed), "is_changed should have been re-run" + assert result4.did_run(is_changed), "is_changed should not have been cached" + + def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder): + self.clear_cache(client) + g = builder + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + input3 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input4 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + average = g.node("TestVariadicAverage", input1=input1.out(0), input2=input2.out(0), input3=input3.out(0), input4=input4.out(0)) + output = g.node("SaveImage", images=average.out(0)) + + result = client.run(g) + result_image = result.get_images(output)[0] + expected = 255 // 4 + assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey" + assert result.did_run(input1) + assert result.did_run(input2) + + def test_for_loop(self, client: ComfyClient, builder: GraphBuilder): + g = builder + iterations = 4 + input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) + input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + is_changed = g.node("TestCustomIsChanged", should_change=True, image=input2.out(0)) + for_open = g.node("TestForLoopOpen", remaining=iterations, initial_value1=is_changed.out(0)) + average = g.node("TestVariadicAverage", input1=input1.out(0), input2=for_open.out(2)) + for_close = g.node("TestForLoopClose", flow_control=for_open.out(0), initial_value1=average.out(0)) + output = g.node("SaveImage", images=for_close.out(0)) + + for iterations in range(1, 5): + for_open.set_input('remaining', iterations) + result = client.run(g) + result_image = result.get_images(output)[0] + expected = 255 // (2 ** iterations) + assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey" + assert result.did_run(is_changed) diff --git a/tests/inference/testing_nodes/testing-pack/__init__.py b/tests/inference/testing_nodes/testing-pack/__init__.py new file mode 100644 index 000000000..dcc71659a --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/__init__.py @@ -0,0 +1,23 @@ +from .specific_tests import TEST_NODE_CLASS_MAPPINGS, TEST_NODE_DISPLAY_NAME_MAPPINGS +from .flow_control import FLOW_CONTROL_NODE_CLASS_MAPPINGS, FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS +from .util import UTILITY_NODE_CLASS_MAPPINGS, UTILITY_NODE_DISPLAY_NAME_MAPPINGS +from .conditions import CONDITION_NODE_CLASS_MAPPINGS, CONDITION_NODE_DISPLAY_NAME_MAPPINGS +from .stubs import TEST_STUB_NODE_CLASS_MAPPINGS, TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS + +# NODE_CLASS_MAPPINGS = GENERAL_NODE_CLASS_MAPPINGS.update(COMPONENT_NODE_CLASS_MAPPINGS) +# NODE_DISPLAY_NAME_MAPPINGS = GENERAL_NODE_DISPLAY_NAME_MAPPINGS.update(COMPONENT_NODE_DISPLAY_NAME_MAPPINGS) + +NODE_CLASS_MAPPINGS = {} +NODE_CLASS_MAPPINGS.update(TEST_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(FLOW_CONTROL_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(UTILITY_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(CONDITION_NODE_CLASS_MAPPINGS) +NODE_CLASS_MAPPINGS.update(TEST_STUB_NODE_CLASS_MAPPINGS) + +NODE_DISPLAY_NAME_MAPPINGS = {} +NODE_DISPLAY_NAME_MAPPINGS.update(TEST_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(UTILITY_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(CONDITION_NODE_DISPLAY_NAME_MAPPINGS) +NODE_DISPLAY_NAME_MAPPINGS.update(TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS) + diff --git a/tests/inference/testing_nodes/testing-pack/conditions.py b/tests/inference/testing_nodes/testing-pack/conditions.py new file mode 100644 index 000000000..0c200ee28 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/conditions.py @@ -0,0 +1,194 @@ +import re +import torch + +class TestIntConditions: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "operation": (["==", "!=", "<", ">", "<=", ">="],), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "int_condition" + + CATEGORY = "Testing/Logic" + + def int_condition(self, a, b, operation): + if operation == "==": + return (a == b,) + elif operation == "!=": + return (a != b,) + elif operation == "<": + return (a < b,) + elif operation == ">": + return (a > b,) + elif operation == "<=": + return (a <= b,) + elif operation == ">=": + return (a >= b,) + + +class TestFloatConditions: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}), + "b": ("FLOAT", {"default": 0, "min": -999999999999.0, "max": 999999999999.0, "step": 1}), + "operation": (["==", "!=", "<", ">", "<=", ">="],), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "float_condition" + + CATEGORY = "Testing/Logic" + + def float_condition(self, a, b, operation): + if operation == "==": + return (a == b,) + elif operation == "!=": + return (a != b,) + elif operation == "<": + return (a < b,) + elif operation == ">": + return (a > b,) + elif operation == "<=": + return (a <= b,) + elif operation == ">=": + return (a >= b,) + +class TestStringConditions: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("STRING", {"multiline": False}), + "b": ("STRING", {"multiline": False}), + "operation": (["a == b", "a != b", "a IN b", "a MATCH REGEX(b)", "a BEGINSWITH b", "a ENDSWITH b"],), + "case_sensitive": ("BOOLEAN", {"default": True}), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "string_condition" + + CATEGORY = "Testing/Logic" + + def string_condition(self, a, b, operation, case_sensitive): + if not case_sensitive: + a = a.lower() + b = b.lower() + + if operation == "a == b": + return (a == b,) + elif operation == "a != b": + return (a != b,) + elif operation == "a IN b": + return (a in b,) + elif operation == "a MATCH REGEX(b)": + try: + return (re.match(b, a) is not None,) + except: + return (False,) + elif operation == "a BEGINSWITH b": + return (a.startswith(b),) + elif operation == "a ENDSWITH b": + return (a.endswith(b),) + +class TestToBoolNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("*",), + }, + "optional": { + "invert": ("BOOLEAN", {"default": False}), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "to_bool" + + CATEGORY = "Testing/Logic" + + def to_bool(self, value, invert = False): + if isinstance(value, torch.Tensor): + if value.max().item() == 0 and value.min().item() == 0: + result = False + else: + result = True + else: + try: + result = bool(value) + except: + # Can't convert it? Well then it's something or other. I dunno, I'm not a Python programmer. + result = True + + if invert: + result = not result + + return (result,) + +class TestBoolOperationNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("BOOLEAN",), + "b": ("BOOLEAN",), + "op": (["a AND b", "a OR b", "a XOR b", "NOT a"],), + }, + } + + RETURN_TYPES = ("BOOLEAN",) + FUNCTION = "bool_operation" + + CATEGORY = "Testing/Logic" + + def bool_operation(self, a, b, op): + if op == "a AND b": + return (a and b,) + elif op == "a OR b": + return (a or b,) + elif op == "a XOR b": + return (a ^ b,) + elif op == "NOT a": + return (not a,) + + +CONDITION_NODE_CLASS_MAPPINGS = { + "TestIntConditions": TestIntConditions, + "TestFloatConditions": TestFloatConditions, + "TestStringConditions": TestStringConditions, + "TestToBoolNode": TestToBoolNode, + "TestBoolOperationNode": TestBoolOperationNode, +} + +CONDITION_NODE_DISPLAY_NAME_MAPPINGS = { + "TestIntConditions": "Int Condition", + "TestFloatConditions": "Float Condition", + "TestStringConditions": "String Condition", + "TestToBoolNode": "To Bool", + "TestBoolOperationNode": "Bool Operation", +} diff --git a/tests/inference/testing_nodes/testing-pack/flow_control.py b/tests/inference/testing_nodes/testing-pack/flow_control.py new file mode 100644 index 000000000..8befdcf19 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/flow_control.py @@ -0,0 +1,169 @@ +from comfy.graph_utils import GraphBuilder, is_link +from comfy.graph import ExecutionBlocker + +NUM_FLOW_SOCKETS = 5 +class TestWhileLoopOpen: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "condition": ("BOOLEAN", {"default": True}), + }, + "optional": { + }, + } + for i in range(NUM_FLOW_SOCKETS): + inputs["optional"]["initial_value%d" % i] = ("*",) + return inputs + + RETURN_TYPES = tuple(["FLOW_CONTROL"] + ["*"] * NUM_FLOW_SOCKETS) + RETURN_NAMES = tuple(["FLOW_CONTROL"] + ["value%d" % i for i in range(NUM_FLOW_SOCKETS)]) + FUNCTION = "while_loop_open" + + CATEGORY = "Testing/Flow" + + def while_loop_open(self, condition, **kwargs): + values = [] + for i in range(NUM_FLOW_SOCKETS): + values.append(kwargs.get("initial_value%d" % i, None)) + return tuple(["stub"] + values) + +class TestWhileLoopClose: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "flow_control": ("FLOW_CONTROL", {"rawLink": True}), + "condition": ("BOOLEAN", {"forceInput": True}), + }, + "optional": { + }, + "hidden": { + "dynprompt": "DYNPROMPT", + "unique_id": "UNIQUE_ID", + } + } + for i in range(NUM_FLOW_SOCKETS): + inputs["optional"]["initial_value%d" % i] = ("*",) + return inputs + + RETURN_TYPES = tuple(["*"] * NUM_FLOW_SOCKETS) + RETURN_NAMES = tuple(["value%d" % i for i in range(NUM_FLOW_SOCKETS)]) + FUNCTION = "while_loop_close" + + CATEGORY = "Testing/Flow" + + def explore_dependencies(self, node_id, dynprompt, upstream): + node_info = dynprompt.get_node(node_id) + if "inputs" not in node_info: + return + for k, v in node_info["inputs"].items(): + if is_link(v): + parent_id = v[0] + if parent_id not in upstream: + upstream[parent_id] = [] + self.explore_dependencies(parent_id, dynprompt, upstream) + upstream[parent_id].append(node_id) + + def collect_contained(self, node_id, upstream, contained): + if node_id not in upstream: + return + for child_id in upstream[node_id]: + if child_id not in contained: + contained[child_id] = True + self.collect_contained(child_id, upstream, contained) + + + def while_loop_close(self, flow_control, condition, dynprompt=None, unique_id=None, **kwargs): + assert dynprompt is not None + if not condition: + # We're done with the loop + values = [] + for i in range(NUM_FLOW_SOCKETS): + values.append(kwargs.get("initial_value%d" % i, None)) + return tuple(values) + + # We want to loop + upstream = {} + # Get the list of all nodes between the open and close nodes + self.explore_dependencies(unique_id, dynprompt, upstream) + + contained = {} + open_node = flow_control[0] + self.collect_contained(open_node, upstream, contained) + contained[unique_id] = True + contained[open_node] = True + + # We'll use the default prefix, but to avoid having node names grow exponentially in size, + # we'll use "Recurse" for the name of the recursively-generated copy of this node. + graph = GraphBuilder() + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.node(original_node["class_type"], "Recurse" if node_id == unique_id else node_id) + node.set_override_display_id(node_id) + for node_id in contained: + original_node = dynprompt.get_node(node_id) + node = graph.lookup_node("Recurse" if node_id == unique_id else node_id) + assert node is not None + for k, v in original_node["inputs"].items(): + if is_link(v) and v[0] in contained: + parent = graph.lookup_node(v[0]) + assert parent is not None + node.set_input(k, parent.out(v[1])) + else: + node.set_input(k, v) + new_open = graph.lookup_node(open_node) + assert new_open is not None + for i in range(NUM_FLOW_SOCKETS): + key = "initial_value%d" % i + new_open.set_input(key, kwargs.get(key, None)) + my_clone = graph.lookup_node("Recurse") + assert my_clone is not None + result = map(lambda x: my_clone.out(x), range(NUM_FLOW_SOCKETS)) + return { + "result": tuple(result), + "expand": graph.finalize(), + } + +class TestExecutionBlockerNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + inputs = { + "required": { + "input": ("*",), + "block": ("BOOLEAN",), + "verbose": ("BOOLEAN", {"default": False}), + }, + } + return inputs + + RETURN_TYPES = ("*",) + RETURN_NAMES = ("output",) + FUNCTION = "execution_blocker" + + CATEGORY = "Testing/Flow" + + def execution_blocker(self, input, block, verbose): + if block: + return (ExecutionBlocker("Blocked Execution" if verbose else None),) + return (input,) + +FLOW_CONTROL_NODE_CLASS_MAPPINGS = { + "TestWhileLoopOpen": TestWhileLoopOpen, + "TestWhileLoopClose": TestWhileLoopClose, + "TestExecutionBlocker": TestExecutionBlockerNode, +} +FLOW_CONTROL_NODE_DISPLAY_NAME_MAPPINGS = { + "TestWhileLoopOpen": "While Loop Open", + "TestWhileLoopClose": "While Loop Close", + "TestExecutionBlocker": "Execution Blocker", +} diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py new file mode 100644 index 000000000..e3d864b44 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -0,0 +1,116 @@ +import torch + +class TestLazyMixImages: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image1": ("IMAGE",{"lazy": True}), + "image2": ("IMAGE",{"lazy": True}), + "mask": ("MASK",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "mix" + + CATEGORY = "Testing/Nodes" + + def check_lazy_status(self, mask, image1 = None, image2 = None): + mask_min = mask.min() + mask_max = mask.max() + needed = [] + if image1 is None and (mask_min != 1.0 or mask_max != 1.0): + needed.append("image1") + if image2 is None and (mask_min != 0.0 or mask_max != 0.0): + needed.append("image2") + return needed + + # Not trying to handle different batch sizes here just to keep the demo simple + def mix(self, mask, image1 = None, image2 = None): + mask_min = mask.min() + mask_max = mask.max() + if mask_min == 0.0 and mask_max == 0.0: + return (image1,) + elif mask_min == 1.0 and mask_max == 1.0: + return (image2,) + + if len(mask.shape) == 2: + mask = mask.unsqueeze(0) + if len(mask.shape) == 3: + mask = mask.unsqueeze(3) + if mask.shape[3] < image1.shape[3]: + mask = mask.repeat(1, 1, 1, image1.shape[3]) + + result = image1 * (1. - mask) + image2 * mask, + print(result[0]) + return (result[0],) + +class TestVariadicAverage: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "variadic_average" + + CATEGORY = "Testing/Nodes" + + def variadic_average(self, input1, **kwargs): + inputs = [input1] + while 'input' + str(len(inputs) + 1) in kwargs: + inputs.append(kwargs['input' + str(len(inputs) + 1)]) + return (torch.stack(inputs).mean(dim=0),) + + +class TestCustomIsChanged: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image": ("IMAGE",), + }, + "optional": { + "should_change": ("BOOL", {"default": False}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_is_changed" + + CATEGORY = "Testing/Nodes" + + def custom_is_changed(self, image, should_change=False): + return (image,) + + @classmethod + def IS_CHANGED(cls, should_change=False, *args, **kwargs): + if should_change: + return float("NaN") + else: + return False + +TEST_NODE_CLASS_MAPPINGS = { + "TestLazyMixImages": TestLazyMixImages, + "TestVariadicAverage": TestVariadicAverage, + "TestCustomIsChanged": TestCustomIsChanged, +} + +TEST_NODE_DISPLAY_NAME_MAPPINGS = { + "TestLazyMixImages": "Lazy Mix Images", + "TestVariadicAverage": "Variadic Average", + "TestCustomIsChanged": "Custom IsChanged", +} diff --git a/tests/inference/testing_nodes/testing-pack/stubs.py b/tests/inference/testing_nodes/testing-pack/stubs.py new file mode 100644 index 000000000..b2a5ebf3d --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/stubs.py @@ -0,0 +1,61 @@ +import torch + +class StubImage: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "content": (['WHITE', 'BLACK', 'NOISE'],), + "height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "stub_image" + + CATEGORY = "Testing/Stub Nodes" + + def stub_image(self, content, height, width, batch_size): + if content == "WHITE": + return (torch.ones(batch_size, height, width, 3),) + elif content == "BLACK": + return (torch.zeros(batch_size, height, width, 3),) + elif content == "NOISE": + return (torch.rand(batch_size, height, width, 3),) + +class StubMask: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + "height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}), + }, + } + + RETURN_TYPES = ("MASK",) + FUNCTION = "stub_mask" + + CATEGORY = "Testing/Stub Nodes" + + def stub_mask(self, value, height, width, batch_size): + return (torch.ones(batch_size, height, width) * value,) + +TEST_STUB_NODE_CLASS_MAPPINGS = { + "StubImage": StubImage, + "StubMask": StubMask, +} +TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = { + "StubImage": "Stub Image", + "StubMask": "Stub Mask", +} diff --git a/tests/inference/testing_nodes/testing-pack/util.py b/tests/inference/testing_nodes/testing-pack/util.py new file mode 100644 index 000000000..16209d3fc --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/util.py @@ -0,0 +1,353 @@ +from comfy.graph_utils import GraphBuilder + +class TestAccumulateNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "to_add": ("*",), + }, + "optional": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("ACCUMULATION",) + FUNCTION = "accumulate" + + CATEGORY = "Testing/Lists" + + def accumulate(self, to_add, accumulation = None): + if accumulation is None: + value = [to_add] + else: + value = accumulation["accum"] + [to_add] + return ({"accum": value},) + +class TestAccumulationHeadNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("ACCUMULATION", "*",) + FUNCTION = "accumulation_head" + + CATEGORY = "Testing/Lists" + + def accumulation_head(self, accumulation): + accum = accumulation["accum"] + if len(accum) == 0: + return (accumulation, None) + else: + return ({"accum": accum[1:]}, accum[0]) + +class TestAccumulationTailNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("ACCUMULATION", "*",) + FUNCTION = "accumulation_tail" + + CATEGORY = "Testing/Lists" + + def accumulation_tail(self, accumulation): + accum = accumulation["accum"] + if len(accum) == 0: + return (None, accumulation) + else: + return ({"accum": accum[:-1]}, accum[-1]) + +class TestAccumulationToListNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("*",) + OUTPUT_IS_LIST = (True,) + + FUNCTION = "accumulation_to_list" + + CATEGORY = "Testing/Lists" + + def accumulation_to_list(self, accumulation): + return (accumulation["accum"],) + +class TestListToAccumulationNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "list": ("*",), + }, + } + + RETURN_TYPES = ("ACCUMULATION",) + INPUT_IS_LIST = (True,) + + FUNCTION = "list_to_accumulation" + + CATEGORY = "Testing/Lists" + + def list_to_accumulation(self, list): + return ({"accum": list},) + +class TestAccumulationGetLengthNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + }, + } + + RETURN_TYPES = ("INT",) + + FUNCTION = "accumlength" + + CATEGORY = "Testing/Lists" + + def accumlength(self, accumulation): + return (len(accumulation['accum']),) + +class TestAccumulationGetItemNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + "index": ("INT", {"default":0, "step":1}) + }, + } + + RETURN_TYPES = ("*",) + + FUNCTION = "get_item" + + CATEGORY = "Testing/Lists" + + def get_item(self, accumulation, index): + return (accumulation['accum'][index],) + +class TestAccumulationSetItemNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "accumulation": ("ACCUMULATION",), + "index": ("INT", {"default":0, "step":1}), + "value": ("*",), + }, + } + + RETURN_TYPES = ("ACCUMULATION",) + + FUNCTION = "set_item" + + CATEGORY = "Testing/Lists" + + def set_item(self, accumulation, index, value): + new_accum = accumulation['accum'][:] + new_accum[index] = value + return ({"accum": new_accum},) + +class TestIntMathOperation: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "a": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "b": ("INT", {"default": 0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 1}), + "operation": (["add", "subtract", "multiply", "divide", "modulo", "power"],), + }, + } + + RETURN_TYPES = ("INT",) + FUNCTION = "int_math_operation" + + CATEGORY = "Testing/Logic" + + def int_math_operation(self, a, b, operation): + if operation == "add": + return (a + b,) + elif operation == "subtract": + return (a - b,) + elif operation == "multiply": + return (a * b,) + elif operation == "divide": + return (a // b,) + elif operation == "modulo": + return (a % b,) + elif operation == "power": + return (a ** b,) + + +from .flow_control import NUM_FLOW_SOCKETS +class TestForLoopOpen: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "remaining": ("INT", {"default": 1, "min": 0, "max": 100000, "step": 1}), + }, + "optional": { + "initial_value%d" % i: ("*",) for i in range(1, NUM_FLOW_SOCKETS) + }, + "hidden": { + "initial_value0": ("*",) + } + } + + RETURN_TYPES = tuple(["FLOW_CONTROL", "INT",] + ["*"] * (NUM_FLOW_SOCKETS-1)) + RETURN_NAMES = tuple(["flow_control", "remaining"] + ["value%d" % i for i in range(1, NUM_FLOW_SOCKETS)]) + FUNCTION = "for_loop_open" + + CATEGORY = "Testing/Flow" + + def for_loop_open(self, remaining, **kwargs): + graph = GraphBuilder() + if "initial_value0" in kwargs: + remaining = kwargs["initial_value0"] + while_open = graph.node("TestWhileLoopOpen", condition=remaining, initial_value0=remaining, **{("initial_value%d" % i): kwargs.get("initial_value%d" % i, None) for i in range(1, NUM_FLOW_SOCKETS)}) + outputs = [kwargs.get("initial_value%d" % i, None) for i in range(1, NUM_FLOW_SOCKETS)] + return { + "result": tuple(["stub", remaining] + outputs), + "expand": graph.finalize(), + } + +class TestForLoopClose: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "flow_control": ("FLOW_CONTROL", {"rawLink": True}), + }, + "optional": { + "initial_value%d" % i: ("*",{"rawLink": True}) for i in range(1, NUM_FLOW_SOCKETS) + }, + } + + RETURN_TYPES = tuple(["*"] * (NUM_FLOW_SOCKETS-1)) + RETURN_NAMES = tuple(["value%d" % i for i in range(1, NUM_FLOW_SOCKETS)]) + FUNCTION = "for_loop_close" + + CATEGORY = "Testing/Flow" + + def for_loop_close(self, flow_control, **kwargs): + graph = GraphBuilder() + while_open = flow_control[0] + sub = graph.node("TestIntMathOperation", operation="subtract", a=[while_open,1], b=1) + cond = graph.node("TestToBoolNode", value=sub.out(0)) + input_values = {("initial_value%d" % i): kwargs.get("initial_value%d" % i, None) for i in range(1, NUM_FLOW_SOCKETS)} + while_close = graph.node("TestWhileLoopClose", + flow_control=flow_control, + condition=cond.out(0), + initial_value0=sub.out(0), + **input_values) + return { + "result": tuple([while_close.out(i) for i in range(1, NUM_FLOW_SOCKETS)]), + "expand": graph.finalize(), + } + +NUM_LIST_SOCKETS = 10 +class TestMakeListNode: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value1": ("*",), + }, + "optional": { + "value%d" % i: ("*",) for i in range(1, NUM_LIST_SOCKETS) + }, + } + + RETURN_TYPES = ("*",) + FUNCTION = "make_list" + OUTPUT_IS_LIST = (True,) + + CATEGORY = "Testing/Lists" + + def make_list(self, **kwargs): + result = [] + for i in range(NUM_LIST_SOCKETS): + if "value%d" % i in kwargs: + result.append(kwargs["value%d" % i]) + return (result,) + +UTILITY_NODE_CLASS_MAPPINGS = { + "TestAccumulateNode": TestAccumulateNode, + "TestAccumulationHeadNode": TestAccumulationHeadNode, + "TestAccumulationTailNode": TestAccumulationTailNode, + "TestAccumulationToListNode": TestAccumulationToListNode, + "TestListToAccumulationNode": TestListToAccumulationNode, + "TestAccumulationGetLengthNode": TestAccumulationGetLengthNode, + "TestAccumulationGetItemNode": TestAccumulationGetItemNode, + "TestAccumulationSetItemNode": TestAccumulationSetItemNode, + "TestForLoopOpen": TestForLoopOpen, + "TestForLoopClose": TestForLoopClose, + "TestIntMathOperation": TestIntMathOperation, + "TestMakeListNode": TestMakeListNode, +} +UTILITY_NODE_DISPLAY_NAME_MAPPINGS = { + "TestAccumulateNode": "Accumulate", + "TestAccumulationHeadNode": "Accumulation Head", + "TestAccumulationTailNode": "Accumulation Tail", + "TestAccumulationToListNode": "Accumulation to List", + "TestListToAccumulationNode": "List to Accumulation", + "TestAccumulationGetLengthNode": "Accumulation Get Length", + "TestAccumulationGetItemNode": "Accumulation Get Item", + "TestAccumulationSetItemNode": "Accumulation Set Item", + "TestForLoopOpen": "For Loop Open", + "TestForLoopClose": "For Loop Close", + "TestIntMathOperation": "Int Math Operation", + "TestMakeListNode": "Make List", +}