diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 68126b6a5..fa6502908 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -213,6 +213,7 @@ async def list_assets_route(request: web.Request) -> web.Response: owner_id=USER_MANAGER.get_request_user_id(request), include_tags=q.include_tags, exclude_tags=q.exclude_tags, + job_ids=q.job_ids, name_contains=q.name_contains, metadata_filter=q.metadata_filter, limit=q.limit, diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index 186a6ae1e..6c5fdb07c 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -1,4 +1,5 @@ import json +import uuid from dataclasses import dataclass from typing import Any, Literal @@ -52,6 +53,7 @@ class ParsedUpload: class ListAssetsQuery(BaseModel): include_tags: list[str] = Field(default_factory=list) exclude_tags: list[str] = Field(default_factory=list) + job_ids: list[str] = Field(default_factory=list, max_length=100) name_contains: str | None = None # Accept either a JSON string (query param) or a dict @@ -65,6 +67,35 @@ class ListAssetsQuery(BaseModel): ) order: Literal["asc", "desc"] = "desc" + @field_validator("job_ids", mode="before") + @classmethod + def _split_csv_job_ids(cls, v): + if v is None: + return [] + if isinstance(v, str): + tokens = [t.strip() for t in v.split(",") if t.strip()] + elif isinstance(v, list): + tokens = [] + for item in v: + if not isinstance(item, str): + raise ValueError( + f"job_ids items must be strings, got {type(item).__name__}" + ) + tokens.extend([t.strip() for t in item.split(",") if t.strip()]) + else: + raise ValueError("job_ids must be a string or list of strings") + seen: set[str] = set() + out: list[str] = [] + for t in tokens: + try: + normalized = str(uuid.UUID(t)) + except ValueError: + raise ValueError(f"invalid UUID in job_ids: {t!r}") + if normalized not in seen: + seen.add(normalized) + out.append(normalized) + return out + @field_validator("include_tags", "exclude_tags", mode="before") @classmethod def _split_csv_tags(cls, v): diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py index 8b90ae511..de409c2f9 100644 --- a/app/assets/database/queries/asset_reference.py +++ b/app/assets/database/queries/asset_reference.py @@ -263,6 +263,7 @@ def list_references_page( name_contains: str | None = None, include_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None, + job_ids: Sequence[str] | None = None, metadata_filter: dict | None = None, sort: str | None = None, order: str | None = None, @@ -284,6 +285,9 @@ def list_references_page( escaped, esc = escape_sql_like_string(name_contains) base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) + if job_ids: + base = base.where(AssetReference.job_id.in_(job_ids)) + base = apply_tag_filters(base, include_tags, exclude_tags) base = apply_metadata_filter(base, metadata_filter) @@ -314,6 +318,9 @@ def list_references_page( count_stmt = count_stmt.where( AssetReference.name.ilike(f"%{escaped}%", escape=esc) ) + if job_ids: + count_stmt = count_stmt.where(AssetReference.job_id.in_(job_ids)) + count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) count_stmt = apply_metadata_filter(count_stmt, metadata_filter) diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py index 5aefd9956..e5ab24f17 100644 --- a/app/assets/services/asset_management.py +++ b/app/assets/services/asset_management.py @@ -246,6 +246,7 @@ def list_assets_page( owner_id: str = "", include_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None, + job_ids: Sequence[str] | None = None, name_contains: str | None = None, metadata_filter: dict | None = None, limit: int = 20, @@ -259,6 +260,7 @@ def list_assets_page( owner_id=owner_id, include_tags=include_tags, exclude_tags=exclude_tags, + job_ids=job_ids, name_contains=name_contains, metadata_filter=metadata_filter, limit=limit, diff --git a/openapi.yaml b/openapi.yaml index 96be4c1d5..307b48333 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -1553,8 +1553,7 @@ paths: in: query schema: type: string - x-runtime: [cloud] - description: "[cloud-only] Comma-separated UUIDs to filter assets by associated job." + description: "Comma-separated UUIDs to filter assets by associated job." - name: include_public in: query schema: diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py index fe510e342..a73dbdc1d 100644 --- a/tests-unit/assets_test/queries/test_asset_info.py +++ b/tests-unit/assets_test/queries/test_asset_info.py @@ -158,6 +158,80 @@ class TestListReferencesPage: refs, _, _ = list_references_page(session, sort="name", order="asc") assert refs[0].name == "large" + def test_job_ids_filter_single(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, name="with_job") + ref1.job_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + _make_reference(session, asset, name="no_job") + session.commit() + + refs, _, total = list_references_page( + session, job_ids=["aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"] + ) + assert total == 1 + assert refs[0].name == "with_job" + + def test_job_ids_filter_multiple(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, name="job_a") + ref1.job_id = "aaaaaaaa-1111-2222-3333-444444444444" + ref2 = _make_reference(session, asset, name="job_b") + ref2.job_id = "bbbbbbbb-1111-2222-3333-444444444444" + _make_reference(session, asset, name="no_job") + session.commit() + + refs, _, total = list_references_page( + session, + job_ids=[ + "aaaaaaaa-1111-2222-3333-444444444444", + "bbbbbbbb-1111-2222-3333-444444444444", + ], + ) + assert total == 2 + names = {r.name for r in refs} + assert names == {"job_a", "job_b"} + + def test_job_ids_filter_empty_returns_all(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, name="with_job") + ref1.job_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + _make_reference(session, asset, name="no_job") + session.commit() + + refs, _, total = list_references_page(session, job_ids=[]) + assert total == 2 + + def test_job_ids_filter_no_match(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, name="with_job") + ref1.job_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + session.commit() + + refs, _, total = list_references_page( + session, job_ids=["99999999-9999-9999-9999-999999999999"] + ) + assert total == 0 + assert refs == [] + + def test_job_ids_combined_with_tags(self, session: Session): + asset = _make_asset(session, "hash1") + ref1 = _make_reference(session, asset, name="tagged_with_job") + ref1.job_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + ensure_tags_exist(session, ["wanted"]) + add_tags_to_reference(session, reference_id=ref1.id, tags=["wanted"]) + + ref2 = _make_reference(session, asset, name="untagged_with_job") + ref2.job_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + session.commit() + + refs, _, total = list_references_page( + session, + job_ids=["aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"], + include_tags=["wanted"], + ) + assert total == 1 + assert refs[0].name == "tagged_with_job" + class TestFetchReferenceAssetAndTags: def test_returns_none_for_nonexistent(self, session: Session):