refactor(assets): split queries.py into modular query modules

Split the ~1000 line app/assets/database/queries.py into focused modules:

- queries/asset.py - Asset entity queries (asset_exists_by_hash, get_asset_by_hash)
- queries/asset_info.py - AssetInfo queries (~15 functions)
- queries/cache_state.py - AssetCacheState queries (list_cache_states_by_asset_id,
  pick_best_live_path, prune_orphaned_assets, fast_db_consistency_pass)
- queries/tags.py - Tag queries (8 functions including ensure_tags_exist,
  add/remove tag functions, list_tags_with_usage)
- queries/__init__.py - Re-exports all public functions for backward compatibility

Also adds comprehensive unit tests using in-memory SQLite:
- tests-unit/assets_test/queries/conftest.py - Session fixture
- tests-unit/assets_test/queries/test_asset.py - 5 tests
- tests-unit/assets_test/queries/test_asset_info.py - 23 tests
- tests-unit/assets_test/queries/test_cache_state.py - 8 tests
- tests-unit/assets_test/queries/test_metadata.py - 12 tests for _apply_metadata_filter
- tests-unit/assets_test/queries/test_tags.py - 23 tests

All 71 unit tests pass. Existing integration tests unaffected.

Amp-Thread-ID: https://ampcode.com/threads/T-019c24bb-475b-7442-9ff9-8288edea3345
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr 2026-02-03 10:25:24 -08:00
parent ab1050bec3
commit 17ad7e393f
13 changed files with 1569 additions and 536 deletions

View File

@ -0,0 +1,73 @@
# Re-export public API from query modules
# Maintains backward compatibility with old flat queries.py imports
from app.assets.database.queries.asset import (
asset_exists_by_hash,
get_asset_by_hash,
)
from app.assets.database.queries.asset_info import (
asset_info_exists_for_asset_id,
get_asset_info_by_id,
list_asset_infos_page,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
touch_asset_info_by_id,
create_asset_info_for_existing_asset,
replace_asset_info_metadata_projection,
ingest_fs_asset,
update_asset_info_full,
delete_asset_info_by_id,
set_asset_info_preview,
)
from app.assets.database.queries.cache_state import (
list_cache_states_by_asset_id,
pick_best_live_path,
prune_orphaned_assets,
fast_db_consistency_pass,
)
from app.assets.database.queries.tags import (
ensure_tags_exist,
get_asset_tags,
set_asset_info_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
add_missing_tag_for_asset_id,
remove_missing_tag_for_asset_id,
list_tags_with_usage,
)
__all__ = [
# asset.py
"asset_exists_by_hash",
"get_asset_by_hash",
# asset_info.py
"asset_info_exists_for_asset_id",
"get_asset_info_by_id",
"list_asset_infos_page",
"fetch_asset_info_asset_and_tags",
"fetch_asset_info_and_asset",
"touch_asset_info_by_id",
"create_asset_info_for_existing_asset",
"replace_asset_info_metadata_projection",
"ingest_fs_asset",
"update_asset_info_full",
"delete_asset_info_by_id",
"set_asset_info_preview",
# cache_state.py
"list_cache_states_by_asset_id",
"pick_best_live_path",
"prune_orphaned_assets",
"fast_db_consistency_pass",
# tags.py
"ensure_tags_exist",
"get_asset_tags",
"set_asset_info_tags",
"add_tags_to_asset_info",
"remove_tags_from_asset_info",
"add_missing_tag_for_asset_id",
"remove_missing_tag_for_asset_id",
"list_tags_with_usage",
]

View File

@ -0,0 +1,31 @@
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.assets.database.models import Asset
def asset_exists_by_hash(
session: Session,
*,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
def get_asset_by_hash(
session: Session,
*,
asset_hash: str,
) -> Asset | None:
return (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()

View File

@ -1,21 +1,27 @@
import os
import logging
import sqlalchemy as sa
from collections import defaultdict
from datetime import datetime
from typing import Iterable, Any
from sqlalchemy import select, delete, exists, func
from typing import Any, Sequence
import sqlalchemy as sa
from sqlalchemy import select, delete, exists
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.database.models import (
Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
)
from app.assets.helpers import (
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
)
from typing import Sequence
from app.assets.database.queries.asset import get_asset_by_hash
from app.assets.database.queries.cache_state import list_cache_states_by_asset_id, pick_best_live_path
from app.assets.database.queries.tags import ensure_tags_exist, set_asset_info_tags, remove_missing_tag_for_asset_id
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
def _visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
@ -23,23 +29,7 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
return AssetInfo.owner_id.in_(["", owner_id])
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def apply_tag_filters(
def _apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
@ -67,7 +57,7 @@ def apply_tag_filters(
return stmt
def apply_metadata_filter(
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
@ -119,22 +109,6 @@ def apply_metadata_filter(
return stmt
def asset_exists_by_hash(
session: Session,
*,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
def asset_info_exists_for_asset_id(
session: Session,
*,
@ -149,16 +123,6 @@ def asset_info_exists_for_asset_id(
return (session.execute(q)).first() is not None
def get_asset_by_hash(
session: Session,
*,
asset_hash: str,
) -> Asset | None:
return (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
def get_asset_info_by_id(
session: Session,
*,
@ -183,15 +147,15 @@ def list_asset_infos_page(
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
.where(_visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
base = _apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
@ -211,13 +175,13 @@ def list_asset_infos_page(
select(sa.func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
.where(_visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
total = int((session.execute(count_stmt)).scalar_one() or 0)
@ -250,7 +214,7 @@ def fetch_asset_info_asset_and_tags(
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
_visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
@ -281,7 +245,7 @@ def fetch_asset_info_and_asset(
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
_visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
@ -292,17 +256,6 @@ def fetch_asset_info_and_asset(
return None
return pair[0], pair[1]
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def touch_asset_info_by_id(
session: Session,
@ -366,7 +319,6 @@ def create_asset_info_for_existing_asset(
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
@ -394,42 +346,6 @@ def create_asset_info_for_existing_asset(
return info
def set_asset_info_tags(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def replace_asset_info_metadata_projection(
session: Session,
*,
@ -507,7 +423,6 @@ def ingest_fs_asset(
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
@ -543,7 +458,6 @@ def ingest_fs_asset(
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
@ -575,7 +489,6 @@ def ingest_fs_asset(
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
with session.begin_nested():
@ -652,7 +565,6 @@ def ingest_fs_asset(
)
session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
@ -752,207 +664,11 @@ def delete_asset_info_by_id(
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
_visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
return [
tag_name for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def add_tags_to_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)
def set_asset_info_preview(
session: Session,
*,
@ -967,7 +683,6 @@ def set_asset_info_preview(
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id

View File

@ -0,0 +1,212 @@
import os
from typing import Sequence
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
from app.assets.helpers import escape_like_prefix
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def prune_orphaned_assets(session: Session, roots: tuple[str, ...], prefixes_for_root_fn) -> int:
"""Prune cache states outside configured prefixes, then delete orphaned seed assets.
Args:
session: Database session
roots: Tuple of root types to prune
prefixes_for_root_fn: Function to get prefixes for a root type
Returns:
Number of orphaned assets deleted
"""
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root_fn(r)]
if not all_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_like_prefix(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sa.or_(*[make_prefix_condition(p) for p in all_prefixes])
orphan_subq = (
sa.select(Asset.id)
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
).scalar_subquery()
session.execute(sa.delete(AssetCacheState).where(~matches_valid_prefix))
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
result = session.execute(sa.delete(Asset).where(Asset.id.in_(orphan_subq)))
return result.rowcount
def fast_db_consistency_pass(
session: Session,
root: str,
*,
prefixes_for_root_fn,
escape_like_prefix_fn,
fast_asset_file_check_fn,
add_missing_tag_fn,
remove_missing_tag_fn,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> set[str] | None:
"""Fast DB+FS pass for a root:
- Toggle needs_verify per state using fast check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
"""
import contextlib
prefixes = prefixes_for_root_fn(root)
if not prefixes:
return set() if collect_existing_paths else None
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix_fn(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
rows = (
session.execute(
sa.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).all()
by_asset: dict[str, dict] = {}
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
acc = by_asset.get(aid)
if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
by_asset[aid] = acc
fast_ok = False
try:
exists = True
fast_ok = fast_asset_file_check_fn(
mtime_db=mtime_db,
size_db=acc["size_db"],
stat_result=os.stat(fp, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append({
"sid": sid,
"fp": fp,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": bool(needs_verify),
})
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing:
session.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid))
asset = session.get(Asset, aid)
if asset:
session.delete(asset)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok:
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
remove_missing_tag_fn(session, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
add_missing_tag_fn(session, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
if stale_state_ids:
session.execute(sa.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
if to_set_verify:
session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set_verify))
.values(needs_verify=True)
)
if to_clear_verify:
session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_clear_verify))
.values(needs_verify=False)
)
return survivors if collect_existing_paths else None

View File

@ -0,0 +1,280 @@
from typing import Iterable, Sequence
import sqlalchemy as sa
from sqlalchemy import select, delete, func
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.assets.database.models import AssetInfo, AssetInfoTag, Tag
from app.assets.helpers import escape_like_prefix, normalize_tags, utcnow
def _visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
return [
tag_name for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def set_asset_info_tags(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def add_tags_to_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row = None,
) -> dict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def add_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sa.select(
AssetInfo.id.label("asset_info_id"),
sa.literal("missing").label("tag_name"),
sa.literal(origin).label("origin"),
sa.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sa.not_(
sa.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(_visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)

View File

@ -1,62 +0,0 @@
from typing import Iterable
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import normalize_tags, utcnow
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
return session.execute(ins)
def add_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sqlalchemy.select(
AssetInfo.id.label("asset_info_id"),
sqlalchemy.literal("missing").label("tag_name"),
sqlalchemy.literal(origin).label("origin"),
sqlalchemy.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sqlalchemy.not_(
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sqlalchemy.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)

View File

@ -1,19 +1,22 @@
import contextlib
import time
import logging
import os
import sqlalchemy
import folder_paths
from app.database.db import create_session, dependencies_available
from app.assets.helpers import (
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
list_tree,prefixes_for_root, escape_like_prefix,
list_tree, prefixes_for_root, escape_like_prefix,
RootType
)
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
from app.assets.database.queries import (
add_missing_tag_for_asset_id,
ensure_tags_exist,
remove_missing_tag_for_asset_id,
prune_orphaned_assets,
fast_db_consistency_pass,
)
from app.assets.database.bulk_ops import seed_from_paths_batch
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
@ -33,14 +36,28 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
existing_paths: set[str] = set()
for r in roots:
try:
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
with create_session() as sess:
survivors: set[str] = fast_db_consistency_pass(
sess,
r,
prefixes_for_root_fn=prefixes_for_root,
escape_like_prefix_fn=escape_like_prefix,
fast_asset_file_check_fn=fast_asset_file_check,
add_missing_tag_fn=add_missing_tag_for_asset_id,
remove_missing_tag_fn=remove_missing_tag_for_asset_id,
collect_existing_paths=True,
update_missing_tags=True,
)
sess.commit()
if survivors:
existing_paths.update(survivors)
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", r, e)
try:
orphans_pruned = _prune_orphaned_assets(roots)
with create_session() as sess:
orphans_pruned = prune_orphaned_assets(sess, roots, prefixes_for_root)
sess.commit()
except Exception as e:
logging.exception("orphan pruning failed: %s", e)
@ -101,163 +118,4 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
)
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
if not all_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_like_prefix(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
orphan_subq = (
sqlalchemy.select(Asset.id)
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
).scalar_subquery()
with create_session() as sess:
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
sess.commit()
return result.rowcount
def _fast_db_consistency_pass(
root: RootType,
*,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> set[str] | None:
"""Fast DB+FS pass for a root:
- Toggle needs_verify per state using fast check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
"""
prefixes = prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
with create_session() as sess:
rows = (
sess.execute(
sqlalchemy.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sqlalchemy.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).all()
by_asset: dict[str, dict] = {}
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
acc = by_asset.get(aid)
if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
by_asset[aid] = acc
fast_ok = False
try:
exists = True
fast_ok = fast_asset_file_check(
mtime_db=mtime_db,
size_db=acc["size_db"],
stat_result=os.stat(fp, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append({
"sid": sid,
"fp": fp,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": bool(needs_verify),
})
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
asset = sess.get(Asset, aid)
if asset:
sess.delete(asset)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
remove_missing_tag_for_asset_id(sess, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
if stale_state_ids:
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
if to_set_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set_verify))
.values(needs_verify=True)
)
if to_clear_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_clear_verify))
.values(needs_verify=False)
)
sess.commit()
return survivors if collect_existing_paths else None

View File

@ -0,0 +1,14 @@
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from app.assets.database.models import Base
@pytest.fixture
def session():
"""In-memory SQLite session for fast unit tests."""
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
with Session(engine) as sess:
yield sess

View File

@ -0,0 +1,39 @@
from sqlalchemy.orm import Session
from app.assets.database.models import Asset
from app.assets.database.queries import asset_exists_by_hash, get_asset_by_hash
class TestAssetExistsByHash:
def test_returns_false_for_nonexistent(self, session: Session):
assert asset_exists_by_hash(session, asset_hash="nonexistent") is False
def test_returns_true_for_existing(self, session: Session):
asset = Asset(hash="blake3:abc123", size_bytes=100)
session.add(asset)
session.commit()
assert asset_exists_by_hash(session, asset_hash="blake3:abc123") is True
def test_does_not_match_null_hash(self, session: Session):
asset = Asset(hash=None, size_bytes=100)
session.add(asset)
session.commit()
assert asset_exists_by_hash(session, asset_hash="") is False
class TestGetAssetByHash:
def test_returns_none_for_nonexistent(self, session: Session):
assert get_asset_by_hash(session, asset_hash="nonexistent") is None
def test_returns_asset_for_existing(self, session: Session):
asset = Asset(hash="blake3:def456", size_bytes=200, mime_type="image/png")
session.add(asset)
session.commit()
result = get_asset_by_hash(session, asset_hash="blake3:def456")
assert result is not None
assert result.id == asset.id
assert result.size_bytes == 200
assert result.mime_type == "image/png"

View File

@ -0,0 +1,268 @@
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.database.queries import (
asset_info_exists_for_asset_id,
get_asset_info_by_id,
list_asset_infos_page,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
touch_asset_info_by_id,
delete_asset_info_by_id,
set_asset_info_preview,
ensure_tags_exist,
add_tags_to_asset_info,
)
from app.assets.helpers import utcnow
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream")
session.add(asset)
session.flush()
return asset
def _make_asset_info(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetInfo:
now = utcnow()
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
class TestAssetInfoExistsForAssetId:
def test_returns_false_when_no_info(self, session: Session):
asset = _make_asset(session, "hash1")
assert asset_info_exists_for_asset_id(session, asset_id=asset.id) is False
def test_returns_true_when_info_exists(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset)
assert asset_info_exists_for_asset_id(session, asset_id=asset.id) is True
class TestGetAssetInfoById:
def test_returns_none_for_nonexistent(self, session: Session):
assert get_asset_info_by_id(session, asset_info_id="nonexistent") is None
def test_returns_info(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, name="myfile.txt")
result = get_asset_info_by_id(session, asset_info_id=info.id)
assert result is not None
assert result.name == "myfile.txt"
class TestListAssetInfosPage:
def test_empty_db(self, session: Session):
infos, tag_map, total = list_asset_infos_page(session)
assert infos == []
assert tag_map == {}
assert total == 0
def test_returns_infos_with_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, name="test.bin")
ensure_tags_exist(session, ["alpha", "beta"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["alpha", "beta"])
session.commit()
infos, tag_map, total = list_asset_infos_page(session)
assert len(infos) == 1
assert infos[0].id == info.id
assert set(tag_map[info.id]) == {"alpha", "beta"}
assert total == 1
def test_name_contains_filter(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, name="model_v1.safetensors")
_make_asset_info(session, asset, name="config.json")
session.commit()
infos, _, total = list_asset_infos_page(session, name_contains="model")
assert total == 1
assert infos[0].name == "model_v1.safetensors"
def test_owner_visibility(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, name="public", owner_id="")
_make_asset_info(session, asset, name="private", owner_id="user1")
session.commit()
# Empty owner sees only public
infos, _, total = list_asset_infos_page(session, owner_id="")
assert total == 1
assert infos[0].name == "public"
# Owner sees both
infos, _, total = list_asset_infos_page(session, owner_id="user1")
assert total == 2
def test_include_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
info1 = _make_asset_info(session, asset, name="tagged")
info2 = _make_asset_info(session, asset, name="untagged")
ensure_tags_exist(session, ["wanted"])
add_tags_to_asset_info(session, asset_info_id=info1.id, tags=["wanted"])
session.commit()
infos, _, total = list_asset_infos_page(session, include_tags=["wanted"])
assert total == 1
assert infos[0].name == "tagged"
def test_exclude_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
info1 = _make_asset_info(session, asset, name="keep")
info2 = _make_asset_info(session, asset, name="exclude")
ensure_tags_exist(session, ["bad"])
add_tags_to_asset_info(session, asset_info_id=info2.id, tags=["bad"])
session.commit()
infos, _, total = list_asset_infos_page(session, exclude_tags=["bad"])
assert total == 1
assert infos[0].name == "keep"
def test_sorting(self, session: Session):
asset = _make_asset(session, "hash1", size=100)
asset2 = _make_asset(session, "hash2", size=500)
_make_asset_info(session, asset, name="small")
_make_asset_info(session, asset2, name="large")
session.commit()
infos, _, _ = list_asset_infos_page(session, sort="size", order="desc")
assert infos[0].name == "large"
infos, _, _ = list_asset_infos_page(session, sort="name", order="asc")
assert infos[0].name == "large"
class TestFetchAssetInfoAssetAndTags:
def test_returns_none_for_nonexistent(self, session: Session):
result = fetch_asset_info_asset_and_tags(session, "nonexistent")
assert result is None
def test_returns_tuple(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, name="test.bin")
ensure_tags_exist(session, ["tag1"])
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["tag1"])
session.commit()
result = fetch_asset_info_asset_and_tags(session, info.id)
assert result is not None
ret_info, ret_asset, ret_tags = result
assert ret_info.id == info.id
assert ret_asset.id == asset.id
assert ret_tags == ["tag1"]
class TestFetchAssetInfoAndAsset:
def test_returns_none_for_nonexistent(self, session: Session):
result = fetch_asset_info_and_asset(session, asset_info_id="nonexistent")
assert result is None
def test_returns_tuple(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
result = fetch_asset_info_and_asset(session, asset_info_id=info.id)
assert result is not None
ret_info, ret_asset = result
assert ret_info.id == info.id
assert ret_asset.id == asset.id
class TestTouchAssetInfoById:
def test_updates_last_access_time(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
original_time = info.last_access_time
session.commit()
import time
time.sleep(0.01)
touch_asset_info_by_id(session, asset_info_id=info.id)
session.commit()
session.refresh(info)
assert info.last_access_time > original_time
class TestDeleteAssetInfoById:
def test_deletes_existing(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
result = delete_asset_info_by_id(session, asset_info_id=info.id, owner_id="")
assert result is True
assert get_asset_info_by_id(session, asset_info_id=info.id) is None
def test_returns_false_for_nonexistent(self, session: Session):
result = delete_asset_info_by_id(session, asset_info_id="nonexistent", owner_id="")
assert result is False
def test_respects_owner_visibility(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset, owner_id="user1")
session.commit()
result = delete_asset_info_by_id(session, asset_info_id=info.id, owner_id="user2")
assert result is False
assert get_asset_info_by_id(session, asset_info_id=info.id) is not None
class TestSetAssetInfoPreview:
def test_sets_preview(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
info = _make_asset_info(session, asset)
session.commit()
set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id=preview_asset.id)
session.commit()
session.refresh(info)
assert info.preview_id == preview_asset.id
def test_clears_preview(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
info = _make_asset_info(session, asset)
info.preview_id = preview_asset.id
session.commit()
set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id=None)
session.commit()
session.refresh(info)
assert info.preview_id is None
def test_raises_for_nonexistent_info(self, session: Session):
with pytest.raises(ValueError, match="not found"):
set_asset_info_preview(session, asset_info_id="nonexistent", preview_asset_id=None)
def test_raises_for_nonexistent_preview(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
session.commit()
with pytest.raises(ValueError, match="Preview Asset"):
set_asset_info_preview(session, asset_info_id=info.id, preview_asset_id="nonexistent")

View File

@ -0,0 +1,128 @@
"""Tests for cache_state query functions."""
import os
import tempfile
from unittest.mock import patch
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
from app.assets.database.queries import (
list_cache_states_by_asset_id,
pick_best_live_path,
)
from app.assets.helpers import utcnow
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size)
session.add(asset)
session.flush()
return asset
def _make_cache_state(
session: Session,
asset: Asset,
file_path: str,
mtime_ns: int | None = None,
needs_verify: bool = False,
) -> AssetCacheState:
state = AssetCacheState(
asset_id=asset.id,
file_path=file_path,
mtime_ns=mtime_ns,
needs_verify=needs_verify,
)
session.add(state)
session.flush()
return state
class TestListCacheStatesByAssetId:
def test_returns_empty_for_no_states(self, session: Session):
asset = _make_asset(session, "hash1")
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
assert list(states) == []
def test_returns_states_for_asset(self, session: Session):
asset = _make_asset(session, "hash1")
_make_cache_state(session, asset, "/path/a.bin")
_make_cache_state(session, asset, "/path/b.bin")
session.commit()
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
paths = [s.file_path for s in states]
assert set(paths) == {"/path/a.bin", "/path/b.bin"}
def test_does_not_return_other_assets_states(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_cache_state(session, asset1, "/path/asset1.bin")
_make_cache_state(session, asset2, "/path/asset2.bin")
session.commit()
states = list_cache_states_by_asset_id(session, asset_id=asset1.id)
paths = [s.file_path for s in states]
assert paths == ["/path/asset1.bin"]
class TestPickBestLivePath:
def test_returns_empty_for_empty_list(self):
result = pick_best_live_path([])
assert result == ""
def test_returns_empty_when_no_files_exist(self, session: Session):
asset = _make_asset(session, "hash1")
state = _make_cache_state(session, asset, "/nonexistent/path.bin")
session.commit()
result = pick_best_live_path([state])
assert result == ""
def test_prefers_verified_path(self, session: Session, tmp_path):
"""needs_verify=False should be preferred."""
asset = _make_asset(session, "hash1")
verified_file = tmp_path / "verified.bin"
verified_file.write_bytes(b"data")
unverified_file = tmp_path / "unverified.bin"
unverified_file.write_bytes(b"data")
state_verified = _make_cache_state(
session, asset, str(verified_file), needs_verify=False
)
state_unverified = _make_cache_state(
session, asset, str(unverified_file), needs_verify=True
)
session.commit()
states = [state_unverified, state_verified]
result = pick_best_live_path(states)
assert result == str(verified_file)
def test_falls_back_to_existing_unverified(self, session: Session, tmp_path):
"""If all states need verification, return first existing path."""
asset = _make_asset(session, "hash1")
existing_file = tmp_path / "exists.bin"
existing_file.write_bytes(b"data")
state = _make_cache_state(session, asset, str(existing_file), needs_verify=True)
session.commit()
result = pick_best_live_path([state])
assert result == str(existing_file)
class TestPickBestLivePathWithMocking:
def test_handles_missing_file_path_attr(self):
"""Gracefully handle states with None file_path."""
class MockState:
file_path = None
needs_verify = False
result = pick_best_live_path([MockState()])
assert result == ""

View File

@ -0,0 +1,180 @@
"""Tests for metadata filtering logic in asset_info queries."""
import pytest
from decimal import Decimal
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta
from app.assets.database.queries import list_asset_infos_page
from app.assets.helpers import utcnow, project_kv
def _make_asset(session: Session, hash_val: str) -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_asset_info(
session: Session,
asset: Asset,
name: str,
metadata: dict | None = None,
) -> AssetInfo:
now = utcnow()
info = AssetInfo(
owner_id="",
name=name,
asset_id=asset.id,
user_metadata=metadata,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
if metadata:
for key, val in metadata.items():
for row in project_kv(key, val):
meta_row = AssetInfoMeta(
asset_info_id=info.id,
key=row["key"],
ordinal=row.get("ordinal", 0),
val_str=row.get("val_str"),
val_num=row.get("val_num"),
val_bool=row.get("val_bool"),
val_json=row.get("val_json"),
)
session.add(meta_row)
session.flush()
return info
class TestMetadataFilterString:
def test_filter_by_string_value(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "match", {"category": "models"})
_make_asset_info(session, asset, "nomatch", {"category": "images"})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={"category": "models"})
assert total == 1
assert infos[0].name == "match"
def test_filter_by_string_no_match(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "item", {"category": "models"})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={"category": "other"})
assert total == 0
class TestMetadataFilterNumeric:
def test_filter_by_int_value(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "epoch5", {"epoch": 5})
_make_asset_info(session, asset, "epoch10", {"epoch": 10})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={"epoch": 5})
assert total == 1
assert infos[0].name == "epoch5"
def test_filter_by_float_value(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "high", {"score": 0.95})
_make_asset_info(session, asset, "low", {"score": 0.5})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={"score": 0.95})
assert total == 1
assert infos[0].name == "high"
class TestMetadataFilterBoolean:
def test_filter_by_true(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "active", {"enabled": True})
_make_asset_info(session, asset, "inactive", {"enabled": False})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={"enabled": True})
assert total == 1
assert infos[0].name == "active"
def test_filter_by_false(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "active", {"enabled": True})
_make_asset_info(session, asset, "inactive", {"enabled": False})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={"enabled": False})
assert total == 1
assert infos[0].name == "inactive"
class TestMetadataFilterNull:
def test_filter_by_null_matches_missing_key(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "has_key", {"optional": "value"})
_make_asset_info(session, asset, "missing_key", {})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={"optional": None})
assert total == 1
assert infos[0].name == "missing_key"
def test_filter_by_null_matches_explicit_null(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "explicit_null", {"nullable": None})
_make_asset_info(session, asset, "has_value", {"nullable": "present"})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={"nullable": None})
assert total == 1
assert infos[0].name == "explicit_null"
class TestMetadataFilterList:
def test_filter_by_list_or(self, session: Session):
"""List values should match ANY of the values (OR)."""
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "cat_a", {"category": "a"})
_make_asset_info(session, asset, "cat_b", {"category": "b"})
_make_asset_info(session, asset, "cat_c", {"category": "c"})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={"category": ["a", "b"]})
assert total == 2
names = {i.name for i in infos}
assert names == {"cat_a", "cat_b"}
class TestMetadataFilterMultipleKeys:
def test_multiple_keys_and(self, session: Session):
"""Multiple keys should ALL match (AND)."""
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "match", {"type": "model", "version": 2})
_make_asset_info(session, asset, "wrong_type", {"type": "config", "version": 2})
_make_asset_info(session, asset, "wrong_version", {"type": "model", "version": 1})
session.commit()
infos, _, total = list_asset_infos_page(
session, metadata_filter={"type": "model", "version": 2}
)
assert total == 1
assert infos[0].name == "match"
class TestMetadataFilterEmptyDict:
def test_empty_filter_returns_all(self, session: Session):
asset = _make_asset(session, "hash1")
_make_asset_info(session, asset, "a", {"key": "val"})
_make_asset_info(session, asset, "b", {})
session.commit()
infos, _, total = list_asset_infos_page(session, metadata_filter={})
assert total == 2

View File

@ -0,0 +1,297 @@
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetInfo, AssetInfoTag, Tag
from app.assets.database.queries import (
ensure_tags_exist,
get_asset_tags,
set_asset_info_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
add_missing_tag_for_asset_id,
remove_missing_tag_for_asset_id,
list_tags_with_usage,
)
from app.assets.helpers import utcnow
def _make_asset(session: Session, hash_val: str | None = None) -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_asset_info(session: Session, asset: Asset, name: str = "test", owner_id: str = "") -> AssetInfo:
now = utcnow()
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
return info
class TestEnsureTagsExist:
def test_creates_new_tags(self, session: Session):
ensure_tags_exist(session, ["alpha", "beta"], tag_type="user")
session.commit()
tags = session.query(Tag).all()
assert {t.name for t in tags} == {"alpha", "beta"}
def test_is_idempotent(self, session: Session):
ensure_tags_exist(session, ["alpha"], tag_type="user")
ensure_tags_exist(session, ["alpha"], tag_type="user")
session.commit()
assert session.query(Tag).count() == 1
def test_normalizes_tags(self, session: Session):
ensure_tags_exist(session, [" ALPHA ", "Beta", "alpha"])
session.commit()
tags = session.query(Tag).all()
assert {t.name for t in tags} == {"alpha", "beta"}
def test_empty_list_is_noop(self, session: Session):
ensure_tags_exist(session, [])
session.commit()
assert session.query(Tag).count() == 0
def test_tag_type_is_set(self, session: Session):
ensure_tags_exist(session, ["system-tag"], tag_type="system")
session.commit()
tag = session.query(Tag).filter_by(name="system-tag").one()
assert tag.tag_type == "system"
class TestGetAssetTags:
def test_returns_empty_for_no_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
tags = get_asset_tags(session, asset_info_id=info.id)
assert tags == []
def test_returns_tags_for_asset(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["tag1", "tag2"])
session.add_all([
AssetInfoTag(asset_info_id=info.id, tag_name="tag1", origin="manual", added_at=utcnow()),
AssetInfoTag(asset_info_id=info.id, tag_name="tag2", origin="manual", added_at=utcnow()),
])
session.flush()
tags = get_asset_tags(session, asset_info_id=info.id)
assert set(tags) == {"tag1", "tag2"}
class TestSetAssetInfoTags:
def test_adds_new_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
result = set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b"])
session.commit()
assert set(result["added"]) == {"a", "b"}
assert result["removed"] == []
assert set(result["total"]) == {"a", "b"}
def test_removes_old_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b", "c"])
result = set_asset_info_tags(session, asset_info_id=info.id, tags=["a"])
session.commit()
assert result["added"] == []
assert set(result["removed"]) == {"b", "c"}
assert result["total"] == ["a"]
def test_replaces_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
set_asset_info_tags(session, asset_info_id=info.id, tags=["a", "b"])
result = set_asset_info_tags(session, asset_info_id=info.id, tags=["b", "c"])
session.commit()
assert result["added"] == ["c"]
assert result["removed"] == ["a"]
assert set(result["total"]) == {"b", "c"}
class TestAddTagsToAssetInfo:
def test_adds_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
result = add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x", "y"])
session.commit()
assert set(result["added"]) == {"x", "y"}
assert result["already_present"] == []
def test_reports_already_present(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x"])
result = add_tags_to_asset_info(session, asset_info_id=info.id, tags=["x", "y"])
session.commit()
assert result["added"] == ["y"]
assert result["already_present"] == ["x"]
def test_raises_for_missing_asset_info(self, session: Session):
with pytest.raises(ValueError, match="not found"):
add_tags_to_asset_info(session, asset_info_id="nonexistent", tags=["x"])
class TestRemoveTagsFromAssetInfo:
def test_removes_tags(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a", "b", "c"])
result = remove_tags_from_asset_info(session, asset_info_id=info.id, tags=["a", "b"])
session.commit()
assert set(result["removed"]) == {"a", "b"}
assert result["not_present"] == []
assert result["total_tags"] == ["c"]
def test_reports_not_present(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["a"])
result = remove_tags_from_asset_info(session, asset_info_id=info.id, tags=["a", "x"])
session.commit()
assert result["removed"] == ["a"]
assert result["not_present"] == ["x"]
def test_raises_for_missing_asset_info(self, session: Session):
with pytest.raises(ValueError, match="not found"):
remove_tags_from_asset_info(session, asset_info_id="nonexistent", tags=["x"])
class TestMissingTagFunctions:
def test_add_missing_tag_for_asset_id(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
tags = get_asset_tags(session, asset_info_id=info.id)
assert "missing" in tags
def test_add_missing_tag_is_idempotent(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
add_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
links = session.query(AssetInfoTag).filter_by(asset_info_id=info.id, tag_name="missing").all()
assert len(links) == 1
def test_remove_missing_tag_for_asset_id(self, session: Session):
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
tags = get_asset_tags(session, asset_info_id=info.id)
assert "missing" not in tags
class TestListTagsWithUsage:
def test_returns_tags_with_counts(self, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
session.commit()
rows, total = list_tags_with_usage(session)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict["used"] == 1
assert tag_dict["unused"] == 0
assert total == 2
def test_exclude_zero_counts(self, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session, "hash1")
info = _make_asset_info(session, asset)
add_tags_to_asset_info(session, asset_info_id=info.id, tags=["used"])
session.commit()
rows, total = list_tags_with_usage(session, include_zero=False)
tag_names = {name for name, _, _ in rows}
assert "used" in tag_names
assert "unused" not in tag_names
def test_prefix_filter(self, session: Session):
ensure_tags_exist(session, ["alpha", "beta", "alphabet"])
session.commit()
rows, total = list_tags_with_usage(session, prefix="alph")
tag_names = {name for name, _, _ in rows}
assert tag_names == {"alpha", "alphabet"}
def test_order_by_name(self, session: Session):
ensure_tags_exist(session, ["zebra", "alpha", "middle"])
session.commit()
rows, _ = list_tags_with_usage(session, order="name_asc")
names = [name for name, _, _ in rows]
assert names == ["alpha", "middle", "zebra"]
def test_owner_visibility(self, session: Session):
ensure_tags_exist(session, ["shared-tag", "owner-tag"])
asset = _make_asset(session, "hash1")
shared_info = _make_asset_info(session, asset, name="shared", owner_id="")
owner_info = _make_asset_info(session, asset, name="owned", owner_id="user1")
add_tags_to_asset_info(session, asset_info_id=shared_info.id, tags=["shared-tag"])
add_tags_to_asset_info(session, asset_info_id=owner_info.id, tags=["owner-tag"])
session.commit()
# Empty owner sees only shared
rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict.get("shared-tag", 0) == 1
assert tag_dict.get("owner-tag", 0) == 0
# User1 sees both
rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict.get("shared-tag", 0) == 1
assert tag_dict.get("owner-tag", 0) == 1