Optimize enrichment: shared DB session per batch, add fast scan timing logs

- Add debug timing logs for each fast scan sub-step (sync_root, collect_paths, build_asset_specs) and info-level total timing
- Refactor enrich_asset to accept a session parameter instead of creating one per file
- enrich_assets_batch now opens one session for the entire batch, committing after each asset to keep transactions short
- Simplify enrichment tests by removing create_session mocking

Amp-Thread-ID: https://ampcode.com/threads/T-019cbb0b-8563-7199-b628-33e3c4fe9f41
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr 2026-03-04 15:17:48 -08:00
parent 6e33c4985a
commit 58582f1faf
3 changed files with 165 additions and 129 deletions

View File

@ -1,7 +1,7 @@
import logging
import os
from pathlib import Path
from typing import Literal, TypedDict
from typing import Callable, Literal, TypedDict
import folder_paths
from app.assets.database.queries import (
@ -31,7 +31,7 @@ from app.assets.services.file_utils import (
list_files_recursively,
verify_file_unchanged,
)
from app.assets.services.hashing import compute_blake3_hash
from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
from app.assets.services.metadata_extract import extract_file_metadata
from app.assets.services.path_utils import (
compute_relative_filename,
@ -322,7 +322,7 @@ def build_asset_specs(
asset_hash: str | None = None
if compute_hashes:
try:
digest = compute_blake3_hash(abs_p)
digest, _ = compute_blake3_hash(abs_p)
asset_hash = "blake3:" + digest
except Exception as e:
logging.warning("Failed to hash %s: %s", abs_p, e)
@ -394,20 +394,28 @@ def get_unenriched_assets_for_roots(
def enrich_asset(
session,
file_path: str,
reference_id: str,
asset_id: str,
extract_metadata: bool = True,
compute_hash: bool = False,
interrupt_check: Callable[[], bool] | None = None,
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
) -> int:
"""Enrich a single asset with metadata and/or hash.
Args:
session: Database session (caller manages lifecycle)
file_path: Absolute path to the file
reference_id: ID of the reference to update
asset_id: ID of the asset to update (for mime_type and hash)
extract_metadata: If True, extract safetensors header and mime type
compute_hash: If True, compute blake3 hash
interrupt_check: Optional callable that may block (e.g. while paused)
and returns True if the operation should be cancelled
hash_checkpoints: Optional dict for saving/restoring hash progress
across interruptions, keyed by file path
Returns:
New enrichment level achieved
@ -438,7 +446,31 @@ def enrich_asset(
if compute_hash:
try:
mtime_before = get_mtime_ns(stat_p)
digest = compute_blake3_hash(file_path)
# Restore checkpoint if available and file unchanged
checkpoint = None
if hash_checkpoints is not None:
checkpoint = hash_checkpoints.get(file_path)
if checkpoint is not None and mtime_before != get_mtime_ns(stat_p):
checkpoint = None
hash_checkpoints.pop(file_path, None)
digest, new_checkpoint = compute_blake3_hash(
file_path,
interrupt_check=interrupt_check,
checkpoint=checkpoint,
)
if digest is None:
# Interrupted — save checkpoint for later resumption
if hash_checkpoints is not None and new_checkpoint is not None:
hash_checkpoints[file_path] = new_checkpoint
return new_level
# Completed — clear any saved checkpoint
if hash_checkpoints is not None:
hash_checkpoints.pop(file_path, None)
stat_after = os.stat(file_path, follow_symlinks=True)
mtime_after = get_mtime_ns(stat_after)
if mtime_before != mtime_after:
@ -451,25 +483,24 @@ def enrich_asset(
except Exception as e:
logging.warning("Failed to hash %s: %s", file_path, e)
with create_session() as sess:
if extract_metadata and metadata:
user_metadata = metadata.to_user_metadata()
set_reference_metadata(sess, reference_id, user_metadata)
if extract_metadata and metadata:
user_metadata = metadata.to_user_metadata()
set_reference_metadata(session, reference_id, user_metadata)
if full_hash:
existing = get_asset_by_hash(sess, full_hash)
if existing and existing.id != asset_id:
reassign_asset_references(sess, asset_id, existing.id, reference_id)
delete_orphaned_seed_asset(sess, asset_id)
if mime_type:
update_asset_hash_and_mime(sess, existing.id, mime_type=mime_type)
else:
update_asset_hash_and_mime(sess, asset_id, full_hash, mime_type)
elif mime_type:
update_asset_hash_and_mime(sess, asset_id, mime_type=mime_type)
if full_hash:
existing = get_asset_by_hash(session, full_hash)
if existing and existing.id != asset_id:
reassign_asset_references(session, asset_id, existing.id, reference_id)
delete_orphaned_seed_asset(session, asset_id)
if mime_type:
update_asset_hash_and_mime(session, existing.id, mime_type=mime_type)
else:
update_asset_hash_and_mime(session, asset_id, full_hash, mime_type)
elif mime_type:
update_asset_hash_and_mime(session, asset_id, mime_type=mime_type)
bulk_update_enrichment_level(sess, [reference_id], new_level)
sess.commit()
bulk_update_enrichment_level(session, [reference_id], new_level)
session.commit()
return new_level
@ -478,13 +509,23 @@ def enrich_assets_batch(
rows: list,
extract_metadata: bool = True,
compute_hash: bool = False,
interrupt_check: Callable[[], bool] | None = None,
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
) -> tuple[int, list[str]]:
"""Enrich a batch of assets.
Uses a single DB session for the entire batch, committing after each
individual asset to avoid long-held transactions while eliminating
per-asset session creation overhead.
Args:
rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots
extract_metadata: If True, extract metadata for each asset
compute_hash: If True, compute hash for each asset
interrupt_check: Optional callable that may block (e.g. while paused)
and returns True if the operation should be cancelled
hash_checkpoints: Optional dict for saving/restoring hash progress
across interruptions, keyed by file path
Returns:
Tuple of (enriched_count, failed_reference_ids)
@ -492,21 +533,28 @@ def enrich_assets_batch(
enriched = 0
failed_ids: list[str] = []
for row in rows:
try:
new_level = enrich_asset(
file_path=row.file_path,
reference_id=row.reference_id,
asset_id=row.asset_id,
extract_metadata=extract_metadata,
compute_hash=compute_hash,
)
if new_level > row.enrichment_level:
enriched += 1
else:
with create_session() as sess:
for row in rows:
if interrupt_check is not None and interrupt_check():
break
try:
new_level = enrich_asset(
sess,
file_path=row.file_path,
reference_id=row.reference_id,
asset_id=row.asset_id,
extract_metadata=extract_metadata,
compute_hash=compute_hash,
interrupt_check=interrupt_check,
hash_checkpoints=hash_checkpoints,
)
if new_level > row.enrichment_level:
enriched += 1
else:
failed_ids.append(row.reference_id)
except Exception as e:
logging.warning("Failed to enrich %s: %s", row.file_path, e)
failed_ids.append(row.reference_id)
except Exception as e:
logging.warning("Failed to enrich %s: %s", row.file_path, e)
failed_ids.append(row.reference_id)
return enriched, failed_ids

View File

@ -590,19 +590,32 @@ class _AssetSeeder:
Returns:
Tuple of (total_created, skipped_existing, total_paths)
"""
t_fast_start = time.perf_counter()
total_created = 0
skipped_existing = 0
existing_paths: set[str] = set()
t_sync = time.perf_counter()
for r in roots:
if self._check_pause_and_cancel():
return total_created, skipped_existing, 0
existing_paths.update(sync_root_safely(r))
logging.debug(
"Fast scan: sync_root phase took %.3fs (%d existing paths)",
time.perf_counter() - t_sync,
len(existing_paths),
)
if self._check_pause_and_cancel():
return total_created, skipped_existing, 0
t_collect = time.perf_counter()
paths = collect_paths_for_roots(roots)
logging.debug(
"Fast scan: collect_paths took %.3fs (%d paths found)",
time.perf_counter() - t_collect,
len(paths),
)
total_paths = len(paths)
self._update_progress(total=total_paths)
@ -612,12 +625,19 @@ class _AssetSeeder:
)
# Use stub specs (no metadata extraction, no hashing)
t_specs = time.perf_counter()
specs, tag_pool, skipped_existing = build_asset_specs(
paths,
existing_paths,
enable_metadata_extraction=False,
compute_hashes=False,
)
logging.debug(
"Fast scan: build_asset_specs took %.3fs (%d specs, %d skipped)",
time.perf_counter() - t_specs,
len(specs),
skipped_existing,
)
self._update_progress(skipped=skipped_existing)
if self._check_pause_and_cancel():
@ -663,6 +683,13 @@ class _AssetSeeder:
last_progress_time = now
self._update_progress(scanned=len(specs), created=total_created)
logging.info(
"Fast scan complete: %.3fs total (created=%d, skipped=%d, total_paths=%d)",
time.perf_counter() - t_fast_start,
total_created,
skipped_existing,
total_paths,
)
return total_created, skipped_existing, total_paths
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]:
@ -691,6 +718,10 @@ class _AssetSeeder:
consecutive_empty = 0
max_consecutive_empty = 3
# Hash checkpoints survive across batches so interrupted hashes
# can be resumed without re-reading the entire file.
hash_checkpoints: dict[str, object] = {}
while True:
if self._check_pause_and_cancel():
logging.info("Enrich scan cancelled after %d assets", total_enriched)
@ -714,6 +745,8 @@ class _AssetSeeder:
unenriched,
extract_metadata=True,
compute_hash=self._compute_hashes,
interrupt_check=self._check_pause_and_cancel,
hash_checkpoints=hash_checkpoints,
)
total_enriched += enriched
skip_ids.update(failed_ids)

View File

@ -1,6 +1,5 @@
"""Tests for asset enrichment (mime_type and hash population)."""
from pathlib import Path
from unittest.mock import patch
from sqlalchemy.orm import Session
@ -58,23 +57,14 @@ class TestEnrichAsset:
)
session.commit()
with patch("app.assets.scanner.create_session") as mock_cs:
from contextlib import contextmanager
@contextmanager
def _create_session():
with Session(db_engine) as sess:
yield sess
mock_cs.side_effect = _create_session
new_level = enrich_asset(
file_path=str(file_path),
reference_id=ref.id,
asset_id=asset.id,
extract_metadata=True,
compute_hash=False,
)
new_level = enrich_asset(
session,
file_path=str(file_path),
reference_id=ref.id,
asset_id=asset.id,
extract_metadata=True,
compute_hash=False,
)
assert new_level == ENRICHMENT_METADATA
@ -95,23 +85,14 @@ class TestEnrichAsset:
)
session.commit()
with patch("app.assets.scanner.create_session") as mock_cs:
from contextlib import contextmanager
@contextmanager
def _create_session():
with Session(db_engine) as sess:
yield sess
mock_cs.side_effect = _create_session
new_level = enrich_asset(
file_path=str(file_path),
reference_id=ref.id,
asset_id=asset.id,
extract_metadata=True,
compute_hash=True,
)
new_level = enrich_asset(
session,
file_path=str(file_path),
reference_id=ref.id,
asset_id=asset.id,
extract_metadata=True,
compute_hash=True,
)
assert new_level == ENRICHMENT_HASHED
@ -133,23 +114,14 @@ class TestEnrichAsset:
)
session.commit()
with patch("app.assets.scanner.create_session") as mock_cs:
from contextlib import contextmanager
@contextmanager
def _create_session():
with Session(db_engine) as sess:
yield sess
mock_cs.side_effect = _create_session
enrich_asset(
file_path=str(file_path),
reference_id=ref.id,
asset_id=asset.id,
extract_metadata=True,
compute_hash=True,
)
enrich_asset(
session,
file_path=str(file_path),
reference_id=ref.id,
asset_id=asset.id,
extract_metadata=True,
compute_hash=True,
)
session.expire_all()
updated_asset = session.get(Asset, "asset-3")
@ -169,23 +141,14 @@ class TestEnrichAsset:
)
session.commit()
with patch("app.assets.scanner.create_session") as mock_cs:
from contextlib import contextmanager
@contextmanager
def _create_session():
with Session(db_engine) as sess:
yield sess
mock_cs.side_effect = _create_session
new_level = enrich_asset(
file_path=str(file_path),
reference_id=ref.id,
asset_id=asset.id,
extract_metadata=True,
compute_hash=True,
)
new_level = enrich_asset(
session,
file_path=str(file_path),
reference_id=ref.id,
asset_id=asset.id,
extract_metadata=True,
compute_hash=True,
)
assert new_level == ENRICHMENT_STUB
@ -212,31 +175,23 @@ class TestEnrichAsset:
)
session.commit()
with patch("app.assets.scanner.create_session") as mock_cs:
from contextlib import contextmanager
enrich_asset(
session,
file_path=str(file_path_1),
reference_id=ref1.id,
asset_id=asset1.id,
extract_metadata=True,
compute_hash=True,
)
@contextmanager
def _create_session():
with Session(db_engine) as sess:
yield sess
mock_cs.side_effect = _create_session
enrich_asset(
file_path=str(file_path_1),
reference_id=ref1.id,
asset_id=asset1.id,
extract_metadata=True,
compute_hash=True,
)
enrich_asset(
file_path=str(file_path_2),
reference_id=ref2.id,
asset_id=asset2.id,
extract_metadata=True,
compute_hash=True,
)
enrich_asset(
session,
file_path=str(file_path_2),
reference_id=ref2.id,
asset_id=asset2.id,
extract_metadata=True,
compute_hash=True,
)
session.expire_all()