diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index 369d6710b..1c5563d3f 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -65,7 +65,7 @@ def upgrade() -> None: op.create_table( "asset_info_tags", sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), - sa.Column("tag_name", sa.String(length=128), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), sa.Column("added_by", sa.String(length=128), nullable=True), sa.Column("added_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), @@ -104,7 +104,7 @@ def upgrade() -> None: # Tags vocabulary for models tags_table = sa.table( "tags", - sa.column("name", sa.String()), + sa.column("name", sa.String(length=512)), sa.column("tag_type", sa.String()), ) op.bulk_insert( diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 2e58532b8..8c037fd97 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -4,7 +4,7 @@ from aiohttp import web from pydantic import ValidationError from .. import assets_manager -from .schemas_in import ListAssetsQuery, UpdateAssetBody +from . import schemas_in ROUTES = web.RouteTableDef() @@ -15,7 +15,7 @@ async def list_assets(request: web.Request) -> web.Response: query_dict = dict(request.rel_url.query) try: - q = ListAssetsQuery.model_validate(query_dict) + q = schemas_in.ListAssetsQuery.model_validate(query_dict) except ValidationError as ve: return _validation_error_response("INVALID_QUERY", ve) @@ -29,7 +29,7 @@ async def list_assets(request: web.Request) -> web.Response: sort=q.sort, order=q.order, ) - return web.json_response(payload) + return web.json_response(payload.model_dump(mode="json")) @ROUTES.put("/api/assets/{id}") @@ -41,7 +41,7 @@ async def update_asset(request: web.Request) -> web.Response: return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") try: - body = UpdateAssetBody.model_validate(await request.json()) + body = schemas_in.UpdateAssetBody.model_validate(await request.json()) except ValidationError as ve: return _validation_error_response("INVALID_BODY", ve) except Exception: @@ -58,7 +58,89 @@ async def update_asset(request: web.Request) -> web.Response: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) except Exception: return _error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(result, status=200) + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.get("/api/tags") +async def get_tags(request: web.Request) -> web.Response: + query_map = dict(request.rel_url.query) + + try: + query = schemas_in.TagsListQuery.model_validate(query_map) + except ValidationError as ve: + return web.json_response( + {"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": ve.errors()}}, + status=400, + ) + + result = await assets_manager.list_tags( + prefix=query.prefix, + limit=query.limit, + offset=query.offset, + order=query.order, + include_zero=query.include_zero, + ) + return web.json_response(result.model_dump(mode="json")) + + +@ROUTES.post("/api/assets/{id}/tags") +async def add_asset_tags(request: web.Request) -> web.Response: + asset_info_id_raw = request.match_info.get("id") + try: + asset_info_id = int(asset_info_id_raw) + except Exception: + return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") + + try: + payload = await request.json() + data = schemas_in.TagsAdd.model_validate(payload) + except ValidationError as ve: + return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()}) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = await assets_manager.add_tags_to_asset( + asset_info_id=asset_info_id, + tags=data.tags, + origin="manual", + added_by=None, + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + return _error_response(500, "INTERNAL", "Unexpected server error.") + + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.delete("/api/assets/{id}/tags") +async def delete_asset_tags(request: web.Request) -> web.Response: + asset_info_id_raw = request.match_info.get("id") + try: + asset_info_id = int(asset_info_id_raw) + except Exception: + return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") + + try: + payload = await request.json() + data = schemas_in.TagsRemove.model_validate(payload) + except ValidationError as ve: + return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()}) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = await assets_manager.remove_tags_from_asset( + asset_info_id=asset_info_id, + tags=data.tags, + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + return _error_response(500, "INTERNAL", "Unexpected server error.") + + return web.json_response(result.model_dump(mode="json"), status=200) def register_assets_routes(app: web.Application) -> None: diff --git a/app/api/schemas_in.py b/app/api/schemas_in.py index fb936a79a..4e0eb6253 100644 --- a/app/api/schemas_in.py +++ b/app/api/schemas_in.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Any, Optional, Literal -from pydantic import BaseModel, Field, field_validator, model_validator, conint +from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator, conint class ListAssetsQuery(BaseModel): @@ -64,3 +64,48 @@ class UpdateAssetBody(BaseModel): if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags): raise ValueError("Field 'tags' must be an array of strings.") return self + + +class TagsListQuery(BaseModel): + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + prefix: Optional[str] = Field(None, min_length=1, max_length=256) + limit: int = Field(100, ge=1, le=1000) + offset: int = Field(0, ge=0, le=10_000_000) + order: Literal["count_desc", "name_asc"] = "count_desc" + include_zero: bool = True + + @field_validator("prefix") + @classmethod + def normalize_prefix(cls, v: Optional[str]) -> Optional[str]: + if v is None: + return v + v = v.strip() + return v.lower() or None + + +class TagsAdd(BaseModel): + model_config = ConfigDict(extra="ignore") + tags: list[str] = Field(..., min_length=1) + + @field_validator("tags") + @classmethod + def normalize_tags(cls, v: list[str]) -> list[str]: + out = [] + for t in v: + if not isinstance(t, str): + raise TypeError("tags must be strings") + tnorm = t.strip().lower() + if tnorm: + out.append(tnorm) + seen = set() + deduplicated = [] + for x in out: + if x not in seen: + seen.add(x) + deduplicated.append(x) + return deduplicated + + +class TagsRemove(TagsAdd): + pass diff --git a/app/api/schemas_out.py b/app/api/schemas_out.py new file mode 100644 index 000000000..f86da3523 --- /dev/null +++ b/app/api/schemas_out.py @@ -0,0 +1,69 @@ +from datetime import datetime +from typing import Any, Optional +from pydantic import BaseModel, ConfigDict, Field, field_serializer + + +class AssetSummary(BaseModel): + id: int + name: str + asset_hash: str + size: Optional[int] = None + mime_type: Optional[str] = None + tags: list[str] = Field(default_factory=list) + preview_url: Optional[str] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + last_access_time: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("created_at", "updated_at", "last_access_time") + def _ser_dt(self, v: Optional[datetime], _info): + return v.isoformat() if v else None + + +class AssetsList(BaseModel): + assets: list[AssetSummary] + total: int + has_more: bool + + +class AssetUpdated(BaseModel): + id: int + name: str + asset_hash: str + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + updated_at: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("updated_at") + def _ser_updated(self, v: Optional[datetime], _info): + return v.isoformat() if v else None + + +class TagUsage(BaseModel): + name: str + count: int + type: str + + +class TagsList(BaseModel): + tags: list[TagUsage] = Field(default_factory=list) + total: int + has_more: bool + + +class TagsAdd(BaseModel): + model_config = ConfigDict(str_strip_whitespace=True) + added: list[str] = Field(default_factory=list) + already_present: list[str] = Field(default_factory=list) + total_tags: list[str] = Field(default_factory=list) + + +class TagsRemove(BaseModel): + model_config = ConfigDict(str_strip_whitespace=True) + removed: list[str] = Field(default_factory=list) + not_present: list[str] = Field(default_factory=list) + total_tags: list[str] = Field(default_factory=list) diff --git a/app/assets_manager.py b/app/assets_manager.py index 05031a1bf..60c3f08cd 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -14,7 +14,11 @@ from .database.services import ( list_asset_infos_page, update_asset_info_full, get_asset_tags, + list_tags_with_usage, + add_tags_to_asset_info, + remove_tags_from_asset_info, ) +from .api import schemas_out def populate_db_with_asset(tags: list[str], file_name: str, file_path: str) -> None: @@ -70,7 +74,7 @@ async def list_assets( offset: int = 0, sort: str | None = "created_at", order: str | None = "desc", -) -> dict: +) -> schemas_out.AssetsList: sort = _safe_sort_field(sort) order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() @@ -87,30 +91,30 @@ async def list_assets( order=order, ) - assets_json = [] + summaries: list[schemas_out.AssetSummary] = [] for info in infos: - asset = info.asset # populated via contains_eager + asset = info.asset tags = tag_map.get(info.id, []) - assets_json.append( - { - "id": info.id, - "name": info.name, - "asset_hash": info.asset_hash, - "size": int(asset.size_bytes) if asset else None, - "mime_type": asset.mime_type if asset else None, - "tags": tags, - "preview_url": f"/api/v1/assets/{info.id}/content", # TODO: implement actual content endpoint later - "created_at": info.created_at.isoformat() if info.created_at else None, - "updated_at": info.updated_at.isoformat() if info.updated_at else None, - "last_access_time": info.last_access_time.isoformat() if info.last_access_time else None, - } + summaries.append( + schemas_out.AssetSummary( + id=info.id, + name=info.name, + asset_hash=info.asset_hash, + size=int(asset.size_bytes) if asset else None, + mime_type=asset.mime_type if asset else None, + tags=tags, + preview_url=f"/api/v1/assets/{info.id}/content", # TODO: implement actual content endpoint later + created_at=info.created_at, + updated_at=info.updated_at, + last_access_time=info.last_access_time, + ) ) - return { - "assets": assets_json, - "total": total, - "has_more": (offset + len(assets_json)) < total, - } + return schemas_out.AssetsList( + assets=summaries, + total=total, + has_more=(offset + len(summaries)) < total, + ) async def update_asset( @@ -119,7 +123,7 @@ async def update_asset( name: str | None = None, tags: list[str] | None = None, user_metadata: dict | None = None, -) -> dict: +) -> schemas_out.AssetUpdated: async with await create_session() as session: info = await update_asset_info_full( session, @@ -134,14 +138,40 @@ async def update_asset( tag_names = await get_asset_tags(session, asset_info_id=asset_info_id) await session.commit() - return { - "id": info.id, - "name": info.name, - "asset_hash": info.asset_hash, - "tags": tag_names, - "user_metadata": info.user_metadata or {}, - "updated_at": info.updated_at.isoformat() if info.updated_at else None, - } + return schemas_out.AssetUpdated( + id=info.id, + name=info.name, + asset_hash=info.asset_hash, + tags=tag_names, + user_metadata=info.user_metadata or {}, + updated_at=info.updated_at, + ) + + + +async def list_tags( + *, + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + order: str = "count_desc", + include_zero: bool = True, +) -> schemas_out.TagsList: + limit = max(1, min(1000, limit)) + offset = max(0, offset) + + async with await create_session() as session: + rows, total = await list_tags_with_usage( + session, + prefix=prefix, + limit=limit, + offset=offset, + include_zero=include_zero, + order=order, + ) + + tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows] + return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total) def _safe_sort_field(requested: str | None) -> str: @@ -156,3 +186,38 @@ def _safe_sort_field(requested: str | None) -> str: def _get_size_mtime_ns(path: str) -> tuple[int, int]: st = os.stat(path, follow_symlinks=True) return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + + +async def add_tags_to_asset( + *, + asset_info_id: int, + tags: list[str], + origin: str = "manual", + added_by: str | None = None, +) -> schemas_out.TagsAdd: + async with await create_session() as session: + data = await add_tags_to_asset_info( + session, + asset_info_id=asset_info_id, + tags=tags, + origin=origin, + added_by=added_by, + create_if_missing=True, + ) + await session.commit() + return schemas_out.TagsAdd(**data) + + +async def remove_tags_from_asset( + *, + asset_info_id: int, + tags: list[str], +) -> schemas_out.TagsRemove: + async with await create_session() as session: + data = await remove_tags_from_asset_info( + session, + asset_info_id=asset_info_id, + tags=tags, + ) + await session.commit() + return schemas_out.TagsRemove(**data) diff --git a/app/database/services.py b/app/database/services.py index c2792b4c4..3280fd534 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -493,7 +493,7 @@ async def replace_asset_info_metadata_projection( await session.flush() -async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[Tag]: +async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[str]: return [ tag_name for (tag_name,) in ( @@ -504,6 +504,179 @@ async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[T ] +async def list_tags_with_usage( + session, + *, + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + include_zero: bool = True, + order: str = "count_desc", # "count_desc" | "name_asc" +) -> tuple[list[tuple[str, str, int]], int]: + """ + Returns: + rows: list of (name, tag_type, count) + total: number of tags matching filter (independent of pagination) + """ + # Subquery with counts by tag_name + counts_sq = ( + select( + AssetInfoTag.tag_name.label("tag_name"), + func.count(AssetInfoTag.asset_info_id).label("cnt"), + ) + .group_by(AssetInfoTag.tag_name) + .subquery() + ) + + # Base select with LEFT JOIN so we can include zero-usage tags + q = ( + select( + Tag.name, + Tag.tag_type, + func.coalesce(counts_sq.c.cnt, 0).label("count"), + ) + .select_from(Tag) + .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) + ) + + # Prefix filter (tags are lowercase by check constraint) + if prefix: + q = q.where(Tag.name.like(prefix.strip().lower() + "%")) + + # Include_zero toggles: if False, drop zero-usage tags + if not include_zero: + q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) + + # Ordering + if order == "name_asc": + q = q.order_by(Tag.name.asc()) + else: # default "count_desc" + q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) + + # Total (without limit/offset, same filters) + total_q = select(func.count()).select_from(Tag) + if prefix: + total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%")) + if not include_zero: + # count only names that appear in counts subquery + total_q = total_q.where( + Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) + ) + + rows = (await session.execute(q.limit(limit).offset(offset))).all() + total = (await session.execute(total_q)).scalar_one() + + # Normalize counts to int for SQLite/Postgres consistency + rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + return rows_norm, int(total or 0) + + +async def add_tags_to_asset_info( + session: AsyncSession, + *, + asset_info_id: int, + tags: Sequence[str], + origin: str = "manual", + added_by: Optional[str] = None, + create_if_missing: bool = True, +) -> dict: + """Adds tags to an AssetInfo. + If create_if_missing=True, missing tag rows are created as 'user'. + Returns: {"added": [...], "already_present": [...], "total_tags": [...]} + """ + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = _normalize_tags(tags) + if not norm: + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"added": [], "already_present": [], "total_tags": total} + + # Ensure tag rows exist if requested. + if create_if_missing: + await _ensure_tags_exist(session, norm, tag_type="user") + + # Current links + existing = { + tname + for (tname,) in ( + await session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + to_add = [t for t in norm if t not in existing] + already = [t for t in norm if t in existing] + + if to_add: + now = datetime.now(timezone.utc) + # Make insert race-safe with a nested tx; ignore dup conflicts if any. + async with session.begin_nested(): + session.add_all([ + AssetInfoTag( + asset_info_id=asset_info_id, + tag_name=t, + origin=origin, + added_by=added_by, + added_at=now, + ) for t in to_add + ]) + try: + await session.flush() + except IntegrityError: + # Another writer linked the same tag at the same time -> ok, treat as already present. + await session.rollback() + + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"added": sorted(set(to_add)), "already_present": sorted(set(already)), "total_tags": total} + + +async def remove_tags_from_asset_info( + session: AsyncSession, + *, + asset_info_id: int, + tags: Sequence[str], +) -> dict: + """Removes tags from an AssetInfo. + Returns: {"removed": [...], "not_present": [...], "total_tags": [...]} + """ + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = _normalize_tags(tags) + if not norm: + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": [], "not_present": [], "total_tags": total} + + existing = { + tname + for (tname,) in ( + await session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + to_remove = sorted(set(t for t in norm if t in existing)) + not_present = sorted(set(t for t in norm if t not in existing)) + + if to_remove: + await session.execute( + delete(AssetInfoTag) + .where( + AssetInfoTag.asset_info_id == asset_info_id, + AssetInfoTag.tag_name.in_(to_remove), + ) + ) + await session.flush() + + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": to_remove, "not_present": not_present, "total_tags": total} + + def _normalize_tags(tags: Sequence[str] | None) -> list[str]: return [t.strip().lower() for t in (tags or []) if (t or "").strip()]