import os import sqlalchemy as sa from collections import defaultdict from datetime import datetime from sqlalchemy import select, exists, func from sqlalchemy.orm import Session, contains_eager, noload from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag from app.assets.helpers import escape_like_prefix, normalize_tags, utcnow from typing import Sequence def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" owner_id = (owner_id or "").strip() if owner_id == "": return AssetInfo.owner_id == "" return AssetInfo.owner_id.in_(["", owner_id]) def apply_tag_filters( stmt: sa.sql.Select, include_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None, ) -> sa.sql.Select: """include_tags: every tag must be present; exclude_tags: none may be present.""" include_tags = normalize_tags(include_tags) exclude_tags = normalize_tags(exclude_tags) if include_tags: for tag_name in include_tags: stmt = stmt.where( exists().where( (AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == tag_name) ) ) if exclude_tags: stmt = stmt.where( ~exists().where( (AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name.in_(exclude_tags)) ) ) return stmt def apply_metadata_filter( stmt: sa.sql.Select, metadata_filter: dict | None = None, ) -> sa.sql.Select: """Apply filters using asset_info_meta projection table.""" if not metadata_filter: return stmt def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: return sa.exists().where( AssetInfoMeta.asset_info_id == AssetInfo.id, AssetInfoMeta.key == key, *preds, ) def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: if value is None: no_row_for_key = sa.not_( sa.exists().where( AssetInfoMeta.asset_info_id == AssetInfo.id, AssetInfoMeta.key == key, ) ) null_row = _exists_for_pred( key, AssetInfoMeta.val_json.is_(None), AssetInfoMeta.val_str.is_(None), AssetInfoMeta.val_num.is_(None), AssetInfoMeta.val_bool.is_(None), ) return sa.or_(no_row_for_key, null_row) if isinstance(value, bool): return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) if isinstance(value, (int, float)): from decimal import Decimal num = value if isinstance(value, Decimal) else Decimal(str(value)) return _exists_for_pred(key, AssetInfoMeta.val_num == num) if isinstance(value, str): return _exists_for_pred(key, AssetInfoMeta.val_str == value) return _exists_for_pred(key, AssetInfoMeta.val_json == value) for k, v in metadata_filter.items(): if isinstance(v, list): ors = [_exists_clause_for_value(k, elem) for elem in v] if ors: stmt = stmt.where(sa.or_(*ors)) else: stmt = stmt.where(_exists_clause_for_value(k, v)) return stmt def asset_exists_by_hash(session: Session, asset_hash: str) -> bool: """ Check if an asset with a given hash exists in database. """ row = ( session.execute( select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) ) ).first() return row is not None def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None: return session.get(AssetInfo, asset_info_id) def list_asset_infos_page( session: Session, owner_id: str = "", include_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None, name_contains: str | None = None, metadata_filter: dict | None = None, limit: int = 20, offset: int = 0, sort: str = "created_at", order: str = "desc", ) -> tuple[list[AssetInfo], dict[str, list[str]], int]: base = ( select(AssetInfo) .join(Asset, Asset.id == AssetInfo.asset_id) .options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags)) .where(visible_owner_clause(owner_id)) ) if name_contains: escaped, esc = escape_like_prefix(name_contains) base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) base = apply_tag_filters(base, include_tags, exclude_tags) base = apply_metadata_filter(base, metadata_filter) sort = (sort or "created_at").lower() order = (order or "desc").lower() sort_map = { "name": AssetInfo.name, "created_at": AssetInfo.created_at, "updated_at": AssetInfo.updated_at, "last_access_time": AssetInfo.last_access_time, "size": Asset.size_bytes, } sort_col = sort_map.get(sort, AssetInfo.created_at) sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() base = base.order_by(sort_exp).limit(limit).offset(offset) count_stmt = ( select(sa.func.count()) .select_from(AssetInfo) .join(Asset, Asset.id == AssetInfo.asset_id) .where(visible_owner_clause(owner_id)) ) if name_contains: escaped, esc = escape_like_prefix(name_contains) count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) count_stmt = apply_metadata_filter(count_stmt, metadata_filter) total = int((session.execute(count_stmt)).scalar_one() or 0) infos = (session.execute(base)).unique().scalars().all() id_list: list[str] = [i.id for i in infos] tag_map: dict[str, list[str]] = defaultdict(list) if id_list: rows = session.execute( select(AssetInfoTag.asset_info_id, Tag.name) .join(Tag, Tag.name == AssetInfoTag.tag_name) .where(AssetInfoTag.asset_info_id.in_(id_list)) ) for aid, tag_name in rows.all(): tag_map[aid].append(tag_name) return infos, tag_map, total def fetch_asset_info_asset_and_tags( session: Session, asset_info_id: str, owner_id: str = "", ) -> tuple[AssetInfo, Asset, list[str]] | None: stmt = ( select(AssetInfo, Asset, Tag.name) .join(Asset, Asset.id == AssetInfo.asset_id) .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) .where( AssetInfo.id == asset_info_id, visible_owner_clause(owner_id), ) .options(noload(AssetInfo.tags)) .order_by(Tag.name.asc()) ) rows = (session.execute(stmt)).all() if not rows: return None first_info, first_asset, _ = rows[0] tags: list[str] = [] seen: set[str] = set() for _info, _asset, tag_name in rows: if tag_name and tag_name not in seen: seen.add(tag_name) tags.append(tag_name) return first_info, first_asset, tags def fetch_asset_info_and_asset( session: Session, *, asset_info_id: str, owner_id: str = "", ) -> tuple[AssetInfo, Asset] | None: stmt = ( select(AssetInfo, Asset) .join(Asset, Asset.id == AssetInfo.asset_id) .where( AssetInfo.id == asset_info_id, visible_owner_clause(owner_id), ) .limit(1) .options(noload(AssetInfo.tags)) ) row = session.execute(stmt) pair = row.first() if not pair: return None return pair[0], pair[1] def list_cache_states_by_asset_id( session: Session, *, asset_id: str ) -> Sequence[AssetCacheState]: return ( session.execute( select(AssetCacheState) .where(AssetCacheState.asset_id == asset_id) .order_by(AssetCacheState.id.asc()) ) ).scalars().all() def touch_asset_info_by_id( session: Session, *, asset_info_id: str, ts: datetime | None = None, only_if_newer: bool = True, ) -> None: ts = ts or utcnow() stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) if only_if_newer: stmt = stmt.where( sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) ) session.execute(stmt.values(last_access_time=ts)) def list_tags_with_usage( session: Session, prefix: str | None = None, limit: int = 100, offset: int = 0, include_zero: bool = True, order: str = "count_desc", owner_id: str = "", ) -> tuple[list[tuple[str, str, int]], int]: counts_sq = ( select( AssetInfoTag.tag_name.label("tag_name"), func.count(AssetInfoTag.asset_info_id).label("cnt"), ) .select_from(AssetInfoTag) .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) .where(visible_owner_clause(owner_id)) .group_by(AssetInfoTag.tag_name) .subquery() ) q = ( select( Tag.name, Tag.tag_type, func.coalesce(counts_sq.c.cnt, 0).label("count"), ) .select_from(Tag) .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) ) if prefix: escaped, esc = escape_like_prefix(prefix.strip().lower()) q = q.where(Tag.name.like(escaped + "%", escape=esc)) if not include_zero: q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) if order == "name_asc": q = q.order_by(Tag.name.asc()) else: q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) total_q = select(func.count()).select_from(Tag) if prefix: escaped, esc = escape_like_prefix(prefix.strip().lower()) total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) if not include_zero: total_q = total_q.where( Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) ) rows = (session.execute(q.limit(limit).offset(offset))).all() total = (session.execute(total_q)).scalar_one() rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] return rows_norm, int(total or 0) def pick_best_live_path(states: Sequence[AssetCacheState]) -> str: """ Return the best on-disk path among cache states: 1) Prefer a path that exists with needs_verify == False (already verified). 2) Otherwise, pick the first path that exists. 3) Otherwise return empty string. """ alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)] if not alive: return "" for s in alive: if not getattr(s, "needs_verify", False): return s.file_path return alive[0].file_path