This commit is contained in:
Alexander Piskun 2026-05-14 20:56:22 +03:00 committed by GitHub
commit 7ce06b48fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 334 additions and 5 deletions

View File

@ -327,11 +327,14 @@ class String(ComfyTypeIO):
'''String input.''' '''String input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None,
min_length: int=None, max_length: int=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.multiline = multiline self.multiline = multiline
self.placeholder = placeholder self.placeholder = placeholder
self.dynamic_prompts = dynamic_prompts self.dynamic_prompts = dynamic_prompts
self.min_length = min_length
self.max_length = max_length
self.default: str self.default: str
def as_dict(self): def as_dict(self):
@ -339,6 +342,8 @@ class String(ComfyTypeIO):
"multiline": self.multiline, "multiline": self.multiline,
"placeholder": self.placeholder, "placeholder": self.placeholder,
"dynamicPrompts": self.dynamic_prompts, "dynamicPrompts": self.dynamic_prompts,
"minLength": self.min_length,
"maxLength": self.max_length,
}) })
@comfytype(io_type="COMBO") @comfytype(io_type="COMBO")
@ -1551,6 +1556,12 @@ class Schema:
Use this for nodes with interactive/operable UI regions that produce intermediate outputs Use this for nodes with interactive/operable UI regions that produce intermediate outputs
(e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image). (e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image).
""" """
runtime_input_validation: bool = False
"""Opt this node into runtime validation of declared input bounds (STRING minLength/maxLength,
INT/FLOAT min/max, COMBO membership) against resolved values, including values that arrive via links.
When False, only direct widget values are validated pre-execution and linked values flow through unchecked.
"""
def validate(self): def validate(self):
'''Validate the schema: '''Validate the schema:
@ -2006,6 +2017,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls.GET_SCHEMA() cls.GET_SCHEMA()
return cls._ACCEPT_ALL_INPUTS return cls._ACCEPT_ALL_INPUTS
_RUNTIME_INPUT_VALIDATION = None
@final
@classproperty
def RUNTIME_INPUT_VALIDATION(cls): # noqa
if cls._RUNTIME_INPUT_VALIDATION is None:
cls.GET_SCHEMA()
return cls._RUNTIME_INPUT_VALIDATION
@final @final
@classmethod @classmethod
def INPUT_TYPES(cls) -> dict[str, dict]: def INPUT_TYPES(cls) -> dict[str, dict]:
@ -2050,6 +2069,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls._NOT_IDEMPOTENT = schema.not_idempotent cls._NOT_IDEMPOTENT = schema.not_idempotent
if cls._ACCEPT_ALL_INPUTS is None: if cls._ACCEPT_ALL_INPUTS is None:
cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs
if cls._RUNTIME_INPUT_VALIDATION is None:
cls._RUNTIME_INPUT_VALIDATION = schema.runtime_input_validation
if cls._RETURN_TYPES is None: if cls._RETURN_TYPES is None:
output = [] output = []

View File

@ -83,7 +83,7 @@ class IsChangedCache:
return self.is_changed[node_id] return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) input_data_all, _, v3_data, _ = get_input_data(node["inputs"], class_def, node_id, None)
try: try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data) is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
is_changed = await resolve_map_node_over_list_results(is_changed) is_changed = await resolve_map_node_over_list_results(is_changed)
@ -215,7 +215,52 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
if h[x] == "API_KEY_COMFY_ORG": if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
v3_data["hidden_inputs"] = hidden_inputs_v3 v3_data["hidden_inputs"] = hidden_inputs_v3
return input_data_all, missing_keys, v3_data return input_data_all, missing_keys, v3_data, valid_inputs
def _check_resolved_input_bounds(name, val, input_type, extra_info):
"""Raise ValueError if a single resolved value violates declared bounds."""
if input_type == "STRING":
if not isinstance(val, str):
return
min_length = extra_info.get("minLength")
max_length = extra_info.get("maxLength")
if min_length is not None and len(val) < min_length:
raise ValueError(f"Input '{name}': string length {len(val)} is shorter than minLength of {min_length}")
if max_length is not None and len(val) > max_length:
raise ValueError(f"Input '{name}': string length {len(val)} is longer than maxLength of {max_length}")
elif input_type in ("INT", "FLOAT"):
if isinstance(val, bool) or not isinstance(val, (int, float)):
return
min_v = extra_info.get("min")
max_v = extra_info.get("max")
if min_v is not None and val < min_v:
raise ValueError(f"Input '{name}': value {val} is smaller than min of {min_v}")
if max_v is not None and val > max_v:
raise ValueError(f"Input '{name}': value {val} is bigger than max of {max_v}")
elif isinstance(input_type, list) or input_type == io.Combo.io_type:
combo_options = extra_info.get("options", []) if input_type == io.Combo.io_type else input_type
is_multiselect = extra_info.get("multiselect", False)
if is_multiselect and isinstance(val, list):
invalid_vals = [v for v in val if v not in combo_options]
else:
invalid_vals = [val] if val not in combo_options else []
if invalid_vals:
raise ValueError(f"Input '{name}': value(s) {invalid_vals} not in combo options")
def _validate_resolved_inputs(class_def, input_data_all, valid_inputs):
"""Enforce declared input bounds against resolved values, including values that arrive via links."""
if not getattr(class_def, "RUNTIME_INPUT_VALIDATION", False):
return
for x, values in input_data_all.items():
input_type, _, extra_info = get_input_info(class_def, x, valid_inputs)
if input_type is None or extra_info is None:
continue
for val in values:
if val is None:
continue
_check_resolved_input_bounds(x, val, input_type, extra_info)
map_node_over_list = None #Don't hook this please map_node_over_list = None #Don't hook this please
@ -480,7 +525,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False has_subgraph = False
else: else:
get_progress_state().start_progress(unique_id) get_progress_state().start_progress(unique_id)
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) input_data_all, missing_keys, v3_data, valid_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None: if server.client_id is not None:
server.last_node_id = display_node_id server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@ -509,6 +554,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
execution_list.make_input_strong_link(unique_id, i) execution_list.make_input_strong_link(unique_id, i)
return (ExecutionResult.PENDING, None, None) return (ExecutionResult.PENDING, None, None)
_validate_resolved_inputs(class_def, input_data_all, valid_inputs)
def execution_block_cb(block): def execution_block_cb(block):
if block.message is not None: if block.message is not None:
mes = { mes = {
@ -1014,6 +1061,36 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
errors.append(error) errors.append(error)
continue continue
if input_type == "STRING":
min_length = extra_info.get("minLength")
max_length = extra_info.get("maxLength")
if min_length is not None and len(val) < min_length:
error = {
"type": "value_shorter_than_min_length",
"message": f"Value length {len(val)} shorter than min length of {min_length}",
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if max_length is not None and len(val) > max_length:
error = {
"type": "value_longer_than_max_length",
"message": f"Value length {len(val)} longer than max length of {max_length}",
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if isinstance(input_type, list) or input_type == io.Combo.io_type: if isinstance(input_type, list) or input_type == io.Combo.io_type:
if input_type == io.Combo.io_type: if input_type == io.Combo.io_type:
combo_options = extra_info.get("options", []) combo_options = extra_info.get("options", [])
@ -1050,7 +1127,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
continue continue
if len(validate_function_inputs) > 0 or validate_has_kwargs: if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id) input_data_all, _, v3_data, _ = get_input_data(inputs, obj_class, unique_id)
input_filtered = {} input_filtered = {}
for x in input_data_all: for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs: if x in validate_function_inputs or validate_has_kwargs:

View File

@ -1011,3 +1011,124 @@ class TestExecution:
"""Test getting a non-existent job returns 404""" """Test getting a non-existent job returns 404"""
job = client.get_job("nonexistent-job-id") job = client.get_job("nonexistent-job-id")
assert job is None, "Non-existent job should return None" assert job is None, "Non-existent job should return None"
@pytest.mark.parametrize("text, expect_error", [
("hello", False), # 5 chars, within [3, 10]
("abc", False), # 3 chars, exact min boundary
("abcdefghij", False), # 10 chars, exact max boundary
("ab", True), # 2 chars, below min
("abcdefghijk", True), # 11 chars, above max
("", True), # 0 chars, below min
])
def test_string_length_widget_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder):
"""Test minLength/maxLength validation for direct widget values (validate_inputs path)."""
g = builder
node = g.node("StubStringWithLength", text=text)
g.node("SaveImage", images=node.out(0))
if expect_error:
with pytest.raises(urllib.error.HTTPError) as exc_info:
client.run(g)
assert exc_info.value.code == 400
else:
client.run(g)
@pytest.mark.parametrize("text, expect_error", [
("hello", False), # within bounds
("ab", True), # below min
("abcdefghijk", True), # above max
])
def test_string_length_linked_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder):
"""Test minLength/maxLength validation for linked inputs when node opts in via RUNTIME_INPUT_VALIDATION=True."""
g = builder
str_node = g.node("StubStringOutput", value=text)
node = g.node("StubStringWithLength", text=str_node.out(0))
g.node("SaveImage", images=node.out(0))
if expect_error:
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
else:
client.run(g)
@pytest.mark.parametrize("text", [
"ab", # below declared minLength
"abcdefghijk", # above declared maxLength
"", # empty
"hello", # within bounds
])
def test_string_length_linked_skipped_without_flag(self, text, client: ComfyClient, builder: GraphBuilder):
"""Without RUNTIME_INPUT_VALIDATION=True, declared bounds must NOT be enforced for linked values.
Preserves V1 behavior: many existing workflows rely on out-of-bounds values passing
through links. Adding declared bounds without the flag must not break them.
"""
g = builder
str_node = g.node("StubStringOutput", value=text)
node = g.node("StubStringWithLengthNoFlag", text=str_node.out(0))
g.node("SaveImage", images=node.out(0))
client.run(g)
@pytest.mark.parametrize("value, expect_error", [
(5, False), # within [1, 10]
(1, False), # exact min boundary
(10, False), # exact max boundary
(0, True), # below min
(11, True), # above max
(-7, True), # well below min
])
def test_int_bounds_linked_validation(self, value, expect_error, client: ComfyClient, builder: GraphBuilder):
"""min/max validation for linked INT inputs when node opts in via RUNTIME_INPUT_VALIDATION=True.
Direct widget INT values are already validated pre-execution. This test exercises the
symmetric runtime path for values arriving through a connection.
"""
g = builder
int_node = g.node("StubInt", value=value)
node = g.node("StubIntWithBounds", value=int_node.out(0))
g.node("SaveImage", images=node.out(0))
if expect_error:
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
else:
client.run(g)
@pytest.mark.parametrize("choice, expect_error", [
("RED", False),
("GREEN", False),
("BLUE", False),
("PURPLE", True),
("", True),
("red", True), # case-sensitive
])
def test_combo_membership_linked_validation(self, choice, expect_error, client: ComfyClient, builder: GraphBuilder):
"""COMBO option membership for linked values when node opts in via RUNTIME_INPUT_VALIDATION=True.
StubComboWithOptions declares ``input_types`` in VALIDATE_INPUTS to bypass the engine's
link-type compatibility check, so we can feed a STRING into a COMBO and verify the
runtime membership check fires.
"""
g = builder
str_node = g.node("StubStringOutput", value=choice)
node = g.node("StubComboWithOptions", choice=str_node.out(0))
g.node("SaveImage", images=node.out(0))
if expect_error:
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
else:
client.run(g)

View File

@ -113,12 +113,117 @@ class StubFloat:
def stub_float(self, value): def stub_float(self, value):
return (value,) return (value,)
class StubStringOutput:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("STRING", {"default": ""}),
},
}
RETURN_TYPES = ("STRING",)
FUNCTION = "stub_string"
CATEGORY = "Testing/Stub Nodes"
def stub_string(self, value):
return (value,)
class StubStringWithLength:
"""STRING input with declared bounds AND opted in to runtime validation (RUNTIME_INPUT_VALIDATION = True)."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": ("STRING", {"default": "hello", "minLength": 3, "maxLength": 10}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_string_with_length"
RUNTIME_INPUT_VALIDATION = True
CATEGORY = "Testing/Stub Nodes"
def stub_string_with_length(self, text):
return (torch.zeros(1, 64, 64, 3),)
class StubStringWithLengthNoFlag:
"""Same bounds as StubStringWithLength but NOT opted in - linked values must flow through unchecked."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": ("STRING", {"default": "hello", "minLength": 3, "maxLength": 10}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_string_with_length_no_flag"
CATEGORY = "Testing/Stub Nodes"
def stub_string_with_length_no_flag(self, text):
return (torch.zeros(1, 64, 64, 3),)
class StubIntWithBounds:
"""INT input with min/max bounds AND opted in to runtime validation."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("INT", {"default": 5, "min": 1, "max": 10}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_int_with_bounds"
RUNTIME_INPUT_VALIDATION = True
CATEGORY = "Testing/Stub Nodes"
def stub_int_with_bounds(self, value):
return (torch.zeros(1, 64, 64, 3),)
class StubComboWithOptions:
"""COMBO input opted in to runtime validation.
Declares ``input_types`` in VALIDATE_INPUTS to bypass the engine's link-type compatibility
check, allowing tests to link a STRING into a COMBO and exercise the runtime membership check.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"choice": (["RED", "GREEN", "BLUE"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_combo"
RUNTIME_INPUT_VALIDATION = True
CATEGORY = "Testing/Stub Nodes"
@classmethod
def VALIDATE_INPUTS(cls, input_types):
return True
def stub_combo(self, choice):
return (torch.zeros(1, 64, 64, 3),)
TEST_STUB_NODE_CLASS_MAPPINGS = { TEST_STUB_NODE_CLASS_MAPPINGS = {
"StubImage": StubImage, "StubImage": StubImage,
"StubConstantImage": StubConstantImage, "StubConstantImage": StubConstantImage,
"StubMask": StubMask, "StubMask": StubMask,
"StubInt": StubInt, "StubInt": StubInt,
"StubFloat": StubFloat, "StubFloat": StubFloat,
"StubStringOutput": StubStringOutput,
"StubStringWithLength": StubStringWithLength,
"StubStringWithLengthNoFlag": StubStringWithLengthNoFlag,
"StubIntWithBounds": StubIntWithBounds,
"StubComboWithOptions": StubComboWithOptions,
} }
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = { TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
"StubImage": "Stub Image", "StubImage": "Stub Image",
@ -126,4 +231,9 @@ TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
"StubMask": "Stub Mask", "StubMask": "Stub Mask",
"StubInt": "Stub Int", "StubInt": "Stub Int",
"StubFloat": "Stub Float", "StubFloat": "Stub Float",
"StubStringOutput": "Stub String Output",
"StubStringWithLength": "Stub String With Length",
"StubStringWithLengthNoFlag": "Stub String With Length (No Flag)",
"StubIntWithBounds": "Stub Int With Bounds",
"StubComboWithOptions": "Stub Combo With Options",
} }