diff --git a/api_server/utils/query_params.py b/api_server/utils/query_params.py new file mode 100644 index 000000000..1f0b6a80e --- /dev/null +++ b/api_server/utils/query_params.py @@ -0,0 +1,12 @@ +from collections.abc import Mapping + + +def parse_optional_int_query_param(query: Mapping[str, str], name: str) -> int | None: + value = query.get(name) + if value is None: + return None + + try: + return int(value) + except (TypeError, ValueError) as exc: + raise ValueError(f"{name} must be an integer") from exc diff --git a/server.py b/server.py index 881da8e66..565fa5dc2 100644 --- a/server.py +++ b/server.py @@ -45,6 +45,7 @@ from app.subgraph_manager import SubgraphManager from app.node_replace_manager import NodeReplaceManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes +from api_server.utils.query_params import parse_optional_int_query_param from protocol import BinaryEventTypes # Import cache control middleware @@ -887,14 +888,15 @@ class PromptServer(): @routes.get("/history") async def get_history(request): - max_items = request.rel_url.query.get("max_items", None) - if max_items is not None: - max_items = int(max_items) + query = request.rel_url.query - offset = request.rel_url.query.get("offset", None) - if offset is not None: - offset = int(offset) - else: + try: + max_items = parse_optional_int_query_param(query, "max_items") + offset = parse_optional_int_query_param(query, "offset") + except ValueError as exc: + return web.json_response({"error": str(exc)}, status=400) + + if offset is None: offset = -1 return web.json_response(self.prompt_queue.get_history(max_items=max_items, offset=offset)) diff --git a/tests-unit/server/utils/query_params_test.py b/tests-unit/server/utils/query_params_test.py new file mode 100644 index 000000000..558c6812b --- /dev/null +++ b/tests-unit/server/utils/query_params_test.py @@ -0,0 +1,39 @@ +import pytest + +from api_server.utils.query_params import parse_optional_int_query_param + + +def test_parse_optional_int_query_param_returns_none_when_missing(): + assert parse_optional_int_query_param({}, "offset") is None + + +@pytest.mark.parametrize( + ("raw_value", "expected"), + [ + ("0", 0), + ("5", 5), + ("-1", -1), + ], +) +def test_parse_optional_int_query_param_parses_integers(raw_value, expected): + query = {"offset": raw_value} + + assert parse_optional_int_query_param(query, "offset") == expected + + +@pytest.mark.parametrize( + ("name", "raw_value"), + [ + ("offset", "not-an-integer"), + ("offset", "1.5"), + ("offset", ""), + ("max_items", "not-an-integer"), + ], +) +def test_parse_optional_int_query_param_rejects_invalid_integers(name, raw_value): + query = {name: raw_value} + + with pytest.raises(ValueError) as exc_info: + parse_optional_int_query_param(query, name) + + assert str(exc_info.value) == f"{name} must be an integer" diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index f73ca7e3c..3e16ccec6 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -909,6 +909,20 @@ class TestExecution: assert len(result) <= 1, "Should return at most 1 item when offset is near end" + def test_history_api_rejects_non_integer_max_items(self, client: ComfyClient): + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.get_all_history(max_items="not-an-integer") + + assert exc_info.value.code == 400 + assert json.loads(exc_info.value.read()) == {"error": "max_items must be an integer"} + + def test_history_api_rejects_non_integer_offset(self, client: ComfyClient): + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.get_all_history(offset="not-an-integer") + + assert exc_info.value.code == 400 + assert json.loads(exc_info.value.read()) == {"error": "offset must be an integer"} + # Jobs API tests def test_jobs_api_job_structure( self, client: ComfyClient, builder: GraphBuilder