Optimize enrichment: shared DB session per batch, add fast scan timing logs

- Add debug timing logs for each fast scan sub-step (sync_root, collect_paths, build_asset_specs) and info-level total timing
- Refactor enrich_asset to accept a session parameter instead of creating one per file
- enrich_assets_batch now opens one session for the entire batch, committing after each asset to keep transactions short
- Simplify enrichment tests by removing create_session mocking

Amp-Thread-ID: https://ampcode.com/threads/T-019cbb0b-8563-7199-b628-33e3c4fe9f41
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr 2026-03-04 15:17:48 -08:00
parent 6e33c4985a
commit 58582f1faf
3 changed files with 165 additions and 129 deletions

View File

@ -1,7 +1,7 @@
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Literal, TypedDict from typing import Callable, Literal, TypedDict
import folder_paths import folder_paths
from app.assets.database.queries import ( from app.assets.database.queries import (
@ -31,7 +31,7 @@ from app.assets.services.file_utils import (
list_files_recursively, list_files_recursively,
verify_file_unchanged, verify_file_unchanged,
) )
from app.assets.services.hashing import compute_blake3_hash from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
from app.assets.services.metadata_extract import extract_file_metadata from app.assets.services.metadata_extract import extract_file_metadata
from app.assets.services.path_utils import ( from app.assets.services.path_utils import (
compute_relative_filename, compute_relative_filename,
@ -322,7 +322,7 @@ def build_asset_specs(
asset_hash: str | None = None asset_hash: str | None = None
if compute_hashes: if compute_hashes:
try: try:
digest = compute_blake3_hash(abs_p) digest, _ = compute_blake3_hash(abs_p)
asset_hash = "blake3:" + digest asset_hash = "blake3:" + digest
except Exception as e: except Exception as e:
logging.warning("Failed to hash %s: %s", abs_p, e) logging.warning("Failed to hash %s: %s", abs_p, e)
@ -394,20 +394,28 @@ def get_unenriched_assets_for_roots(
def enrich_asset( def enrich_asset(
session,
file_path: str, file_path: str,
reference_id: str, reference_id: str,
asset_id: str, asset_id: str,
extract_metadata: bool = True, extract_metadata: bool = True,
compute_hash: bool = False, compute_hash: bool = False,
interrupt_check: Callable[[], bool] | None = None,
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
) -> int: ) -> int:
"""Enrich a single asset with metadata and/or hash. """Enrich a single asset with metadata and/or hash.
Args: Args:
session: Database session (caller manages lifecycle)
file_path: Absolute path to the file file_path: Absolute path to the file
reference_id: ID of the reference to update reference_id: ID of the reference to update
asset_id: ID of the asset to update (for mime_type and hash) asset_id: ID of the asset to update (for mime_type and hash)
extract_metadata: If True, extract safetensors header and mime type extract_metadata: If True, extract safetensors header and mime type
compute_hash: If True, compute blake3 hash compute_hash: If True, compute blake3 hash
interrupt_check: Optional callable that may block (e.g. while paused)
and returns True if the operation should be cancelled
hash_checkpoints: Optional dict for saving/restoring hash progress
across interruptions, keyed by file path
Returns: Returns:
New enrichment level achieved New enrichment level achieved
@ -438,7 +446,31 @@ def enrich_asset(
if compute_hash: if compute_hash:
try: try:
mtime_before = get_mtime_ns(stat_p) mtime_before = get_mtime_ns(stat_p)
digest = compute_blake3_hash(file_path)
# Restore checkpoint if available and file unchanged
checkpoint = None
if hash_checkpoints is not None:
checkpoint = hash_checkpoints.get(file_path)
if checkpoint is not None and mtime_before != get_mtime_ns(stat_p):
checkpoint = None
hash_checkpoints.pop(file_path, None)
digest, new_checkpoint = compute_blake3_hash(
file_path,
interrupt_check=interrupt_check,
checkpoint=checkpoint,
)
if digest is None:
# Interrupted — save checkpoint for later resumption
if hash_checkpoints is not None and new_checkpoint is not None:
hash_checkpoints[file_path] = new_checkpoint
return new_level
# Completed — clear any saved checkpoint
if hash_checkpoints is not None:
hash_checkpoints.pop(file_path, None)
stat_after = os.stat(file_path, follow_symlinks=True) stat_after = os.stat(file_path, follow_symlinks=True)
mtime_after = get_mtime_ns(stat_after) mtime_after = get_mtime_ns(stat_after)
if mtime_before != mtime_after: if mtime_before != mtime_after:
@ -451,25 +483,24 @@ def enrich_asset(
except Exception as e: except Exception as e:
logging.warning("Failed to hash %s: %s", file_path, e) logging.warning("Failed to hash %s: %s", file_path, e)
with create_session() as sess: if extract_metadata and metadata:
if extract_metadata and metadata: user_metadata = metadata.to_user_metadata()
user_metadata = metadata.to_user_metadata() set_reference_metadata(session, reference_id, user_metadata)
set_reference_metadata(sess, reference_id, user_metadata)
if full_hash: if full_hash:
existing = get_asset_by_hash(sess, full_hash) existing = get_asset_by_hash(session, full_hash)
if existing and existing.id != asset_id: if existing and existing.id != asset_id:
reassign_asset_references(sess, asset_id, existing.id, reference_id) reassign_asset_references(session, asset_id, existing.id, reference_id)
delete_orphaned_seed_asset(sess, asset_id) delete_orphaned_seed_asset(session, asset_id)
if mime_type: if mime_type:
update_asset_hash_and_mime(sess, existing.id, mime_type=mime_type) update_asset_hash_and_mime(session, existing.id, mime_type=mime_type)
else: else:
update_asset_hash_and_mime(sess, asset_id, full_hash, mime_type) update_asset_hash_and_mime(session, asset_id, full_hash, mime_type)
elif mime_type: elif mime_type:
update_asset_hash_and_mime(sess, asset_id, mime_type=mime_type) update_asset_hash_and_mime(session, asset_id, mime_type=mime_type)
bulk_update_enrichment_level(sess, [reference_id], new_level) bulk_update_enrichment_level(session, [reference_id], new_level)
sess.commit() session.commit()
return new_level return new_level
@ -478,13 +509,23 @@ def enrich_assets_batch(
rows: list, rows: list,
extract_metadata: bool = True, extract_metadata: bool = True,
compute_hash: bool = False, compute_hash: bool = False,
interrupt_check: Callable[[], bool] | None = None,
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
) -> tuple[int, list[str]]: ) -> tuple[int, list[str]]:
"""Enrich a batch of assets. """Enrich a batch of assets.
Uses a single DB session for the entire batch, committing after each
individual asset to avoid long-held transactions while eliminating
per-asset session creation overhead.
Args: Args:
rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots
extract_metadata: If True, extract metadata for each asset extract_metadata: If True, extract metadata for each asset
compute_hash: If True, compute hash for each asset compute_hash: If True, compute hash for each asset
interrupt_check: Optional callable that may block (e.g. while paused)
and returns True if the operation should be cancelled
hash_checkpoints: Optional dict for saving/restoring hash progress
across interruptions, keyed by file path
Returns: Returns:
Tuple of (enriched_count, failed_reference_ids) Tuple of (enriched_count, failed_reference_ids)
@ -492,21 +533,28 @@ def enrich_assets_batch(
enriched = 0 enriched = 0
failed_ids: list[str] = [] failed_ids: list[str] = []
for row in rows: with create_session() as sess:
try: for row in rows:
new_level = enrich_asset( if interrupt_check is not None and interrupt_check():
file_path=row.file_path, break
reference_id=row.reference_id,
asset_id=row.asset_id, try:
extract_metadata=extract_metadata, new_level = enrich_asset(
compute_hash=compute_hash, sess,
) file_path=row.file_path,
if new_level > row.enrichment_level: reference_id=row.reference_id,
enriched += 1 asset_id=row.asset_id,
else: extract_metadata=extract_metadata,
compute_hash=compute_hash,
interrupt_check=interrupt_check,
hash_checkpoints=hash_checkpoints,
)
if new_level > row.enrichment_level:
enriched += 1
else:
failed_ids.append(row.reference_id)
except Exception as e:
logging.warning("Failed to enrich %s: %s", row.file_path, e)
failed_ids.append(row.reference_id) failed_ids.append(row.reference_id)
except Exception as e:
logging.warning("Failed to enrich %s: %s", row.file_path, e)
failed_ids.append(row.reference_id)
return enriched, failed_ids return enriched, failed_ids

View File

@ -590,19 +590,32 @@ class _AssetSeeder:
Returns: Returns:
Tuple of (total_created, skipped_existing, total_paths) Tuple of (total_created, skipped_existing, total_paths)
""" """
t_fast_start = time.perf_counter()
total_created = 0 total_created = 0
skipped_existing = 0 skipped_existing = 0
existing_paths: set[str] = set() existing_paths: set[str] = set()
t_sync = time.perf_counter()
for r in roots: for r in roots:
if self._check_pause_and_cancel(): if self._check_pause_and_cancel():
return total_created, skipped_existing, 0 return total_created, skipped_existing, 0
existing_paths.update(sync_root_safely(r)) existing_paths.update(sync_root_safely(r))
logging.debug(
"Fast scan: sync_root phase took %.3fs (%d existing paths)",
time.perf_counter() - t_sync,
len(existing_paths),
)
if self._check_pause_and_cancel(): if self._check_pause_and_cancel():
return total_created, skipped_existing, 0 return total_created, skipped_existing, 0
t_collect = time.perf_counter()
paths = collect_paths_for_roots(roots) paths = collect_paths_for_roots(roots)
logging.debug(
"Fast scan: collect_paths took %.3fs (%d paths found)",
time.perf_counter() - t_collect,
len(paths),
)
total_paths = len(paths) total_paths = len(paths)
self._update_progress(total=total_paths) self._update_progress(total=total_paths)
@ -612,12 +625,19 @@ class _AssetSeeder:
) )
# Use stub specs (no metadata extraction, no hashing) # Use stub specs (no metadata extraction, no hashing)
t_specs = time.perf_counter()
specs, tag_pool, skipped_existing = build_asset_specs( specs, tag_pool, skipped_existing = build_asset_specs(
paths, paths,
existing_paths, existing_paths,
enable_metadata_extraction=False, enable_metadata_extraction=False,
compute_hashes=False, compute_hashes=False,
) )
logging.debug(
"Fast scan: build_asset_specs took %.3fs (%d specs, %d skipped)",
time.perf_counter() - t_specs,
len(specs),
skipped_existing,
)
self._update_progress(skipped=skipped_existing) self._update_progress(skipped=skipped_existing)
if self._check_pause_and_cancel(): if self._check_pause_and_cancel():
@ -663,6 +683,13 @@ class _AssetSeeder:
last_progress_time = now last_progress_time = now
self._update_progress(scanned=len(specs), created=total_created) self._update_progress(scanned=len(specs), created=total_created)
logging.info(
"Fast scan complete: %.3fs total (created=%d, skipped=%d, total_paths=%d)",
time.perf_counter() - t_fast_start,
total_created,
skipped_existing,
total_paths,
)
return total_created, skipped_existing, total_paths return total_created, skipped_existing, total_paths
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]: def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]:
@ -691,6 +718,10 @@ class _AssetSeeder:
consecutive_empty = 0 consecutive_empty = 0
max_consecutive_empty = 3 max_consecutive_empty = 3
# Hash checkpoints survive across batches so interrupted hashes
# can be resumed without re-reading the entire file.
hash_checkpoints: dict[str, object] = {}
while True: while True:
if self._check_pause_and_cancel(): if self._check_pause_and_cancel():
logging.info("Enrich scan cancelled after %d assets", total_enriched) logging.info("Enrich scan cancelled after %d assets", total_enriched)
@ -714,6 +745,8 @@ class _AssetSeeder:
unenriched, unenriched,
extract_metadata=True, extract_metadata=True,
compute_hash=self._compute_hashes, compute_hash=self._compute_hashes,
interrupt_check=self._check_pause_and_cancel,
hash_checkpoints=hash_checkpoints,
) )
total_enriched += enriched total_enriched += enriched
skip_ids.update(failed_ids) skip_ids.update(failed_ids)

View File

@ -1,6 +1,5 @@
"""Tests for asset enrichment (mime_type and hash population).""" """Tests for asset enrichment (mime_type and hash population)."""
from pathlib import Path from pathlib import Path
from unittest.mock import patch
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -58,23 +57,14 @@ class TestEnrichAsset:
) )
session.commit() session.commit()
with patch("app.assets.scanner.create_session") as mock_cs: new_level = enrich_asset(
from contextlib import contextmanager session,
file_path=str(file_path),
@contextmanager reference_id=ref.id,
def _create_session(): asset_id=asset.id,
with Session(db_engine) as sess: extract_metadata=True,
yield sess compute_hash=False,
)
mock_cs.side_effect = _create_session
new_level = enrich_asset(
file_path=str(file_path),
reference_id=ref.id,
asset_id=asset.id,
extract_metadata=True,
compute_hash=False,
)
assert new_level == ENRICHMENT_METADATA assert new_level == ENRICHMENT_METADATA
@ -95,23 +85,14 @@ class TestEnrichAsset:
) )
session.commit() session.commit()
with patch("app.assets.scanner.create_session") as mock_cs: new_level = enrich_asset(
from contextlib import contextmanager session,
file_path=str(file_path),
@contextmanager reference_id=ref.id,
def _create_session(): asset_id=asset.id,
with Session(db_engine) as sess: extract_metadata=True,
yield sess compute_hash=True,
)
mock_cs.side_effect = _create_session
new_level = enrich_asset(
file_path=str(file_path),
reference_id=ref.id,
asset_id=asset.id,
extract_metadata=True,
compute_hash=True,
)
assert new_level == ENRICHMENT_HASHED assert new_level == ENRICHMENT_HASHED
@ -133,23 +114,14 @@ class TestEnrichAsset:
) )
session.commit() session.commit()
with patch("app.assets.scanner.create_session") as mock_cs: enrich_asset(
from contextlib import contextmanager session,
file_path=str(file_path),
@contextmanager reference_id=ref.id,
def _create_session(): asset_id=asset.id,
with Session(db_engine) as sess: extract_metadata=True,
yield sess compute_hash=True,
)
mock_cs.side_effect = _create_session
enrich_asset(
file_path=str(file_path),
reference_id=ref.id,
asset_id=asset.id,
extract_metadata=True,
compute_hash=True,
)
session.expire_all() session.expire_all()
updated_asset = session.get(Asset, "asset-3") updated_asset = session.get(Asset, "asset-3")
@ -169,23 +141,14 @@ class TestEnrichAsset:
) )
session.commit() session.commit()
with patch("app.assets.scanner.create_session") as mock_cs: new_level = enrich_asset(
from contextlib import contextmanager session,
file_path=str(file_path),
@contextmanager reference_id=ref.id,
def _create_session(): asset_id=asset.id,
with Session(db_engine) as sess: extract_metadata=True,
yield sess compute_hash=True,
)
mock_cs.side_effect = _create_session
new_level = enrich_asset(
file_path=str(file_path),
reference_id=ref.id,
asset_id=asset.id,
extract_metadata=True,
compute_hash=True,
)
assert new_level == ENRICHMENT_STUB assert new_level == ENRICHMENT_STUB
@ -212,31 +175,23 @@ class TestEnrichAsset:
) )
session.commit() session.commit()
with patch("app.assets.scanner.create_session") as mock_cs: enrich_asset(
from contextlib import contextmanager session,
file_path=str(file_path_1),
reference_id=ref1.id,
asset_id=asset1.id,
extract_metadata=True,
compute_hash=True,
)
@contextmanager enrich_asset(
def _create_session(): session,
with Session(db_engine) as sess: file_path=str(file_path_2),
yield sess reference_id=ref2.id,
asset_id=asset2.id,
mock_cs.side_effect = _create_session extract_metadata=True,
compute_hash=True,
enrich_asset( )
file_path=str(file_path_1),
reference_id=ref1.id,
asset_id=asset1.id,
extract_metadata=True,
compute_hash=True,
)
enrich_asset(
file_path=str(file_path_2),
reference_id=ref2.id,
asset_id=asset2.id,
extract_metadata=True,
compute_hash=True,
)
session.expire_all() session.expire_all()