mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-30 19:07:25 +08:00
Add job_ids filter for GET api assets (#13998)
Amp-Thread-ID: https://ampcode.com/threads/T-019e4ca5-b71a-7168-8f56-58b2325f34c3 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
00c88a4634
commit
560e6ee5c1
@ -227,6 +227,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
|||||||
exclude_tags=q.exclude_tags,
|
exclude_tags=q.exclude_tags,
|
||||||
name_contains=q.name_contains,
|
name_contains=q.name_contains,
|
||||||
metadata_filter=q.metadata_filter,
|
metadata_filter=q.metadata_filter,
|
||||||
|
job_ids=q.job_ids,
|
||||||
limit=q.limit,
|
limit=q.limit,
|
||||||
offset=q.offset,
|
offset=q.offset,
|
||||||
sort=sort,
|
sort=sort,
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@ -53,6 +54,7 @@ class ListAssetsQuery(BaseModel):
|
|||||||
include_tags: list[str] = Field(default_factory=list)
|
include_tags: list[str] = Field(default_factory=list)
|
||||||
exclude_tags: list[str] = Field(default_factory=list)
|
exclude_tags: list[str] = Field(default_factory=list)
|
||||||
name_contains: str | None = None
|
name_contains: str | None = None
|
||||||
|
job_ids: list[str] = Field(default_factory=list, max_length=500)
|
||||||
|
|
||||||
# Accept either a JSON string (query param) or a dict
|
# Accept either a JSON string (query param) or a dict
|
||||||
metadata_filter: dict[str, Any] | None = None
|
metadata_filter: dict[str, Any] | None = None
|
||||||
@ -86,6 +88,40 @@ class ListAssetsQuery(BaseModel):
|
|||||||
return out
|
return out
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@field_validator("job_ids", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _split_and_validate_job_ids(cls, v):
|
||||||
|
# Accept "uuid1,uuid2" or ["uuid1","uuid2"] or repeated query params.
|
||||||
|
# Each entry must parse as a UUID; canonicalized to lowercase hyphenated form.
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
if isinstance(v, str):
|
||||||
|
raw = [t.strip() for t in v.split(",") if t.strip()]
|
||||||
|
elif isinstance(v, list):
|
||||||
|
raw = []
|
||||||
|
for item in v:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
raise ValueError(
|
||||||
|
f"job_ids entries must be strings, got {type(item).__name__}"
|
||||||
|
)
|
||||||
|
raw.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"job_ids must be a string or list of strings, got {type(v).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
out: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for s in raw:
|
||||||
|
try:
|
||||||
|
canonical = str(uuid.UUID(s))
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"job_ids must be UUIDs: {s!r}") from e
|
||||||
|
if canonical not in seen:
|
||||||
|
seen.add(canonical)
|
||||||
|
out.append(canonical)
|
||||||
|
return out
|
||||||
|
|
||||||
@field_validator("metadata_filter", mode="before")
|
@field_validator("metadata_filter", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _parse_metadata_json(cls, v):
|
def _parse_metadata_json(cls, v):
|
||||||
|
|||||||
@ -264,6 +264,7 @@ def list_references_page(
|
|||||||
include_tags: Sequence[str] | None = None,
|
include_tags: Sequence[str] | None = None,
|
||||||
exclude_tags: Sequence[str] | None = None,
|
exclude_tags: Sequence[str] | None = None,
|
||||||
metadata_filter: dict | None = None,
|
metadata_filter: dict | None = None,
|
||||||
|
job_ids: Sequence[str] | None = None,
|
||||||
sort: str | None = None,
|
sort: str | None = None,
|
||||||
order: str | None = None,
|
order: str | None = None,
|
||||||
after_cursor_value: object | None = None,
|
after_cursor_value: object | None = None,
|
||||||
@ -293,6 +294,9 @@ def list_references_page(
|
|||||||
escaped, esc = escape_sql_like_string(name_contains)
|
escaped, esc = escape_sql_like_string(name_contains)
|
||||||
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||||
|
|
||||||
|
if job_ids:
|
||||||
|
base = base.where(AssetReference.job_id.in_(list(job_ids)))
|
||||||
|
|
||||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||||
base = apply_metadata_filter(base, metadata_filter)
|
base = apply_metadata_filter(base, metadata_filter)
|
||||||
|
|
||||||
@ -345,6 +349,8 @@ def list_references_page(
|
|||||||
count_stmt = count_stmt.where(
|
count_stmt = count_stmt.where(
|
||||||
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
||||||
)
|
)
|
||||||
|
if job_ids:
|
||||||
|
count_stmt = count_stmt.where(AssetReference.job_id.in_(list(job_ids)))
|
||||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||||
|
|
||||||
|
|||||||
@ -264,6 +264,7 @@ def list_assets_page(
|
|||||||
exclude_tags: Sequence[str] | None = None,
|
exclude_tags: Sequence[str] | None = None,
|
||||||
name_contains: str | None = None,
|
name_contains: str | None = None,
|
||||||
metadata_filter: dict | None = None,
|
metadata_filter: dict | None = None,
|
||||||
|
job_ids: Sequence[str] | None = None,
|
||||||
limit: int = 20,
|
limit: int = 20,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
sort: str = "created_at",
|
sort: str = "created_at",
|
||||||
@ -309,6 +310,7 @@ def list_assets_page(
|
|||||||
exclude_tags=exclude_tags,
|
exclude_tags=exclude_tags,
|
||||||
name_contains=name_contains,
|
name_contains=name_contains,
|
||||||
metadata_filter=metadata_filter,
|
metadata_filter=metadata_filter,
|
||||||
|
job_ids=job_ids,
|
||||||
limit=fetch_limit,
|
limit=fetch_limit,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
sort=sort,
|
sort=sort,
|
||||||
|
|||||||
11
openapi.yaml
11
openapi.yaml
@ -1572,6 +1572,17 @@ paths:
|
|||||||
type: string
|
type: string
|
||||||
enum: [asc, desc]
|
enum: [asc, desc]
|
||||||
description: Sort direction
|
description: Sort direction
|
||||||
|
- name: job_ids
|
||||||
|
in: query
|
||||||
|
schema:
|
||||||
|
type: array
|
||||||
|
maxItems: 500
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
format: uuid
|
||||||
|
style: form
|
||||||
|
explode: true
|
||||||
|
description: "Filter assets by associated job UUIDs. Accepts repeated query params (e.g. `?job_ids=a&job_ids=b`) or a single comma-separated value (`?job_ids=a,b`)."
|
||||||
- name: include_public
|
- name: include_public
|
||||||
in: query
|
in: query
|
||||||
schema:
|
schema:
|
||||||
|
|||||||
@ -159,6 +159,56 @@ class TestListReferencesPage:
|
|||||||
refs, _, _ = list_references_page(session, sort="name", order="asc")
|
refs, _, _ = list_references_page(session, sort="name", order="asc")
|
||||||
assert refs[0].name == "large"
|
assert refs[0].name == "large"
|
||||||
|
|
||||||
|
def test_job_ids_filter(self, session: Session):
|
||||||
|
asset = _make_asset(session, "hash1")
|
||||||
|
job_a = str(uuid.uuid4())
|
||||||
|
job_b = str(uuid.uuid4())
|
||||||
|
ref_a = _make_reference(session, asset, name="from_job_a")
|
||||||
|
ref_a.job_id = job_a
|
||||||
|
ref_b = _make_reference(session, asset, name="from_job_b")
|
||||||
|
ref_b.job_id = job_b
|
||||||
|
_make_reference(session, asset, name="no_job")
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Single job filter
|
||||||
|
refs, _, total = list_references_page(session, job_ids=[job_a])
|
||||||
|
assert total == 1
|
||||||
|
assert refs[0].name == "from_job_a"
|
||||||
|
|
||||||
|
# Multi-job filter (IN)
|
||||||
|
refs, _, total = list_references_page(session, job_ids=[job_a, job_b])
|
||||||
|
names = sorted(r.name for r in refs)
|
||||||
|
assert total == 2
|
||||||
|
assert names == ["from_job_a", "from_job_b"]
|
||||||
|
|
||||||
|
# Unknown job id matches nothing
|
||||||
|
refs, _, total = list_references_page(session, job_ids=[str(uuid.uuid4())])
|
||||||
|
assert total == 0
|
||||||
|
assert refs == []
|
||||||
|
|
||||||
|
# Empty/None means no filter -> all three references
|
||||||
|
refs, _, total = list_references_page(session, job_ids=[])
|
||||||
|
assert total == 3
|
||||||
|
refs, _, total = list_references_page(session, job_ids=None)
|
||||||
|
assert total == 3
|
||||||
|
|
||||||
|
def test_job_ids_combined_with_other_filters(self, session: Session):
|
||||||
|
asset = _make_asset(session, "hash1")
|
||||||
|
job_a = str(uuid.uuid4())
|
||||||
|
ref_match = _make_reference(session, asset, name="match.bin")
|
||||||
|
ref_match.job_id = job_a
|
||||||
|
ref_wrong_name = _make_reference(session, asset, name="other.bin")
|
||||||
|
ref_wrong_name.job_id = job_a
|
||||||
|
ref_wrong_job = _make_reference(session, asset, name="match.bin")
|
||||||
|
ref_wrong_job.job_id = str(uuid.uuid4())
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
refs, _, total = list_references_page(
|
||||||
|
session, job_ids=[job_a], name_contains="match"
|
||||||
|
)
|
||||||
|
assert total == 1
|
||||||
|
assert refs[0].id == ref_match.id
|
||||||
|
|
||||||
|
|
||||||
class TestTagRetrievalOrder:
|
class TestTagRetrievalOrder:
|
||||||
"""End-to-end check: tags written through the public write paths come
|
"""End-to-end check: tags written through the public write paths come
|
||||||
|
|||||||
60
tests-unit/assets_test/queries/test_list_assets_query.py
Normal file
60
tests-unit/assets_test/queries/test_list_assets_query.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
"""Schema-level unit tests for ListAssetsQuery (no DB required)."""
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from app.assets.api.schemas_in import ListAssetsQuery
|
||||||
|
|
||||||
|
|
||||||
|
class TestJobIdsValidator:
|
||||||
|
def test_csv_string_parses_and_canonicalizes(self):
|
||||||
|
a = "AAAAAAAA-BBBB-CCCC-DDDD-EEEEEEEEEEEE"
|
||||||
|
b = "11111111-2222-3333-4444-555555555555"
|
||||||
|
q = ListAssetsQuery.model_validate({"job_ids": f"{a},{b}"})
|
||||||
|
# Canonicalized to lowercase
|
||||||
|
assert q.job_ids == [a.lower(), b]
|
||||||
|
|
||||||
|
def test_repeated_query_params_as_list(self):
|
||||||
|
a = "11111111-1111-1111-1111-111111111111"
|
||||||
|
b = "22222222-2222-2222-2222-222222222222"
|
||||||
|
q = ListAssetsQuery.model_validate({"job_ids": [a, b]})
|
||||||
|
assert q.job_ids == [a, b]
|
||||||
|
|
||||||
|
def test_dedup_preserves_first_seen_order(self):
|
||||||
|
a = "11111111-1111-1111-1111-111111111111"
|
||||||
|
b = "22222222-2222-2222-2222-222222222222"
|
||||||
|
q = ListAssetsQuery.model_validate({"job_ids": [a, b, a]})
|
||||||
|
assert q.job_ids == [a, b]
|
||||||
|
|
||||||
|
def test_default_empty(self):
|
||||||
|
q = ListAssetsQuery.model_validate({})
|
||||||
|
assert q.job_ids == []
|
||||||
|
|
||||||
|
def test_invalid_uuid_rejected(self):
|
||||||
|
with pytest.raises(ValidationError) as exc:
|
||||||
|
ListAssetsQuery.model_validate({"job_ids": "not-a-uuid"})
|
||||||
|
assert "must be UUIDs" in str(exc.value)
|
||||||
|
|
||||||
|
def test_non_string_list_item_rejected(self):
|
||||||
|
with pytest.raises(ValidationError) as exc:
|
||||||
|
ListAssetsQuery.model_validate(
|
||||||
|
{"job_ids": ["11111111-1111-1111-1111-111111111111", 42]}
|
||||||
|
)
|
||||||
|
assert "must be strings" in str(exc.value)
|
||||||
|
|
||||||
|
def test_non_string_non_list_value_rejected(self):
|
||||||
|
with pytest.raises(ValidationError) as exc:
|
||||||
|
ListAssetsQuery.model_validate({"job_ids": {"bad": "shape"}})
|
||||||
|
assert "must be a string or list of strings" in str(exc.value)
|
||||||
|
|
||||||
|
def test_max_length_enforced(self):
|
||||||
|
too_many = [str(uuid.uuid4()) for _ in range(501)]
|
||||||
|
with pytest.raises(ValidationError) as exc:
|
||||||
|
ListAssetsQuery.model_validate({"job_ids": too_many})
|
||||||
|
assert exc.value.errors()[0]["type"] == "too_long"
|
||||||
|
|
||||||
|
def test_max_length_boundary_accepted(self):
|
||||||
|
at_cap = [str(uuid.uuid4()) for _ in range(500)]
|
||||||
|
q = ListAssetsQuery.model_validate({"job_ids": at_cap})
|
||||||
|
assert len(q.job_ids) == 500
|
||||||
Loading…
Reference in New Issue
Block a user