refactor: make scanner helper functions public

Rename _sync_root_safely, _prune_orphans_safely, _collect_paths_for_roots,
_build_asset_specs, and _insert_asset_specs to remove underscore prefix
since they are used by seeder.py as part of the public API.

Amp-Thread-ID: https://ampcode.com/threads/T-019c3037-df32-7138-99d8-b4b824d896b3
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr 2026-02-05 19:01:46 -08:00
parent 310cfc6455
commit b4f5bb2faa
3 changed files with 70 additions and 53 deletions

View File

@ -46,6 +46,7 @@ class _AssetAccumulator(TypedDict):
size_db: int size_db: int
states: list[_StateInfo] states: list[_StateInfo]
RootType = Literal["models", "input", "output"] RootType = Literal["models", "input", "output"]
@ -200,7 +201,7 @@ def sync_cache_states_with_filesystem(
return survivors if collect_existing_paths else None return survivors if collect_existing_paths else None
def _sync_root_safely(root: RootType) -> set[str]: def sync_root_safely(root: RootType) -> set[str]:
"""Sync a single root's cache states with the filesystem. """Sync a single root's cache states with the filesystem.
Returns survivors (existing paths) or empty set on failure. Returns survivors (existing paths) or empty set on failure.
@ -220,7 +221,7 @@ def _sync_root_safely(root: RootType) -> set[str]:
return set() return set()
def _prune_orphans_safely(prefixes: list[str]) -> int: def prune_orphans_safely(prefixes: list[str]) -> int:
"""Prune orphaned assets outside the given prefixes. """Prune orphaned assets outside the given prefixes.
Returns count pruned or 0 on failure. Returns count pruned or 0 on failure.
@ -235,7 +236,7 @@ def _prune_orphans_safely(prefixes: list[str]) -> int:
return 0 return 0
def _collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]: def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
"""Collect all file paths for the given roots.""" """Collect all file paths for the given roots."""
paths: list[str] = [] paths: list[str] = []
if "models" in roots: if "models" in roots:
@ -247,7 +248,7 @@ def _collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
return paths return paths
def _build_asset_specs( def build_asset_specs(
paths: list[str], paths: list[str],
existing_paths: set[str], existing_paths: set[str],
enable_metadata_extraction: bool = True, enable_metadata_extraction: bool = True,
@ -303,7 +304,7 @@ def _build_asset_specs(
return specs, tag_pool, skipped return specs, tag_pool, skipped
def _insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int: def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
"""Insert asset specs into database, returning count of created infos.""" """Insert asset specs into database, returning count of created infos."""
if not specs: if not specs:
return 0 return 0
@ -330,11 +331,11 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
existing_paths: set[str] = set() existing_paths: set[str] = set()
for r in roots: for r in roots:
existing_paths.update(_sync_root_safely(r)) existing_paths.update(sync_root_safely(r))
paths = _collect_paths_for_roots(roots) paths = collect_paths_for_roots(roots)
specs, tag_pool, skipped_existing = _build_asset_specs(paths, existing_paths) specs, tag_pool, skipped_existing = build_asset_specs(paths, existing_paths)
created = _insert_asset_specs(specs, tag_pool) created = insert_asset_specs(specs, tag_pool)
if enable_logging: if enable_logging:
logging.info( logging.info(

View File

@ -10,18 +10,18 @@ from typing import TYPE_CHECKING, Callable
from app.assets.scanner import ( from app.assets.scanner import (
RootType, RootType,
_build_asset_specs, build_asset_specs,
_collect_paths_for_roots, collect_paths_for_roots,
_insert_asset_specs,
_prune_orphans_safely,
_sync_root_safely,
get_all_known_prefixes, get_all_known_prefixes,
get_prefixes_for_root, get_prefixes_for_root,
insert_asset_specs,
prune_orphans_safely,
sync_root_safely,
) )
from app.database.db import dependencies_available from app.database.db import dependencies_available
if TYPE_CHECKING: if TYPE_CHECKING:
from server import PromptServer pass
class State(Enum): class State(Enum):
@ -193,11 +193,13 @@ class AssetSeeder:
return 0 return 0
if not dependencies_available(): if not dependencies_available():
logging.warning("Database dependencies not available, skipping orphan pruning") logging.warning(
"Database dependencies not available, skipping orphan pruning"
)
return 0 return 0
all_prefixes = get_all_known_prefixes() all_prefixes = get_all_known_prefixes()
pruned = _prune_orphans_safely(all_prefixes) pruned = prune_orphans_safely(all_prefixes)
if pruned > 0: if pruned > 0:
logging.info("Pruned %d orphaned assets", pruned) logging.info("Pruned %d orphaned assets", pruned)
return pruned return pruned
@ -288,7 +290,7 @@ class AssetSeeder:
if self._prune_first: if self._prune_first:
all_prefixes = get_all_known_prefixes() all_prefixes = get_all_known_prefixes()
pruned = _prune_orphans_safely(all_prefixes) pruned = prune_orphans_safely(all_prefixes)
if pruned > 0: if pruned > 0:
logging.info("Pruned %d orphaned assets before scan", pruned) logging.info("Pruned %d orphaned assets before scan", pruned)
@ -305,14 +307,14 @@ class AssetSeeder:
logging.info("Asset scan cancelled during sync phase") logging.info("Asset scan cancelled during sync phase")
cancelled = True cancelled = True
return return
existing_paths.update(_sync_root_safely(r)) existing_paths.update(sync_root_safely(r))
if self._is_cancelled(): if self._is_cancelled():
logging.info("Asset scan cancelled after sync phase") logging.info("Asset scan cancelled after sync phase")
cancelled = True cancelled = True
return return
paths = _collect_paths_for_roots(roots) paths = collect_paths_for_roots(roots)
total_paths = len(paths) total_paths = len(paths)
self._update_progress(total=total_paths) self._update_progress(total=total_paths)
@ -321,7 +323,7 @@ class AssetSeeder:
{"roots": list(roots), "total": total_paths}, {"roots": list(roots), "total": total_paths},
) )
specs, tag_pool, skipped_existing = _build_asset_specs(paths, existing_paths) specs, tag_pool, skipped_existing = build_asset_specs(paths, existing_paths)
self._update_progress(skipped=skipped_existing) self._update_progress(skipped=skipped_existing)
if self._is_cancelled(): if self._is_cancelled():
@ -347,7 +349,7 @@ class AssetSeeder:
batch = specs[i : i + batch_size] batch = specs[i : i + batch_size]
batch_tags = {t for spec in batch for t in spec["tags"]} batch_tags = {t for spec in batch for t in spec["tags"]}
try: try:
created = _insert_asset_specs(batch, batch_tags) created = insert_asset_specs(batch, batch_tags)
total_created += created total_created += created
except Exception as e: except Exception as e:
self._add_error(f"Batch insert failed at offset {i}: {e}") self._add_error(f"Batch insert failed at offset {i}: {e}")
@ -360,7 +362,11 @@ class AssetSeeder:
if now - last_progress_time >= progress_interval: if now - last_progress_time >= progress_interval:
self._emit_event( self._emit_event(
"assets.seed.progress", "assets.seed.progress",
{"scanned": scanned, "total": len(specs), "created": total_created}, {
"scanned": scanned,
"total": len(specs),
"created": total_created,
},
) )
last_progress_time = now last_progress_time = now

View File

@ -2,7 +2,7 @@
import threading import threading
import time import time
from unittest.mock import MagicMock, patch from unittest.mock import patch
import pytest import pytest
@ -24,10 +24,10 @@ def mock_dependencies():
"""Mock all external dependencies for isolated testing.""" """Mock all external dependencies for isolated testing."""
with ( with (
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder._sync_root_safely", return_value=set()), patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder._collect_paths_for_roots", return_value=[]), patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder._build_asset_specs", return_value=([], set(), 0)), patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder._insert_asset_specs", return_value=0), patch("app.assets.seeder.insert_asset_specs", return_value=0),
): ):
yield yield
@ -56,7 +56,7 @@ class TestSeederStateTransitions:
return [] return []
with patch( with patch(
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
time.sleep(0.05) time.sleep(0.05)
@ -76,7 +76,7 @@ class TestSeederStateTransitions:
return [] return []
with patch( with patch(
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
time.sleep(0.05) time.sleep(0.05)
@ -121,7 +121,7 @@ class TestSeederWait:
return [] return []
with patch( with patch(
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
completed = fresh_seeder.wait(timeout=0.1) completed = fresh_seeder.wait(timeout=0.1)
@ -148,7 +148,7 @@ class TestSeederProgress:
return ["/path/file1.safetensors", "/path/file2.safetensors"] return ["/path/file1.safetensors", "/path/file2.safetensors"]
with patch( with patch(
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
time.sleep(0.05) time.sleep(0.05)
@ -168,7 +168,7 @@ class TestSeederProgress:
progress_updates.append(p) progress_updates.append(p)
with patch( with patch(
"app.assets.seeder._collect_paths_for_roots", "app.assets.seeder.collect_paths_for_roots",
return_value=[f"/path/file{i}.safetensors" for i in range(10)], return_value=[f"/path/file{i}.safetensors" for i in range(10)],
): ):
fresh_seeder.start(roots=("models",), progress_callback=callback) fresh_seeder.start(roots=("models",), progress_callback=callback)
@ -208,10 +208,12 @@ class TestSeederCancellation:
with ( with (
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder._sync_root_safely", return_value=set()), patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder._collect_paths_for_roots", return_value=paths), patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
patch("app.assets.seeder._build_asset_specs", return_value=(specs, set(), 0)), patch(
patch("app.assets.seeder._insert_asset_specs", side_effect=slow_insert), "app.assets.seeder.build_asset_specs", return_value=(specs, set(), 0)
),
patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert),
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
time.sleep(0.1) time.sleep(0.1)
@ -229,13 +231,13 @@ class TestSeederErrorHandling:
def test_database_errors_captured_in_status(self, fresh_seeder: AssetSeeder): def test_database_errors_captured_in_status(self, fresh_seeder: AssetSeeder):
with ( with (
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder._sync_root_safely", return_value=set()), patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch( patch(
"app.assets.seeder._collect_paths_for_roots", "app.assets.seeder.collect_paths_for_roots",
return_value=["/path/file.safetensors"], return_value=["/path/file.safetensors"],
), ),
patch( patch(
"app.assets.seeder._build_asset_specs", "app.assets.seeder.build_asset_specs",
return_value=( return_value=(
[ [
{ {
@ -252,7 +254,7 @@ class TestSeederErrorHandling:
), ),
), ),
patch( patch(
"app.assets.seeder._insert_asset_specs", "app.assets.seeder.insert_asset_specs",
side_effect=Exception("DB connection failed"), side_effect=Exception("DB connection failed"),
), ),
): ):
@ -278,7 +280,7 @@ class TestSeederErrorHandling:
with ( with (
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch( patch(
"app.assets.seeder._sync_root_safely", "app.assets.seeder.sync_root_safely",
side_effect=RuntimeError("Unexpected crash"), side_effect=RuntimeError("Unexpected crash"),
), ),
): ):
@ -303,7 +305,7 @@ class TestSeederThreadSafety:
return [] return []
with patch( with patch(
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
results = [] results = []
@ -330,7 +332,7 @@ class TestSeederThreadSafety:
return [] return []
with patch( with patch(
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
@ -341,7 +343,10 @@ class TestSeederThreadSafety:
barrier.set() barrier.set()
assert all(s.state in (State.RUNNING, State.IDLE, State.CANCELLING) for s in statuses) assert all(
s.state in (State.RUNNING, State.IDLE, State.CANCELLING)
for s in statuses
)
class TestSeederPruneOrphans: class TestSeederPruneOrphans:
@ -350,8 +355,13 @@ class TestSeederPruneOrphans:
def test_prune_orphans_when_idle(self, fresh_seeder: AssetSeeder): def test_prune_orphans_when_idle(self, fresh_seeder: AssetSeeder):
with ( with (
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.get_all_known_prefixes", return_value=["/models", "/input", "/output"]), patch(
patch("app.assets.seeder._prune_orphans_safely", return_value=5) as mock_prune, "app.assets.seeder.get_all_known_prefixes",
return_value=["/models", "/input", "/output"],
),
patch(
"app.assets.seeder.prune_orphans_safely", return_value=5
) as mock_prune,
): ):
result = fresh_seeder.prune_orphans() result = fresh_seeder.prune_orphans()
assert result == 5 assert result == 5
@ -367,7 +377,7 @@ class TestSeederPruneOrphans:
return [] return []
with patch( with patch(
"app.assets.seeder._collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
time.sleep(0.05) time.sleep(0.05)
@ -400,11 +410,11 @@ class TestSeederPruneOrphans:
with ( with (
patch("app.assets.seeder.dependencies_available", return_value=True), patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.get_all_known_prefixes", return_value=["/models"]), patch("app.assets.seeder.get_all_known_prefixes", return_value=["/models"]),
patch("app.assets.seeder._prune_orphans_safely", side_effect=track_prune), patch("app.assets.seeder.prune_orphans_safely", side_effect=track_prune),
patch("app.assets.seeder._sync_root_safely", side_effect=track_sync), patch("app.assets.seeder.sync_root_safely", side_effect=track_sync),
patch("app.assets.seeder._collect_paths_for_roots", return_value=[]), patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder._build_asset_specs", return_value=([], set(), 0)), patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder._insert_asset_specs", return_value=0), patch("app.assets.seeder.insert_asset_specs", return_value=0),
): ):
fresh_seeder.start(roots=("models",), prune_first=True) fresh_seeder.start(roots=("models",), prune_first=True)
fresh_seeder.wait(timeout=5.0) fresh_seeder.wait(timeout=5.0)