mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-12 01:07:30 +08:00
feat(assets): add job_ids filter to GET /api/assets (#13998)
* feat(assets): add job_ids filter to GET /api/assets Mirrors the existing cloud `job_ids` query param on the local Python server: clients can pass a comma-separated list (or repeated query params) of UUIDs to filter assets by their associated job. The `AssetReference.job_id` column already exists, so no migration is needed — this just plumbs the filter through schema → service → query. Marks the parameter as available in both runtimes by dropping the `[cloud-only]` description prefix and the `x-runtime: [cloud]` tag from the OpenAPI spec, per the OSS field-drift convention (absent runtime tag = populated by both local and cloud). * fix(assets): tighten job_ids — array schema, max_length, narrow except From cursor-reviews on the parent commit: - OpenAPI: declare job_ids as `type: array, items: string format: uuid` with `style: form, explode: true` so it matches the documented contract (and matches sibling include_tags/exclude_tags shape). Description now states both accepted shapes explicitly. - Schema: cap `job_ids` at 500 entries (max_length on the Pydantic field) so a client can't splice an unbounded list into the IN clauses. - Schema: drop `AttributeError` from the except — `raw` only contains `str` items by construction, so `uuid.UUID(<str>)` raises `ValueError` exclusively; the second clause was dead code. * fix(assets): tighten job_ids validator + add schema-level tests Aligns with the parallel hardening from draft PR #13848 (now closed as a duplicate). The validator now: - Raises ValueError on non-string list items (was: silently dropped). - Raises ValueError on non-string / non-list top-level values like dict or int (was: silently passed through to Pydantic's downstream coercion). Adds tests-unit/assets_test/queries/test_list_assets_query.py covering the validator end-to-end: CSV canonicalization, dedup order, default empty, invalid UUID, non-string list item, non-string non-list value, and the max_length=500 boundary. * feat(prompt): enforce canonical UUID prompt_id at job creation POST /prompt previously accepted any client-supplied prompt_id verbatim, str()-coercing even non-strings, and minting the literal job id "None" for an explicit JSON null. The new GET /api/assets job_ids filter matches stored job ids as canonical UUIDs exactly, so a non-UUID id minted a job whose assets could never be filtered. - validate_job_id (comfy_execution/jobs.py): requires a string in the canonical lowercase hyphenated UUID form; raises ValueError otherwise, including parseable-but-non-canonical spellings (uppercase, braced, URN, bare hex), which would otherwise be silently rewritten and then miss every exact-match lookup downstream (history keys, websocket correlation, /interrupt, the assets job_ids filter). - POST /prompt: absent or null prompt_id means the server mints uuid4; invalid means 400 invalid_prompt_id on the standard error envelope. - openapi.yaml: document the request-side prompt_id (format uuid, nullable) on PromptRequest. - tests: unit matrix for validate_job_id; integration tests against the booted server covering rejection, acceptance, and null handling. --------- Co-authored-by: guill <jacob.e.segal@gmail.com>
This commit is contained in:
parent
6d18f4adac
commit
e5b7140dcc
@ -219,6 +219,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
job_ids=q.job_ids,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=sort,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
@ -53,6 +54,7 @@ class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: str | None = None
|
||||
job_ids: list[str] = Field(default_factory=list, max_length=500)
|
||||
|
||||
# Accept either a JSON string (query param) or a dict
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
@ -86,6 +88,40 @@ class ListAssetsQuery(BaseModel):
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("job_ids", mode="before")
|
||||
@classmethod
|
||||
def _split_and_validate_job_ids(cls, v):
|
||||
# Accept "uuid1,uuid2" or ["uuid1","uuid2"] or repeated query params.
|
||||
# Each entry must parse as a UUID; canonicalized to lowercase hyphenated form.
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
raw = [t.strip() for t in v.split(",") if t.strip()]
|
||||
elif isinstance(v, list):
|
||||
raw = []
|
||||
for item in v:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError(
|
||||
f"job_ids entries must be strings, got {type(item).__name__}"
|
||||
)
|
||||
raw.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"job_ids must be a string or list of strings, got {type(v).__name__}"
|
||||
)
|
||||
|
||||
out: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for s in raw:
|
||||
try:
|
||||
canonical = str(uuid.UUID(s))
|
||||
except ValueError as e:
|
||||
raise ValueError(f"job_ids must be UUIDs: {s!r}") from e
|
||||
if canonical not in seen:
|
||||
seen.add(canonical)
|
||||
out.append(canonical)
|
||||
return out
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
|
||||
@ -264,6 +264,7 @@ def list_references_page(
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
job_ids: Sequence[str] | None = None,
|
||||
sort: str | None = None,
|
||||
order: str | None = None,
|
||||
after_cursor_value: object | None = None,
|
||||
@ -293,6 +294,9 @@ def list_references_page(
|
||||
escaped, esc = escape_sql_like_string(name_contains)
|
||||
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
if job_ids:
|
||||
base = base.where(AssetReference.job_id.in_(list(job_ids)))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
@ -345,6 +349,8 @@ def list_references_page(
|
||||
count_stmt = count_stmt.where(
|
||||
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
||||
)
|
||||
if job_ids:
|
||||
count_stmt = count_stmt.where(AssetReference.job_id.in_(list(job_ids)))
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
|
||||
@ -274,6 +274,7 @@ def list_assets_page(
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
job_ids: Sequence[str] | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
@ -319,6 +320,7 @@ def list_assets_page(
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
job_ids=job_ids,
|
||||
limit=fetch_limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
|
||||
@ -3,6 +3,7 @@ Job utilities for the /api/jobs endpoint.
|
||||
Provides normalization and helper functions for job status tracking.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from comfy_api.internal import prune_dict
|
||||
@ -19,6 +20,26 @@ class JobStatus:
|
||||
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
|
||||
|
||||
|
||||
def validate_job_id(value) -> str:
|
||||
"""Validate a client-supplied job (prompt) id.
|
||||
|
||||
Job ids must be UUIDs in the canonical lowercase hyphenated form. The id
|
||||
is stored and compared verbatim everywhere downstream — history keys,
|
||||
websocket events, /interrupt matching, and the assets ``job_ids`` filter
|
||||
(a String(36) column matched exactly) — so accepting another spelling
|
||||
would either rewrite the client's id behind its back or mint a job whose
|
||||
outputs the filter can never find. Rejecting loudly beats both.
|
||||
|
||||
Returns the id unchanged. Raises ValueError when the value is not a
|
||||
string in canonical UUID form.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"job id must be a string, got {type(value).__name__}")
|
||||
if str(uuid.UUID(value)) != value:
|
||||
raise ValueError("job id must be a UUID in canonical lowercase hyphenated form")
|
||||
return value
|
||||
|
||||
|
||||
# Media types that can be previewed in the frontend
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
|
||||
|
||||
|
||||
@ -896,6 +896,11 @@ components:
|
||||
additionalProperties: true
|
||||
description: The workflow graph to execute
|
||||
type: object
|
||||
prompt_id:
|
||||
description: Optional client-supplied job id. Must be a UUID in canonical lowercase hyphenated form; it is echoed back in the response. Omitted or null means the server generates one.
|
||||
format: uuid
|
||||
nullable: true
|
||||
type: string
|
||||
workflow_id:
|
||||
description: UUID identifying the cloud workflow entity to associate with this job
|
||||
type: string
|
||||
|
||||
18
server.py
18
server.py
@ -8,7 +8,7 @@ import time
|
||||
import nodes
|
||||
import folder_paths
|
||||
import execution
|
||||
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
|
||||
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs, validate_job_id
|
||||
import uuid
|
||||
import urllib
|
||||
import json
|
||||
@ -942,7 +942,21 @@ class PromptServer():
|
||||
|
||||
if "prompt" in json_data:
|
||||
prompt = json_data["prompt"]
|
||||
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
|
||||
client_prompt_id = json_data.get("prompt_id")
|
||||
if client_prompt_id is None:
|
||||
# Absent or explicit null: the server mints the id.
|
||||
prompt_id = str(uuid.uuid4())
|
||||
else:
|
||||
try:
|
||||
prompt_id = validate_job_id(client_prompt_id)
|
||||
except ValueError:
|
||||
error = {
|
||||
"type": "invalid_prompt_id",
|
||||
"message": "prompt_id must be a valid UUID",
|
||||
"details": "prompt_id must be a UUID string in canonical lowercase hyphenated form; omit it to let the server generate one",
|
||||
"extra_info": {}
|
||||
}
|
||||
return web.json_response({"error": error, "node_errors": {}}, status=400)
|
||||
|
||||
partial_execution_targets = None
|
||||
if "partial_execution_targets" in json_data:
|
||||
|
||||
@ -158,6 +158,56 @@ class TestListReferencesPage:
|
||||
refs, _, _ = list_references_page(session, sort="name", order="asc")
|
||||
assert refs[0].name == "large"
|
||||
|
||||
def test_job_ids_filter(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
job_a = str(uuid.uuid4())
|
||||
job_b = str(uuid.uuid4())
|
||||
ref_a = _make_reference(session, asset, name="from_job_a")
|
||||
ref_a.job_id = job_a
|
||||
ref_b = _make_reference(session, asset, name="from_job_b")
|
||||
ref_b.job_id = job_b
|
||||
_make_reference(session, asset, name="no_job")
|
||||
session.commit()
|
||||
|
||||
# Single job filter
|
||||
refs, _, total = list_references_page(session, job_ids=[job_a])
|
||||
assert total == 1
|
||||
assert refs[0].name == "from_job_a"
|
||||
|
||||
# Multi-job filter (IN)
|
||||
refs, _, total = list_references_page(session, job_ids=[job_a, job_b])
|
||||
names = sorted(r.name for r in refs)
|
||||
assert total == 2
|
||||
assert names == ["from_job_a", "from_job_b"]
|
||||
|
||||
# Unknown job id matches nothing
|
||||
refs, _, total = list_references_page(session, job_ids=[str(uuid.uuid4())])
|
||||
assert total == 0
|
||||
assert refs == []
|
||||
|
||||
# Empty/None means no filter -> all three references
|
||||
refs, _, total = list_references_page(session, job_ids=[])
|
||||
assert total == 3
|
||||
refs, _, total = list_references_page(session, job_ids=None)
|
||||
assert total == 3
|
||||
|
||||
def test_job_ids_combined_with_other_filters(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
job_a = str(uuid.uuid4())
|
||||
ref_match = _make_reference(session, asset, name="match.bin")
|
||||
ref_match.job_id = job_a
|
||||
ref_wrong_name = _make_reference(session, asset, name="other.bin")
|
||||
ref_wrong_name.job_id = job_a
|
||||
ref_wrong_job = _make_reference(session, asset, name="match.bin")
|
||||
ref_wrong_job.job_id = str(uuid.uuid4())
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(
|
||||
session, job_ids=[job_a], name_contains="match"
|
||||
)
|
||||
assert total == 1
|
||||
assert refs[0].id == ref_match.id
|
||||
|
||||
|
||||
class TestFetchReferenceAssetAndTags:
|
||||
def test_returns_none_for_nonexistent(self, session: Session):
|
||||
|
||||
60
tests-unit/assets_test/queries/test_list_assets_query.py
Normal file
60
tests-unit/assets_test/queries/test_list_assets_query.py
Normal file
@ -0,0 +1,60 @@
|
||||
"""Schema-level unit tests for ListAssetsQuery (no DB required)."""
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.assets.api.schemas_in import ListAssetsQuery
|
||||
|
||||
|
||||
class TestJobIdsValidator:
|
||||
def test_csv_string_parses_and_canonicalizes(self):
|
||||
a = "AAAAAAAA-BBBB-CCCC-DDDD-EEEEEEEEEEEE"
|
||||
b = "11111111-2222-3333-4444-555555555555"
|
||||
q = ListAssetsQuery.model_validate({"job_ids": f"{a},{b}"})
|
||||
# Canonicalized to lowercase
|
||||
assert q.job_ids == [a.lower(), b]
|
||||
|
||||
def test_repeated_query_params_as_list(self):
|
||||
a = "11111111-1111-1111-1111-111111111111"
|
||||
b = "22222222-2222-2222-2222-222222222222"
|
||||
q = ListAssetsQuery.model_validate({"job_ids": [a, b]})
|
||||
assert q.job_ids == [a, b]
|
||||
|
||||
def test_dedup_preserves_first_seen_order(self):
|
||||
a = "11111111-1111-1111-1111-111111111111"
|
||||
b = "22222222-2222-2222-2222-222222222222"
|
||||
q = ListAssetsQuery.model_validate({"job_ids": [a, b, a]})
|
||||
assert q.job_ids == [a, b]
|
||||
|
||||
def test_default_empty(self):
|
||||
q = ListAssetsQuery.model_validate({})
|
||||
assert q.job_ids == []
|
||||
|
||||
def test_invalid_uuid_rejected(self):
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
ListAssetsQuery.model_validate({"job_ids": "not-a-uuid"})
|
||||
assert "must be UUIDs" in str(exc.value)
|
||||
|
||||
def test_non_string_list_item_rejected(self):
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
ListAssetsQuery.model_validate(
|
||||
{"job_ids": ["11111111-1111-1111-1111-111111111111", 42]}
|
||||
)
|
||||
assert "must be strings" in str(exc.value)
|
||||
|
||||
def test_non_string_non_list_value_rejected(self):
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
ListAssetsQuery.model_validate({"job_ids": {"bad": "shape"}})
|
||||
assert "must be a string or list of strings" in str(exc.value)
|
||||
|
||||
def test_max_length_enforced(self):
|
||||
too_many = [str(uuid.uuid4()) for _ in range(501)]
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
ListAssetsQuery.model_validate({"job_ids": too_many})
|
||||
assert exc.value.errors()[0]["type"] == "too_long"
|
||||
|
||||
def test_max_length_boundary_accepted(self):
|
||||
at_cap = [str(uuid.uuid4()) for _ in range(500)]
|
||||
q = ListAssetsQuery.model_validate({"job_ids": at_cap})
|
||||
assert len(q.job_ids) == 500
|
||||
69
tests-unit/assets_test/test_prompt_id_enforcement.py
Normal file
69
tests-unit/assets_test/test_prompt_id_enforcement.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""POST /prompt enforces canonical-UUID job ids at creation time.
|
||||
|
||||
Lives in assets_test because it uses this suite's booted-server fixture and
|
||||
because the invariant exists for the assets pipeline: the GET /api/assets
|
||||
``job_ids`` filter matches stored job ids exactly, so a job minted with a
|
||||
non-canonical id would produce assets the filter can never find.
|
||||
|
||||
The prompt bodies here are intentionally invalid workflows — prompt_id
|
||||
validation happens before workflow validation, so a rejected id returns
|
||||
``invalid_prompt_id`` while an accepted id falls through to the ordinary
|
||||
workflow-validation error (proving it cleared the id check).
|
||||
"""
|
||||
import requests
|
||||
|
||||
|
||||
def _post_prompt(http: requests.Session, api_base: str, body: dict) -> requests.Response:
|
||||
return http.post(api_base + "/prompt", json=body, timeout=30)
|
||||
|
||||
|
||||
def _error_type(r: requests.Response) -> str:
|
||||
return r.json()["error"]["type"]
|
||||
|
||||
|
||||
def test_non_uuid_prompt_id_rejected(http: requests.Session, api_base: str):
|
||||
r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": "not-a-uuid"})
|
||||
assert r.status_code == 400, r.text
|
||||
assert _error_type(r) == "invalid_prompt_id"
|
||||
|
||||
|
||||
def test_non_string_prompt_id_rejected(http: requests.Session, api_base: str):
|
||||
# Previously str()-coerced (123 became the job id "123"); must now be a 400,
|
||||
# not a 500 from uuid.UUID choking on a non-string.
|
||||
r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": 123})
|
||||
assert r.status_code == 400, r.text
|
||||
assert _error_type(r) == "invalid_prompt_id"
|
||||
|
||||
|
||||
def test_non_canonical_uuid_rejected(http: requests.Session, api_base: str):
|
||||
# Parseable as a UUID, but not the canonical lowercase form: rejected
|
||||
# loudly rather than silently rewritten (downstream lookups match the
|
||||
# stored id exactly).
|
||||
r = _post_prompt(
|
||||
http,
|
||||
api_base,
|
||||
{"prompt": {}, "prompt_id": "AAAAAAAA-BBBB-4CCC-8DDD-EEEEEEEEEEEE"},
|
||||
)
|
||||
assert r.status_code == 400, r.text
|
||||
assert _error_type(r) == "invalid_prompt_id"
|
||||
|
||||
|
||||
def test_canonical_uuid_accepted(http: requests.Session, api_base: str):
|
||||
# The id clears validation; the empty workflow then fails ordinary prompt
|
||||
# validation, proving the request got past the id check.
|
||||
r = _post_prompt(
|
||||
http,
|
||||
api_base,
|
||||
{"prompt": {}, "prompt_id": "aaaaaaaa-bbbb-4ccc-8ddd-eeeeeeeeeeee"},
|
||||
)
|
||||
assert r.status_code == 400, r.text
|
||||
assert _error_type(r) != "invalid_prompt_id"
|
||||
|
||||
|
||||
def test_null_prompt_id_not_rejected(http: requests.Session, api_base: str):
|
||||
# Explicit null means "server generates" and must not be rejected as an
|
||||
# invalid id. (The minted id itself is not observable here because the
|
||||
# workflow is invalid; unit tests cover validate_job_id directly.)
|
||||
r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": None})
|
||||
assert r.status_code == 400, r.text
|
||||
assert _error_type(r) != "invalid_prompt_id"
|
||||
@ -1,5 +1,7 @@
|
||||
"""Unit tests for comfy_execution/jobs.py"""
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy_execution.jobs import (
|
||||
JobStatus,
|
||||
is_previewable,
|
||||
@ -10,9 +12,50 @@ from comfy_execution.jobs import (
|
||||
get_outputs_summary,
|
||||
apply_sorting,
|
||||
has_3d_extension,
|
||||
validate_job_id,
|
||||
)
|
||||
|
||||
|
||||
class TestValidateJobId:
|
||||
"""validate_job_id guards job creation: POST /prompt rejects ids it raises on."""
|
||||
|
||||
def test_canonical_form_passes_through(self):
|
||||
cid = "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7"
|
||||
assert validate_job_id(cid) == cid
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"variant",
|
||||
[
|
||||
"A1B2C3D4-E5F6-7A89-B0C1-D2E3F4A5B6C7", # uppercase
|
||||
"{a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7}", # braced
|
||||
"urn:uuid:a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7", # URN
|
||||
"a1b2c3d4e5f67a89b0c1d2e3f4a5b6c7", # bare hex
|
||||
" a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7 ", # padded
|
||||
],
|
||||
)
|
||||
def test_non_canonical_spellings_rejected(self, variant):
|
||||
# uuid.UUID parses all of these, but accepting them would silently
|
||||
# rewrite the client's id (history keys, websocket events, and the
|
||||
# assets job_ids filter all match the stored form exactly).
|
||||
with pytest.raises(ValueError):
|
||||
validate_job_id(variant)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad",
|
||||
["", "not-a-uuid", "prompt-123", "a1b2c3d4-e5f6-7a89-b0c1", "None"],
|
||||
)
|
||||
def test_non_uuid_strings_rejected(self, bad):
|
||||
with pytest.raises(ValueError):
|
||||
validate_job_id(bad)
|
||||
|
||||
@pytest.mark.parametrize("bad", [123, 1.5, True, None, ["a"], {"id": "x"}])
|
||||
def test_non_strings_rejected(self, bad):
|
||||
# uuid.UUID raises AttributeError/TypeError on non-strings; the helper
|
||||
# must normalize those to ValueError so callers need one except clause.
|
||||
with pytest.raises(ValueError):
|
||||
validate_job_id(bad)
|
||||
|
||||
|
||||
class TestJobStatus:
|
||||
"""Test JobStatus constants."""
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user