mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-30 19:07:25 +08:00
feat(assets): add job_ids filter to GET /api/assets
Mirrors the existing cloud `job_ids` query param on the local Python server: clients can pass a comma-separated list (or repeated query params) of UUIDs to filter assets by their associated job. The `AssetReference.job_id` column already exists, so no migration is needed — this just plumbs the filter through schema → service → query. Marks the parameter as available in both runtimes by dropping the `[cloud-only]` description prefix and the `x-runtime: [cloud]` tag from the OpenAPI spec, per the OSS field-drift convention (absent runtime tag = populated by both local and cloud).
This commit is contained in:
parent
6887165a9d
commit
fbaae9bc42
@ -215,6 +215,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
job_ids=q.job_ids,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=sort,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
@ -53,6 +54,7 @@ class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: str | None = None
|
||||
job_ids: list[str] = Field(default_factory=list)
|
||||
|
||||
# Accept either a JSON string (query param) or a dict
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
@ -81,6 +83,35 @@ class ListAssetsQuery(BaseModel):
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("job_ids", mode="before")
|
||||
@classmethod
|
||||
def _split_and_validate_job_ids(cls, v):
|
||||
# Accept "uuid1,uuid2" or ["uuid1","uuid2"] or repeated query params.
|
||||
# Each entry must parse as a UUID; canonicalized to lowercase hyphenated form.
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
raw = [t.strip() for t in v.split(",") if t.strip()]
|
||||
elif isinstance(v, list):
|
||||
raw = []
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
raw.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
else:
|
||||
return v
|
||||
|
||||
out: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for s in raw:
|
||||
try:
|
||||
canonical = str(uuid.UUID(s))
|
||||
except (ValueError, AttributeError) as e:
|
||||
raise ValueError(f"job_ids must be UUIDs: {s!r}") from e
|
||||
if canonical not in seen:
|
||||
seen.add(canonical)
|
||||
out.append(canonical)
|
||||
return out
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
|
||||
@ -264,6 +264,7 @@ def list_references_page(
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
job_ids: Sequence[str] | None = None,
|
||||
sort: str | None = None,
|
||||
order: str | None = None,
|
||||
) -> tuple[list[AssetReference], dict[str, list[str]], int]:
|
||||
@ -284,6 +285,9 @@ def list_references_page(
|
||||
escaped, esc = escape_sql_like_string(name_contains)
|
||||
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
if job_ids:
|
||||
base = base.where(AssetReference.job_id.in_(list(job_ids)))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
@ -314,6 +318,8 @@ def list_references_page(
|
||||
count_stmt = count_stmt.where(
|
||||
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
||||
)
|
||||
if job_ids:
|
||||
count_stmt = count_stmt.where(AssetReference.job_id.in_(list(job_ids)))
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
|
||||
@ -248,6 +248,7 @@ def list_assets_page(
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
job_ids: Sequence[str] | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
@ -261,6 +262,7 @@ def list_assets_page(
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
job_ids=job_ids,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
|
||||
@ -1560,8 +1560,7 @@ paths:
|
||||
in: query
|
||||
schema:
|
||||
type: string
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Comma-separated UUIDs to filter assets by associated job."
|
||||
description: "Comma-separated UUIDs to filter assets by associated job."
|
||||
- name: include_public
|
||||
in: query
|
||||
schema:
|
||||
|
||||
@ -158,6 +158,56 @@ class TestListReferencesPage:
|
||||
refs, _, _ = list_references_page(session, sort="name", order="asc")
|
||||
assert refs[0].name == "large"
|
||||
|
||||
def test_job_ids_filter(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
job_a = str(uuid.uuid4())
|
||||
job_b = str(uuid.uuid4())
|
||||
ref_a = _make_reference(session, asset, name="from_job_a")
|
||||
ref_a.job_id = job_a
|
||||
ref_b = _make_reference(session, asset, name="from_job_b")
|
||||
ref_b.job_id = job_b
|
||||
_make_reference(session, asset, name="no_job")
|
||||
session.commit()
|
||||
|
||||
# Single job filter
|
||||
refs, _, total = list_references_page(session, job_ids=[job_a])
|
||||
assert total == 1
|
||||
assert refs[0].name == "from_job_a"
|
||||
|
||||
# Multi-job filter (IN)
|
||||
refs, _, total = list_references_page(session, job_ids=[job_a, job_b])
|
||||
names = sorted(r.name for r in refs)
|
||||
assert total == 2
|
||||
assert names == ["from_job_a", "from_job_b"]
|
||||
|
||||
# Unknown job id matches nothing
|
||||
refs, _, total = list_references_page(session, job_ids=[str(uuid.uuid4())])
|
||||
assert total == 0
|
||||
assert refs == []
|
||||
|
||||
# Empty/None means no filter -> all three references
|
||||
refs, _, total = list_references_page(session, job_ids=[])
|
||||
assert total == 3
|
||||
refs, _, total = list_references_page(session, job_ids=None)
|
||||
assert total == 3
|
||||
|
||||
def test_job_ids_combined_with_other_filters(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
job_a = str(uuid.uuid4())
|
||||
ref_match = _make_reference(session, asset, name="match.bin")
|
||||
ref_match.job_id = job_a
|
||||
ref_wrong_name = _make_reference(session, asset, name="other.bin")
|
||||
ref_wrong_name.job_id = job_a
|
||||
ref_wrong_job = _make_reference(session, asset, name="match.bin")
|
||||
ref_wrong_job.job_id = str(uuid.uuid4())
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(
|
||||
session, job_ids=[job_a], name_contains="match"
|
||||
)
|
||||
assert total == 1
|
||||
assert refs[0].id == ref_match.id
|
||||
|
||||
|
||||
class TestFetchReferenceAssetAndTags:
|
||||
def test_returns_none_for_nonexistent(self, session: Session):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user