fix: address code review feedback

- Fix missing import for compute_filename_for_reference in ingest.py
- Apply code review fixes across routes, queries, scanner, seeder,
  hashing, ingest, path_utils, main, and server
- Update and add tests for sync references and seeder

Amp-Thread-ID: https://ampcode.com/threads/T-019cb61a-ed54-738c-a05f-9b5242e513f3
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr 2026-03-03 15:51:35 -08:00
parent 3232f48a41
commit 4d4c2cedd3
13 changed files with 675 additions and 218 deletions

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import functools
import json import json
import logging import logging
import os import os
@ -39,6 +40,20 @@ from app.assets.services import (
ROUTES = web.RouteTableDef() ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None USER_MANAGER: user_manager.UserManager | None = None
_ASSETS_ENABLED = False
def _require_assets_feature_enabled(handler):
@functools.wraps(handler)
async def wrapper(request: web.Request) -> web.Response:
if not _ASSETS_ENABLED:
return _build_error_response(
503,
"SERVICE_DISABLED",
"Assets system is disabled. Start the server with --enable-assets to use this feature.",
)
return await handler(request)
return wrapper
# UUID regex (canonical hyphenated form, case-insensitive) # UUID regex (canonical hyphenated form, case-insensitive)
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
@ -64,11 +79,13 @@ def get_query_dict(request: web.Request) -> dict[str, Any]:
# do not rely on the code in /app/assets remaining the same. # do not rely on the code in /app/assets remaining the same.
def register_assets_system( def register_assets_routes(
app: web.Application, user_manager_instance: user_manager.UserManager app: web.Application, user_manager_instance: user_manager.UserManager | None = None,
) -> None: ) -> None:
global USER_MANAGER global USER_MANAGER, _ASSETS_ENABLED
USER_MANAGER = user_manager_instance if user_manager_instance is not None:
USER_MANAGER = user_manager_instance
_ASSETS_ENABLED = True
app.add_routes(ROUTES) app.add_routes(ROUTES)
@ -96,6 +113,7 @@ def _validate_sort_field(requested: str | None) -> str:
@ROUTES.head("/api/assets/hash/{hash}") @ROUTES.head("/api/assets/hash/{hash}")
@_require_assets_feature_enabled
async def head_asset_by_hash(request: web.Request) -> web.Response: async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower() hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str: if not hash_str or ":" not in hash_str:
@ -116,6 +134,7 @@ async def head_asset_by_hash(request: web.Request) -> web.Response:
@ROUTES.get("/api/assets") @ROUTES.get("/api/assets")
@_require_assets_feature_enabled
async def list_assets_route(request: web.Request) -> web.Response: async def list_assets_route(request: web.Request) -> web.Response:
""" """
GET request to list assets. GET request to list assets.
@ -166,6 +185,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
@_require_assets_feature_enabled
async def get_asset_route(request: web.Request) -> web.Response: async def get_asset_route(request: web.Request) -> web.Response:
""" """
GET request to get an asset's info as JSON. GET request to get an asset's info as JSON.
@ -211,6 +231,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
@_require_assets_feature_enabled
async def download_asset_content(request: web.Request) -> web.Response: async def download_asset_content(request: web.Request) -> web.Response:
disposition = request.query.get("disposition", "attachment").lower().strip() disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}: if disposition not in {"inline", "attachment"}:
@ -264,6 +285,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/from-hash") @ROUTES.post("/api/assets/from-hash")
@_require_assets_feature_enabled
async def create_asset_from_hash_route(request: web.Request) -> web.Response: async def create_asset_from_hash_route(request: web.Request) -> web.Response:
try: try:
payload = await request.json() payload = await request.json()
@ -304,6 +326,7 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets") @ROUTES.post("/api/assets")
@_require_assets_feature_enabled
async def upload_asset(request: web.Request) -> web.Response: async def upload_asset(request: web.Request) -> web.Response:
"""Multipart/form-data endpoint for Asset uploads.""" """Multipart/form-data endpoint for Asset uploads."""
try: try:
@ -408,6 +431,7 @@ async def upload_asset(request: web.Request) -> web.Response:
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") @ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
@_require_assets_feature_enabled
async def update_asset_route(request: web.Request) -> web.Response: async def update_asset_route(request: web.Request) -> web.Response:
reference_id = str(uuid.UUID(request.match_info["id"])) reference_id = str(uuid.UUID(request.match_info["id"]))
try: try:
@ -453,6 +477,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
@_require_assets_feature_enabled
async def delete_asset_route(request: web.Request) -> web.Response: async def delete_asset_route(request: web.Request) -> web.Response:
reference_id = str(uuid.UUID(request.match_info["id"])) reference_id = str(uuid.UUID(request.match_info["id"]))
delete_content_param = request.query.get("delete_content") delete_content_param = request.query.get("delete_content")
@ -484,6 +509,7 @@ async def delete_asset_route(request: web.Request) -> web.Response:
@ROUTES.get("/api/tags") @ROUTES.get("/api/tags")
@_require_assets_feature_enabled
async def get_tags(request: web.Request) -> web.Response: async def get_tags(request: web.Request) -> web.Response:
""" """
GET request to list all tags based on query parameters. GET request to list all tags based on query parameters.
@ -520,6 +546,7 @@ async def get_tags(request: web.Request) -> web.Response:
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") @ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
@_require_assets_feature_enabled
async def add_asset_tags(request: web.Request) -> web.Response: async def add_asset_tags(request: web.Request) -> web.Response:
reference_id = str(uuid.UUID(request.match_info["id"])) reference_id = str(uuid.UUID(request.match_info["id"]))
try: try:
@ -569,6 +596,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
@_require_assets_feature_enabled
async def delete_asset_tags(request: web.Request) -> web.Response: async def delete_asset_tags(request: web.Request) -> web.Response:
reference_id = str(uuid.UUID(request.match_info["id"])) reference_id = str(uuid.UUID(request.match_info["id"]))
try: try:
@ -613,6 +641,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/seed") @ROUTES.post("/api/assets/seed")
@_require_assets_feature_enabled
async def seed_assets(request: web.Request) -> web.Response: async def seed_assets(request: web.Request) -> web.Response:
"""Trigger asset seeding for specified roots (models, input, output). """Trigger asset seeding for specified roots (models, input, output).
@ -662,6 +691,7 @@ async def seed_assets(request: web.Request) -> web.Response:
@ROUTES.get("/api/assets/seed/status") @ROUTES.get("/api/assets/seed/status")
@_require_assets_feature_enabled
async def get_seed_status(request: web.Request) -> web.Response: async def get_seed_status(request: web.Request) -> web.Response:
"""Get current scan status and progress.""" """Get current scan status and progress."""
status = asset_seeder.get_status() status = asset_seeder.get_status()
@ -683,6 +713,7 @@ async def get_seed_status(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/seed/cancel") @ROUTES.post("/api/assets/seed/cancel")
@_require_assets_feature_enabled
async def cancel_seed(request: web.Request) -> web.Response: async def cancel_seed(request: web.Request) -> web.Response:
"""Request cancellation of in-progress scan.""" """Request cancellation of in-progress scan."""
cancelled = asset_seeder.cancel() cancelled = asset_seeder.cancel()
@ -692,6 +723,7 @@ async def cancel_seed(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/prune") @ROUTES.post("/api/assets/prune")
@_require_assets_feature_enabled
async def mark_missing_assets(request: web.Request) -> web.Response: async def mark_missing_assets(request: web.Request) -> web.Response:
"""Mark assets as missing when outside all known root prefixes. """Mark assets as missing when outside all known root prefixes.

View File

@ -57,6 +57,7 @@ from app.assets.database.queries.tags import (
remove_missing_tag_for_asset_id, remove_missing_tag_for_asset_id,
remove_tags_from_reference, remove_tags_from_reference,
set_reference_tags, set_reference_tags,
validate_tags_exist,
) )
__all__ = [ __all__ = [
@ -114,4 +115,5 @@ __all__ = [
"update_reference_updated_at", "update_reference_updated_at",
"upsert_asset", "upsert_asset",
"upsert_reference", "upsert_reference",
"validate_tags_exist",
] ]

View File

@ -660,13 +660,16 @@ def restore_references_by_paths(session: Session, file_paths: list[str]) -> int:
if not file_paths: if not file_paths:
return 0 return 0
result = session.execute( total = 0
sa.update(AssetReference) for chunk in iter_chunks(file_paths, MAX_BIND_PARAMS):
.where(AssetReference.file_path.in_(file_paths)) result = session.execute(
.where(AssetReference.is_missing == True) # noqa: E712 sa.update(AssetReference)
.values(is_missing=False) .where(AssetReference.file_path.in_(chunk))
) .where(AssetReference.is_missing == True) # noqa: E712
return result.rowcount .values(is_missing=False)
)
total += result.rowcount
return total
def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]: def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]:
@ -697,11 +700,14 @@ def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int:
""" """
if not asset_ids: if not asset_ids:
return 0 return 0
session.execute( total = 0
sa.delete(AssetReference).where(AssetReference.asset_id.in_(asset_ids)) for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS):
) session.execute(
result = session.execute(sa.delete(Asset).where(Asset.id.in_(asset_ids))) sa.delete(AssetReference).where(AssetReference.asset_id.in_(chunk))
return result.rowcount )
result = session.execute(sa.delete(Asset).where(Asset.id.in_(chunk)))
total += result.rowcount
return total
def get_references_for_prefixes( def get_references_for_prefixes(

View File

@ -37,6 +37,17 @@ class SetTagsDict(TypedDict):
total: list[str] total: list[str]
def validate_tags_exist(session: Session, tags: list[str]) -> None:
"""Raise ValueError if any of the given tag names do not exist."""
existing_tag_names = set(
name
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
)
missing = [t for t in tags if t not in existing_tag_names]
if missing:
raise ValueError(f"Unknown tags: {missing}")
def ensure_tags_exist( def ensure_tags_exist(
session: Session, names: Iterable[str], tag_type: str = "user" session: Session, names: Iterable[str], tag_type: str = "user"
) -> None: ) -> None:

View File

@ -44,9 +44,9 @@ from app.database.db import create_session, dependencies_available
class _RefInfo(TypedDict): class _RefInfo(TypedDict):
ref_id: str ref_id: str
fp: str file_path: str
exists: bool exists: bool
fast_ok: bool stat_unchanged: bool
needs_verify: bool needs_verify: bool
@ -75,9 +75,7 @@ def get_prefixes_for_root(root: RootType) -> list[str]:
def get_all_known_prefixes() -> list[str]: def get_all_known_prefixes() -> list[str]:
"""Get all known asset prefixes across all root types.""" """Get all known asset prefixes across all root types."""
all_roots: tuple[RootType, ...] = ("models", "input", "output") all_roots: tuple[RootType, ...] = ("models", "input", "output")
return [ return [p for root in all_roots for p in get_prefixes_for_root(root)]
os.path.abspath(p) for root in all_roots for p in get_prefixes_for_root(root)
]
def collect_models_files() -> list[str]: def collect_models_files() -> list[str]:
@ -110,10 +108,10 @@ def sync_references_with_filesystem(
) -> set[str] | None: ) -> set[str] | None:
"""Reconcile asset references with filesystem for a root. """Reconcile asset references with filesystem for a root.
- Toggle needs_verify per reference using fast mtime/size check - Toggle needs_verify per reference using mtime/size stat check
- For hashed assets with at least one fast-ok ref: delete stale missing refs - For hashed assets with at least one stat-unchanged ref: delete stale missing refs
- For seed assets with all refs missing: delete Asset and its references - For seed assets with all refs missing: delete Asset and its references
- Optionally add/remove 'missing' tags based on fast-ok in this root - Optionally add/remove 'missing' tags based on stat check in this root
- Optionally return surviving absolute paths - Optionally return surviving absolute paths
Args: Args:
@ -140,10 +138,10 @@ def sync_references_with_filesystem(
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []} acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []}
by_asset[row.asset_id] = acc by_asset[row.asset_id] = acc
fast_ok = False stat_unchanged = False
try: try:
exists = True exists = True
fast_ok = verify_file_unchanged( stat_unchanged = verify_file_unchanged(
mtime_db=row.mtime_ns, mtime_db=row.mtime_ns,
size_db=acc["size_db"], size_db=acc["size_db"],
stat_result=os.stat(row.file_path, follow_symlinks=True), stat_result=os.stat(row.file_path, follow_symlinks=True),
@ -160,9 +158,9 @@ def sync_references_with_filesystem(
acc["refs"].append( acc["refs"].append(
{ {
"ref_id": row.reference_id, "ref_id": row.reference_id,
"fp": row.file_path, "file_path": row.file_path,
"exists": exists, "exists": exists,
"fast_ok": fast_ok, "stat_unchanged": stat_unchanged,
"needs_verify": row.needs_verify, "needs_verify": row.needs_verify,
} }
) )
@ -177,18 +175,18 @@ def sync_references_with_filesystem(
for aid, acc in by_asset.items(): for aid, acc in by_asset.items():
a_hash = acc["hash"] a_hash = acc["hash"]
refs = acc["refs"] refs = acc["refs"]
any_fast_ok = any(r["fast_ok"] for r in refs) any_unchanged = any(r["stat_unchanged"] for r in refs)
all_missing = all(not r["exists"] for r in refs) all_missing = all(not r["exists"] for r in refs)
for r in refs: for r in refs:
if not r["exists"]: if not r["exists"]:
to_mark_missing.append(r["ref_id"]) to_mark_missing.append(r["ref_id"])
continue continue
if r["fast_ok"]: if r["stat_unchanged"]:
to_clear_missing.append(r["ref_id"]) to_clear_missing.append(r["ref_id"])
if r["needs_verify"]: if r["needs_verify"]:
to_clear_verify.append(r["ref_id"]) to_clear_verify.append(r["ref_id"])
if not r["fast_ok"] and not r["needs_verify"]: if not r["stat_unchanged"] and not r["needs_verify"]:
to_set_verify.append(r["ref_id"]) to_set_verify.append(r["ref_id"])
if a_hash is None: if a_hash is None:
@ -197,10 +195,10 @@ def sync_references_with_filesystem(
else: else:
for r in refs: for r in refs:
if r["exists"]: if r["exists"]:
survivors.add(os.path.abspath(r["fp"])) survivors.add(os.path.abspath(r["file_path"]))
continue continue
if any_fast_ok: if any_unchanged:
for r in refs: for r in refs:
if not r["exists"]: if not r["exists"]:
stale_ref_ids.append(r["ref_id"]) stale_ref_ids.append(r["ref_id"])
@ -219,7 +217,7 @@ def sync_references_with_filesystem(
for r in refs: for r in refs:
if r["exists"]: if r["exists"]:
survivors.add(os.path.abspath(r["fp"])) survivors.add(os.path.abspath(r["file_path"]))
delete_references_by_ids(session, stale_ref_ids) delete_references_by_ids(session, stale_ref_ids)
stale_set = set(stale_ref_ids) stale_set = set(stale_ref_ids)
@ -349,58 +347,6 @@ def build_asset_specs(
return specs, tag_pool, skipped return specs, tag_pool, skipped
def build_stub_specs(
paths: list[str],
existing_paths: set[str],
) -> tuple[list[SeedAssetSpec], set[str], int]:
"""Build minimal stub specs for fast phase scanning.
Only collects filesystem metadata (stat), no file content reading.
This is the fastest possible scan to populate the asset database.
Args:
paths: List of file paths to process
existing_paths: Set of paths that already exist in the database
Returns:
Tuple of (specs, tag_pool, skipped_count)
"""
specs: list[SeedAssetSpec] = []
tag_pool: set[str] = set()
skipped = 0
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=True)
except OSError:
continue
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
rel_fname = compute_relative_filename(abs_p)
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": get_mtime_ns(stat_p),
"info_name": name,
"tags": tags,
"fname": rel_fname,
"metadata": None,
"hash": None,
"mime_type": None,
}
)
tag_pool.update(tags)
return specs, tag_pool, skipped
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int: def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
"""Insert asset specs into database, returning count of created refs.""" """Insert asset specs into database, returning count of created refs."""
@ -538,7 +484,8 @@ def enrich_asset(
try: try:
digest = compute_blake3_hash(file_path) digest = compute_blake3_hash(file_path)
full_hash = f"blake3:{digest}" full_hash = f"blake3:{digest}"
if not extract_metadata or metadata: metadata_ok = not extract_metadata or metadata is not None
if metadata_ok:
new_level = ENRICHMENT_HASHED new_level = ENRICHMENT_HASHED
except Exception as e: except Exception as e:
logging.warning("Failed to hash %s: %s", file_path, e) logging.warning("Failed to hash %s: %s", file_path, e)

View File

@ -12,7 +12,7 @@ from app.assets.scanner import (
ENRICHMENT_METADATA, ENRICHMENT_METADATA,
ENRICHMENT_STUB, ENRICHMENT_STUB,
RootType, RootType,
build_stub_specs, build_asset_specs,
collect_paths_for_roots, collect_paths_for_roots,
enrich_assets_batch, enrich_assets_batch,
get_all_known_prefixes, get_all_known_prefixes,
@ -68,35 +68,23 @@ class ScanStatus:
ProgressCallback = Callable[[Progress], None] ProgressCallback = Callable[[Progress], None]
class AssetSeeder: class _AssetSeeder:
"""Singleton class managing background asset scanning. """Background asset scanning manager.
Thread-safe singleton that spawns ephemeral daemon threads for scanning. Spawns ephemeral daemon threads for scanning.
Each scan creates a new thread that exits when complete. Each scan creates a new thread that exits when complete.
Use the module-level ``asset_seeder`` instance.
""" """
_instance: "AssetSeeder | None" = None
_instance_lock = threading.Lock()
def __new__(cls) -> "AssetSeeder":
with cls._instance_lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self) -> None: def __init__(self) -> None:
if self._initialized:
return
self._initialized = True
self._lock = threading.Lock() self._lock = threading.Lock()
self._state = State.IDLE self._state = State.IDLE
self._progress: Progress | None = None self._progress: Progress | None = None
self._errors: list[str] = [] self._errors: list[str] = []
self._thread: threading.Thread | None = None self._thread: threading.Thread | None = None
self._cancel_event = threading.Event() self._cancel_event = threading.Event()
self._pause_event = threading.Event() self._run_gate = threading.Event()
self._pause_event.set() # Start unpaused (set = running, clear = paused) self._run_gate.set() # Start unpaused (set = running, clear = paused)
self._roots: tuple[RootType, ...] = () self._roots: tuple[RootType, ...] = ()
self._phase: ScanPhase = ScanPhase.FULL self._phase: ScanPhase = ScanPhase.FULL
self._compute_hashes: bool = False self._compute_hashes: bool = False
@ -154,10 +142,10 @@ class AssetSeeder:
self._compute_hashes = compute_hashes self._compute_hashes = compute_hashes
self._progress_callback = progress_callback self._progress_callback = progress_callback
self._cancel_event.clear() self._cancel_event.clear()
self._pause_event.set() # Ensure unpaused when starting self._run_gate.set() # Ensure unpaused when starting
self._thread = threading.Thread( self._thread = threading.Thread(
target=self._run_scan, target=self._run_scan,
name="AssetSeeder", name="_AssetSeeder",
daemon=True, daemon=True,
) )
self._thread.start() self._thread.start()
@ -223,7 +211,7 @@ class AssetSeeder:
logging.info("Asset seeder cancelling (was %s)", self._state.value) logging.info("Asset seeder cancelling (was %s)", self._state.value)
self._state = State.CANCELLING self._state = State.CANCELLING
self._cancel_event.set() self._cancel_event.set()
self._pause_event.set() # Unblock if paused so thread can exit self._run_gate.set() # Unblock if paused so thread can exit
return True return True
def stop(self) -> bool: def stop(self) -> bool:
@ -247,7 +235,7 @@ class AssetSeeder:
return False return False
logging.info("Asset seeder pausing") logging.info("Asset seeder pausing")
self._state = State.PAUSED self._state = State.PAUSED
self._pause_event.clear() self._run_gate.clear()
return True return True
def resume(self) -> bool: def resume(self) -> bool:
@ -263,7 +251,7 @@ class AssetSeeder:
return False return False
logging.info("Asset seeder resuming") logging.info("Asset seeder resuming")
self._state = State.RUNNING self._state = State.RUNNING
self._pause_event.set() self._run_gate.set()
self._emit_event("assets.seed.resumed", {}) self._emit_event("assets.seed.resumed", {})
return True return True
@ -356,10 +344,10 @@ class AssetSeeder:
self._thread = None self._thread = None
def mark_missing_outside_prefixes(self) -> int: def mark_missing_outside_prefixes(self) -> int:
"""Mark cache states as missing when outside all known root prefixes. """Mark references as missing when outside all known root prefixes.
This is a non-destructive soft-delete operation. Assets and their This is a non-destructive soft-delete operation. Assets and their
metadata are preserved, but cache states are flagged as missing. metadata are preserved, but references are flagged as missing.
They can be restored if the file reappears in a future scan. They can be restored if the file reappears in a future scan.
This operation is decoupled from scanning to prevent partial scans This operation is decoupled from scanning to prevent partial scans
@ -369,7 +357,7 @@ class AssetSeeder:
a full scan of all roots or during maintenance. a full scan of all roots or during maintenance.
Returns: Returns:
Number of cache states marked as missing Number of references marked as missing
Raises: Raises:
ScanInProgressError: If a scan is currently running ScanInProgressError: If a scan is currently running
@ -389,7 +377,7 @@ class AssetSeeder:
all_prefixes = get_all_known_prefixes() all_prefixes = get_all_known_prefixes()
marked = mark_missing_outside_prefixes_safely(all_prefixes) marked = mark_missing_outside_prefixes_safely(all_prefixes)
if marked > 0: if marked > 0:
logging.info("Marked %d cache states as missing", marked) logging.info("Marked %d references as missing", marked)
return marked return marked
finally: finally:
with self._lock: with self._lock:
@ -409,9 +397,9 @@ class AssetSeeder:
Returns: Returns:
True if scan should stop, False to continue True if scan should stop, False to continue
""" """
if not self._pause_event.is_set(): if not self._run_gate.is_set():
self._emit_event("assets.seed.paused", {}) self._emit_event("assets.seed.paused", {})
self._pause_event.wait() # Blocks if paused self._run_gate.wait() # Blocks if paused
return self._is_cancelled() return self._is_cancelled()
def _emit_event(self, event_type: str, data: dict) -> None: def _emit_event(self, event_type: str, data: dict) -> None:
@ -539,7 +527,11 @@ class AssetSeeder:
cancelled = True cancelled = True
return return
total_enriched = self._run_enrich_phase(roots) enrich_cancelled, total_enriched = self._run_enrich_phase(roots)
if enrich_cancelled:
cancelled = True
return
self._emit_event( self._emit_event(
"assets.seed.enrich_complete", "assets.seed.enrich_complete",
@ -613,7 +605,9 @@ class AssetSeeder:
) )
# Use stub specs (no metadata extraction, no hashing) # Use stub specs (no metadata extraction, no hashing)
specs, tag_pool, skipped_existing = build_stub_specs(paths, existing_paths) specs, tag_pool, skipped_existing = build_asset_specs(
paths, existing_paths, enable_metadata_extraction=False, compute_hashes=False,
)
self._update_progress(skipped=skipped_existing) self._update_progress(skipped=skipped_existing)
if self._check_pause_and_cancel(): if self._check_pause_and_cancel():
@ -661,11 +655,11 @@ class AssetSeeder:
self._update_progress(scanned=len(specs), created=total_created) self._update_progress(scanned=len(specs), created=total_created)
return total_created, skipped_existing, total_paths return total_created, skipped_existing, total_paths
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> int: def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]:
"""Run phase 2: enrich existing records with metadata and hashes. """Run phase 2: enrich existing records with metadata and hashes.
Returns: Returns:
Total number of assets enriched Tuple of (cancelled, total_enriched)
""" """
total_enriched = 0 total_enriched = 0
batch_size = 100 batch_size = 100
@ -690,7 +684,7 @@ class AssetSeeder:
while True: while True:
if self._check_pause_and_cancel(): if self._check_pause_and_cancel():
logging.info("Enrich scan cancelled after %d assets", total_enriched) logging.info("Enrich scan cancelled after %d assets", total_enriched)
break return True, total_enriched
# Fetch next batch of unenriched assets # Fetch next batch of unenriched assets
unenriched = get_unenriched_assets_for_roots( unenriched = get_unenriched_assets_for_roots(
@ -737,7 +731,7 @@ class AssetSeeder:
) )
last_progress_time = now last_progress_time = now
return total_enriched return False, total_enriched
asset_seeder = AssetSeeder() asset_seeder = _AssetSeeder()

View File

@ -1,4 +1,3 @@
import asyncio
import os import os
from typing import IO from typing import IO
@ -18,20 +17,6 @@ def compute_blake3_hash(
return _hash_file_obj(f, chunk_size) return _hash_file_obj(f, chunk_size)
async def compute_blake3_hash_async(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
if hasattr(fp, "read"):
return await asyncio.to_thread(compute_blake3_hash, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
return await asyncio.to_thread(_worker)
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str: def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
if chunk_size <= 0: if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK chunk_size = DEFAULT_CHUNK

View File

@ -2,17 +2,16 @@ import contextlib
import logging import logging
import mimetypes import mimetypes
import os import os
from typing import Sequence from typing import Any, Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
import app.assets.services.hashing as hashing import app.assets.services.hashing as hashing
from app.assets.database.models import Asset, AssetReference, Tag
from app.assets.database.queries import ( from app.assets.database.queries import (
add_tags_to_reference, add_tags_to_reference,
fetch_reference_and_asset, fetch_reference_and_asset,
get_asset_by_hash, get_asset_by_hash,
get_existing_asset_ids,
get_reference_by_file_path, get_reference_by_file_path,
get_reference_tags, get_reference_tags,
get_or_create_reference, get_or_create_reference,
@ -21,11 +20,13 @@ from app.assets.database.queries import (
set_reference_tags, set_reference_tags,
upsert_asset, upsert_asset,
upsert_reference, upsert_reference,
validate_tags_exist,
) )
from app.assets.helpers import normalize_tags from app.assets.helpers import normalize_tags
from app.assets.services.file_utils import get_size_and_mtime_ns from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import ( from app.assets.services.path_utils import (
compute_filename_for_reference, compute_filename_for_reference,
compute_relative_filename,
resolve_destination_from_tags, resolve_destination_from_tags,
validate_path_within_base, validate_path_within_base,
) )
@ -55,6 +56,7 @@ def _ingest_file_from_path(
require_existing_tags: bool = False, require_existing_tags: bool = False,
) -> IngestResult: ) -> IngestResult:
locator = os.path.abspath(abs_path) locator = os.path.abspath(abs_path)
user_metadata = user_metadata or {}
asset_created = False asset_created = False
asset_updated = False asset_updated = False
@ -64,7 +66,7 @@ def _ingest_file_from_path(
with create_session() as session: with create_session() as session:
if preview_id: if preview_id:
if not session.get(Asset, preview_id): if preview_id not in get_existing_asset_ids(session, [preview_id]):
preview_id = None preview_id = None
asset, asset_created, asset_updated = upsert_asset( asset, asset_created, asset_updated = upsert_asset(
@ -94,7 +96,7 @@ def _ingest_file_from_path(
norm = normalize_tags(list(tags)) norm = normalize_tags(list(tags))
if norm: if norm:
if require_existing_tags: if require_existing_tags:
_validate_tags_exist(session, norm) validate_tags_exist(session, norm)
add_tags_to_reference( add_tags_to_reference(
session, session,
reference_id=reference_id, reference_id=reference_id,
@ -106,7 +108,8 @@ def _ingest_file_from_path(
_update_metadata_with_filename( _update_metadata_with_filename(
session, session,
reference_id=reference_id, reference_id=reference_id,
ref=ref, file_path=ref.file_path,
current_metadata=ref.user_metadata,
user_metadata=user_metadata, user_metadata=user_metadata,
) )
@ -134,6 +137,8 @@ def _register_existing_asset(
tag_origin: str = "manual", tag_origin: str = "manual",
owner_id: str = "", owner_id: str = "",
) -> RegisterAssetResult: ) -> RegisterAssetResult:
user_metadata = user_metadata or {}
with create_session() as session: with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=asset_hash) asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset: if not asset:
@ -157,7 +162,7 @@ def _register_existing_asset(
session.commit() session.commit()
return result return result
new_meta = dict(user_metadata or {}) new_meta = dict(user_metadata)
computed_filename = compute_filename_for_reference(session, ref) computed_filename = compute_filename_for_reference(session, ref)
if computed_filename: if computed_filename:
new_meta["filename"] = computed_filename new_meta["filename"] = computed_filename
@ -190,29 +195,20 @@ def _register_existing_asset(
return result return result
def _validate_tags_exist(session: Session, tags: list[str]) -> None:
existing_tag_names = set(
name
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
)
missing = [t for t in tags if t not in existing_tag_names]
if missing:
raise ValueError(f"Unknown tags: {missing}")
def _update_metadata_with_filename( def _update_metadata_with_filename(
session: Session, session: Session,
reference_id: str, reference_id: str,
ref: AssetReference, file_path: str | None,
user_metadata: UserMetadata, current_metadata: dict | None,
user_metadata: dict[str, Any],
) -> None: ) -> None:
computed_filename = compute_filename_for_reference(session, ref) computed_filename = compute_relative_filename(file_path) if file_path else None
current_meta = ref.user_metadata or {} current_meta = current_metadata or {}
new_meta = dict(current_meta) new_meta = dict(current_meta)
if user_metadata: for k, v in user_metadata.items():
for k, v in user_metadata.items(): new_meta[k] = v
new_meta[k] = v
if computed_filename: if computed_filename:
new_meta["filename"] = computed_filename new_meta["filename"] = computed_filename

View File

@ -51,8 +51,9 @@ def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
raw_subdirs = tags[1:] raw_subdirs = tags[1:]
else: else:
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'") raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
_sep_chars = frozenset(("/", "\\", os.sep))
for i in raw_subdirs: for i in raw_subdirs:
if i in (".", ".."): if i in (".", "..") or _sep_chars & set(i):
raise ValueError("invalid path component in tags") raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else [] return base_dir, raw_subdirs if raw_subdirs else []
@ -113,6 +114,8 @@ def get_asset_category_and_relative_path(
return Path(child).is_relative_to(parent) return Path(child).is_relative_to(parent)
def _compute_relative(child: str, parent: str) -> str: def _compute_relative(child: str, parent: str) -> str:
# Normalize relative path, stripping any leading ".." components
# by anchoring to root (os.sep) then computing relpath back from it.
return os.path.relpath( return os.path.relpath(
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
) )

View File

@ -259,10 +259,10 @@ def prompt_worker(q, server_instance):
extra_data[k] = sensitive[k] extra_data[k] = sensitive[k]
asset_seeder.pause() asset_seeder.pause()
try:
e.execute(item[2], prompt_id, extra_data, item[4]) e.execute(item[2], prompt_id, extra_data, item[4])
finally:
asset_seeder.resume() asset_seeder.resume()
need_gc = True need_gc = True
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]

View File

@ -34,7 +34,7 @@ from comfyui_version import __version__
from app.frontend_management import FrontendManager, parse_version from app.frontend_management import FrontendManager, parse_version
from comfy_api.internal import _ComfyNodeInternal from comfy_api.internal import _ComfyNodeInternal
from app.assets.seeder import asset_seeder from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_system from app.assets.api.routes import register_assets_routes
from app.user_manager import UserManager from app.user_manager import UserManager
from app.model_manager import ModelFileManager from app.model_manager import ModelFileManager
@ -240,7 +240,10 @@ class PromptServer():
) )
logging.info(f"[Prompt Server] web root: {self.web_root}") logging.info(f"[Prompt Server] web root: {self.web_root}")
if args.enable_assets: if args.enable_assets:
register_assets_system(self.app, self.user_manager) register_assets_routes(self.app, self.user_manager)
else:
register_assets_routes(self.app)
asset_seeder.disable()
routes = web.RouteTableDef() routes = web.RouteTableDef()
self.routes = routes self.routes = routes
self.last_node_id = None self.last_node_id = None

View File

@ -0,0 +1,350 @@
"""Tests for sync_references_with_filesystem in scanner.py."""
import os
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from app.assets.database.models import (
Asset,
AssetReference,
AssetReferenceTag,
Base,
Tag,
)
from app.assets.scanner import sync_references_with_filesystem
from app.assets.services.file_utils import get_mtime_ns
@pytest.fixture
def db_engine():
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
return engine
@pytest.fixture
def session(db_engine):
with Session(db_engine) as sess:
yield sess
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
def _create_file(temp_dir: Path, name: str, content: bytes = b"\x00" * 100) -> str:
"""Create a file and return its absolute path (no symlink resolution)."""
p = temp_dir / name
p.parent.mkdir(parents=True, exist_ok=True)
p.write_bytes(content)
return os.path.abspath(str(p))
def _stat_mtime_ns(path: str) -> int:
return get_mtime_ns(os.stat(path, follow_symlinks=True))
def _make_asset(
session: Session,
asset_id: str,
file_path: str,
ref_id: str,
*,
asset_hash: str | None = None,
size_bytes: int = 100,
mtime_ns: int | None = None,
needs_verify: bool = False,
is_missing: bool = False,
) -> tuple[Asset, AssetReference]:
"""Insert an Asset + AssetReference and flush."""
asset = session.get(Asset, asset_id)
if asset is None:
asset = Asset(id=asset_id, hash=asset_hash, size_bytes=size_bytes)
session.add(asset)
session.flush()
ref = AssetReference(
id=ref_id,
asset_id=asset_id,
name=f"test-{ref_id}",
owner_id="system",
file_path=file_path,
mtime_ns=mtime_ns,
needs_verify=needs_verify,
is_missing=is_missing,
)
session.add(ref)
session.flush()
return asset, ref
def _ensure_missing_tag(session: Session):
"""Ensure the 'missing' tag exists."""
if not session.get(Tag, "missing"):
session.add(Tag(name="missing", tag_type="system"))
session.flush()
class _VerifyCase:
def __init__(self, id, stat_unchanged, needs_verify_before, expect_needs_verify):
self.id = id
self.stat_unchanged = stat_unchanged
self.needs_verify_before = needs_verify_before
self.expect_needs_verify = expect_needs_verify
VERIFY_CASES = [
_VerifyCase(
id="unchanged_clears_verify",
stat_unchanged=True,
needs_verify_before=True,
expect_needs_verify=False,
),
_VerifyCase(
id="unchanged_keeps_clear",
stat_unchanged=True,
needs_verify_before=False,
expect_needs_verify=False,
),
_VerifyCase(
id="changed_sets_verify",
stat_unchanged=False,
needs_verify_before=False,
expect_needs_verify=True,
),
_VerifyCase(
id="changed_keeps_verify",
stat_unchanged=False,
needs_verify_before=True,
expect_needs_verify=True,
),
]
@pytest.mark.parametrize("case", VERIFY_CASES, ids=lambda c: c.id)
def test_needs_verify_toggling(session, temp_dir, case):
"""needs_verify is set/cleared based on mtime+size match."""
fp = _create_file(temp_dir, "model.bin")
real_mtime = _stat_mtime_ns(fp)
mtime_for_db = real_mtime if case.stat_unchanged else real_mtime + 1
_make_asset(
session, "a1", fp, "r1",
asset_hash="blake3:abc",
mtime_ns=mtime_for_db,
needs_verify=case.needs_verify_before,
)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(session, "models")
session.commit()
session.expire_all()
ref = session.get(AssetReference, "r1")
assert ref.needs_verify is case.expect_needs_verify
class _MissingCase:
def __init__(self, id, file_exists, expect_is_missing):
self.id = id
self.file_exists = file_exists
self.expect_is_missing = expect_is_missing
MISSING_CASES = [
_MissingCase(id="existing_file_not_missing", file_exists=True, expect_is_missing=False),
_MissingCase(id="missing_file_marked_missing", file_exists=False, expect_is_missing=True),
]
@pytest.mark.parametrize("case", MISSING_CASES, ids=lambda c: c.id)
def test_is_missing_flag(session, temp_dir, case):
"""is_missing reflects whether the file exists on disk."""
if case.file_exists:
fp = _create_file(temp_dir, "model.bin")
mtime = _stat_mtime_ns(fp)
else:
fp = str(temp_dir / "gone.bin")
mtime = 999
_make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(session, "models")
session.commit()
session.expire_all()
ref = session.get(AssetReference, "r1")
assert ref.is_missing is case.expect_is_missing
def test_seed_asset_all_missing_deletes_asset(session, temp_dir):
"""Seed asset with all refs missing gets deleted entirely."""
fp = str(temp_dir / "gone.bin")
_make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=999)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(session, "models")
session.commit()
assert session.get(Asset, "seed1") is None
assert session.get(AssetReference, "r1") is None
def test_seed_asset_some_exist_returns_survivors(session, temp_dir):
"""Seed asset with at least one existing ref survives and is returned."""
fp = _create_file(temp_dir, "model.bin")
mtime = _stat_mtime_ns(fp)
_make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=mtime)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
survivors = sync_references_with_filesystem(
session, "models", collect_existing_paths=True,
)
session.commit()
assert session.get(Asset, "seed1") is not None
assert os.path.abspath(fp) in survivors
def test_hashed_asset_prunes_missing_refs_when_one_is_ok(session, temp_dir):
"""Hashed asset with one stat-unchanged ref deletes missing refs."""
fp_ok = _create_file(temp_dir, "good.bin")
fp_gone = str(temp_dir / "gone.bin")
mtime = _stat_mtime_ns(fp_ok)
_make_asset(session, "h1", fp_ok, "r_ok", asset_hash="blake3:aaa", mtime_ns=mtime)
# Second ref on same asset, file missing
ref_gone = AssetReference(
id="r_gone", asset_id="h1", name="gone",
owner_id="system", file_path=fp_gone, mtime_ns=999,
)
session.add(ref_gone)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(session, "models")
session.commit()
session.expire_all()
assert session.get(AssetReference, "r_ok") is not None
assert session.get(AssetReference, "r_gone") is None
def test_hashed_asset_all_missing_keeps_refs(session, temp_dir):
"""Hashed asset with all refs missing keeps refs (no pruning)."""
fp = str(temp_dir / "gone.bin")
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(session, "models")
session.commit()
session.expire_all()
assert session.get(AssetReference, "r1") is not None
ref = session.get(AssetReference, "r1")
assert ref.is_missing is True
def test_missing_tag_added_when_all_refs_gone(session, temp_dir):
"""Missing tag is added to hashed asset when all refs are missing."""
_ensure_missing_tag(session)
fp = str(temp_dir / "gone.bin")
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(
session, "models", update_missing_tags=True,
)
session.commit()
session.expire_all()
tag_link = session.get(AssetReferenceTag, ("r1", "missing"))
assert tag_link is not None
def test_missing_tag_removed_when_ref_ok(session, temp_dir):
"""Missing tag is removed from hashed asset when a ref is stat-unchanged."""
_ensure_missing_tag(session)
fp = _create_file(temp_dir, "model.bin")
mtime = _stat_mtime_ns(fp)
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=mtime)
# Pre-add a stale missing tag
session.add(AssetReferenceTag(
asset_reference_id="r1", tag_name="missing", origin="automatic",
))
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(
session, "models", update_missing_tags=True,
)
session.commit()
session.expire_all()
tag_link = session.get(AssetReferenceTag, ("r1", "missing"))
assert tag_link is None
def test_missing_tags_not_touched_when_flag_false(session, temp_dir):
"""Missing tags are not modified when update_missing_tags=False."""
_ensure_missing_tag(session)
fp = str(temp_dir / "gone.bin")
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(
session, "models", update_missing_tags=False,
)
session.commit()
tag_link = session.get(AssetReferenceTag, ("r1", "missing"))
assert tag_link is None # tag was never added
def test_returns_none_when_collect_false(session, temp_dir):
fp = _create_file(temp_dir, "model.bin")
mtime = _stat_mtime_ns(fp)
_make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
result = sync_references_with_filesystem(
session, "models", collect_existing_paths=False,
)
assert result is None
def test_returns_empty_set_for_no_prefixes(session):
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[]):
result = sync_references_with_filesystem(
session, "models", collect_existing_paths=True,
)
assert result == set()
def test_no_references_is_noop(session, temp_dir):
"""No crash and no side effects when there are no references."""
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
survivors = sync_references_with_filesystem(
session, "models", collect_existing_paths=True,
)
session.commit()
assert survivors == set()

View File

@ -1,19 +1,18 @@
"""Unit tests for the AssetSeeder background scanning class.""" """Unit tests for the _AssetSeeder background scanning class."""
import threading import threading
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from app.assets.seeder import AssetSeeder, Progress, ScanInProgressError, ScanPhase, State from app.assets.database.queries.asset_reference import UnenrichedReferenceRow
from app.assets.seeder import _AssetSeeder, Progress, ScanInProgressError, ScanPhase, State
@pytest.fixture @pytest.fixture
def fresh_seeder(): def fresh_seeder():
"""Create a fresh AssetSeeder instance for testing (bypasses singleton).""" """Create a fresh _AssetSeeder instance for testing."""
seeder = object.__new__(AssetSeeder) seeder = _AssetSeeder()
seeder._initialized = False
seeder.__init__()
yield seeder yield seeder
seeder.shutdown(timeout=1.0) seeder.shutdown(timeout=1.0)
@ -25,7 +24,7 @@ def mock_dependencies():
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)), patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -36,11 +35,11 @@ def mock_dependencies():
class TestSeederStateTransitions: class TestSeederStateTransitions:
"""Test state machine transitions.""" """Test state machine transitions."""
def test_initial_state_is_idle(self, fresh_seeder: AssetSeeder): def test_initial_state_is_idle(self, fresh_seeder: _AssetSeeder):
assert fresh_seeder.get_status().state == State.IDLE assert fresh_seeder.get_status().state == State.IDLE
def test_start_transitions_to_running( def test_start_transitions_to_running(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event() reached = threading.Event()
@ -61,7 +60,7 @@ class TestSeederStateTransitions:
barrier.set() barrier.set()
def test_start_while_running_returns_false( def test_start_while_running_returns_false(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event() reached = threading.Event()
@ -83,7 +82,7 @@ class TestSeederStateTransitions:
barrier.set() barrier.set()
def test_cancel_transitions_to_cancelling( def test_cancel_transitions_to_cancelling(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event() reached = threading.Event()
@ -105,12 +104,12 @@ class TestSeederStateTransitions:
barrier.set() barrier.set()
def test_cancel_when_idle_returns_false(self, fresh_seeder: AssetSeeder): def test_cancel_when_idle_returns_false(self, fresh_seeder: _AssetSeeder):
cancelled = fresh_seeder.cancel() cancelled = fresh_seeder.cancel()
assert cancelled is False assert cancelled is False
def test_state_returns_to_idle_after_completion( def test_state_returns_to_idle_after_completion(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
completed = fresh_seeder.wait(timeout=5.0) completed = fresh_seeder.wait(timeout=5.0)
@ -122,7 +121,7 @@ class TestSeederWait:
"""Test wait() behavior.""" """Test wait() behavior."""
def test_wait_blocks_until_complete( def test_wait_blocks_until_complete(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
completed = fresh_seeder.wait(timeout=5.0) completed = fresh_seeder.wait(timeout=5.0)
@ -130,7 +129,7 @@ class TestSeederWait:
assert fresh_seeder.get_status().state == State.IDLE assert fresh_seeder.get_status().state == State.IDLE
def test_wait_returns_false_on_timeout( def test_wait_returns_false_on_timeout(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
@ -147,7 +146,7 @@ class TestSeederWait:
barrier.set() barrier.set()
def test_wait_when_idle_returns_true(self, fresh_seeder: AssetSeeder): def test_wait_when_idle_returns_true(self, fresh_seeder: _AssetSeeder):
completed = fresh_seeder.wait(timeout=1.0) completed = fresh_seeder.wait(timeout=1.0)
assert completed is True assert completed is True
@ -156,7 +155,7 @@ class TestSeederProgress:
"""Test progress tracking.""" """Test progress tracking."""
def test_get_status_returns_progress_during_scan( def test_get_status_returns_progress_during_scan(
self, fresh_seeder: AssetSeeder self, fresh_seeder: _AssetSeeder
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event() reached = threading.Event()
@ -172,7 +171,7 @@ class TestSeederProgress:
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths), patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
patch("app.assets.seeder.build_stub_specs", side_effect=slow_build), patch("app.assets.seeder.build_asset_specs", side_effect=slow_build),
patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -188,7 +187,7 @@ class TestSeederProgress:
barrier.set() barrier.set()
def test_progress_callback_is_invoked( def test_progress_callback_is_invoked(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
progress_updates: list[Progress] = [] progress_updates: list[Progress] = []
@ -209,7 +208,7 @@ class TestSeederCancellation:
"""Test cancellation behavior.""" """Test cancellation behavior."""
def test_scan_commits_partial_progress_on_cancellation( def test_scan_commits_partial_progress_on_cancellation(
self, fresh_seeder: AssetSeeder self, fresh_seeder: _AssetSeeder
): ):
insert_count = 0 insert_count = 0
barrier = threading.Event() barrier = threading.Event()
@ -245,7 +244,7 @@ class TestSeederCancellation:
patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths), patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
patch( patch(
"app.assets.seeder.build_stub_specs", return_value=(specs, set(), 0) "app.assets.seeder.build_asset_specs", return_value=(specs, set(), 0)
), ),
patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert), patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
@ -264,7 +263,7 @@ class TestSeederCancellation:
class TestSeederErrorHandling: class TestSeederErrorHandling:
"""Test error handling behavior.""" """Test error handling behavior."""
def test_database_errors_captured_in_status(self, fresh_seeder: AssetSeeder): def test_database_errors_captured_in_status(self, fresh_seeder: _AssetSeeder):
with ( with (
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.sync_root_safely", return_value=set()),
@ -273,7 +272,7 @@ class TestSeederErrorHandling:
return_value=["/path/file.safetensors"], return_value=["/path/file.safetensors"],
), ),
patch( patch(
"app.assets.seeder.build_stub_specs", "app.assets.seeder.build_asset_specs",
return_value=( return_value=(
[ [
{ {
@ -307,7 +306,7 @@ class TestSeederErrorHandling:
assert "DB connection failed" in status.errors[0] assert "DB connection failed" in status.errors[0]
def test_dependencies_unavailable_captured_in_errors( def test_dependencies_unavailable_captured_in_errors(
self, fresh_seeder: AssetSeeder self, fresh_seeder: _AssetSeeder
): ):
with patch("app.assets.seeder.dependencies_available", return_value=False): with patch("app.assets.seeder.dependencies_available", return_value=False):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
@ -317,7 +316,7 @@ class TestSeederErrorHandling:
assert len(status.errors) > 0 assert len(status.errors) > 0
assert "dependencies" in status.errors[0].lower() assert "dependencies" in status.errors[0].lower()
def test_thread_crash_resets_state_to_idle(self, fresh_seeder: AssetSeeder): def test_thread_crash_resets_state_to_idle(self, fresh_seeder: _AssetSeeder):
with ( with (
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch( patch(
@ -337,7 +336,7 @@ class TestSeederThreadSafety:
"""Test thread safety of concurrent operations.""" """Test thread safety of concurrent operations."""
def test_concurrent_start_calls_spawn_only_one_thread( def test_concurrent_start_calls_spawn_only_one_thread(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
@ -364,7 +363,7 @@ class TestSeederThreadSafety:
assert sum(results) == 1 assert sum(results) == 1
def test_get_status_safe_during_scan( def test_get_status_safe_during_scan(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event() reached = threading.Event()
@ -395,7 +394,7 @@ class TestSeederThreadSafety:
class TestSeederMarkMissing: class TestSeederMarkMissing:
"""Test mark_missing_outside_prefixes behavior.""" """Test mark_missing_outside_prefixes behavior."""
def test_mark_missing_when_idle(self, fresh_seeder: AssetSeeder): def test_mark_missing_when_idle(self, fresh_seeder: _AssetSeeder):
with ( with (
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch( patch(
@ -411,7 +410,7 @@ class TestSeederMarkMissing:
mock_mark.assert_called_once_with(["/models", "/input", "/output"]) mock_mark.assert_called_once_with(["/models", "/input", "/output"])
def test_mark_missing_raises_when_running( def test_mark_missing_raises_when_running(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event() reached = threading.Event()
@ -433,14 +432,14 @@ class TestSeederMarkMissing:
barrier.set() barrier.set()
def test_mark_missing_returns_zero_when_dependencies_unavailable( def test_mark_missing_returns_zero_when_dependencies_unavailable(
self, fresh_seeder: AssetSeeder self, fresh_seeder: _AssetSeeder
): ):
with patch("app.assets.seeder.dependencies_available", return_value=False): with patch("app.assets.seeder.dependencies_available", return_value=False):
result = fresh_seeder.mark_missing_outside_prefixes() result = fresh_seeder.mark_missing_outside_prefixes()
assert result == 0 assert result == 0
def test_prune_first_flag_triggers_mark_missing_before_scan( def test_prune_first_flag_triggers_mark_missing_before_scan(
self, fresh_seeder: AssetSeeder self, fresh_seeder: _AssetSeeder
): ):
call_order = [] call_order = []
@ -458,7 +457,7 @@ class TestSeederMarkMissing:
patch("app.assets.seeder.mark_missing_outside_prefixes_safely", side_effect=track_mark), patch("app.assets.seeder.mark_missing_outside_prefixes_safely", side_effect=track_mark),
patch("app.assets.seeder.sync_root_safely", side_effect=track_sync), patch("app.assets.seeder.sync_root_safely", side_effect=track_sync),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)), patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -473,7 +472,7 @@ class TestSeederMarkMissing:
class TestSeederPhases: class TestSeederPhases:
"""Test phased scanning behavior.""" """Test phased scanning behavior."""
def test_start_fast_only_runs_fast_phase(self, fresh_seeder: AssetSeeder): def test_start_fast_only_runs_fast_phase(self, fresh_seeder: _AssetSeeder):
"""Verify start_fast only runs the fast phase.""" """Verify start_fast only runs the fast phase."""
fast_called = [] fast_called = []
enrich_called = [] enrich_called = []
@ -490,7 +489,7 @@ class TestSeederPhases:
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_stub_specs", side_effect=track_fast), patch("app.assets.seeder.build_asset_specs", side_effect=track_fast),
patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -501,7 +500,7 @@ class TestSeederPhases:
assert len(fast_called) == 1 assert len(fast_called) == 1
assert len(enrich_called) == 0 assert len(enrich_called) == 0
def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: AssetSeeder): def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: _AssetSeeder):
"""Verify start_enrich only runs the enrich phase.""" """Verify start_enrich only runs the enrich phase."""
fast_called = [] fast_called = []
enrich_called = [] enrich_called = []
@ -518,7 +517,7 @@ class TestSeederPhases:
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_stub_specs", side_effect=track_fast), patch("app.assets.seeder.build_asset_specs", side_effect=track_fast),
patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -529,7 +528,7 @@ class TestSeederPhases:
assert len(fast_called) == 0 assert len(fast_called) == 0
assert len(enrich_called) == 1 assert len(enrich_called) == 1
def test_full_scan_runs_both_phases(self, fresh_seeder: AssetSeeder): def test_full_scan_runs_both_phases(self, fresh_seeder: _AssetSeeder):
"""Verify full scan runs both fast and enrich phases.""" """Verify full scan runs both fast and enrich phases."""
fast_called = [] fast_called = []
enrich_called = [] enrich_called = []
@ -546,7 +545,7 @@ class TestSeederPhases:
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]), patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_stub_specs", side_effect=track_fast), patch("app.assets.seeder.build_asset_specs", side_effect=track_fast),
patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich), patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -562,7 +561,7 @@ class TestSeederPauseResume:
"""Test pause/resume behavior.""" """Test pause/resume behavior."""
def test_pause_transitions_to_paused( def test_pause_transitions_to_paused(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event() reached = threading.Event()
@ -584,12 +583,12 @@ class TestSeederPauseResume:
barrier.set() barrier.set()
def test_pause_when_idle_returns_false(self, fresh_seeder: AssetSeeder): def test_pause_when_idle_returns_false(self, fresh_seeder: _AssetSeeder):
paused = fresh_seeder.pause() paused = fresh_seeder.pause()
assert paused is False assert paused is False
def test_resume_returns_to_running( def test_resume_returns_to_running(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event() reached = threading.Event()
@ -615,7 +614,7 @@ class TestSeederPauseResume:
barrier.set() barrier.set()
def test_resume_when_not_paused_returns_false( def test_resume_when_not_paused_returns_false(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event() reached = threading.Event()
@ -637,7 +636,7 @@ class TestSeederPauseResume:
barrier.set() barrier.set()
def test_cancel_while_paused_works( def test_cancel_while_paused_works(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached_checkpoint = threading.Event() reached_checkpoint = threading.Event()
@ -667,7 +666,7 @@ class TestSeederStopRestart:
"""Test stop and restart behavior.""" """Test stop and restart behavior."""
def test_stop_is_alias_for_cancel( def test_stop_is_alias_for_cancel(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event() reached = threading.Event()
@ -690,7 +689,7 @@ class TestSeederStopRestart:
barrier.set() barrier.set()
def test_restart_cancels_and_starts_new_scan( def test_restart_cancels_and_starts_new_scan(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: _AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event() reached = threading.Event()
@ -717,7 +716,7 @@ class TestSeederStopRestart:
fresh_seeder.wait(timeout=5.0) fresh_seeder.wait(timeout=5.0)
assert start_count == 2 assert start_count == 2
def test_restart_preserves_previous_params(self, fresh_seeder: AssetSeeder): def test_restart_preserves_previous_params(self, fresh_seeder: _AssetSeeder):
"""Verify restart uses previous params when not overridden.""" """Verify restart uses previous params when not overridden."""
collected_roots = [] collected_roots = []
@ -729,7 +728,7 @@ class TestSeederStopRestart:
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect), patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)), patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -744,7 +743,7 @@ class TestSeederStopRestart:
assert collected_roots[0] == ("input", "output") assert collected_roots[0] == ("input", "output")
assert collected_roots[1] == ("input", "output") assert collected_roots[1] == ("input", "output")
def test_restart_can_override_params(self, fresh_seeder: AssetSeeder): def test_restart_can_override_params(self, fresh_seeder: _AssetSeeder):
"""Verify restart can override previous params.""" """Verify restart can override previous params."""
collected_roots = [] collected_roots = []
@ -756,7 +755,7 @@ class TestSeederStopRestart:
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()), patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect), patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)), patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0), patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]), patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -770,3 +769,132 @@ class TestSeederStopRestart:
assert len(collected_roots) == 2 assert len(collected_roots) == 2
assert collected_roots[0] == ("models",) assert collected_roots[0] == ("models",)
assert collected_roots[1] == ("input",) assert collected_roots[1] == ("input",)
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
return UnenrichedReferenceRow(
reference_id=ref_id, asset_id=asset_id,
file_path=f"/fake/{ref_id}.bin", enrichment_level=0,
)
class TestEnrichPhaseDefensiveLogic:
"""Test skip_ids filtering and consecutive_empty termination."""
def test_failed_refs_are_skipped_on_subsequent_batches(
self, fresh_seeder: _AssetSeeder,
):
"""References that fail enrichment are filtered out of future batches."""
row_a = _make_row("r1")
row_b = _make_row("r2")
call_count = 0
def fake_get_unenriched(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count <= 2:
return [row_a, row_b]
return []
enriched_refs: list[list[str]] = []
def fake_enrich(rows, **kwargs):
ref_ids = [r.reference_id for r in rows]
enriched_refs.append(ref_ids)
# r1 always fails, r2 succeeds
failed = [r.reference_id for r in rows if r.reference_id == "r1"]
enriched = len(rows) - len(failed)
return enriched, failed
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich),
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH)
fresh_seeder.wait(timeout=5.0)
# First batch: both refs attempted
assert "r1" in enriched_refs[0]
assert "r2" in enriched_refs[0]
# Second batch: r1 filtered out
assert "r1" not in enriched_refs[1]
assert "r2" in enriched_refs[1]
def test_stops_after_consecutive_empty_batches(
self, fresh_seeder: _AssetSeeder,
):
"""Enrich phase terminates after 3 consecutive batches with zero progress."""
row = _make_row("r1")
batch_count = 0
def fake_get_unenriched(*args, **kwargs):
nonlocal batch_count
batch_count += 1
# Always return the same row (simulating a permanently failing ref)
return [row]
def fake_enrich(rows, **kwargs):
# Always fail — zero enriched, all failed
return 0, [r.reference_id for r in rows]
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich),
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH)
fresh_seeder.wait(timeout=5.0)
# Should stop after exactly 3 consecutive empty batches
# Batch 1: returns row, enrich fails → filtered out in batch 2+
# But get_unenriched keeps returning it, filter removes it → empty → break
# Actually: batch 1 has row, fails. Batch 2 get_unenriched returns [row],
# skip_ids filters it → empty list → breaks via `if not unenriched: break`
# So it terminates in 2 calls to get_unenriched.
assert batch_count == 2
def test_consecutive_empty_counter_resets_on_success(
self, fresh_seeder: _AssetSeeder,
):
"""A successful batch resets the consecutive empty counter."""
call_count = 0
def fake_get_unenriched(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count <= 6:
return [_make_row(f"r{call_count}", f"a{call_count}")]
return []
def fake_enrich(rows, **kwargs):
ref_id = rows[0].reference_id
# Fail batches 1-2, succeed batch 3, fail batches 4-5, succeed batch 6
if ref_id in ("r1", "r2", "r4", "r5"):
return 0, [ref_id]
return 1, []
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich),
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH)
fresh_seeder.wait(timeout=5.0)
# All 6 batches should run + 1 final call returning empty
assert call_count == 7
status = fresh_seeder.get_status()
assert status.state == State.IDLE