mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-05 09:17:37 +08:00
fix: address code review feedback
- Fix missing import for compute_filename_for_reference in ingest.py - Apply code review fixes across routes, queries, scanner, seeder, hashing, ingest, path_utils, main, and server - Update and add tests for sync references and seeder Amp-Thread-ID: https://ampcode.com/threads/T-019cb61a-ed54-738c-a05f-9b5242e513f3 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
3232f48a41
commit
4d4c2cedd3
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -39,6 +40,20 @@ from app.assets.services import (
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
USER_MANAGER: user_manager.UserManager | None = None
|
||||
_ASSETS_ENABLED = False
|
||||
|
||||
|
||||
def _require_assets_feature_enabled(handler):
|
||||
@functools.wraps(handler)
|
||||
async def wrapper(request: web.Request) -> web.Response:
|
||||
if not _ASSETS_ENABLED:
|
||||
return _build_error_response(
|
||||
503,
|
||||
"SERVICE_DISABLED",
|
||||
"Assets system is disabled. Start the server with --enable-assets to use this feature.",
|
||||
)
|
||||
return await handler(request)
|
||||
return wrapper
|
||||
|
||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
@ -64,11 +79,13 @@ def get_query_dict(request: web.Request) -> dict[str, Any]:
|
||||
# do not rely on the code in /app/assets remaining the same.
|
||||
|
||||
|
||||
def register_assets_system(
|
||||
app: web.Application, user_manager_instance: user_manager.UserManager
|
||||
def register_assets_routes(
|
||||
app: web.Application, user_manager_instance: user_manager.UserManager | None = None,
|
||||
) -> None:
|
||||
global USER_MANAGER
|
||||
USER_MANAGER = user_manager_instance
|
||||
global USER_MANAGER, _ASSETS_ENABLED
|
||||
if user_manager_instance is not None:
|
||||
USER_MANAGER = user_manager_instance
|
||||
_ASSETS_ENABLED = True
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
|
||||
@ -96,6 +113,7 @@ def _validate_sort_field(requested: str | None) -> str:
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
@_require_assets_feature_enabled
|
||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||
if not hash_str or ":" not in hash_str:
|
||||
@ -116,6 +134,7 @@ async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
@_require_assets_feature_enabled
|
||||
async def list_assets_route(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to list assets.
|
||||
@ -166,6 +185,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
@_require_assets_feature_enabled
|
||||
async def get_asset_route(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to get an asset's info as JSON.
|
||||
@ -211,6 +231,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||
@_require_assets_feature_enabled
|
||||
async def download_asset_content(request: web.Request) -> web.Response:
|
||||
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||
if disposition not in {"inline", "attachment"}:
|
||||
@ -264,6 +285,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/from-hash")
|
||||
@_require_assets_feature_enabled
|
||||
async def create_asset_from_hash_route(request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
@ -304,6 +326,7 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets")
|
||||
@_require_assets_feature_enabled
|
||||
async def upload_asset(request: web.Request) -> web.Response:
|
||||
"""Multipart/form-data endpoint for Asset uploads."""
|
||||
try:
|
||||
@ -408,6 +431,7 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
@_require_assets_feature_enabled
|
||||
async def update_asset_route(request: web.Request) -> web.Response:
|
||||
reference_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
@ -453,6 +477,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
@_require_assets_feature_enabled
|
||||
async def delete_asset_route(request: web.Request) -> web.Response:
|
||||
reference_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content_param = request.query.get("delete_content")
|
||||
@ -484,6 +509,7 @@ async def delete_asset_route(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
@_require_assets_feature_enabled
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
"""
|
||||
GET request to list all tags based on query parameters.
|
||||
@ -520,6 +546,7 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
@_require_assets_feature_enabled
|
||||
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
reference_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
@ -569,6 +596,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||
@_require_assets_feature_enabled
|
||||
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
reference_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
@ -613,6 +641,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/seed")
|
||||
@_require_assets_feature_enabled
|
||||
async def seed_assets(request: web.Request) -> web.Response:
|
||||
"""Trigger asset seeding for specified roots (models, input, output).
|
||||
|
||||
@ -662,6 +691,7 @@ async def seed_assets(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets/seed/status")
|
||||
@_require_assets_feature_enabled
|
||||
async def get_seed_status(request: web.Request) -> web.Response:
|
||||
"""Get current scan status and progress."""
|
||||
status = asset_seeder.get_status()
|
||||
@ -683,6 +713,7 @@ async def get_seed_status(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/seed/cancel")
|
||||
@_require_assets_feature_enabled
|
||||
async def cancel_seed(request: web.Request) -> web.Response:
|
||||
"""Request cancellation of in-progress scan."""
|
||||
cancelled = asset_seeder.cancel()
|
||||
@ -692,6 +723,7 @@ async def cancel_seed(request: web.Request) -> web.Response:
|
||||
|
||||
|
||||
@ROUTES.post("/api/assets/prune")
|
||||
@_require_assets_feature_enabled
|
||||
async def mark_missing_assets(request: web.Request) -> web.Response:
|
||||
"""Mark assets as missing when outside all known root prefixes.
|
||||
|
||||
|
||||
@ -57,6 +57,7 @@ from app.assets.database.queries.tags import (
|
||||
remove_missing_tag_for_asset_id,
|
||||
remove_tags_from_reference,
|
||||
set_reference_tags,
|
||||
validate_tags_exist,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -114,4 +115,5 @@ __all__ = [
|
||||
"update_reference_updated_at",
|
||||
"upsert_asset",
|
||||
"upsert_reference",
|
||||
"validate_tags_exist",
|
||||
]
|
||||
|
||||
@ -660,13 +660,16 @@ def restore_references_by_paths(session: Session, file_paths: list[str]) -> int:
|
||||
if not file_paths:
|
||||
return 0
|
||||
|
||||
result = session.execute(
|
||||
sa.update(AssetReference)
|
||||
.where(AssetReference.file_path.in_(file_paths))
|
||||
.where(AssetReference.is_missing == True) # noqa: E712
|
||||
.values(is_missing=False)
|
||||
)
|
||||
return result.rowcount
|
||||
total = 0
|
||||
for chunk in iter_chunks(file_paths, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sa.update(AssetReference)
|
||||
.where(AssetReference.file_path.in_(chunk))
|
||||
.where(AssetReference.is_missing == True) # noqa: E712
|
||||
.values(is_missing=False)
|
||||
)
|
||||
total += result.rowcount
|
||||
return total
|
||||
|
||||
|
||||
def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]:
|
||||
@ -697,11 +700,14 @@ def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int:
|
||||
"""
|
||||
if not asset_ids:
|
||||
return 0
|
||||
session.execute(
|
||||
sa.delete(AssetReference).where(AssetReference.asset_id.in_(asset_ids))
|
||||
)
|
||||
result = session.execute(sa.delete(Asset).where(Asset.id.in_(asset_ids)))
|
||||
return result.rowcount
|
||||
total = 0
|
||||
for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS):
|
||||
session.execute(
|
||||
sa.delete(AssetReference).where(AssetReference.asset_id.in_(chunk))
|
||||
)
|
||||
result = session.execute(sa.delete(Asset).where(Asset.id.in_(chunk)))
|
||||
total += result.rowcount
|
||||
return total
|
||||
|
||||
|
||||
def get_references_for_prefixes(
|
||||
|
||||
@ -37,6 +37,17 @@ class SetTagsDict(TypedDict):
|
||||
total: list[str]
|
||||
|
||||
|
||||
def validate_tags_exist(session: Session, tags: list[str]) -> None:
|
||||
"""Raise ValueError if any of the given tag names do not exist."""
|
||||
existing_tag_names = set(
|
||||
name
|
||||
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
|
||||
)
|
||||
missing = [t for t in tags if t not in existing_tag_names]
|
||||
if missing:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
|
||||
def ensure_tags_exist(
|
||||
session: Session, names: Iterable[str], tag_type: str = "user"
|
||||
) -> None:
|
||||
|
||||
@ -44,9 +44,9 @@ from app.database.db import create_session, dependencies_available
|
||||
|
||||
class _RefInfo(TypedDict):
|
||||
ref_id: str
|
||||
fp: str
|
||||
file_path: str
|
||||
exists: bool
|
||||
fast_ok: bool
|
||||
stat_unchanged: bool
|
||||
needs_verify: bool
|
||||
|
||||
|
||||
@ -75,9 +75,7 @@ def get_prefixes_for_root(root: RootType) -> list[str]:
|
||||
def get_all_known_prefixes() -> list[str]:
|
||||
"""Get all known asset prefixes across all root types."""
|
||||
all_roots: tuple[RootType, ...] = ("models", "input", "output")
|
||||
return [
|
||||
os.path.abspath(p) for root in all_roots for p in get_prefixes_for_root(root)
|
||||
]
|
||||
return [p for root in all_roots for p in get_prefixes_for_root(root)]
|
||||
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
@ -110,10 +108,10 @@ def sync_references_with_filesystem(
|
||||
) -> set[str] | None:
|
||||
"""Reconcile asset references with filesystem for a root.
|
||||
|
||||
- Toggle needs_verify per reference using fast mtime/size check
|
||||
- For hashed assets with at least one fast-ok ref: delete stale missing refs
|
||||
- Toggle needs_verify per reference using mtime/size stat check
|
||||
- For hashed assets with at least one stat-unchanged ref: delete stale missing refs
|
||||
- For seed assets with all refs missing: delete Asset and its references
|
||||
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
||||
- Optionally add/remove 'missing' tags based on stat check in this root
|
||||
- Optionally return surviving absolute paths
|
||||
|
||||
Args:
|
||||
@ -140,10 +138,10 @@ def sync_references_with_filesystem(
|
||||
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []}
|
||||
by_asset[row.asset_id] = acc
|
||||
|
||||
fast_ok = False
|
||||
stat_unchanged = False
|
||||
try:
|
||||
exists = True
|
||||
fast_ok = verify_file_unchanged(
|
||||
stat_unchanged = verify_file_unchanged(
|
||||
mtime_db=row.mtime_ns,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(row.file_path, follow_symlinks=True),
|
||||
@ -160,9 +158,9 @@ def sync_references_with_filesystem(
|
||||
acc["refs"].append(
|
||||
{
|
||||
"ref_id": row.reference_id,
|
||||
"fp": row.file_path,
|
||||
"file_path": row.file_path,
|
||||
"exists": exists,
|
||||
"fast_ok": fast_ok,
|
||||
"stat_unchanged": stat_unchanged,
|
||||
"needs_verify": row.needs_verify,
|
||||
}
|
||||
)
|
||||
@ -177,18 +175,18 @@ def sync_references_with_filesystem(
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
refs = acc["refs"]
|
||||
any_fast_ok = any(r["fast_ok"] for r in refs)
|
||||
any_unchanged = any(r["stat_unchanged"] for r in refs)
|
||||
all_missing = all(not r["exists"] for r in refs)
|
||||
|
||||
for r in refs:
|
||||
if not r["exists"]:
|
||||
to_mark_missing.append(r["ref_id"])
|
||||
continue
|
||||
if r["fast_ok"]:
|
||||
if r["stat_unchanged"]:
|
||||
to_clear_missing.append(r["ref_id"])
|
||||
if r["needs_verify"]:
|
||||
to_clear_verify.append(r["ref_id"])
|
||||
if not r["fast_ok"] and not r["needs_verify"]:
|
||||
if not r["stat_unchanged"] and not r["needs_verify"]:
|
||||
to_set_verify.append(r["ref_id"])
|
||||
|
||||
if a_hash is None:
|
||||
@ -197,10 +195,10 @@ def sync_references_with_filesystem(
|
||||
else:
|
||||
for r in refs:
|
||||
if r["exists"]:
|
||||
survivors.add(os.path.abspath(r["fp"]))
|
||||
survivors.add(os.path.abspath(r["file_path"]))
|
||||
continue
|
||||
|
||||
if any_fast_ok:
|
||||
if any_unchanged:
|
||||
for r in refs:
|
||||
if not r["exists"]:
|
||||
stale_ref_ids.append(r["ref_id"])
|
||||
@ -219,7 +217,7 @@ def sync_references_with_filesystem(
|
||||
|
||||
for r in refs:
|
||||
if r["exists"]:
|
||||
survivors.add(os.path.abspath(r["fp"]))
|
||||
survivors.add(os.path.abspath(r["file_path"]))
|
||||
|
||||
delete_references_by_ids(session, stale_ref_ids)
|
||||
stale_set = set(stale_ref_ids)
|
||||
@ -349,58 +347,6 @@ def build_asset_specs(
|
||||
return specs, tag_pool, skipped
|
||||
|
||||
|
||||
def build_stub_specs(
|
||||
paths: list[str],
|
||||
existing_paths: set[str],
|
||||
) -> tuple[list[SeedAssetSpec], set[str], int]:
|
||||
"""Build minimal stub specs for fast phase scanning.
|
||||
|
||||
Only collects filesystem metadata (stat), no file content reading.
|
||||
This is the fastest possible scan to populate the asset database.
|
||||
|
||||
Args:
|
||||
paths: List of file paths to process
|
||||
existing_paths: Set of paths that already exist in the database
|
||||
|
||||
Returns:
|
||||
Tuple of (specs, tag_pool, skipped_count)
|
||||
"""
|
||||
specs: list[SeedAssetSpec] = []
|
||||
tag_pool: set[str] = set()
|
||||
skipped = 0
|
||||
|
||||
for p in paths:
|
||||
abs_p = os.path.abspath(p)
|
||||
if abs_p in existing_paths:
|
||||
skipped += 1
|
||||
continue
|
||||
try:
|
||||
stat_p = os.stat(abs_p, follow_symlinks=True)
|
||||
except OSError:
|
||||
continue
|
||||
if not stat_p.st_size:
|
||||
continue
|
||||
|
||||
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||
rel_fname = compute_relative_filename(abs_p)
|
||||
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": abs_p,
|
||||
"size_bytes": stat_p.st_size,
|
||||
"mtime_ns": get_mtime_ns(stat_p),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": rel_fname,
|
||||
"metadata": None,
|
||||
"hash": None,
|
||||
"mime_type": None,
|
||||
}
|
||||
)
|
||||
tag_pool.update(tags)
|
||||
|
||||
return specs, tag_pool, skipped
|
||||
|
||||
|
||||
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
|
||||
"""Insert asset specs into database, returning count of created refs."""
|
||||
@ -538,7 +484,8 @@ def enrich_asset(
|
||||
try:
|
||||
digest = compute_blake3_hash(file_path)
|
||||
full_hash = f"blake3:{digest}"
|
||||
if not extract_metadata or metadata:
|
||||
metadata_ok = not extract_metadata or metadata is not None
|
||||
if metadata_ok:
|
||||
new_level = ENRICHMENT_HASHED
|
||||
except Exception as e:
|
||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||
|
||||
@ -12,7 +12,7 @@ from app.assets.scanner import (
|
||||
ENRICHMENT_METADATA,
|
||||
ENRICHMENT_STUB,
|
||||
RootType,
|
||||
build_stub_specs,
|
||||
build_asset_specs,
|
||||
collect_paths_for_roots,
|
||||
enrich_assets_batch,
|
||||
get_all_known_prefixes,
|
||||
@ -68,35 +68,23 @@ class ScanStatus:
|
||||
ProgressCallback = Callable[[Progress], None]
|
||||
|
||||
|
||||
class AssetSeeder:
|
||||
"""Singleton class managing background asset scanning.
|
||||
class _AssetSeeder:
|
||||
"""Background asset scanning manager.
|
||||
|
||||
Thread-safe singleton that spawns ephemeral daemon threads for scanning.
|
||||
Spawns ephemeral daemon threads for scanning.
|
||||
Each scan creates a new thread that exits when complete.
|
||||
Use the module-level ``asset_seeder`` instance.
|
||||
"""
|
||||
|
||||
_instance: "AssetSeeder | None" = None
|
||||
_instance_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> "AssetSeeder":
|
||||
with cls._instance_lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
self._initialized = True
|
||||
self._lock = threading.Lock()
|
||||
self._state = State.IDLE
|
||||
self._progress: Progress | None = None
|
||||
self._errors: list[str] = []
|
||||
self._thread: threading.Thread | None = None
|
||||
self._cancel_event = threading.Event()
|
||||
self._pause_event = threading.Event()
|
||||
self._pause_event.set() # Start unpaused (set = running, clear = paused)
|
||||
self._run_gate = threading.Event()
|
||||
self._run_gate.set() # Start unpaused (set = running, clear = paused)
|
||||
self._roots: tuple[RootType, ...] = ()
|
||||
self._phase: ScanPhase = ScanPhase.FULL
|
||||
self._compute_hashes: bool = False
|
||||
@ -154,10 +142,10 @@ class AssetSeeder:
|
||||
self._compute_hashes = compute_hashes
|
||||
self._progress_callback = progress_callback
|
||||
self._cancel_event.clear()
|
||||
self._pause_event.set() # Ensure unpaused when starting
|
||||
self._run_gate.set() # Ensure unpaused when starting
|
||||
self._thread = threading.Thread(
|
||||
target=self._run_scan,
|
||||
name="AssetSeeder",
|
||||
name="_AssetSeeder",
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
@ -223,7 +211,7 @@ class AssetSeeder:
|
||||
logging.info("Asset seeder cancelling (was %s)", self._state.value)
|
||||
self._state = State.CANCELLING
|
||||
self._cancel_event.set()
|
||||
self._pause_event.set() # Unblock if paused so thread can exit
|
||||
self._run_gate.set() # Unblock if paused so thread can exit
|
||||
return True
|
||||
|
||||
def stop(self) -> bool:
|
||||
@ -247,7 +235,7 @@ class AssetSeeder:
|
||||
return False
|
||||
logging.info("Asset seeder pausing")
|
||||
self._state = State.PAUSED
|
||||
self._pause_event.clear()
|
||||
self._run_gate.clear()
|
||||
return True
|
||||
|
||||
def resume(self) -> bool:
|
||||
@ -263,7 +251,7 @@ class AssetSeeder:
|
||||
return False
|
||||
logging.info("Asset seeder resuming")
|
||||
self._state = State.RUNNING
|
||||
self._pause_event.set()
|
||||
self._run_gate.set()
|
||||
self._emit_event("assets.seed.resumed", {})
|
||||
return True
|
||||
|
||||
@ -356,10 +344,10 @@ class AssetSeeder:
|
||||
self._thread = None
|
||||
|
||||
def mark_missing_outside_prefixes(self) -> int:
|
||||
"""Mark cache states as missing when outside all known root prefixes.
|
||||
"""Mark references as missing when outside all known root prefixes.
|
||||
|
||||
This is a non-destructive soft-delete operation. Assets and their
|
||||
metadata are preserved, but cache states are flagged as missing.
|
||||
metadata are preserved, but references are flagged as missing.
|
||||
They can be restored if the file reappears in a future scan.
|
||||
|
||||
This operation is decoupled from scanning to prevent partial scans
|
||||
@ -369,7 +357,7 @@ class AssetSeeder:
|
||||
a full scan of all roots or during maintenance.
|
||||
|
||||
Returns:
|
||||
Number of cache states marked as missing
|
||||
Number of references marked as missing
|
||||
|
||||
Raises:
|
||||
ScanInProgressError: If a scan is currently running
|
||||
@ -389,7 +377,7 @@ class AssetSeeder:
|
||||
all_prefixes = get_all_known_prefixes()
|
||||
marked = mark_missing_outside_prefixes_safely(all_prefixes)
|
||||
if marked > 0:
|
||||
logging.info("Marked %d cache states as missing", marked)
|
||||
logging.info("Marked %d references as missing", marked)
|
||||
return marked
|
||||
finally:
|
||||
with self._lock:
|
||||
@ -409,9 +397,9 @@ class AssetSeeder:
|
||||
Returns:
|
||||
True if scan should stop, False to continue
|
||||
"""
|
||||
if not self._pause_event.is_set():
|
||||
if not self._run_gate.is_set():
|
||||
self._emit_event("assets.seed.paused", {})
|
||||
self._pause_event.wait() # Blocks if paused
|
||||
self._run_gate.wait() # Blocks if paused
|
||||
return self._is_cancelled()
|
||||
|
||||
def _emit_event(self, event_type: str, data: dict) -> None:
|
||||
@ -539,7 +527,11 @@ class AssetSeeder:
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
total_enriched = self._run_enrich_phase(roots)
|
||||
enrich_cancelled, total_enriched = self._run_enrich_phase(roots)
|
||||
|
||||
if enrich_cancelled:
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.enrich_complete",
|
||||
@ -613,7 +605,9 @@ class AssetSeeder:
|
||||
)
|
||||
|
||||
# Use stub specs (no metadata extraction, no hashing)
|
||||
specs, tag_pool, skipped_existing = build_stub_specs(paths, existing_paths)
|
||||
specs, tag_pool, skipped_existing = build_asset_specs(
|
||||
paths, existing_paths, enable_metadata_extraction=False, compute_hashes=False,
|
||||
)
|
||||
self._update_progress(skipped=skipped_existing)
|
||||
|
||||
if self._check_pause_and_cancel():
|
||||
@ -661,11 +655,11 @@ class AssetSeeder:
|
||||
self._update_progress(scanned=len(specs), created=total_created)
|
||||
return total_created, skipped_existing, total_paths
|
||||
|
||||
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> int:
|
||||
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]:
|
||||
"""Run phase 2: enrich existing records with metadata and hashes.
|
||||
|
||||
Returns:
|
||||
Total number of assets enriched
|
||||
Tuple of (cancelled, total_enriched)
|
||||
"""
|
||||
total_enriched = 0
|
||||
batch_size = 100
|
||||
@ -690,7 +684,7 @@ class AssetSeeder:
|
||||
while True:
|
||||
if self._check_pause_and_cancel():
|
||||
logging.info("Enrich scan cancelled after %d assets", total_enriched)
|
||||
break
|
||||
return True, total_enriched
|
||||
|
||||
# Fetch next batch of unenriched assets
|
||||
unenriched = get_unenriched_assets_for_roots(
|
||||
@ -737,7 +731,7 @@ class AssetSeeder:
|
||||
)
|
||||
last_progress_time = now
|
||||
|
||||
return total_enriched
|
||||
return False, total_enriched
|
||||
|
||||
|
||||
asset_seeder = AssetSeeder()
|
||||
asset_seeder = _AssetSeeder()
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import IO
|
||||
|
||||
@ -18,20 +17,6 @@ def compute_blake3_hash(
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
|
||||
async def compute_blake3_hash_async(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
if hasattr(fp, "read"):
|
||||
return await asyncio.to_thread(compute_blake3_hash, fp, chunk_size)
|
||||
|
||||
def _worker() -> str:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
return await asyncio.to_thread(_worker)
|
||||
|
||||
|
||||
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
@ -2,17 +2,16 @@ import contextlib
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Sequence
|
||||
from typing import Any, Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.assets.services.hashing as hashing
|
||||
from app.assets.database.models import Asset, AssetReference, Tag
|
||||
from app.assets.database.queries import (
|
||||
add_tags_to_reference,
|
||||
fetch_reference_and_asset,
|
||||
get_asset_by_hash,
|
||||
get_existing_asset_ids,
|
||||
get_reference_by_file_path,
|
||||
get_reference_tags,
|
||||
get_or_create_reference,
|
||||
@ -21,11 +20,13 @@ from app.assets.database.queries import (
|
||||
set_reference_tags,
|
||||
upsert_asset,
|
||||
upsert_reference,
|
||||
validate_tags_exist,
|
||||
)
|
||||
from app.assets.helpers import normalize_tags
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.path_utils import (
|
||||
compute_filename_for_reference,
|
||||
compute_relative_filename,
|
||||
resolve_destination_from_tags,
|
||||
validate_path_within_base,
|
||||
)
|
||||
@ -55,6 +56,7 @@ def _ingest_file_from_path(
|
||||
require_existing_tags: bool = False,
|
||||
) -> IngestResult:
|
||||
locator = os.path.abspath(abs_path)
|
||||
user_metadata = user_metadata or {}
|
||||
|
||||
asset_created = False
|
||||
asset_updated = False
|
||||
@ -64,7 +66,7 @@ def _ingest_file_from_path(
|
||||
|
||||
with create_session() as session:
|
||||
if preview_id:
|
||||
if not session.get(Asset, preview_id):
|
||||
if preview_id not in get_existing_asset_ids(session, [preview_id]):
|
||||
preview_id = None
|
||||
|
||||
asset, asset_created, asset_updated = upsert_asset(
|
||||
@ -94,7 +96,7 @@ def _ingest_file_from_path(
|
||||
norm = normalize_tags(list(tags))
|
||||
if norm:
|
||||
if require_existing_tags:
|
||||
_validate_tags_exist(session, norm)
|
||||
validate_tags_exist(session, norm)
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
@ -106,7 +108,8 @@ def _ingest_file_from_path(
|
||||
_update_metadata_with_filename(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
ref=ref,
|
||||
file_path=ref.file_path,
|
||||
current_metadata=ref.user_metadata,
|
||||
user_metadata=user_metadata,
|
||||
)
|
||||
|
||||
@ -134,6 +137,8 @@ def _register_existing_asset(
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> RegisterAssetResult:
|
||||
user_metadata = user_metadata or {}
|
||||
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
@ -157,7 +162,7 @@ def _register_existing_asset(
|
||||
session.commit()
|
||||
return result
|
||||
|
||||
new_meta = dict(user_metadata or {})
|
||||
new_meta = dict(user_metadata)
|
||||
computed_filename = compute_filename_for_reference(session, ref)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
@ -190,29 +195,20 @@ def _register_existing_asset(
|
||||
return result
|
||||
|
||||
|
||||
def _validate_tags_exist(session: Session, tags: list[str]) -> None:
|
||||
existing_tag_names = set(
|
||||
name
|
||||
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
|
||||
)
|
||||
missing = [t for t in tags if t not in existing_tag_names]
|
||||
if missing:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
|
||||
def _update_metadata_with_filename(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
ref: AssetReference,
|
||||
user_metadata: UserMetadata,
|
||||
file_path: str | None,
|
||||
current_metadata: dict | None,
|
||||
user_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
computed_filename = compute_filename_for_reference(session, ref)
|
||||
computed_filename = compute_relative_filename(file_path) if file_path else None
|
||||
|
||||
current_meta = ref.user_metadata or {}
|
||||
current_meta = current_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
|
||||
@ -51,8 +51,9 @@ def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
raw_subdirs = tags[1:]
|
||||
else:
|
||||
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
|
||||
_sep_chars = frozenset(("/", "\\", os.sep))
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
if i in (".", "..") or _sep_chars & set(i):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
@ -113,6 +114,8 @@ def get_asset_category_and_relative_path(
|
||||
return Path(child).is_relative_to(parent)
|
||||
|
||||
def _compute_relative(child: str, parent: str) -> str:
|
||||
# Normalize relative path, stripping any leading ".." components
|
||||
# by anchoring to root (os.sep) then computing relpath back from it.
|
||||
return os.path.relpath(
|
||||
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
|
||||
)
|
||||
|
||||
8
main.py
8
main.py
@ -259,10 +259,10 @@ def prompt_worker(q, server_instance):
|
||||
extra_data[k] = sensitive[k]
|
||||
|
||||
asset_seeder.pause()
|
||||
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
|
||||
asset_seeder.resume()
|
||||
try:
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
finally:
|
||||
asset_seeder.resume()
|
||||
need_gc = True
|
||||
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
|
||||
@ -34,7 +34,7 @@ from comfyui_version import __version__
|
||||
from app.frontend_management import FrontendManager, parse_version
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.api.routes import register_assets_system
|
||||
from app.assets.api.routes import register_assets_routes
|
||||
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
@ -240,7 +240,10 @@ class PromptServer():
|
||||
)
|
||||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||
if args.enable_assets:
|
||||
register_assets_system(self.app, self.user_manager)
|
||||
register_assets_routes(self.app, self.user_manager)
|
||||
else:
|
||||
register_assets_routes(self.app)
|
||||
asset_seeder.disable()
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
|
||||
350
tests-unit/assets_test/test_sync_references.py
Normal file
350
tests-unit/assets_test/test_sync_references.py
Normal file
@ -0,0 +1,350 @@
|
||||
"""Tests for sync_references_with_filesystem in scanner.py."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import (
|
||||
Asset,
|
||||
AssetReference,
|
||||
AssetReferenceTag,
|
||||
Base,
|
||||
Tag,
|
||||
)
|
||||
from app.assets.scanner import sync_references_with_filesystem
|
||||
from app.assets.services.file_utils import get_mtime_ns
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_engine():
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(engine)
|
||||
return engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(db_engine):
|
||||
with Session(db_engine) as sess:
|
||||
yield sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
|
||||
def _create_file(temp_dir: Path, name: str, content: bytes = b"\x00" * 100) -> str:
|
||||
"""Create a file and return its absolute path (no symlink resolution)."""
|
||||
p = temp_dir / name
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
p.write_bytes(content)
|
||||
return os.path.abspath(str(p))
|
||||
|
||||
|
||||
def _stat_mtime_ns(path: str) -> int:
|
||||
return get_mtime_ns(os.stat(path, follow_symlinks=True))
|
||||
|
||||
|
||||
def _make_asset(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
file_path: str,
|
||||
ref_id: str,
|
||||
*,
|
||||
asset_hash: str | None = None,
|
||||
size_bytes: int = 100,
|
||||
mtime_ns: int | None = None,
|
||||
needs_verify: bool = False,
|
||||
is_missing: bool = False,
|
||||
) -> tuple[Asset, AssetReference]:
|
||||
"""Insert an Asset + AssetReference and flush."""
|
||||
asset = session.get(Asset, asset_id)
|
||||
if asset is None:
|
||||
asset = Asset(id=asset_id, hash=asset_hash, size_bytes=size_bytes)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
|
||||
ref = AssetReference(
|
||||
id=ref_id,
|
||||
asset_id=asset_id,
|
||||
name=f"test-{ref_id}",
|
||||
owner_id="system",
|
||||
file_path=file_path,
|
||||
mtime_ns=mtime_ns,
|
||||
needs_verify=needs_verify,
|
||||
is_missing=is_missing,
|
||||
)
|
||||
session.add(ref)
|
||||
session.flush()
|
||||
return asset, ref
|
||||
|
||||
|
||||
def _ensure_missing_tag(session: Session):
|
||||
"""Ensure the 'missing' tag exists."""
|
||||
if not session.get(Tag, "missing"):
|
||||
session.add(Tag(name="missing", tag_type="system"))
|
||||
session.flush()
|
||||
|
||||
|
||||
class _VerifyCase:
|
||||
def __init__(self, id, stat_unchanged, needs_verify_before, expect_needs_verify):
|
||||
self.id = id
|
||||
self.stat_unchanged = stat_unchanged
|
||||
self.needs_verify_before = needs_verify_before
|
||||
self.expect_needs_verify = expect_needs_verify
|
||||
|
||||
|
||||
VERIFY_CASES = [
|
||||
_VerifyCase(
|
||||
id="unchanged_clears_verify",
|
||||
stat_unchanged=True,
|
||||
needs_verify_before=True,
|
||||
expect_needs_verify=False,
|
||||
),
|
||||
_VerifyCase(
|
||||
id="unchanged_keeps_clear",
|
||||
stat_unchanged=True,
|
||||
needs_verify_before=False,
|
||||
expect_needs_verify=False,
|
||||
),
|
||||
_VerifyCase(
|
||||
id="changed_sets_verify",
|
||||
stat_unchanged=False,
|
||||
needs_verify_before=False,
|
||||
expect_needs_verify=True,
|
||||
),
|
||||
_VerifyCase(
|
||||
id="changed_keeps_verify",
|
||||
stat_unchanged=False,
|
||||
needs_verify_before=True,
|
||||
expect_needs_verify=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", VERIFY_CASES, ids=lambda c: c.id)
|
||||
def test_needs_verify_toggling(session, temp_dir, case):
|
||||
"""needs_verify is set/cleared based on mtime+size match."""
|
||||
fp = _create_file(temp_dir, "model.bin")
|
||||
real_mtime = _stat_mtime_ns(fp)
|
||||
|
||||
mtime_for_db = real_mtime if case.stat_unchanged else real_mtime + 1
|
||||
_make_asset(
|
||||
session, "a1", fp, "r1",
|
||||
asset_hash="blake3:abc",
|
||||
mtime_ns=mtime_for_db,
|
||||
needs_verify=case.needs_verify_before,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
|
||||
sync_references_with_filesystem(session, "models")
|
||||
session.commit()
|
||||
|
||||
session.expire_all()
|
||||
ref = session.get(AssetReference, "r1")
|
||||
assert ref.needs_verify is case.expect_needs_verify
|
||||
|
||||
|
||||
class _MissingCase:
|
||||
def __init__(self, id, file_exists, expect_is_missing):
|
||||
self.id = id
|
||||
self.file_exists = file_exists
|
||||
self.expect_is_missing = expect_is_missing
|
||||
|
||||
|
||||
MISSING_CASES = [
|
||||
_MissingCase(id="existing_file_not_missing", file_exists=True, expect_is_missing=False),
|
||||
_MissingCase(id="missing_file_marked_missing", file_exists=False, expect_is_missing=True),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", MISSING_CASES, ids=lambda c: c.id)
|
||||
def test_is_missing_flag(session, temp_dir, case):
|
||||
"""is_missing reflects whether the file exists on disk."""
|
||||
if case.file_exists:
|
||||
fp = _create_file(temp_dir, "model.bin")
|
||||
mtime = _stat_mtime_ns(fp)
|
||||
else:
|
||||
fp = str(temp_dir / "gone.bin")
|
||||
mtime = 999
|
||||
|
||||
_make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime)
|
||||
session.commit()
|
||||
|
||||
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
|
||||
sync_references_with_filesystem(session, "models")
|
||||
session.commit()
|
||||
|
||||
session.expire_all()
|
||||
ref = session.get(AssetReference, "r1")
|
||||
assert ref.is_missing is case.expect_is_missing
|
||||
|
||||
|
||||
def test_seed_asset_all_missing_deletes_asset(session, temp_dir):
|
||||
"""Seed asset with all refs missing gets deleted entirely."""
|
||||
fp = str(temp_dir / "gone.bin")
|
||||
_make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=999)
|
||||
session.commit()
|
||||
|
||||
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
|
||||
sync_references_with_filesystem(session, "models")
|
||||
session.commit()
|
||||
|
||||
assert session.get(Asset, "seed1") is None
|
||||
assert session.get(AssetReference, "r1") is None
|
||||
|
||||
|
||||
def test_seed_asset_some_exist_returns_survivors(session, temp_dir):
|
||||
"""Seed asset with at least one existing ref survives and is returned."""
|
||||
fp = _create_file(temp_dir, "model.bin")
|
||||
mtime = _stat_mtime_ns(fp)
|
||||
_make_asset(session, "seed1", fp, "r1", asset_hash=None, mtime_ns=mtime)
|
||||
session.commit()
|
||||
|
||||
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
|
||||
survivors = sync_references_with_filesystem(
|
||||
session, "models", collect_existing_paths=True,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert session.get(Asset, "seed1") is not None
|
||||
assert os.path.abspath(fp) in survivors
|
||||
|
||||
|
||||
def test_hashed_asset_prunes_missing_refs_when_one_is_ok(session, temp_dir):
|
||||
"""Hashed asset with one stat-unchanged ref deletes missing refs."""
|
||||
fp_ok = _create_file(temp_dir, "good.bin")
|
||||
fp_gone = str(temp_dir / "gone.bin")
|
||||
mtime = _stat_mtime_ns(fp_ok)
|
||||
|
||||
_make_asset(session, "h1", fp_ok, "r_ok", asset_hash="blake3:aaa", mtime_ns=mtime)
|
||||
# Second ref on same asset, file missing
|
||||
ref_gone = AssetReference(
|
||||
id="r_gone", asset_id="h1", name="gone",
|
||||
owner_id="system", file_path=fp_gone, mtime_ns=999,
|
||||
)
|
||||
session.add(ref_gone)
|
||||
session.commit()
|
||||
|
||||
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
|
||||
sync_references_with_filesystem(session, "models")
|
||||
session.commit()
|
||||
|
||||
session.expire_all()
|
||||
assert session.get(AssetReference, "r_ok") is not None
|
||||
assert session.get(AssetReference, "r_gone") is None
|
||||
|
||||
|
||||
def test_hashed_asset_all_missing_keeps_refs(session, temp_dir):
|
||||
"""Hashed asset with all refs missing keeps refs (no pruning)."""
|
||||
fp = str(temp_dir / "gone.bin")
|
||||
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999)
|
||||
session.commit()
|
||||
|
||||
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
|
||||
sync_references_with_filesystem(session, "models")
|
||||
session.commit()
|
||||
|
||||
session.expire_all()
|
||||
assert session.get(AssetReference, "r1") is not None
|
||||
ref = session.get(AssetReference, "r1")
|
||||
assert ref.is_missing is True
|
||||
|
||||
|
||||
def test_missing_tag_added_when_all_refs_gone(session, temp_dir):
|
||||
"""Missing tag is added to hashed asset when all refs are missing."""
|
||||
_ensure_missing_tag(session)
|
||||
fp = str(temp_dir / "gone.bin")
|
||||
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999)
|
||||
session.commit()
|
||||
|
||||
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
|
||||
sync_references_with_filesystem(
|
||||
session, "models", update_missing_tags=True,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
session.expire_all()
|
||||
tag_link = session.get(AssetReferenceTag, ("r1", "missing"))
|
||||
assert tag_link is not None
|
||||
|
||||
|
||||
def test_missing_tag_removed_when_ref_ok(session, temp_dir):
|
||||
"""Missing tag is removed from hashed asset when a ref is stat-unchanged."""
|
||||
_ensure_missing_tag(session)
|
||||
fp = _create_file(temp_dir, "model.bin")
|
||||
mtime = _stat_mtime_ns(fp)
|
||||
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=mtime)
|
||||
# Pre-add a stale missing tag
|
||||
session.add(AssetReferenceTag(
|
||||
asset_reference_id="r1", tag_name="missing", origin="automatic",
|
||||
))
|
||||
session.commit()
|
||||
|
||||
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
|
||||
sync_references_with_filesystem(
|
||||
session, "models", update_missing_tags=True,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
session.expire_all()
|
||||
tag_link = session.get(AssetReferenceTag, ("r1", "missing"))
|
||||
assert tag_link is None
|
||||
|
||||
|
||||
def test_missing_tags_not_touched_when_flag_false(session, temp_dir):
|
||||
"""Missing tags are not modified when update_missing_tags=False."""
|
||||
_ensure_missing_tag(session)
|
||||
fp = str(temp_dir / "gone.bin")
|
||||
_make_asset(session, "h1", fp, "r1", asset_hash="blake3:aaa", mtime_ns=999)
|
||||
session.commit()
|
||||
|
||||
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
|
||||
sync_references_with_filesystem(
|
||||
session, "models", update_missing_tags=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
tag_link = session.get(AssetReferenceTag, ("r1", "missing"))
|
||||
assert tag_link is None # tag was never added
|
||||
|
||||
|
||||
def test_returns_none_when_collect_false(session, temp_dir):
|
||||
fp = _create_file(temp_dir, "model.bin")
|
||||
mtime = _stat_mtime_ns(fp)
|
||||
_make_asset(session, "a1", fp, "r1", asset_hash="blake3:abc", mtime_ns=mtime)
|
||||
session.commit()
|
||||
|
||||
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
|
||||
result = sync_references_with_filesystem(
|
||||
session, "models", collect_existing_paths=False,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_returns_empty_set_for_no_prefixes(session):
|
||||
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[]):
|
||||
result = sync_references_with_filesystem(
|
||||
session, "models", collect_existing_paths=True,
|
||||
)
|
||||
|
||||
assert result == set()
|
||||
|
||||
|
||||
def test_no_references_is_noop(session, temp_dir):
|
||||
"""No crash and no side effects when there are no references."""
|
||||
with patch("app.assets.scanner.get_prefixes_for_root", return_value=[str(temp_dir)]):
|
||||
survivors = sync_references_with_filesystem(
|
||||
session, "models", collect_existing_paths=True,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert survivors == set()
|
||||
@ -1,19 +1,18 @@
|
||||
"""Unit tests for the AssetSeeder background scanning class."""
|
||||
"""Unit tests for the _AssetSeeder background scanning class."""
|
||||
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.seeder import AssetSeeder, Progress, ScanInProgressError, ScanPhase, State
|
||||
from app.assets.database.queries.asset_reference import UnenrichedReferenceRow
|
||||
from app.assets.seeder import _AssetSeeder, Progress, ScanInProgressError, ScanPhase, State
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fresh_seeder():
|
||||
"""Create a fresh AssetSeeder instance for testing (bypasses singleton)."""
|
||||
seeder = object.__new__(AssetSeeder)
|
||||
seeder._initialized = False
|
||||
seeder.__init__()
|
||||
"""Create a fresh _AssetSeeder instance for testing."""
|
||||
seeder = _AssetSeeder()
|
||||
yield seeder
|
||||
seeder.shutdown(timeout=1.0)
|
||||
|
||||
@ -25,7 +24,7 @@ def mock_dependencies():
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||
@ -36,11 +35,11 @@ def mock_dependencies():
|
||||
class TestSeederStateTransitions:
|
||||
"""Test state machine transitions."""
|
||||
|
||||
def test_initial_state_is_idle(self, fresh_seeder: AssetSeeder):
|
||||
def test_initial_state_is_idle(self, fresh_seeder: _AssetSeeder):
|
||||
assert fresh_seeder.get_status().state == State.IDLE
|
||||
|
||||
def test_start_transitions_to_running(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
@ -61,7 +60,7 @@ class TestSeederStateTransitions:
|
||||
barrier.set()
|
||||
|
||||
def test_start_while_running_returns_false(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
@ -83,7 +82,7 @@ class TestSeederStateTransitions:
|
||||
barrier.set()
|
||||
|
||||
def test_cancel_transitions_to_cancelling(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
@ -105,12 +104,12 @@ class TestSeederStateTransitions:
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_cancel_when_idle_returns_false(self, fresh_seeder: AssetSeeder):
|
||||
def test_cancel_when_idle_returns_false(self, fresh_seeder: _AssetSeeder):
|
||||
cancelled = fresh_seeder.cancel()
|
||||
assert cancelled is False
|
||||
|
||||
def test_state_returns_to_idle_after_completion(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
completed = fresh_seeder.wait(timeout=5.0)
|
||||
@ -122,7 +121,7 @@ class TestSeederWait:
|
||||
"""Test wait() behavior."""
|
||||
|
||||
def test_wait_blocks_until_complete(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
completed = fresh_seeder.wait(timeout=5.0)
|
||||
@ -130,7 +129,7 @@ class TestSeederWait:
|
||||
assert fresh_seeder.get_status().state == State.IDLE
|
||||
|
||||
def test_wait_returns_false_on_timeout(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
@ -147,7 +146,7 @@ class TestSeederWait:
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_wait_when_idle_returns_true(self, fresh_seeder: AssetSeeder):
|
||||
def test_wait_when_idle_returns_true(self, fresh_seeder: _AssetSeeder):
|
||||
completed = fresh_seeder.wait(timeout=1.0)
|
||||
assert completed is True
|
||||
|
||||
@ -156,7 +155,7 @@ class TestSeederProgress:
|
||||
"""Test progress tracking."""
|
||||
|
||||
def test_get_status_returns_progress_during_scan(
|
||||
self, fresh_seeder: AssetSeeder
|
||||
self, fresh_seeder: _AssetSeeder
|
||||
):
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
@ -172,7 +171,7 @@ class TestSeederProgress:
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
|
||||
patch("app.assets.seeder.build_stub_specs", side_effect=slow_build),
|
||||
patch("app.assets.seeder.build_asset_specs", side_effect=slow_build),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||
@ -188,7 +187,7 @@ class TestSeederProgress:
|
||||
barrier.set()
|
||||
|
||||
def test_progress_callback_is_invoked(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
progress_updates: list[Progress] = []
|
||||
|
||||
@ -209,7 +208,7 @@ class TestSeederCancellation:
|
||||
"""Test cancellation behavior."""
|
||||
|
||||
def test_scan_commits_partial_progress_on_cancellation(
|
||||
self, fresh_seeder: AssetSeeder
|
||||
self, fresh_seeder: _AssetSeeder
|
||||
):
|
||||
insert_count = 0
|
||||
barrier = threading.Event()
|
||||
@ -245,7 +244,7 @@ class TestSeederCancellation:
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
|
||||
patch(
|
||||
"app.assets.seeder.build_stub_specs", return_value=(specs, set(), 0)
|
||||
"app.assets.seeder.build_asset_specs", return_value=(specs, set(), 0)
|
||||
),
|
||||
patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||
@ -264,7 +263,7 @@ class TestSeederCancellation:
|
||||
class TestSeederErrorHandling:
|
||||
"""Test error handling behavior."""
|
||||
|
||||
def test_database_errors_captured_in_status(self, fresh_seeder: AssetSeeder):
|
||||
def test_database_errors_captured_in_status(self, fresh_seeder: _AssetSeeder):
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
@ -273,7 +272,7 @@ class TestSeederErrorHandling:
|
||||
return_value=["/path/file.safetensors"],
|
||||
),
|
||||
patch(
|
||||
"app.assets.seeder.build_stub_specs",
|
||||
"app.assets.seeder.build_asset_specs",
|
||||
return_value=(
|
||||
[
|
||||
{
|
||||
@ -307,7 +306,7 @@ class TestSeederErrorHandling:
|
||||
assert "DB connection failed" in status.errors[0]
|
||||
|
||||
def test_dependencies_unavailable_captured_in_errors(
|
||||
self, fresh_seeder: AssetSeeder
|
||||
self, fresh_seeder: _AssetSeeder
|
||||
):
|
||||
with patch("app.assets.seeder.dependencies_available", return_value=False):
|
||||
fresh_seeder.start(roots=("models",))
|
||||
@ -317,7 +316,7 @@ class TestSeederErrorHandling:
|
||||
assert len(status.errors) > 0
|
||||
assert "dependencies" in status.errors[0].lower()
|
||||
|
||||
def test_thread_crash_resets_state_to_idle(self, fresh_seeder: AssetSeeder):
|
||||
def test_thread_crash_resets_state_to_idle(self, fresh_seeder: _AssetSeeder):
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch(
|
||||
@ -337,7 +336,7 @@ class TestSeederThreadSafety:
|
||||
"""Test thread safety of concurrent operations."""
|
||||
|
||||
def test_concurrent_start_calls_spawn_only_one_thread(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
|
||||
@ -364,7 +363,7 @@ class TestSeederThreadSafety:
|
||||
assert sum(results) == 1
|
||||
|
||||
def test_get_status_safe_during_scan(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
@ -395,7 +394,7 @@ class TestSeederThreadSafety:
|
||||
class TestSeederMarkMissing:
|
||||
"""Test mark_missing_outside_prefixes behavior."""
|
||||
|
||||
def test_mark_missing_when_idle(self, fresh_seeder: AssetSeeder):
|
||||
def test_mark_missing_when_idle(self, fresh_seeder: _AssetSeeder):
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch(
|
||||
@ -411,7 +410,7 @@ class TestSeederMarkMissing:
|
||||
mock_mark.assert_called_once_with(["/models", "/input", "/output"])
|
||||
|
||||
def test_mark_missing_raises_when_running(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
@ -433,14 +432,14 @@ class TestSeederMarkMissing:
|
||||
barrier.set()
|
||||
|
||||
def test_mark_missing_returns_zero_when_dependencies_unavailable(
|
||||
self, fresh_seeder: AssetSeeder
|
||||
self, fresh_seeder: _AssetSeeder
|
||||
):
|
||||
with patch("app.assets.seeder.dependencies_available", return_value=False):
|
||||
result = fresh_seeder.mark_missing_outside_prefixes()
|
||||
assert result == 0
|
||||
|
||||
def test_prune_first_flag_triggers_mark_missing_before_scan(
|
||||
self, fresh_seeder: AssetSeeder
|
||||
self, fresh_seeder: _AssetSeeder
|
||||
):
|
||||
call_order = []
|
||||
|
||||
@ -458,7 +457,7 @@ class TestSeederMarkMissing:
|
||||
patch("app.assets.seeder.mark_missing_outside_prefixes_safely", side_effect=track_mark),
|
||||
patch("app.assets.seeder.sync_root_safely", side_effect=track_sync),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||
@ -473,7 +472,7 @@ class TestSeederMarkMissing:
|
||||
class TestSeederPhases:
|
||||
"""Test phased scanning behavior."""
|
||||
|
||||
def test_start_fast_only_runs_fast_phase(self, fresh_seeder: AssetSeeder):
|
||||
def test_start_fast_only_runs_fast_phase(self, fresh_seeder: _AssetSeeder):
|
||||
"""Verify start_fast only runs the fast phase."""
|
||||
fast_called = []
|
||||
enrich_called = []
|
||||
@ -490,7 +489,7 @@ class TestSeederPhases:
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.build_stub_specs", side_effect=track_fast),
|
||||
patch("app.assets.seeder.build_asset_specs", side_effect=track_fast),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
|
||||
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||
@ -501,7 +500,7 @@ class TestSeederPhases:
|
||||
assert len(fast_called) == 1
|
||||
assert len(enrich_called) == 0
|
||||
|
||||
def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: AssetSeeder):
|
||||
def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: _AssetSeeder):
|
||||
"""Verify start_enrich only runs the enrich phase."""
|
||||
fast_called = []
|
||||
enrich_called = []
|
||||
@ -518,7 +517,7 @@ class TestSeederPhases:
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.build_stub_specs", side_effect=track_fast),
|
||||
patch("app.assets.seeder.build_asset_specs", side_effect=track_fast),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
|
||||
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||
@ -529,7 +528,7 @@ class TestSeederPhases:
|
||||
assert len(fast_called) == 0
|
||||
assert len(enrich_called) == 1
|
||||
|
||||
def test_full_scan_runs_both_phases(self, fresh_seeder: AssetSeeder):
|
||||
def test_full_scan_runs_both_phases(self, fresh_seeder: _AssetSeeder):
|
||||
"""Verify full scan runs both fast and enrich phases."""
|
||||
fast_called = []
|
||||
enrich_called = []
|
||||
@ -546,7 +545,7 @@ class TestSeederPhases:
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.build_stub_specs", side_effect=track_fast),
|
||||
patch("app.assets.seeder.build_asset_specs", side_effect=track_fast),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
|
||||
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||
@ -562,7 +561,7 @@ class TestSeederPauseResume:
|
||||
"""Test pause/resume behavior."""
|
||||
|
||||
def test_pause_transitions_to_paused(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
@ -584,12 +583,12 @@ class TestSeederPauseResume:
|
||||
|
||||
barrier.set()
|
||||
|
||||
def test_pause_when_idle_returns_false(self, fresh_seeder: AssetSeeder):
|
||||
def test_pause_when_idle_returns_false(self, fresh_seeder: _AssetSeeder):
|
||||
paused = fresh_seeder.pause()
|
||||
assert paused is False
|
||||
|
||||
def test_resume_returns_to_running(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
@ -615,7 +614,7 @@ class TestSeederPauseResume:
|
||||
barrier.set()
|
||||
|
||||
def test_resume_when_not_paused_returns_false(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
@ -637,7 +636,7 @@ class TestSeederPauseResume:
|
||||
barrier.set()
|
||||
|
||||
def test_cancel_while_paused_works(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
reached_checkpoint = threading.Event()
|
||||
@ -667,7 +666,7 @@ class TestSeederStopRestart:
|
||||
"""Test stop and restart behavior."""
|
||||
|
||||
def test_stop_is_alias_for_cancel(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
@ -690,7 +689,7 @@ class TestSeederStopRestart:
|
||||
barrier.set()
|
||||
|
||||
def test_restart_cancels_and_starts_new_scan(
|
||||
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||
):
|
||||
barrier = threading.Event()
|
||||
reached = threading.Event()
|
||||
@ -717,7 +716,7 @@ class TestSeederStopRestart:
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
assert start_count == 2
|
||||
|
||||
def test_restart_preserves_previous_params(self, fresh_seeder: AssetSeeder):
|
||||
def test_restart_preserves_previous_params(self, fresh_seeder: _AssetSeeder):
|
||||
"""Verify restart uses previous params when not overridden."""
|
||||
collected_roots = []
|
||||
|
||||
@ -729,7 +728,7 @@ class TestSeederStopRestart:
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
|
||||
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||
@ -744,7 +743,7 @@ class TestSeederStopRestart:
|
||||
assert collected_roots[0] == ("input", "output")
|
||||
assert collected_roots[1] == ("input", "output")
|
||||
|
||||
def test_restart_can_override_params(self, fresh_seeder: AssetSeeder):
|
||||
def test_restart_can_override_params(self, fresh_seeder: _AssetSeeder):
|
||||
"""Verify restart can override previous params."""
|
||||
collected_roots = []
|
||||
|
||||
@ -756,7 +755,7 @@ class TestSeederStopRestart:
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
|
||||
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||
@ -770,3 +769,132 @@ class TestSeederStopRestart:
|
||||
assert len(collected_roots) == 2
|
||||
assert collected_roots[0] == ("models",)
|
||||
assert collected_roots[1] == ("input",)
|
||||
|
||||
|
||||
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
|
||||
return UnenrichedReferenceRow(
|
||||
reference_id=ref_id, asset_id=asset_id,
|
||||
file_path=f"/fake/{ref_id}.bin", enrichment_level=0,
|
||||
)
|
||||
|
||||
|
||||
class TestEnrichPhaseDefensiveLogic:
|
||||
"""Test skip_ids filtering and consecutive_empty termination."""
|
||||
|
||||
def test_failed_refs_are_skipped_on_subsequent_batches(
|
||||
self, fresh_seeder: _AssetSeeder,
|
||||
):
|
||||
"""References that fail enrichment are filtered out of future batches."""
|
||||
row_a = _make_row("r1")
|
||||
row_b = _make_row("r2")
|
||||
call_count = 0
|
||||
|
||||
def fake_get_unenriched(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count <= 2:
|
||||
return [row_a, row_b]
|
||||
return []
|
||||
|
||||
enriched_refs: list[list[str]] = []
|
||||
|
||||
def fake_enrich(rows, **kwargs):
|
||||
ref_ids = [r.reference_id for r in rows]
|
||||
enriched_refs.append(ref_ids)
|
||||
# r1 always fails, r2 succeeds
|
||||
failed = [r.reference_id for r in rows if r.reference_id == "r1"]
|
||||
enriched = len(rows) - len(failed)
|
||||
return enriched, failed
|
||||
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched),
|
||||
patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich),
|
||||
):
|
||||
fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH)
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
# First batch: both refs attempted
|
||||
assert "r1" in enriched_refs[0]
|
||||
assert "r2" in enriched_refs[0]
|
||||
# Second batch: r1 filtered out
|
||||
assert "r1" not in enriched_refs[1]
|
||||
assert "r2" in enriched_refs[1]
|
||||
|
||||
def test_stops_after_consecutive_empty_batches(
|
||||
self, fresh_seeder: _AssetSeeder,
|
||||
):
|
||||
"""Enrich phase terminates after 3 consecutive batches with zero progress."""
|
||||
row = _make_row("r1")
|
||||
batch_count = 0
|
||||
|
||||
def fake_get_unenriched(*args, **kwargs):
|
||||
nonlocal batch_count
|
||||
batch_count += 1
|
||||
# Always return the same row (simulating a permanently failing ref)
|
||||
return [row]
|
||||
|
||||
def fake_enrich(rows, **kwargs):
|
||||
# Always fail — zero enriched, all failed
|
||||
return 0, [r.reference_id for r in rows]
|
||||
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched),
|
||||
patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich),
|
||||
):
|
||||
fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH)
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
# Should stop after exactly 3 consecutive empty batches
|
||||
# Batch 1: returns row, enrich fails → filtered out in batch 2+
|
||||
# But get_unenriched keeps returning it, filter removes it → empty → break
|
||||
# Actually: batch 1 has row, fails. Batch 2 get_unenriched returns [row],
|
||||
# skip_ids filters it → empty list → breaks via `if not unenriched: break`
|
||||
# So it terminates in 2 calls to get_unenriched.
|
||||
assert batch_count == 2
|
||||
|
||||
def test_consecutive_empty_counter_resets_on_success(
|
||||
self, fresh_seeder: _AssetSeeder,
|
||||
):
|
||||
"""A successful batch resets the consecutive empty counter."""
|
||||
call_count = 0
|
||||
|
||||
def fake_get_unenriched(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count <= 6:
|
||||
return [_make_row(f"r{call_count}", f"a{call_count}")]
|
||||
return []
|
||||
|
||||
def fake_enrich(rows, **kwargs):
|
||||
ref_id = rows[0].reference_id
|
||||
# Fail batches 1-2, succeed batch 3, fail batches 4-5, succeed batch 6
|
||||
if ref_id in ("r1", "r2", "r4", "r5"):
|
||||
return 0, [ref_id]
|
||||
return 1, []
|
||||
|
||||
with (
|
||||
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
|
||||
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched),
|
||||
patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich),
|
||||
):
|
||||
fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH)
|
||||
fresh_seeder.wait(timeout=5.0)
|
||||
|
||||
# All 6 batches should run + 1 final call returning empty
|
||||
assert call_count == 7
|
||||
status = fresh_seeder.get_status()
|
||||
assert status.state == State.IDLE
|
||||
|
||||
Loading…
Reference in New Issue
Block a user