fix: address code review feedback

- Fix missing import for compute_filename_for_reference in ingest.py
- Apply code review fixes across routes, queries, scanner, seeder,
  hashing, ingest, path_utils, main, and server
- Update and add tests for sync references and seeder

Amp-Thread-ID: https://ampcode.com/threads/T-019cb61a-ed54-738c-a05f-9b5242e513f3
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr 2026-03-03 15:51:35 -08:00
parent 3232f48a41
commit 4d4c2cedd3
13 changed files with 675 additions and 218 deletions

View File

@ -1,4 +1,5 @@
import asyncio
import functools
import json
import logging
import os
@ -39,6 +40,20 @@ from app.assets.services import (
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
_ASSETS_ENABLED = False
def _require_assets_feature_enabled(handler):
@functools.wraps(handler)
async def wrapper(request: web.Request) -> web.Response:
if not _ASSETS_ENABLED:
return _build_error_response(
503,
"SERVICE_DISABLED",
"Assets system is disabled. Start the server with --enable-assets to use this feature.",
)
return await handler(request)
return wrapper
# UUID regex (canonical hyphenated form, case-insensitive)
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
@ -64,11 +79,13 @@ def get_query_dict(request: web.Request) -> dict[str, Any]:
# do not rely on the code in /app/assets remaining the same.
def register_assets_system(
app: web.Application, user_manager_instance: user_manager.UserManager
def register_assets_routes(
app: web.Application, user_manager_instance: user_manager.UserManager | None = None,
) -> None:
global USER_MANAGER
USER_MANAGER = user_manager_instance
global USER_MANAGER, _ASSETS_ENABLED
if user_manager_instance is not None:
USER_MANAGER = user_manager_instance
_ASSETS_ENABLED = True
app.add_routes(ROUTES)
@ -96,6 +113,7 @@ def _validate_sort_field(requested: str | None) -> str:
@ROUTES.head("/api/assets/hash/{hash}")
@_require_assets_feature_enabled
async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str:
@ -116,6 +134,7 @@ async def head_asset_by_hash(request: web.Request) -> web.Response:
@ROUTES.get("/api/assets")
@_require_assets_feature_enabled
async def list_assets_route(request: web.Request) -> web.Response:
"""
GET request to list assets.
@ -166,6 +185,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
@_require_assets_feature_enabled
async def get_asset_route(request: web.Request) -> web.Response:
"""
GET request to get an asset's info as JSON.
@ -211,6 +231,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
@_require_assets_feature_enabled
async def download_asset_content(request: web.Request) -> web.Response:
disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}:
@ -264,6 +285,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/from-hash")
@_require_assets_feature_enabled
async def create_asset_from_hash_route(request: web.Request) -> web.Response:
try:
payload = await request.json()
@ -304,6 +326,7 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets")
@_require_assets_feature_enabled
async def upload_asset(request: web.Request) -> web.Response:
"""Multipart/form-data endpoint for Asset uploads."""
try:
@ -408,6 +431,7 @@ async def upload_asset(request: web.Request) -> web.Response:
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
@_require_assets_feature_enabled
async def update_asset_route(request: web.Request) -> web.Response:
reference_id = str(uuid.UUID(request.match_info["id"]))
try:
@ -453,6 +477,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
@_require_assets_feature_enabled
async def delete_asset_route(request: web.Request) -> web.Response:
reference_id = str(uuid.UUID(request.match_info["id"]))
delete_content_param = request.query.get("delete_content")
@ -484,6 +509,7 @@ async def delete_asset_route(request: web.Request) -> web.Response:
@ROUTES.get("/api/tags")
@_require_assets_feature_enabled
async def get_tags(request: web.Request) -> web.Response:
"""
GET request to list all tags based on query parameters.
@ -520,6 +546,7 @@ async def get_tags(request: web.Request) -> web.Response:
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
@_require_assets_feature_enabled
async def add_asset_tags(request: web.Request) -> web.Response:
reference_id = str(uuid.UUID(request.match_info["id"]))
try:
@ -569,6 +596,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
@_require_assets_feature_enabled
async def delete_asset_tags(request: web.Request) -> web.Response:
reference_id = str(uuid.UUID(request.match_info["id"]))
try:
@ -613,6 +641,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/seed")
@_require_assets_feature_enabled
async def seed_assets(request: web.Request) -> web.Response:
"""Trigger asset seeding for specified roots (models, input, output).
@ -662,6 +691,7 @@ async def seed_assets(request: web.Request) -> web.Response:
@ROUTES.get("/api/assets/seed/status")
@_require_assets_feature_enabled
async def get_seed_status(request: web.Request) -> web.Response:
"""Get current scan status and progress."""
status = asset_seeder.get_status()
@ -683,6 +713,7 @@ async def get_seed_status(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/seed/cancel")
@_require_assets_feature_enabled
async def cancel_seed(request: web.Request) -> web.Response:
"""Request cancellation of in-progress scan."""
cancelled = asset_seeder.cancel()
@ -692,6 +723,7 @@ async def cancel_seed(request: web.Request) -> web.Response:
@ROUTES.post("/api/assets/prune")
@_require_assets_feature_enabled
async def mark_missing_assets(request: web.Request) -> web.Response:
"""Mark assets as missing when outside all known root prefixes.

View File

@ -57,6 +57,7 @@ from app.assets.database.queries.tags import (
remove_missing_tag_for_asset_id,
remove_tags_from_reference,
set_reference_tags,
validate_tags_exist,
)
__all__ = [
@ -114,4 +115,5 @@ __all__ = [
"update_reference_updated_at",
"upsert_asset",
"upsert_reference",
"validate_tags_exist",
]

View File

@ -660,13 +660,16 @@ def restore_references_by_paths(session: Session, file_paths: list[str]) -> int:
if not file_paths:
return 0
result = session.execute(
sa.update(AssetReference)
.where(AssetReference.file_path.in_(file_paths))
.where(AssetReference.is_missing == True) # noqa: E712
.values(is_missing=False)
)
return result.rowcount
total = 0
for chunk in iter_chunks(file_paths, MAX_BIND_PARAMS):
result = session.execute(
sa.update(AssetReference)
.where(AssetReference.file_path.in_(chunk))
.where(AssetReference.is_missing == True) # noqa: E712
.values(is_missing=False)
)
total += result.rowcount
return total
def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]:
@ -697,11 +700,14 @@ def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int:
"""
if not asset_ids:
return 0
session.execute(
sa.delete(AssetReference).where(AssetReference.asset_id.in_(asset_ids))
)
result = session.execute(sa.delete(Asset).where(Asset.id.in_(asset_ids)))
return result.rowcount
total = 0
for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS):
session.execute(
sa.delete(AssetReference).where(AssetReference.asset_id.in_(chunk))
)
result = session.execute(sa.delete(Asset).where(Asset.id.in_(chunk)))
total += result.rowcount
return total
def get_references_for_prefixes(

View File

@ -37,6 +37,17 @@ class SetTagsDict(TypedDict):
total: list[str]
def validate_tags_exist(session: Session, tags: list[str]) -> None:
"""Raise ValueError if any of the given tag names do not exist."""
existing_tag_names = set(
name
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
)
missing = [t for t in tags if t not in existing_tag_names]
if missing:
raise ValueError(f"Unknown tags: {missing}")
def ensure_tags_exist(
session: Session, names: Iterable[str], tag_type: str = "user"
) -> None:

View File

@ -44,9 +44,9 @@ from app.database.db import create_session, dependencies_available
class _RefInfo(TypedDict):
ref_id: str
fp: str
file_path: str
exists: bool
fast_ok: bool
stat_unchanged: bool
needs_verify: bool
@ -75,9 +75,7 @@ def get_prefixes_for_root(root: RootType) -> list[str]:
def get_all_known_prefixes() -> list[str]:
"""Get all known asset prefixes across all root types."""
all_roots: tuple[RootType, ...] = ("models", "input", "output")
return [
os.path.abspath(p) for root in all_roots for p in get_prefixes_for_root(root)
]
return [p for root in all_roots for p in get_prefixes_for_root(root)]
def collect_models_files() -> list[str]:
@ -110,10 +108,10 @@ def sync_references_with_filesystem(
) -> set[str] | None:
"""Reconcile asset references with filesystem for a root.
- Toggle needs_verify per reference using fast mtime/size check
- For hashed assets with at least one fast-ok ref: delete stale missing refs
- Toggle needs_verify per reference using mtime/size stat check
- For hashed assets with at least one stat-unchanged ref: delete stale missing refs
- For seed assets with all refs missing: delete Asset and its references
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally add/remove 'missing' tags based on stat check in this root
- Optionally return surviving absolute paths
Args:
@ -140,10 +138,10 @@ def sync_references_with_filesystem(
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []}
by_asset[row.asset_id] = acc
fast_ok = False
stat_unchanged = False
try:
exists = True
fast_ok = verify_file_unchanged(
stat_unchanged = verify_file_unchanged(
mtime_db=row.mtime_ns,
size_db=acc["size_db"],
stat_result=os.stat(row.file_path, follow_symlinks=True),
@ -160,9 +158,9 @@ def sync_references_with_filesystem(
acc["refs"].append(
{
"ref_id": row.reference_id,
"fp": row.file_path,
"file_path": row.file_path,
"exists": exists,
"fast_ok": fast_ok,
"stat_unchanged": stat_unchanged,
"needs_verify": row.needs_verify,
}
)
@ -177,18 +175,18 @@ def sync_references_with_filesystem(
for aid, acc in by_asset.items():
a_hash = acc["hash"]
refs = acc["refs"]
any_fast_ok = any(r["fast_ok"] for r in refs)
any_unchanged = any(r["stat_unchanged"] for r in refs)
all_missing = all(not r["exists"] for r in refs)
for r in refs:
if not r["exists"]:
to_mark_missing.append(r["ref_id"])
continue
if r["fast_ok"]:
if r["stat_unchanged"]:
to_clear_missing.append(r["ref_id"])
if r["needs_verify"]:
to_clear_verify.append(r["ref_id"])
if not r["fast_ok"] and not r["needs_verify"]:
if not r["stat_unchanged"] and not r["needs_verify"]:
to_set_verify.append(r["ref_id"])
if a_hash is None:
@ -197,10 +195,10 @@ def sync_references_with_filesystem(
else:
for r in refs:
if r["exists"]:
survivors.add(os.path.abspath(r["fp"]))
survivors.add(os.path.abspath(r["file_path"]))
continue
if any_fast_ok:
if any_unchanged:
for r in refs:
if not r["exists"]:
stale_ref_ids.append(r["ref_id"])
@ -219,7 +217,7 @@ def sync_references_with_filesystem(
for r in refs:
if r["exists"]:
survivors.add(os.path.abspath(r["fp"]))
survivors.add(os.path.abspath(r["file_path"]))
delete_references_by_ids(session, stale_ref_ids)
stale_set = set(stale_ref_ids)
@ -349,58 +347,6 @@ def build_asset_specs(
return specs, tag_pool, skipped
def build_stub_specs(
paths: list[str],
existing_paths: set[str],
) -> tuple[list[SeedAssetSpec], set[str], int]:
"""Build minimal stub specs for fast phase scanning.
Only collects filesystem metadata (stat), no file content reading.
This is the fastest possible scan to populate the asset database.
Args:
paths: List of file paths to process
existing_paths: Set of paths that already exist in the database
Returns:
Tuple of (specs, tag_pool, skipped_count)
"""
specs: list[SeedAssetSpec] = []
tag_pool: set[str] = set()
skipped = 0
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=True)
except OSError:
continue
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
rel_fname = compute_relative_filename(abs_p)
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": get_mtime_ns(stat_p),
"info_name": name,
"tags": tags,
"fname": rel_fname,
"metadata": None,
"hash": None,
"mime_type": None,
}
)
tag_pool.update(tags)
return specs, tag_pool, skipped
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
"""Insert asset specs into database, returning count of created refs."""
@ -538,7 +484,8 @@ def enrich_asset(
try:
digest = compute_blake3_hash(file_path)
full_hash = f"blake3:{digest}"
if not extract_metadata or metadata:
metadata_ok = not extract_metadata or metadata is not None
if metadata_ok:
new_level = ENRICHMENT_HASHED
except Exception as e:
logging.warning("Failed to hash %s: %s", file_path, e)

View File

@ -12,7 +12,7 @@ from app.assets.scanner import (
ENRICHMENT_METADATA,
ENRICHMENT_STUB,
RootType,
build_stub_specs,
build_asset_specs,
collect_paths_for_roots,
enrich_assets_batch,
get_all_known_prefixes,
@ -68,35 +68,23 @@ class ScanStatus:
ProgressCallback = Callable[[Progress], None]
class AssetSeeder:
"""Singleton class managing background asset scanning.
class _AssetSeeder:
"""Background asset scanning manager.
Thread-safe singleton that spawns ephemeral daemon threads for scanning.
Spawns ephemeral daemon threads for scanning.
Each scan creates a new thread that exits when complete.
Use the module-level ``asset_seeder`` instance.
"""
_instance: "AssetSeeder | None" = None
_instance_lock = threading.Lock()
def __new__(cls) -> "AssetSeeder":
with cls._instance_lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._initialized = True
self._lock = threading.Lock()
self._state = State.IDLE
self._progress: Progress | None = None
self._errors: list[str] = []
self._thread: threading.Thread | None = None
self._cancel_event = threading.Event()
self._pause_event = threading.Event()
self._pause_event.set() # Start unpaused (set = running, clear = paused)
self._run_gate = threading.Event()
self._run_gate.set() # Start unpaused (set = running, clear = paused)
self._roots: tuple[RootType, ...] = ()
self._phase: ScanPhase = ScanPhase.FULL
self._compute_hashes: bool = False
@ -154,10 +142,10 @@ class AssetSeeder:
self._compute_hashes = compute_hashes
self._progress_callback = progress_callback
self._cancel_event.clear()
self._pause_event.set() # Ensure unpaused when starting
self._run_gate.set() # Ensure unpaused when starting
self._thread = threading.Thread(
target=self._run_scan,
name="AssetSeeder",
name="_AssetSeeder",
daemon=True,
)
self._thread.start()
@ -223,7 +211,7 @@ class AssetSeeder:
logging.info("Asset seeder cancelling (was %s)", self._state.value)
self._state = State.CANCELLING
self._cancel_event.set()
self._pause_event.set() # Unblock if paused so thread can exit
self._run_gate.set() # Unblock if paused so thread can exit
return True
def stop(self) -> bool:
@ -247,7 +235,7 @@ class AssetSeeder:
return False
logging.info("Asset seeder pausing")
self._state = State.PAUSED
self._pause_event.clear()
self._run_gate.clear()
return True
def resume(self) -> bool:
@ -263,7 +251,7 @@ class AssetSeeder:
return False
logging.info("Asset seeder resuming")
self._state = State.RUNNING
self._pause_event.set()
self._run_gate.set()
self._emit_event("assets.seed.resumed", {})
return True
@ -356,10 +344,10 @@ class AssetSeeder:
self._thread = None
def mark_missing_outside_prefixes(self) -> int:
"""Mark cache states as missing when outside all known root prefixes.
"""Mark references as missing when outside all known root prefixes.
This is a non-destructive soft-delete operation. Assets and their
metadata are preserved, but cache states are flagged as missing.
metadata are preserved, but references are flagged as missing.
They can be restored if the file reappears in a future scan.
This operation is decoupled from scanning to prevent partial scans
@ -369,7 +357,7 @@ class AssetSeeder:
a full scan of all roots or during maintenance.
Returns:
Number of cache states marked as missing
Number of references marked as missing
Raises:
ScanInProgressError: If a scan is currently running
@ -389,7 +377,7 @@ class AssetSeeder:
all_prefixes = get_all_known_prefixes()
marked = mark_missing_outside_prefixes_safely(all_prefixes)
if marked > 0:
logging.info("Marked %d cache states as missing", marked)
logging.info("Marked %d references as missing", marked)
return marked
finally:
with self._lock:
@ -409,9 +397,9 @@ class AssetSeeder:
Returns:
True if scan should stop, False to continue
"""
if not self._pause_event.is_set():
if not self._run_gate.is_set():
self._emit_event("assets.seed.paused", {})
self._pause_event.wait() # Blocks if paused
self._run_gate.wait() # Blocks if paused
return self._is_cancelled()
def _emit_event(self, event_type: str, data: dict) -> None:
@ -539,7 +527,11 @@ class AssetSeeder:
cancelled = True
return
total_enriched = self._run_enrich_phase(roots)
enrich_cancelled, total_enriched = self._run_enrich_phase(roots)
if enrich_cancelled:
cancelled = True
return
self._emit_event(
"assets.seed.enrich_complete",
@ -613,7 +605,9 @@ class AssetSeeder:
)
# Use stub specs (no metadata extraction, no hashing)
specs, tag_pool, skipped_existing = build_stub_specs(paths, existing_paths)
specs, tag_pool, skipped_existing = build_asset_specs(
paths, existing_paths, enable_metadata_extraction=False, compute_hashes=False,
)
self._update_progress(skipped=skipped_existing)
if self._check_pause_and_cancel():
@ -661,11 +655,11 @@ class AssetSeeder:
self._update_progress(scanned=len(specs), created=total_created)
return total_created, skipped_existing, total_paths
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> int:
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]:
"""Run phase 2: enrich existing records with metadata and hashes.
Returns:
Total number of assets enriched
Tuple of (cancelled, total_enriched)
"""
total_enriched = 0
batch_size = 100
@ -690,7 +684,7 @@ class AssetSeeder:
while True:
if self._check_pause_and_cancel():
logging.info("Enrich scan cancelled after %d assets", total_enriched)
break
return True, total_enriched
# Fetch next batch of unenriched assets
unenriched = get_unenriched_assets_for_roots(
@ -737,7 +731,7 @@ class AssetSeeder:
)
last_progress_time = now
return total_enriched
return False, total_enriched
asset_seeder = AssetSeeder()
asset_seeder = _AssetSeeder()

View File

@ -1,4 +1,3 @@
import asyncio
import os
from typing import IO
@ -18,20 +17,6 @@ def compute_blake3_hash(
return _hash_file_obj(f, chunk_size)
async def compute_blake3_hash_async(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
if hasattr(fp, "read"):
return await asyncio.to_thread(compute_blake3_hash, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
return await asyncio.to_thread(_worker)
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK

View File

@ -2,17 +2,16 @@ import contextlib
import logging
import mimetypes
import os
from typing import Sequence
from typing import Any, Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
import app.assets.services.hashing as hashing
from app.assets.database.models import Asset, AssetReference, Tag
from app.assets.database.queries import (
add_tags_to_reference,
fetch_reference_and_asset,
get_asset_by_hash,
get_existing_asset_ids,
get_reference_by_file_path,
get_reference_tags,
get_or_create_reference,
@ -21,11 +20,13 @@ from app.assets.database.queries import (
set_reference_tags,
upsert_asset,
upsert_reference,
validate_tags_exist,
)
from app.assets.helpers import normalize_tags
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import (
compute_filename_for_reference,
compute_relative_filename,
resolve_destination_from_tags,
validate_path_within_base,
)
@ -55,6 +56,7 @@ def _ingest_file_from_path(
require_existing_tags: bool = False,
) -> IngestResult:
locator = os.path.abspath(abs_path)
user_metadata = user_metadata or {}
asset_created = False
asset_updated = False
@ -64,7 +66,7 @@ def _ingest_file_from_path(
with create_session() as session:
if preview_id:
if not session.get(Asset, preview_id):
if preview_id not in get_existing_asset_ids(session, [preview_id]):
preview_id = None
asset, asset_created, asset_updated = upsert_asset(
@ -94,7 +96,7 @@ def _ingest_file_from_path(
norm = normalize_tags(list(tags))
if norm:
if require_existing_tags:
_validate_tags_exist(session, norm)
validate_tags_exist(session, norm)
add_tags_to_reference(
session,
reference_id=reference_id,
@ -106,7 +108,8 @@ def _ingest_file_from_path(
_update_metadata_with_filename(
session,
reference_id=reference_id,
ref=ref,
file_path=ref.file_path,
current_metadata=ref.user_metadata,
user_metadata=user_metadata,
)
@ -134,6 +137,8 @@ def _register_existing_asset(
tag_origin: str = "manual",
owner_id: str = "",
) -> RegisterAssetResult:
user_metadata = user_metadata or {}
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
@ -157,7 +162,7 @@ def _register_existing_asset(
session.commit()
return result
new_meta = dict(user_metadata or {})
new_meta = dict(user_metadata)
computed_filename = compute_filename_for_reference(session, ref)
if computed_filename:
new_meta["filename"] = computed_filename
@ -190,29 +195,20 @@ def _register_existing_asset(
return result
def _validate_tags_exist(session: Session, tags: list[str]) -> None:
existing_tag_names = set(
name
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
)
missing = [t for t in tags if t not in existing_tag_names]
if missing:
raise ValueError(f"Unknown tags: {missing}")
def _update_metadata_with_filename(
session: Session,
reference_id: str,
ref: AssetReference,
user_metadata: UserMetadata,
file_path: str | None,
current_metadata: dict | None,
user_metadata: dict[str, Any],
) -> None:
computed_filename = compute_filename_for_reference(session, ref)
computed_filename = compute_relative_filename(file_path) if file_path else None
current_meta = ref.user_metadata or {}
current_meta = current_metadata or {}
new_meta = dict(current_meta)
if user_metadata:
for k, v in user_metadata.items():
new_meta[k] = v
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename

View File

@ -51,8 +51,9 @@ def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
raw_subdirs = tags[1:]
else:
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
_sep_chars = frozenset(("/", "\\", os.sep))
for i in raw_subdirs:
if i in (".", ".."):
if i in (".", "..") or _sep_chars & set(i):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
@ -113,6 +114,8 @@ def get_asset_category_and_relative_path(
return Path(child).is_relative_to(parent)
def _compute_relative(child: str, parent: str) -> str:
# Normalize relative path, stripping any leading ".." components
# by anchoring to root (os.sep) then computing relpath back from it.
return os.path.relpath(
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
)

View File

@ -259,10 +259,10 @@ def prompt_worker(q, server_instance):
extra_data[k] = sensitive[k]
asset_seeder.pause()
e.execute(item[2], prompt_id, extra_data, item[4])
asset_seeder.resume()
try:
e.execute(item[2], prompt_id, extra_data, item[4])
finally:
asset_seeder.resume()
need_gc = True
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]

View File

@ -34,7 +34,7 @@ from comfyui_version import __version__
from app.frontend_management import FrontendManager, parse_version
from comfy_api.internal import _ComfyNodeInternal
from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_system
from app.assets.api.routes import register_assets_routes
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
@ -240,7 +240,10 @@ class PromptServer():
)
logging.info(f"[Prompt Server] web root: {self.web_root}")
if args.enable_assets:
register_assets_system(self.app, self.user_manager)
register_assets_routes(self.app, self.user_manager)
else:
register_assets_routes(self.app)
asset_seeder.disable()
routes = web.RouteTableDef()
self.routes = routes
self.last_node_id = None

View File

@ -0,0 +1,350 @@
"""Tests for sync_references_with_filesystem in scanner.py."""
import os
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from app.assets.database.models import (
Asset,
AssetReference,
AssetReferenceTag,
Base,
Tag,
)
from app.assets.scanner import sync_references_with_filesystem
from app.assets.services.file_utils import get_mtime_ns
@pytest.fixture
def db_engine():
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
return engine
@pytest.fixture
def session(db_engine):
with Session(db_engine) as sess:
yield sess
@pytest.fixture
def temp_dir():
with tempfile.TemporaryDirectory() as tmpdir:
yield Path(tmpdir)
def _create_file(temp_dir: Path, name: str, content: bytes = b"\x00" * 100) -> str:
"""Create a file and return its absolute path (no symlink resolution)."""
p = temp_dir / name
p.parent.mkdir(parents=True, exist_ok=True)
p.write_bytes(content)
return os.path.abspath(str(p))
def _stat_mtime_ns(path: str) -> int:
return get_mtime_ns(os.stat(path, follow_symlinks=True))
def _make_asset(
session: Session,
asset_id: str,
file_path: str,
ref_id: str,
*,
asset_hash: str | None = None,
size_bytes: int = 100,
mtime_ns: int | None = None,
needs_verify: bool = False,
is_missing: bool = False,
) -> tuple[Asset, AssetReference]:
"""Insert an Asset + AssetReference and flush."""
asset = session.get(Asset, asset_id)
if asset is None:
asset = Asset(id=asset_id, hash=asset_hash, size_bytes=size_bytes)
session.add(asset)
session.flush()
ref = AssetReference(
id=ref_id,
asset_id=asset_id,
name=f"test-{ref_id}",
owner_id="system",
file_path=file_path,
mtime_ns=mtime_ns,
needs_verify=needs_verify,
is_missing=is_missing,
)
session.add(ref)
session.flush()
return asset, ref
def _ensure_missing_tag(session: Session):
"""Ensure the 'missing' tag exists."""
if not session.get(Tag, "missing"):
session.add(Tag(name="missing", tag_type="system"))
session.flush()
class _VerifyCase:
def __init__(self, id, stat_unchanged, needs_verify_before, expect_needs_verify):
self.id = id
self.stat_unchanged = stat_unchanged
self.needs_verify_before = needs_verify_before
self.expect_needs_verify = expect_needs_verify
VERIFY_CASES = [
_VerifyCase(
id="unchanged_clears_verify",
stat_unchanged=True,
needs_verify_before=True,
expect_needs_verify=False,
),
_VerifyCase(
id="unchanged_keeps_clear",
stat_unchanged=True,
needs_verify_before=False,
expect_needs_verify=False,
),
_VerifyCase(
id="changed_sets_verify",
stat_unchanged=False,
needs_verify_before=False,
expect_needs_verify=True,
),
_VerifyCase(
id="changed_keeps_verify",
stat_unchanged=False,
needs_verify_before=True,
expect_needs_verify=True,
),
]
@pytest.mark.parametrize("case", VERIFY_CASES, ids=lambda c: c.id)
def test_needs_verify_toggling(session, temp_dir, case):
"""needs_verify is set/cleared based on mtime+size match."""
fp = _create_file(temp_dir, "model.bin")
real_mtime = _stat_mtime_ns(fp)
mtime_for_db = real_mtime if case.stat_unchanged else real_mtime + 1
_make_asset(
session, "a1", fp, "r1",
asset_hash="blake3:abc",
mtime_ns=mtime_for_db,
needs_verify=case.needs_verify_before,
)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(session, "models")
session.commit()
session.expire_all()
ref = session.get(AssetReference, "r1")
assert ref.needs_verify is case.expect_needs_verify
class _MissingCase:
def __init__(self, id, file_exists, expect_is_missing):
self.id = id
self.file_exists = file_exists
self.expect_is_missing = expect_is_missing
MISSING_CASES = [
_MissingCase(id="existing_file_not_missing", file_exists=True, expect_is_missing=False),
_MissingCase(id="missing_file_marked_missing", file_exists=False, expect_is_missing=True),
]
@pytest.mark.parametrize("case", MISSING_CASES, ids=lambda c: c.id)
def test_is_missing_flag(session, temp_dir, case):
"""is_missing reflects whether the file exists on disk."""
if case.file_exists:
fp = _create_file(temp_dir, "model.bin")
mtime = _stat_mtime_ns(fp)
else:
fp = str(temp_dir / "gone.bin")
mtime = 999
_make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(session, "models")
session.commit()
session.expire_all()
ref = session.get(AssetReference, "r1")
assert ref.is_missing is case.expect_is_missing
def test_seed_asset_all_missing_deletes_asset(session, temp_dir):
"""Seed asset with all refs missing gets deleted entirely."""
fp = str(temp_dir / "gone.bin")
_make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=999)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(session, "models")
session.commit()
assert session.get(Asset, "seed1") is None
assert session.get(AssetReference, "r1") is None
def test_seed_asset_some_exist_returns_survivors(session, temp_dir):
"""Seed asset with at least one existing ref survives and is returned."""
fp = _create_file(temp_dir, "model.bin")
mtime = _stat_mtime_ns(fp)
_make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=mtime)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
survivors = sync_references_with_filesystem(
session, "models", collect_existing_paths=True,
)
session.commit()
assert session.get(Asset, "seed1") is not None
assert os.path.abspath(fp) in survivors
def test_hashed_asset_prunes_missing_refs_when_one_is_ok(session, temp_dir):
"""Hashed asset with one stat-unchanged ref deletes missing refs."""
fp_ok = _create_file(temp_dir, "good.bin")
fp_gone = str(temp_dir / "gone.bin")
mtime = _stat_mtime_ns(fp_ok)
_make_asset(session, "h1", fp_ok, "r_ok", asset_hash="blake3:aaa", mtime_ns=mtime)
# Second ref on same asset, file missing
ref_gone = AssetReference(
id="r_gone", asset_id="h1", name="gone",
owner_id="system", file_path=fp_gone, mtime_ns=999,
)
session.add(ref_gone)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(session, "models")
session.commit()
session.expire_all()
assert session.get(AssetReference, "r_ok") is not None
assert session.get(AssetReference, "r_gone") is None
def test_hashed_asset_all_missing_keeps_refs(session, temp_dir):
"""Hashed asset with all refs missing keeps refs (no pruning)."""
fp = str(temp_dir / "gone.bin")
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(session, "models")
session.commit()
session.expire_all()
assert session.get(AssetReference, "r1") is not None
ref = session.get(AssetReference, "r1")
assert ref.is_missing is True
def test_missing_tag_added_when_all_refs_gone(session, temp_dir):
"""Missing tag is added to hashed asset when all refs are missing."""
_ensure_missing_tag(session)
fp = str(temp_dir / "gone.bin")
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(
session, "models", update_missing_tags=True,
)
session.commit()
session.expire_all()
tag_link = session.get(AssetReferenceTag, ("r1", "missing"))
assert tag_link is not None
def test_missing_tag_removed_when_ref_ok(session, temp_dir):
"""Missing tag is removed from hashed asset when a ref is stat-unchanged."""
_ensure_missing_tag(session)
fp = _create_file(temp_dir, "model.bin")
mtime = _stat_mtime_ns(fp)
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=mtime)
# Pre-add a stale missing tag
session.add(AssetReferenceTag(
asset_reference_id="r1", tag_name="missing", origin="automatic",
))
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(
session, "models", update_missing_tags=True,
)
session.commit()
session.expire_all()
tag_link = session.get(AssetReferenceTag, ("r1", "missing"))
assert tag_link is None
def test_missing_tags_not_touched_when_flag_false(session, temp_dir):
"""Missing tags are not modified when update_missing_tags=False."""
_ensure_missing_tag(session)
fp = str(temp_dir / "gone.bin")
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
sync_references_with_filesystem(
session, "models", update_missing_tags=False,
)
session.commit()
tag_link = session.get(AssetReferenceTag, ("r1", "missing"))
assert tag_link is None # tag was never added
def test_returns_none_when_collect_false(session, temp_dir):
fp = _create_file(temp_dir, "model.bin")
mtime = _stat_mtime_ns(fp)
_make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime)
session.commit()
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
result = sync_references_with_filesystem(
session, "models", collect_existing_paths=False,
)
assert result is None
def test_returns_empty_set_for_no_prefixes(session):
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[]):
result = sync_references_with_filesystem(
session, "models", collect_existing_paths=True,
)
assert result == set()
def test_no_references_is_noop(session, temp_dir):
"""No crash and no side effects when there are no references."""
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
survivors = sync_references_with_filesystem(
session, "models", collect_existing_paths=True,
)
session.commit()
assert survivors == set()

View File

@ -1,19 +1,18 @@
"""Unit tests for the AssetSeeder background scanning class."""
"""Unit tests for the _AssetSeeder background scanning class."""
import threading
from unittest.mock import patch
import pytest
from app.assets.seeder import AssetSeeder, Progress, ScanInProgressError, ScanPhase, State
from app.assets.database.queries.asset_reference import UnenrichedReferenceRow
from app.assets.seeder import _AssetSeeder, Progress, ScanInProgressError, ScanPhase, State
@pytest.fixture
def fresh_seeder():
"""Create a fresh AssetSeeder instance for testing (bypasses singleton)."""
seeder = object.__new__(AssetSeeder)
seeder._initialized = False
seeder.__init__()
"""Create a fresh _AssetSeeder instance for testing."""
seeder = _AssetSeeder()
yield seeder
seeder.shutdown(timeout=1.0)
@ -25,7 +24,7 @@ def mock_dependencies():
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -36,11 +35,11 @@ def mock_dependencies():
class TestSeederStateTransitions:
"""Test state machine transitions."""
def test_initial_state_is_idle(self, fresh_seeder: AssetSeeder):
def test_initial_state_is_idle(self, fresh_seeder: _AssetSeeder):
assert fresh_seeder.get_status().state == State.IDLE
def test_start_transitions_to_running(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@ -61,7 +60,7 @@ class TestSeederStateTransitions:
barrier.set()
def test_start_while_running_returns_false(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@ -83,7 +82,7 @@ class TestSeederStateTransitions:
barrier.set()
def test_cancel_transitions_to_cancelling(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@ -105,12 +104,12 @@ class TestSeederStateTransitions:
barrier.set()
def test_cancel_when_idle_returns_false(self, fresh_seeder: AssetSeeder):
def test_cancel_when_idle_returns_false(self, fresh_seeder: _AssetSeeder):
cancelled = fresh_seeder.cancel()
assert cancelled is False
def test_state_returns_to_idle_after_completion(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
fresh_seeder.start(roots=("models",))
completed = fresh_seeder.wait(timeout=5.0)
@ -122,7 +121,7 @@ class TestSeederWait:
"""Test wait() behavior."""
def test_wait_blocks_until_complete(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
fresh_seeder.start(roots=("models",))
completed = fresh_seeder.wait(timeout=5.0)
@ -130,7 +129,7 @@ class TestSeederWait:
assert fresh_seeder.get_status().state == State.IDLE
def test_wait_returns_false_on_timeout(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
@ -147,7 +146,7 @@ class TestSeederWait:
barrier.set()
def test_wait_when_idle_returns_true(self, fresh_seeder: AssetSeeder):
def test_wait_when_idle_returns_true(self, fresh_seeder: _AssetSeeder):
completed = fresh_seeder.wait(timeout=1.0)
assert completed is True
@ -156,7 +155,7 @@ class TestSeederProgress:
"""Test progress tracking."""
def test_get_status_returns_progress_during_scan(
self, fresh_seeder: AssetSeeder
self, fresh_seeder: _AssetSeeder
):
barrier = threading.Event()
reached = threading.Event()
@ -172,7 +171,7 @@ class TestSeederProgress:
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
patch("app.assets.seeder.build_stub_specs", side_effect=slow_build),
patch("app.assets.seeder.build_asset_specs", side_effect=slow_build),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -188,7 +187,7 @@ class TestSeederProgress:
barrier.set()
def test_progress_callback_is_invoked(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
progress_updates: list[Progress] = []
@ -209,7 +208,7 @@ class TestSeederCancellation:
"""Test cancellation behavior."""
def test_scan_commits_partial_progress_on_cancellation(
self, fresh_seeder: AssetSeeder
self, fresh_seeder: _AssetSeeder
):
insert_count = 0
barrier = threading.Event()
@ -245,7 +244,7 @@ class TestSeederCancellation:
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
patch(
"app.assets.seeder.build_stub_specs", return_value=(specs, set(), 0)
"app.assets.seeder.build_asset_specs", return_value=(specs, set(), 0)
),
patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
@ -264,7 +263,7 @@ class TestSeederCancellation:
class TestSeederErrorHandling:
"""Test error handling behavior."""
def test_database_errors_captured_in_status(self, fresh_seeder: AssetSeeder):
def test_database_errors_captured_in_status(self, fresh_seeder: _AssetSeeder):
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
@ -273,7 +272,7 @@ class TestSeederErrorHandling:
return_value=["/path/file.safetensors"],
),
patch(
"app.assets.seeder.build_stub_specs",
"app.assets.seeder.build_asset_specs",
return_value=(
[
{
@ -307,7 +306,7 @@ class TestSeederErrorHandling:
assert "DB connection failed" in status.errors[0]
def test_dependencies_unavailable_captured_in_errors(
self, fresh_seeder: AssetSeeder
self, fresh_seeder: _AssetSeeder
):
with patch("app.assets.seeder.dependencies_available", return_value=False):
fresh_seeder.start(roots=("models",))
@ -317,7 +316,7 @@ class TestSeederErrorHandling:
assert len(status.errors) > 0
assert "dependencies" in status.errors[0].lower()
def test_thread_crash_resets_state_to_idle(self, fresh_seeder: AssetSeeder):
def test_thread_crash_resets_state_to_idle(self, fresh_seeder: _AssetSeeder):
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch(
@ -337,7 +336,7 @@ class TestSeederThreadSafety:
"""Test thread safety of concurrent operations."""
def test_concurrent_start_calls_spawn_only_one_thread(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
@ -364,7 +363,7 @@ class TestSeederThreadSafety:
assert sum(results) == 1
def test_get_status_safe_during_scan(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@ -395,7 +394,7 @@ class TestSeederThreadSafety:
class TestSeederMarkMissing:
"""Test mark_missing_outside_prefixes behavior."""
def test_mark_missing_when_idle(self, fresh_seeder: AssetSeeder):
def test_mark_missing_when_idle(self, fresh_seeder: _AssetSeeder):
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch(
@ -411,7 +410,7 @@ class TestSeederMarkMissing:
mock_mark.assert_called_once_with(["/models", "/input", "/output"])
def test_mark_missing_raises_when_running(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@ -433,14 +432,14 @@ class TestSeederMarkMissing:
barrier.set()
def test_mark_missing_returns_zero_when_dependencies_unavailable(
self, fresh_seeder: AssetSeeder
self, fresh_seeder: _AssetSeeder
):
with patch("app.assets.seeder.dependencies_available", return_value=False):
result = fresh_seeder.mark_missing_outside_prefixes()
assert result == 0
def test_prune_first_flag_triggers_mark_missing_before_scan(
self, fresh_seeder: AssetSeeder
self, fresh_seeder: _AssetSeeder
):
call_order = []
@ -458,7 +457,7 @@ class TestSeederMarkMissing:
patch("app.assets.seeder.mark_missing_outside_prefixes_safely", side_effect=track_mark),
patch("app.assets.seeder.sync_root_safely", side_effect=track_sync),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -473,7 +472,7 @@ class TestSeederMarkMissing:
class TestSeederPhases:
"""Test phased scanning behavior."""
def test_start_fast_only_runs_fast_phase(self, fresh_seeder: AssetSeeder):
def test_start_fast_only_runs_fast_phase(self, fresh_seeder: _AssetSeeder):
"""Verify start_fast only runs the fast phase."""
fast_called = []
enrich_called = []
@ -490,7 +489,7 @@ class TestSeederPhases:
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_stub_specs", side_effect=track_fast),
patch("app.assets.seeder.build_asset_specs", side_effect=track_fast),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -501,7 +500,7 @@ class TestSeederPhases:
assert len(fast_called) == 1
assert len(enrich_called) == 0
def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: AssetSeeder):
def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: _AssetSeeder):
"""Verify start_enrich only runs the enrich phase."""
fast_called = []
enrich_called = []
@ -518,7 +517,7 @@ class TestSeederPhases:
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_stub_specs", side_effect=track_fast),
patch("app.assets.seeder.build_asset_specs", side_effect=track_fast),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -529,7 +528,7 @@ class TestSeederPhases:
assert len(fast_called) == 0
assert len(enrich_called) == 1
def test_full_scan_runs_both_phases(self, fresh_seeder: AssetSeeder):
def test_full_scan_runs_both_phases(self, fresh_seeder: _AssetSeeder):
"""Verify full scan runs both fast and enrich phases."""
fast_called = []
enrich_called = []
@ -546,7 +545,7 @@ class TestSeederPhases:
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_stub_specs", side_effect=track_fast),
patch("app.assets.seeder.build_asset_specs", side_effect=track_fast),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -562,7 +561,7 @@ class TestSeederPauseResume:
"""Test pause/resume behavior."""
def test_pause_transitions_to_paused(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@ -584,12 +583,12 @@ class TestSeederPauseResume:
barrier.set()
def test_pause_when_idle_returns_false(self, fresh_seeder: AssetSeeder):
def test_pause_when_idle_returns_false(self, fresh_seeder: _AssetSeeder):
paused = fresh_seeder.pause()
assert paused is False
def test_resume_returns_to_running(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@ -615,7 +614,7 @@ class TestSeederPauseResume:
barrier.set()
def test_resume_when_not_paused_returns_false(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@ -637,7 +636,7 @@ class TestSeederPauseResume:
barrier.set()
def test_cancel_while_paused_works(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached_checkpoint = threading.Event()
@ -667,7 +666,7 @@ class TestSeederStopRestart:
"""Test stop and restart behavior."""
def test_stop_is_alias_for_cancel(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@ -690,7 +689,7 @@ class TestSeederStopRestart:
barrier.set()
def test_restart_cancels_and_starts_new_scan(
self, fresh_seeder: AssetSeeder, mock_dependencies
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
@ -717,7 +716,7 @@ class TestSeederStopRestart:
fresh_seeder.wait(timeout=5.0)
assert start_count == 2
def test_restart_preserves_previous_params(self, fresh_seeder: AssetSeeder):
def test_restart_preserves_previous_params(self, fresh_seeder: _AssetSeeder):
"""Verify restart uses previous params when not overridden."""
collected_roots = []
@ -729,7 +728,7 @@ class TestSeederStopRestart:
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -744,7 +743,7 @@ class TestSeederStopRestart:
assert collected_roots[0] == ("input", "output")
assert collected_roots[1] == ("input", "output")
def test_restart_can_override_params(self, fresh_seeder: AssetSeeder):
def test_restart_can_override_params(self, fresh_seeder: _AssetSeeder):
"""Verify restart can override previous params."""
collected_roots = []
@ -756,7 +755,7 @@ class TestSeederStopRestart:
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
@ -770,3 +769,132 @@ class TestSeederStopRestart:
assert len(collected_roots) == 2
assert collected_roots[0] == ("models",)
assert collected_roots[1] == ("input",)
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
return UnenrichedReferenceRow(
reference_id=ref_id, asset_id=asset_id,
file_path=f"/fake/{ref_id}.bin", enrichment_level=0,
)
class TestEnrichPhaseDefensiveLogic:
"""Test skip_ids filtering and consecutive_empty termination."""
def test_failed_refs_are_skipped_on_subsequent_batches(
self, fresh_seeder: _AssetSeeder,
):
"""References that fail enrichment are filtered out of future batches."""
row_a = _make_row("r1")
row_b = _make_row("r2")
call_count = 0
def fake_get_unenriched(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count <= 2:
return [row_a, row_b]
return []
enriched_refs: list[list[str]] = []
def fake_enrich(rows, **kwargs):
ref_ids = [r.reference_id for r in rows]
enriched_refs.append(ref_ids)
# r1 always fails, r2 succeeds
failed = [r.reference_id for r in rows if r.reference_id == "r1"]
enriched = len(rows) - len(failed)
return enriched, failed
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich),
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH)
fresh_seeder.wait(timeout=5.0)
# First batch: both refs attempted
assert "r1" in enriched_refs[0]
assert "r2" in enriched_refs[0]
# Second batch: r1 filtered out
assert "r1" not in enriched_refs[1]
assert "r2" in enriched_refs[1]
def test_stops_after_consecutive_empty_batches(
self, fresh_seeder: _AssetSeeder,
):
"""Enrich phase terminates after 3 consecutive batches with zero progress."""
row = _make_row("r1")
batch_count = 0
def fake_get_unenriched(*args, **kwargs):
nonlocal batch_count
batch_count += 1
# Always return the same row (simulating a permanently failing ref)
return [row]
def fake_enrich(rows, **kwargs):
# Always fail — zero enriched, all failed
return 0, [r.reference_id for r in rows]
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich),
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH)
fresh_seeder.wait(timeout=5.0)
# Should stop after exactly 3 consecutive empty batches
# Batch 1: returns row, enrich fails → filtered out in batch 2+
# But get_unenriched keeps returning it, filter removes it → empty → break
# Actually: batch 1 has row, fails. Batch 2 get_unenriched returns [row],
# skip_ids filters it → empty list → breaks via `if not unenriched: break`
# So it terminates in 2 calls to get_unenriched.
assert batch_count == 2
def test_consecutive_empty_counter_resets_on_success(
self, fresh_seeder: _AssetSeeder,
):
"""A successful batch resets the consecutive empty counter."""
call_count = 0
def fake_get_unenriched(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count <= 6:
return [_make_row(f"r{call_count}", f"a{call_count}")]
return []
def fake_enrich(rows, **kwargs):
ref_id = rows[0].reference_id
# Fail batches 1-2, succeed batch 3, fail batches 4-5, succeed batch 6
if ref_id in ("r1", "r2", "r4", "r5"):
return 0, [ref_id]
return 1, []
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich),
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH)
fresh_seeder.wait(timeout=5.0)
# All 6 batches should run + 1 final call returning empty
assert call_count == 7
status = fresh_seeder.get_status()
assert status.state == State.IDLE