refactor: move bulk_ops to queries and scanner service

- Delete bulk_ops.py, moving logic to appropriate layers
- Add bulk insert query functions:
  - queries/asset.bulk_insert_assets
  - queries/cache_state.bulk_insert_cache_states_ignore_conflicts
  - queries/cache_state.get_cache_states_by_paths_and_asset_ids
  - queries/asset_info.bulk_insert_asset_infos_ignore_conflicts
  - queries/asset_info.get_asset_info_ids_by_ids
  - queries/tags.bulk_insert_tags_and_meta
- Move seed_from_paths_batch orchestration to scanner._seed_from_paths_batch

Amp-Thread-ID: https://ampcode.com/threads/T-019c24fd-157d-776a-ad24-4f19cf5d3afe
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr 2026-02-03 11:50:39 -08:00
parent 64d2f51dfc
commit c3105b1174
7 changed files with 343 additions and 205 deletions

View File

@ -1,202 +0,0 @@
import os
import uuid
import sqlalchemy
from typing import Iterable
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import utcnow
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
MAX_BIND_PARAMS = 800
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))
def seed_from_paths_batch(
session: Session,
specs: list[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = sqlite.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
ins_state = (
sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
session.execute(ins_state, chunk)
# Query to find which of our paths won (were actually inserted)
winners_by_path: set[str] = set()
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetCacheState.file_path)
.where(AssetCacheState.file_path.in_(chunk))
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
)
winners_by_path.update(result.scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
ins_info = (
sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
)
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
session.execute(ins_info, chunk)
# Query to find which info rows were actually inserted (by matching our generated IDs)
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
)
inserted_info_ids.update(result.scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
def bulk_insert_tags_and_meta(
session: Session,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if tag_rows:
ins_links = (
sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
session.execute(ins_links, chunk)
if meta_rows:
ins_meta = (
sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
session.execute(ins_meta, chunk)

View File

@ -5,6 +5,7 @@ from app.assets.database.queries.asset import (
asset_exists_by_hash,
get_asset_by_hash,
upsert_asset,
bulk_insert_assets,
)
from app.assets.database.queries.asset_info import (
@ -20,6 +21,8 @@ from app.assets.database.queries.asset_info import (
replace_asset_info_metadata_projection,
delete_asset_info_by_id,
set_asset_info_preview,
bulk_insert_asset_infos_ignore_conflicts,
get_asset_info_ids_by_ids,
)
from app.assets.database.queries.cache_state import (
@ -33,6 +36,8 @@ from app.assets.database.queries.cache_state import (
bulk_set_needs_verify,
delete_cache_states_by_ids,
delete_orphaned_seed_asset,
bulk_insert_cache_states_ignore_conflicts,
get_cache_states_by_paths_and_asset_ids,
)
from app.assets.database.queries.tags import (
@ -44,6 +49,7 @@ from app.assets.database.queries.tags import (
add_missing_tag_for_asset_id,
remove_missing_tag_for_asset_id,
list_tags_with_usage,
bulk_insert_tags_and_meta,
)
__all__ = [
@ -51,6 +57,7 @@ __all__ = [
"asset_exists_by_hash",
"get_asset_by_hash",
"upsert_asset",
"bulk_insert_assets",
# asset_info.py
"asset_info_exists_for_asset_id",
"get_asset_info_by_id",
@ -64,6 +71,8 @@ __all__ = [
"replace_asset_info_metadata_projection",
"delete_asset_info_by_id",
"set_asset_info_preview",
"bulk_insert_asset_infos_ignore_conflicts",
"get_asset_info_ids_by_ids",
# cache_state.py
"CacheStateRow",
"list_cache_states_by_asset_id",
@ -75,6 +84,8 @@ __all__ = [
"bulk_set_needs_verify",
"delete_cache_states_by_ids",
"delete_orphaned_seed_asset",
"bulk_insert_cache_states_ignore_conflicts",
"get_cache_states_by_paths_and_asset_ids",
# tags.py
"ensure_tags_exist",
"get_asset_tags",
@ -84,4 +95,5 @@ __all__ = [
"add_missing_tag_for_asset_id",
"remove_missing_tag_for_asset_id",
"list_tags_with_usage",
"bulk_insert_tags_and_meta",
]

View File

@ -1,3 +1,5 @@
from typing import Iterable
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import sqlite
@ -5,6 +7,17 @@ from sqlalchemy.orm import Session
from app.assets.database.models import Asset
MAX_BIND_PARAMS = 800
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i : i + n]
def asset_exists_by_hash(
session: Session,
@ -68,3 +81,15 @@ def upsert_asset(
updated = True
return asset, created, updated
def bulk_insert_assets(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert Asset rows. Each dict should have: id, hash, size_bytes, mime_type, created_at."""
if not rows:
return
ins = sqlite.insert(Asset)
for chunk in _iter_chunks(rows, _rows_per_stmt(5)):
session.execute(ins, chunk)

View File

@ -11,6 +11,7 @@ from typing import Sequence
import sqlalchemy as sa
from sqlalchemy import select, delete, exists
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, contains_eager, noload
@ -19,6 +20,17 @@ from app.assets.database.models import (
)
from app.assets.helpers import escape_like_prefix, normalize_tags, project_kv, utcnow
MAX_BIND_PARAMS = 800
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i : i + n]
def _visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
@ -410,3 +422,38 @@ def set_asset_info_preview(
info.updated_at = utcnow()
session.flush()
def bulk_insert_asset_infos_ignore_conflicts(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert AssetInfo rows with ON CONFLICT DO NOTHING.
Each dict should have: id, owner_id, name, asset_id, preview_id,
user_metadata, created_at, updated_at, last_access_time
"""
if not rows:
return
ins = sqlite.insert(AssetInfo).on_conflict_do_nothing(
index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]
)
for chunk in _iter_chunks(rows, _rows_per_stmt(9)):
session.execute(ins, chunk)
def get_asset_info_ids_by_ids(
session: Session,
info_ids: list[str],
) -> set[str]:
"""Query to find which AssetInfo IDs exist in the database."""
if not info_ids:
return set()
found: set[str] = set()
for chunk in _iter_chunks(info_ids, MAX_BIND_PARAMS):
result = session.execute(
select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
)
found.update(result.scalars().all())
return found

View File

@ -9,6 +9,8 @@ from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
from app.assets.helpers import escape_like_prefix
MAX_BIND_PARAMS = 800
__all__ = [
"CacheStateRow",
"list_cache_states_by_asset_id",
@ -20,9 +22,20 @@ __all__ = [
"bulk_set_needs_verify",
"delete_cache_states_by_ids",
"delete_orphaned_seed_asset",
"bulk_insert_cache_states_ignore_conflicts",
"get_cache_states_by_paths_and_asset_ids",
]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i : i + n]
class CacheStateRow(NamedTuple):
"""Row from cache state query with joined asset data."""
@ -233,3 +246,50 @@ def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool:
session.delete(asset)
return True
return False
def bulk_insert_cache_states_ignore_conflicts(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert cache state rows with ON CONFLICT DO NOTHING on file_path.
Each dict should have: asset_id, file_path, mtime_ns
"""
if not rows:
return
ins = sqlite.insert(AssetCacheState).on_conflict_do_nothing(
index_elements=[AssetCacheState.file_path]
)
for chunk in _iter_chunks(rows, _rows_per_stmt(3)):
session.execute(ins, chunk)
def get_cache_states_by_paths_and_asset_ids(
session: Session,
path_to_asset: dict[str, str],
) -> set[str]:
"""Query cache states to find paths where our asset_id won the insert.
Args:
path_to_asset: Mapping of file_path -> asset_id we tried to insert
Returns:
Set of file_paths where our asset_id is present
"""
if not path_to_asset:
return set()
paths = list(path_to_asset.keys())
winners: set[str] = set()
for chunk in _iter_chunks(paths, MAX_BIND_PARAMS):
result = session.execute(
select(AssetCacheState.file_path).where(
AssetCacheState.file_path.in_(chunk),
AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]),
)
)
winners.update(result.scalars().all())
return winners

View File

@ -6,9 +6,23 @@ from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.assets.database.models import AssetInfo, AssetInfoTag, Tag
from app.assets.database.models import AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import escape_like_prefix, normalize_tags, utcnow
MAX_BIND_PARAMS = 800
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))
def _chunk_rows(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, MAX_BIND_PARAMS // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i : i + rows_per_stmt]
def _visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
@ -273,3 +287,30 @@ def list_tags_with_usage(
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def bulk_insert_tags_and_meta(
session: Session,
tag_rows: list[dict],
meta_rows: list[dict],
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
Args:
session: Database session
tag_rows: List of dicts with keys: asset_info_id, tag_name, origin, added_at
meta_rows: List of dicts with keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if tag_rows:
ins_tags = sqlite.insert(AssetInfoTag).on_conflict_do_nothing(
index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]
)
for chunk in _chunk_rows(tag_rows, cols_per_row=4):
session.execute(ins_tags, chunk)
if meta_rows:
ins_meta = sqlite.insert(AssetInfoMeta).on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
for chunk in _chunk_rows(meta_rows, cols_per_row=7):
session.execute(ins_meta, chunk)

View File

@ -2,9 +2,11 @@ import contextlib
import logging
import os
import time
import uuid
import folder_paths
from app.assets.database.bulk_ops import seed_from_paths_batch
from sqlalchemy.orm import Session
from app.assets.database.queries import (
add_missing_tag_for_asset_id,
ensure_tags_exist,
@ -16,6 +18,12 @@ from app.assets.database.queries import (
bulk_set_needs_verify,
delete_cache_states_by_ids,
delete_orphaned_seed_asset,
bulk_insert_assets,
bulk_insert_cache_states_ignore_conflicts,
get_cache_states_by_paths_and_asset_ids,
bulk_insert_asset_infos_ignore_conflicts,
get_asset_info_ids_by_ids,
bulk_insert_tags_and_meta,
)
from app.assets.helpers import (
collect_models_files,
@ -25,10 +33,157 @@ from app.assets.helpers import (
list_tree,
prefixes_for_root,
RootType,
utcnow,
)
from app.database.db import create_session, dependencies_available
def _seed_from_paths_batch(
session: Session,
specs: list[dict],
owner_id: str = "",
) -> dict:
"""Seed assets from filesystem specs in batch.
Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
This function orchestrates:
1. Insert seed Assets (hash=NULL)
2. Claim cache states with ON CONFLICT DO NOTHING
3. Query to find winners (paths where our asset_id was inserted)
4. Delete Assets for losers (path already claimed by another asset)
5. Insert AssetInfo for winners
6. Insert tags and metadata for successfully inserted AssetInfos
Returns:
dict with keys: inserted_infos, won_states, lost_states
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {}
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append({
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
})
state_rows.append({
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
})
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# 1. Insert all seed Assets (hash=NULL)
bulk_insert_assets(session, asset_rows)
# 2. Try to claim cache states (file_path unique)
bulk_insert_cache_states_ignore_conflicts(session, state_rows)
# 3. Query to find which paths we won
winners_by_path = get_cache_states_by_paths_and_asset_ids(session, path_to_asset)
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
# 4. Delete Assets for losers
if lost_assets:
delete_assets_by_ids(session, lost_assets)
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# 5. Insert AssetInfo for winners
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
db_info_rows = [
{
"id": row["id"],
"owner_id": row["owner_id"],
"name": row["name"],
"asset_id": row["asset_id"],
"preview_id": row["preview_id"],
"user_metadata": row["user_metadata"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"last_access_time": row["last_access_time"],
}
for row in winner_info_rows
]
bulk_insert_asset_infos_ignore_conflicts(session, db_info_rows)
# 6. Find which info rows were actually inserted
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids = get_asset_info_ids_by_ids(session, all_info_ids)
# 7. Build and insert tag + meta rows
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append({
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
})
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
def prune_orphaned_assets(session, valid_prefixes: list[str]) -> int:
"""Prune cache states outside valid prefixes, then delete orphaned seed assets.
@ -229,7 +384,7 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
result = _seed_from_paths_batch(sess, specs=specs, owner_id="")
created += result["inserted_infos"]
sess.commit()