import asyncio import logging import os import time from dataclasses import dataclass, field from typing import Literal, Optional import sqlalchemy as sa import folder_paths from ._assets_helpers import ( collect_models_files, get_comfy_models_folders, get_name_and_tags_from_asset_path, list_tree, new_scan_id, prefixes_for_root, ts_to_iso, ) from .api import schemas_in, schemas_out from .database.db import create_session from .database.helpers import ( add_missing_tag_for_asset_id, remove_missing_tag_for_asset_id, ) from .database.models import Asset, AssetCacheState, AssetInfo from .database.services import ( compute_hash_and_dedup_for_cache_state, ensure_seed_for_path, list_cache_states_by_asset_id, list_cache_states_with_asset_under_prefixes, list_unhashed_candidates_under_prefixes, list_verify_candidates_under_prefixes, ) LOGGER = logging.getLogger(__name__) SLOW_HASH_CONCURRENCY = 1 @dataclass class ScanProgress: scan_id: str root: schemas_in.RootType status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled" scheduled_at: float = field(default_factory=lambda: time.time()) started_at: Optional[float] = None finished_at: Optional[float] = None discovered: int = 0 processed: int = 0 file_errors: list[dict] = field(default_factory=list) @dataclass class SlowQueueState: queue: asyncio.Queue workers: list[asyncio.Task] = field(default_factory=list) closed: bool = False RUNNING_TASKS: dict[schemas_in.RootType, asyncio.Task] = {} PROGRESS_BY_ROOT: dict[schemas_in.RootType, ScanProgress] = {} SLOW_STATE_BY_ROOT: dict[schemas_in.RootType, SlowQueueState] = {} def current_statuses() -> schemas_out.AssetScanStatusResponse: scans = [] for root in schemas_in.ALLOWED_ROOTS: prog = PROGRESS_BY_ROOT.get(root) if not prog: continue scans.append(_scan_progress_to_scan_status_model(prog)) return schemas_out.AssetScanStatusResponse(scans=scans) async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetScanStatusResponse: results: list[ScanProgress] = [] for root in roots: if root in RUNNING_TASKS and not RUNNING_TASKS[root].done(): results.append(PROGRESS_BY_ROOT[root]) continue prog = ScanProgress(scan_id=new_scan_id(root), root=root, status="scheduled") PROGRESS_BY_ROOT[root] = prog state = SlowQueueState(queue=asyncio.Queue()) SLOW_STATE_BY_ROOT[root] = state RUNNING_TASKS[root] = asyncio.create_task( _run_hash_verify_pipeline(root, prog, state), name=f"asset-scan:{root}", ) results.append(prog) return _status_response_for(results) async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: for r in roots: try: await _fast_db_consistency_pass(r) except Exception as ex: LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex) paths: list[str] = [] if "models" in roots: paths.extend(collect_models_files()) if "input" in roots: paths.extend(list_tree(folder_paths.get_input_directory())) if "output" in roots: paths.extend(list_tree(folder_paths.get_output_directory())) for p in paths: try: st = os.stat(p, follow_symlinks=True) if not int(st.st_size or 0): continue size_bytes = int(st.st_size) mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) name, tags = get_name_and_tags_from_asset_path(p) await _seed_one_async(p, size_bytes, mtime_ns, name, tags) except OSError: continue async def _seed_one_async(p: str, size_bytes: int, mtime_ns: int, name: str, tags: list[str]) -> None: async with await create_session() as sess: await ensure_seed_for_path( sess, abs_path=p, size_bytes=size_bytes, mtime_ns=mtime_ns, info_name=name, tags=tags, owner_id="", ) await sess.commit() def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse: return schemas_out.AssetScanStatusResponse(scans=[_scan_progress_to_scan_status_model(p) for p in progresses]) def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.AssetScanStatus: return schemas_out.AssetScanStatus( scan_id=progress.scan_id, root=progress.root, status=progress.status, scheduled_at=ts_to_iso(progress.scheduled_at), started_at=ts_to_iso(progress.started_at), finished_at=ts_to_iso(progress.finished_at), discovered=progress.discovered, processed=progress.processed, file_errors=[ schemas_out.AssetScanError( path=e.get("path", ""), message=e.get("message", ""), at=e.get("at"), ) for e in (progress.file_errors or []) ], ) async def _refresh_verify_flags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None: """Fast pass to mark verify candidates by comparing stored mtime_ns with on-disk mtime.""" prefixes = prefixes_for_root(root) if not prefixes: return conds = [] for p in prefixes: base = os.path.abspath(p) if not base.endswith(os.sep): base += os.sep conds.append(AssetCacheState.file_path.like(base + "%")) async with await create_session() as sess: rows = ( await sess.execute( sa.select( AssetCacheState.id, AssetCacheState.mtime_ns, AssetCacheState.needs_verify, Asset.hash, AssetCacheState.file_path, ) .join(Asset, Asset.id == AssetCacheState.asset_id) .where(sa.or_(*conds)) ) ).all() to_set = [] to_clear = [] for sid, mtime_db, needs_verify, a_hash, fp in rows: try: st = os.stat(fp, follow_symlinks=True) except OSError: # Missing files are handled by missing-tag reconciliation later. continue actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) if a_hash is not None: if mtime_db is None or int(mtime_db) != int(actual_mtime_ns): if not needs_verify: to_set.append(sid) else: if needs_verify: to_clear.append(sid) if to_set: await sess.execute( sa.update(AssetCacheState) .where(AssetCacheState.id.in_(to_set)) .values(needs_verify=True) ) if to_clear: await sess.execute( sa.update(AssetCacheState) .where(AssetCacheState.id.in_(to_clear)) .values(needs_verify=False) ) await sess.commit() async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None: prog.status = "running" prog.started_at = time.time() try: prefixes = prefixes_for_root(root) await _refresh_verify_flags_for_root(root, prog) # collect candidates from DB async with await create_session() as sess: verify_ids = await list_verify_candidates_under_prefixes(sess, prefixes=prefixes) unhashed_ids = await list_unhashed_candidates_under_prefixes(sess, prefixes=prefixes) # dedupe: prioritize verification first seen = set() ordered: list[int] = [] for lst in (verify_ids, unhashed_ids): for sid in lst: if sid not in seen: seen.add(sid); ordered.append(sid) prog.discovered = len(ordered) # queue up work for sid in ordered: await state.queue.put(sid) state.closed = True _start_state_workers(root, prog, state) await _await_state_workers_then_finish(root, prog, state) except asyncio.CancelledError: prog.status = "cancelled" raise except Exception as exc: _append_error(prog, path="", message=str(exc)) prog.status = "failed" prog.finished_at = time.time() LOGGER.exception("Asset scan failed for %s", root) finally: RUNNING_TASKS.pop(root, None) async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None: """ Detect missing files quickly and toggle 'missing' tag per asset_id. Rules: - Only hashed assets (assets.hash != NULL) participate in missing tagging. - We consider ALL cache states of the asset (across roots) before tagging. """ if root == "models": bases: list[str] = [] for _bucket, paths in get_comfy_models_folders(): bases.extend(paths) elif root == "input": bases = [folder_paths.get_input_directory()] else: bases = [folder_paths.get_output_directory()] try: async with await create_session() as sess: # state + hash + size for the current root rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases) # Track fast_ok within the scanned root and whether the asset is hashed by_asset: dict[str, dict[str, bool]] = {} for state, a_hash, size_db in rows: aid = state.asset_id acc = by_asset.get(aid) if acc is None: acc = {"any_fast_ok_here": False, "hashed": (a_hash is not None), "size_db": int(size_db or 0)} by_asset[aid] = acc try: st = os.stat(state.file_path, follow_symlinks=True) actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) fast_ok = False if acc["hashed"]: if state.mtime_ns is not None and int(state.mtime_ns) == int(actual_mtime_ns): if int(acc["size_db"]) > 0 and int(st.st_size) == int(acc["size_db"]): fast_ok = True if fast_ok: acc["any_fast_ok_here"] = True except FileNotFoundError: pass except OSError as e: _append_error(prog, path=state.file_path, message=str(e)) # Decide per asset, considering ALL its states (not just this root) for aid, acc in by_asset.items(): try: if not acc["hashed"]: # Never tag seed assets as missing continue any_fast_ok_global = acc["any_fast_ok_here"] if not any_fast_ok_global: # Check other states outside this root others = await list_cache_states_by_asset_id(sess, asset_id=aid) for st in others: try: s = os.stat(st.file_path, follow_symlinks=True) actual_mtime_ns = getattr(s, "st_mtime_ns", int(s.st_mtime * 1_000_000_000)) if st.mtime_ns is not None and int(st.mtime_ns) == int(actual_mtime_ns): if acc["size_db"] > 0 and int(s.st_size) == acc["size_db"]: any_fast_ok_global = True break except OSError: continue if any_fast_ok_global: await remove_missing_tag_for_asset_id(sess, asset_id=aid) else: await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") except Exception as ex: _append_error(prog, path="", message=f"reconcile {aid[:8]}: {ex}") await sess.commit() except Exception as e: _append_error(prog, path="", message=f"reconcile failed: {e}") def _start_state_workers(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None: if state.workers: return async def _worker(_wid: int): while True: sid = await state.queue.get() try: if sid is None: return try: async with await create_session() as sess: # Optional: fetch path for better error messages st = await sess.get(AssetCacheState, sid) try: await compute_hash_and_dedup_for_cache_state(sess, state_id=sid) await sess.commit() except Exception as e: path = st.file_path if st else f"state:{sid}" _append_error(prog, path=path, message=str(e)) raise except Exception: pass finally: prog.processed += 1 finally: state.queue.task_done() state.workers = [ asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}") for i in range(SLOW_HASH_CONCURRENCY) ] async def _close_when_ready(): while not state.closed: await asyncio.sleep(0.05) for _ in range(SLOW_HASH_CONCURRENCY): await state.queue.put(None) asyncio.create_task(_close_when_ready()) async def _await_state_workers_then_finish(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None: if state.workers: await asyncio.gather(*state.workers, return_exceptions=True) await _reconcile_missing_tags_for_root(root, prog) prog.finished_at = time.time() prog.status = "completed" def _append_error(prog: ScanProgress, *, path: str, message: str) -> None: prog.file_errors.append({ "path": path, "message": message, "at": ts_to_iso(time.time()), }) async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: """ Quick pass over asset_cache_state for `root`: - If file missing and Asset.hash is NULL and the Asset has no other states, delete the Asset and its infos. - If file missing and Asset.hash is NOT NULL: * If at least one state for this Asset is fast-ok, delete the missing state. * If none are fast-ok, add 'missing' tag to all AssetInfos for this Asset. - If at least one state becomes fast-ok for a hashed Asset, remove the 'missing' tag. """ prefixes = prefixes_for_root(root) if not prefixes: return conds = [] for p in prefixes: base = os.path.abspath(p) if not base.endswith(os.sep): base += os.sep conds.append(AssetCacheState.file_path.like(base + "%")) async with await create_session() as sess: if not conds: return rows = ( await sess.execute( sa.select(AssetCacheState, Asset.hash, Asset.size_bytes) .join(Asset, Asset.id == AssetCacheState.asset_id) .where(sa.or_(*conds)) .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) ) ).all() # Group by asset_id with status per state by_asset: dict[str, dict] = {} for st, a_hash, a_size in rows: aid = st.asset_id acc = by_asset.get(aid) if acc is None: acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} by_asset[aid] = acc exists = False fast_ok = False try: s = os.stat(st.file_path, follow_symlinks=True) exists = True actual_mtime_ns = getattr(s, "st_mtime_ns", int(s.st_mtime * 1_000_000_000)) if st.mtime_ns is not None and int(st.mtime_ns) == int(actual_mtime_ns): if acc["size_db"] == 0 or int(s.st_size) == acc["size_db"]: fast_ok = True except FileNotFoundError: exists = False except OSError as ex: exists = False LOGGER.debug("fast pass stat error for %s: %s", st.file_path, ex) acc["states"].append({"obj": st, "exists": exists, "fast_ok": fast_ok}) # Apply actions for aid, acc in by_asset.items(): a_hash = acc["hash"] states = acc["states"] any_fast_ok = any(s["fast_ok"] for s in states) all_missing = all(not s["exists"] for s in states) missing_states = [s["obj"] for s in states if not s["exists"]] if a_hash is None: # Seed asset: if all states gone (and in practice there is only one), remove the whole Asset if states and all_missing: await sess.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid)) asset = await sess.get(Asset, aid) if asset: await sess.delete(asset) # else leave it for the slow scan to verify/rehash else: if any_fast_ok: # Remove 'missing' and delete just the stale state rows for st in missing_states: try: await sess.delete(await sess.get(AssetCacheState, st.id)) except Exception: pass try: await remove_missing_tag_for_asset_id(sess, asset_id=aid) except Exception: pass else: # No fast-ok path: mark as missing try: await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") except Exception: pass await sess.flush() await sess.commit()