mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 23:00:20 +08:00
Pass None for uneval inputs instead of default
This applies to `VALIDATE_INPUTS`, `check_lazy_status`, and lazy values in evaluation functions.
This commit is contained in:
parent
ecbef304ed
commit
1f06588905
19
execution.py
19
execution.py
@ -41,7 +41,7 @@ class IsChangedCache:
|
|||||||
if "is_changed" in node:
|
if "is_changed" in node:
|
||||||
self.is_changed[node_id] = node["is_changed"]
|
self.is_changed[node_id] = node["is_changed"]
|
||||||
else:
|
else:
|
||||||
input_data_all = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
|
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
|
||||||
try:
|
try:
|
||||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||||
@ -84,18 +84,25 @@ class CacheSet:
|
|||||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
||||||
valid_inputs = class_def.INPUT_TYPES()
|
valid_inputs = class_def.INPUT_TYPES()
|
||||||
input_data_all = {}
|
input_data_all = {}
|
||||||
|
missing_keys = {}
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
input_type, input_category, input_info = get_input_info(class_def, x)
|
input_type, input_category, input_info = get_input_info(class_def, x)
|
||||||
|
def mark_missing():
|
||||||
|
missing_keys[x] = True
|
||||||
|
input_data_all[x] = (None,)
|
||||||
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
|
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
|
||||||
input_unique_id = input_data[0]
|
input_unique_id = input_data[0]
|
||||||
output_index = input_data[1]
|
output_index = input_data[1]
|
||||||
if outputs is None:
|
if outputs is None:
|
||||||
|
mark_missing()
|
||||||
continue # This might be a lazily-evaluated input
|
continue # This might be a lazily-evaluated input
|
||||||
cached_output = outputs.get(input_unique_id)
|
cached_output = outputs.get(input_unique_id)
|
||||||
if cached_output is None:
|
if cached_output is None:
|
||||||
|
mark_missing()
|
||||||
continue
|
continue
|
||||||
if output_index >= len(cached_output):
|
if output_index >= len(cached_output):
|
||||||
|
mark_missing()
|
||||||
continue
|
continue
|
||||||
obj = cached_output[output_index]
|
obj = cached_output[output_index]
|
||||||
input_data_all[x] = obj
|
input_data_all[x] = obj
|
||||||
@ -113,7 +120,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
|
||||||
if h[x] == "UNIQUE_ID":
|
if h[x] == "UNIQUE_ID":
|
||||||
input_data_all[x] = [unique_id]
|
input_data_all[x] = [unique_id]
|
||||||
return input_data_all
|
return input_data_all, missing_keys
|
||||||
|
|
||||||
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||||
# check if node wants the lists
|
# check if node wants the lists
|
||||||
@ -275,7 +282,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
|||||||
output_ui = []
|
output_ui = []
|
||||||
has_subgraph = False
|
has_subgraph = False
|
||||||
else:
|
else:
|
||||||
input_data_all = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.last_node_id = display_node_id
|
server.last_node_id = display_node_id
|
||||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
@ -288,7 +295,9 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
|
|||||||
if hasattr(obj, "check_lazy_status"):
|
if hasattr(obj, "check_lazy_status"):
|
||||||
required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
|
||||||
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
|
||||||
required_inputs = [x for x in required_inputs if isinstance(x,str) and x not in input_data_all]
|
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
|
||||||
|
x not in input_data_all or x in missing_keys
|
||||||
|
)]
|
||||||
if len(required_inputs) > 0:
|
if len(required_inputs) > 0:
|
||||||
for i in required_inputs:
|
for i in required_inputs:
|
||||||
execution_list.make_input_strong_link(unique_id, i)
|
execution_list.make_input_strong_link(unique_id, i)
|
||||||
@ -685,7 +694,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if len(validate_function_inputs) > 0:
|
if len(validate_function_inputs) > 0:
|
||||||
input_data_all = get_input_data(inputs, obj_class, unique_id)
|
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
|
||||||
input_filtered = {}
|
input_filtered = {}
|
||||||
for x in input_data_all:
|
for x in input_data_all:
|
||||||
if x in validate_function_inputs:
|
if x in validate_function_inputs:
|
||||||
|
|||||||
@ -299,6 +299,22 @@ class TestExecution:
|
|||||||
else:
|
else:
|
||||||
client.run(g)
|
client.run(g)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("test_type, test_value, expect_error", [
|
||||||
|
("StubInt", 5, True),
|
||||||
|
("StubFloat", 5.0, False)
|
||||||
|
])
|
||||||
|
def test_validation_error_edge4(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
stub = g.node(test_type, value=test_value)
|
||||||
|
validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0)
|
||||||
|
g.node("SaveImage", images=validation4.out(0))
|
||||||
|
|
||||||
|
if expect_error:
|
||||||
|
with pytest.raises(urllib.error.HTTPError):
|
||||||
|
client.run(g)
|
||||||
|
else:
|
||||||
|
client.run(g)
|
||||||
|
|
||||||
def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
|
def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|||||||
@ -18,7 +18,7 @@ class TestLazyMixImages:
|
|||||||
|
|
||||||
CATEGORY = "Testing/Nodes"
|
CATEGORY = "Testing/Nodes"
|
||||||
|
|
||||||
def check_lazy_status(self, mask, image1 = None, image2 = None):
|
def check_lazy_status(self, mask, image1, image2):
|
||||||
mask_min = mask.min()
|
mask_min = mask.min()
|
||||||
mask_max = mask.max()
|
mask_max = mask.max()
|
||||||
needed = []
|
needed = []
|
||||||
@ -29,7 +29,7 @@ class TestLazyMixImages:
|
|||||||
return needed
|
return needed
|
||||||
|
|
||||||
# Not trying to handle different batch sizes here just to keep the demo simple
|
# Not trying to handle different batch sizes here just to keep the demo simple
|
||||||
def mix(self, mask, image1 = None, image2 = None):
|
def mix(self, mask, image1, image2):
|
||||||
mask_min = mask.min()
|
mask_min = mask.min()
|
||||||
mask_max = mask.max()
|
mask_max = mask.max()
|
||||||
if mask_min == 0.0 and mask_max == 0.0:
|
if mask_min == 0.0 and mask_max == 0.0:
|
||||||
@ -45,7 +45,6 @@ class TestLazyMixImages:
|
|||||||
mask = mask.repeat(1, 1, 1, image1.shape[3])
|
mask = mask.repeat(1, 1, 1, image1.shape[3])
|
||||||
|
|
||||||
result = image1 * (1. - mask) + image2 * mask,
|
result = image1 * (1. - mask) + image2 * mask,
|
||||||
print(result[0])
|
|
||||||
return (result[0],)
|
return (result[0],)
|
||||||
|
|
||||||
class TestVariadicAverage:
|
class TestVariadicAverage:
|
||||||
@ -192,6 +191,36 @@ class TestCustomValidation3:
|
|||||||
result = input1 * input2
|
result = input1 * input2
|
||||||
return (result,)
|
return (result,)
|
||||||
|
|
||||||
|
class TestCustomValidation4:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"input1": ("FLOAT",),
|
||||||
|
"input2": ("FLOAT",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "custom_validation4"
|
||||||
|
|
||||||
|
CATEGORY = "Testing/Nodes"
|
||||||
|
|
||||||
|
def custom_validation4(self, input1, input2):
|
||||||
|
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
||||||
|
return (result,)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def VALIDATE_INPUTS(cls, input1, input2):
|
||||||
|
if input1 is not None:
|
||||||
|
if not isinstance(input1, float):
|
||||||
|
return f"Invalid type of input1: {type(input1)}"
|
||||||
|
if input2 is not None:
|
||||||
|
if not isinstance(input2, float):
|
||||||
|
return f"Invalid type of input2: {type(input2)}"
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
class TestDynamicDependencyCycle:
|
class TestDynamicDependencyCycle:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
@ -228,6 +257,7 @@ TEST_NODE_CLASS_MAPPINGS = {
|
|||||||
"TestCustomValidation1": TestCustomValidation1,
|
"TestCustomValidation1": TestCustomValidation1,
|
||||||
"TestCustomValidation2": TestCustomValidation2,
|
"TestCustomValidation2": TestCustomValidation2,
|
||||||
"TestCustomValidation3": TestCustomValidation3,
|
"TestCustomValidation3": TestCustomValidation3,
|
||||||
|
"TestCustomValidation4": TestCustomValidation4,
|
||||||
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
|
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -238,5 +268,6 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"TestCustomValidation1": "Custom Validation 1",
|
"TestCustomValidation1": "Custom Validation 1",
|
||||||
"TestCustomValidation2": "Custom Validation 2",
|
"TestCustomValidation2": "Custom Validation 2",
|
||||||
"TestCustomValidation3": "Custom Validation 3",
|
"TestCustomValidation3": "Custom Validation 3",
|
||||||
|
"TestCustomValidation4": "Custom Validation 4",
|
||||||
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
|
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user