diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 74354ea94..2cbefefeb 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -117,7 +117,6 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") -parser.add_argument("--enable-variants", action="store_true", help="Enables '*' type nodes.") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/execution.py b/execution.py index afedc0758..c8c89d01f 100644 --- a/execution.py +++ b/execution.py @@ -92,6 +92,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, prompt={}, dynpro cached_output = outputs.get(input_unique_id) if cached_output is None: continue + if output_index >= len(cached_output): + continue obj = cached_output[output_index] input_data_all[x] = obj elif input_category is not None: @@ -514,6 +516,7 @@ def validate_inputs(prompt, item, validated): validate_function_inputs = [] if hasattr(obj_class, "VALIDATE_INPUTS"): validate_function_inputs = inspect.getfullargspec(obj_class.VALIDATE_INPUTS).args + received_types = {} for x in valid_inputs: type_input, input_category, extra_info = get_input_info(obj_class, x) @@ -551,9 +554,9 @@ def validate_inputs(prompt, item, validated): o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES - is_variant = args.enable_variants and (r[val[1]] == "*" or type_input == "*") - if r[val[1]] != type_input and not is_variant: - received_type = r[val[1]] + received_type = r[val[1]] + received_types[x] = received_type + if 'input_types' not in validate_function_inputs and received_type != type_input: details = f"{x}, {received_type} != {type_input}" error = { "type": "return_type_mismatch", @@ -622,34 +625,34 @@ def validate_inputs(prompt, item, validated): errors.append(error) continue - if "min" in extra_info and val < extra_info["min"]: - error = { - "type": "value_smaller_than_min", - "message": "Value {} smaller than min of {}".format(val, extra_info["min"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - if "max" in extra_info and val > extra_info["max"]: - error = { - "type": "value_bigger_than_max", - "message": "Value {} bigger than max of {}".format(val, extra_info["max"]), - "details": f"{x}", - "extra_info": { - "input_name": x, - "input_config": info, - "received_value": val, - } - } - errors.append(error) - continue - if x not in validate_function_inputs: + if "min" in extra_info and val < extra_info["min"]: + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, extra_info["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue + if "max" in extra_info and val > extra_info["max"]: + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, extra_info["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue + if isinstance(type_input, list): if val not in type_input: input_config = info @@ -682,6 +685,8 @@ def validate_inputs(prompt, item, validated): for x in input_data_all: if x in validate_function_inputs: input_filtered[x] = input_data_all[x] + if 'input_types' in validate_function_inputs: + input_filtered['input_types'] = [received_types] #ret = obj_class.VALIDATE_INPUTS(**input_filtered) ret = map_node_over_list(obj_class, input_filtered, "VALIDATE_INPUTS") diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 52e195627..6a4fa3dd1 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -12,6 +12,7 @@ import websocket #NOTE: websocket-client (https://github.com/websocket-client/we import uuid import urllib.request import urllib.parse +import urllib.error from comfy.graph_utils import GraphBuilder, Node class RunResult: @@ -125,7 +126,6 @@ class TestExecution: '--listen', args_pytest["listen"], '--port', str(args_pytest["port"]), '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', - '--enable-variants', ]) yield p.kill() @@ -237,6 +237,67 @@ class TestExecution: except Exception as e: assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" + @pytest.mark.parametrize("test_value, expect_error", [ + (5, True), + ("foo", True), + (5.0, False), + ]) + def test_validation_error_literal(self, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + validation1 = g.node("TestCustomValidation1", input1=test_value, input2=3.0) + g.node("SaveImage", images=validation1.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + + @pytest.mark.parametrize("test_type, test_value", [ + ("StubInt", 5), + ("StubFloat", 5.0) + ]) + def test_validation_error_edge1(self, test_type, test_value, client: ComfyClient, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation1 = g.node("TestCustomValidation1", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation1.out(0)) + + with pytest.raises(urllib.error.HTTPError): + client.run(g) + + @pytest.mark.parametrize("test_type, test_value, expect_error", [ + ("StubInt", 5, True), + ("StubFloat", 5.0, False) + ]) + def test_validation_error_edge2(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation2 = g.node("TestCustomValidation2", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation2.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + + @pytest.mark.parametrize("test_type, test_value, expect_error", [ + ("StubInt", 5, True), + ("StubFloat", 5.0, False) + ]) + def test_validation_error_edge3(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation3 = g.node("TestCustomValidation3", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation3.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): g = builder # Creating the nodes in this specific order previously caused a bug diff --git a/tests/inference/testing_nodes/testing-pack/flow_control.py b/tests/inference/testing_nodes/testing-pack/flow_control.py index 8befdcf19..43f1ce02f 100644 --- a/tests/inference/testing_nodes/testing-pack/flow_control.py +++ b/tests/inference/testing_nodes/testing-pack/flow_control.py @@ -1,7 +1,9 @@ from comfy.graph_utils import GraphBuilder, is_link from comfy.graph import ExecutionBlocker +from .tools import VariantSupport NUM_FLOW_SOCKETS = 5 +@VariantSupport() class TestWhileLoopOpen: def __init__(self): pass @@ -31,6 +33,7 @@ class TestWhileLoopOpen: values.append(kwargs.get("initial_value%d" % i, None)) return tuple(["stub"] + values) +@VariantSupport() class TestWhileLoopClose: def __init__(self): pass @@ -131,6 +134,7 @@ class TestWhileLoopClose: "expand": graph.finalize(), } +@VariantSupport() class TestExecutionBlockerNode: def __init__(self): pass diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py index e3d864b44..8c103c18a 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -1,9 +1,7 @@ import torch +from .tools import VariantSupport class TestLazyMixImages: - def __init__(self): - pass - @classmethod def INPUT_TYPES(cls): return { @@ -50,9 +48,6 @@ class TestLazyMixImages: return (result[0],) class TestVariadicAverage: - def __init__(self): - pass - @classmethod def INPUT_TYPES(cls): return { @@ -74,9 +69,6 @@ class TestVariadicAverage: class TestCustomIsChanged: - def __init__(self): - pass - @classmethod def INPUT_TYPES(cls): return { @@ -103,14 +95,116 @@ class TestCustomIsChanged: else: return False +class TestCustomValidation1: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE,FLOAT",), + "input2": ("IMAGE,FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation1" + + CATEGORY = "Testing/Nodes" + + def custom_validation1(self, input1, input2): + if isinstance(input1, float) and isinstance(input2, float): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + else: + result = input1 * input2 + return (result,) + + @classmethod + def VALIDATE_INPUTS(cls, input1=None, input2=None): + if input1 is not None: + if not isinstance(input1, (torch.Tensor, float)): + return f"Invalid type of input1: {type(input1)}" + if input2 is not None: + if not isinstance(input2, (torch.Tensor, float)): + return f"Invalid type of input2: {type(input2)}" + + return True + +class TestCustomValidation2: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE,FLOAT",), + "input2": ("IMAGE,FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation2" + + CATEGORY = "Testing/Nodes" + + def custom_validation2(self, input1, input2): + if isinstance(input1, float) and isinstance(input2, float): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + else: + result = input1 * input2 + return (result,) + + @classmethod + def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None): + if input1 is not None: + if not isinstance(input1, (torch.Tensor, float)): + return f"Invalid type of input1: {type(input1)}" + if input2 is not None: + if not isinstance(input2, (torch.Tensor, float)): + return f"Invalid type of input2: {type(input2)}" + + if 'input1' in input_types: + if input_types['input1'] not in ["IMAGE", "FLOAT"]: + return f"Invalid type of input1: {input_types['input1']}" + if 'input2' in input_types: + if input_types['input2'] not in ["IMAGE", "FLOAT"]: + return f"Invalid type of input2: {input_types['input2']}" + + return True + +@VariantSupport() +class TestCustomValidation3: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("IMAGE,FLOAT",), + "input2": ("IMAGE,FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation3" + + CATEGORY = "Testing/Nodes" + + def custom_validation3(self, input1, input2): + if isinstance(input1, float) and isinstance(input2, float): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + else: + result = input1 * input2 + return (result,) + TEST_NODE_CLASS_MAPPINGS = { "TestLazyMixImages": TestLazyMixImages, "TestVariadicAverage": TestVariadicAverage, "TestCustomIsChanged": TestCustomIsChanged, + "TestCustomValidation1": TestCustomValidation1, + "TestCustomValidation2": TestCustomValidation2, + "TestCustomValidation3": TestCustomValidation3, } TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestLazyMixImages": "Lazy Mix Images", "TestVariadicAverage": "Variadic Average", "TestCustomIsChanged": "Custom IsChanged", + "TestCustomValidation1": "Custom Validation 1", + "TestCustomValidation2": "Custom Validation 2", + "TestCustomValidation3": "Custom Validation 3", } diff --git a/tests/inference/testing_nodes/testing-pack/stubs.py b/tests/inference/testing_nodes/testing-pack/stubs.py index b2a5ebf3d..9be6eac9d 100644 --- a/tests/inference/testing_nodes/testing-pack/stubs.py +++ b/tests/inference/testing_nodes/testing-pack/stubs.py @@ -51,11 +51,55 @@ class StubMask: def stub_mask(self, value, height, width, batch_size): return (torch.ones(batch_size, height, width) * value,) +class StubInt: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("INT", {"default": 0, "min": -0xffffffff, "max": 0xffffffff, "step": 1}), + }, + } + + RETURN_TYPES = ("INT",) + FUNCTION = "stub_int" + + CATEGORY = "Testing/Stub Nodes" + + def stub_int(self, value): + return (value,) + +class StubFloat: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 0.0, "min": -1.0e38, "max": 1.0e38, "step": 0.01}), + }, + } + + RETURN_TYPES = ("FLOAT",) + FUNCTION = "stub_float" + + CATEGORY = "Testing/Stub Nodes" + + def stub_float(self, value): + return (value,) + TEST_STUB_NODE_CLASS_MAPPINGS = { "StubImage": StubImage, "StubMask": StubMask, + "StubInt": StubInt, + "StubFloat": StubFloat, } TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = { "StubImage": "Stub Image", "StubMask": "Stub Mask", + "StubInt": "Stub Int", + "StubFloat": "Stub Float", } diff --git a/tests/inference/testing_nodes/testing-pack/tools.py b/tests/inference/testing_nodes/testing-pack/tools.py new file mode 100644 index 000000000..6c8d5eaa0 --- /dev/null +++ b/tests/inference/testing_nodes/testing-pack/tools.py @@ -0,0 +1,48 @@ + +class SmartType(str): + def __ne__(self, other): + if self == "*" or other == "*": + return False + selfset = set(self.split(',')) + otherset = set(other.split(',')) + return not selfset.issubset(otherset) + +def VariantSupport(): + def decorator(cls): + if hasattr(cls, "INPUT_TYPES"): + old_input_types = getattr(cls, "INPUT_TYPES") + def new_input_types(*args, **kwargs): + types = old_input_types(*args, **kwargs) + for category in ["required", "optional"]: + if category not in types: + continue + for key, value in types[category].items(): + if isinstance(value, tuple): + types[category][key] = (SmartType(value[0]),) + value[1:] + return types + setattr(cls, "INPUT_TYPES", new_input_types) + if hasattr(cls, "RETURN_TYPES"): + old_return_types = cls.RETURN_TYPES + setattr(cls, "RETURN_TYPES", tuple(SmartType(x) for x in old_return_types)) + if hasattr(cls, "VALIDATE_INPUTS"): + # Reflection is used to determine what the function signature is, so we can't just change the function signature + raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet") + else: + def validate_inputs(input_types): + inputs = cls.INPUT_TYPES() + for key, value in input_types.items(): + if isinstance(value, SmartType): + continue + if "required" in inputs and key in inputs["required"]: + expected_type = inputs["required"][key][0] + elif "optional" in inputs and key in inputs["optional"]: + expected_type = inputs["optional"][key][0] + else: + expected_type = None + if expected_type is not None and SmartType(value) != expected_type: + return f"Invalid type of {key}: {value} (expected {expected_type})" + return True + setattr(cls, "VALIDATE_INPUTS", validate_inputs) + return cls + return decorator + diff --git a/tests/inference/testing_nodes/testing-pack/util.py b/tests/inference/testing_nodes/testing-pack/util.py index 16209d3fc..8e2065c7b 100644 --- a/tests/inference/testing_nodes/testing-pack/util.py +++ b/tests/inference/testing_nodes/testing-pack/util.py @@ -1,5 +1,7 @@ from comfy.graph_utils import GraphBuilder +from .tools import VariantSupport +@VariantSupport() class TestAccumulateNode: def __init__(self): pass @@ -27,6 +29,7 @@ class TestAccumulateNode: value = accumulation["accum"] + [to_add] return ({"accum": value},) +@VariantSupport() class TestAccumulationHeadNode: def __init__(self): pass @@ -75,6 +78,7 @@ class TestAccumulationTailNode: else: return ({"accum": accum[:-1]}, accum[-1]) +@VariantSupport() class TestAccumulationToListNode: def __init__(self): pass @@ -97,6 +101,7 @@ class TestAccumulationToListNode: def accumulation_to_list(self, accumulation): return (accumulation["accum"],) +@VariantSupport() class TestListToAccumulationNode: def __init__(self): pass @@ -119,6 +124,7 @@ class TestListToAccumulationNode: def list_to_accumulation(self, list): return ({"accum": list},) +@VariantSupport() class TestAccumulationGetLengthNode: def __init__(self): pass @@ -140,6 +146,7 @@ class TestAccumulationGetLengthNode: def accumlength(self, accumulation): return (len(accumulation['accum']),) +@VariantSupport() class TestAccumulationGetItemNode: def __init__(self): pass @@ -162,6 +169,7 @@ class TestAccumulationGetItemNode: def get_item(self, accumulation, index): return (accumulation['accum'][index],) +@VariantSupport() class TestAccumulationSetItemNode: def __init__(self): pass @@ -222,6 +230,7 @@ class TestIntMathOperation: from .flow_control import NUM_FLOW_SOCKETS +@VariantSupport() class TestForLoopOpen: def __init__(self): pass @@ -257,6 +266,7 @@ class TestForLoopOpen: "expand": graph.finalize(), } +@VariantSupport() class TestForLoopClose: def __init__(self): pass @@ -295,6 +305,7 @@ class TestForLoopClose: } NUM_LIST_SOCKETS = 10 +@VariantSupport() class TestMakeListNode: def __init__(self): pass