mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-12 17:27:26 +08:00
Add test + changed how expansion nodes outputs are resolved
This commit is contained in:
parent
9a797e1ec4
commit
c2d60e0641
60
execution.py
60
execution.py
@ -309,17 +309,16 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def merge_result_data(results, obj, is_list_overrides=[]):
|
def merge_result_data(results, obj):
|
||||||
# check which outputs need concatenating
|
# check which outputs need concatenating
|
||||||
output = []
|
output = []
|
||||||
output_is_list = [False] * len(results[0])
|
output_is_list = [False] * len(results[0])
|
||||||
if hasattr(obj, "OUTPUT_IS_LIST"):
|
if hasattr(obj, "OUTPUT_IS_LIST"):
|
||||||
output_is_list = obj.OUTPUT_IS_LIST
|
output_is_list = obj.OUTPUT_IS_LIST
|
||||||
is_list_override = is_list_overrides[0] if is_list_overrides else output_is_list
|
|
||||||
|
|
||||||
# merge node execution results
|
# merge node execution results
|
||||||
for i, is_list, override in zip(range(len(results[0])), output_is_list, is_list_override):
|
for i, is_list in zip(range(len(results[0])), output_is_list):
|
||||||
if is_list or override:
|
if is_list:
|
||||||
value = []
|
value = []
|
||||||
for o in results:
|
for o in results:
|
||||||
if isinstance(o[i], ExecutionBlocker):
|
if isinstance(o[i], ExecutionBlocker):
|
||||||
@ -420,6 +419,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
|
output_is_list = [False] * len(class_def.RETURN_TYPES)
|
||||||
|
if hasattr(class_def, "OUTPUT_IS_LIST"):
|
||||||
|
output_is_list = class_def.OUTPUT_IS_LIST
|
||||||
cached = caches.outputs.get(unique_id)
|
cached = caches.outputs.get(unique_id)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
@ -450,35 +452,32 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
elif unique_id in pending_subgraph_results:
|
elif unique_id in pending_subgraph_results:
|
||||||
cached_results = pending_subgraph_results[unique_id]
|
cached_results = pending_subgraph_results[unique_id]
|
||||||
resolved_outputs = []
|
resolved_outputs = []
|
||||||
is_list_overrides = []
|
|
||||||
for is_subgraph, result in cached_results:
|
for is_subgraph, result in cached_results:
|
||||||
if not is_subgraph:
|
if not is_subgraph:
|
||||||
resolved_outputs.append(result)
|
resolved_outputs.append(result)
|
||||||
else:
|
else:
|
||||||
resolved_output = []
|
resolved_output = []
|
||||||
is_list_override = []
|
for i, _result in enumerate(result):
|
||||||
for r in result:
|
if not output_is_list[i]:
|
||||||
if is_link(r):
|
if is_link(_result):
|
||||||
source_node, source_output = r[0], r[1]
|
source_node, source_output = _result[0], _result[1]
|
||||||
_class_type = dynprompt.get_node(source_node)['class_type']
|
node_cached = execution_list.get_cache(source_node, unique_id)
|
||||||
_class_def = nodes.NODE_CLASS_MAPPINGS[_class_type]
|
if node_cached.outputs[source_output]:
|
||||||
_source_is_list = False
|
resolved_output.append(node_cached.outputs[source_output][0])
|
||||||
if hasattr(_class_def, "OUTPUT_IS_LIST"):
|
|
||||||
_source_is_list = _class_def.OUTPUT_IS_LIST[source_output]
|
|
||||||
node_cached = execution_list.get_cache(source_node, unique_id)
|
|
||||||
if _source_is_list:
|
|
||||||
resolved_output.append(node_cached.outputs[source_output])
|
|
||||||
is_list_override.append(_source_is_list)
|
|
||||||
else:
|
else:
|
||||||
for o in node_cached.outputs[source_output]:
|
resolved_output.append(_result)
|
||||||
resolved_output.append(o)
|
|
||||||
is_list_override.append(_source_is_list)
|
|
||||||
else:
|
else:
|
||||||
resolved_output.append(r)
|
_resolved = []
|
||||||
is_list_override.append(False)
|
for output in _result:
|
||||||
|
if is_link(output):
|
||||||
|
source_node, source_output = output[0], output[1]
|
||||||
|
node_cached = execution_list.get_cache(source_node, unique_id)
|
||||||
|
_resolved.extend(node_cached.outputs[source_output])
|
||||||
|
else:
|
||||||
|
_resolved.extend(output)
|
||||||
|
resolved_output.append(_resolved)
|
||||||
resolved_outputs.append(tuple(resolved_output))
|
resolved_outputs.append(tuple(resolved_output))
|
||||||
is_list_overrides.append(tuple(is_list_override))
|
output_data = merge_result_data(resolved_outputs, class_def)
|
||||||
output_data = merge_result_data(resolved_outputs, class_def, is_list_overrides)
|
|
||||||
output_ui = []
|
output_ui = []
|
||||||
del pending_subgraph_results[unique_id]
|
del pending_subgraph_results[unique_id]
|
||||||
has_subgraph = False
|
has_subgraph = False
|
||||||
@ -590,9 +589,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||||
new_output_ids.append(node_id)
|
new_output_ids.append(node_id)
|
||||||
for i in range(len(node_outputs)):
|
for i in range(len(node_outputs)):
|
||||||
if is_link(node_outputs[i]):
|
# Consider a returned list if output_is_list on the parent node
|
||||||
from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1]
|
_node_outputs = node_outputs[i]
|
||||||
new_output_links.append((from_node_id, from_socket))
|
if not output_is_list[i]:
|
||||||
|
_node_outputs = [_node_outputs]
|
||||||
|
for node_output in _node_outputs:
|
||||||
|
if is_link(node_output):
|
||||||
|
from_node_id, from_socket = node_output[0], node_output[1]
|
||||||
|
new_output_links.append((from_node_id, from_socket))
|
||||||
cached_outputs.append((True, node_outputs))
|
cached_outputs.append((True, node_outputs))
|
||||||
new_node_ids = set(new_node_ids)
|
new_node_ids = set(new_node_ids)
|
||||||
for cache in caches.all:
|
for cache in caches.all:
|
||||||
|
|||||||
@ -522,6 +522,58 @@ class TestExecution:
|
|||||||
for i in range(3):
|
for i in range(3):
|
||||||
assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"
|
assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"
|
||||||
|
|
||||||
|
# Tests functionality of defining OUTPUT_IS_LIST for expanding nodes.
|
||||||
|
def test_output_is_list_expansion_results(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
list_out = g.node("TestListExpansionResult", value1=0.1)
|
||||||
|
output = g.node("SaveImage", images=list_out.out(0))
|
||||||
|
output_constant = g.node("SaveImage", images=list_out.out(1))
|
||||||
|
|
||||||
|
# return list of one image (list of one link)
|
||||||
|
result = client.run(g)
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 1, "Should have 1 image"
|
||||||
|
assert numpy.array(images[0]).min() == 25 and numpy.array(images[0]).max() == 25, "First image should be 0.1"
|
||||||
|
images_constant = result.get_images(output_constant)
|
||||||
|
assert len(images_constant) == 1, "Should have 1 image"
|
||||||
|
assert numpy.array(images_constant[0]).min() == 255 and numpy.array(images_constant[0]).max() == 255, "Image should be white"
|
||||||
|
|
||||||
|
# test return list of two images (list of two links)
|
||||||
|
list_out.set_input("value2", 0.2)
|
||||||
|
result = client.run(g)
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 2, "Should have 2 images"
|
||||||
|
assert numpy.array(images[0]).min() == 25 and numpy.array(images[0]).max() == 25, "First image should be 0.1"
|
||||||
|
assert numpy.array(images[1]).min() == 51 and numpy.array(images[1]).max() == 51, "Second image should be 0.2"
|
||||||
|
images_constant = result.get_images(output_constant)
|
||||||
|
assert len(images_constant) == 1, "Should have 1 image"
|
||||||
|
assert numpy.array(images_constant[0]).min() == 255 and numpy.array(images_constant[0]).max() == 255, "Image should be white"
|
||||||
|
|
||||||
|
# test mixed links and non-link values in returned list
|
||||||
|
list_out.set_input("value3", 0.3)
|
||||||
|
result = client.run(g)
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 3, "Should have 3 images"
|
||||||
|
assert numpy.array(images[0]).min() == 25 and numpy.array(images[0]).max() == 25, "First image should be 0.1"
|
||||||
|
assert numpy.array(images[1]).min() == 51 and numpy.array(images[1]).max() == 51, "Second image should be 0.2"
|
||||||
|
assert numpy.array(images[2]).min() == 76 and numpy.array(images[2]).max() == 76, "Third image should be 0.3"
|
||||||
|
images_constant = result.get_images(output_constant)
|
||||||
|
assert len(images_constant) == 1, "Should have 1 image"
|
||||||
|
assert numpy.array(images_constant[0]).min() == 255 and numpy.array(images_constant[0]).max() == 255, "Image should be white"
|
||||||
|
|
||||||
|
# test returning list of a single link from an list output subnode
|
||||||
|
list_out.set_input("value4", 0.4)
|
||||||
|
result = client.run(g)
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 4, "Should have 4 images"
|
||||||
|
assert numpy.array(images[0]).min() == 25 and numpy.array(images[0]).max() == 25, "First image should be 0.1"
|
||||||
|
assert numpy.array(images[1]).min() == 51 and numpy.array(images[1]).max() == 51, "Second image should be 0.2"
|
||||||
|
assert numpy.array(images[2]).min() == 76 and numpy.array(images[2]).max() == 76, "Third image should be 0.3"
|
||||||
|
assert numpy.array(images[3]).min() == 102 and numpy.array(images[3]).max() == 102, "Fourth image should be 0.4"
|
||||||
|
images_constant = result.get_images(output_constant)
|
||||||
|
assert len(images_constant) == 1, "Should have 1 image"
|
||||||
|
assert numpy.array(images_constant[0]).min() == 255 and numpy.array(images_constant[0]).max() == 255, "Image should be white"
|
||||||
|
|
||||||
def test_mixed_lazy_results(self, client: ComfyClient, builder: GraphBuilder):
|
def test_mixed_lazy_results(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0)
|
val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0)
|
||||||
|
|||||||
@ -338,6 +338,65 @@ class TestMixedExpansionReturns:
|
|||||||
"expand": g.finalize(),
|
"expand": g.finalize(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class TestListExpansionResult:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value1": ("FLOAT",),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"value2": ("FLOAT",),
|
||||||
|
"value3": ("FLOAT",),
|
||||||
|
"value4": ("FLOAT",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE", "IMAGE")
|
||||||
|
FUNCTION = "result_as_list"
|
||||||
|
OUTPUT_IS_LIST = (True, False)
|
||||||
|
|
||||||
|
CATEGORY = "Testing/Nodes"
|
||||||
|
|
||||||
|
def result_as_list(self, **kwargs):
|
||||||
|
g = GraphBuilder()
|
||||||
|
values = []
|
||||||
|
for i in range(4):
|
||||||
|
key = f"value{i+1}"
|
||||||
|
if key in kwargs:
|
||||||
|
values.append(kwargs[key])
|
||||||
|
white = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
if len(values) == 1:
|
||||||
|
image1 = g.node("StubConstantImage", value=values[0], height=512, width=512, batch_size=1)
|
||||||
|
return {
|
||||||
|
"result": ([image1.out(0)], white.out(0)),
|
||||||
|
"expand": g.finalize(),
|
||||||
|
}
|
||||||
|
elif len(values) == 2:
|
||||||
|
image1 = g.node("StubConstantImage", value=values[0], height=512, width=512, batch_size=1)
|
||||||
|
image2 = g.node("StubConstantImage", value=values[1], height=512, width=512, batch_size=1)
|
||||||
|
return {
|
||||||
|
"result": ([image1.out(0), image2.out(0)], white.out(0)),
|
||||||
|
"expand": g.finalize(),
|
||||||
|
}
|
||||||
|
elif len(values) == 3:
|
||||||
|
image1 = g.node("StubConstantImage", value=values[0], height=512, width=512, batch_size=1)
|
||||||
|
image2 = g.node("StubConstantImage", value=values[1], height=512, width=512, batch_size=1)
|
||||||
|
image3 = torch.ones(1, 512, 512, 3) * values[2]
|
||||||
|
return {
|
||||||
|
"result": ([image1.out(0), image2.out(0), image3], white.out(0)),
|
||||||
|
"expand": g.finalize(),
|
||||||
|
}
|
||||||
|
elif len(values) == 4:
|
||||||
|
list_out = g.node("TestMakeListNode")
|
||||||
|
for i, value in enumerate(values):
|
||||||
|
image = g.node("StubConstantImage", value=value, height=512, width=512, batch_size=1)
|
||||||
|
list_out.set_input(f"value{i+1}", image.out(0))
|
||||||
|
return {
|
||||||
|
"result": ([list_out.out(0)], white.out(0)),
|
||||||
|
"expand": g.finalize(),
|
||||||
|
}
|
||||||
|
|
||||||
class TestSamplingInExpansion:
|
class TestSamplingInExpansion:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
@ -494,6 +553,7 @@ TEST_NODE_CLASS_MAPPINGS = {
|
|||||||
"TestCustomValidation5": TestCustomValidation5,
|
"TestCustomValidation5": TestCustomValidation5,
|
||||||
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
|
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
|
||||||
"TestMixedExpansionReturns": TestMixedExpansionReturns,
|
"TestMixedExpansionReturns": TestMixedExpansionReturns,
|
||||||
|
"TestListExpansionResult": TestListExpansionResult,
|
||||||
"TestSamplingInExpansion": TestSamplingInExpansion,
|
"TestSamplingInExpansion": TestSamplingInExpansion,
|
||||||
"TestSleep": TestSleep,
|
"TestSleep": TestSleep,
|
||||||
"TestParallelSleep": TestParallelSleep,
|
"TestParallelSleep": TestParallelSleep,
|
||||||
@ -512,6 +572,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"TestCustomValidation5": "Custom Validation 5",
|
"TestCustomValidation5": "Custom Validation 5",
|
||||||
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
|
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
|
||||||
"TestMixedExpansionReturns": "Mixed Expansion Returns",
|
"TestMixedExpansionReturns": "Mixed Expansion Returns",
|
||||||
|
"TestListExpansionResult": "Output is List Expansion Result",
|
||||||
"TestSamplingInExpansion": "Sampling In Expansion",
|
"TestSamplingInExpansion": "Sampling In Expansion",
|
||||||
"TestSleep": "Test Sleep",
|
"TestSleep": "Test Sleep",
|
||||||
"TestParallelSleep": "Test Parallel Sleep",
|
"TestParallelSleep": "Test Parallel Sleep",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user