diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 5ed968960..246bcc2c8 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -327,11 +327,14 @@ class String(ComfyTypeIO): '''String input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, - socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None, + min_length: int=None, max_length: int=None): super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.multiline = multiline self.placeholder = placeholder self.dynamic_prompts = dynamic_prompts + self.min_length = min_length + self.max_length = max_length self.default: str def as_dict(self): @@ -339,6 +342,8 @@ class String(ComfyTypeIO): "multiline": self.multiline, "placeholder": self.placeholder, "dynamicPrompts": self.dynamic_prompts, + "minLength": self.min_length, + "maxLength": self.max_length, }) @comfytype(io_type="COMBO") @@ -1551,6 +1556,12 @@ class Schema: Use this for nodes with interactive/operable UI regions that produce intermediate outputs (e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image). """ + runtime_input_validation: bool = False + """Opt this node into runtime validation of declared input bounds (STRING minLength/maxLength, + INT/FLOAT min/max, COMBO membership) against resolved values, including values that arrive via links. + + When False, only direct widget values are validated pre-execution and linked values flow through unchecked. + """ def validate(self): '''Validate the schema: @@ -2006,6 +2017,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls.GET_SCHEMA() return cls._ACCEPT_ALL_INPUTS + _RUNTIME_INPUT_VALIDATION = None + @final + @classproperty + def RUNTIME_INPUT_VALIDATION(cls): # noqa + if cls._RUNTIME_INPUT_VALIDATION is None: + cls.GET_SCHEMA() + return cls._RUNTIME_INPUT_VALIDATION + @final @classmethod def INPUT_TYPES(cls) -> dict[str, dict]: @@ -2050,6 +2069,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls._NOT_IDEMPOTENT = schema.not_idempotent if cls._ACCEPT_ALL_INPUTS is None: cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs + if cls._RUNTIME_INPUT_VALIDATION is None: + cls._RUNTIME_INPUT_VALIDATION = schema.runtime_input_validation if cls._RETURN_TYPES is None: output = [] diff --git a/execution.py b/execution.py index f37d0360d..2c54392d9 100644 --- a/execution.py +++ b/execution.py @@ -83,7 +83,7 @@ class IsChangedCache: return self.is_changed[node_id] # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _, v3_data, _ = get_input_data(node["inputs"], class_def, node_id, None) try: is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data) is_changed = await resolve_map_node_over_list_results(is_changed) @@ -215,7 +215,52 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= if h[x] == "API_KEY_COMFY_ORG": input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] v3_data["hidden_inputs"] = hidden_inputs_v3 - return input_data_all, missing_keys, v3_data + return input_data_all, missing_keys, v3_data, valid_inputs + +def _check_resolved_input_bounds(name, val, input_type, extra_info): + """Raise ValueError if a single resolved value violates declared bounds.""" + if input_type == "STRING": + if not isinstance(val, str): + return + min_length = extra_info.get("minLength") + max_length = extra_info.get("maxLength") + if min_length is not None and len(val) < min_length: + raise ValueError(f"Input '{name}': string length {len(val)} is shorter than minLength of {min_length}") + if max_length is not None and len(val) > max_length: + raise ValueError(f"Input '{name}': string length {len(val)} is longer than maxLength of {max_length}") + elif input_type in ("INT", "FLOAT"): + if isinstance(val, bool) or not isinstance(val, (int, float)): + return + min_v = extra_info.get("min") + max_v = extra_info.get("max") + if min_v is not None and val < min_v: + raise ValueError(f"Input '{name}': value {val} is smaller than min of {min_v}") + if max_v is not None and val > max_v: + raise ValueError(f"Input '{name}': value {val} is bigger than max of {max_v}") + elif isinstance(input_type, list) or input_type == io.Combo.io_type: + combo_options = extra_info.get("options", []) if input_type == io.Combo.io_type else input_type + is_multiselect = extra_info.get("multiselect", False) + if is_multiselect and isinstance(val, list): + invalid_vals = [v for v in val if v not in combo_options] + else: + invalid_vals = [val] if val not in combo_options else [] + if invalid_vals: + raise ValueError(f"Input '{name}': value(s) {invalid_vals} not in combo options") + + +def _validate_resolved_inputs(class_def, input_data_all, valid_inputs): + """Enforce declared input bounds against resolved values, including values that arrive via links.""" + if not getattr(class_def, "RUNTIME_INPUT_VALIDATION", False): + return + + for x, values in input_data_all.items(): + input_type, _, extra_info = get_input_info(class_def, x, valid_inputs) + if input_type is None or extra_info is None: + continue + for val in values: + if val is None: + continue + _check_resolved_input_bounds(x, val, input_type, extra_info) map_node_over_list = None #Don't hook this please @@ -480,7 +525,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) + input_data_all, missing_keys, v3_data, valid_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -509,6 +554,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, execution_list.make_input_strong_link(unique_id, i) return (ExecutionResult.PENDING, None, None) + _validate_resolved_inputs(class_def, input_data_all, valid_inputs) + def execution_block_cb(block): if block.message is not None: mes = { @@ -1014,6 +1061,36 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): errors.append(error) continue + if input_type == "STRING": + min_length = extra_info.get("minLength") + max_length = extra_info.get("maxLength") + if min_length is not None and len(val) < min_length: + error = { + "type": "value_shorter_than_min_length", + "message": f"Value length {len(val)} shorter than min length of {min_length}", + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue + if max_length is not None and len(val) > max_length: + error = { + "type": "value_longer_than_max_length", + "message": f"Value length {len(val)} longer than max length of {max_length}", + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue + if isinstance(input_type, list) or input_type == io.Combo.io_type: if input_type == io.Combo.io_type: combo_options = extra_info.get("options", []) @@ -1050,7 +1127,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): continue if len(validate_function_inputs) > 0 or validate_has_kwargs: - input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id) + input_data_all, _, v3_data, _ = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs or validate_has_kwargs: diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index f73ca7e3c..cf87ca65a 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -1011,3 +1011,124 @@ class TestExecution: """Test getting a non-existent job returns 404""" job = client.get_job("nonexistent-job-id") assert job is None, "Non-existent job should return None" + + + @pytest.mark.parametrize("text, expect_error", [ + ("hello", False), # 5 chars, within [3, 10] + ("abc", False), # 3 chars, exact min boundary + ("abcdefghij", False), # 10 chars, exact max boundary + ("ab", True), # 2 chars, below min + ("abcdefghijk", True), # 11 chars, above max + ("", True), # 0 chars, below min + ]) + def test_string_length_widget_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder): + """Test minLength/maxLength validation for direct widget values (validate_inputs path).""" + g = builder + node = g.node("StubStringWithLength", text=text) + g.node("SaveImage", images=node.out(0)) + if expect_error: + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.run(g) + assert exc_info.value.code == 400 + else: + client.run(g) + + + @pytest.mark.parametrize("text, expect_error", [ + ("hello", False), # within bounds + ("ab", True), # below min + ("abcdefghijk", True), # above max + ]) + def test_string_length_linked_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder): + """Test minLength/maxLength validation for linked inputs when node opts in via RUNTIME_INPUT_VALIDATION=True.""" + g = builder + str_node = g.node("StubStringOutput", value=text) + node = g.node("StubStringWithLength", text=str_node.out(0)) + g.node("SaveImage", images=node.out(0)) + + if expect_error: + try: + client.run(g) + assert False, "Should have raised an error" + except Exception as e: + assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}" + else: + client.run(g) + + + @pytest.mark.parametrize("text", [ + "ab", # below declared minLength + "abcdefghijk", # above declared maxLength + "", # empty + "hello", # within bounds + ]) + def test_string_length_linked_skipped_without_flag(self, text, client: ComfyClient, builder: GraphBuilder): + """Without RUNTIME_INPUT_VALIDATION=True, declared bounds must NOT be enforced for linked values. + + Preserves V1 behavior: many existing workflows rely on out-of-bounds values passing + through links. Adding declared bounds without the flag must not break them. + """ + g = builder + str_node = g.node("StubStringOutput", value=text) + node = g.node("StubStringWithLengthNoFlag", text=str_node.out(0)) + g.node("SaveImage", images=node.out(0)) + client.run(g) + + + @pytest.mark.parametrize("value, expect_error", [ + (5, False), # within [1, 10] + (1, False), # exact min boundary + (10, False), # exact max boundary + (0, True), # below min + (11, True), # above max + (-7, True), # well below min + ]) + def test_int_bounds_linked_validation(self, value, expect_error, client: ComfyClient, builder: GraphBuilder): + """min/max validation for linked INT inputs when node opts in via RUNTIME_INPUT_VALIDATION=True. + + Direct widget INT values are already validated pre-execution. This test exercises the + symmetric runtime path for values arriving through a connection. + """ + g = builder + int_node = g.node("StubInt", value=value) + node = g.node("StubIntWithBounds", value=int_node.out(0)) + g.node("SaveImage", images=node.out(0)) + + if expect_error: + try: + client.run(g) + assert False, "Should have raised an error" + except Exception as e: + assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}" + else: + client.run(g) + + + @pytest.mark.parametrize("choice, expect_error", [ + ("RED", False), + ("GREEN", False), + ("BLUE", False), + ("PURPLE", True), + ("", True), + ("red", True), # case-sensitive + ]) + def test_combo_membership_linked_validation(self, choice, expect_error, client: ComfyClient, builder: GraphBuilder): + """COMBO option membership for linked values when node opts in via RUNTIME_INPUT_VALIDATION=True. + + StubComboWithOptions declares ``input_types`` in VALIDATE_INPUTS to bypass the engine's + link-type compatibility check, so we can feed a STRING into a COMBO and verify the + runtime membership check fires. + """ + g = builder + str_node = g.node("StubStringOutput", value=choice) + node = g.node("StubComboWithOptions", choice=str_node.out(0)) + g.node("SaveImage", images=node.out(0)) + + if expect_error: + try: + client.run(g) + assert False, "Should have raised an error" + except Exception as e: + assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}" + else: + client.run(g) diff --git a/tests/execution/testing_nodes/testing-pack/stubs.py b/tests/execution/testing_nodes/testing-pack/stubs.py index a1df87529..6ffb71948 100644 --- a/tests/execution/testing_nodes/testing-pack/stubs.py +++ b/tests/execution/testing_nodes/testing-pack/stubs.py @@ -113,12 +113,117 @@ class StubFloat: def stub_float(self, value): return (value,) +class StubStringOutput: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("STRING", {"default": ""}), + }, + } + + RETURN_TYPES = ("STRING",) + FUNCTION = "stub_string" + + CATEGORY = "Testing/Stub Nodes" + + def stub_string(self, value): + return (value,) + +class StubStringWithLength: + """STRING input with declared bounds AND opted in to runtime validation (RUNTIME_INPUT_VALIDATION = True).""" + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "text": ("STRING", {"default": "hello", "minLength": 3, "maxLength": 10}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "stub_string_with_length" + RUNTIME_INPUT_VALIDATION = True + + CATEGORY = "Testing/Stub Nodes" + + def stub_string_with_length(self, text): + return (torch.zeros(1, 64, 64, 3),) + +class StubStringWithLengthNoFlag: + """Same bounds as StubStringWithLength but NOT opted in - linked values must flow through unchecked.""" + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "text": ("STRING", {"default": "hello", "minLength": 3, "maxLength": 10}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "stub_string_with_length_no_flag" + + CATEGORY = "Testing/Stub Nodes" + + def stub_string_with_length_no_flag(self, text): + return (torch.zeros(1, 64, 64, 3),) + +class StubIntWithBounds: + """INT input with min/max bounds AND opted in to runtime validation.""" + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("INT", {"default": 5, "min": 1, "max": 10}), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "stub_int_with_bounds" + RUNTIME_INPUT_VALIDATION = True + + CATEGORY = "Testing/Stub Nodes" + + def stub_int_with_bounds(self, value): + return (torch.zeros(1, 64, 64, 3),) + +class StubComboWithOptions: + """COMBO input opted in to runtime validation. + + Declares ``input_types`` in VALIDATE_INPUTS to bypass the engine's link-type compatibility + check, allowing tests to link a STRING into a COMBO and exercise the runtime membership check. + """ + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "choice": (["RED", "GREEN", "BLUE"],), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "stub_combo" + RUNTIME_INPUT_VALIDATION = True + + CATEGORY = "Testing/Stub Nodes" + + @classmethod + def VALIDATE_INPUTS(cls, input_types): + return True + + def stub_combo(self, choice): + return (torch.zeros(1, 64, 64, 3),) + TEST_STUB_NODE_CLASS_MAPPINGS = { "StubImage": StubImage, "StubConstantImage": StubConstantImage, "StubMask": StubMask, "StubInt": StubInt, "StubFloat": StubFloat, + "StubStringOutput": StubStringOutput, + "StubStringWithLength": StubStringWithLength, + "StubStringWithLengthNoFlag": StubStringWithLengthNoFlag, + "StubIntWithBounds": StubIntWithBounds, + "StubComboWithOptions": StubComboWithOptions, } TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = { "StubImage": "Stub Image", @@ -126,4 +231,9 @@ TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = { "StubMask": "Stub Mask", "StubInt": "Stub Int", "StubFloat": "Stub Float", + "StubStringOutput": "Stub String Output", + "StubStringWithLength": "Stub String With Length", + "StubStringWithLengthNoFlag": "Stub String With Length (No Flag)", + "StubIntWithBounds": "Stub Int With Bounds", + "StubComboWithOptions": "Stub Combo With Options", }