From 4d4c2cedd3da2aba1fe44689f0cb3d39a30c2d1d Mon Sep 17 00:00:00 2001 From: Luke Mino-Altherr Date: Tue, 3 Mar 2026 15:51:35 -0800 Subject: [PATCH] fix: address code review feedback - Fix missing import for compute_filename_for_reference in ingest.py - Apply code review fixes across routes, queries, scanner, seeder, hashing, ingest, path_utils, main, and server - Update and add tests for sync references and seeder Amp-Thread-ID: https://ampcode.com/threads/T-019cb61a-ed54-738c-a05f-9b5242e513f3 Co-authored-by: Amp --- app/assets/api/routes.py | 40 +- app/assets/database/queries/__init__.py | 2 + .../database/queries/asset_reference.py | 30 +- app/assets/database/queries/tags.py | 11 + app/assets/scanner.py | 89 +---- app/assets/seeder.py | 68 ++-- app/assets/services/hashing.py | 15 - app/assets/services/ingest.py | 42 +-- app/assets/services/path_utils.py | 5 +- main.py | 8 +- server.py | 7 +- .../assets_test/test_sync_references.py | 350 ++++++++++++++++++ tests-unit/seeder_test/test_seeder.py | 226 ++++++++--- 13 files changed, 675 insertions(+), 218 deletions(-) create mode 100644 tests-unit/assets_test/test_sync_references.py diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 55bbe7ded..99d3820d0 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -1,4 +1,5 @@ import asyncio +import functools import json import logging import os @@ -39,6 +40,20 @@ from app.assets.services import ( ROUTES = web.RouteTableDef() USER_MANAGER: user_manager.UserManager | None = None +_ASSETS_ENABLED = False + + +def _require_assets_feature_enabled(handler): + @functools.wraps(handler) + async def wrapper(request: web.Request) -> web.Response: + if not _ASSETS_ENABLED: + return _build_error_response( + 503, + "SERVICE_DISABLED", + "Assets system is disabled. Start the server with --enable-assets to use this feature.", + ) + return await handler(request) + return wrapper # UUID regex (canonical hyphenated form, case-insensitive) UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" @@ -64,11 +79,13 @@ def get_query_dict(request: web.Request) -> dict[str, Any]: # do not rely on the code in /app/assets remaining the same. -def register_assets_system( - app: web.Application, user_manager_instance: user_manager.UserManager +def register_assets_routes( + app: web.Application, user_manager_instance: user_manager.UserManager | None = None, ) -> None: - global USER_MANAGER - USER_MANAGER = user_manager_instance + global USER_MANAGER, _ASSETS_ENABLED + if user_manager_instance is not None: + USER_MANAGER = user_manager_instance + _ASSETS_ENABLED = True app.add_routes(ROUTES) @@ -96,6 +113,7 @@ def _validate_sort_field(requested: str | None) -> str: @ROUTES.head("/api/assets/hash/{hash}") +@_require_assets_feature_enabled async def head_asset_by_hash(request: web.Request) -> web.Response: hash_str = request.match_info.get("hash", "").strip().lower() if not hash_str or ":" not in hash_str: @@ -116,6 +134,7 @@ async def head_asset_by_hash(request: web.Request) -> web.Response: @ROUTES.get("/api/assets") +@_require_assets_feature_enabled async def list_assets_route(request: web.Request) -> web.Response: """ GET request to list assets. @@ -166,6 +185,7 @@ async def list_assets_route(request: web.Request) -> web.Response: @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") +@_require_assets_feature_enabled async def get_asset_route(request: web.Request) -> web.Response: """ GET request to get an asset's info as JSON. @@ -211,6 +231,7 @@ async def get_asset_route(request: web.Request) -> web.Response: @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") +@_require_assets_feature_enabled async def download_asset_content(request: web.Request) -> web.Response: disposition = request.query.get("disposition", "attachment").lower().strip() if disposition not in {"inline", "attachment"}: @@ -264,6 +285,7 @@ async def download_asset_content(request: web.Request) -> web.Response: @ROUTES.post("/api/assets/from-hash") +@_require_assets_feature_enabled async def create_asset_from_hash_route(request: web.Request) -> web.Response: try: payload = await request.json() @@ -304,6 +326,7 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response: @ROUTES.post("/api/assets") +@_require_assets_feature_enabled async def upload_asset(request: web.Request) -> web.Response: """Multipart/form-data endpoint for Asset uploads.""" try: @@ -408,6 +431,7 @@ async def upload_asset(request: web.Request) -> web.Response: @ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") +@_require_assets_feature_enabled async def update_asset_route(request: web.Request) -> web.Response: reference_id = str(uuid.UUID(request.match_info["id"])) try: @@ -453,6 +477,7 @@ async def update_asset_route(request: web.Request) -> web.Response: @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") +@_require_assets_feature_enabled async def delete_asset_route(request: web.Request) -> web.Response: reference_id = str(uuid.UUID(request.match_info["id"])) delete_content_param = request.query.get("delete_content") @@ -484,6 +509,7 @@ async def delete_asset_route(request: web.Request) -> web.Response: @ROUTES.get("/api/tags") +@_require_assets_feature_enabled async def get_tags(request: web.Request) -> web.Response: """ GET request to list all tags based on query parameters. @@ -520,6 +546,7 @@ async def get_tags(request: web.Request) -> web.Response: @ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") +@_require_assets_feature_enabled async def add_asset_tags(request: web.Request) -> web.Response: reference_id = str(uuid.UUID(request.match_info["id"])) try: @@ -569,6 +596,7 @@ async def add_asset_tags(request: web.Request) -> web.Response: @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") +@_require_assets_feature_enabled async def delete_asset_tags(request: web.Request) -> web.Response: reference_id = str(uuid.UUID(request.match_info["id"])) try: @@ -613,6 +641,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response: @ROUTES.post("/api/assets/seed") +@_require_assets_feature_enabled async def seed_assets(request: web.Request) -> web.Response: """Trigger asset seeding for specified roots (models, input, output). @@ -662,6 +691,7 @@ async def seed_assets(request: web.Request) -> web.Response: @ROUTES.get("/api/assets/seed/status") +@_require_assets_feature_enabled async def get_seed_status(request: web.Request) -> web.Response: """Get current scan status and progress.""" status = asset_seeder.get_status() @@ -683,6 +713,7 @@ async def get_seed_status(request: web.Request) -> web.Response: @ROUTES.post("/api/assets/seed/cancel") +@_require_assets_feature_enabled async def cancel_seed(request: web.Request) -> web.Response: """Request cancellation of in-progress scan.""" cancelled = asset_seeder.cancel() @@ -692,6 +723,7 @@ async def cancel_seed(request: web.Request) -> web.Response: @ROUTES.post("/api/assets/prune") +@_require_assets_feature_enabled async def mark_missing_assets(request: web.Request) -> web.Response: """Mark assets as missing when outside all known root prefixes. diff --git a/app/assets/database/queries/__init__.py b/app/assets/database/queries/__init__.py index bd65b1997..61acdb36b 100644 --- a/app/assets/database/queries/__init__.py +++ b/app/assets/database/queries/__init__.py @@ -57,6 +57,7 @@ from app.assets.database.queries.tags import ( remove_missing_tag_for_asset_id, remove_tags_from_reference, set_reference_tags, + validate_tags_exist, ) __all__ = [ @@ -114,4 +115,5 @@ __all__ = [ "update_reference_updated_at", "upsert_asset", "upsert_reference", + "validate_tags_exist", ] diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py index a7419d13d..d721913e1 100644 --- a/app/assets/database/queries/asset_reference.py +++ b/app/assets/database/queries/asset_reference.py @@ -660,13 +660,16 @@ def restore_references_by_paths(session: Session, file_paths: list[str]) -> int: if not file_paths: return 0 - result = session.execute( - sa.update(AssetReference) - .where(AssetReference.file_path.in_(file_paths)) - .where(AssetReference.is_missing == True) # noqa: E712 - .values(is_missing=False) - ) - return result.rowcount + total = 0 + for chunk in iter_chunks(file_paths, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.file_path.in_(chunk)) + .where(AssetReference.is_missing == True) # noqa: E712 + .values(is_missing=False) + ) + total += result.rowcount + return total def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]: @@ -697,11 +700,14 @@ def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int: """ if not asset_ids: return 0 - session.execute( - sa.delete(AssetReference).where(AssetReference.asset_id.in_(asset_ids)) - ) - result = session.execute(sa.delete(Asset).where(Asset.id.in_(asset_ids))) - return result.rowcount + total = 0 + for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS): + session.execute( + sa.delete(AssetReference).where(AssetReference.asset_id.in_(chunk)) + ) + result = session.execute(sa.delete(Asset).where(Asset.id.in_(chunk))) + total += result.rowcount + return total def get_references_for_prefixes( diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py index 047c88793..551cc09fa 100644 --- a/app/assets/database/queries/tags.py +++ b/app/assets/database/queries/tags.py @@ -37,6 +37,17 @@ class SetTagsDict(TypedDict): total: list[str] +def validate_tags_exist(session: Session, tags: list[str]) -> None: + """Raise ValueError if any of the given tag names do not exist.""" + existing_tag_names = set( + name + for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all() + ) + missing = [t for t in tags if t not in existing_tag_names] + if missing: + raise ValueError(f"Unknown tags: {missing}") + + def ensure_tags_exist( session: Session, names: Iterable[str], tag_type: str = "user" ) -> None: diff --git a/app/assets/scanner.py b/app/assets/scanner.py index 260cf9711..3b9749387 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -44,9 +44,9 @@ from app.database.db import create_session, dependencies_available class _RefInfo(TypedDict): ref_id: str - fp: str + file_path: str exists: bool - fast_ok: bool + stat_unchanged: bool needs_verify: bool @@ -75,9 +75,7 @@ def get_prefixes_for_root(root: RootType) -> list[str]: def get_all_known_prefixes() -> list[str]: """Get all known asset prefixes across all root types.""" all_roots: tuple[RootType, ...] = ("models", "input", "output") - return [ - os.path.abspath(p) for root in all_roots for p in get_prefixes_for_root(root) - ] + return [p for root in all_roots for p in get_prefixes_for_root(root)] def collect_models_files() -> list[str]: @@ -110,10 +108,10 @@ def sync_references_with_filesystem( ) -> set[str] | None: """Reconcile asset references with filesystem for a root. - - Toggle needs_verify per reference using fast mtime/size check - - For hashed assets with at least one fast-ok ref: delete stale missing refs + - Toggle needs_verify per reference using mtime/size stat check + - For hashed assets with at least one stat-unchanged ref: delete stale missing refs - For seed assets with all refs missing: delete Asset and its references - - Optionally add/remove 'missing' tags based on fast-ok in this root + - Optionally add/remove 'missing' tags based on stat check in this root - Optionally return surviving absolute paths Args: @@ -140,10 +138,10 @@ def sync_references_with_filesystem( acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []} by_asset[row.asset_id] = acc - fast_ok = False + stat_unchanged = False try: exists = True - fast_ok = verify_file_unchanged( + stat_unchanged = verify_file_unchanged( mtime_db=row.mtime_ns, size_db=acc["size_db"], stat_result=os.stat(row.file_path, follow_symlinks=True), @@ -160,9 +158,9 @@ def sync_references_with_filesystem( acc["refs"].append( { "ref_id": row.reference_id, - "fp": row.file_path, + "file_path": row.file_path, "exists": exists, - "fast_ok": fast_ok, + "stat_unchanged": stat_unchanged, "needs_verify": row.needs_verify, } ) @@ -177,18 +175,18 @@ def sync_references_with_filesystem( for aid, acc in by_asset.items(): a_hash = acc["hash"] refs = acc["refs"] - any_fast_ok = any(r["fast_ok"] for r in refs) + any_unchanged = any(r["stat_unchanged"] for r in refs) all_missing = all(not r["exists"] for r in refs) for r in refs: if not r["exists"]: to_mark_missing.append(r["ref_id"]) continue - if r["fast_ok"]: + if r["stat_unchanged"]: to_clear_missing.append(r["ref_id"]) if r["needs_verify"]: to_clear_verify.append(r["ref_id"]) - if not r["fast_ok"] and not r["needs_verify"]: + if not r["stat_unchanged"] and not r["needs_verify"]: to_set_verify.append(r["ref_id"]) if a_hash is None: @@ -197,10 +195,10 @@ def sync_references_with_filesystem( else: for r in refs: if r["exists"]: - survivors.add(os.path.abspath(r["fp"])) + survivors.add(os.path.abspath(r["file_path"])) continue - if any_fast_ok: + if any_unchanged: for r in refs: if not r["exists"]: stale_ref_ids.append(r["ref_id"]) @@ -219,7 +217,7 @@ def sync_references_with_filesystem( for r in refs: if r["exists"]: - survivors.add(os.path.abspath(r["fp"])) + survivors.add(os.path.abspath(r["file_path"])) delete_references_by_ids(session, stale_ref_ids) stale_set = set(stale_ref_ids) @@ -349,58 +347,6 @@ def build_asset_specs( return specs, tag_pool, skipped -def build_stub_specs( - paths: list[str], - existing_paths: set[str], -) -> tuple[list[SeedAssetSpec], set[str], int]: - """Build minimal stub specs for fast phase scanning. - - Only collects filesystem metadata (stat), no file content reading. - This is the fastest possible scan to populate the asset database. - - Args: - paths: List of file paths to process - existing_paths: Set of paths that already exist in the database - - Returns: - Tuple of (specs, tag_pool, skipped_count) - """ - specs: list[SeedAssetSpec] = [] - tag_pool: set[str] = set() - skipped = 0 - - for p in paths: - abs_p = os.path.abspath(p) - if abs_p in existing_paths: - skipped += 1 - continue - try: - stat_p = os.stat(abs_p, follow_symlinks=True) - except OSError: - continue - if not stat_p.st_size: - continue - - name, tags = get_name_and_tags_from_asset_path(abs_p) - rel_fname = compute_relative_filename(abs_p) - - specs.append( - { - "abs_path": abs_p, - "size_bytes": stat_p.st_size, - "mtime_ns": get_mtime_ns(stat_p), - "info_name": name, - "tags": tags, - "fname": rel_fname, - "metadata": None, - "hash": None, - "mime_type": None, - } - ) - tag_pool.update(tags) - - return specs, tag_pool, skipped - def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int: """Insert asset specs into database, returning count of created refs.""" @@ -538,7 +484,8 @@ def enrich_asset( try: digest = compute_blake3_hash(file_path) full_hash = f"blake3:{digest}" - if not extract_metadata or metadata: + metadata_ok = not extract_metadata or metadata is not None + if metadata_ok: new_level = ENRICHMENT_HASHED except Exception as e: logging.warning("Failed to hash %s: %s", file_path, e) diff --git a/app/assets/seeder.py b/app/assets/seeder.py index 953bd5b0d..5b3fd157e 100644 --- a/app/assets/seeder.py +++ b/app/assets/seeder.py @@ -12,7 +12,7 @@ from app.assets.scanner import ( ENRICHMENT_METADATA, ENRICHMENT_STUB, RootType, - build_stub_specs, + build_asset_specs, collect_paths_for_roots, enrich_assets_batch, get_all_known_prefixes, @@ -68,35 +68,23 @@ class ScanStatus: ProgressCallback = Callable[[Progress], None] -class AssetSeeder: - """Singleton class managing background asset scanning. +class _AssetSeeder: + """Background asset scanning manager. - Thread-safe singleton that spawns ephemeral daemon threads for scanning. + Spawns ephemeral daemon threads for scanning. Each scan creates a new thread that exits when complete. + Use the module-level ``asset_seeder`` instance. """ - _instance: "AssetSeeder | None" = None - _instance_lock = threading.Lock() - - def __new__(cls) -> "AssetSeeder": - with cls._instance_lock: - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._initialized = False - return cls._instance - def __init__(self) -> None: - if self._initialized: - return - self._initialized = True self._lock = threading.Lock() self._state = State.IDLE self._progress: Progress | None = None self._errors: list[str] = [] self._thread: threading.Thread | None = None self._cancel_event = threading.Event() - self._pause_event = threading.Event() - self._pause_event.set() # Start unpaused (set = running, clear = paused) + self._run_gate = threading.Event() + self._run_gate.set() # Start unpaused (set = running, clear = paused) self._roots: tuple[RootType, ...] = () self._phase: ScanPhase = ScanPhase.FULL self._compute_hashes: bool = False @@ -154,10 +142,10 @@ class AssetSeeder: self._compute_hashes = compute_hashes self._progress_callback = progress_callback self._cancel_event.clear() - self._pause_event.set() # Ensure unpaused when starting + self._run_gate.set() # Ensure unpaused when starting self._thread = threading.Thread( target=self._run_scan, - name="AssetSeeder", + name="_AssetSeeder", daemon=True, ) self._thread.start() @@ -223,7 +211,7 @@ class AssetSeeder: logging.info("Asset seeder cancelling (was %s)", self._state.value) self._state = State.CANCELLING self._cancel_event.set() - self._pause_event.set() # Unblock if paused so thread can exit + self._run_gate.set() # Unblock if paused so thread can exit return True def stop(self) -> bool: @@ -247,7 +235,7 @@ class AssetSeeder: return False logging.info("Asset seeder pausing") self._state = State.PAUSED - self._pause_event.clear() + self._run_gate.clear() return True def resume(self) -> bool: @@ -263,7 +251,7 @@ class AssetSeeder: return False logging.info("Asset seeder resuming") self._state = State.RUNNING - self._pause_event.set() + self._run_gate.set() self._emit_event("assets.seed.resumed", {}) return True @@ -356,10 +344,10 @@ class AssetSeeder: self._thread = None def mark_missing_outside_prefixes(self) -> int: - """Mark cache states as missing when outside all known root prefixes. + """Mark references as missing when outside all known root prefixes. This is a non-destructive soft-delete operation. Assets and their - metadata are preserved, but cache states are flagged as missing. + metadata are preserved, but references are flagged as missing. They can be restored if the file reappears in a future scan. This operation is decoupled from scanning to prevent partial scans @@ -369,7 +357,7 @@ class AssetSeeder: a full scan of all roots or during maintenance. Returns: - Number of cache states marked as missing + Number of references marked as missing Raises: ScanInProgressError: If a scan is currently running @@ -389,7 +377,7 @@ class AssetSeeder: all_prefixes = get_all_known_prefixes() marked = mark_missing_outside_prefixes_safely(all_prefixes) if marked > 0: - logging.info("Marked %d cache states as missing", marked) + logging.info("Marked %d references as missing", marked) return marked finally: with self._lock: @@ -409,9 +397,9 @@ class AssetSeeder: Returns: True if scan should stop, False to continue """ - if not self._pause_event.is_set(): + if not self._run_gate.is_set(): self._emit_event("assets.seed.paused", {}) - self._pause_event.wait() # Blocks if paused + self._run_gate.wait() # Blocks if paused return self._is_cancelled() def _emit_event(self, event_type: str, data: dict) -> None: @@ -539,7 +527,11 @@ class AssetSeeder: cancelled = True return - total_enriched = self._run_enrich_phase(roots) + enrich_cancelled, total_enriched = self._run_enrich_phase(roots) + + if enrich_cancelled: + cancelled = True + return self._emit_event( "assets.seed.enrich_complete", @@ -613,7 +605,9 @@ class AssetSeeder: ) # Use stub specs (no metadata extraction, no hashing) - specs, tag_pool, skipped_existing = build_stub_specs(paths, existing_paths) + specs, tag_pool, skipped_existing = build_asset_specs( + paths, existing_paths, enable_metadata_extraction=False, compute_hashes=False, + ) self._update_progress(skipped=skipped_existing) if self._check_pause_and_cancel(): @@ -661,11 +655,11 @@ class AssetSeeder: self._update_progress(scanned=len(specs), created=total_created) return total_created, skipped_existing, total_paths - def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> int: + def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]: """Run phase 2: enrich existing records with metadata and hashes. Returns: - Total number of assets enriched + Tuple of (cancelled, total_enriched) """ total_enriched = 0 batch_size = 100 @@ -690,7 +684,7 @@ class AssetSeeder: while True: if self._check_pause_and_cancel(): logging.info("Enrich scan cancelled after %d assets", total_enriched) - break + return True, total_enriched # Fetch next batch of unenriched assets unenriched = get_unenriched_assets_for_roots( @@ -737,7 +731,7 @@ class AssetSeeder: ) last_progress_time = now - return total_enriched + return False, total_enriched -asset_seeder = AssetSeeder() +asset_seeder = _AssetSeeder() diff --git a/app/assets/services/hashing.py b/app/assets/services/hashing.py index e92d34aaf..bcd7645c3 100644 --- a/app/assets/services/hashing.py +++ b/app/assets/services/hashing.py @@ -1,4 +1,3 @@ -import asyncio import os from typing import IO @@ -18,20 +17,6 @@ def compute_blake3_hash( return _hash_file_obj(f, chunk_size) -async def compute_blake3_hash_async( - fp: str | IO[bytes], - chunk_size: int = DEFAULT_CHUNK, -) -> str: - if hasattr(fp, "read"): - return await asyncio.to_thread(compute_blake3_hash, fp, chunk_size) - - def _worker() -> str: - with open(os.fspath(fp), "rb") as f: - return _hash_file_obj(f, chunk_size) - - return await asyncio.to_thread(_worker) - - def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str: if chunk_size <= 0: chunk_size = DEFAULT_CHUNK diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index 2e046dbd9..c8331b31b 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -2,17 +2,16 @@ import contextlib import logging import mimetypes import os -from typing import Sequence +from typing import Any, Sequence -from sqlalchemy import select from sqlalchemy.orm import Session import app.assets.services.hashing as hashing -from app.assets.database.models import Asset, AssetReference, Tag from app.assets.database.queries import ( add_tags_to_reference, fetch_reference_and_asset, get_asset_by_hash, + get_existing_asset_ids, get_reference_by_file_path, get_reference_tags, get_or_create_reference, @@ -21,11 +20,13 @@ from app.assets.database.queries import ( set_reference_tags, upsert_asset, upsert_reference, + validate_tags_exist, ) from app.assets.helpers import normalize_tags from app.assets.services.file_utils import get_size_and_mtime_ns from app.assets.services.path_utils import ( compute_filename_for_reference, + compute_relative_filename, resolve_destination_from_tags, validate_path_within_base, ) @@ -55,6 +56,7 @@ def _ingest_file_from_path( require_existing_tags: bool = False, ) -> IngestResult: locator = os.path.abspath(abs_path) + user_metadata = user_metadata or {} asset_created = False asset_updated = False @@ -64,7 +66,7 @@ def _ingest_file_from_path( with create_session() as session: if preview_id: - if not session.get(Asset, preview_id): + if preview_id not in get_existing_asset_ids(session, [preview_id]): preview_id = None asset, asset_created, asset_updated = upsert_asset( @@ -94,7 +96,7 @@ def _ingest_file_from_path( norm = normalize_tags(list(tags)) if norm: if require_existing_tags: - _validate_tags_exist(session, norm) + validate_tags_exist(session, norm) add_tags_to_reference( session, reference_id=reference_id, @@ -106,7 +108,8 @@ def _ingest_file_from_path( _update_metadata_with_filename( session, reference_id=reference_id, - ref=ref, + file_path=ref.file_path, + current_metadata=ref.user_metadata, user_metadata=user_metadata, ) @@ -134,6 +137,8 @@ def _register_existing_asset( tag_origin: str = "manual", owner_id: str = "", ) -> RegisterAssetResult: + user_metadata = user_metadata or {} + with create_session() as session: asset = get_asset_by_hash(session, asset_hash=asset_hash) if not asset: @@ -157,7 +162,7 @@ def _register_existing_asset( session.commit() return result - new_meta = dict(user_metadata or {}) + new_meta = dict(user_metadata) computed_filename = compute_filename_for_reference(session, ref) if computed_filename: new_meta["filename"] = computed_filename @@ -190,29 +195,20 @@ def _register_existing_asset( return result -def _validate_tags_exist(session: Session, tags: list[str]) -> None: - existing_tag_names = set( - name - for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all() - ) - missing = [t for t in tags if t not in existing_tag_names] - if missing: - raise ValueError(f"Unknown tags: {missing}") - def _update_metadata_with_filename( session: Session, reference_id: str, - ref: AssetReference, - user_metadata: UserMetadata, + file_path: str | None, + current_metadata: dict | None, + user_metadata: dict[str, Any], ) -> None: - computed_filename = compute_filename_for_reference(session, ref) + computed_filename = compute_relative_filename(file_path) if file_path else None - current_meta = ref.user_metadata or {} + current_meta = current_metadata or {} new_meta = dict(current_meta) - if user_metadata: - for k, v in user_metadata.items(): - new_meta[k] = v + for k, v in user_metadata.items(): + new_meta[k] = v if computed_filename: new_meta["filename"] = computed_filename diff --git a/app/assets/services/path_utils.py b/app/assets/services/path_utils.py index f10229af9..eb5852aa8 100644 --- a/app/assets/services/path_utils.py +++ b/app/assets/services/path_utils.py @@ -51,8 +51,9 @@ def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: raw_subdirs = tags[1:] else: raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'") + _sep_chars = frozenset(("/", "\\", os.sep)) for i in raw_subdirs: - if i in (".", ".."): + if i in (".", "..") or _sep_chars & set(i): raise ValueError("invalid path component in tags") return base_dir, raw_subdirs if raw_subdirs else [] @@ -113,6 +114,8 @@ def get_asset_category_and_relative_path( return Path(child).is_relative_to(parent) def _compute_relative(child: str, parent: str) -> str: + # Normalize relative path, stripping any leading ".." components + # by anchoring to root (os.sep) then computing relpath back from it. return os.path.relpath( os.path.join(os.sep, os.path.relpath(child, parent)), os.sep ) diff --git a/main.py b/main.py index 7b22d783e..5801bbd9a 100644 --- a/main.py +++ b/main.py @@ -259,10 +259,10 @@ def prompt_worker(q, server_instance): extra_data[k] = sensitive[k] asset_seeder.pause() - - e.execute(item[2], prompt_id, extra_data, item[4]) - - asset_seeder.resume() + try: + e.execute(item[2], prompt_id, extra_data, item[4]) + finally: + asset_seeder.resume() need_gc = True remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] diff --git a/server.py b/server.py index 5cbcf0916..aaae89a06 100644 --- a/server.py +++ b/server.py @@ -34,7 +34,7 @@ from comfyui_version import __version__ from app.frontend_management import FrontendManager, parse_version from comfy_api.internal import _ComfyNodeInternal from app.assets.seeder import asset_seeder -from app.assets.api.routes import register_assets_system +from app.assets.api.routes import register_assets_routes from app.user_manager import UserManager from app.model_manager import ModelFileManager @@ -240,7 +240,10 @@ class PromptServer(): ) logging.info(f"[Prompt Server] web root: {self.web_root}") if args.enable_assets: - register_assets_system(self.app, self.user_manager) + register_assets_routes(self.app, self.user_manager) + else: + register_assets_routes(self.app) + asset_seeder.disable() routes = web.RouteTableDef() self.routes = routes self.last_node_id = None diff --git a/tests-unit/assets_test/test_sync_references.py b/tests-unit/assets_test/test_sync_references.py new file mode 100644 index 000000000..e646c1a2b --- /dev/null +++ b/tests-unit/assets_test/test_sync_references.py @@ -0,0 +1,350 @@ +"""Tests for sync_references_with_filesystem in scanner.py.""" + +import os +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import Session + +from app.assets.database.models import ( + Asset, + AssetReference, + AssetReferenceTag, + Base, + Tag, +) +from app.assets.scanner import sync_references_with_filesystem +from app.assets.services.file_utils import get_mtime_ns + + +@pytest.fixture +def db_engine(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return engine + + +@pytest.fixture +def session(db_engine): + with Session(db_engine) as sess: + yield sess + + +@pytest.fixture +def temp_dir(): + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +def _create_file(temp_dir: Path, name: str, content: bytes = b"\x00" * 100) -> str: + """Create a file and return its absolute path (no symlink resolution).""" + p = temp_dir / name + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(content) + return os.path.abspath(str(p)) + + +def _stat_mtime_ns(path: str) -> int: + return get_mtime_ns(os.stat(path, follow_symlinks=True)) + + +def _make_asset( + session: Session, + asset_id: str, + file_path: str, + ref_id: str, + *, + asset_hash: str | None = None, + size_bytes: int = 100, + mtime_ns: int | None = None, + needs_verify: bool = False, + is_missing: bool = False, +) -> tuple[Asset, AssetReference]: + """Insert an Asset + AssetReference and flush.""" + asset = session.get(Asset, asset_id) + if asset is None: + asset = Asset(id=asset_id, hash=asset_hash, size_bytes=size_bytes) + session.add(asset) + session.flush() + + ref = AssetReference( + id=ref_id, + asset_id=asset_id, + name=f"test-{ref_id}", + owner_id="system", + file_path=file_path, + mtime_ns=mtime_ns, + needs_verify=needs_verify, + is_missing=is_missing, + ) + session.add(ref) + session.flush() + return asset, ref + + +def _ensure_missing_tag(session: Session): + """Ensure the 'missing' tag exists.""" + if not session.get(Tag, "missing"): + session.add(Tag(name="missing", tag_type="system")) + session.flush() + + +class _VerifyCase: + def __init__(self, id, stat_unchanged, needs_verify_before, expect_needs_verify): + self.id = id + self.stat_unchanged = stat_unchanged + self.needs_verify_before = needs_verify_before + self.expect_needs_verify = expect_needs_verify + + +VERIFY_CASES = [ + _VerifyCase( + id="unchanged_clears_verify", + stat_unchanged=True, + needs_verify_before=True, + expect_needs_verify=False, + ), + _VerifyCase( + id="unchanged_keeps_clear", + stat_unchanged=True, + needs_verify_before=False, + expect_needs_verify=False, + ), + _VerifyCase( + id="changed_sets_verify", + stat_unchanged=False, + needs_verify_before=False, + expect_needs_verify=True, + ), + _VerifyCase( + id="changed_keeps_verify", + stat_unchanged=False, + needs_verify_before=True, + expect_needs_verify=True, + ), +] + + +@pytest.mark.parametrize("case", VERIFY_CASES, ids=lambda c: c.id) +def test_needs_verify_toggling(session, temp_dir, case): + """needs_verify is set/cleared based on mtime+size match.""" + fp = _create_file(temp_dir, "model.bin") + real_mtime = _stat_mtime_ns(fp) + + mtime_for_db = real_mtime if case.stat_unchanged else real_mtime + 1 + _make_asset( + session, "a1", fp, "r1", + asset_hash="blake3:abc", + mtime_ns=mtime_for_db, + needs_verify=case.needs_verify_before, + ) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.needs_verify is case.expect_needs_verify + + +class _MissingCase: + def __init__(self, id, file_exists, expect_is_missing): + self.id = id + self.file_exists = file_exists + self.expect_is_missing = expect_is_missing + + +MISSING_CASES = [ + _MissingCase(id="existing_file_not_missing", file_exists=True, expect_is_missing=False), + _MissingCase(id="missing_file_marked_missing", file_exists=False, expect_is_missing=True), +] + + +@pytest.mark.parametrize("case", MISSING_CASES, ids=lambda c: c.id) +def test_is_missing_flag(session, temp_dir, case): + """is_missing reflects whether the file exists on disk.""" + if case.file_exists: + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + else: + fp = str(temp_dir / "gone.bin") + mtime = 999 + + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + ref = session.get(AssetReference, "r1") + assert ref.is_missing is case.expect_is_missing + + +def test_seed_asset_all_missing_deletes_asset(session, temp_dir): + """Seed asset with all refs missing gets deleted entirely.""" + fp = str(temp_dir / "gone.bin") + _make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + assert session.get(Asset, "seed1") is None + assert session.get(AssetReference, "r1") is None + + +def test_seed_asset_some_exist_returns_survivors(session, temp_dir): + """Seed asset with at least one existing ref survives and is returned.""" + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=mtime) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + survivors = sync_references_with_filesystem( + session, "models", collect_existing_paths=True, + ) + session.commit() + + assert session.get(Asset, "seed1") is not None + assert os.path.abspath(fp) in survivors + + +def test_hashed_asset_prunes_missing_refs_when_one_is_ok(session, temp_dir): + """Hashed asset with one stat-unchanged ref deletes missing refs.""" + fp_ok = _create_file(temp_dir, "good.bin") + fp_gone = str(temp_dir / "gone.bin") + mtime = _stat_mtime_ns(fp_ok) + + _make_asset(session, "h1", fp_ok, "r_ok", asset_hash="blake3:aaa", mtime_ns=mtime) + # Second ref on same asset, file missing + ref_gone = AssetReference( + id="r_gone", asset_id="h1", name="gone", + owner_id="system", file_path=fp_gone, mtime_ns=999, + ) + session.add(ref_gone) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + assert session.get(AssetReference, "r_ok") is not None + assert session.get(AssetReference, "r_gone") is None + + +def test_hashed_asset_all_missing_keeps_refs(session, temp_dir): + """Hashed asset with all refs missing keeps refs (no pruning).""" + fp = str(temp_dir / "gone.bin") + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem(session, "models") + session.commit() + + session.expire_all() + assert session.get(AssetReference, "r1") is not None + ref = session.get(AssetReference, "r1") + assert ref.is_missing is True + + +def test_missing_tag_added_when_all_refs_gone(session, temp_dir): + """Missing tag is added to hashed asset when all refs are missing.""" + _ensure_missing_tag(session) + fp = str(temp_dir / "gone.bin") + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem( + session, "models", update_missing_tags=True, + ) + session.commit() + + session.expire_all() + tag_link = session.get(AssetReferenceTag, ("r1", "missing")) + assert tag_link is not None + + +def test_missing_tag_removed_when_ref_ok(session, temp_dir): + """Missing tag is removed from hashed asset when a ref is stat-unchanged.""" + _ensure_missing_tag(session) + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=mtime) + # Pre-add a stale missing tag + session.add(AssetReferenceTag( + asset_reference_id="r1", tag_name="missing", origin="automatic", + )) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem( + session, "models", update_missing_tags=True, + ) + session.commit() + + session.expire_all() + tag_link = session.get(AssetReferenceTag, ("r1", "missing")) + assert tag_link is None + + +def test_missing_tags_not_touched_when_flag_false(session, temp_dir): + """Missing tags are not modified when update_missing_tags=False.""" + _ensure_missing_tag(session) + fp = str(temp_dir / "gone.bin") + _make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + sync_references_with_filesystem( + session, "models", update_missing_tags=False, + ) + session.commit() + + tag_link = session.get(AssetReferenceTag, ("r1", "missing")) + assert tag_link is None # tag was never added + + +def test_returns_none_when_collect_false(session, temp_dir): + fp = _create_file(temp_dir, "model.bin") + mtime = _stat_mtime_ns(fp) + _make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime) + session.commit() + + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + result = sync_references_with_filesystem( + session, "models", collect_existing_paths=False, + ) + + assert result is None + + +def test_returns_empty_set_for_no_prefixes(session): + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[]): + result = sync_references_with_filesystem( + session, "models", collect_existing_paths=True, + ) + + assert result == set() + + +def test_no_references_is_noop(session, temp_dir): + """No crash and no side effects when there are no references.""" + with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]): + survivors = sync_references_with_filesystem( + session, "models", collect_existing_paths=True, + ) + session.commit() + + assert survivors == set() diff --git a/tests-unit/seeder_test/test_seeder.py b/tests-unit/seeder_test/test_seeder.py index 5ac0e0c08..db3795e48 100644 --- a/tests-unit/seeder_test/test_seeder.py +++ b/tests-unit/seeder_test/test_seeder.py @@ -1,19 +1,18 @@ -"""Unit tests for the AssetSeeder background scanning class.""" +"""Unit tests for the _AssetSeeder background scanning class.""" import threading from unittest.mock import patch import pytest -from app.assets.seeder import AssetSeeder, Progress, ScanInProgressError, ScanPhase, State +from app.assets.database.queries.asset_reference import UnenrichedReferenceRow +from app.assets.seeder import _AssetSeeder, Progress, ScanInProgressError, ScanPhase, State @pytest.fixture def fresh_seeder(): - """Create a fresh AssetSeeder instance for testing (bypasses singleton).""" - seeder = object.__new__(AssetSeeder) - seeder._initialized = False - seeder.__init__() + """Create a fresh _AssetSeeder instance for testing.""" + seeder = _AssetSeeder() yield seeder seeder.shutdown(timeout=1.0) @@ -25,7 +24,7 @@ def mock_dependencies(): patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), - patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), @@ -36,11 +35,11 @@ def mock_dependencies(): class TestSeederStateTransitions: """Test state machine transitions.""" - def test_initial_state_is_idle(self, fresh_seeder: AssetSeeder): + def test_initial_state_is_idle(self, fresh_seeder: _AssetSeeder): assert fresh_seeder.get_status().state == State.IDLE def test_start_transitions_to_running( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): barrier = threading.Event() reached = threading.Event() @@ -61,7 +60,7 @@ class TestSeederStateTransitions: barrier.set() def test_start_while_running_returns_false( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): barrier = threading.Event() reached = threading.Event() @@ -83,7 +82,7 @@ class TestSeederStateTransitions: barrier.set() def test_cancel_transitions_to_cancelling( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): barrier = threading.Event() reached = threading.Event() @@ -105,12 +104,12 @@ class TestSeederStateTransitions: barrier.set() - def test_cancel_when_idle_returns_false(self, fresh_seeder: AssetSeeder): + def test_cancel_when_idle_returns_false(self, fresh_seeder: _AssetSeeder): cancelled = fresh_seeder.cancel() assert cancelled is False def test_state_returns_to_idle_after_completion( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): fresh_seeder.start(roots=("models",)) completed = fresh_seeder.wait(timeout=5.0) @@ -122,7 +121,7 @@ class TestSeederWait: """Test wait() behavior.""" def test_wait_blocks_until_complete( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): fresh_seeder.start(roots=("models",)) completed = fresh_seeder.wait(timeout=5.0) @@ -130,7 +129,7 @@ class TestSeederWait: assert fresh_seeder.get_status().state == State.IDLE def test_wait_returns_false_on_timeout( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): barrier = threading.Event() @@ -147,7 +146,7 @@ class TestSeederWait: barrier.set() - def test_wait_when_idle_returns_true(self, fresh_seeder: AssetSeeder): + def test_wait_when_idle_returns_true(self, fresh_seeder: _AssetSeeder): completed = fresh_seeder.wait(timeout=1.0) assert completed is True @@ -156,7 +155,7 @@ class TestSeederProgress: """Test progress tracking.""" def test_get_status_returns_progress_during_scan( - self, fresh_seeder: AssetSeeder + self, fresh_seeder: _AssetSeeder ): barrier = threading.Event() reached = threading.Event() @@ -172,7 +171,7 @@ class TestSeederProgress: patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.collect_paths_for_roots", return_value=paths), - patch("app.assets.seeder.build_stub_specs", side_effect=slow_build), + patch("app.assets.seeder.build_asset_specs", side_effect=slow_build), patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), @@ -188,7 +187,7 @@ class TestSeederProgress: barrier.set() def test_progress_callback_is_invoked( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): progress_updates: list[Progress] = [] @@ -209,7 +208,7 @@ class TestSeederCancellation: """Test cancellation behavior.""" def test_scan_commits_partial_progress_on_cancellation( - self, fresh_seeder: AssetSeeder + self, fresh_seeder: _AssetSeeder ): insert_count = 0 barrier = threading.Event() @@ -245,7 +244,7 @@ class TestSeederCancellation: patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.collect_paths_for_roots", return_value=paths), patch( - "app.assets.seeder.build_stub_specs", return_value=(specs, set(), 0) + "app.assets.seeder.build_asset_specs", return_value=(specs, set(), 0) ), patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert), patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), @@ -264,7 +263,7 @@ class TestSeederCancellation: class TestSeederErrorHandling: """Test error handling behavior.""" - def test_database_errors_captured_in_status(self, fresh_seeder: AssetSeeder): + def test_database_errors_captured_in_status(self, fresh_seeder: _AssetSeeder): with ( patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.sync_root_safely", return_value=set()), @@ -273,7 +272,7 @@ class TestSeederErrorHandling: return_value=["/path/file.safetensors"], ), patch( - "app.assets.seeder.build_stub_specs", + "app.assets.seeder.build_asset_specs", return_value=( [ { @@ -307,7 +306,7 @@ class TestSeederErrorHandling: assert "DB connection failed" in status.errors[0] def test_dependencies_unavailable_captured_in_errors( - self, fresh_seeder: AssetSeeder + self, fresh_seeder: _AssetSeeder ): with patch("app.assets.seeder.dependencies_available", return_value=False): fresh_seeder.start(roots=("models",)) @@ -317,7 +316,7 @@ class TestSeederErrorHandling: assert len(status.errors) > 0 assert "dependencies" in status.errors[0].lower() - def test_thread_crash_resets_state_to_idle(self, fresh_seeder: AssetSeeder): + def test_thread_crash_resets_state_to_idle(self, fresh_seeder: _AssetSeeder): with ( patch("app.assets.seeder.dependencies_available", return_value=True), patch( @@ -337,7 +336,7 @@ class TestSeederThreadSafety: """Test thread safety of concurrent operations.""" def test_concurrent_start_calls_spawn_only_one_thread( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): barrier = threading.Event() @@ -364,7 +363,7 @@ class TestSeederThreadSafety: assert sum(results) == 1 def test_get_status_safe_during_scan( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): barrier = threading.Event() reached = threading.Event() @@ -395,7 +394,7 @@ class TestSeederThreadSafety: class TestSeederMarkMissing: """Test mark_missing_outside_prefixes behavior.""" - def test_mark_missing_when_idle(self, fresh_seeder: AssetSeeder): + def test_mark_missing_when_idle(self, fresh_seeder: _AssetSeeder): with ( patch("app.assets.seeder.dependencies_available", return_value=True), patch( @@ -411,7 +410,7 @@ class TestSeederMarkMissing: mock_mark.assert_called_once_with(["/models", "/input", "/output"]) def test_mark_missing_raises_when_running( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): barrier = threading.Event() reached = threading.Event() @@ -433,14 +432,14 @@ class TestSeederMarkMissing: barrier.set() def test_mark_missing_returns_zero_when_dependencies_unavailable( - self, fresh_seeder: AssetSeeder + self, fresh_seeder: _AssetSeeder ): with patch("app.assets.seeder.dependencies_available", return_value=False): result = fresh_seeder.mark_missing_outside_prefixes() assert result == 0 def test_prune_first_flag_triggers_mark_missing_before_scan( - self, fresh_seeder: AssetSeeder + self, fresh_seeder: _AssetSeeder ): call_order = [] @@ -458,7 +457,7 @@ class TestSeederMarkMissing: patch("app.assets.seeder.mark_missing_outside_prefixes_safely", side_effect=track_mark), patch("app.assets.seeder.sync_root_safely", side_effect=track_sync), patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), - patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), @@ -473,7 +472,7 @@ class TestSeederMarkMissing: class TestSeederPhases: """Test phased scanning behavior.""" - def test_start_fast_only_runs_fast_phase(self, fresh_seeder: AssetSeeder): + def test_start_fast_only_runs_fast_phase(self, fresh_seeder: _AssetSeeder): """Verify start_fast only runs the fast phase.""" fast_called = [] enrich_called = [] @@ -490,7 +489,7 @@ class TestSeederPhases: patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), - patch("app.assets.seeder.build_stub_specs", side_effect=track_fast), + patch("app.assets.seeder.build_asset_specs", side_effect=track_fast), patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), @@ -501,7 +500,7 @@ class TestSeederPhases: assert len(fast_called) == 1 assert len(enrich_called) == 0 - def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: AssetSeeder): + def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: _AssetSeeder): """Verify start_enrich only runs the enrich phase.""" fast_called = [] enrich_called = [] @@ -518,7 +517,7 @@ class TestSeederPhases: patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), - patch("app.assets.seeder.build_stub_specs", side_effect=track_fast), + patch("app.assets.seeder.build_asset_specs", side_effect=track_fast), patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), @@ -529,7 +528,7 @@ class TestSeederPhases: assert len(fast_called) == 0 assert len(enrich_called) == 1 - def test_full_scan_runs_both_phases(self, fresh_seeder: AssetSeeder): + def test_full_scan_runs_both_phases(self, fresh_seeder: _AssetSeeder): """Verify full scan runs both fast and enrich phases.""" fast_called = [] enrich_called = [] @@ -546,7 +545,7 @@ class TestSeederPhases: patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), - patch("app.assets.seeder.build_stub_specs", side_effect=track_fast), + patch("app.assets.seeder.build_asset_specs", side_effect=track_fast), patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), @@ -562,7 +561,7 @@ class TestSeederPauseResume: """Test pause/resume behavior.""" def test_pause_transitions_to_paused( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): barrier = threading.Event() reached = threading.Event() @@ -584,12 +583,12 @@ class TestSeederPauseResume: barrier.set() - def test_pause_when_idle_returns_false(self, fresh_seeder: AssetSeeder): + def test_pause_when_idle_returns_false(self, fresh_seeder: _AssetSeeder): paused = fresh_seeder.pause() assert paused is False def test_resume_returns_to_running( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): barrier = threading.Event() reached = threading.Event() @@ -615,7 +614,7 @@ class TestSeederPauseResume: barrier.set() def test_resume_when_not_paused_returns_false( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): barrier = threading.Event() reached = threading.Event() @@ -637,7 +636,7 @@ class TestSeederPauseResume: barrier.set() def test_cancel_while_paused_works( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): barrier = threading.Event() reached_checkpoint = threading.Event() @@ -667,7 +666,7 @@ class TestSeederStopRestart: """Test stop and restart behavior.""" def test_stop_is_alias_for_cancel( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): barrier = threading.Event() reached = threading.Event() @@ -690,7 +689,7 @@ class TestSeederStopRestart: barrier.set() def test_restart_cancels_and_starts_new_scan( - self, fresh_seeder: AssetSeeder, mock_dependencies + self, fresh_seeder: _AssetSeeder, mock_dependencies ): barrier = threading.Event() reached = threading.Event() @@ -717,7 +716,7 @@ class TestSeederStopRestart: fresh_seeder.wait(timeout=5.0) assert start_count == 2 - def test_restart_preserves_previous_params(self, fresh_seeder: AssetSeeder): + def test_restart_preserves_previous_params(self, fresh_seeder: _AssetSeeder): """Verify restart uses previous params when not overridden.""" collected_roots = [] @@ -729,7 +728,7 @@ class TestSeederStopRestart: patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect), - patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), @@ -744,7 +743,7 @@ class TestSeederStopRestart: assert collected_roots[0] == ("input", "output") assert collected_roots[1] == ("input", "output") - def test_restart_can_override_params(self, fresh_seeder: AssetSeeder): + def test_restart_can_override_params(self, fresh_seeder: _AssetSeeder): """Verify restart can override previous params.""" collected_roots = [] @@ -756,7 +755,7 @@ class TestSeederStopRestart: patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect), - patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), @@ -770,3 +769,132 @@ class TestSeederStopRestart: assert len(collected_roots) == 2 assert collected_roots[0] == ("models",) assert collected_roots[1] == ("input",) + + +def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow: + return UnenrichedReferenceRow( + reference_id=ref_id, asset_id=asset_id, + file_path=f"/fake/{ref_id}.bin", enrichment_level=0, + ) + + +class TestEnrichPhaseDefensiveLogic: + """Test skip_ids filtering and consecutive_empty termination.""" + + def test_failed_refs_are_skipped_on_subsequent_batches( + self, fresh_seeder: _AssetSeeder, + ): + """References that fail enrichment are filtered out of future batches.""" + row_a = _make_row("r1") + row_b = _make_row("r2") + call_count = 0 + + def fake_get_unenriched(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count <= 2: + return [row_a, row_b] + return [] + + enriched_refs: list[list[str]] = [] + + def fake_enrich(rows, **kwargs): + ref_ids = [r.reference_id for r in rows] + enriched_refs.append(ref_ids) + # r1 always fails, r2 succeeds + failed = [r.reference_id for r in rows if r.reference_id == "r1"] + enriched = len(rows) - len(failed) + return enriched, failed + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH) + fresh_seeder.wait(timeout=5.0) + + # First batch: both refs attempted + assert "r1" in enriched_refs[0] + assert "r2" in enriched_refs[0] + # Second batch: r1 filtered out + assert "r1" not in enriched_refs[1] + assert "r2" in enriched_refs[1] + + def test_stops_after_consecutive_empty_batches( + self, fresh_seeder: _AssetSeeder, + ): + """Enrich phase terminates after 3 consecutive batches with zero progress.""" + row = _make_row("r1") + batch_count = 0 + + def fake_get_unenriched(*args, **kwargs): + nonlocal batch_count + batch_count += 1 + # Always return the same row (simulating a permanently failing ref) + return [row] + + def fake_enrich(rows, **kwargs): + # Always fail — zero enriched, all failed + return 0, [r.reference_id for r in rows] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH) + fresh_seeder.wait(timeout=5.0) + + # Should stop after exactly 3 consecutive empty batches + # Batch 1: returns row, enrich fails → filtered out in batch 2+ + # But get_unenriched keeps returning it, filter removes it → empty → break + # Actually: batch 1 has row, fails. Batch 2 get_unenriched returns [row], + # skip_ids filters it → empty list → breaks via `if not unenriched: break` + # So it terminates in 2 calls to get_unenriched. + assert batch_count == 2 + + def test_consecutive_empty_counter_resets_on_success( + self, fresh_seeder: _AssetSeeder, + ): + """A successful batch resets the consecutive empty counter.""" + call_count = 0 + + def fake_get_unenriched(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count <= 6: + return [_make_row(f"r{call_count}", f"a{call_count}")] + return [] + + def fake_enrich(rows, **kwargs): + ref_id = rows[0].reference_id + # Fail batches 1-2, succeed batch 3, fail batches 4-5, succeed batch 6 + if ref_id in ("r1", "r2", "r4", "r5"): + return 0, [ref_id] + return 1, [] + + with ( + patch("app.assets.seeder.dependencies_available", return_value=True), + patch("app.assets.seeder.sync_root_safely", return_value=set()), + patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), + patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)), + patch("app.assets.seeder.insert_asset_specs", return_value=0), + patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched), + patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich), + ): + fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH) + fresh_seeder.wait(timeout=5.0) + + # All 6 batches should run + 1 final call returning empty + assert call_count == 7 + status = fresh_seeder.get_status() + assert status.state == State.IDLE