From 1f065889053a9b1328dde614a729eae886ee60a3 Mon Sep 17 00:00:00 2001 From: Jacob Segal Date: Sun, 21 Apr 2024 00:10:04 -0700 Subject: [PATCH] Pass None for uneval inputs instead of default This applies to `VALIDATE_INPUTS`, `check_lazy_status`, and lazy values in evaluation functions. --- execution.py | 19 +++++++--- tests/inference/test_execution.py | 16 ++++++++ .../testing-pack/specific_tests.py | 37 +++++++++++++++++-- 3 files changed, 64 insertions(+), 8 deletions(-) diff --git a/execution.py b/execution.py index 6d0f40a93..2aa939c1b 100644 --- a/execution.py +++ b/execution.py @@ -41,7 +41,7 @@ class IsChangedCache: if "is_changed" in node: self.is_changed[node_id] = node["is_changed"] else: - input_data_all = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache) + input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache) try: is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED") node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] @@ -84,18 +84,25 @@ class CacheSet: def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}): valid_inputs = class_def.INPUT_TYPES() input_data_all = {} + missing_keys = {} for x in inputs: input_data = inputs[x] input_type, input_category, input_info = get_input_info(class_def, x) + def mark_missing(): + missing_keys[x] = True + input_data_all[x] = (None,) if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): input_unique_id = input_data[0] output_index = input_data[1] if outputs is None: + mark_missing() continue # This might be a lazily-evaluated input cached_output = outputs.get(input_unique_id) if cached_output is None: + mark_missing() continue if output_index >= len(cached_output): + mark_missing() continue obj = cached_output[output_index] input_data_all[x] = obj @@ -113,7 +120,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e input_data_all[x] = [extra_data.get('extra_pnginfo', None)] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] - return input_data_all + return input_data_all, missing_keys def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): # check if node wants the lists @@ -275,7 +282,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp output_ui = [] has_subgraph = False else: - input_data_all = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -288,7 +295,9 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp if hasattr(obj, "check_lazy_status"): required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) - required_inputs = [x for x in required_inputs if isinstance(x,str) and x not in input_data_all] + required_inputs = [x for x in required_inputs if isinstance(x,str) and ( + x not in input_data_all or x in missing_keys + )] if len(required_inputs) > 0: for i in required_inputs: execution_list.make_input_strong_link(unique_id, i) @@ -685,7 +694,7 @@ def validate_inputs(prompt, item, validated): continue if len(validate_function_inputs) > 0: - input_data_all = get_input_data(inputs, obj_class, unique_id) + input_data_all, _ = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs: diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 0ae70b8ca..40de99833 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -299,6 +299,22 @@ class TestExecution: else: client.run(g) + @pytest.mark.parametrize("test_type, test_value, expect_error", [ + ("StubInt", 5, True), + ("StubFloat", 5.0, False) + ]) + def test_validation_error_edge4(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder): + g = builder + stub = g.node(test_type, value=test_value) + validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0) + g.node("SaveImage", images=validation4.out(0)) + + if expect_error: + with pytest.raises(urllib.error.HTTPError): + client.run(g) + else: + client.run(g) + def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/inference/testing_nodes/testing-pack/specific_tests.py index 56b8f70b2..03dbc8b42 100644 --- a/tests/inference/testing_nodes/testing-pack/specific_tests.py +++ b/tests/inference/testing_nodes/testing-pack/specific_tests.py @@ -18,7 +18,7 @@ class TestLazyMixImages: CATEGORY = "Testing/Nodes" - def check_lazy_status(self, mask, image1 = None, image2 = None): + def check_lazy_status(self, mask, image1, image2): mask_min = mask.min() mask_max = mask.max() needed = [] @@ -29,7 +29,7 @@ class TestLazyMixImages: return needed # Not trying to handle different batch sizes here just to keep the demo simple - def mix(self, mask, image1 = None, image2 = None): + def mix(self, mask, image1, image2): mask_min = mask.min() mask_max = mask.max() if mask_min == 0.0 and mask_max == 0.0: @@ -45,7 +45,6 @@ class TestLazyMixImages: mask = mask.repeat(1, 1, 1, image1.shape[3]) result = image1 * (1. - mask) + image2 * mask, - print(result[0]) return (result[0],) class TestVariadicAverage: @@ -192,6 +191,36 @@ class TestCustomValidation3: result = input1 * input2 return (result,) +class TestCustomValidation4: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "input1": ("FLOAT",), + "input2": ("FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "custom_validation4" + + CATEGORY = "Testing/Nodes" + + def custom_validation4(self, input1, input2): + result = torch.ones([1, 512, 512, 3]) * input1 * input2 + return (result,) + + @classmethod + def VALIDATE_INPUTS(cls, input1, input2): + if input1 is not None: + if not isinstance(input1, float): + return f"Invalid type of input1: {type(input1)}" + if input2 is not None: + if not isinstance(input2, float): + return f"Invalid type of input2: {type(input2)}" + + return True + class TestDynamicDependencyCycle: @classmethod def INPUT_TYPES(cls): @@ -228,6 +257,7 @@ TEST_NODE_CLASS_MAPPINGS = { "TestCustomValidation1": TestCustomValidation1, "TestCustomValidation2": TestCustomValidation2, "TestCustomValidation3": TestCustomValidation3, + "TestCustomValidation4": TestCustomValidation4, "TestDynamicDependencyCycle": TestDynamicDependencyCycle, } @@ -238,5 +268,6 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestCustomValidation1": "Custom Validation 1", "TestCustomValidation2": "Custom Validation 2", "TestCustomValidation3": "Custom Validation 3", + "TestCustomValidation4": "Custom Validation 4", "TestDynamicDependencyCycle": "Dynamic Dependency Cycle", }