This commit is contained in:
Matt Miller 2026-05-13 19:14:40 +03:00 committed by GitHub
commit 4cf4c34d43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 116 additions and 2 deletions

View File

@ -213,6 +213,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
owner_id=USER_MANAGER.get_request_user_id(request), owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags, include_tags=q.include_tags,
exclude_tags=q.exclude_tags, exclude_tags=q.exclude_tags,
job_ids=q.job_ids,
name_contains=q.name_contains, name_contains=q.name_contains,
metadata_filter=q.metadata_filter, metadata_filter=q.metadata_filter,
limit=q.limit, limit=q.limit,

View File

@ -1,4 +1,5 @@
import json import json
import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal from typing import Any, Literal
@ -52,6 +53,7 @@ class ParsedUpload:
class ListAssetsQuery(BaseModel): class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list) include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list) exclude_tags: list[str] = Field(default_factory=list)
job_ids: list[str] = Field(default_factory=list, max_length=100)
name_contains: str | None = None name_contains: str | None = None
# Accept either a JSON string (query param) or a dict # Accept either a JSON string (query param) or a dict
@ -65,6 +67,35 @@ class ListAssetsQuery(BaseModel):
) )
order: Literal["asc", "desc"] = "desc" order: Literal["asc", "desc"] = "desc"
@field_validator("job_ids", mode="before")
@classmethod
def _split_csv_job_ids(cls, v):
if v is None:
return []
if isinstance(v, str):
tokens = [t.strip() for t in v.split(",") if t.strip()]
elif isinstance(v, list):
tokens = []
for item in v:
if not isinstance(item, str):
raise ValueError(
f"job_ids items must be strings, got {type(item).__name__}"
)
tokens.extend([t.strip() for t in item.split(",") if t.strip()])
else:
raise ValueError("job_ids must be a string or list of strings")
seen: set[str] = set()
out: list[str] = []
for t in tokens:
try:
normalized = str(uuid.UUID(t))
except ValueError:
raise ValueError(f"invalid UUID in job_ids: {t!r}")
if normalized not in seen:
seen.add(normalized)
out.append(normalized)
return out
@field_validator("include_tags", "exclude_tags", mode="before") @field_validator("include_tags", "exclude_tags", mode="before")
@classmethod @classmethod
def _split_csv_tags(cls, v): def _split_csv_tags(cls, v):

View File

@ -263,6 +263,7 @@ def list_references_page(
name_contains: str | None = None, name_contains: str | None = None,
include_tags: Sequence[str] | None = None, include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None,
job_ids: Sequence[str] | None = None,
metadata_filter: dict | None = None, metadata_filter: dict | None = None,
sort: str | None = None, sort: str | None = None,
order: str | None = None, order: str | None = None,
@ -284,6 +285,9 @@ def list_references_page(
escaped, esc = escape_sql_like_string(name_contains) escaped, esc = escape_sql_like_string(name_contains)
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
if job_ids:
base = base.where(AssetReference.job_id.in_(job_ids))
base = apply_tag_filters(base, include_tags, exclude_tags) base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter) base = apply_metadata_filter(base, metadata_filter)
@ -314,6 +318,9 @@ def list_references_page(
count_stmt = count_stmt.where( count_stmt = count_stmt.where(
AssetReference.name.ilike(f"%{escaped}%", escape=esc) AssetReference.name.ilike(f"%{escaped}%", escape=esc)
) )
if job_ids:
count_stmt = count_stmt.where(AssetReference.job_id.in_(job_ids))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter) count_stmt = apply_metadata_filter(count_stmt, metadata_filter)

View File

@ -246,6 +246,7 @@ def list_assets_page(
owner_id: str = "", owner_id: str = "",
include_tags: Sequence[str] | None = None, include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None,
job_ids: Sequence[str] | None = None,
name_contains: str | None = None, name_contains: str | None = None,
metadata_filter: dict | None = None, metadata_filter: dict | None = None,
limit: int = 20, limit: int = 20,
@ -259,6 +260,7 @@ def list_assets_page(
owner_id=owner_id, owner_id=owner_id,
include_tags=include_tags, include_tags=include_tags,
exclude_tags=exclude_tags, exclude_tags=exclude_tags,
job_ids=job_ids,
name_contains=name_contains, name_contains=name_contains,
metadata_filter=metadata_filter, metadata_filter=metadata_filter,
limit=limit, limit=limit,

View File

@ -1553,8 +1553,7 @@ paths:
in: query in: query
schema: schema:
type: string type: string
x-runtime: [cloud] description: "Comma-separated UUIDs to filter assets by associated job."
description: "[cloud-only] Comma-separated UUIDs to filter assets by associated job."
- name: include_public - name: include_public
in: query in: query
schema: schema:

View File

@ -158,6 +158,80 @@ class TestListReferencesPage:
refs, _, _ = list_references_page(session, sort="name", order="asc") refs, _, _ = list_references_page(session, sort="name", order="asc")
assert refs[0].name == "large" assert refs[0].name == "large"
def test_job_ids_filter_single(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, name="with_job")
ref1.job_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
_make_reference(session, asset, name="no_job")
session.commit()
refs, _, total = list_references_page(
session, job_ids=["aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"]
)
assert total == 1
assert refs[0].name == "with_job"
def test_job_ids_filter_multiple(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, name="job_a")
ref1.job_id = "aaaaaaaa-1111-2222-3333-444444444444"
ref2 = _make_reference(session, asset, name="job_b")
ref2.job_id = "bbbbbbbb-1111-2222-3333-444444444444"
_make_reference(session, asset, name="no_job")
session.commit()
refs, _, total = list_references_page(
session,
job_ids=[
"aaaaaaaa-1111-2222-3333-444444444444",
"bbbbbbbb-1111-2222-3333-444444444444",
],
)
assert total == 2
names = {r.name for r in refs}
assert names == {"job_a", "job_b"}
def test_job_ids_filter_empty_returns_all(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, name="with_job")
ref1.job_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
_make_reference(session, asset, name="no_job")
session.commit()
refs, _, total = list_references_page(session, job_ids=[])
assert total == 2
def test_job_ids_filter_no_match(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, name="with_job")
ref1.job_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
session.commit()
refs, _, total = list_references_page(
session, job_ids=["99999999-9999-9999-9999-999999999999"]
)
assert total == 0
assert refs == []
def test_job_ids_combined_with_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, name="tagged_with_job")
ref1.job_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
ensure_tags_exist(session, ["wanted"])
add_tags_to_reference(session, reference_id=ref1.id, tags=["wanted"])
ref2 = _make_reference(session, asset, name="untagged_with_job")
ref2.job_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
session.commit()
refs, _, total = list_references_page(
session,
job_ids=["aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"],
include_tags=["wanted"],
)
assert total == 1
assert refs[0].name == "tagged_with_job"
class TestFetchReferenceAssetAndTags: class TestFetchReferenceAssetAndTags:
def test_returns_none_for_nonexistent(self, session: Session): def test_returns_none_for_nonexistent(self, session: Session):