diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index 681af2635..9481100b0 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -149,6 +149,9 @@ def upgrade() -> None: # Extra basic tags {"name": "encoder", "tag_type": "system"}, {"name": "decoder", "tag_type": "system"}, + + # Special tags + {"name": "missing", "tag_type": "system"}, ], ) diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 42cf123d2..33efbf047 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -14,7 +14,12 @@ from . import assets_manager from .api import schemas_out from ._assets_helpers import get_comfy_models_folders from .database.db import create_session -from .database.services import check_fs_asset_exists_quick +from .database.services import ( + check_fs_asset_exists_quick, + list_cache_states_under_prefixes, + add_missing_tag_for_asset_hash, + remove_missing_tag_for_asset_hash, +) LOGGER = logging.getLogger(__name__) @@ -239,7 +244,6 @@ async def _fast_reconcile_into_queue( checked = 0 clean = 0 - # Single session for the whole fast pass async with await create_session() as sess: while True: item = await files_iter.get() @@ -261,7 +265,6 @@ async def _fast_reconcile_into_queue( _append_error(prog, phase="fast", path=abs_path, message=str(e)) continue - # Known good -> count as processed immediately try: known = await check_fs_asset_exists_quick( sess, @@ -275,7 +278,7 @@ async def _fast_reconcile_into_queue( if known: clean += 1 - prog.processed += 1 # preserve original semantics + prog.processed += 1 else: await state.queue.put(abs_path) queued += 1 @@ -300,9 +303,56 @@ async def _fast_reconcile_into_queue( "discovered": prog.discovered, }) + await _reconcile_missing_tags_for_root(root, prog) state.closed = True +async def _reconcile_missing_tags_for_root(root: RootType, prog: ScanProgress) -> None: + """ + For every AssetCacheState under the root's base directories: + - if at least one recorded file_path exists for a hash -> remove 'missing' + - if none of the recorded file_paths exist for a hash -> add 'missing' + """ + if root == "models": + bases: list[str] = [] + for _bucket, paths in get_comfy_models_folders(): + bases.extend(paths) + elif root == "input": + bases = [folder_paths.get_input_directory()] + else: + bases = [folder_paths.get_output_directory()] + + try: + async with await create_session() as sess: + states = await list_cache_states_under_prefixes(sess, prefixes=bases) + + present: set[str] = set() + missing: set[str] = set() + + for s in states: + try: + if os.path.isfile(s.file_path): + present.add(s.asset_hash) + else: + missing.add(s.asset_hash) + except Exception as e: + _append_error(prog, phase="fast", path=s.file_path, message=f"stat error: {e}") + + only_missing = missing - present + + for h in present: + with contextlib.suppress(Exception): + await remove_missing_tag_for_asset_hash(sess, asset_hash=h) + + for h in only_missing: + with contextlib.suppress(Exception): + await add_missing_tag_for_asset_hash(sess, asset_hash=h, origin="automatic") + + await sess.commit() + except Exception as e: + _append_error(prog, phase="fast", path="", message=f"missing-tag reconcile failed: {e}") + + def _start_slow_workers( root: RootType, prog: ScanProgress, diff --git a/app/database/_helpers.py b/app/database/_helpers.py new file mode 100644 index 000000000..5ce972076 --- /dev/null +++ b/app/database/_helpers.py @@ -0,0 +1,183 @@ +from decimal import Decimal +from typing import Any, Sequence, Optional, Iterable + +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, exists + +from .models import AssetInfo, AssetInfoTag, Tag, AssetInfoMeta +from .._assets_helpers import normalize_tags + + +async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]: + wanted = normalize_tags(list(names)) + if not wanted: + return [] + existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() + by_name = {t.name: t for t in existing} + to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name] + if to_create: + session.add_all(to_create) + await session.flush() + by_name.update({t.name: t for t in to_create}) + return [by_name[n] for n in wanted] + + +def apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Optional[Sequence[str]], + exclude_tags: Optional[Sequence[str]], +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + + +def apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: Optional[dict], +) -> sa.sql.Select: + """Apply metadata filters using the projection table asset_info_meta. + + Semantics: + - For scalar values: require EXISTS(asset_info_meta) with matching key + typed value. + - For None: key is missing OR key has explicit null (val_json IS NULL). + - For list values: ANY-of the list elements matches (EXISTS for any). + (Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))') + """ + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + subquery = ( + select(sa.literal(1)) + .select_from(AssetInfoMeta) + .where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + *preds, + ) + .limit(1) + ) + return sa.exists(subquery) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + # Missing OR null: + if value is None: + # either: no row for key OR a row for key with explicit null + no_row_for_key = ~sa.exists( + select(sa.literal(1)) + .select_from(AssetInfoMeta) + .where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + ) + .limit(1) + ) + null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None)) + return sa.or_(no_row_for_key, null_row) + + # Typed scalar matches: + if isinstance(value, bool): + return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) + if isinstance(value, (int, float, Decimal)): + # store as Decimal for equality against NUMERIC(38,10) + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetInfoMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetInfoMeta.val_str == value) + + # Complex: compare JSON (no index, but supported) + return _exists_for_pred(key, AssetInfoMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + # ANY-of (exists for any element) + ors = [_exists_clause_for_value(k, elem) for elem in v] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt + + +def is_scalar(v: Any) -> bool: + if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries + return True + if isinstance(v, bool): + return True + if isinstance(v, (int, float, Decimal, str)): + return True + return False + + +def project_kv(key: str, value: Any) -> list[dict]: + """ + Turn a metadata key/value into one or more projection rows: + - scalar -> one row (ordinal=0) in the proper typed column + - list of scalars -> one row per element with ordinal=i + - dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list) + - None -> single row with val_json = None + Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...} + """ + rows: list[dict] = [] + + if value is None: + rows.append({"key": key, "ordinal": 0, "val_json": None}) + return rows + + if is_scalar(value): + if isinstance(value, bool): + rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) + elif isinstance(value, (int, float, Decimal)): + # store numeric; SQLAlchemy will coerce to Numeric + num = value if isinstance(value, Decimal) else Decimal(str(value)) + rows.append({"key": key, "ordinal": 0, "val_num": num}) + elif isinstance(value, str): + rows.append({"key": key, "ordinal": 0, "val_str": value}) + else: + # Fallback to json + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows + + if isinstance(value, list): + if all(is_scalar(x) for x in value): + for i, x in enumerate(value): + if x is None: + rows.append({"key": key, "ordinal": i, "val_json": None}) + elif isinstance(x, bool): + rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) + elif isinstance(x, (int, float, Decimal)): + num = x if isinstance(x, Decimal) else Decimal(str(x)) + rows.append({"key": key, "ordinal": i, "val_num": num}) + elif isinstance(x, str): + rows.append({"key": key, "ordinal": i, "val_str": x}) + else: + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + # list contains objects -> one val_json per element + for i, x in enumerate(value): + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + + # Dict or any other structure -> single json row + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows diff --git a/app/database/services.py b/app/database/services.py index 94a9b7016..ceed3749a 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -3,18 +3,18 @@ import os import logging from collections import defaultdict from datetime import datetime -from decimal import Decimal -from typing import Any, Sequence, Optional, Iterable, Union +from typing import Any, Sequence, Optional, Union import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select, delete, exists, func +from sqlalchemy import select, delete, func from sqlalchemy.orm import contains_eager, noload from sqlalchemy.exc import IntegrityError from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation from .timeutil import utcnow from .._assets_helpers import normalize_tags, visible_owner_clause, compute_model_relative_filename +from . import _helpers async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool: @@ -221,7 +221,7 @@ async def ingest_fs_asset( norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] if norm and out["asset_info_id"] is not None: if not require_existing_tags: - await _ensure_tags_exist(session, norm, tag_type="user") + await _helpers.ensure_tags_exist(session, norm, tag_type="user") # Which tags exist? existing_tag_names = set( @@ -296,6 +296,10 @@ async def ingest_fs_asset( user_metadata=new_meta, ) # end of adding metadata["filename"] + try: + await remove_missing_tag_for_asset_hash(session, asset_hash=asset_hash) + except Exception: + logging.exception("Failed to clear 'missing' tag for %s", asset_hash) return out @@ -376,8 +380,8 @@ async def list_asset_infos_page( if name_contains: base = base.where(AssetInfo.name.ilike(f"%{name_contains}%")) - base = _apply_tag_filters(base, include_tags, exclude_tags) - base = _apply_metadata_filter(base, metadata_filter) + base = _helpers.apply_tag_filters(base, include_tags, exclude_tags) + base = _helpers.apply_metadata_filter(base, metadata_filter) sort = (sort or "created_at").lower() order = (order or "desc").lower() @@ -401,8 +405,8 @@ async def list_asset_infos_page( ) if name_contains: count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%")) - count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) - count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) + count_stmt = _helpers.apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = _helpers.apply_metadata_filter(count_stmt, metadata_filter) total = int((await session.execute(count_stmt)).scalar_one() or 0) @@ -646,7 +650,7 @@ async def set_asset_info_tags( to_remove = [t for t in current if t not in desired] if to_add: - await _ensure_tags_exist(session, to_add, tag_type="user") + await _helpers.ensure_tags_exist(session, to_add, tag_type="user") session.add_all([ AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) for t in to_add @@ -776,7 +780,7 @@ async def replace_asset_info_metadata_projection( rows: list[AssetInfoMeta] = [] for k, v in user_metadata.items(): - for r in _project_kv(k, v): + for r in _helpers.project_kv(k, v): rows.append( AssetInfoMeta( asset_info_id=asset_info_id, @@ -894,7 +898,7 @@ async def add_tags_to_asset_info( # Ensure tag rows exist if requested. if create_if_missing: - await _ensure_tags_exist(session, norm, tag_type="user") + await _helpers.ensure_tags_exist(session, norm, tag_type="user") # Snapshot current links current = { @@ -979,175 +983,93 @@ async def remove_tags_from_asset_info( return {"removed": to_remove, "not_present": not_present, "total_tags": total} -async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]: - wanted = normalize_tags(list(names)) - if not wanted: +async def add_missing_tag_for_asset_hash( + session: AsyncSession, + *, + asset_hash: str, + origin: str = "automatic", +) -> int: + """Ensure every AssetInfo referencing asset_hash has the 'missing' tag. + Returns number of AssetInfos newly tagged. + """ + ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_hash == asset_hash))).scalars().all() + if not ids: + return 0 + + existing = { + asset_info_id + for (asset_info_id,) in ( + await session.execute( + select(AssetInfoTag.asset_info_id).where( + AssetInfoTag.asset_info_id.in_(ids), + AssetInfoTag.tag_name == "missing", + ) + ) + ).all() + } + to_add = [i for i in ids if i not in existing] + if not to_add: + return 0 + + now = utcnow() + session.add_all( + [ + AssetInfoTag(asset_info_id=i, tag_name="missing", origin=origin, added_at=now) + for i in to_add + ] + ) + await session.flush() + return len(to_add) + + +async def remove_missing_tag_for_asset_hash( + session: AsyncSession, + *, + asset_hash: str, +) -> int: + """Remove the 'missing' tag from every AssetInfo referencing asset_hash. + Returns number of link rows removed. + """ + ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_hash == asset_hash))).scalars().all() + if not ids: + return 0 + + res = await session.execute( + delete(AssetInfoTag).where( + AssetInfoTag.asset_info_id.in_(ids), + AssetInfoTag.tag_name == "missing", + ) + ) + await session.flush() + return int(res.rowcount or 0) + + +async def list_cache_states_under_prefixes( + session: AsyncSession, + *, + prefixes: Sequence[str], +) -> list[AssetCacheState]: + """Return AssetCacheState rows whose file_path starts with any of the given absolute prefixes.""" + if not prefixes: return [] - existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() - by_name = {t.name: t for t in existing} - to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name] - if to_create: - session.add_all(to_create) - await session.flush() - by_name.update({t.name: t for t in to_create}) - return [by_name[n] for n in wanted] + conds = [] + for p in prefixes: + if not p: + continue + base = os.path.abspath(p) + if not base.endswith(os.sep): + base = base + os.sep + conds.append(AssetCacheState.file_path.like(base + "%")) -def _apply_tag_filters( - stmt: sa.sql.Select, - include_tags: Optional[Sequence[str]], - exclude_tags: Optional[Sequence[str]], -) -> sa.sql.Select: - """include_tags: every tag must be present; exclude_tags: none may be present.""" - include_tags = normalize_tags(include_tags) - exclude_tags = normalize_tags(exclude_tags) + if not conds: + return [] - if include_tags: - for tag_name in include_tags: - stmt = stmt.where( - exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name == tag_name) - ) - ) - - if exclude_tags: - stmt = stmt.where( - ~exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name.in_(exclude_tags)) - ) + rows = ( + await session.execute( + select(AssetCacheState) + .where(sa.or_(*conds)) + .order_by(AssetCacheState.id.asc()) ) - return stmt - - -def _apply_metadata_filter( - stmt: sa.sql.Select, - metadata_filter: Optional[dict], -) -> sa.sql.Select: - """Apply metadata filters using the projection table asset_info_meta. - - Semantics: - - For scalar values: require EXISTS(asset_info_meta) with matching key + typed value. - - For None: key is missing OR key has explicit null (val_json IS NULL). - - For list values: ANY-of the list elements matches (EXISTS for any). - (Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))') - """ - if not metadata_filter: - return stmt - - def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: - subquery = ( - select(sa.literal(1)) - .select_from(AssetInfoMeta) - .where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - *preds, - ) - .limit(1) - ) - return sa.exists(subquery) - - def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: - # Missing OR null: - if value is None: - # either: no row for key OR a row for key with explicit null - no_row_for_key = ~sa.exists( - select(sa.literal(1)) - .select_from(AssetInfoMeta) - .where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - ) - .limit(1) - ) - null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None)) - return sa.or_(no_row_for_key, null_row) - - # Typed scalar matches: - if isinstance(value, bool): - return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) - if isinstance(value, (int, float, Decimal)): - # store as Decimal for equality against NUMERIC(38,10) - num = value if isinstance(value, Decimal) else Decimal(str(value)) - return _exists_for_pred(key, AssetInfoMeta.val_num == num) - if isinstance(value, str): - return _exists_for_pred(key, AssetInfoMeta.val_str == value) - - # Complex: compare JSON (no index, but supported) - return _exists_for_pred(key, AssetInfoMeta.val_json == value) - - for k, v in metadata_filter.items(): - if isinstance(v, list): - # ANY-of (exists for any element) - ors = [_exists_clause_for_value(k, elem) for elem in v] - if ors: - stmt = stmt.where(sa.or_(*ors)) - else: - stmt = stmt.where(_exists_clause_for_value(k, v)) - return stmt - - -def _is_scalar(v: Any) -> bool: - if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries - return True - if isinstance(v, bool): - return True - if isinstance(v, (int, float, Decimal, str)): - return True - return False - - -def _project_kv(key: str, value: Any) -> list[dict]: - """ - Turn a metadata key/value into one or more projection rows: - - scalar -> one row (ordinal=0) in the proper typed column - - list of scalars -> one row per element with ordinal=i - - dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list) - - None -> single row with val_json = None - Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...} - """ - rows: list[dict] = [] - - if value is None: - rows.append({"key": key, "ordinal": 0, "val_json": None}) - return rows - - if _is_scalar(value): - if isinstance(value, bool): - rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) - elif isinstance(value, (int, float, Decimal)): - # store numeric; SQLAlchemy will coerce to Numeric - num = value if isinstance(value, Decimal) else Decimal(str(value)) - rows.append({"key": key, "ordinal": 0, "val_num": num}) - elif isinstance(value, str): - rows.append({"key": key, "ordinal": 0, "val_str": value}) - else: - # Fallback to json - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows - - if isinstance(value, list): - if all(_is_scalar(x) for x in value): - for i, x in enumerate(value): - if x is None: - rows.append({"key": key, "ordinal": i, "val_json": None}) - elif isinstance(x, bool): - rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) - elif isinstance(x, (int, float, Decimal)): - num = x if isinstance(x, Decimal) else Decimal(str(x)) - rows.append({"key": key, "ordinal": i, "val_num": num}) - elif isinstance(x, str): - rows.append({"key": key, "ordinal": i, "val_str": x}) - else: - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - # list contains objects -> one val_json per element - for i, x in enumerate(value): - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - - # Dict or any other structure -> single json row - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows + ).scalars().all() + return list(rows)