From 6d09dd70f8e6400ab9952bfc2bb98ba10c360395 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sat, 24 Feb 2024 23:17:01 -0800 Subject: [PATCH] Make custom VALIDATE_INPUTS skip normal validation Additionally, if `VALIDATE_INPUTS` takes an argument named `input_types`, that variable will be a dictionary of the socket type of all incoming connections. If that argument exists, normal socket type validation will not occur. This removes the last hurdle for enabling variant types entirely from custom nodes, so I've removed that command-line option. I've added appropriate unit tests for these changes. --- comfy/cli_args.py | 1 - execution.py | 65 +++++----- tests/inference/test_execution.py | 63 +++++++++- .../testing-pack/flow_control.py | 4 + .../testing-pack/specific_tests.py | 112 ++++++++++++++++-- .../testing_nodes/testing-pack/stubs.py | 44 +++++++ .../testing_nodes/testing-pack/tools.py | 48 ++++++++ .../testing_nodes/testing-pack/util.py | 11 ++ 8 files changed, 307 insertions(+), 41 deletions(-) create mode 100644 tests/inference/testing_nodes/testing-pack/tools.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 74354ea94..2cbefefeb 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -117,7 +117,6 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") -parser.add_argument("--enable-variants", action="store_true", help="Enables '*' type nodes.") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/execution.py b/execution.py index afedc0758..c8c89d01f 100644 --- a/execution.py +++ b/execution.py @@ -92,6 +92,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro cached_output = outputs.get(input_unique_id) if cached_output is None: continue + if output_index >= len(cached_output): + continue obj = cached_output[output_index] input_data_all[x] = obj elif input_category is not None: @@ -514,6 +516,7 @@ def validate_inputs(prompt, item, validated): validate_function_inputs = [] if hasattr(obj_class, "VALIDATE_INPUTS"): validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args + received_types = {} for x in valid_inputs: type_input, input_category, extra_info = get_input_info(obj_class, x) @@ -551,9 +554,9 @@ def validate_inputs(prompt, item, validated): o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES - is_variant = args.enable_variants and (r[val[1]] == "*" or type_input == "*") - if r[val[1]] != type_input and not is_variant: - received_type = r[val[1]] + received_type = r[val[1]] + received_types[x] = received_type + if 'input_types' not in validate_function_inputs and received_type != type_input: details = f"{x}, {received_type} != {type_input}" error = { "type": "return_type_mismatch", @@ -622,34 +625,34 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if "min" in extra_info and val < extra_info["min"]: - error = { - "type": "value_smaller_than_min", - "message": "Value {} smaller than min of {}".format(val, extra_info["min"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - if "max" in extra_info and val > extra_info["max"]: - error = { - "type": "value_bigger_than_max", - "message": "Value {} bigger than max of {}".format(val, extra_info["max"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - if x not in validate_function_inputs: + if "min" in extra_info and val < extra_info["min"]: + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, extra_info["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue + if "max" in extra_info and val > extra_info["max"]: + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, extra_info["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue + if isinstance(type_input, list): if val not in type_input: input_config = info @@ -682,6 +685,8 @@ def validate_inputs(prompt, item, validated): for x in input_data_all: if x in validate_function_inputs: input_filtered[x] = input_data_all[x] + if 'input_types' in validate_function_inputs: + input_filtered['input_types'] = [received_types] #ret = obj_class.VALIDATE_INPUTS(**input_filtered) ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 52e195627..6a4fa3dd1 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -12,6 +12,7 @@ import websocket #NOTE: websocket-client (https://github.com/websocket-client/we import uuid import urllib.request import urllib.parse +import urllib.error from comfy.graph_utils import GraphBuilder, Node class RunResult: @@ -125,7 +126,6 @@ class TestExecution: '--listen', args_pytest["listen"], '--port', str(args_pytest["port"]), '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', - '--enable-variants', ]) yield p.kill() @@ -237,6 +237,67 @@ class TestExecution: except Exception as e: assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" + @pytest.mark.parametrize("test_value, expect_error", [ + (5, True), + ("foo", True), + (5.0, False), + ]) + def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0) + g.node("SaveImage", images=validation1.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + + @pytest.mark.parametrize("test_type, test_value", [ + ("StubInt", 5), + ("StubFloat", 5.0) + ]) + def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation1.out(0)) + + with pytest.raises(urllib.error.HTTPError): + client.run(g) + + @pytest.mark.parametrize("test_type, test_value, expect_error", [ + ("StubInt", 5, True), + ("StubFloat", 5.0, False) + ]) + def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation2.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + + @pytest.mark.parametrize("test_type, test_value, expect_error", [ + ("StubInt", 5, True), + ("StubFloat", 5.0, False) + ]) + def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation3.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): g = builder # Creating the nodes in this specific order previously caused a bug diff --git a/tests/inference/testing_nodes/testing-pack/flow_control.py b/tests/inference/testing_nodes/testing-pack/flow_control.py index 8befdcf19..43f1ce02f 100644 --- a/tests/inference/testing_nodes/testing-pack/flow_control.py +++ b/tests/inference/testing_nodes/testing-pack/flow_control.py @@ -1,7 +1,9 @@ from comfy.graph_utils import GraphBuilder, is_link from comfy.graph import ExecutionBlocker +from .tools import VariantSupport NUM_FLOW_SOCKETS = 5 +@VariantSupport() class TestWhileLoopOpen: def __init__(self): pass @@ -31,6 +33,7 @@ class TestWhileLoopOpen: values.append(kwargs.get("initial_value%d" % i, None)) return tuple(["stub"] + values) +@VariantSupport() class TestWhileLoopClose: def __init__(self): pass @@ -131,6 +134,7 @@ class TestWhileLoopClose: "expand": graph.finalize(), } +@VariantSupport() class TestExecutionBlockerNode: def __init__(self): pass diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py index e3d864b44..8c103c18a 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -1,9 +1,7 @@ import torch +from .tools import VariantSupport class TestLazyMixImages: - def __init__(self): - pass - @classmethod def INPUT_TYPES(cls): return { @@ -50,9 +48,6 @@ class TestLazyMixImages: return (result[0],) class TestVariadicAverage: - def __init__(self): - pass - @classmethod def INPUT_TYPES(cls): return { @@ -74,9 +69,6 @@ class TestVariadicAverage: class TestCustomIsChanged: - def __init__(self): - pass - @classmethod def INPUT_TYPES(cls): return { @@ -103,14 +95,116 @@ class TestCustomIsChanged: else: return False +class TestCustomValidation1: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE,FLOAT",), + "input2": ("IMAGE,FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation1" + + CATEGORY = "Testing/Nodes" + + def custom_validation1(self, input1, input2): + if isinstance(input1, float) and isinstance(input2, float): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + else: + result = input1 * input2 + return (result,) + + @classmethod + def VALIDATE_INPUTS(cls, input1=None, input2=None): + if input1 is not None: + if not isinstance(input1, (torch.Tensor, float)): + return f"Invalid type of input1: {type(input1)}" + if input2 is not None: + if not isinstance(input2, (torch.Tensor, float)): + return f"Invalid type of input2: {type(input2)}" + + return True + +class TestCustomValidation2: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE,FLOAT",), + "input2": ("IMAGE,FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation2" + + CATEGORY = "Testing/Nodes" + + def custom_validation2(self, input1, input2): + if isinstance(input1, float) and isinstance(input2, float): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + else: + result = input1 * input2 + return (result,) + + @classmethod + def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None): + if input1 is not None: + if not isinstance(input1, (torch.Tensor, float)): + return f"Invalid type of input1: {type(input1)}" + if input2 is not None: + if not isinstance(input2, (torch.Tensor, float)): + return f"Invalid type of input2: {type(input2)}" + + if 'input1' in input_types: + if input_types['input1'] not in ["IMAGE", "FLOAT"]: + return f"Invalid type of input1: {input_types['input1']}" + if 'input2' in input_types: + if input_types['input2'] not in ["IMAGE", "FLOAT"]: + return f"Invalid type of input2: {input_types['input2']}" + + return True + +@VariantSupport() +class TestCustomValidation3: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE,FLOAT",), + "input2": ("IMAGE,FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation3" + + CATEGORY = "Testing/Nodes" + + def custom_validation3(self, input1, input2): + if isinstance(input1, float) and isinstance(input2, float): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + else: + result = input1 * input2 + return (result,) + TEST_NODE_CLASS_MAPPINGS = { "TestLazyMixImages": TestLazyMixImages, "TestVariadicAverage": TestVariadicAverage, "TestCustomIsChanged": TestCustomIsChanged, + "TestCustomValidation1": TestCustomValidation1, + "TestCustomValidation2": TestCustomValidation2, + "TestCustomValidation3": TestCustomValidation3, } TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestLazyMixImages": "Lazy Mix Images", "TestVariadicAverage": "Variadic Average", "TestCustomIsChanged": "Custom IsChanged", + "TestCustomValidation1": "Custom Validation 1", + "TestCustomValidation2": "Custom Validation 2", + "TestCustomValidation3": "Custom Validation 3", } diff --git a/tests/inference/testing_nodes/testing-pack/stubs.py b/tests/inference/testing_nodes/testing-pack/stubs.py index b2a5ebf3d..9be6eac9d 100644 --- a/tests/inference/testing_nodes/testing-pack/stubs.py +++ b/tests/inference/testing_nodes/testing-pack/stubs.py @@ -51,11 +51,55 @@ class StubMask: def stub_mask(self, value, height, width, batch_size): return (torch.ones(batch_size, height, width) * value,) +class StubInt: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}), + }, + } + + RETURN_TYPES = ("INT",) + FUNCTION = "stub_int" + + CATEGORY = "Testing/Stub Nodes" + + def stub_int(self, value): + return (value,) + +class StubFloat: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}), + }, + } + + RETURN_TYPES = ("FLOAT",) + FUNCTION = "stub_float" + + CATEGORY = "Testing/Stub Nodes" + + def stub_float(self, value): + return (value,) + TEST_STUB_NODE_CLASS_MAPPINGS = { "StubImage": StubImage, "StubMask": StubMask, + "StubInt": StubInt, + "StubFloat": StubFloat, } TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = { "StubImage": "Stub Image", "StubMask": "Stub Mask", + "StubInt": "Stub Int", + "StubFloat": "Stub Float", } diff --git a/tests/inference/testing_nodes/testing-pack/tools.py b/tests/inference/testing_nodes/testing-pack/tools.py new file mode 100644 index 000000000..6c8d5eaa0 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/tools.py @@ -0,0 +1,48 @@ + +class SmartType(str): + def __ne__(self, other): + if self == "*" or other == "*": + return False + selfset = set(self.split(',')) + otherset = set(other.split(',')) + return not selfset.issubset(otherset) + +def VariantSupport(): + def decorator(cls): + if hasattr(cls, "INPUT_TYPES"): + old_input_types = getattr(cls, "INPUT_TYPES") + def new_input_types(*args, **kwargs): + types = old_input_types(*args, **kwargs) + for category in ["required", "optional"]: + if category not in types: + continue + for key, value in types[category].items(): + if isinstance(value, tuple): + types[category][key] = (SmartType(value[0]),) + value[1:] + return types + setattr(cls, "INPUT_TYPES", new_input_types) + if hasattr(cls, "RETURN_TYPES"): + old_return_types = cls.RETURN_TYPES + setattr(cls, "RETURN_TYPES", tuple(SmartType(x) for x in old_return_types)) + if hasattr(cls, "VALIDATE_INPUTS"): + # Reflection is used to determine what the function signature is, so we can't just change the function signature + raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet") + else: + def validate_inputs(input_types): + inputs = cls.INPUT_TYPES() + for key, value in input_types.items(): + if isinstance(value, SmartType): + continue + if "required" in inputs and key in inputs["required"]: + expected_type = inputs["required"][key][0] + elif "optional" in inputs and key in inputs["optional"]: + expected_type = inputs["optional"][key][0] + else: + expected_type = None + if expected_type is not None and SmartType(value) != expected_type: + return f"Invalid type of {key}: {value} (expected {expected_type})" + return True + setattr(cls, "VALIDATE_INPUTS", validate_inputs) + return cls + return decorator + diff --git a/tests/inference/testing_nodes/testing-pack/util.py b/tests/inference/testing_nodes/testing-pack/util.py index 16209d3fc..8e2065c7b 100644 --- a/tests/inference/testing_nodes/testing-pack/util.py +++ b/tests/inference/testing_nodes/testing-pack/util.py @@ -1,5 +1,7 @@ from comfy.graph_utils import GraphBuilder +from .tools import VariantSupport +@VariantSupport() class TestAccumulateNode: def __init__(self): pass @@ -27,6 +29,7 @@ class TestAccumulateNode: value = accumulation["accum"] + [to_add] return ({"accum": value},) +@VariantSupport() class TestAccumulationHeadNode: def __init__(self): pass @@ -75,6 +78,7 @@ class TestAccumulationTailNode: else: return ({"accum": accum[:-1]}, accum[-1]) +@VariantSupport() class TestAccumulationToListNode: def __init__(self): pass @@ -97,6 +101,7 @@ class TestAccumulationToListNode: def accumulation_to_list(self, accumulation): return (accumulation["accum"],) +@VariantSupport() class TestListToAccumulationNode: def __init__(self): pass @@ -119,6 +124,7 @@ class TestListToAccumulationNode: def list_to_accumulation(self, list): return ({"accum": list},) +@VariantSupport() class TestAccumulationGetLengthNode: def __init__(self): pass @@ -140,6 +146,7 @@ class TestAccumulationGetLengthNode: def accumlength(self, accumulation): return (len(accumulation['accum']),) +@VariantSupport() class TestAccumulationGetItemNode: def __init__(self): pass @@ -162,6 +169,7 @@ class TestAccumulationGetItemNode: def get_item(self, accumulation, index): return (accumulation['accum'][index],) +@VariantSupport() class TestAccumulationSetItemNode: def __init__(self): pass @@ -222,6 +230,7 @@ class TestIntMathOperation: from .flow_control import NUM_FLOW_SOCKETS +@VariantSupport() class TestForLoopOpen: def __init__(self): pass @@ -257,6 +266,7 @@ class TestForLoopOpen: "expand": graph.finalize(), } +@VariantSupport() class TestForLoopClose: def __init__(self): pass @@ -295,6 +305,7 @@ class TestForLoopClose: } NUM_LIST_SOCKETS = 10 +@VariantSupport() class TestMakeListNode: def __init__(self): pass