mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-03 02:30:52 +08:00
use Pydantic for output; finished Tags endpoints
This commit is contained in:
parent
5c1b5973ac
commit
8d46bec951
@ -65,7 +65,7 @@ def upgrade() -> None:
|
|||||||
op.create_table(
|
op.create_table(
|
||||||
"asset_info_tags",
|
"asset_info_tags",
|
||||||
sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||||
sa.Column("tag_name", sa.String(length=128), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||||
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||||
sa.Column("added_by", sa.String(length=128), nullable=True),
|
sa.Column("added_by", sa.String(length=128), nullable=True),
|
||||||
sa.Column("added_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
sa.Column("added_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||||
@ -104,7 +104,7 @@ def upgrade() -> None:
|
|||||||
# Tags vocabulary for models
|
# Tags vocabulary for models
|
||||||
tags_table = sa.table(
|
tags_table = sa.table(
|
||||||
"tags",
|
"tags",
|
||||||
sa.column("name", sa.String()),
|
sa.column("name", sa.String(length=512)),
|
||||||
sa.column("tag_type", sa.String()),
|
sa.column("tag_type", sa.String()),
|
||||||
)
|
)
|
||||||
op.bulk_insert(
|
op.bulk_insert(
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from aiohttp import web
|
|||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from .. import assets_manager
|
from .. import assets_manager
|
||||||
from .schemas_in import ListAssetsQuery, UpdateAssetBody
|
from . import schemas_in
|
||||||
|
|
||||||
|
|
||||||
ROUTES = web.RouteTableDef()
|
ROUTES = web.RouteTableDef()
|
||||||
@ -15,7 +15,7 @@ async def list_assets(request: web.Request) -> web.Response:
|
|||||||
query_dict = dict(request.rel_url.query)
|
query_dict = dict(request.rel_url.query)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
q = ListAssetsQuery.model_validate(query_dict)
|
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
|
||||||
except ValidationError as ve:
|
except ValidationError as ve:
|
||||||
return _validation_error_response("INVALID_QUERY", ve)
|
return _validation_error_response("INVALID_QUERY", ve)
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ async def list_assets(request: web.Request) -> web.Response:
|
|||||||
sort=q.sort,
|
sort=q.sort,
|
||||||
order=q.order,
|
order=q.order,
|
||||||
)
|
)
|
||||||
return web.json_response(payload)
|
return web.json_response(payload.model_dump(mode="json"))
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.put("/api/assets/{id}")
|
@ROUTES.put("/api/assets/{id}")
|
||||||
@ -41,7 +41,7 @@ async def update_asset(request: web.Request) -> web.Response:
|
|||||||
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body = UpdateAssetBody.model_validate(await request.json())
|
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
|
||||||
except ValidationError as ve:
|
except ValidationError as ve:
|
||||||
return _validation_error_response("INVALID_BODY", ve)
|
return _validation_error_response("INVALID_BODY", ve)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -58,7 +58,89 @@ async def update_asset(request: web.Request) -> web.Response:
|
|||||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||||
except Exception:
|
except Exception:
|
||||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
return web.json_response(result, status=200)
|
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.get("/api/tags")
|
||||||
|
async def get_tags(request: web.Request) -> web.Response:
|
||||||
|
query_map = dict(request.rel_url.query)
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = schemas_in.TagsListQuery.model_validate(query_map)
|
||||||
|
except ValidationError as ve:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": ve.errors()}},
|
||||||
|
status=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await assets_manager.list_tags(
|
||||||
|
prefix=query.prefix,
|
||||||
|
limit=query.limit,
|
||||||
|
offset=query.offset,
|
||||||
|
order=query.order,
|
||||||
|
include_zero=query.include_zero,
|
||||||
|
)
|
||||||
|
return web.json_response(result.model_dump(mode="json"))
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post("/api/assets/{id}/tags")
|
||||||
|
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||||
|
asset_info_id_raw = request.match_info.get("id")
|
||||||
|
try:
|
||||||
|
asset_info_id = int(asset_info_id_raw)
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
data = schemas_in.TagsAdd.model_validate(payload)
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await assets_manager.add_tags_to_asset(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=data.tags,
|
||||||
|
origin="manual",
|
||||||
|
added_by=None,
|
||||||
|
)
|
||||||
|
except ValueError as ve:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||||
|
except Exception:
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
|
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.delete("/api/assets/{id}/tags")
|
||||||
|
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||||
|
asset_info_id_raw = request.match_info.get("id")
|
||||||
|
try:
|
||||||
|
asset_info_id = int(asset_info_id_raw)
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
data = schemas_in.TagsRemove.model_validate(payload)
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await assets_manager.remove_tags_from_asset(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=data.tags,
|
||||||
|
)
|
||||||
|
except ValueError as ve:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||||
|
except Exception:
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
|
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
def register_assets_routes(app: web.Application) -> None:
|
def register_assets_routes(app: web.Application) -> None:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Optional, Literal
|
from typing import Any, Optional, Literal
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator, conint
|
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator, conint
|
||||||
|
|
||||||
|
|
||||||
class ListAssetsQuery(BaseModel):
|
class ListAssetsQuery(BaseModel):
|
||||||
@ -64,3 +64,48 @@ class UpdateAssetBody(BaseModel):
|
|||||||
if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags):
|
if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags):
|
||||||
raise ValueError("Field 'tags' must be an array of strings.")
|
raise ValueError("Field 'tags' must be an array of strings.")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class TagsListQuery(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
|
prefix: Optional[str] = Field(None, min_length=1, max_length=256)
|
||||||
|
limit: int = Field(100, ge=1, le=1000)
|
||||||
|
offset: int = Field(0, ge=0, le=10_000_000)
|
||||||
|
order: Literal["count_desc", "name_asc"] = "count_desc"
|
||||||
|
include_zero: bool = True
|
||||||
|
|
||||||
|
@field_validator("prefix")
|
||||||
|
@classmethod
|
||||||
|
def normalize_prefix(cls, v: Optional[str]) -> Optional[str]:
|
||||||
|
if v is None:
|
||||||
|
return v
|
||||||
|
v = v.strip()
|
||||||
|
return v.lower() or None
|
||||||
|
|
||||||
|
|
||||||
|
class TagsAdd(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
tags: list[str] = Field(..., min_length=1)
|
||||||
|
|
||||||
|
@field_validator("tags")
|
||||||
|
@classmethod
|
||||||
|
def normalize_tags(cls, v: list[str]) -> list[str]:
|
||||||
|
out = []
|
||||||
|
for t in v:
|
||||||
|
if not isinstance(t, str):
|
||||||
|
raise TypeError("tags must be strings")
|
||||||
|
tnorm = t.strip().lower()
|
||||||
|
if tnorm:
|
||||||
|
out.append(tnorm)
|
||||||
|
seen = set()
|
||||||
|
deduplicated = []
|
||||||
|
for x in out:
|
||||||
|
if x not in seen:
|
||||||
|
seen.add(x)
|
||||||
|
deduplicated.append(x)
|
||||||
|
return deduplicated
|
||||||
|
|
||||||
|
|
||||||
|
class TagsRemove(TagsAdd):
|
||||||
|
pass
|
||||||
|
|||||||
69
app/api/schemas_out.py
Normal file
69
app/api/schemas_out.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Optional
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||||
|
|
||||||
|
|
||||||
|
class AssetSummary(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
asset_hash: str
|
||||||
|
size: Optional[int] = None
|
||||||
|
mime_type: Optional[str] = None
|
||||||
|
tags: list[str] = Field(default_factory=list)
|
||||||
|
preview_url: Optional[str] = None
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
last_access_time: Optional[datetime] = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
@field_serializer("created_at", "updated_at", "last_access_time")
|
||||||
|
def _ser_dt(self, v: Optional[datetime], _info):
|
||||||
|
return v.isoformat() if v else None
|
||||||
|
|
||||||
|
|
||||||
|
class AssetsList(BaseModel):
|
||||||
|
assets: list[AssetSummary]
|
||||||
|
total: int
|
||||||
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
class AssetUpdated(BaseModel):
|
||||||
|
id: int
|
||||||
|
name: str
|
||||||
|
asset_hash: str
|
||||||
|
tags: list[str] = Field(default_factory=list)
|
||||||
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
@field_serializer("updated_at")
|
||||||
|
def _ser_updated(self, v: Optional[datetime], _info):
|
||||||
|
return v.isoformat() if v else None
|
||||||
|
|
||||||
|
|
||||||
|
class TagUsage(BaseModel):
|
||||||
|
name: str
|
||||||
|
count: int
|
||||||
|
type: str
|
||||||
|
|
||||||
|
|
||||||
|
class TagsList(BaseModel):
|
||||||
|
tags: list[TagUsage] = Field(default_factory=list)
|
||||||
|
total: int
|
||||||
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
class TagsAdd(BaseModel):
|
||||||
|
model_config = ConfigDict(str_strip_whitespace=True)
|
||||||
|
added: list[str] = Field(default_factory=list)
|
||||||
|
already_present: list[str] = Field(default_factory=list)
|
||||||
|
total_tags: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TagsRemove(BaseModel):
|
||||||
|
model_config = ConfigDict(str_strip_whitespace=True)
|
||||||
|
removed: list[str] = Field(default_factory=list)
|
||||||
|
not_present: list[str] = Field(default_factory=list)
|
||||||
|
total_tags: list[str] = Field(default_factory=list)
|
||||||
@ -14,7 +14,11 @@ from .database.services import (
|
|||||||
list_asset_infos_page,
|
list_asset_infos_page,
|
||||||
update_asset_info_full,
|
update_asset_info_full,
|
||||||
get_asset_tags,
|
get_asset_tags,
|
||||||
|
list_tags_with_usage,
|
||||||
|
add_tags_to_asset_info,
|
||||||
|
remove_tags_from_asset_info,
|
||||||
)
|
)
|
||||||
|
from .api import schemas_out
|
||||||
|
|
||||||
|
|
||||||
def populate_db_with_asset(tags: list[str], file_name: str, file_path: str) -> None:
|
def populate_db_with_asset(tags: list[str], file_name: str, file_path: str) -> None:
|
||||||
@ -70,7 +74,7 @@ async def list_assets(
|
|||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
sort: str | None = "created_at",
|
sort: str | None = "created_at",
|
||||||
order: str | None = "desc",
|
order: str | None = "desc",
|
||||||
) -> dict:
|
) -> schemas_out.AssetsList:
|
||||||
sort = _safe_sort_field(sort)
|
sort = _safe_sort_field(sort)
|
||||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||||
|
|
||||||
@ -87,30 +91,30 @@ async def list_assets(
|
|||||||
order=order,
|
order=order,
|
||||||
)
|
)
|
||||||
|
|
||||||
assets_json = []
|
summaries: list[schemas_out.AssetSummary] = []
|
||||||
for info in infos:
|
for info in infos:
|
||||||
asset = info.asset # populated via contains_eager
|
asset = info.asset
|
||||||
tags = tag_map.get(info.id, [])
|
tags = tag_map.get(info.id, [])
|
||||||
assets_json.append(
|
summaries.append(
|
||||||
{
|
schemas_out.AssetSummary(
|
||||||
"id": info.id,
|
id=info.id,
|
||||||
"name": info.name,
|
name=info.name,
|
||||||
"asset_hash": info.asset_hash,
|
asset_hash=info.asset_hash,
|
||||||
"size": int(asset.size_bytes) if asset else None,
|
size=int(asset.size_bytes) if asset else None,
|
||||||
"mime_type": asset.mime_type if asset else None,
|
mime_type=asset.mime_type if asset else None,
|
||||||
"tags": tags,
|
tags=tags,
|
||||||
"preview_url": f"/api/v1/assets/{info.id}/content", # TODO: implement actual content endpoint later
|
preview_url=f"/api/v1/assets/{info.id}/content", # TODO: implement actual content endpoint later
|
||||||
"created_at": info.created_at.isoformat() if info.created_at else None,
|
created_at=info.created_at,
|
||||||
"updated_at": info.updated_at.isoformat() if info.updated_at else None,
|
updated_at=info.updated_at,
|
||||||
"last_access_time": info.last_access_time.isoformat() if info.last_access_time else None,
|
last_access_time=info.last_access_time,
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return schemas_out.AssetsList(
|
||||||
"assets": assets_json,
|
assets=summaries,
|
||||||
"total": total,
|
total=total,
|
||||||
"has_more": (offset + len(assets_json)) < total,
|
has_more=(offset + len(summaries)) < total,
|
||||||
}
|
)
|
||||||
|
|
||||||
|
|
||||||
async def update_asset(
|
async def update_asset(
|
||||||
@ -119,7 +123,7 @@ async def update_asset(
|
|||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
tags: list[str] | None = None,
|
tags: list[str] | None = None,
|
||||||
user_metadata: dict | None = None,
|
user_metadata: dict | None = None,
|
||||||
) -> dict:
|
) -> schemas_out.AssetUpdated:
|
||||||
async with await create_session() as session:
|
async with await create_session() as session:
|
||||||
info = await update_asset_info_full(
|
info = await update_asset_info_full(
|
||||||
session,
|
session,
|
||||||
@ -134,14 +138,40 @@ async def update_asset(
|
|||||||
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
|
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
return {
|
return schemas_out.AssetUpdated(
|
||||||
"id": info.id,
|
id=info.id,
|
||||||
"name": info.name,
|
name=info.name,
|
||||||
"asset_hash": info.asset_hash,
|
asset_hash=info.asset_hash,
|
||||||
"tags": tag_names,
|
tags=tag_names,
|
||||||
"user_metadata": info.user_metadata or {},
|
user_metadata=info.user_metadata or {},
|
||||||
"updated_at": info.updated_at.isoformat() if info.updated_at else None,
|
updated_at=info.updated_at,
|
||||||
}
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def list_tags(
|
||||||
|
*,
|
||||||
|
prefix: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
order: str = "count_desc",
|
||||||
|
include_zero: bool = True,
|
||||||
|
) -> schemas_out.TagsList:
|
||||||
|
limit = max(1, min(1000, limit))
|
||||||
|
offset = max(0, offset)
|
||||||
|
|
||||||
|
async with await create_session() as session:
|
||||||
|
rows, total = await list_tags_with_usage(
|
||||||
|
session,
|
||||||
|
prefix=prefix,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
include_zero=include_zero,
|
||||||
|
order=order,
|
||||||
|
)
|
||||||
|
|
||||||
|
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||||
|
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
|
||||||
|
|
||||||
|
|
||||||
def _safe_sort_field(requested: str | None) -> str:
|
def _safe_sort_field(requested: str | None) -> str:
|
||||||
@ -156,3 +186,38 @@ def _safe_sort_field(requested: str | None) -> str:
|
|||||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||||
st = os.stat(path, follow_symlinks=True)
|
st = os.stat(path, follow_symlinks=True)
|
||||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||||
|
|
||||||
|
|
||||||
|
async def add_tags_to_asset(
|
||||||
|
*,
|
||||||
|
asset_info_id: int,
|
||||||
|
tags: list[str],
|
||||||
|
origin: str = "manual",
|
||||||
|
added_by: str | None = None,
|
||||||
|
) -> schemas_out.TagsAdd:
|
||||||
|
async with await create_session() as session:
|
||||||
|
data = await add_tags_to_asset_info(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=tags,
|
||||||
|
origin=origin,
|
||||||
|
added_by=added_by,
|
||||||
|
create_if_missing=True,
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
return schemas_out.TagsAdd(**data)
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_tags_from_asset(
|
||||||
|
*,
|
||||||
|
asset_info_id: int,
|
||||||
|
tags: list[str],
|
||||||
|
) -> schemas_out.TagsRemove:
|
||||||
|
async with await create_session() as session:
|
||||||
|
data = await remove_tags_from_asset_info(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
return schemas_out.TagsRemove(**data)
|
||||||
|
|||||||
@ -493,7 +493,7 @@ async def replace_asset_info_metadata_projection(
|
|||||||
await session.flush()
|
await session.flush()
|
||||||
|
|
||||||
|
|
||||||
async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[Tag]:
|
async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[str]:
|
||||||
return [
|
return [
|
||||||
tag_name
|
tag_name
|
||||||
for (tag_name,) in (
|
for (tag_name,) in (
|
||||||
@ -504,6 +504,179 @@ async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[T
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def list_tags_with_usage(
|
||||||
|
session,
|
||||||
|
*,
|
||||||
|
prefix: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
include_zero: bool = True,
|
||||||
|
order: str = "count_desc", # "count_desc" | "name_asc"
|
||||||
|
) -> tuple[list[tuple[str, str, int]], int]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
rows: list of (name, tag_type, count)
|
||||||
|
total: number of tags matching filter (independent of pagination)
|
||||||
|
"""
|
||||||
|
# Subquery with counts by tag_name
|
||||||
|
counts_sq = (
|
||||||
|
select(
|
||||||
|
AssetInfoTag.tag_name.label("tag_name"),
|
||||||
|
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||||
|
)
|
||||||
|
.group_by(AssetInfoTag.tag_name)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Base select with LEFT JOIN so we can include zero-usage tags
|
||||||
|
q = (
|
||||||
|
select(
|
||||||
|
Tag.name,
|
||||||
|
Tag.tag_type,
|
||||||
|
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||||
|
)
|
||||||
|
.select_from(Tag)
|
||||||
|
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefix filter (tags are lowercase by check constraint)
|
||||||
|
if prefix:
|
||||||
|
q = q.where(Tag.name.like(prefix.strip().lower() + "%"))
|
||||||
|
|
||||||
|
# Include_zero toggles: if False, drop zero-usage tags
|
||||||
|
if not include_zero:
|
||||||
|
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||||
|
|
||||||
|
# Ordering
|
||||||
|
if order == "name_asc":
|
||||||
|
q = q.order_by(Tag.name.asc())
|
||||||
|
else: # default "count_desc"
|
||||||
|
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||||
|
|
||||||
|
# Total (without limit/offset, same filters)
|
||||||
|
total_q = select(func.count()).select_from(Tag)
|
||||||
|
if prefix:
|
||||||
|
total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%"))
|
||||||
|
if not include_zero:
|
||||||
|
# count only names that appear in counts subquery
|
||||||
|
total_q = total_q.where(
|
||||||
|
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = (await session.execute(q.limit(limit).offset(offset))).all()
|
||||||
|
total = (await session.execute(total_q)).scalar_one()
|
||||||
|
|
||||||
|
# Normalize counts to int for SQLite/Postgres consistency
|
||||||
|
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||||
|
return rows_norm, int(total or 0)
|
||||||
|
|
||||||
|
|
||||||
|
async def add_tags_to_asset_info(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_info_id: int,
|
||||||
|
tags: Sequence[str],
|
||||||
|
origin: str = "manual",
|
||||||
|
added_by: Optional[str] = None,
|
||||||
|
create_if_missing: bool = True,
|
||||||
|
) -> dict:
|
||||||
|
"""Adds tags to an AssetInfo.
|
||||||
|
If create_if_missing=True, missing tag rows are created as 'user'.
|
||||||
|
Returns: {"added": [...], "already_present": [...], "total_tags": [...]}
|
||||||
|
"""
|
||||||
|
info = await session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
norm = _normalize_tags(tags)
|
||||||
|
if not norm:
|
||||||
|
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
return {"added": [], "already_present": [], "total_tags": total}
|
||||||
|
|
||||||
|
# Ensure tag rows exist if requested.
|
||||||
|
if create_if_missing:
|
||||||
|
await _ensure_tags_exist(session, norm, tag_type="user")
|
||||||
|
|
||||||
|
# Current links
|
||||||
|
existing = {
|
||||||
|
tname
|
||||||
|
for (tname,) in (
|
||||||
|
await session.execute(
|
||||||
|
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
}
|
||||||
|
|
||||||
|
to_add = [t for t in norm if t not in existing]
|
||||||
|
already = [t for t in norm if t in existing]
|
||||||
|
|
||||||
|
if to_add:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
# Make insert race-safe with a nested tx; ignore dup conflicts if any.
|
||||||
|
async with session.begin_nested():
|
||||||
|
session.add_all([
|
||||||
|
AssetInfoTag(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tag_name=t,
|
||||||
|
origin=origin,
|
||||||
|
added_by=added_by,
|
||||||
|
added_at=now,
|
||||||
|
) for t in to_add
|
||||||
|
])
|
||||||
|
try:
|
||||||
|
await session.flush()
|
||||||
|
except IntegrityError:
|
||||||
|
# Another writer linked the same tag at the same time -> ok, treat as already present.
|
||||||
|
await session.rollback()
|
||||||
|
|
||||||
|
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
return {"added": sorted(set(to_add)), "already_present": sorted(set(already)), "total_tags": total}
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_tags_from_asset_info(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_info_id: int,
|
||||||
|
tags: Sequence[str],
|
||||||
|
) -> dict:
|
||||||
|
"""Removes tags from an AssetInfo.
|
||||||
|
Returns: {"removed": [...], "not_present": [...], "total_tags": [...]}
|
||||||
|
"""
|
||||||
|
info = await session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
norm = _normalize_tags(tags)
|
||||||
|
if not norm:
|
||||||
|
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
return {"removed": [], "not_present": [], "total_tags": total}
|
||||||
|
|
||||||
|
existing = {
|
||||||
|
tname
|
||||||
|
for (tname,) in (
|
||||||
|
await session.execute(
|
||||||
|
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
}
|
||||||
|
|
||||||
|
to_remove = sorted(set(t for t in norm if t in existing))
|
||||||
|
not_present = sorted(set(t for t in norm if t not in existing))
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
await session.execute(
|
||||||
|
delete(AssetInfoTag)
|
||||||
|
.where(
|
||||||
|
AssetInfoTag.asset_info_id == asset_info_id,
|
||||||
|
AssetInfoTag.tag_name.in_(to_remove),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||||
|
|
||||||
|
|
||||||
def _normalize_tags(tags: Sequence[str] | None) -> list[str]:
|
def _normalize_tags(tags: Sequence[str] | None) -> list[str]:
|
||||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user