Add job_ids filter for GET api assets (#13998)

Amp-Thread-ID: https://ampcode.com/threads/T-019e4ca5-b71a-7168-8f56-58b2325f34c3
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Simon Pinfold 2026-05-22 10:53:50 +12:00
parent 00c88a4634
commit 560e6ee5c1
7 changed files with 166 additions and 0 deletions

View File

@ -227,6 +227,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
job_ids=q.job_ids,
limit=q.limit,
offset=q.offset,
sort=sort,

View File

@ -1,4 +1,5 @@
import json
import uuid
from dataclasses import dataclass
from typing import Any, Literal
@ -53,6 +54,7 @@ class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
job_ids: list[str] = Field(default_factory=list, max_length=500)
# Accept either a JSON string (query param) or a dict
metadata_filter: dict[str, Any] | None = None
@ -86,6 +88,40 @@ class ListAssetsQuery(BaseModel):
return out
return v
@field_validator("job_ids", mode="before")
@classmethod
def _split_and_validate_job_ids(cls, v):
# Accept "uuid1,uuid2" or ["uuid1","uuid2"] or repeated query params.
# Each entry must parse as a UUID; canonicalized to lowercase hyphenated form.
if v is None:
return []
if isinstance(v, str):
raw = [t.strip() for t in v.split(",") if t.strip()]
elif isinstance(v, list):
raw = []
for item in v:
if not isinstance(item, str):
raise ValueError(
f"job_ids entries must be strings, got {type(item).__name__}"
)
raw.extend([t.strip() for t in item.split(",") if t.strip()])
else:
raise ValueError(
f"job_ids must be a string or list of strings, got {type(v).__name__}"
)
out: list[str] = []
seen: set[str] = set()
for s in raw:
try:
canonical = str(uuid.UUID(s))
except ValueError as e:
raise ValueError(f"job_ids must be UUIDs: {s!r}") from e
if canonical not in seen:
seen.add(canonical)
out.append(canonical)
return out
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):

View File

@ -264,6 +264,7 @@ def list_references_page(
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
metadata_filter: dict | None = None,
job_ids: Sequence[str] | None = None,
sort: str | None = None,
order: str | None = None,
after_cursor_value: object | None = None,
@ -293,6 +294,9 @@ def list_references_page(
escaped, esc = escape_sql_like_string(name_contains)
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
if job_ids:
base = base.where(AssetReference.job_id.in_(list(job_ids)))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
@ -345,6 +349,8 @@ def list_references_page(
count_stmt = count_stmt.where(
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
)
if job_ids:
count_stmt = count_stmt.where(AssetReference.job_id.in_(list(job_ids)))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)

View File

@ -264,6 +264,7 @@ def list_assets_page(
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
job_ids: Sequence[str] | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
@ -309,6 +310,7 @@ def list_assets_page(
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
job_ids=job_ids,
limit=fetch_limit,
offset=offset,
sort=sort,

View File

@ -1572,6 +1572,17 @@ paths:
type: string
enum: [asc, desc]
description: Sort direction
- name: job_ids
in: query
schema:
type: array
maxItems: 500
items:
type: string
format: uuid
style: form
explode: true
description: "Filter assets by associated job UUIDs. Accepts repeated query params (e.g. `?job_ids=a&job_ids=b`) or a single comma-separated value (`?job_ids=a,b`)."
- name: include_public
in: query
schema:

View File

@ -159,6 +159,56 @@ class TestListReferencesPage:
refs, _, _ = list_references_page(session, sort="name", order="asc")
assert refs[0].name == "large"
def test_job_ids_filter(self, session: Session):
asset = _make_asset(session, "hash1")
job_a = str(uuid.uuid4())
job_b = str(uuid.uuid4())
ref_a = _make_reference(session, asset, name="from_job_a")
ref_a.job_id = job_a
ref_b = _make_reference(session, asset, name="from_job_b")
ref_b.job_id = job_b
_make_reference(session, asset, name="no_job")
session.commit()
# Single job filter
refs, _, total = list_references_page(session, job_ids=[job_a])
assert total == 1
assert refs[0].name == "from_job_a"
# Multi-job filter (IN)
refs, _, total = list_references_page(session, job_ids=[job_a, job_b])
names = sorted(r.name for r in refs)
assert total == 2
assert names == ["from_job_a", "from_job_b"]
# Unknown job id matches nothing
refs, _, total = list_references_page(session, job_ids=[str(uuid.uuid4())])
assert total == 0
assert refs == []
# Empty/None means no filter -> all three references
refs, _, total = list_references_page(session, job_ids=[])
assert total == 3
refs, _, total = list_references_page(session, job_ids=None)
assert total == 3
def test_job_ids_combined_with_other_filters(self, session: Session):
asset = _make_asset(session, "hash1")
job_a = str(uuid.uuid4())
ref_match = _make_reference(session, asset, name="match.bin")
ref_match.job_id = job_a
ref_wrong_name = _make_reference(session, asset, name="other.bin")
ref_wrong_name.job_id = job_a
ref_wrong_job = _make_reference(session, asset, name="match.bin")
ref_wrong_job.job_id = str(uuid.uuid4())
session.commit()
refs, _, total = list_references_page(
session, job_ids=[job_a], name_contains="match"
)
assert total == 1
assert refs[0].id == ref_match.id
class TestTagRetrievalOrder:
"""End-to-end check: tags written through the public write paths come

View File

@ -0,0 +1,60 @@
"""Schema-level unit tests for ListAssetsQuery (no DB required)."""
import uuid
import pytest
from pydantic import ValidationError
from app.assets.api.schemas_in import ListAssetsQuery
class TestJobIdsValidator:
def test_csv_string_parses_and_canonicalizes(self):
a = "AAAAAAAA-BBBB-CCCC-DDDD-EEEEEEEEEEEE"
b = "11111111-2222-3333-4444-555555555555"
q = ListAssetsQuery.model_validate({"job_ids": f"{a},{b}"})
# Canonicalized to lowercase
assert q.job_ids == [a.lower(), b]
def test_repeated_query_params_as_list(self):
a = "11111111-1111-1111-1111-111111111111"
b = "22222222-2222-2222-2222-222222222222"
q = ListAssetsQuery.model_validate({"job_ids": [a, b]})
assert q.job_ids == [a, b]
def test_dedup_preserves_first_seen_order(self):
a = "11111111-1111-1111-1111-111111111111"
b = "22222222-2222-2222-2222-222222222222"
q = ListAssetsQuery.model_validate({"job_ids": [a, b, a]})
assert q.job_ids == [a, b]
def test_default_empty(self):
q = ListAssetsQuery.model_validate({})
assert q.job_ids == []
def test_invalid_uuid_rejected(self):
with pytest.raises(ValidationError) as exc:
ListAssetsQuery.model_validate({"job_ids": "not-a-uuid"})
assert "must be UUIDs" in str(exc.value)
def test_non_string_list_item_rejected(self):
with pytest.raises(ValidationError) as exc:
ListAssetsQuery.model_validate(
{"job_ids": ["11111111-1111-1111-1111-111111111111", 42]}
)
assert "must be strings" in str(exc.value)
def test_non_string_non_list_value_rejected(self):
with pytest.raises(ValidationError) as exc:
ListAssetsQuery.model_validate({"job_ids": {"bad": "shape"}})
assert "must be a string or list of strings" in str(exc.value)
def test_max_length_enforced(self):
too_many = [str(uuid.uuid4()) for _ in range(501)]
with pytest.raises(ValidationError) as exc:
ListAssetsQuery.model_validate({"job_ids": too_many})
assert exc.value.errors()[0]["type"] == "too_long"
def test_max_length_boundary_accepted(self):
at_cap = [str(uuid.uuid4()) for _ in range(500)]
q = ListAssetsQuery.model_validate({"job_ids": at_cap})
assert len(q.job_ids) == 500