Pass None for uneval inputs instead of default

This applies to `VALIDATE_INPUTS`, `check_lazy_status`, and lazy values
in evaluation functions.
This commit is contained in:
Jacob Segal 2024-04-21 00:10:04 -07:00
parent ecbef304ed
commit 1f06588905
3 changed files with 64 additions and 8 deletions

View File

@ -41,7 +41,7 @@ class IsChangedCache:
if "is_changed" in node:
self.is_changed[node_id] = node["is_changed"]
else:
input_data_all = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
try:
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
@ -84,18 +84,25 @@ class CacheSet:
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
missing_keys = {}
for x in inputs:
input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x)
def mark_missing():
missing_keys[x] = True
input_data_all[x] = (None,)
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
input_unique_id = input_data[0]
output_index = input_data[1]
if outputs is None:
mark_missing()
continue # This might be a lazily-evaluated input
cached_output = outputs.get(input_unique_id)
if cached_output is None:
mark_missing()
continue
if output_index >= len(cached_output):
mark_missing()
continue
obj = cached_output[output_index]
input_data_all[x] = obj
@ -113,7 +120,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
input_data_all[x] = [extra_data.get('extra_pnginfo', None)]
if h[x] == "UNIQUE_ID":
input_data_all[x] = [unique_id]
return input_data_all
return input_data_all, missing_keys
def map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# check if node wants the lists
@ -275,7 +282,7 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
output_ui = []
has_subgraph = False
else:
input_data_all = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
input_data_all, missing_keys = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@ -288,7 +295,9 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp
if hasattr(obj, "check_lazy_status"):
required_inputs = map_node_over_list(obj, input_data_all, "check_lazy_status", allow_interrupt=True)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and x not in input_data_all]
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
x not in input_data_all or x in missing_keys
)]
if len(required_inputs) > 0:
for i in required_inputs:
execution_list.make_input_strong_link(unique_id, i)
@ -685,7 +694,7 @@ def validate_inputs(prompt, item, validated):
continue
if len(validate_function_inputs) > 0:
input_data_all = get_input_data(inputs, obj_class, unique_id)
input_data_all, _ = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs:

View File

@ -299,6 +299,22 @@ class TestExecution:
else:
client.run(g)
@pytest.mark.parametrize("test_type, test_value, expect_error", [
("StubInt", 5, True),
("StubFloat", 5.0, False)
])
def test_validation_error_edge4(self, test_type, test_value, expect_error, client: ComfyClient, builder: GraphBuilder):
g = builder
stub = g.node(test_type, value=test_value)
validation4 = g.node("TestCustomValidation4", input1=stub.out(0), input2=3.0)
g.node("SaveImage", images=validation4.out(0))
if expect_error:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
else:
client.run(g)
def test_cycle_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)

View File

@ -18,7 +18,7 @@ class TestLazyMixImages:
CATEGORY = "Testing/Nodes"
def check_lazy_status(self, mask, image1 = None, image2 = None):
def check_lazy_status(self, mask, image1, image2):
mask_min = mask.min()
mask_max = mask.max()
needed = []
@ -29,7 +29,7 @@ class TestLazyMixImages:
return needed
# Not trying to handle different batch sizes here just to keep the demo simple
def mix(self, mask, image1 = None, image2 = None):
def mix(self, mask, image1, image2):
mask_min = mask.min()
mask_max = mask.max()
if mask_min == 0.0 and mask_max == 0.0:
@ -45,7 +45,6 @@ class TestLazyMixImages:
mask = mask.repeat(1, 1, 1, image1.shape[3])
result = image1 * (1. - mask) + image2 * mask,
print(result[0])
return (result[0],)
class TestVariadicAverage:
@ -192,6 +191,36 @@ class TestCustomValidation3:
result = input1 * input2
return (result,)
class TestCustomValidation4:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("FLOAT",),
"input2": ("FLOAT",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation4"
CATEGORY = "Testing/Nodes"
def custom_validation4(self, input1, input2):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
return (result,)
@classmethod
def VALIDATE_INPUTS(cls, input1, input2):
if input1 is not None:
if not isinstance(input1, float):
return f"Invalid type of input1: {type(input1)}"
if input2 is not None:
if not isinstance(input2, float):
return f"Invalid type of input2: {type(input2)}"
return True
class TestDynamicDependencyCycle:
@classmethod
def INPUT_TYPES(cls):
@ -228,6 +257,7 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestCustomValidation1": TestCustomValidation1,
"TestCustomValidation2": TestCustomValidation2,
"TestCustomValidation3": TestCustomValidation3,
"TestCustomValidation4": TestCustomValidation4,
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
}
@ -238,5 +268,6 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestCustomValidation1": "Custom Validation 1",
"TestCustomValidation2": "Custom Validation 2",
"TestCustomValidation3": "Custom Validation 3",
"TestCustomValidation4": "Custom Validation 4",
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
}