mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-06 01:37:45 +08:00
feat(assets): async two-phase scanner and background seeder
- Rewrite scanner.py with two-phase scanning architecture (fast scan + enrich) - Add AssetSeeder for non-blocking background startup scanning - Implement pause/resume/stop/restart controls and disable/enable for --disable-assets-autoscan - Add non-destructive asset pruning with is_missing flag - Wire seeder into main.py and server.py lifecycle - Skip hidden files/directories, populate mime_type, optional blake3 hashing - Add comprehensive seeder tests Co-authored-by: Amp <amp@ampcode.com> Amp-Thread-ID: https://ampcode.com/threads/T-019c9209-37af-757a-b6e4-af59b4267362
This commit is contained in:
parent
4e0282c2a0
commit
709a721591
@ -19,7 +19,7 @@ from app.assets.api.upload import (
|
|||||||
delete_temp_file_if_exists,
|
delete_temp_file_if_exists,
|
||||||
parse_multipart_upload,
|
parse_multipart_upload,
|
||||||
)
|
)
|
||||||
from app.assets.seeder import asset_seeder
|
from app.assets.seeder import ScanInProgressError, asset_seeder
|
||||||
from app.assets.services import (
|
from app.assets.services import (
|
||||||
DependencyMissingError,
|
DependencyMissingError,
|
||||||
HashMismatchError,
|
HashMismatchError,
|
||||||
@ -717,8 +717,9 @@ async def mark_missing_assets(request: web.Request) -> web.Response:
|
|||||||
200 OK with count of marked assets
|
200 OK with count of marked assets
|
||||||
409 Conflict if a scan is currently running
|
409 Conflict if a scan is currently running
|
||||||
"""
|
"""
|
||||||
marked = asset_seeder.mark_missing_outside_prefixes()
|
try:
|
||||||
if marked == 0 and asset_seeder.get_status().state.value != "IDLE":
|
marked = asset_seeder.mark_missing_outside_prefixes()
|
||||||
|
except ScanInProgressError:
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
{"status": "scan_running", "marked": 0},
|
{"status": "scan_running", "marked": 0},
|
||||||
status=409,
|
status=409,
|
||||||
|
|||||||
@ -1,263 +1,602 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import time
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sqlalchemy
|
import time
|
||||||
|
from typing import Literal, TypedDict
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from app.database.db import create_session, dependencies_available
|
from app.assets.database.queries import (
|
||||||
from app.assets.helpers import (
|
add_missing_tag_for_asset_id,
|
||||||
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
|
bulk_update_enrichment_level,
|
||||||
list_tree,prefixes_for_root, escape_like_prefix,
|
bulk_update_is_missing,
|
||||||
RootType
|
bulk_update_needs_verify,
|
||||||
|
delete_orphaned_seed_asset,
|
||||||
|
delete_references_by_ids,
|
||||||
|
ensure_tags_exist,
|
||||||
|
get_asset_by_hash,
|
||||||
|
get_references_for_prefixes,
|
||||||
|
get_unenriched_references,
|
||||||
|
reassign_asset_references,
|
||||||
|
remove_missing_tag_for_asset_id,
|
||||||
|
set_reference_metadata,
|
||||||
|
update_asset_hash_and_mime,
|
||||||
)
|
)
|
||||||
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
|
from app.assets.services.bulk_ingest import (
|
||||||
from app.assets.database.bulk_ops import seed_from_paths_batch
|
SeedAssetSpec,
|
||||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
|
batch_insert_seed_assets,
|
||||||
|
mark_assets_missing_outside_prefixes,
|
||||||
|
)
|
||||||
|
from app.assets.services.file_utils import (
|
||||||
|
get_mtime_ns,
|
||||||
|
list_files_recursively,
|
||||||
|
verify_file_unchanged,
|
||||||
|
)
|
||||||
|
from app.assets.services.hashing import compute_blake3_hash
|
||||||
|
from app.assets.services.metadata_extract import extract_file_metadata
|
||||||
|
from app.assets.services.path_utils import (
|
||||||
|
compute_relative_filename,
|
||||||
|
get_comfy_models_folders,
|
||||||
|
get_name_and_tags_from_asset_path,
|
||||||
|
)
|
||||||
|
from app.database.db import create_session, dependencies_available
|
||||||
|
|
||||||
|
|
||||||
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
|
class _RefInfo(TypedDict):
|
||||||
|
ref_id: str
|
||||||
|
fp: str
|
||||||
|
exists: bool
|
||||||
|
fast_ok: bool
|
||||||
|
needs_verify: bool
|
||||||
|
|
||||||
|
|
||||||
|
class _AssetAccumulator(TypedDict):
|
||||||
|
hash: str | None
|
||||||
|
size_db: int
|
||||||
|
refs: list[_RefInfo]
|
||||||
|
|
||||||
|
|
||||||
|
RootType = Literal["models", "input", "output"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_prefixes_for_root(root: RootType) -> list[str]:
|
||||||
|
if root == "models":
|
||||||
|
bases: list[str] = []
|
||||||
|
for _bucket, paths in get_comfy_models_folders():
|
||||||
|
bases.extend(paths)
|
||||||
|
return [os.path.abspath(p) for p in bases]
|
||||||
|
if root == "input":
|
||||||
|
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||||
|
if root == "output":
|
||||||
|
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_known_prefixes() -> list[str]:
|
||||||
|
"""Get all known asset prefixes across all root types."""
|
||||||
|
all_roots: tuple[RootType, ...] = ("models", "input", "output")
|
||||||
|
return [
|
||||||
|
os.path.abspath(p) for root in all_roots for p in get_prefixes_for_root(root)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def collect_models_files() -> list[str]:
|
||||||
|
out: list[str] = []
|
||||||
|
for folder_name, bases in get_comfy_models_folders():
|
||||||
|
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||||
|
for rel_path in rel_files:
|
||||||
|
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||||
|
if not abs_path:
|
||||||
|
continue
|
||||||
|
abs_path = os.path.abspath(abs_path)
|
||||||
|
allowed = False
|
||||||
|
for b in bases:
|
||||||
|
base_abs = os.path.abspath(b)
|
||||||
|
with contextlib.suppress(ValueError):
|
||||||
|
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||||
|
allowed = True
|
||||||
|
break
|
||||||
|
if allowed:
|
||||||
|
out.append(abs_path)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def sync_references_with_filesystem(
|
||||||
|
session,
|
||||||
|
root: RootType,
|
||||||
|
collect_existing_paths: bool = False,
|
||||||
|
update_missing_tags: bool = False,
|
||||||
|
) -> set[str] | None:
|
||||||
|
"""Reconcile asset references with filesystem for a root.
|
||||||
|
|
||||||
|
- Toggle needs_verify per reference using fast mtime/size check
|
||||||
|
- For hashed assets with at least one fast-ok ref: delete stale missing refs
|
||||||
|
- For seed assets with all refs missing: delete Asset and its references
|
||||||
|
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
||||||
|
- Optionally return surviving absolute paths
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session: Database session
|
||||||
|
root: Root type to scan
|
||||||
|
collect_existing_paths: If True, return set of surviving file paths
|
||||||
|
update_missing_tags: If True, update 'missing' tags based on file status
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of surviving absolute paths if collect_existing_paths=True, else None
|
||||||
"""
|
"""
|
||||||
Scan the given roots and seed the assets into the database.
|
prefixes = get_prefixes_for_root(root)
|
||||||
|
if not prefixes:
|
||||||
|
return set() if collect_existing_paths else None
|
||||||
|
|
||||||
|
rows = get_references_for_prefixes(
|
||||||
|
session, prefixes, include_missing=update_missing_tags
|
||||||
|
)
|
||||||
|
|
||||||
|
by_asset: dict[str, _AssetAccumulator] = {}
|
||||||
|
for row in rows:
|
||||||
|
acc = by_asset.get(row.asset_id)
|
||||||
|
if acc is None:
|
||||||
|
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []}
|
||||||
|
by_asset[row.asset_id] = acc
|
||||||
|
|
||||||
|
fast_ok = False
|
||||||
|
try:
|
||||||
|
exists = True
|
||||||
|
fast_ok = verify_file_unchanged(
|
||||||
|
mtime_db=row.mtime_ns,
|
||||||
|
size_db=acc["size_db"],
|
||||||
|
stat_result=os.stat(row.file_path, follow_symlinks=True),
|
||||||
|
)
|
||||||
|
except FileNotFoundError:
|
||||||
|
exists = False
|
||||||
|
except PermissionError:
|
||||||
|
exists = True
|
||||||
|
logging.debug("Permission denied accessing %s", row.file_path)
|
||||||
|
except OSError as e:
|
||||||
|
exists = False
|
||||||
|
logging.debug("OSError checking %s: %s", row.file_path, e)
|
||||||
|
|
||||||
|
acc["refs"].append(
|
||||||
|
{
|
||||||
|
"ref_id": row.reference_id,
|
||||||
|
"fp": row.file_path,
|
||||||
|
"exists": exists,
|
||||||
|
"fast_ok": fast_ok,
|
||||||
|
"needs_verify": row.needs_verify,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
to_set_verify: list[str] = []
|
||||||
|
to_clear_verify: list[str] = []
|
||||||
|
stale_ref_ids: list[str] = []
|
||||||
|
to_mark_missing: list[str] = []
|
||||||
|
to_clear_missing: list[str] = []
|
||||||
|
survivors: set[str] = set()
|
||||||
|
|
||||||
|
for aid, acc in by_asset.items():
|
||||||
|
a_hash = acc["hash"]
|
||||||
|
refs = acc["refs"]
|
||||||
|
any_fast_ok = any(r["fast_ok"] for r in refs)
|
||||||
|
all_missing = all(not r["exists"] for r in refs)
|
||||||
|
|
||||||
|
for r in refs:
|
||||||
|
if not r["exists"]:
|
||||||
|
to_mark_missing.append(r["ref_id"])
|
||||||
|
continue
|
||||||
|
if r["fast_ok"]:
|
||||||
|
to_clear_missing.append(r["ref_id"])
|
||||||
|
if r["needs_verify"]:
|
||||||
|
to_clear_verify.append(r["ref_id"])
|
||||||
|
if not r["fast_ok"] and not r["needs_verify"]:
|
||||||
|
to_set_verify.append(r["ref_id"])
|
||||||
|
|
||||||
|
if a_hash is None:
|
||||||
|
if refs and all_missing:
|
||||||
|
delete_orphaned_seed_asset(session, aid)
|
||||||
|
else:
|
||||||
|
for r in refs:
|
||||||
|
if r["exists"]:
|
||||||
|
survivors.add(os.path.abspath(r["fp"]))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if any_fast_ok:
|
||||||
|
for r in refs:
|
||||||
|
if not r["exists"]:
|
||||||
|
stale_ref_ids.append(r["ref_id"])
|
||||||
|
if update_missing_tags:
|
||||||
|
try:
|
||||||
|
remove_missing_tag_for_asset_id(session, asset_id=aid)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(
|
||||||
|
"Failed to remove missing tag for asset %s: %s", aid, e
|
||||||
|
)
|
||||||
|
elif update_missing_tags:
|
||||||
|
try:
|
||||||
|
add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Failed to add missing tag for asset %s: %s", aid, e)
|
||||||
|
|
||||||
|
for r in refs:
|
||||||
|
if r["exists"]:
|
||||||
|
survivors.add(os.path.abspath(r["fp"]))
|
||||||
|
|
||||||
|
delete_references_by_ids(session, stale_ref_ids)
|
||||||
|
stale_set = set(stale_ref_ids)
|
||||||
|
to_mark_missing = [ref_id for ref_id in to_mark_missing if ref_id not in stale_set]
|
||||||
|
bulk_update_is_missing(session, to_mark_missing, value=True)
|
||||||
|
bulk_update_is_missing(session, to_clear_missing, value=False)
|
||||||
|
bulk_update_needs_verify(session, to_set_verify, value=True)
|
||||||
|
bulk_update_needs_verify(session, to_clear_verify, value=False)
|
||||||
|
|
||||||
|
return survivors if collect_existing_paths else None
|
||||||
|
|
||||||
|
|
||||||
|
def sync_root_safely(root: RootType) -> set[str]:
|
||||||
|
"""Sync a single root's references with the filesystem.
|
||||||
|
|
||||||
|
Returns survivors (existing paths) or empty set on failure.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with create_session() as sess:
|
||||||
|
survivors = sync_references_with_filesystem(
|
||||||
|
sess,
|
||||||
|
root,
|
||||||
|
collect_existing_paths=True,
|
||||||
|
update_missing_tags=True,
|
||||||
|
)
|
||||||
|
sess.commit()
|
||||||
|
return survivors or set()
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("fast DB scan failed for %s: %s", root, e)
|
||||||
|
return set()
|
||||||
|
|
||||||
|
|
||||||
|
def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int:
|
||||||
|
"""Mark references as missing when outside the given prefixes.
|
||||||
|
|
||||||
|
This is a non-destructive soft-delete. Returns count marked or 0 on failure.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with create_session() as sess:
|
||||||
|
count = mark_assets_missing_outside_prefixes(sess, prefixes)
|
||||||
|
sess.commit()
|
||||||
|
return count
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("marking missing assets failed: %s", e)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
|
||||||
|
"""Collect all file paths for the given roots."""
|
||||||
|
paths: list[str] = []
|
||||||
|
if "models" in roots:
|
||||||
|
paths.extend(collect_models_files())
|
||||||
|
if "input" in roots:
|
||||||
|
paths.extend(list_files_recursively(folder_paths.get_input_directory()))
|
||||||
|
if "output" in roots:
|
||||||
|
paths.extend(list_files_recursively(folder_paths.get_output_directory()))
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
|
def build_asset_specs(
|
||||||
|
paths: list[str],
|
||||||
|
existing_paths: set[str],
|
||||||
|
enable_metadata_extraction: bool = True,
|
||||||
|
compute_hashes: bool = False,
|
||||||
|
) -> tuple[list[SeedAssetSpec], set[str], int]:
|
||||||
|
"""Build asset specs from paths, returning (specs, tag_pool, skipped_count).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paths: List of file paths to process
|
||||||
|
existing_paths: Set of paths that already exist in the database
|
||||||
|
enable_metadata_extraction: If True, extract tier 1 & 2 metadata
|
||||||
|
compute_hashes: If True, compute blake3 hashes (slow for large files)
|
||||||
|
"""
|
||||||
|
specs: list[SeedAssetSpec] = []
|
||||||
|
tag_pool: set[str] = set()
|
||||||
|
skipped = 0
|
||||||
|
|
||||||
|
for p in paths:
|
||||||
|
abs_p = os.path.abspath(p)
|
||||||
|
if abs_p in existing_paths:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
stat_p = os.stat(abs_p, follow_symlinks=True)
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
if not stat_p.st_size:
|
||||||
|
continue
|
||||||
|
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||||
|
rel_fname = compute_relative_filename(abs_p)
|
||||||
|
|
||||||
|
# Extract metadata (tier 1: filesystem, tier 2: safetensors header)
|
||||||
|
metadata = None
|
||||||
|
if enable_metadata_extraction:
|
||||||
|
metadata = extract_file_metadata(
|
||||||
|
abs_p,
|
||||||
|
stat_result=stat_p,
|
||||||
|
enable_safetensors=True,
|
||||||
|
relative_filename=rel_fname,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute hash if requested
|
||||||
|
asset_hash: str | None = None
|
||||||
|
if compute_hashes:
|
||||||
|
try:
|
||||||
|
digest = compute_blake3_hash(abs_p)
|
||||||
|
asset_hash = "blake3:" + digest
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Failed to hash %s: %s", abs_p, e)
|
||||||
|
|
||||||
|
mime_type = metadata.content_type if metadata else None
|
||||||
|
specs.append(
|
||||||
|
{
|
||||||
|
"abs_path": abs_p,
|
||||||
|
"size_bytes": stat_p.st_size,
|
||||||
|
"mtime_ns": get_mtime_ns(stat_p),
|
||||||
|
"info_name": name,
|
||||||
|
"tags": tags,
|
||||||
|
"fname": rel_fname,
|
||||||
|
"metadata": metadata,
|
||||||
|
"hash": asset_hash,
|
||||||
|
"mime_type": mime_type,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tag_pool.update(tags)
|
||||||
|
|
||||||
|
return specs, tag_pool, skipped
|
||||||
|
|
||||||
|
|
||||||
|
def build_stub_specs(
|
||||||
|
paths: list[str],
|
||||||
|
existing_paths: set[str],
|
||||||
|
) -> tuple[list[SeedAssetSpec], set[str], int]:
|
||||||
|
"""Build minimal stub specs for fast phase scanning.
|
||||||
|
|
||||||
|
Only collects filesystem metadata (stat), no file content reading.
|
||||||
|
This is the fastest possible scan to populate the asset database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paths: List of file paths to process
|
||||||
|
existing_paths: Set of paths that already exist in the database
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (specs, tag_pool, skipped_count)
|
||||||
|
"""
|
||||||
|
specs: list[SeedAssetSpec] = []
|
||||||
|
tag_pool: set[str] = set()
|
||||||
|
skipped = 0
|
||||||
|
|
||||||
|
for p in paths:
|
||||||
|
abs_p = os.path.abspath(p)
|
||||||
|
if abs_p in existing_paths:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
stat_p = os.stat(abs_p, follow_symlinks=True)
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
if not stat_p.st_size:
|
||||||
|
continue
|
||||||
|
|
||||||
|
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||||
|
rel_fname = compute_relative_filename(abs_p)
|
||||||
|
|
||||||
|
specs.append(
|
||||||
|
{
|
||||||
|
"abs_path": abs_p,
|
||||||
|
"size_bytes": stat_p.st_size,
|
||||||
|
"mtime_ns": get_mtime_ns(stat_p),
|
||||||
|
"info_name": name,
|
||||||
|
"tags": tags,
|
||||||
|
"fname": rel_fname,
|
||||||
|
"metadata": None,
|
||||||
|
"hash": None,
|
||||||
|
"mime_type": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tag_pool.update(tags)
|
||||||
|
|
||||||
|
return specs, tag_pool, skipped
|
||||||
|
|
||||||
|
|
||||||
|
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
|
||||||
|
"""Insert asset specs into database, returning count of created refs."""
|
||||||
|
if not specs:
|
||||||
|
return 0
|
||||||
|
with create_session() as sess:
|
||||||
|
if tag_pool:
|
||||||
|
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||||
|
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
|
||||||
|
sess.commit()
|
||||||
|
return result.inserted_refs
|
||||||
|
|
||||||
|
|
||||||
|
def seed_assets(
|
||||||
|
roots: tuple[RootType, ...],
|
||||||
|
enable_logging: bool = False,
|
||||||
|
compute_hashes: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Scan the given roots and seed the assets into the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
roots: Tuple of root types to scan (models, input, output)
|
||||||
|
enable_logging: If True, log progress and completion messages
|
||||||
|
compute_hashes: If True, compute blake3 hashes (slow for large files)
|
||||||
|
|
||||||
|
Note: This function does not mark missing assets.
|
||||||
|
Call mark_missing_outside_prefixes_safely separately if cleanup is needed.
|
||||||
"""
|
"""
|
||||||
if not dependencies_available():
|
if not dependencies_available():
|
||||||
if enable_logging:
|
if enable_logging:
|
||||||
logging.warning("Database dependencies not available, skipping assets scan")
|
logging.warning("Database dependencies not available, skipping assets scan")
|
||||||
return
|
return
|
||||||
|
|
||||||
t_start = time.perf_counter()
|
t_start = time.perf_counter()
|
||||||
created = 0
|
|
||||||
skipped_existing = 0
|
|
||||||
orphans_pruned = 0
|
|
||||||
paths: list[str] = []
|
|
||||||
try:
|
|
||||||
existing_paths: set[str] = set()
|
|
||||||
for r in roots:
|
|
||||||
try:
|
|
||||||
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
|
|
||||||
if survivors:
|
|
||||||
existing_paths.update(survivors)
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
|
||||||
|
|
||||||
try:
|
existing_paths: set[str] = set()
|
||||||
orphans_pruned = _prune_orphaned_assets(roots)
|
for r in roots:
|
||||||
except Exception as e:
|
existing_paths.update(sync_root_safely(r))
|
||||||
logging.exception("orphan pruning failed: %s", e)
|
|
||||||
|
|
||||||
if "models" in roots:
|
paths = collect_paths_for_roots(roots)
|
||||||
paths.extend(collect_models_files())
|
specs, tag_pool, skipped_existing = build_asset_specs(
|
||||||
if "input" in roots:
|
paths, existing_paths, compute_hashes=compute_hashes
|
||||||
paths.extend(list_tree(folder_paths.get_input_directory()))
|
)
|
||||||
if "output" in roots:
|
created = insert_asset_specs(specs, tag_pool)
|
||||||
paths.extend(list_tree(folder_paths.get_output_directory()))
|
|
||||||
|
|
||||||
specs: list[dict] = []
|
if enable_logging:
|
||||||
tag_pool: set[str] = set()
|
logging.info(
|
||||||
for p in paths:
|
"Assets scan(roots=%s) completed in %.3fs "
|
||||||
abs_p = os.path.abspath(p)
|
"(created=%d, skipped_existing=%d, total_seen=%d)",
|
||||||
if abs_p in existing_paths:
|
roots,
|
||||||
skipped_existing += 1
|
time.perf_counter() - t_start,
|
||||||
continue
|
created,
|
||||||
try:
|
skipped_existing,
|
||||||
stat_p = os.stat(abs_p, follow_symlinks=False)
|
len(paths),
|
||||||
except OSError:
|
)
|
||||||
continue
|
|
||||||
# skip empty files
|
|
||||||
if not stat_p.st_size:
|
|
||||||
continue
|
|
||||||
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
|
||||||
specs.append(
|
|
||||||
{
|
|
||||||
"abs_path": abs_p,
|
|
||||||
"size_bytes": stat_p.st_size,
|
|
||||||
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
|
|
||||||
"info_name": name,
|
|
||||||
"tags": tags,
|
|
||||||
"fname": compute_relative_filename(abs_p),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for t in tags:
|
|
||||||
tag_pool.add(t)
|
|
||||||
# if no file specs, nothing to do
|
|
||||||
if not specs:
|
|
||||||
return
|
|
||||||
with create_session() as sess:
|
|
||||||
if tag_pool:
|
|
||||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
|
||||||
|
|
||||||
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
|
|
||||||
created += result["inserted_infos"]
|
|
||||||
sess.commit()
|
|
||||||
finally:
|
|
||||||
if enable_logging:
|
|
||||||
logging.info(
|
|
||||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
|
|
||||||
roots,
|
|
||||||
time.perf_counter() - t_start,
|
|
||||||
created,
|
|
||||||
skipped_existing,
|
|
||||||
orphans_pruned,
|
|
||||||
len(paths),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
|
# Enrichment level constants
|
||||||
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
|
ENRICHMENT_STUB = 0 # Fast scan: path, size, mtime only
|
||||||
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
|
ENRICHMENT_METADATA = 1 # Metadata extracted (safetensors header, mime type)
|
||||||
if not all_prefixes:
|
ENRICHMENT_HASHED = 2 # Hash computed (blake3)
|
||||||
return 0
|
|
||||||
|
|
||||||
def make_prefix_condition(prefix: str):
|
|
||||||
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
|
||||||
escaped, esc = escape_like_prefix(base)
|
|
||||||
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
|
|
||||||
|
|
||||||
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
|
|
||||||
|
|
||||||
orphan_subq = (
|
|
||||||
sqlalchemy.select(Asset.id)
|
|
||||||
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
|
|
||||||
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
|
|
||||||
).scalar_subquery()
|
|
||||||
|
|
||||||
with create_session() as sess:
|
|
||||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
|
|
||||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
|
|
||||||
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
|
|
||||||
sess.commit()
|
|
||||||
return result.rowcount
|
|
||||||
|
|
||||||
|
|
||||||
def _fast_db_consistency_pass(
|
def get_unenriched_assets_for_roots(
|
||||||
root: RootType,
|
roots: tuple[RootType, ...],
|
||||||
*,
|
max_level: int = ENRICHMENT_STUB,
|
||||||
collect_existing_paths: bool = False,
|
limit: int = 1000,
|
||||||
update_missing_tags: bool = False,
|
) -> list:
|
||||||
) -> set[str] | None:
|
"""Get assets that need enrichment for the given roots.
|
||||||
"""Fast DB+FS pass for a root:
|
|
||||||
- Toggle needs_verify per state using fast check
|
Args:
|
||||||
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
|
roots: Tuple of root types to scan
|
||||||
- For seed assets with all states missing: delete Asset and its AssetInfos
|
max_level: Maximum enrichment level to include
|
||||||
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
limit: Maximum number of rows to return
|
||||||
- Optionally return surviving absolute paths
|
|
||||||
|
Returns:
|
||||||
|
List of UnenrichedReferenceRow
|
||||||
"""
|
"""
|
||||||
prefixes = prefixes_for_root(root)
|
prefixes: list[str] = []
|
||||||
if not prefixes:
|
for root in roots:
|
||||||
return set() if collect_existing_paths else None
|
prefixes.extend(get_prefixes_for_root(root))
|
||||||
|
|
||||||
conds = []
|
if not prefixes:
|
||||||
for p in prefixes:
|
return []
|
||||||
base = os.path.abspath(p)
|
|
||||||
if not base.endswith(os.sep):
|
|
||||||
base += os.sep
|
|
||||||
escaped, esc = escape_like_prefix(base)
|
|
||||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
|
||||||
|
|
||||||
with create_session() as sess:
|
with create_session() as sess:
|
||||||
rows = (
|
return get_unenriched_references(
|
||||||
sess.execute(
|
sess, prefixes, max_level=max_level, limit=limit
|
||||||
sqlalchemy.select(
|
)
|
||||||
AssetCacheState.id,
|
|
||||||
AssetCacheState.file_path,
|
|
||||||
AssetCacheState.mtime_ns,
|
|
||||||
AssetCacheState.needs_verify,
|
|
||||||
AssetCacheState.asset_id,
|
|
||||||
Asset.hash,
|
|
||||||
Asset.size_bytes,
|
|
||||||
)
|
|
||||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
|
||||||
.where(sqlalchemy.or_(*conds))
|
|
||||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
|
|
||||||
by_asset: dict[str, dict] = {}
|
|
||||||
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
|
|
||||||
acc = by_asset.get(aid)
|
|
||||||
if acc is None:
|
|
||||||
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
|
|
||||||
by_asset[aid] = acc
|
|
||||||
|
|
||||||
fast_ok = False
|
def enrich_asset(
|
||||||
try:
|
file_path: str,
|
||||||
exists = True
|
reference_id: str,
|
||||||
fast_ok = fast_asset_file_check(
|
asset_id: str,
|
||||||
mtime_db=mtime_db,
|
extract_metadata: bool = True,
|
||||||
size_db=acc["size_db"],
|
compute_hash: bool = False,
|
||||||
stat_result=os.stat(fp, follow_symlinks=True),
|
) -> int:
|
||||||
)
|
"""Enrich a single asset with metadata and/or hash.
|
||||||
except FileNotFoundError:
|
|
||||||
exists = False
|
|
||||||
except OSError:
|
|
||||||
exists = False
|
|
||||||
|
|
||||||
acc["states"].append({
|
Args:
|
||||||
"sid": sid,
|
file_path: Absolute path to the file
|
||||||
"fp": fp,
|
reference_id: ID of the reference to update
|
||||||
"exists": exists,
|
asset_id: ID of the asset to update (for mime_type and hash)
|
||||||
"fast_ok": fast_ok,
|
extract_metadata: If True, extract safetensors header and mime type
|
||||||
"needs_verify": bool(needs_verify),
|
compute_hash: If True, compute blake3 hash
|
||||||
})
|
|
||||||
|
|
||||||
to_set_verify: list[int] = []
|
Returns:
|
||||||
to_clear_verify: list[int] = []
|
New enrichment level achieved
|
||||||
stale_state_ids: list[int] = []
|
"""
|
||||||
survivors: set[str] = set()
|
new_level = ENRICHMENT_STUB
|
||||||
|
|
||||||
for aid, acc in by_asset.items():
|
try:
|
||||||
a_hash = acc["hash"]
|
stat_p = os.stat(file_path, follow_symlinks=True)
|
||||||
states = acc["states"]
|
except OSError:
|
||||||
any_fast_ok = any(s["fast_ok"] for s in states)
|
return new_level
|
||||||
all_missing = all(not s["exists"] for s in states)
|
|
||||||
|
|
||||||
for s in states:
|
rel_fname = compute_relative_filename(file_path)
|
||||||
if not s["exists"]:
|
mime_type: str | None = None
|
||||||
continue
|
metadata = None
|
||||||
if s["fast_ok"] and s["needs_verify"]:
|
|
||||||
to_clear_verify.append(s["sid"])
|
|
||||||
if not s["fast_ok"] and not s["needs_verify"]:
|
|
||||||
to_set_verify.append(s["sid"])
|
|
||||||
|
|
||||||
if a_hash is None:
|
if extract_metadata:
|
||||||
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
|
metadata = extract_file_metadata(
|
||||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
|
file_path,
|
||||||
asset = sess.get(Asset, aid)
|
stat_result=stat_p,
|
||||||
if asset:
|
enable_safetensors=True,
|
||||||
sess.delete(asset)
|
relative_filename=rel_fname,
|
||||||
else:
|
)
|
||||||
for s in states:
|
if metadata:
|
||||||
if s["exists"]:
|
mime_type = metadata.content_type
|
||||||
survivors.add(os.path.abspath(s["fp"]))
|
new_level = ENRICHMENT_METADATA
|
||||||
continue
|
|
||||||
|
|
||||||
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
|
full_hash: str | None = None
|
||||||
for s in states:
|
if compute_hash:
|
||||||
if not s["exists"]:
|
try:
|
||||||
stale_state_ids.append(s["sid"])
|
digest = compute_blake3_hash(file_path)
|
||||||
if update_missing_tags:
|
full_hash = f"blake3:{digest}"
|
||||||
with contextlib.suppress(Exception):
|
if not extract_metadata or metadata:
|
||||||
remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
new_level = ENRICHMENT_HASHED
|
||||||
elif update_missing_tags:
|
except Exception as e:
|
||||||
with contextlib.suppress(Exception):
|
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||||
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
|
||||||
|
|
||||||
for s in states:
|
with create_session() as sess:
|
||||||
if s["exists"]:
|
if extract_metadata and metadata:
|
||||||
survivors.add(os.path.abspath(s["fp"]))
|
user_metadata = metadata.to_user_metadata()
|
||||||
|
set_reference_metadata(sess, reference_id, user_metadata)
|
||||||
|
|
||||||
if stale_state_ids:
|
if full_hash:
|
||||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
|
existing = get_asset_by_hash(sess, full_hash)
|
||||||
if to_set_verify:
|
if existing and existing.id != asset_id:
|
||||||
sess.execute(
|
reassign_asset_references(sess, asset_id, existing.id, reference_id)
|
||||||
sqlalchemy.update(AssetCacheState)
|
delete_orphaned_seed_asset(sess, asset_id)
|
||||||
.where(AssetCacheState.id.in_(to_set_verify))
|
if mime_type:
|
||||||
.values(needs_verify=True)
|
update_asset_hash_and_mime(sess, existing.id, mime_type=mime_type)
|
||||||
)
|
else:
|
||||||
if to_clear_verify:
|
update_asset_hash_and_mime(sess, asset_id, full_hash, mime_type)
|
||||||
sess.execute(
|
elif mime_type:
|
||||||
sqlalchemy.update(AssetCacheState)
|
update_asset_hash_and_mime(sess, asset_id, mime_type=mime_type)
|
||||||
.where(AssetCacheState.id.in_(to_clear_verify))
|
|
||||||
.values(needs_verify=False)
|
bulk_update_enrichment_level(sess, [reference_id], new_level)
|
||||||
)
|
|
||||||
sess.commit()
|
sess.commit()
|
||||||
return survivors if collect_existing_paths else None
|
|
||||||
|
return new_level
|
||||||
|
|
||||||
|
|
||||||
|
def enrich_assets_batch(
|
||||||
|
rows: list,
|
||||||
|
extract_metadata: bool = True,
|
||||||
|
compute_hash: bool = False,
|
||||||
|
) -> tuple[int, list[str]]:
|
||||||
|
"""Enrich a batch of assets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots
|
||||||
|
extract_metadata: If True, extract metadata for each asset
|
||||||
|
compute_hash: If True, compute hash for each asset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (enriched_count, failed_reference_ids)
|
||||||
|
"""
|
||||||
|
enriched = 0
|
||||||
|
failed_ids: list[str] = []
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
new_level = enrich_asset(
|
||||||
|
file_path=row.file_path,
|
||||||
|
reference_id=row.reference_id,
|
||||||
|
asset_id=row.asset_id,
|
||||||
|
extract_metadata=extract_metadata,
|
||||||
|
compute_hash=compute_hash,
|
||||||
|
)
|
||||||
|
if new_level > row.enrichment_level:
|
||||||
|
enriched += 1
|
||||||
|
else:
|
||||||
|
failed_ids.append(row.reference_id)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning("Failed to enrich %s: %s", row.file_path, e)
|
||||||
|
failed_ids.append(row.reference_id)
|
||||||
|
|
||||||
|
return enriched, failed_ids
|
||||||
|
|||||||
743
app/assets/seeder.py
Normal file
743
app/assets/seeder.py
Normal file
@ -0,0 +1,743 @@
|
|||||||
|
"""Background asset seeder with thread management and cancellation support."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from app.assets.scanner import (
|
||||||
|
ENRICHMENT_METADATA,
|
||||||
|
ENRICHMENT_STUB,
|
||||||
|
RootType,
|
||||||
|
build_stub_specs,
|
||||||
|
collect_paths_for_roots,
|
||||||
|
enrich_assets_batch,
|
||||||
|
get_all_known_prefixes,
|
||||||
|
get_prefixes_for_root,
|
||||||
|
get_unenriched_assets_for_roots,
|
||||||
|
insert_asset_specs,
|
||||||
|
mark_missing_outside_prefixes_safely,
|
||||||
|
sync_root_safely,
|
||||||
|
)
|
||||||
|
from app.database.db import dependencies_available
|
||||||
|
|
||||||
|
|
||||||
|
class ScanInProgressError(Exception):
|
||||||
|
"""Raised when an operation cannot proceed because a scan is running."""
|
||||||
|
|
||||||
|
|
||||||
|
class State(Enum):
|
||||||
|
"""Seeder state machine states."""
|
||||||
|
|
||||||
|
IDLE = "IDLE"
|
||||||
|
RUNNING = "RUNNING"
|
||||||
|
PAUSED = "PAUSED"
|
||||||
|
CANCELLING = "CANCELLING"
|
||||||
|
|
||||||
|
|
||||||
|
class ScanPhase(Enum):
|
||||||
|
"""Scan phase options."""
|
||||||
|
|
||||||
|
FAST = "fast" # Phase 1: filesystem only (stubs)
|
||||||
|
ENRICH = "enrich" # Phase 2: metadata + hash
|
||||||
|
FULL = "full" # Both phases sequentially
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Progress:
|
||||||
|
"""Progress information for a scan operation."""
|
||||||
|
|
||||||
|
scanned: int = 0
|
||||||
|
total: int = 0
|
||||||
|
created: int = 0
|
||||||
|
skipped: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScanStatus:
|
||||||
|
"""Current status of the asset seeder."""
|
||||||
|
|
||||||
|
state: State
|
||||||
|
progress: Progress | None
|
||||||
|
errors: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
ProgressCallback = Callable[[Progress], None]
|
||||||
|
|
||||||
|
|
||||||
|
class AssetSeeder:
|
||||||
|
"""Singleton class managing background asset scanning.
|
||||||
|
|
||||||
|
Thread-safe singleton that spawns ephemeral daemon threads for scanning.
|
||||||
|
Each scan creates a new thread that exits when complete.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: "AssetSeeder | None" = None
|
||||||
|
_instance_lock = threading.Lock()
|
||||||
|
|
||||||
|
def __new__(cls) -> "AssetSeeder":
|
||||||
|
with cls._instance_lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
self._initialized = True
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._state = State.IDLE
|
||||||
|
self._progress: Progress | None = None
|
||||||
|
self._errors: list[str] = []
|
||||||
|
self._thread: threading.Thread | None = None
|
||||||
|
self._cancel_event = threading.Event()
|
||||||
|
self._pause_event = threading.Event()
|
||||||
|
self._pause_event.set() # Start unpaused (set = running, clear = paused)
|
||||||
|
self._roots: tuple[RootType, ...] = ()
|
||||||
|
self._phase: ScanPhase = ScanPhase.FULL
|
||||||
|
self._compute_hashes: bool = False
|
||||||
|
self._progress_callback: ProgressCallback | None = None
|
||||||
|
self._disabled: bool = False
|
||||||
|
|
||||||
|
def disable(self) -> None:
|
||||||
|
"""Disable the asset seeder, preventing any scans from starting."""
|
||||||
|
self._disabled = True
|
||||||
|
logging.info("Asset seeder disabled")
|
||||||
|
|
||||||
|
def enable(self) -> None:
|
||||||
|
"""Enable the asset seeder, allowing scans to start."""
|
||||||
|
self._disabled = False
|
||||||
|
logging.info("Asset seeder enabled")
|
||||||
|
|
||||||
|
def is_disabled(self) -> bool:
|
||||||
|
"""Check if the asset seeder is disabled."""
|
||||||
|
return self._disabled
|
||||||
|
|
||||||
|
def start(
|
||||||
|
self,
|
||||||
|
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||||
|
phase: ScanPhase = ScanPhase.FULL,
|
||||||
|
progress_callback: ProgressCallback | None = None,
|
||||||
|
prune_first: bool = False,
|
||||||
|
compute_hashes: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""Start a background scan for the given roots.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
roots: Tuple of root types to scan (models, input, output)
|
||||||
|
phase: Scan phase to run (FAST, ENRICH, or FULL for both)
|
||||||
|
progress_callback: Optional callback called with progress updates
|
||||||
|
prune_first: If True, prune orphaned assets before scanning
|
||||||
|
compute_hashes: If True, compute blake3 hashes (slow)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if scan was started, False if already running
|
||||||
|
"""
|
||||||
|
if self._disabled:
|
||||||
|
logging.debug("Asset seeder is disabled, skipping start")
|
||||||
|
return False
|
||||||
|
logging.info("Seeder start (roots=%s, phase=%s)", roots, phase.value)
|
||||||
|
with self._lock:
|
||||||
|
if self._state != State.IDLE:
|
||||||
|
logging.info("Asset seeder already running, skipping start")
|
||||||
|
return False
|
||||||
|
self._state = State.RUNNING
|
||||||
|
self._progress = Progress()
|
||||||
|
self._errors = []
|
||||||
|
self._roots = roots
|
||||||
|
self._phase = phase
|
||||||
|
self._prune_first = prune_first
|
||||||
|
self._compute_hashes = compute_hashes
|
||||||
|
self._progress_callback = progress_callback
|
||||||
|
self._cancel_event.clear()
|
||||||
|
self._pause_event.set() # Ensure unpaused when starting
|
||||||
|
self._thread = threading.Thread(
|
||||||
|
target=self._run_scan,
|
||||||
|
name="AssetSeeder",
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
self._thread.start()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def start_fast(
|
||||||
|
self,
|
||||||
|
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||||
|
progress_callback: ProgressCallback | None = None,
|
||||||
|
prune_first: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""Start a fast scan (phase 1 only) - creates stub records.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
roots: Tuple of root types to scan
|
||||||
|
progress_callback: Optional callback for progress updates
|
||||||
|
prune_first: If True, prune orphaned assets before scanning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if scan was started, False if already running
|
||||||
|
"""
|
||||||
|
return self.start(
|
||||||
|
roots=roots,
|
||||||
|
phase=ScanPhase.FAST,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
prune_first=prune_first,
|
||||||
|
compute_hashes=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def start_enrich(
|
||||||
|
self,
|
||||||
|
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||||
|
progress_callback: ProgressCallback | None = None,
|
||||||
|
compute_hashes: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""Start an enrichment scan (phase 2 only) - extracts metadata and hashes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
roots: Tuple of root types to scan
|
||||||
|
progress_callback: Optional callback for progress updates
|
||||||
|
compute_hashes: If True, compute blake3 hashes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if scan was started, False if already running
|
||||||
|
"""
|
||||||
|
return self.start(
|
||||||
|
roots=roots,
|
||||||
|
phase=ScanPhase.ENRICH,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
prune_first=False,
|
||||||
|
compute_hashes=compute_hashes,
|
||||||
|
)
|
||||||
|
|
||||||
|
def cancel(self) -> bool:
|
||||||
|
"""Request cancellation of the current scan.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if cancellation was requested, False if not running or paused
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self._state not in (State.RUNNING, State.PAUSED):
|
||||||
|
return False
|
||||||
|
logging.info("Asset seeder cancelling (was %s)", self._state.value)
|
||||||
|
self._state = State.CANCELLING
|
||||||
|
self._cancel_event.set()
|
||||||
|
self._pause_event.set() # Unblock if paused so thread can exit
|
||||||
|
return True
|
||||||
|
|
||||||
|
def stop(self) -> bool:
|
||||||
|
"""Stop the current scan (alias for cancel).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if stop was requested, False if not running
|
||||||
|
"""
|
||||||
|
return self.cancel()
|
||||||
|
|
||||||
|
def pause(self) -> bool:
|
||||||
|
"""Pause the current scan.
|
||||||
|
|
||||||
|
The scan will complete its current batch before pausing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if pause was requested, False if not running
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self._state != State.RUNNING:
|
||||||
|
return False
|
||||||
|
logging.info("Asset seeder pausing")
|
||||||
|
self._state = State.PAUSED
|
||||||
|
self._pause_event.clear()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def resume(self) -> bool:
|
||||||
|
"""Resume a paused scan.
|
||||||
|
|
||||||
|
This is a noop if the scan is not in the PAUSED state
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if resumed, False if not paused
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self._state != State.PAUSED:
|
||||||
|
return False
|
||||||
|
logging.info("Asset seeder resuming")
|
||||||
|
self._state = State.RUNNING
|
||||||
|
self._pause_event.set()
|
||||||
|
self._emit_event("assets.seed.resumed", {})
|
||||||
|
return True
|
||||||
|
|
||||||
|
def restart(
|
||||||
|
self,
|
||||||
|
roots: tuple[RootType, ...] | None = None,
|
||||||
|
phase: ScanPhase | None = None,
|
||||||
|
progress_callback: ProgressCallback | None = None,
|
||||||
|
prune_first: bool | None = None,
|
||||||
|
compute_hashes: bool | None = None,
|
||||||
|
timeout: float = 5.0,
|
||||||
|
) -> bool:
|
||||||
|
"""Cancel any running scan and start a new one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
roots: Roots to scan (defaults to previous roots)
|
||||||
|
phase: Scan phase (defaults to previous phase)
|
||||||
|
progress_callback: Progress callback (defaults to previous)
|
||||||
|
prune_first: Prune before scan (defaults to previous)
|
||||||
|
compute_hashes: Compute hashes (defaults to previous)
|
||||||
|
timeout: Max seconds to wait for current scan to stop
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if new scan was started, False if failed to stop previous
|
||||||
|
"""
|
||||||
|
logging.info("Asset seeder restart requested")
|
||||||
|
with self._lock:
|
||||||
|
prev_roots = self._roots
|
||||||
|
prev_phase = self._phase
|
||||||
|
prev_callback = self._progress_callback
|
||||||
|
prev_prune = getattr(self, "_prune_first", False)
|
||||||
|
prev_hashes = self._compute_hashes
|
||||||
|
|
||||||
|
self.cancel()
|
||||||
|
if not self.wait(timeout=timeout):
|
||||||
|
return False
|
||||||
|
|
||||||
|
cb = progress_callback if progress_callback is not None else prev_callback
|
||||||
|
return self.start(
|
||||||
|
roots=roots if roots is not None else prev_roots,
|
||||||
|
phase=phase if phase is not None else prev_phase,
|
||||||
|
progress_callback=cb,
|
||||||
|
prune_first=prune_first if prune_first is not None else prev_prune,
|
||||||
|
compute_hashes=(
|
||||||
|
compute_hashes if compute_hashes is not None else prev_hashes
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def wait(self, timeout: float | None = None) -> bool:
|
||||||
|
"""Wait for the current scan to complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum seconds to wait, or None for no timeout
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if scan completed, False if timeout expired or no scan running
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
thread = self._thread
|
||||||
|
if thread is None:
|
||||||
|
return True
|
||||||
|
thread.join(timeout=timeout)
|
||||||
|
return not thread.is_alive()
|
||||||
|
|
||||||
|
def get_status(self) -> ScanStatus:
|
||||||
|
"""Get the current status and progress of the seeder."""
|
||||||
|
with self._lock:
|
||||||
|
return ScanStatus(
|
||||||
|
state=self._state,
|
||||||
|
progress=Progress(
|
||||||
|
scanned=self._progress.scanned,
|
||||||
|
total=self._progress.total,
|
||||||
|
created=self._progress.created,
|
||||||
|
skipped=self._progress.skipped,
|
||||||
|
)
|
||||||
|
if self._progress
|
||||||
|
else None,
|
||||||
|
errors=list(self._errors),
|
||||||
|
)
|
||||||
|
|
||||||
|
def shutdown(self, timeout: float = 5.0) -> None:
|
||||||
|
"""Gracefully shutdown: cancel any running scan and wait for thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum seconds to wait for thread to exit
|
||||||
|
"""
|
||||||
|
self.cancel()
|
||||||
|
self.wait(timeout=timeout)
|
||||||
|
with self._lock:
|
||||||
|
self._thread = None
|
||||||
|
|
||||||
|
def mark_missing_outside_prefixes(self) -> int:
|
||||||
|
"""Mark cache states as missing when outside all known root prefixes.
|
||||||
|
|
||||||
|
This is a non-destructive soft-delete operation. Assets and their
|
||||||
|
metadata are preserved, but cache states are flagged as missing.
|
||||||
|
They can be restored if the file reappears in a future scan.
|
||||||
|
|
||||||
|
This operation is decoupled from scanning to prevent partial scans
|
||||||
|
from accidentally marking assets belonging to other roots.
|
||||||
|
|
||||||
|
Should be called explicitly when cleanup is desired, typically after
|
||||||
|
a full scan of all roots or during maintenance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of cache states marked as missing
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ScanInProgressError: If a scan is currently running
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self._state != State.IDLE:
|
||||||
|
raise ScanInProgressError("Cannot mark missing assets while scan is running")
|
||||||
|
self._state = State.RUNNING
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not dependencies_available():
|
||||||
|
logging.warning(
|
||||||
|
"Database dependencies not available, skipping mark missing"
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
all_prefixes = get_all_known_prefixes()
|
||||||
|
marked = mark_missing_outside_prefixes_safely(all_prefixes)
|
||||||
|
if marked > 0:
|
||||||
|
logging.info("Marked %d cache states as missing", marked)
|
||||||
|
return marked
|
||||||
|
finally:
|
||||||
|
with self._lock:
|
||||||
|
self._state = State.IDLE
|
||||||
|
|
||||||
|
def _is_cancelled(self) -> bool:
|
||||||
|
"""Check if cancellation has been requested."""
|
||||||
|
return self._cancel_event.is_set()
|
||||||
|
|
||||||
|
def _check_pause_and_cancel(self) -> bool:
|
||||||
|
"""Block while paused, then check if cancelled.
|
||||||
|
|
||||||
|
Call this at checkpoint locations in scan loops. It will:
|
||||||
|
1. Block indefinitely while paused (until resume or cancel)
|
||||||
|
2. Return True if cancelled, False to continue
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if scan should stop, False to continue
|
||||||
|
"""
|
||||||
|
if not self._pause_event.is_set():
|
||||||
|
self._emit_event("assets.seed.paused", {})
|
||||||
|
self._pause_event.wait() # Blocks if paused
|
||||||
|
return self._is_cancelled()
|
||||||
|
|
||||||
|
def _emit_event(self, event_type: str, data: dict) -> None:
|
||||||
|
"""Emit a WebSocket event if server is available."""
|
||||||
|
try:
|
||||||
|
from server import PromptServer
|
||||||
|
|
||||||
|
if hasattr(PromptServer, "instance") and PromptServer.instance:
|
||||||
|
PromptServer.instance.send_sync(event_type, data)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _update_progress(
|
||||||
|
self,
|
||||||
|
scanned: int | None = None,
|
||||||
|
total: int | None = None,
|
||||||
|
created: int | None = None,
|
||||||
|
skipped: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Update progress counters (thread-safe)."""
|
||||||
|
callback: ProgressCallback | None = None
|
||||||
|
progress: Progress | None = None
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if self._progress is None:
|
||||||
|
return
|
||||||
|
if scanned is not None:
|
||||||
|
self._progress.scanned = scanned
|
||||||
|
if total is not None:
|
||||||
|
self._progress.total = total
|
||||||
|
if created is not None:
|
||||||
|
self._progress.created = created
|
||||||
|
if skipped is not None:
|
||||||
|
self._progress.skipped = skipped
|
||||||
|
if self._progress_callback:
|
||||||
|
callback = self._progress_callback
|
||||||
|
progress = Progress(
|
||||||
|
scanned=self._progress.scanned,
|
||||||
|
total=self._progress.total,
|
||||||
|
created=self._progress.created,
|
||||||
|
skipped=self._progress.skipped,
|
||||||
|
)
|
||||||
|
|
||||||
|
if callback and progress:
|
||||||
|
try:
|
||||||
|
callback(progress)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _add_error(self, message: str) -> None:
|
||||||
|
"""Add an error message (thread-safe)."""
|
||||||
|
with self._lock:
|
||||||
|
self._errors.append(message)
|
||||||
|
|
||||||
|
def _log_scan_config(self, roots: tuple[RootType, ...]) -> None:
|
||||||
|
"""Log the directories that will be scanned."""
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
|
for root in roots:
|
||||||
|
if root == "models":
|
||||||
|
logging.info(
|
||||||
|
"Asset scan [models] directory: %s",
|
||||||
|
os.path.abspath(folder_paths.models_dir),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefixes = get_prefixes_for_root(root)
|
||||||
|
if prefixes:
|
||||||
|
logging.info("Asset scan [%s] directories: %s", root, prefixes)
|
||||||
|
|
||||||
|
def _run_scan(self) -> None:
|
||||||
|
"""Main scan loop running in background thread."""
|
||||||
|
t_start = time.perf_counter()
|
||||||
|
roots = self._roots
|
||||||
|
phase = self._phase
|
||||||
|
cancelled = False
|
||||||
|
total_created = 0
|
||||||
|
total_enriched = 0
|
||||||
|
skipped_existing = 0
|
||||||
|
total_paths = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not dependencies_available():
|
||||||
|
self._add_error("Database dependencies not available")
|
||||||
|
self._emit_event(
|
||||||
|
"assets.seed.error",
|
||||||
|
{"message": "Database dependencies not available"},
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._prune_first:
|
||||||
|
all_prefixes = get_all_known_prefixes()
|
||||||
|
marked = mark_missing_outside_prefixes_safely(all_prefixes)
|
||||||
|
if marked > 0:
|
||||||
|
logging.info("Marked %d refs as missing before scan", marked)
|
||||||
|
|
||||||
|
if self._check_pause_and_cancel():
|
||||||
|
logging.info("Asset scan cancelled after pruning phase")
|
||||||
|
cancelled = True
|
||||||
|
return
|
||||||
|
|
||||||
|
self._log_scan_config(roots)
|
||||||
|
|
||||||
|
# Phase 1: Fast scan (stub records)
|
||||||
|
if phase in (ScanPhase.FAST, ScanPhase.FULL):
|
||||||
|
created, skipped, paths = self._run_fast_phase(roots)
|
||||||
|
total_created, skipped_existing, total_paths = created, skipped, paths
|
||||||
|
|
||||||
|
if self._check_pause_and_cancel():
|
||||||
|
cancelled = True
|
||||||
|
return
|
||||||
|
|
||||||
|
self._emit_event(
|
||||||
|
"assets.seed.fast_complete",
|
||||||
|
{
|
||||||
|
"roots": list(roots),
|
||||||
|
"created": total_created,
|
||||||
|
"skipped": skipped_existing,
|
||||||
|
"total": total_paths,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 2: Enrichment scan (metadata + hashes)
|
||||||
|
if phase in (ScanPhase.ENRICH, ScanPhase.FULL):
|
||||||
|
if self._check_pause_and_cancel():
|
||||||
|
cancelled = True
|
||||||
|
return
|
||||||
|
|
||||||
|
total_enriched = self._run_enrich_phase(roots)
|
||||||
|
|
||||||
|
self._emit_event(
|
||||||
|
"assets.seed.enrich_complete",
|
||||||
|
{
|
||||||
|
"roots": list(roots),
|
||||||
|
"enriched": total_enriched,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed = time.perf_counter() - t_start
|
||||||
|
logging.info(
|
||||||
|
"Scan(%s, %s) done %.3fs: created=%d enriched=%d skipped=%d",
|
||||||
|
roots, phase.value, elapsed, total_created, total_enriched,
|
||||||
|
skipped_existing,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._emit_event(
|
||||||
|
"assets.seed.completed",
|
||||||
|
{
|
||||||
|
"phase": phase.value,
|
||||||
|
"total": total_paths,
|
||||||
|
"created": total_created,
|
||||||
|
"enriched": total_enriched,
|
||||||
|
"skipped": skipped_existing,
|
||||||
|
"elapsed": round(elapsed, 3),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._add_error(f"Scan failed: {e}")
|
||||||
|
logging.exception("Asset scan failed")
|
||||||
|
self._emit_event("assets.seed.error", {"message": str(e)})
|
||||||
|
finally:
|
||||||
|
if cancelled:
|
||||||
|
self._emit_event(
|
||||||
|
"assets.seed.cancelled",
|
||||||
|
{
|
||||||
|
"scanned": self._progress.scanned if self._progress else 0,
|
||||||
|
"total": total_paths,
|
||||||
|
"created": total_created,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with self._lock:
|
||||||
|
self._state = State.IDLE
|
||||||
|
|
||||||
|
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
||||||
|
"""Run phase 1: fast scan to create stub records.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (total_created, skipped_existing, total_paths)
|
||||||
|
"""
|
||||||
|
total_created = 0
|
||||||
|
skipped_existing = 0
|
||||||
|
|
||||||
|
existing_paths: set[str] = set()
|
||||||
|
for r in roots:
|
||||||
|
if self._check_pause_and_cancel():
|
||||||
|
return total_created, skipped_existing, 0
|
||||||
|
existing_paths.update(sync_root_safely(r))
|
||||||
|
|
||||||
|
if self._check_pause_and_cancel():
|
||||||
|
return total_created, skipped_existing, 0
|
||||||
|
|
||||||
|
paths = collect_paths_for_roots(roots)
|
||||||
|
total_paths = len(paths)
|
||||||
|
self._update_progress(total=total_paths)
|
||||||
|
|
||||||
|
self._emit_event(
|
||||||
|
"assets.seed.started",
|
||||||
|
{"roots": list(roots), "total": total_paths, "phase": "fast"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use stub specs (no metadata extraction, no hashing)
|
||||||
|
specs, tag_pool, skipped_existing = build_stub_specs(paths, existing_paths)
|
||||||
|
self._update_progress(skipped=skipped_existing)
|
||||||
|
|
||||||
|
if self._check_pause_and_cancel():
|
||||||
|
return total_created, skipped_existing, total_paths
|
||||||
|
|
||||||
|
batch_size = 500
|
||||||
|
last_progress_time = time.perf_counter()
|
||||||
|
progress_interval = 1.0
|
||||||
|
|
||||||
|
for i in range(0, len(specs), batch_size):
|
||||||
|
if self._check_pause_and_cancel():
|
||||||
|
logging.info(
|
||||||
|
"Fast scan cancelled after %d/%d files (created=%d)",
|
||||||
|
i,
|
||||||
|
len(specs),
|
||||||
|
total_created,
|
||||||
|
)
|
||||||
|
return total_created, skipped_existing, total_paths
|
||||||
|
|
||||||
|
batch = specs[i : i + batch_size]
|
||||||
|
batch_tags = {t for spec in batch for t in spec["tags"]}
|
||||||
|
try:
|
||||||
|
created = insert_asset_specs(batch, batch_tags)
|
||||||
|
total_created += created
|
||||||
|
except Exception as e:
|
||||||
|
self._add_error(f"Batch insert failed at offset {i}: {e}")
|
||||||
|
logging.exception("Batch insert failed at offset %d", i)
|
||||||
|
|
||||||
|
scanned = i + len(batch)
|
||||||
|
now = time.perf_counter()
|
||||||
|
self._update_progress(scanned=scanned, created=total_created)
|
||||||
|
|
||||||
|
if now - last_progress_time >= progress_interval:
|
||||||
|
self._emit_event(
|
||||||
|
"assets.seed.progress",
|
||||||
|
{
|
||||||
|
"phase": "fast",
|
||||||
|
"scanned": scanned,
|
||||||
|
"total": len(specs),
|
||||||
|
"created": total_created,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
last_progress_time = now
|
||||||
|
|
||||||
|
self._update_progress(scanned=len(specs), created=total_created)
|
||||||
|
return total_created, skipped_existing, total_paths
|
||||||
|
|
||||||
|
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> int:
|
||||||
|
"""Run phase 2: enrich existing records with metadata and hashes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total number of assets enriched
|
||||||
|
"""
|
||||||
|
total_enriched = 0
|
||||||
|
batch_size = 100
|
||||||
|
last_progress_time = time.perf_counter()
|
||||||
|
progress_interval = 1.0
|
||||||
|
|
||||||
|
# Get the target enrichment level based on compute_hashes
|
||||||
|
if not self._compute_hashes:
|
||||||
|
target_max_level = ENRICHMENT_STUB
|
||||||
|
else:
|
||||||
|
target_max_level = ENRICHMENT_METADATA
|
||||||
|
|
||||||
|
self._emit_event(
|
||||||
|
"assets.seed.started",
|
||||||
|
{"roots": list(roots), "phase": "enrich"},
|
||||||
|
)
|
||||||
|
|
||||||
|
skip_ids: set[str] = set()
|
||||||
|
consecutive_empty = 0
|
||||||
|
max_consecutive_empty = 3
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if self._check_pause_and_cancel():
|
||||||
|
logging.info("Enrich scan cancelled after %d assets", total_enriched)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Fetch next batch of unenriched assets
|
||||||
|
unenriched = get_unenriched_assets_for_roots(
|
||||||
|
roots,
|
||||||
|
max_level=target_max_level,
|
||||||
|
limit=batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter out previously failed references
|
||||||
|
if skip_ids:
|
||||||
|
unenriched = [r for r in unenriched if r.reference_id not in skip_ids]
|
||||||
|
|
||||||
|
if not unenriched:
|
||||||
|
break
|
||||||
|
|
||||||
|
enriched, failed_ids = enrich_assets_batch(
|
||||||
|
unenriched,
|
||||||
|
extract_metadata=True,
|
||||||
|
compute_hash=self._compute_hashes,
|
||||||
|
)
|
||||||
|
total_enriched += enriched
|
||||||
|
skip_ids.update(failed_ids)
|
||||||
|
|
||||||
|
if enriched == 0:
|
||||||
|
consecutive_empty += 1
|
||||||
|
if consecutive_empty >= max_consecutive_empty:
|
||||||
|
logging.warning(
|
||||||
|
"Enrich phase stopping: %d consecutive batches with no progress (%d skipped)",
|
||||||
|
consecutive_empty,
|
||||||
|
len(skip_ids),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
consecutive_empty = 0
|
||||||
|
|
||||||
|
now = time.perf_counter()
|
||||||
|
if now - last_progress_time >= progress_interval:
|
||||||
|
self._emit_event(
|
||||||
|
"assets.seed.progress",
|
||||||
|
{
|
||||||
|
"phase": "enrich",
|
||||||
|
"enriched": total_enriched,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
last_progress_time = now
|
||||||
|
|
||||||
|
return total_enriched
|
||||||
|
|
||||||
|
|
||||||
|
asset_seeder = AssetSeeder()
|
||||||
17
main.py
17
main.py
@ -7,7 +7,7 @@ import folder_paths
|
|||||||
import time
|
import time
|
||||||
from comfy.cli_args import args, enables_dynamic_vram
|
from comfy.cli_args import args, enables_dynamic_vram
|
||||||
from app.logger import setup_logger
|
from app.logger import setup_logger
|
||||||
from app.assets.scanner import seed_assets
|
from app.assets.seeder import asset_seeder
|
||||||
import itertools
|
import itertools
|
||||||
import utils.extra_config
|
import utils.extra_config
|
||||||
import logging
|
import logging
|
||||||
@ -258,7 +258,11 @@ def prompt_worker(q, server_instance):
|
|||||||
for k in sensitive:
|
for k in sensitive:
|
||||||
extra_data[k] = sensitive[k]
|
extra_data[k] = sensitive[k]
|
||||||
|
|
||||||
|
asset_seeder.pause()
|
||||||
|
|
||||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||||
|
|
||||||
|
asset_seeder.resume()
|
||||||
need_gc = True
|
need_gc = True
|
||||||
|
|
||||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||||
@ -355,8 +359,10 @@ def setup_database():
|
|||||||
from app.database.db import init_db, dependencies_available
|
from app.database.db import init_db, dependencies_available
|
||||||
if dependencies_available():
|
if dependencies_available():
|
||||||
init_db()
|
init_db()
|
||||||
if not args.disable_assets_autoscan:
|
if args.disable_assets_autoscan:
|
||||||
seed_assets(["models"], enable_logging=True)
|
asset_seeder.disable()
|
||||||
|
elif asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True):
|
||||||
|
logging.info("Background asset scan initiated for models, input, output")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||||
|
|
||||||
@ -440,5 +446,6 @@ if __name__ == "__main__":
|
|||||||
event_loop.run_until_complete(x)
|
event_loop.run_until_complete(x)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging.info("\nStopped server")
|
logging.info("\nStopped server")
|
||||||
|
finally:
|
||||||
cleanup_temp()
|
asset_seeder.shutdown()
|
||||||
|
cleanup_temp()
|
||||||
|
|||||||
@ -33,7 +33,7 @@ import node_helpers
|
|||||||
from comfyui_version import __version__
|
from comfyui_version import __version__
|
||||||
from app.frontend_management import FrontendManager, parse_version
|
from app.frontend_management import FrontendManager, parse_version
|
||||||
from comfy_api.internal import _ComfyNodeInternal
|
from comfy_api.internal import _ComfyNodeInternal
|
||||||
from app.assets.scanner import seed_assets
|
from app.assets.seeder import asset_seeder
|
||||||
from app.assets.api.routes import register_assets_system
|
from app.assets.api.routes import register_assets_system
|
||||||
|
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
@ -697,10 +697,7 @@ class PromptServer():
|
|||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
async def get_object_info(request):
|
async def get_object_info(request):
|
||||||
try:
|
asset_seeder.start(roots=("models", "input", "output"))
|
||||||
seed_assets(["models"])
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to seed assets: {e}")
|
|
||||||
with folder_paths.cache_helper:
|
with folder_paths.cache_helper:
|
||||||
out = {}
|
out = {}
|
||||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||||
|
|||||||
772
tests-unit/seeder_test/test_seeder.py
Normal file
772
tests-unit/seeder_test/test_seeder.py
Normal file
@ -0,0 +1,772 @@
|
|||||||
|
"""Unit tests for the AssetSeeder background scanning class."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.assets.seeder import AssetSeeder, Progress, ScanPhase, State
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fresh_seeder():
|
||||||
|
"""Create a fresh AssetSeeder instance for testing (bypasses singleton)."""
|
||||||
|
seeder = object.__new__(AssetSeeder)
|
||||||
|
seeder._initialized = False
|
||||||
|
seeder.__init__()
|
||||||
|
yield seeder
|
||||||
|
seeder.shutdown(timeout=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_dependencies():
|
||||||
|
"""Mock all external dependencies for isolated testing."""
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
|
||||||
|
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
class TestSeederStateTransitions:
|
||||||
|
"""Test state machine transitions."""
|
||||||
|
|
||||||
|
def test_initial_state_is_idle(self, fresh_seeder: AssetSeeder):
|
||||||
|
assert fresh_seeder.get_status().state == State.IDLE
|
||||||
|
|
||||||
|
def test_start_transitions_to_running(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
started = fresh_seeder.start(roots=("models",))
|
||||||
|
assert started is True
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
assert fresh_seeder.get_status().state == State.RUNNING
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
def test_start_while_running_returns_false(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
second_start = fresh_seeder.start(roots=("models",))
|
||||||
|
assert second_start is False
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
def test_cancel_transitions_to_cancelling(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
cancelled = fresh_seeder.cancel()
|
||||||
|
assert cancelled is True
|
||||||
|
assert fresh_seeder.get_status().state == State.CANCELLING
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
def test_cancel_when_idle_returns_false(self, fresh_seeder: AssetSeeder):
|
||||||
|
cancelled = fresh_seeder.cancel()
|
||||||
|
assert cancelled is False
|
||||||
|
|
||||||
|
def test_state_returns_to_idle_after_completion(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
completed = fresh_seeder.wait(timeout=5.0)
|
||||||
|
assert completed is True
|
||||||
|
assert fresh_seeder.get_status().state == State.IDLE
|
||||||
|
|
||||||
|
|
||||||
|
class TestSeederWait:
|
||||||
|
"""Test wait() behavior."""
|
||||||
|
|
||||||
|
def test_wait_blocks_until_complete(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
completed = fresh_seeder.wait(timeout=5.0)
|
||||||
|
assert completed is True
|
||||||
|
assert fresh_seeder.get_status().state == State.IDLE
|
||||||
|
|
||||||
|
def test_wait_returns_false_on_timeout(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
barrier.wait(timeout=10.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
completed = fresh_seeder.wait(timeout=0.1)
|
||||||
|
assert completed is False
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
def test_wait_when_idle_returns_true(self, fresh_seeder: AssetSeeder):
|
||||||
|
completed = fresh_seeder.wait(timeout=1.0)
|
||||||
|
assert completed is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestSeederProgress:
|
||||||
|
"""Test progress tracking."""
|
||||||
|
|
||||||
|
def test_get_status_returns_progress_during_scan(
|
||||||
|
self, fresh_seeder: AssetSeeder
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
|
||||||
|
def slow_build(*args, **kwargs):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return ([], set(), 0)
|
||||||
|
|
||||||
|
paths = ["/path/file1.safetensors", "/path/file2.safetensors"]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
|
||||||
|
patch("app.assets.seeder.build_stub_specs", side_effect=slow_build),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
status = fresh_seeder.get_status()
|
||||||
|
assert status.state == State.RUNNING
|
||||||
|
assert status.progress is not None
|
||||||
|
assert status.progress.total == 2
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
def test_progress_callback_is_invoked(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
progress_updates: list[Progress] = []
|
||||||
|
|
||||||
|
def callback(p: Progress):
|
||||||
|
progress_updates.append(p)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots",
|
||||||
|
return_value=[f"/path/file{i}.safetensors" for i in range(10)],
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",), progress_callback=callback)
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
assert len(progress_updates) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSeederCancellation:
|
||||||
|
"""Test cancellation behavior."""
|
||||||
|
|
||||||
|
def test_scan_commits_partial_progress_on_cancellation(
|
||||||
|
self, fresh_seeder: AssetSeeder
|
||||||
|
):
|
||||||
|
insert_count = 0
|
||||||
|
barrier = threading.Event()
|
||||||
|
first_insert_done = threading.Event()
|
||||||
|
|
||||||
|
def slow_insert(specs, tags):
|
||||||
|
nonlocal insert_count
|
||||||
|
insert_count += 1
|
||||||
|
if insert_count == 1:
|
||||||
|
first_insert_done.set()
|
||||||
|
if insert_count >= 2:
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return len(specs)
|
||||||
|
|
||||||
|
paths = [f"/path/file{i}.safetensors" for i in range(1500)]
|
||||||
|
specs = [
|
||||||
|
{
|
||||||
|
"abs_path": p,
|
||||||
|
"size_bytes": 100,
|
||||||
|
"mtime_ns": 0,
|
||||||
|
"info_name": f"file{i}",
|
||||||
|
"tags": [],
|
||||||
|
"fname": f"file{i}",
|
||||||
|
"metadata": None,
|
||||||
|
"hash": None,
|
||||||
|
"mime_type": None,
|
||||||
|
}
|
||||||
|
for i, p in enumerate(paths)
|
||||||
|
]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
|
||||||
|
patch(
|
||||||
|
"app.assets.seeder.build_stub_specs", return_value=(specs, set(), 0)
|
||||||
|
),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
assert first_insert_done.wait(timeout=2.0)
|
||||||
|
|
||||||
|
fresh_seeder.cancel()
|
||||||
|
barrier.set()
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
assert 1 <= insert_count < 3 # 1500 paths / 500 batch = 3; cancel stopped early
|
||||||
|
|
||||||
|
|
||||||
|
class TestSeederErrorHandling:
|
||||||
|
"""Test error handling behavior."""
|
||||||
|
|
||||||
|
def test_database_errors_captured_in_status(self, fresh_seeder: AssetSeeder):
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots",
|
||||||
|
return_value=["/path/file.safetensors"],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"app.assets.seeder.build_stub_specs",
|
||||||
|
return_value=(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"abs_path": "/path/file.safetensors",
|
||||||
|
"size_bytes": 100,
|
||||||
|
"mtime_ns": 0,
|
||||||
|
"info_name": "file",
|
||||||
|
"tags": [],
|
||||||
|
"fname": "file",
|
||||||
|
"metadata": None,
|
||||||
|
"hash": None,
|
||||||
|
"mime_type": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
set(),
|
||||||
|
0,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"app.assets.seeder.insert_asset_specs",
|
||||||
|
side_effect=Exception("DB connection failed"),
|
||||||
|
),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
status = fresh_seeder.get_status()
|
||||||
|
assert len(status.errors) > 0
|
||||||
|
assert "DB connection failed" in status.errors[0]
|
||||||
|
|
||||||
|
def test_dependencies_unavailable_captured_in_errors(
|
||||||
|
self, fresh_seeder: AssetSeeder
|
||||||
|
):
|
||||||
|
with patch("app.assets.seeder.dependencies_available", return_value=False):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
status = fresh_seeder.get_status()
|
||||||
|
assert len(status.errors) > 0
|
||||||
|
assert "dependencies" in status.errors[0].lower()
|
||||||
|
|
||||||
|
def test_thread_crash_resets_state_to_idle(self, fresh_seeder: AssetSeeder):
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch(
|
||||||
|
"app.assets.seeder.sync_root_safely",
|
||||||
|
side_effect=RuntimeError("Unexpected crash"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
status = fresh_seeder.get_status()
|
||||||
|
assert status.state == State.IDLE
|
||||||
|
assert len(status.errors) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSeederThreadSafety:
|
||||||
|
"""Test thread safety of concurrent operations."""
|
||||||
|
|
||||||
|
def test_concurrent_start_calls_spawn_only_one_thread(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
results = []
|
||||||
|
|
||||||
|
def try_start():
|
||||||
|
results.append(fresh_seeder.start(roots=("models",)))
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=try_start) for _ in range(10)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
assert sum(results) == 1
|
||||||
|
|
||||||
|
def test_get_status_safe_during_scan(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
statuses = []
|
||||||
|
for _ in range(100):
|
||||||
|
statuses.append(fresh_seeder.get_status())
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
assert all(
|
||||||
|
s.state in (State.RUNNING, State.IDLE, State.CANCELLING)
|
||||||
|
for s in statuses
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSeederMarkMissing:
|
||||||
|
"""Test mark_missing_outside_prefixes behavior."""
|
||||||
|
|
||||||
|
def test_mark_missing_when_idle(self, fresh_seeder: AssetSeeder):
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch(
|
||||||
|
"app.assets.seeder.get_all_known_prefixes",
|
||||||
|
return_value=["/models", "/input", "/output"],
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"app.assets.seeder.mark_missing_outside_prefixes_safely", return_value=5
|
||||||
|
) as mock_mark,
|
||||||
|
):
|
||||||
|
result = fresh_seeder.mark_missing_outside_prefixes()
|
||||||
|
assert result == 5
|
||||||
|
mock_mark.assert_called_once_with(["/models", "/input", "/output"])
|
||||||
|
|
||||||
|
def test_mark_missing_returns_zero_when_running(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
result = fresh_seeder.mark_missing_outside_prefixes()
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
def test_mark_missing_returns_zero_when_dependencies_unavailable(
|
||||||
|
self, fresh_seeder: AssetSeeder
|
||||||
|
):
|
||||||
|
with patch("app.assets.seeder.dependencies_available", return_value=False):
|
||||||
|
result = fresh_seeder.mark_missing_outside_prefixes()
|
||||||
|
assert result == 0
|
||||||
|
|
||||||
|
def test_prune_first_flag_triggers_mark_missing_before_scan(
|
||||||
|
self, fresh_seeder: AssetSeeder
|
||||||
|
):
|
||||||
|
call_order = []
|
||||||
|
|
||||||
|
def track_mark(prefixes):
|
||||||
|
call_order.append("mark_missing")
|
||||||
|
return 3
|
||||||
|
|
||||||
|
def track_sync(root):
|
||||||
|
call_order.append(f"sync_{root}")
|
||||||
|
return set()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.get_all_known_prefixes", return_value=["/models"]),
|
||||||
|
patch("app.assets.seeder.mark_missing_outside_prefixes_safely", side_effect=track_mark),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", side_effect=track_sync),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
|
||||||
|
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",), prune_first=True)
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
assert call_order[0] == "mark_missing"
|
||||||
|
assert "sync_models" in call_order
|
||||||
|
|
||||||
|
|
||||||
|
class TestSeederPhases:
|
||||||
|
"""Test phased scanning behavior."""
|
||||||
|
|
||||||
|
def test_start_fast_only_runs_fast_phase(self, fresh_seeder: AssetSeeder):
|
||||||
|
"""Verify start_fast only runs the fast phase."""
|
||||||
|
fast_called = []
|
||||||
|
enrich_called = []
|
||||||
|
|
||||||
|
def track_fast(*args, **kwargs):
|
||||||
|
fast_called.append(True)
|
||||||
|
return ([], set(), 0)
|
||||||
|
|
||||||
|
def track_enrich(*args, **kwargs):
|
||||||
|
enrich_called.append(True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
|
||||||
|
patch("app.assets.seeder.build_stub_specs", side_effect=track_fast),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
fresh_seeder.start_fast(roots=("models",))
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
assert len(fast_called) == 1
|
||||||
|
assert len(enrich_called) == 0
|
||||||
|
|
||||||
|
def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: AssetSeeder):
|
||||||
|
"""Verify start_enrich only runs the enrich phase."""
|
||||||
|
fast_called = []
|
||||||
|
enrich_called = []
|
||||||
|
|
||||||
|
def track_fast(*args, **kwargs):
|
||||||
|
fast_called.append(True)
|
||||||
|
return ([], set(), 0)
|
||||||
|
|
||||||
|
def track_enrich(*args, **kwargs):
|
||||||
|
enrich_called.append(True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
|
||||||
|
patch("app.assets.seeder.build_stub_specs", side_effect=track_fast),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
fresh_seeder.start_enrich(roots=("models",))
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
assert len(fast_called) == 0
|
||||||
|
assert len(enrich_called) == 1
|
||||||
|
|
||||||
|
def test_full_scan_runs_both_phases(self, fresh_seeder: AssetSeeder):
|
||||||
|
"""Verify full scan runs both fast and enrich phases."""
|
||||||
|
fast_called = []
|
||||||
|
enrich_called = []
|
||||||
|
|
||||||
|
def track_fast(*args, **kwargs):
|
||||||
|
fast_called.append(True)
|
||||||
|
return ([], set(), 0)
|
||||||
|
|
||||||
|
def track_enrich(*args, **kwargs):
|
||||||
|
enrich_called.append(True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
|
||||||
|
patch("app.assets.seeder.build_stub_specs", side_effect=track_fast),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",), phase=ScanPhase.FULL)
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
assert len(fast_called) == 1
|
||||||
|
assert len(enrich_called) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestSeederPauseResume:
|
||||||
|
"""Test pause/resume behavior."""
|
||||||
|
|
||||||
|
def test_pause_transitions_to_paused(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
paused = fresh_seeder.pause()
|
||||||
|
assert paused is True
|
||||||
|
assert fresh_seeder.get_status().state == State.PAUSED
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
def test_pause_when_idle_returns_false(self, fresh_seeder: AssetSeeder):
|
||||||
|
paused = fresh_seeder.pause()
|
||||||
|
assert paused is False
|
||||||
|
|
||||||
|
def test_resume_returns_to_running(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
fresh_seeder.pause()
|
||||||
|
assert fresh_seeder.get_status().state == State.PAUSED
|
||||||
|
|
||||||
|
resumed = fresh_seeder.resume()
|
||||||
|
assert resumed is True
|
||||||
|
assert fresh_seeder.get_status().state == State.RUNNING
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
def test_resume_when_not_paused_returns_false(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
resumed = fresh_seeder.resume()
|
||||||
|
assert resumed is False
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
def test_cancel_while_paused_works(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached_checkpoint = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached_checkpoint.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
assert reached_checkpoint.wait(timeout=2.0)
|
||||||
|
|
||||||
|
fresh_seeder.pause()
|
||||||
|
assert fresh_seeder.get_status().state == State.PAUSED
|
||||||
|
|
||||||
|
cancelled = fresh_seeder.cancel()
|
||||||
|
assert cancelled is True
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
assert fresh_seeder.get_status().state == State.IDLE
|
||||||
|
|
||||||
|
class TestSeederStopRestart:
|
||||||
|
"""Test stop and restart behavior."""
|
||||||
|
|
||||||
|
def test_stop_is_alias_for_cancel(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
stopped = fresh_seeder.stop()
|
||||||
|
assert stopped is True
|
||||||
|
assert fresh_seeder.get_status().state == State.CANCELLING
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
def test_restart_cancels_and_starts_new_scan(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
start_count = 0
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
nonlocal start_count
|
||||||
|
start_count += 1
|
||||||
|
if start_count == 1:
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
restarted = fresh_seeder.restart()
|
||||||
|
assert restarted is True
|
||||||
|
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
assert start_count == 2
|
||||||
|
|
||||||
|
def test_restart_preserves_previous_params(self, fresh_seeder: AssetSeeder):
|
||||||
|
"""Verify restart uses previous params when not overridden."""
|
||||||
|
collected_roots = []
|
||||||
|
|
||||||
|
def track_collect(roots):
|
||||||
|
collected_roots.append(roots)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
|
||||||
|
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("input", "output"))
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
fresh_seeder.restart()
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
assert len(collected_roots) == 2
|
||||||
|
assert collected_roots[0] == ("input", "output")
|
||||||
|
assert collected_roots[1] == ("input", "output")
|
||||||
|
|
||||||
|
def test_restart_can_override_params(self, fresh_seeder: AssetSeeder):
|
||||||
|
"""Verify restart can override previous params."""
|
||||||
|
collected_roots = []
|
||||||
|
|
||||||
|
def track_collect(roots):
|
||||||
|
collected_roots.append(roots)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
|
||||||
|
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
fresh_seeder.restart(roots=("input",))
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
assert len(collected_roots) == 2
|
||||||
|
assert collected_roots[0] == ("models",)
|
||||||
|
assert collected_roots[1] == ("input",)
|
||||||
Loading…
Reference in New Issue
Block a user