diff --git a/execution.py b/execution.py index 4450e217e..b2a1cd1e7 100644 --- a/execution.py +++ b/execution.py @@ -309,17 +309,16 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f return results -def merge_result_data(results, obj, is_list_overrides=[]): +def merge_result_data(results, obj): # check which outputs need concatenating output = [] output_is_list = [False] * len(results[0]) if hasattr(obj, "OUTPUT_IS_LIST"): output_is_list = obj.OUTPUT_IS_LIST - is_list_override = is_list_overrides[0] if is_list_overrides else output_is_list # merge node execution results - for i, is_list, override in zip(range(len(results[0])), output_is_list, is_list_override): - if is_list or override: + for i, is_list in zip(range(len(results[0])), output_is_list): + if is_list: value = [] for o in results: if isinstance(o[i], ExecutionBlocker): @@ -420,6 +419,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + output_is_list = [False] * len(class_def.RETURN_TYPES) + if hasattr(class_def, "OUTPUT_IS_LIST"): + output_is_list = class_def.OUTPUT_IS_LIST cached = caches.outputs.get(unique_id) if cached is not None: if server.client_id is not None: @@ -450,35 +452,32 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, elif unique_id in pending_subgraph_results: cached_results = pending_subgraph_results[unique_id] resolved_outputs = [] - is_list_overrides = [] for is_subgraph, result in cached_results: if not is_subgraph: resolved_outputs.append(result) else: resolved_output = [] - is_list_override = [] - for r in result: - if is_link(r): - source_node, source_output = r[0], r[1] - _class_type = dynprompt.get_node(source_node)['class_type'] - _class_def = nodes.NODE_CLASS_MAPPINGS[_class_type] - _source_is_list = False - if hasattr(_class_def, "OUTPUT_IS_LIST"): - _source_is_list = _class_def.OUTPUT_IS_LIST[source_output] - node_cached = execution_list.get_cache(source_node, unique_id) - if _source_is_list: - resolved_output.append(node_cached.outputs[source_output]) - is_list_override.append(_source_is_list) + for i, _result in enumerate(result): + if not output_is_list[i]: + if is_link(_result): + source_node, source_output = _result[0], _result[1] + node_cached = execution_list.get_cache(source_node, unique_id) + if node_cached.outputs[source_output]: + resolved_output.append(node_cached.outputs[source_output][0]) else: - for o in node_cached.outputs[source_output]: - resolved_output.append(o) - is_list_override.append(_source_is_list) + resolved_output.append(_result) else: - resolved_output.append(r) - is_list_override.append(False) + _resolved = [] + for output in _result: + if is_link(output): + source_node, source_output = output[0], output[1] + node_cached = execution_list.get_cache(source_node, unique_id) + _resolved.extend(node_cached.outputs[source_output]) + else: + _resolved.extend(output) + resolved_output.append(_resolved) resolved_outputs.append(tuple(resolved_output)) - is_list_overrides.append(tuple(is_list_override)) - output_data = merge_result_data(resolved_outputs, class_def, is_list_overrides) + output_data = merge_result_data(resolved_outputs, class_def) output_ui = [] del pending_subgraph_results[unique_id] has_subgraph = False @@ -590,9 +589,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: new_output_ids.append(node_id) for i in range(len(node_outputs)): - if is_link(node_outputs[i]): - from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1] - new_output_links.append((from_node_id, from_socket)) + # Consider a returned list if output_is_list on the parent node + _node_outputs = node_outputs[i] + if not output_is_list[i]: + _node_outputs = [_node_outputs] + for node_output in _node_outputs: + if is_link(node_output): + from_node_id, from_socket = node_output[0], node_output[1] + new_output_links.append((from_node_id, from_socket)) cached_outputs.append((True, node_outputs)) new_node_ids = set(new_node_ids) for cache in caches.all: diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index f73ca7e3c..5fd3a5336 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -522,6 +522,58 @@ class TestExecution: for i in range(3): assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white" + # Tests functionality of defining OUTPUT_IS_LIST for expanding nodes. + def test_output_is_list_expansion_results(self, client: ComfyClient, builder: GraphBuilder): + g = builder + list_out = g.node("TestListExpansionResult", value1=0.1) + output = g.node("SaveImage", images=list_out.out(0)) + output_constant = g.node("SaveImage", images=list_out.out(1)) + + # return list of one image (list of one link) + result = client.run(g) + images = result.get_images(output) + assert len(images) == 1, "Should have 1 image" + assert numpy.array(images[0]).min() == 25 and numpy.array(images[0]).max() == 25, "First image should be 0.1" + images_constant = result.get_images(output_constant) + assert len(images_constant) == 1, "Should have 1 image" + assert numpy.array(images_constant[0]).min() == 255 and numpy.array(images_constant[0]).max() == 255, "Image should be white" + + # test return list of two images (list of two links) + list_out.set_input("value2", 0.2) + result = client.run(g) + images = result.get_images(output) + assert len(images) == 2, "Should have 2 images" + assert numpy.array(images[0]).min() == 25 and numpy.array(images[0]).max() == 25, "First image should be 0.1" + assert numpy.array(images[1]).min() == 51 and numpy.array(images[1]).max() == 51, "Second image should be 0.2" + images_constant = result.get_images(output_constant) + assert len(images_constant) == 1, "Should have 1 image" + assert numpy.array(images_constant[0]).min() == 255 and numpy.array(images_constant[0]).max() == 255, "Image should be white" + + # test mixed links and non-link values in returned list + list_out.set_input("value3", 0.3) + result = client.run(g) + images = result.get_images(output) + assert len(images) == 3, "Should have 3 images" + assert numpy.array(images[0]).min() == 25 and numpy.array(images[0]).max() == 25, "First image should be 0.1" + assert numpy.array(images[1]).min() == 51 and numpy.array(images[1]).max() == 51, "Second image should be 0.2" + assert numpy.array(images[2]).min() == 76 and numpy.array(images[2]).max() == 76, "Third image should be 0.3" + images_constant = result.get_images(output_constant) + assert len(images_constant) == 1, "Should have 1 image" + assert numpy.array(images_constant[0]).min() == 255 and numpy.array(images_constant[0]).max() == 255, "Image should be white" + + # test returning list of a single link from an list output subnode + list_out.set_input("value4", 0.4) + result = client.run(g) + images = result.get_images(output) + assert len(images) == 4, "Should have 4 images" + assert numpy.array(images[0]).min() == 25 and numpy.array(images[0]).max() == 25, "First image should be 0.1" + assert numpy.array(images[1]).min() == 51 and numpy.array(images[1]).max() == 51, "Second image should be 0.2" + assert numpy.array(images[2]).min() == 76 and numpy.array(images[2]).max() == 76, "Third image should be 0.3" + assert numpy.array(images[3]).min() == 102 and numpy.array(images[3]).max() == 102, "Fourth image should be 0.4" + images_constant = result.get_images(output_constant) + assert len(images_constant) == 1, "Should have 1 image" + assert numpy.array(images_constant[0]).min() == 255 and numpy.array(images_constant[0]).max() == 255, "Image should be white" + def test_mixed_lazy_results(self, client: ComfyClient, builder: GraphBuilder): g = builder val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0) diff --git a/tests/execution/testing_nodes/testing-pack/specific_tests.py b/tests/execution/testing_nodes/testing-pack/specific_tests.py index 4f8f01ae4..de571c2c1 100644 --- a/tests/execution/testing_nodes/testing-pack/specific_tests.py +++ b/tests/execution/testing_nodes/testing-pack/specific_tests.py @@ -338,6 +338,65 @@ class TestMixedExpansionReturns: "expand": g.finalize(), } +class TestListExpansionResult: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value1": ("FLOAT",), + }, + "optional": { + "value2": ("FLOAT",), + "value3": ("FLOAT",), + "value4": ("FLOAT",), + }, + } + + RETURN_TYPES = ("IMAGE", "IMAGE") + FUNCTION = "result_as_list" + OUTPUT_IS_LIST = (True, False) + + CATEGORY = "Testing/Nodes" + + def result_as_list(self, **kwargs): + g = GraphBuilder() + values = [] + for i in range(4): + key = f"value{i+1}" + if key in kwargs: + values.append(kwargs[key]) + white = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) + if len(values) == 1: + image1 = g.node("StubConstantImage", value=values[0], height=512, width=512, batch_size=1) + return { + "result": ([image1.out(0)], white.out(0)), + "expand": g.finalize(), + } + elif len(values) == 2: + image1 = g.node("StubConstantImage", value=values[0], height=512, width=512, batch_size=1) + image2 = g.node("StubConstantImage", value=values[1], height=512, width=512, batch_size=1) + return { + "result": ([image1.out(0), image2.out(0)], white.out(0)), + "expand": g.finalize(), + } + elif len(values) == 3: + image1 = g.node("StubConstantImage", value=values[0], height=512, width=512, batch_size=1) + image2 = g.node("StubConstantImage", value=values[1], height=512, width=512, batch_size=1) + image3 = torch.ones(1, 512, 512, 3) * values[2] + return { + "result": ([image1.out(0), image2.out(0), image3], white.out(0)), + "expand": g.finalize(), + } + elif len(values) == 4: + list_out = g.node("TestMakeListNode") + for i, value in enumerate(values): + image = g.node("StubConstantImage", value=value, height=512, width=512, batch_size=1) + list_out.set_input(f"value{i+1}", image.out(0)) + return { + "result": ([list_out.out(0)], white.out(0)), + "expand": g.finalize(), + } + class TestSamplingInExpansion: @classmethod def INPUT_TYPES(cls): @@ -494,6 +553,7 @@ TEST_NODE_CLASS_MAPPINGS = { "TestCustomValidation5": TestCustomValidation5, "TestDynamicDependencyCycle": TestDynamicDependencyCycle, "TestMixedExpansionReturns": TestMixedExpansionReturns, + "TestListExpansionResult": TestListExpansionResult, "TestSamplingInExpansion": TestSamplingInExpansion, "TestSleep": TestSleep, "TestParallelSleep": TestParallelSleep, @@ -512,6 +572,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestCustomValidation5": "Custom Validation 5", "TestDynamicDependencyCycle": "Dynamic Dependency Cycle", "TestMixedExpansionReturns": "Mixed Expansion Returns", + "TestListExpansionResult": "Output is List Expansion Result", "TestSamplingInExpansion": "Sampling In Expansion", "TestSleep": "Test Sleep", "TestParallelSleep": "Test Parallel Sleep",