optimization: fast scan: commit to the DB in chunks

This commit is contained in:
bigcat88 2025-09-16 14:21:40 +03:00
parent 24a95f5ca4
commit 77332d3054
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
2 changed files with 49 additions and 44 deletions

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import contextlib
import logging import logging
import os import os
import time import time
@ -95,45 +96,55 @@ async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetS
async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
for r in roots: t_total = time.perf_counter()
try: try:
await _fast_db_consistency_pass(r) for r in roots:
except Exception as ex: try:
LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex) await _fast_db_consistency_pass(r)
except Exception as ex:
LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex)
paths: list[str] = [] paths: list[str] = []
if "models" in roots: if "models" in roots:
paths.extend(collect_models_files()) paths.extend(collect_models_files())
if "input" in roots: if "input" in roots:
paths.extend(list_tree(folder_paths.get_input_directory())) paths.extend(list_tree(folder_paths.get_input_directory()))
if "output" in roots: if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory())) paths.extend(list_tree(folder_paths.get_output_directory()))
for p in paths: processed = 0
try: async with await create_session() as sess:
st = os.stat(p, follow_symlinks=True) for p in paths:
if not int(st.st_size or 0): try:
continue st = os.stat(p, follow_symlinks=True)
size_bytes = int(st.st_size) if not int(st.st_size or 0):
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) continue
name, tags = get_name_and_tags_from_asset_path(p) size_bytes = int(st.st_size)
await _seed_one_async(p, size_bytes, mtime_ns, name, tags) mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
except OSError: name, tags = get_name_and_tags_from_asset_path(p)
continue
await ensure_seed_for_path(
sess,
abs_path=p,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
info_name=name,
tags=tags,
owner_id="",
)
async def _seed_one_async(p: str, size_bytes: int, mtime_ns: int, name: str, tags: list[str]) -> None: processed += 1
async with await create_session() as sess: if processed % 500 == 0:
await ensure_seed_for_path( await sess.commit()
sess, except OSError:
abs_path=p, continue
size_bytes=size_bytes, await sess.commit()
mtime_ns=mtime_ns, finally:
info_name=name, LOGGER.info(
tags=tags, "Assets scan(roots=%s) completed in %.3f s",
owner_id="", roots,
time.perf_counter() - t_total,
) )
await sess.commit()
def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse: def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse:
@ -482,20 +493,13 @@ async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None:
if any_fast_ok: if any_fast_ok:
# Remove 'missing' and delete just the stale state rows # Remove 'missing' and delete just the stale state rows
for st in missing_states: for st in missing_states:
try: with contextlib.suppress(Exception):
await sess.delete(await sess.get(AssetCacheState, st.id)) await sess.delete(await sess.get(AssetCacheState, st.id))
except Exception: with contextlib.suppress(Exception):
pass
try:
await remove_missing_tag_for_asset_id(sess, asset_id=aid) await remove_missing_tag_for_asset_id(sess, asset_id=aid)
except Exception:
pass
else: else:
# No fast-ok path: mark as missing with contextlib.suppress(Exception):
try:
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
except Exception:
pass
await sess.flush() await sess.flush()
await sess.commit() await sess.commit()

View File

@ -87,6 +87,7 @@ async def ensure_seed_for_path(
state_row.needs_verify = True state_row.needs_verify = True
if asset_row.size_bytes == 0 and size_bytes > 0: if asset_row.size_bytes == 0 and size_bytes > 0:
asset_row.size_bytes = int(size_bytes) asset_row.size_bytes = int(size_bytes)
await session.flush()
return asset_row.id return asset_row.id
asset = Asset(hash=None, size_bytes=int(size_bytes), mime_type=None, created_at=now) asset = Asset(hash=None, size_bytes=int(size_bytes), mime_type=None, created_at=now)