From 17ad7e393f0ab45b28de083eff35d5c3c384c861 Mon Sep 17 00:00:00 2001 From: Luke Mino-Altherr Date: Tue, 3 Feb 2026 10:25:24 -0800 Subject: [PATCH] refactor(assets): split queries.py into modular query modules Split the ~1000 line app/assets/database/queries.py into focused modules: - queries/asset.py - Asset entity queries (asset_exists_by_hash, get_asset_by_hash) - queries/asset_info.py - AssetInfo queries (~15 functions) - queries/cache_state.py - AssetCacheState queries (list_cache_states_by_asset_id, pick_best_live_path, prune_orphaned_assets, fast_db_consistency_pass) - queries/tags.py - Tag queries (8 functions including ensure_tags_exist, add/remove tag functions, list_tags_with_usage) - queries/__init__.py - Re-exports all public functions for backward compatibility Also adds comprehensive unit tests using in-memory SQLite: - tests-unit/assets_test/queries/conftest.py - Session fixture - tests-unit/assets_test/queries/test_asset.py - 5 tests - tests-unit/assets_test/queries/test_asset_info.py - 23 tests - tests-unit/assets_test/queries/test_cache_state.py - 8 tests - tests-unit/assets_test/queries/test_metadata.py - 12 tests for _apply_metadata_filter - tests-unit/assets_test/queries/test_tags.py - 23 tests All 71 unit tests pass. Existing integration tests unaffected. Amp-Thread-ID: https://ampcode.com/threads/T-019c24bb-475b-7442-9ff9-8288edea3345 Co-authored-by: Amp --- app/assets/database/queries/__init__.py | 73 ++++ app/assets/database/queries/asset.py | 31 ++ .../{queries.py => queries/asset_info.py} | 331 ++---------------- app/assets/database/queries/cache_state.py | 212 +++++++++++ app/assets/database/queries/tags.py | 280 +++++++++++++++ app/assets/database/tags.py | 62 ---- app/assets/scanner.py | 190 ++-------- tests-unit/assets_test/queries/conftest.py | 14 + tests-unit/assets_test/queries/test_asset.py | 39 +++ .../assets_test/queries/test_asset_info.py | 268 ++++++++++++++ .../assets_test/queries/test_cache_state.py | 128 +++++++ .../assets_test/queries/test_metadata.py | 180 ++++++++++ tests-unit/assets_test/queries/test_tags.py | 297 ++++++++++++++++ 13 files changed, 1569 insertions(+), 536 deletions(-) create mode 100644 app/assets/database/queries/__init__.py create mode 100644 app/assets/database/queries/asset.py rename app/assets/database/{queries.py => queries/asset_info.py} (68%) create mode 100644 app/assets/database/queries/cache_state.py create mode 100644 app/assets/database/queries/tags.py delete mode 100644 app/assets/database/tags.py create mode 100644 tests-unit/assets_test/queries/conftest.py create mode 100644 tests-unit/assets_test/queries/test_asset.py create mode 100644 tests-unit/assets_test/queries/test_asset_info.py create mode 100644 tests-unit/assets_test/queries/test_cache_state.py create mode 100644 tests-unit/assets_test/queries/test_metadata.py create mode 100644 tests-unit/assets_test/queries/test_tags.py diff --git a/app/assets/database/queries/__init__.py b/app/assets/database/queries/__init__.py new file mode 100644 index 000000000..874ae3bf2 --- /dev/null +++ b/app/assets/database/queries/__init__.py @@ -0,0 +1,73 @@ +# Re-export public API from query modules +# Maintains backward compatibility with old flat queries.py imports + +from app.assets.database.queries.asset import ( + asset_exists_by_hash, + get_asset_by_hash, +) + +from app.assets.database.queries.asset_info import ( + asset_info_exists_for_asset_id, + get_asset_info_by_id, + list_asset_infos_page, + fetch_asset_info_asset_and_tags, + fetch_asset_info_and_asset, + touch_asset_info_by_id, + create_asset_info_for_existing_asset, + replace_asset_info_metadata_projection, + ingest_fs_asset, + update_asset_info_full, + delete_asset_info_by_id, + set_asset_info_preview, +) + +from app.assets.database.queries.cache_state import ( + list_cache_states_by_asset_id, + pick_best_live_path, + prune_orphaned_assets, + fast_db_consistency_pass, +) + +from app.assets.database.queries.tags import ( + ensure_tags_exist, + get_asset_tags, + set_asset_info_tags, + add_tags_to_asset_info, + remove_tags_from_asset_info, + add_missing_tag_for_asset_id, + remove_missing_tag_for_asset_id, + list_tags_with_usage, +) + +__all__ = [ + # asset.py + "asset_exists_by_hash", + "get_asset_by_hash", + # asset_info.py + "asset_info_exists_for_asset_id", + "get_asset_info_by_id", + "list_asset_infos_page", + "fetch_asset_info_asset_and_tags", + "fetch_asset_info_and_asset", + "touch_asset_info_by_id", + "create_asset_info_for_existing_asset", + "replace_asset_info_metadata_projection", + "ingest_fs_asset", + "update_asset_info_full", + "delete_asset_info_by_id", + "set_asset_info_preview", + # cache_state.py + "list_cache_states_by_asset_id", + "pick_best_live_path", + "prune_orphaned_assets", + "fast_db_consistency_pass", + # tags.py + "ensure_tags_exist", + "get_asset_tags", + "set_asset_info_tags", + "add_tags_to_asset_info", + "remove_tags_from_asset_info", + "add_missing_tag_for_asset_id", + "remove_missing_tag_for_asset_id", + "list_tags_with_usage", +] diff --git a/app/assets/database/queries/asset.py b/app/assets/database/queries/asset.py new file mode 100644 index 000000000..5d00991a0 --- /dev/null +++ b/app/assets/database/queries/asset.py @@ -0,0 +1,31 @@ +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset + + +def asset_exists_by_hash( + session: Session, + *, + asset_hash: str, +) -> bool: + """ + Check if an asset with a given hash exists in database. + """ + row = ( + session.execute( + select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).first() + return row is not None + + +def get_asset_by_hash( + session: Session, + *, + asset_hash: str, +) -> Asset | None: + return ( + session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() diff --git a/app/assets/database/queries.py b/app/assets/database/queries/asset_info.py similarity index 68% rename from app/assets/database/queries.py rename to app/assets/database/queries/asset_info.py index d6b33ec7b..f238138b3 100644 --- a/app/assets/database/queries.py +++ b/app/assets/database/queries/asset_info.py @@ -1,21 +1,27 @@ import os import logging -import sqlalchemy as sa from collections import defaultdict from datetime import datetime -from typing import Iterable, Any -from sqlalchemy import select, delete, exists, func +from typing import Any, Sequence + +import sqlalchemy as sa +from sqlalchemy import select, delete, exists from sqlalchemy.dialects import sqlite from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, contains_eager, noload -from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag + +from app.assets.database.models import ( + Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag +) from app.assets.helpers import ( compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow ) -from typing import Sequence +from app.assets.database.queries.asset import get_asset_by_hash +from app.assets.database.queries.cache_state import list_cache_states_by_asset_id, pick_best_live_path +from app.assets.database.queries.tags import ensure_tags_exist, set_asset_info_tags, remove_missing_tag_for_asset_id -def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: +def _visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" owner_id = (owner_id or "").strip() if owner_id == "": @@ -23,23 +29,7 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: return AssetInfo.owner_id.in_(["", owner_id]) -def pick_best_live_path(states: Sequence[AssetCacheState]) -> str: - """ - Return the best on-disk path among cache states: - 1) Prefer a path that exists with needs_verify == False (already verified). - 2) Otherwise, pick the first path that exists. - 3) Otherwise return empty string. - """ - alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)] - if not alive: - return "" - for s in alive: - if not getattr(s, "needs_verify", False): - return s.file_path - return alive[0].file_path - - -def apply_tag_filters( +def _apply_tag_filters( stmt: sa.sql.Select, include_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None, @@ -67,7 +57,7 @@ def apply_tag_filters( return stmt -def apply_metadata_filter( +def _apply_metadata_filter( stmt: sa.sql.Select, metadata_filter: dict | None = None, ) -> sa.sql.Select: @@ -119,22 +109,6 @@ def apply_metadata_filter( return stmt -def asset_exists_by_hash( - session: Session, - *, - asset_hash: str, -) -> bool: - """ - Check if an asset with a given hash exists in database. - """ - row = ( - session.execute( - select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) - ) - ).first() - return row is not None - - def asset_info_exists_for_asset_id( session: Session, *, @@ -149,16 +123,6 @@ def asset_info_exists_for_asset_id( return (session.execute(q)).first() is not None -def get_asset_by_hash( - session: Session, - *, - asset_hash: str, -) -> Asset | None: - return ( - session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) - ).scalars().first() - - def get_asset_info_by_id( session: Session, *, @@ -183,15 +147,15 @@ def list_asset_infos_page( select(AssetInfo) .join(Asset, Asset.id == AssetInfo.asset_id) .options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags)) - .where(visible_owner_clause(owner_id)) + .where(_visible_owner_clause(owner_id)) ) if name_contains: escaped, esc = escape_like_prefix(name_contains) base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) - base = apply_tag_filters(base, include_tags, exclude_tags) - base = apply_metadata_filter(base, metadata_filter) + base = _apply_tag_filters(base, include_tags, exclude_tags) + base = _apply_metadata_filter(base, metadata_filter) sort = (sort or "created_at").lower() order = (order or "desc").lower() @@ -211,13 +175,13 @@ def list_asset_infos_page( select(sa.func.count()) .select_from(AssetInfo) .join(Asset, Asset.id == AssetInfo.asset_id) - .where(visible_owner_clause(owner_id)) + .where(_visible_owner_clause(owner_id)) ) if name_contains: escaped, esc = escape_like_prefix(name_contains) count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) - count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) - count_stmt = apply_metadata_filter(count_stmt, metadata_filter) + count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) total = int((session.execute(count_stmt)).scalar_one() or 0) @@ -250,7 +214,7 @@ def fetch_asset_info_asset_and_tags( .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) .where( AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), + _visible_owner_clause(owner_id), ) .options(noload(AssetInfo.tags)) .order_by(Tag.name.asc()) @@ -281,7 +245,7 @@ def fetch_asset_info_and_asset( .join(Asset, Asset.id == AssetInfo.asset_id) .where( AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), + _visible_owner_clause(owner_id), ) .limit(1) .options(noload(AssetInfo.tags)) @@ -292,17 +256,6 @@ def fetch_asset_info_and_asset( return None return pair[0], pair[1] -def list_cache_states_by_asset_id( - session: Session, *, asset_id: str -) -> Sequence[AssetCacheState]: - return ( - session.execute( - select(AssetCacheState) - .where(AssetCacheState.asset_id == asset_id) - .order_by(AssetCacheState.id.asc()) - ) - ).scalars().all() - def touch_asset_info_by_id( session: Session, @@ -366,7 +319,6 @@ def create_asset_info_for_existing_asset( raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.") return existing - # metadata["filename"] hack new_meta = dict(user_metadata or {}) computed_filename = None try: @@ -394,42 +346,6 @@ def create_asset_info_for_existing_asset( return info -def set_asset_info_tags( - session: Session, - *, - asset_info_id: str, - tags: Sequence[str], - origin: str = "manual", -) -> dict: - desired = normalize_tags(tags) - - current = set( - tag_name for (tag_name,) in ( - session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) - ).all() - ) - - to_add = [t for t in desired if t not in current] - to_remove = [t for t in current if t not in desired] - - if to_add: - ensure_tags_exist(session, to_add, tag_type="user") - session.add_all([ - AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) - for t in to_add - ]) - session.flush() - - if to_remove: - session.execute( - delete(AssetInfoTag) - .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) - ) - session.flush() - - return {"added": to_add, "removed": to_remove, "total": desired} - - def replace_asset_info_metadata_projection( session: Session, *, @@ -507,7 +423,6 @@ def ingest_fs_asset( "asset_info_id": None, } - # 1) Asset by hash asset = ( session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) ).scalars().first() @@ -543,7 +458,6 @@ def ingest_fs_asset( if changed: out["asset_updated"] = True - # 2) AssetCacheState upsert by file_path (unique) vals = { "asset_id": asset.id, "file_path": locator, @@ -575,7 +489,6 @@ def ingest_fs_asset( if int(res2.rowcount or 0) > 0: out["state_updated"] = True - # 3) Optional AssetInfo + tags + metadata if info_name: try: with session.begin_nested(): @@ -652,7 +565,6 @@ def ingest_fs_asset( ) session.flush() - # metadata["filename"] hack if out["asset_info_id"] is not None: primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id)) computed_filename = compute_relative_filename(primary_path) if primary_path else None @@ -752,207 +664,11 @@ def delete_asset_info_by_id( ) -> bool: stmt = sa.delete(AssetInfo).where( AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), + _visible_owner_clause(owner_id), ) return int((session.execute(stmt)).rowcount or 0) > 0 -def list_tags_with_usage( - session: Session, - prefix: str | None = None, - limit: int = 100, - offset: int = 0, - include_zero: bool = True, - order: str = "count_desc", - owner_id: str = "", -) -> tuple[list[tuple[str, str, int]], int]: - counts_sq = ( - select( - AssetInfoTag.tag_name.label("tag_name"), - func.count(AssetInfoTag.asset_info_id).label("cnt"), - ) - .select_from(AssetInfoTag) - .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) - .where(visible_owner_clause(owner_id)) - .group_by(AssetInfoTag.tag_name) - .subquery() - ) - - q = ( - select( - Tag.name, - Tag.tag_type, - func.coalesce(counts_sq.c.cnt, 0).label("count"), - ) - .select_from(Tag) - .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) - ) - - if prefix: - escaped, esc = escape_like_prefix(prefix.strip().lower()) - q = q.where(Tag.name.like(escaped + "%", escape=esc)) - - if not include_zero: - q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) - - if order == "name_asc": - q = q.order_by(Tag.name.asc()) - else: - q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) - - total_q = select(func.count()).select_from(Tag) - if prefix: - escaped, esc = escape_like_prefix(prefix.strip().lower()) - total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) - if not include_zero: - total_q = total_q.where( - Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) - ) - - rows = (session.execute(q.limit(limit).offset(offset))).all() - total = (session.execute(total_q)).scalar_one() - - rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] - return rows_norm, int(total or 0) - - -def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: - wanted = normalize_tags(list(names)) - if not wanted: - return - rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] - ins = ( - sqlite.insert(Tag) - .values(rows) - .on_conflict_do_nothing(index_elements=[Tag.name]) - ) - session.execute(ins) - - -def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]: - return [ - tag_name for (tag_name,) in ( - session.execute( - select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - ] - - -def add_tags_to_asset_info( - session: Session, - *, - asset_info_id: str, - tags: Sequence[str], - origin: str = "manual", - create_if_missing: bool = True, - asset_info_row: Any = None, -) -> dict: - if not asset_info_row: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - norm = normalize_tags(tags) - if not norm: - total = get_asset_tags(session, asset_info_id=asset_info_id) - return {"added": [], "already_present": [], "total_tags": total} - - if create_if_missing: - ensure_tags_exist(session, norm, tag_type="user") - - current = { - tag_name - for (tag_name,) in ( - session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - } - - want = set(norm) - to_add = sorted(want - current) - - if to_add: - with session.begin_nested() as nested: - try: - session.add_all( - [ - AssetInfoTag( - asset_info_id=asset_info_id, - tag_name=t, - origin=origin, - added_at=utcnow(), - ) - for t in to_add - ] - ) - session.flush() - except IntegrityError: - nested.rollback() - - after = set(get_asset_tags(session, asset_info_id=asset_info_id)) - return { - "added": sorted(((after - current) & want)), - "already_present": sorted(want & current), - "total_tags": sorted(after), - } - - -def remove_tags_from_asset_info( - session: Session, - *, - asset_info_id: str, - tags: Sequence[str], -) -> dict: - info = session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - norm = normalize_tags(tags) - if not norm: - total = get_asset_tags(session, asset_info_id=asset_info_id) - return {"removed": [], "not_present": [], "total_tags": total} - - existing = { - tag_name - for (tag_name,) in ( - session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - } - - to_remove = sorted(set(t for t in norm if t in existing)) - not_present = sorted(set(t for t in norm if t not in existing)) - - if to_remove: - session.execute( - delete(AssetInfoTag) - .where( - AssetInfoTag.asset_info_id == asset_info_id, - AssetInfoTag.tag_name.in_(to_remove), - ) - ) - session.flush() - - total = get_asset_tags(session, asset_info_id=asset_info_id) - return {"removed": to_remove, "not_present": not_present, "total_tags": total} - - -def remove_missing_tag_for_asset_id( - session: Session, - *, - asset_id: str, -) -> None: - session.execute( - sa.delete(AssetInfoTag).where( - AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), - AssetInfoTag.tag_name == "missing", - ) - ) - - def set_asset_info_preview( session: Session, *, @@ -967,7 +683,6 @@ def set_asset_info_preview( if preview_asset_id is None: info.preview_id = None else: - # validate preview asset exists if not session.get(Asset, preview_asset_id): raise ValueError(f"Preview Asset {preview_asset_id} not found") info.preview_id = preview_asset_id diff --git a/app/assets/database/queries/cache_state.py b/app/assets/database/queries/cache_state.py new file mode 100644 index 000000000..1da3c4430 --- /dev/null +++ b/app/assets/database/queries/cache_state.py @@ -0,0 +1,212 @@ +import os +from typing import Sequence + +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetCacheState, AssetInfo +from app.assets.helpers import escape_like_prefix + + +def list_cache_states_by_asset_id( + session: Session, *, asset_id: str +) -> Sequence[AssetCacheState]: + return ( + session.execute( + select(AssetCacheState) + .where(AssetCacheState.asset_id == asset_id) + .order_by(AssetCacheState.id.asc()) + ) + ).scalars().all() + + +def pick_best_live_path(states: Sequence[AssetCacheState]) -> str: + """ + Return the best on-disk path among cache states: + 1) Prefer a path that exists with needs_verify == False (already verified). + 2) Otherwise, pick the first path that exists. + 3) Otherwise return empty string. + """ + alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)] + if not alive: + return "" + for s in alive: + if not getattr(s, "needs_verify", False): + return s.file_path + return alive[0].file_path + + +def prune_orphaned_assets(session: Session, roots: tuple[str, ...], prefixes_for_root_fn) -> int: + """Prune cache states outside configured prefixes, then delete orphaned seed assets. + + Args: + session: Database session + roots: Tuple of root types to prune + prefixes_for_root_fn: Function to get prefixes for a root type + + Returns: + Number of orphaned assets deleted + """ + all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root_fn(r)] + if not all_prefixes: + return 0 + + def make_prefix_condition(prefix: str): + base = prefix if prefix.endswith(os.sep) else prefix + os.sep + escaped, esc = escape_like_prefix(base) + return AssetCacheState.file_path.like(escaped + "%", escape=esc) + + matches_valid_prefix = sa.or_(*[make_prefix_condition(p) for p in all_prefixes]) + + orphan_subq = ( + sa.select(Asset.id) + .outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id) + .where(Asset.hash.is_(None), AssetCacheState.id.is_(None)) + ).scalar_subquery() + + session.execute(sa.delete(AssetCacheState).where(~matches_valid_prefix)) + session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq))) + result = session.execute(sa.delete(Asset).where(Asset.id.in_(orphan_subq))) + return result.rowcount + + +def fast_db_consistency_pass( + session: Session, + root: str, + *, + prefixes_for_root_fn, + escape_like_prefix_fn, + fast_asset_file_check_fn, + add_missing_tag_fn, + remove_missing_tag_fn, + collect_existing_paths: bool = False, + update_missing_tags: bool = False, +) -> set[str] | None: + """Fast DB+FS pass for a root: + - Toggle needs_verify per state using fast check + - For hashed assets with at least one fast-ok state in this root: delete stale missing states + - For seed assets with all states missing: delete Asset and its AssetInfos + - Optionally add/remove 'missing' tags based on fast-ok in this root + - Optionally return surviving absolute paths + """ + import contextlib + + prefixes = prefixes_for_root_fn(root) + if not prefixes: + return set() if collect_existing_paths else None + + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + escaped, esc = escape_like_prefix_fn(base) + conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) + + rows = ( + session.execute( + sa.select( + AssetCacheState.id, + AssetCacheState.file_path, + AssetCacheState.mtime_ns, + AssetCacheState.needs_verify, + AssetCacheState.asset_id, + Asset.hash, + Asset.size_bytes, + ) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(sa.or_(*conds)) + .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) + ) + ).all() + + by_asset: dict[str, dict] = {} + for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows: + acc = by_asset.get(aid) + if acc is None: + acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} + by_asset[aid] = acc + + fast_ok = False + try: + exists = True + fast_ok = fast_asset_file_check_fn( + mtime_db=mtime_db, + size_db=acc["size_db"], + stat_result=os.stat(fp, follow_symlinks=True), + ) + except FileNotFoundError: + exists = False + except OSError: + exists = False + + acc["states"].append({ + "sid": sid, + "fp": fp, + "exists": exists, + "fast_ok": fast_ok, + "needs_verify": bool(needs_verify), + }) + + to_set_verify: list[int] = [] + to_clear_verify: list[int] = [] + stale_state_ids: list[int] = [] + survivors: set[str] = set() + + for aid, acc in by_asset.items(): + a_hash = acc["hash"] + states = acc["states"] + any_fast_ok = any(s["fast_ok"] for s in states) + all_missing = all(not s["exists"] for s in states) + + for s in states: + if not s["exists"]: + continue + if s["fast_ok"] and s["needs_verify"]: + to_clear_verify.append(s["sid"]) + if not s["fast_ok"] and not s["needs_verify"]: + to_set_verify.append(s["sid"]) + + if a_hash is None: + if states and all_missing: + session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid)) + asset = session.get(Asset, aid) + if asset: + session.delete(asset) + else: + for s in states: + if s["exists"]: + survivors.add(os.path.abspath(s["fp"])) + continue + + if any_fast_ok: + for s in states: + if not s["exists"]: + stale_state_ids.append(s["sid"]) + if update_missing_tags: + with contextlib.suppress(Exception): + remove_missing_tag_fn(session, asset_id=aid) + elif update_missing_tags: + with contextlib.suppress(Exception): + add_missing_tag_fn(session, asset_id=aid, origin="automatic") + + for s in states: + if s["exists"]: + survivors.add(os.path.abspath(s["fp"])) + + if stale_state_ids: + session.execute(sa.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids))) + if to_set_verify: + session.execute( + sa.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_set_verify)) + .values(needs_verify=True) + ) + if to_clear_verify: + session.execute( + sa.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_clear_verify)) + .values(needs_verify=False) + ) + return survivors if collect_existing_paths else None diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py new file mode 100644 index 000000000..09de20672 --- /dev/null +++ b/app/assets/database/queries/tags.py @@ -0,0 +1,280 @@ +from typing import Iterable, Sequence + +import sqlalchemy as sa +from sqlalchemy import select, delete, func +from sqlalchemy.dialects import sqlite +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from app.assets.database.models import AssetInfo, AssetInfoTag, Tag +from app.assets.helpers import escape_like_prefix, normalize_tags, utcnow + + +def _visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: + """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" + owner_id = (owner_id or "").strip() + if owner_id == "": + return AssetInfo.owner_id == "" + return AssetInfo.owner_id.in_(["", owner_id]) + + +def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: + wanted = normalize_tags(list(names)) + if not wanted: + return + rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + ins = ( + sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + session.execute(ins) + + +def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]: + return [ + tag_name for (tag_name,) in ( + session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + ] + + +def set_asset_info_tags( + session: Session, + *, + asset_info_id: str, + tags: Sequence[str], + origin: str = "manual", +) -> dict: + desired = normalize_tags(tags) + + current = set( + tag_name for (tag_name,) in ( + session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) + ).all() + ) + + to_add = [t for t in desired if t not in current] + to_remove = [t for t in current if t not in desired] + + if to_add: + ensure_tags_exist(session, to_add, tag_type="user") + session.add_all([ + AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) + for t in to_add + ]) + session.flush() + + if to_remove: + session.execute( + delete(AssetInfoTag) + .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) + ) + session.flush() + + return {"added": to_add, "removed": to_remove, "total": desired} + + +def add_tags_to_asset_info( + session: Session, + *, + asset_info_id: str, + tags: Sequence[str], + origin: str = "manual", + create_if_missing: bool = True, + asset_info_row = None, +) -> dict: + if not asset_info_row: + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_asset_tags(session, asset_info_id=asset_info_id) + return {"added": [], "already_present": [], "total_tags": total} + + if create_if_missing: + ensure_tags_exist(session, norm, tag_type="user") + + current = { + tag_name + for (tag_name,) in ( + session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + want = set(norm) + to_add = sorted(want - current) + + if to_add: + with session.begin_nested() as nested: + try: + session.add_all( + [ + AssetInfoTag( + asset_info_id=asset_info_id, + tag_name=t, + origin=origin, + added_at=utcnow(), + ) + for t in to_add + ] + ) + session.flush() + except IntegrityError: + nested.rollback() + + after = set(get_asset_tags(session, asset_info_id=asset_info_id)) + return { + "added": sorted(((after - current) & want)), + "already_present": sorted(want & current), + "total_tags": sorted(after), + } + + +def remove_tags_from_asset_info( + session: Session, + *, + asset_info_id: str, + tags: Sequence[str], +) -> dict: + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": [], "not_present": [], "total_tags": total} + + existing = { + tag_name + for (tag_name,) in ( + session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + to_remove = sorted(set(t for t in norm if t in existing)) + not_present = sorted(set(t for t in norm if t not in existing)) + + if to_remove: + session.execute( + delete(AssetInfoTag) + .where( + AssetInfoTag.asset_info_id == asset_info_id, + AssetInfoTag.tag_name.in_(to_remove), + ) + ) + session.flush() + + total = get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": to_remove, "not_present": not_present, "total_tags": total} + + +def add_missing_tag_for_asset_id( + session: Session, + *, + asset_id: str, + origin: str = "automatic", +) -> None: + select_rows = ( + sa.select( + AssetInfo.id.label("asset_info_id"), + sa.literal("missing").label("tag_name"), + sa.literal(origin).label("origin"), + sa.literal(utcnow()).label("added_at"), + ) + .where(AssetInfo.asset_id == asset_id) + .where( + sa.not_( + sa.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing")) + ) + ) + ) + session.execute( + sqlite.insert(AssetInfoTag) + .from_select( + ["asset_info_id", "tag_name", "origin", "added_at"], + select_rows, + ) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + + +def remove_missing_tag_for_asset_id( + session: Session, + *, + asset_id: str, +) -> None: + session.execute( + sa.delete(AssetInfoTag).where( + AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), + AssetInfoTag.tag_name == "missing", + ) + ) + + +def list_tags_with_usage( + session: Session, + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + include_zero: bool = True, + order: str = "count_desc", + owner_id: str = "", +) -> tuple[list[tuple[str, str, int]], int]: + counts_sq = ( + select( + AssetInfoTag.tag_name.label("tag_name"), + func.count(AssetInfoTag.asset_info_id).label("cnt"), + ) + .select_from(AssetInfoTag) + .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) + .where(_visible_owner_clause(owner_id)) + .group_by(AssetInfoTag.tag_name) + .subquery() + ) + + q = ( + select( + Tag.name, + Tag.tag_type, + func.coalesce(counts_sq.c.cnt, 0).label("count"), + ) + .select_from(Tag) + .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) + ) + + if prefix: + escaped, esc = escape_like_prefix(prefix.strip().lower()) + q = q.where(Tag.name.like(escaped + "%", escape=esc)) + + if not include_zero: + q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) + + if order == "name_asc": + q = q.order_by(Tag.name.asc()) + else: + q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) + + total_q = select(func.count()).select_from(Tag) + if prefix: + escaped, esc = escape_like_prefix(prefix.strip().lower()) + total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) + if not include_zero: + total_q = total_q.where( + Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) + ) + + rows = (session.execute(q.limit(limit).offset(offset))).all() + total = (session.execute(total_q)).scalar_one() + + rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + return rows_norm, int(total or 0) diff --git a/app/assets/database/tags.py b/app/assets/database/tags.py deleted file mode 100644 index 3ab6497c2..000000000 --- a/app/assets/database/tags.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Iterable - -import sqlalchemy -from sqlalchemy.orm import Session -from sqlalchemy.dialects import sqlite - -from app.assets.helpers import normalize_tags, utcnow -from app.assets.database.models import Tag, AssetInfoTag, AssetInfo - - -def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: - wanted = normalize_tags(list(names)) - if not wanted: - return - rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] - ins = ( - sqlite.insert(Tag) - .values(rows) - .on_conflict_do_nothing(index_elements=[Tag.name]) - ) - return session.execute(ins) - -def add_missing_tag_for_asset_id( - session: Session, - *, - asset_id: str, - origin: str = "automatic", -) -> None: - select_rows = ( - sqlalchemy.select( - AssetInfo.id.label("asset_info_id"), - sqlalchemy.literal("missing").label("tag_name"), - sqlalchemy.literal(origin).label("origin"), - sqlalchemy.literal(utcnow()).label("added_at"), - ) - .where(AssetInfo.asset_id == asset_id) - .where( - sqlalchemy.not_( - sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing")) - ) - ) - ) - session.execute( - sqlite.insert(AssetInfoTag) - .from_select( - ["asset_info_id", "tag_name", "origin", "added_at"], - select_rows, - ) - .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) - ) - -def remove_missing_tag_for_asset_id( - session: Session, - *, - asset_id: str, -) -> None: - session.execute( - sqlalchemy.delete(AssetInfoTag).where( - AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), - AssetInfoTag.tag_name == "missing", - ) - ) diff --git a/app/assets/scanner.py b/app/assets/scanner.py index 0172a5c2f..3bba7848c 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -1,19 +1,22 @@ -import contextlib import time import logging import os -import sqlalchemy import folder_paths from app.database.db import create_session, dependencies_available from app.assets.helpers import ( collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path, - list_tree,prefixes_for_root, escape_like_prefix, + list_tree, prefixes_for_root, escape_like_prefix, RootType ) -from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id +from app.assets.database.queries import ( + add_missing_tag_for_asset_id, + ensure_tags_exist, + remove_missing_tag_for_asset_id, + prune_orphaned_assets, + fast_db_consistency_pass, +) from app.assets.database.bulk_ops import seed_from_paths_batch -from app.assets.database.models import Asset, AssetCacheState, AssetInfo def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None: @@ -33,14 +36,28 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No existing_paths: set[str] = set() for r in roots: try: - survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True) + with create_session() as sess: + survivors: set[str] = fast_db_consistency_pass( + sess, + r, + prefixes_for_root_fn=prefixes_for_root, + escape_like_prefix_fn=escape_like_prefix, + fast_asset_file_check_fn=fast_asset_file_check, + add_missing_tag_fn=add_missing_tag_for_asset_id, + remove_missing_tag_fn=remove_missing_tag_for_asset_id, + collect_existing_paths=True, + update_missing_tags=True, + ) + sess.commit() if survivors: existing_paths.update(survivors) except Exception as e: logging.exception("fast DB scan failed for %s: %s", r, e) try: - orphans_pruned = _prune_orphaned_assets(roots) + with create_session() as sess: + orphans_pruned = prune_orphaned_assets(sess, roots, prefixes_for_root) + sess.commit() except Exception as e: logging.exception("orphan pruning failed: %s", e) @@ -101,163 +118,4 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No ) -def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int: - """Prune cache states outside configured prefixes, then delete orphaned seed assets.""" - all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)] - if not all_prefixes: - return 0 - def make_prefix_condition(prefix: str): - base = prefix if prefix.endswith(os.sep) else prefix + os.sep - escaped, esc = escape_like_prefix(base) - return AssetCacheState.file_path.like(escaped + "%", escape=esc) - - matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes]) - - orphan_subq = ( - sqlalchemy.select(Asset.id) - .outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id) - .where(Asset.hash.is_(None), AssetCacheState.id.is_(None)) - ).scalar_subquery() - - with create_session() as sess: - sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix)) - sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq))) - result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq))) - sess.commit() - return result.rowcount - - -def _fast_db_consistency_pass( - root: RootType, - *, - collect_existing_paths: bool = False, - update_missing_tags: bool = False, -) -> set[str] | None: - """Fast DB+FS pass for a root: - - Toggle needs_verify per state using fast check - - For hashed assets with at least one fast-ok state in this root: delete stale missing states - - For seed assets with all states missing: delete Asset and its AssetInfos - - Optionally add/remove 'missing' tags based on fast-ok in this root - - Optionally return surviving absolute paths - """ - prefixes = prefixes_for_root(root) - if not prefixes: - return set() if collect_existing_paths else None - - conds = [] - for p in prefixes: - base = os.path.abspath(p) - if not base.endswith(os.sep): - base += os.sep - escaped, esc = escape_like_prefix(base) - conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) - - with create_session() as sess: - rows = ( - sess.execute( - sqlalchemy.select( - AssetCacheState.id, - AssetCacheState.file_path, - AssetCacheState.mtime_ns, - AssetCacheState.needs_verify, - AssetCacheState.asset_id, - Asset.hash, - Asset.size_bytes, - ) - .join(Asset, Asset.id == AssetCacheState.asset_id) - .where(sqlalchemy.or_(*conds)) - .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) - ) - ).all() - - by_asset: dict[str, dict] = {} - for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows: - acc = by_asset.get(aid) - if acc is None: - acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} - by_asset[aid] = acc - - fast_ok = False - try: - exists = True - fast_ok = fast_asset_file_check( - mtime_db=mtime_db, - size_db=acc["size_db"], - stat_result=os.stat(fp, follow_symlinks=True), - ) - except FileNotFoundError: - exists = False - except OSError: - exists = False - - acc["states"].append({ - "sid": sid, - "fp": fp, - "exists": exists, - "fast_ok": fast_ok, - "needs_verify": bool(needs_verify), - }) - - to_set_verify: list[int] = [] - to_clear_verify: list[int] = [] - stale_state_ids: list[int] = [] - survivors: set[str] = set() - - for aid, acc in by_asset.items(): - a_hash = acc["hash"] - states = acc["states"] - any_fast_ok = any(s["fast_ok"] for s in states) - all_missing = all(not s["exists"] for s in states) - - for s in states: - if not s["exists"]: - continue - if s["fast_ok"] and s["needs_verify"]: - to_clear_verify.append(s["sid"]) - if not s["fast_ok"] and not s["needs_verify"]: - to_set_verify.append(s["sid"]) - - if a_hash is None: - if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists - sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid)) - asset = sess.get(Asset, aid) - if asset: - sess.delete(asset) - else: - for s in states: - if s["exists"]: - survivors.add(os.path.abspath(s["fp"])) - continue - - if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records - for s in states: - if not s["exists"]: - stale_state_ids.append(s["sid"]) - if update_missing_tags: - with contextlib.suppress(Exception): - remove_missing_tag_for_asset_id(sess, asset_id=aid) - elif update_missing_tags: - with contextlib.suppress(Exception): - add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") - - for s in states: - if s["exists"]: - survivors.add(os.path.abspath(s["fp"])) - - if stale_state_ids: - sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids))) - if to_set_verify: - sess.execute( - sqlalchemy.update(AssetCacheState) - .where(AssetCacheState.id.in_(to_set_verify)) - .values(needs_verify=True) - ) - if to_clear_verify: - sess.execute( - sqlalchemy.update(AssetCacheState) - .where(AssetCacheState.id.in_(to_clear_verify)) - .values(needs_verify=False) - ) - sess.commit() - return survivors if collect_existing_paths else None diff --git a/tests-unit/assets_test/queries/conftest.py b/tests-unit/assets_test/queries/conftest.py new file mode 100644 index 000000000..6e05031db --- /dev/null +++ b/tests-unit/assets_test/queries/conftest.py @@ -0,0 +1,14 @@ +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from app.assets.database.models import Base + + +@pytest.fixture +def session(): + """In-memory SQLite session for fast unit tests.""" + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + with Session(engine) as sess: + yield sess diff --git a/tests-unit/assets_test/queries/test_asset.py b/tests-unit/assets_test/queries/test_asset.py new file mode 100644 index 000000000..432910435 --- /dev/null +++ b/tests-unit/assets_test/queries/test_asset.py @@ -0,0 +1,39 @@ +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset +from app.assets.database.queries import asset_exists_by_hash, get_asset_by_hash + + +class TestAssetExistsByHash: + def test_returns_false_for_nonexistent(self, session: Session): + assert asset_exists_by_hash(session, asset_hash="nonexistent") is False + + def test_returns_true_for_existing(self, session: Session): + asset = Asset(hash="blake3:abc123", size_bytes=100) + session.add(asset) + session.commit() + + assert asset_exists_by_hash(session, asset_hash="blake3:abc123") is True + + def test_does_not_match_null_hash(self, session: Session): + asset = Asset(hash=None, size_bytes=100) + session.add(asset) + session.commit() + + assert asset_exists_by_hash(session, asset_hash="") is False + + +class TestGetAssetByHash: + def test_returns_none_for_nonexistent(self, session: Session): + assert get_asset_by_hash(session, asset_hash="nonexistent") is None + + def test_returns_asset_for_existing(self, session: Session): + asset = Asset(hash="blake3:def456", size_bytes=200, mime_type="image/png") + session.add(asset) + session.commit() + + result = get_asset_by_hash(session, asset_hash="blake3:def456") + assert result is not None + assert result.id == asset.id + assert result.size_bytes == 200 + assert result.mime_type == "image/png" diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py new file mode 100644 index 000000000..18699da5c --- /dev/null +++ b/tests-unit/assets_test/queries/test_asset_info.py @@ -0,0 +1,268 @@ +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag +from app.assets.database.queries import ( + asset_info_exists_for_asset_id, + get_asset_info_by_id, + list_asset_infos_page, + fetch_asset_info_asset_and_tags, + fetch_asset_info_and_asset, + touch_asset_info_by_id, + delete_asset_info_by_id, + set_asset_info_preview, + ensure_tags_exist, + add_tags_to_asset_info, +) +from app.assets.helpers import utcnow + + +def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset: + asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream") + session.add(asset) + session.flush() + return asset + + +def _make_asset_info( + session: Session, + asset: Asset, + name: str = "test", + owner_id: str = "", +) -> AssetInfo: + now = utcnow() + info = AssetInfo( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(info) + session.flush() + return info + + +class TestAssetInfoExistsForAssetId: + def test_returns_false_when_no_info(self, session: Session): + asset = _make_asset(session, "hash1") + assert asset_info_exists_for_asset_id(session, asset_id=asset.id) is False + + def test_returns_true_when_info_exists(self, session: Session): + asset = _make_asset(session, "hash1") + _make_asset_info(session, asset) + assert asset_info_exists_for_asset_id(session, asset_id=asset.id) is True + + +class TestGetAssetInfoById: + def test_returns_none_for_nonexistent(self, session: Session): + assert get_asset_info_by_id(session, asset_info_id="nonexistent") is None + + def test_returns_info(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset, name="myfile.txt") + + result = get_asset_info_by_id(session, asset_info_id=info.id) + assert result is not None + assert result.name == "myfile.txt" + + +class TestListAssetInfosPage: + def test_empty_db(self, session: Session): + infos, tag_map, total = list_asset_infos_page(session) + assert infos == [] + assert tag_map == {} + assert total == 0 + + def test_returns_infos_with_tags(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset, name="test.bin") + ensure_tags_exist(session, ["alpha", "beta"]) + add_tags_to_asset_info(session, asset_info_id=info.id, tags=["alpha", "beta"]) + session.commit() + + infos, tag_map, total = list_asset_infos_page(session) + assert len(infos) == 1 + assert infos[0].id == info.id + assert set(tag_map[info.id]) == {"alpha", "beta"} + assert total == 1 + + def test_name_contains_filter(self, session: Session): + asset = _make_asset(session, "hash1") + _make_asset_info(session, asset, name="model_v1.safetensors") + _make_asset_info(session, asset, name="config.json") + session.commit() + + infos, _, total = list_asset_infos_page(session, name_contains="model") + assert total == 1 + assert infos[0].name == "model_v1.safetensors" + + def test_owner_visibility(self, session: Session): + asset = _make_asset(session, "hash1") + _make_asset_info(session, asset, name="public", owner_id="") + _make_asset_info(session, asset, name="private", owner_id="user1") + session.commit() + + # Empty owner sees only public + infos, _, total = list_asset_infos_page(session, owner_id="") + assert total == 1 + assert infos[0].name == "public" + + # Owner sees both + infos, _, total = list_asset_infos_page(session, owner_id="user1") + assert total == 2 + + def test_include_tags_filter(self, session: Session): + asset = _make_asset(session, "hash1") + info1 = _make_asset_info(session, asset, name="tagged") + info2 = _make_asset_info(session, asset, name="untagged") + ensure_tags_exist(session, ["wanted"]) + add_tags_to_asset_info(session, asset_info_id=info1.id, tags=["wanted"]) + session.commit() + + infos, _, total = list_asset_infos_page(session, include_tags=["wanted"]) + assert total == 1 + assert infos[0].name == "tagged" + + def test_exclude_tags_filter(self, session: Session): + asset = _make_asset(session, "hash1") + info1 = _make_asset_info(session, asset, name="keep") + info2 = _make_asset_info(session, asset, name="exclude") + ensure_tags_exist(session, ["bad"]) + add_tags_to_asset_info(session, asset_info_id=info2.id, tags=["bad"]) + session.commit() + + infos, _, total = list_asset_infos_page(session, exclude_tags=["bad"]) + assert total == 1 + assert infos[0].name == "keep" + + def test_sorting(self, session: Session): + asset = _make_asset(session, "hash1", size=100) + asset2 = _make_asset(session, "hash2", size=500) + _make_asset_info(session, asset, name="small") + _make_asset_info(session, asset2, name="large") + session.commit() + + infos, _, _ = list_asset_infos_page(session, sort="size", order="desc") + assert infos[0].name == "large" + + infos, _, _ = list_asset_infos_page(session, sort="name", order="asc") + assert infos[0].name == "large" + + +class TestFetchAssetInfoAssetAndTags: + def test_returns_none_for_nonexistent(self, session: Session): + result = fetch_asset_info_asset_and_tags(session, "nonexistent") + assert result is None + + def test_returns_tuple(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset, name="test.bin") + ensure_tags_exist(session, ["tag1"]) + add_tags_to_asset_info(session, asset_info_id=info.id, tags=["tag1"]) + session.commit() + + result = fetch_asset_info_asset_and_tags(session, info.id) + assert result is not None + ret_info, ret_asset, ret_tags = result + assert ret_info.id == info.id + assert ret_asset.id == asset.id + assert ret_tags == ["tag1"] + + +class TestFetchAssetInfoAndAsset: + def test_returns_none_for_nonexistent(self, session: Session): + result = fetch_asset_info_and_asset(session, asset_info_id="nonexistent") + assert result is None + + def test_returns_tuple(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + session.commit() + + result = fetch_asset_info_and_asset(session, asset_info_id=info.id) + assert result is not None + ret_info, ret_asset = result + assert ret_info.id == info.id + assert ret_asset.id == asset.id + + +class TestTouchAssetInfoById: + def test_updates_last_access_time(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + original_time = info.last_access_time + session.commit() + + import time + time.sleep(0.01) + + touch_asset_info_by_id(session, asset_info_id=info.id) + session.commit() + + session.refresh(info) + assert info.last_access_time > original_time + + +class TestDeleteAssetInfoById: + def test_deletes_existing(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + session.commit() + + result = delete_asset_info_by_id(session, asset_info_id=info.id, owner_id="") + assert result is True + assert get_asset_info_by_id(session, asset_info_id=info.id) is None + + def test_returns_false_for_nonexistent(self, session: Session): + result = delete_asset_info_by_id(session, asset_info_id="nonexistent", owner_id="") + assert result is False + + def test_respects_owner_visibility(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset, owner_id="user1") + session.commit() + + result = delete_asset_info_by_id(session, asset_info_id=info.id, owner_id="user2") + assert result is False + assert get_asset_info_by_id(session, asset_info_id=info.id) is not None + + +class TestSetAssetInfoPreview: + def test_sets_preview(self, session: Session): + asset = _make_asset(session, "hash1") + preview_asset = _make_asset(session, "preview_hash") + info = _make_asset_info(session, asset) + session.commit() + + set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id=preview_asset.id) + session.commit() + + session.refresh(info) + assert info.preview_id == preview_asset.id + + def test_clears_preview(self, session: Session): + asset = _make_asset(session, "hash1") + preview_asset = _make_asset(session, "preview_hash") + info = _make_asset_info(session, asset) + info.preview_id = preview_asset.id + session.commit() + + set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id=None) + session.commit() + + session.refresh(info) + assert info.preview_id is None + + def test_raises_for_nonexistent_info(self, session: Session): + with pytest.raises(ValueError, match="not found"): + set_asset_info_preview(session, asset_info_id="nonexistent", preview_asset_id=None) + + def test_raises_for_nonexistent_preview(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + session.commit() + + with pytest.raises(ValueError, match="Preview Asset"): + set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id="nonexistent") diff --git a/tests-unit/assets_test/queries/test_cache_state.py b/tests-unit/assets_test/queries/test_cache_state.py new file mode 100644 index 000000000..f5543523f --- /dev/null +++ b/tests-unit/assets_test/queries/test_cache_state.py @@ -0,0 +1,128 @@ +"""Tests for cache_state query functions.""" +import os +import tempfile +from unittest.mock import patch + +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetCacheState, AssetInfo +from app.assets.database.queries import ( + list_cache_states_by_asset_id, + pick_best_live_path, +) +from app.assets.helpers import utcnow + + +def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset: + asset = Asset(hash=hash_val, size_bytes=size) + session.add(asset) + session.flush() + return asset + + +def _make_cache_state( + session: Session, + asset: Asset, + file_path: str, + mtime_ns: int | None = None, + needs_verify: bool = False, +) -> AssetCacheState: + state = AssetCacheState( + asset_id=asset.id, + file_path=file_path, + mtime_ns=mtime_ns, + needs_verify=needs_verify, + ) + session.add(state) + session.flush() + return state + + +class TestListCacheStatesByAssetId: + def test_returns_empty_for_no_states(self, session: Session): + asset = _make_asset(session, "hash1") + states = list_cache_states_by_asset_id(session, asset_id=asset.id) + assert list(states) == [] + + def test_returns_states_for_asset(self, session: Session): + asset = _make_asset(session, "hash1") + _make_cache_state(session, asset, "/path/a.bin") + _make_cache_state(session, asset, "/path/b.bin") + session.commit() + + states = list_cache_states_by_asset_id(session, asset_id=asset.id) + paths = [s.file_path for s in states] + assert set(paths) == {"/path/a.bin", "/path/b.bin"} + + def test_does_not_return_other_assets_states(self, session: Session): + asset1 = _make_asset(session, "hash1") + asset2 = _make_asset(session, "hash2") + _make_cache_state(session, asset1, "/path/asset1.bin") + _make_cache_state(session, asset2, "/path/asset2.bin") + session.commit() + + states = list_cache_states_by_asset_id(session, asset_id=asset1.id) + paths = [s.file_path for s in states] + assert paths == ["/path/asset1.bin"] + + +class TestPickBestLivePath: + def test_returns_empty_for_empty_list(self): + result = pick_best_live_path([]) + assert result == "" + + def test_returns_empty_when_no_files_exist(self, session: Session): + asset = _make_asset(session, "hash1") + state = _make_cache_state(session, asset, "/nonexistent/path.bin") + session.commit() + + result = pick_best_live_path([state]) + assert result == "" + + def test_prefers_verified_path(self, session: Session, tmp_path): + """needs_verify=False should be preferred.""" + asset = _make_asset(session, "hash1") + + verified_file = tmp_path / "verified.bin" + verified_file.write_bytes(b"data") + + unverified_file = tmp_path / "unverified.bin" + unverified_file.write_bytes(b"data") + + state_verified = _make_cache_state( + session, asset, str(verified_file), needs_verify=False + ) + state_unverified = _make_cache_state( + session, asset, str(unverified_file), needs_verify=True + ) + session.commit() + + states = [state_unverified, state_verified] + result = pick_best_live_path(states) + assert result == str(verified_file) + + def test_falls_back_to_existing_unverified(self, session: Session, tmp_path): + """If all states need verification, return first existing path.""" + asset = _make_asset(session, "hash1") + + existing_file = tmp_path / "exists.bin" + existing_file.write_bytes(b"data") + + state = _make_cache_state(session, asset, str(existing_file), needs_verify=True) + session.commit() + + result = pick_best_live_path([state]) + assert result == str(existing_file) + + +class TestPickBestLivePathWithMocking: + def test_handles_missing_file_path_attr(self): + """Gracefully handle states with None file_path.""" + + class MockState: + file_path = None + needs_verify = False + + result = pick_best_live_path([MockState()]) + assert result == "" diff --git a/tests-unit/assets_test/queries/test_metadata.py b/tests-unit/assets_test/queries/test_metadata.py new file mode 100644 index 000000000..9e0fcfc65 --- /dev/null +++ b/tests-unit/assets_test/queries/test_metadata.py @@ -0,0 +1,180 @@ +"""Tests for metadata filtering logic in asset_info queries.""" +import pytest +from decimal import Decimal +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta +from app.assets.database.queries import list_asset_infos_page +from app.assets.helpers import utcnow, project_kv + + +def _make_asset(session: Session, hash_val: str) -> Asset: + asset = Asset(hash=hash_val, size_bytes=1024) + session.add(asset) + session.flush() + return asset + + +def _make_asset_info( + session: Session, + asset: Asset, + name: str, + metadata: dict | None = None, +) -> AssetInfo: + now = utcnow() + info = AssetInfo( + owner_id="", + name=name, + asset_id=asset.id, + user_metadata=metadata, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(info) + session.flush() + + if metadata: + for key, val in metadata.items(): + for row in project_kv(key, val): + meta_row = AssetInfoMeta( + asset_info_id=info.id, + key=row["key"], + ordinal=row.get("ordinal", 0), + val_str=row.get("val_str"), + val_num=row.get("val_num"), + val_bool=row.get("val_bool"), + val_json=row.get("val_json"), + ) + session.add(meta_row) + session.flush() + + return info + + +class TestMetadataFilterString: + def test_filter_by_string_value(self, session: Session): + asset = _make_asset(session, "hash1") + _make_asset_info(session, asset, "match", {"category": "models"}) + _make_asset_info(session, asset, "nomatch", {"category": "images"}) + session.commit() + + infos, _, total = list_asset_infos_page(session, metadata_filter={"category": "models"}) + assert total == 1 + assert infos[0].name == "match" + + def test_filter_by_string_no_match(self, session: Session): + asset = _make_asset(session, "hash1") + _make_asset_info(session, asset, "item", {"category": "models"}) + session.commit() + + infos, _, total = list_asset_infos_page(session, metadata_filter={"category": "other"}) + assert total == 0 + + +class TestMetadataFilterNumeric: + def test_filter_by_int_value(self, session: Session): + asset = _make_asset(session, "hash1") + _make_asset_info(session, asset, "epoch5", {"epoch": 5}) + _make_asset_info(session, asset, "epoch10", {"epoch": 10}) + session.commit() + + infos, _, total = list_asset_infos_page(session, metadata_filter={"epoch": 5}) + assert total == 1 + assert infos[0].name == "epoch5" + + def test_filter_by_float_value(self, session: Session): + asset = _make_asset(session, "hash1") + _make_asset_info(session, asset, "high", {"score": 0.95}) + _make_asset_info(session, asset, "low", {"score": 0.5}) + session.commit() + + infos, _, total = list_asset_infos_page(session, metadata_filter={"score": 0.95}) + assert total == 1 + assert infos[0].name == "high" + + +class TestMetadataFilterBoolean: + def test_filter_by_true(self, session: Session): + asset = _make_asset(session, "hash1") + _make_asset_info(session, asset, "active", {"enabled": True}) + _make_asset_info(session, asset, "inactive", {"enabled": False}) + session.commit() + + infos, _, total = list_asset_infos_page(session, metadata_filter={"enabled": True}) + assert total == 1 + assert infos[0].name == "active" + + def test_filter_by_false(self, session: Session): + asset = _make_asset(session, "hash1") + _make_asset_info(session, asset, "active", {"enabled": True}) + _make_asset_info(session, asset, "inactive", {"enabled": False}) + session.commit() + + infos, _, total = list_asset_infos_page(session, metadata_filter={"enabled": False}) + assert total == 1 + assert infos[0].name == "inactive" + + +class TestMetadataFilterNull: + def test_filter_by_null_matches_missing_key(self, session: Session): + asset = _make_asset(session, "hash1") + _make_asset_info(session, asset, "has_key", {"optional": "value"}) + _make_asset_info(session, asset, "missing_key", {}) + session.commit() + + infos, _, total = list_asset_infos_page(session, metadata_filter={"optional": None}) + assert total == 1 + assert infos[0].name == "missing_key" + + def test_filter_by_null_matches_explicit_null(self, session: Session): + asset = _make_asset(session, "hash1") + _make_asset_info(session, asset, "explicit_null", {"nullable": None}) + _make_asset_info(session, asset, "has_value", {"nullable": "present"}) + session.commit() + + infos, _, total = list_asset_infos_page(session, metadata_filter={"nullable": None}) + assert total == 1 + assert infos[0].name == "explicit_null" + + +class TestMetadataFilterList: + def test_filter_by_list_or(self, session: Session): + """List values should match ANY of the values (OR).""" + asset = _make_asset(session, "hash1") + _make_asset_info(session, asset, "cat_a", {"category": "a"}) + _make_asset_info(session, asset, "cat_b", {"category": "b"}) + _make_asset_info(session, asset, "cat_c", {"category": "c"}) + session.commit() + + infos, _, total = list_asset_infos_page(session, metadata_filter={"category": ["a", "b"]}) + assert total == 2 + names = {i.name for i in infos} + assert names == {"cat_a", "cat_b"} + + +class TestMetadataFilterMultipleKeys: + def test_multiple_keys_and(self, session: Session): + """Multiple keys should ALL match (AND).""" + asset = _make_asset(session, "hash1") + _make_asset_info(session, asset, "match", {"type": "model", "version": 2}) + _make_asset_info(session, asset, "wrong_type", {"type": "config", "version": 2}) + _make_asset_info(session, asset, "wrong_version", {"type": "model", "version": 1}) + session.commit() + + infos, _, total = list_asset_infos_page( + session, metadata_filter={"type": "model", "version": 2} + ) + assert total == 1 + assert infos[0].name == "match" + + +class TestMetadataFilterEmptyDict: + def test_empty_filter_returns_all(self, session: Session): + asset = _make_asset(session, "hash1") + _make_asset_info(session, asset, "a", {"key": "val"}) + _make_asset_info(session, asset, "b", {}) + session.commit() + + infos, _, total = list_asset_infos_page(session, metadata_filter={}) + assert total == 2 diff --git a/tests-unit/assets_test/queries/test_tags.py b/tests-unit/assets_test/queries/test_tags.py new file mode 100644 index 000000000..faf371b40 --- /dev/null +++ b/tests-unit/assets_test/queries/test_tags.py @@ -0,0 +1,297 @@ +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetInfo, AssetInfoTag, Tag +from app.assets.database.queries import ( + ensure_tags_exist, + get_asset_tags, + set_asset_info_tags, + add_tags_to_asset_info, + remove_tags_from_asset_info, + add_missing_tag_for_asset_id, + remove_missing_tag_for_asset_id, + list_tags_with_usage, +) +from app.assets.helpers import utcnow + + +def _make_asset(session: Session, hash_val: str | None = None) -> Asset: + asset = Asset(hash=hash_val, size_bytes=1024) + session.add(asset) + session.flush() + return asset + + +def _make_asset_info(session: Session, asset: Asset, name: str = "test", owner_id: str = "") -> AssetInfo: + now = utcnow() + info = AssetInfo( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(info) + session.flush() + return info + + +class TestEnsureTagsExist: + def test_creates_new_tags(self, session: Session): + ensure_tags_exist(session, ["alpha", "beta"], tag_type="user") + session.commit() + + tags = session.query(Tag).all() + assert {t.name for t in tags} == {"alpha", "beta"} + + def test_is_idempotent(self, session: Session): + ensure_tags_exist(session, ["alpha"], tag_type="user") + ensure_tags_exist(session, ["alpha"], tag_type="user") + session.commit() + + assert session.query(Tag).count() == 1 + + def test_normalizes_tags(self, session: Session): + ensure_tags_exist(session, [" ALPHA ", "Beta", "alpha"]) + session.commit() + + tags = session.query(Tag).all() + assert {t.name for t in tags} == {"alpha", "beta"} + + def test_empty_list_is_noop(self, session: Session): + ensure_tags_exist(session, []) + session.commit() + assert session.query(Tag).count() == 0 + + def test_tag_type_is_set(self, session: Session): + ensure_tags_exist(session, ["system-tag"], tag_type="system") + session.commit() + + tag = session.query(Tag).filter_by(name="system-tag").one() + assert tag.tag_type == "system" + + +class TestGetAssetTags: + def test_returns_empty_for_no_tags(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + + tags = get_asset_tags(session, asset_info_id=info.id) + assert tags == [] + + def test_returns_tags_for_asset(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + + ensure_tags_exist(session, ["tag1", "tag2"]) + session.add_all([ + AssetInfoTag(asset_info_id=info.id, tag_name="tag1", origin="manual", added_at=utcnow()), + AssetInfoTag(asset_info_id=info.id, tag_name="tag2", origin="manual", added_at=utcnow()), + ]) + session.flush() + + tags = get_asset_tags(session, asset_info_id=info.id) + assert set(tags) == {"tag1", "tag2"} + + +class TestSetAssetInfoTags: + def test_adds_new_tags(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + + result = set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b"]) + session.commit() + + assert set(result["added"]) == {"a", "b"} + assert result["removed"] == [] + assert set(result["total"]) == {"a", "b"} + + def test_removes_old_tags(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + + set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b", "c"]) + result = set_asset_info_tags(session, asset_info_id=info.id, tags=["a"]) + session.commit() + + assert result["added"] == [] + assert set(result["removed"]) == {"b", "c"} + assert result["total"] == ["a"] + + def test_replaces_tags(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + + set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b"]) + result = set_asset_info_tags(session, asset_info_id=info.id, tags=["b", "c"]) + session.commit() + + assert result["added"] == ["c"] + assert result["removed"] == ["a"] + assert set(result["total"]) == {"b", "c"} + + +class TestAddTagsToAssetInfo: + def test_adds_tags(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + + result = add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x", "y"]) + session.commit() + + assert set(result["added"]) == {"x", "y"} + assert result["already_present"] == [] + + def test_reports_already_present(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + + add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x"]) + result = add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x", "y"]) + session.commit() + + assert result["added"] == ["y"] + assert result["already_present"] == ["x"] + + def test_raises_for_missing_asset_info(self, session: Session): + with pytest.raises(ValueError, match="not found"): + add_tags_to_asset_info(session, asset_info_id="nonexistent", tags=["x"]) + + +class TestRemoveTagsFromAssetInfo: + def test_removes_tags(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + + add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a", "b", "c"]) + result = remove_tags_from_asset_info(session, asset_info_id=info.id, tags=["a", "b"]) + session.commit() + + assert set(result["removed"]) == {"a", "b"} + assert result["not_present"] == [] + assert result["total_tags"] == ["c"] + + def test_reports_not_present(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + + add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a"]) + result = remove_tags_from_asset_info(session, asset_info_id=info.id, tags=["a", "x"]) + session.commit() + + assert result["removed"] == ["a"] + assert result["not_present"] == ["x"] + + def test_raises_for_missing_asset_info(self, session: Session): + with pytest.raises(ValueError, match="not found"): + remove_tags_from_asset_info(session, asset_info_id="nonexistent", tags=["x"]) + + +class TestMissingTagFunctions: + def test_add_missing_tag_for_asset_id(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + ensure_tags_exist(session, ["missing"], tag_type="system") + + add_missing_tag_for_asset_id(session, asset_id=asset.id) + session.commit() + + tags = get_asset_tags(session, asset_info_id=info.id) + assert "missing" in tags + + def test_add_missing_tag_is_idempotent(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + ensure_tags_exist(session, ["missing"], tag_type="system") + + add_missing_tag_for_asset_id(session, asset_id=asset.id) + add_missing_tag_for_asset_id(session, asset_id=asset.id) + session.commit() + + links = session.query(AssetInfoTag).filter_by(asset_info_id=info.id, tag_name="missing").all() + assert len(links) == 1 + + def test_remove_missing_tag_for_asset_id(self, session: Session): + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + ensure_tags_exist(session, ["missing"], tag_type="system") + add_missing_tag_for_asset_id(session, asset_id=asset.id) + + remove_missing_tag_for_asset_id(session, asset_id=asset.id) + session.commit() + + tags = get_asset_tags(session, asset_info_id=info.id) + assert "missing" not in tags + + +class TestListTagsWithUsage: + def test_returns_tags_with_counts(self, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"]) + session.commit() + + rows, total = list_tags_with_usage(session) + + tag_dict = {name: count for name, _, count in rows} + assert tag_dict["used"] == 1 + assert tag_dict["unused"] == 0 + assert total == 2 + + def test_exclude_zero_counts(self, session: Session): + ensure_tags_exist(session, ["used", "unused"]) + + asset = _make_asset(session, "hash1") + info = _make_asset_info(session, asset) + add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"]) + session.commit() + + rows, total = list_tags_with_usage(session, include_zero=False) + + tag_names = {name for name, _, _ in rows} + assert "used" in tag_names + assert "unused" not in tag_names + + def test_prefix_filter(self, session: Session): + ensure_tags_exist(session, ["alpha", "beta", "alphabet"]) + session.commit() + + rows, total = list_tags_with_usage(session, prefix="alph") + + tag_names = {name for name, _, _ in rows} + assert tag_names == {"alpha", "alphabet"} + + def test_order_by_name(self, session: Session): + ensure_tags_exist(session, ["zebra", "alpha", "middle"]) + session.commit() + + rows, _ = list_tags_with_usage(session, order="name_asc") + + names = [name for name, _, _ in rows] + assert names == ["alpha", "middle", "zebra"] + + def test_owner_visibility(self, session: Session): + ensure_tags_exist(session, ["shared-tag", "owner-tag"]) + + asset = _make_asset(session, "hash1") + shared_info = _make_asset_info(session, asset, name="shared", owner_id="") + owner_info = _make_asset_info(session, asset, name="owned", owner_id="user1") + + add_tags_to_asset_info(session, asset_info_id=shared_info.id, tags=["shared-tag"]) + add_tags_to_asset_info(session, asset_info_id=owner_info.id, tags=["owner-tag"]) + session.commit() + + # Empty owner sees only shared + rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False) + tag_dict = {name: count for name, _, count in rows} + assert tag_dict.get("shared-tag", 0) == 1 + assert tag_dict.get("owner-tag", 0) == 0 + + # User1 sees both + rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False) + tag_dict = {name: count for name, _, count in rows} + assert tag_dict.get("shared-tag", 0) == 1 + assert tag_dict.get("owner-tag", 0) == 1