refactor(6): fully batched initial scan

This commit is contained in:
bigcat88 2025-09-17 20:15:50 +03:00
parent f9602457d6
commit 1a37d1476d
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
8 changed files with 255 additions and 220 deletions

View File

@ -97,7 +97,7 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path) root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path) p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, normalize_tags([root_category, *parent_parts]) return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]: def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:

View File

@ -12,6 +12,7 @@ import folder_paths
from ._assets_helpers import ( from ._assets_helpers import (
collect_models_files, collect_models_files,
compute_relative_filename,
get_comfy_models_folders, get_comfy_models_folders,
get_name_and_tags_from_asset_path, get_name_and_tags_from_asset_path,
list_tree, list_tree,
@ -26,9 +27,8 @@ from .database.helpers import (
ensure_tags_exist, ensure_tags_exist,
escape_like_prefix, escape_like_prefix,
fast_asset_file_check, fast_asset_file_check,
insert_meta_from_batch,
insert_tags_from_batch,
remove_missing_tag_for_asset_id, remove_missing_tag_for_asset_id,
seed_from_paths_batch,
) )
from .database.models import Asset, AssetCacheState, AssetInfo from .database.models import Asset, AssetCacheState, AssetInfo
from .database.services import ( from .database.services import (
@ -37,7 +37,6 @@ from .database.services import (
list_cache_states_with_asset_under_prefixes, list_cache_states_with_asset_under_prefixes,
list_unhashed_candidates_under_prefixes, list_unhashed_candidates_under_prefixes,
list_verify_candidates_under_prefixes, list_verify_candidates_under_prefixes,
seed_from_path,
) )
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
@ -121,7 +120,7 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
if "output" in roots: if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory())) paths.extend(list_tree(folder_paths.get_output_directory()))
new_specs: list[tuple[str, int, int, str, list[str]]] = [] specs: list[dict] = []
tag_pool: set[str] = set() tag_pool: set[str] = set()
for p in paths: for p in paths:
ap = os.path.abspath(p) ap = os.path.abspath(p)
@ -129,54 +128,33 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
skipped_existing += 1 skipped_existing += 1
continue continue
try: try:
st = os.stat(p, follow_symlinks=True) st = os.stat(ap, follow_symlinks=True)
except OSError: except OSError:
continue continue
if not int(st.st_size or 0): if not st.st_size:
continue continue
name, tags = get_name_and_tags_from_asset_path(ap) name, tags = get_name_and_tags_from_asset_path(ap)
new_specs.append(( specs.append(
ap, {
int(st.st_size), "abs_path": ap,
getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)), "size_bytes": st.st_size,
name, "mtime_ns": getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)),
tags, "info_name": name,
)) "tags": tags,
"fname": compute_relative_filename(ap),
}
)
for t in tags: for t in tags:
tag_pool.add(t) tag_pool.add(t)
if not specs:
return
async with await create_session() as sess: async with await create_session() as sess:
if tag_pool: if tag_pool:
await ensure_tags_exist(sess, tag_pool, tag_type="user") await ensure_tags_exist(sess, tag_pool, tag_type="user")
pending_tag_links: list[dict] = [] result = await seed_from_paths_batch(sess, specs=specs, owner_id="")
pending_meta_rows: list[dict] = [] created += result["inserted_infos"]
for ap, sz, mt, name, tags in new_specs:
await seed_from_path(
sess,
abs_path=ap,
size_bytes=sz,
mtime_ns=mt,
info_name=name,
tags=tags,
owner_id="",
collected_tag_rows=pending_tag_links,
collected_meta_rows=pending_meta_rows,
)
created += 1
if created % 500 == 0:
if pending_tag_links:
await insert_tags_from_batch(sess, tag_rows=pending_tag_links)
pending_tag_links.clear()
if pending_meta_rows:
await insert_meta_from_batch(sess, rows=pending_meta_rows)
pending_meta_rows.clear()
await sess.commit()
if pending_tag_links:
await insert_tags_from_batch(sess, tag_rows=pending_tag_links)
if pending_meta_rows:
await insert_meta_from_batch(sess, rows=pending_meta_rows)
await sess.commit() await sess.commit()
finally: finally:
LOGGER.info( LOGGER.info(

View File

@ -1,13 +1,12 @@
from .bulk_ops import seed_from_paths_batch
from .escape_like import escape_like_prefix from .escape_like import escape_like_prefix
from .fast_check import fast_asset_file_check from .fast_check import fast_asset_file_check
from .filters import apply_metadata_filter, apply_tag_filters from .filters import apply_metadata_filter, apply_tag_filters
from .meta import insert_meta_from_batch
from .ownership import visible_owner_clause from .ownership import visible_owner_clause
from .projection import is_scalar, project_kv from .projection import is_scalar, project_kv
from .tags import ( from .tags import (
add_missing_tag_for_asset_id, add_missing_tag_for_asset_id,
ensure_tags_exist, ensure_tags_exist,
insert_tags_from_batch,
remove_missing_tag_for_asset_id, remove_missing_tag_for_asset_id,
) )
@ -21,7 +20,6 @@ __all__ = [
"ensure_tags_exist", "ensure_tags_exist",
"add_missing_tag_for_asset_id", "add_missing_tag_for_asset_id",
"remove_missing_tag_for_asset_id", "remove_missing_tag_for_asset_id",
"insert_meta_from_batch", "seed_from_paths_batch",
"insert_tags_from_batch",
"visible_owner_clause", "visible_owner_clause",
] ]

View File

@ -0,0 +1,231 @@
import os
import uuid
from typing import Iterable, Sequence
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag
from ..timeutil import utcnow
MAX_BIND_PARAMS = 800
async def seed_from_paths_batch(
session: AsyncSession,
*,
specs: Sequence[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
dialect = session.bind.dialect.name
if dialect not in ("sqlite", "postgresql"):
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = d_sqlite.insert(Asset) if dialect == "sqlite" else d_pg.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
await session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
winners_by_path: set[str] = set()
if dialect == "sqlite":
ins_state = (
d_sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.file_path)
)
else:
ins_state = (
d_pg.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.file_path)
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
winners_by_path.update((await session.execute(ins_state, chunk)).scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
await session.execute(sa.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
if dialect == "sqlite":
ins_info = (
d_sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
else:
ins_info = (
d_pg.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
inserted_info_ids.update((await session.execute(ins_info, chunk)).scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
await bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
async def bulk_insert_tags_and_meta(
session: AsyncSession,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
dialect = session.bind.dialect.name
if tag_rows:
if dialect == "sqlite":
ins_links = (
d_sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
elif dialect == "postgresql":
ins_links = (
d_pg.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
await session.execute(ins_links, chunk)
if meta_rows:
if dialect == "sqlite":
ins_meta = (
d_sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
elif dialect == "postgresql":
ins_meta = (
d_pg.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
await session.execute(ins_meta, chunk)
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))

View File

@ -1,30 +0,0 @@
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import AssetInfoMeta
async def insert_meta_from_batch(session: AsyncSession, *, rows: list[dict]) -> None:
"""Bulk insert rows into asset_info_meta with ON CONFLICT DO NOTHING.
Each row should contain: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if session.bind.dialect.name == "sqlite":
ins = (
d_sqlite.insert(AssetInfoMeta)
.values(rows)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
elif session.bind.dialect.name == "postgresql":
ins = (
d_pg.insert(AssetInfoMeta)
.values(rows)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
else:
raise NotImplementedError(f"Unsupported database dialect: {session.bind.dialect.name}")
await session.execute(ins)

View File

@ -88,21 +88,3 @@ async def remove_missing_tag_for_asset_id(
AssetInfoTag.tag_name == "missing", AssetInfoTag.tag_name == "missing",
) )
) )
async def insert_tags_from_batch(session: AsyncSession, *, tag_rows: list[dict]) -> None:
if session.bind.dialect.name == "sqlite":
ins_links = (
d_sqlite.insert(AssetInfoTag)
.values(tag_rows)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
elif session.bind.dialect.name == "postgresql":
ins_links = (
d_pg.insert(AssetInfoTag)
.values(tag_rows)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {session.bind.dialect.name}")
await session.execute(ins_links)

View File

@ -6,7 +6,6 @@ from .content import (
list_unhashed_candidates_under_prefixes, list_unhashed_candidates_under_prefixes,
list_verify_candidates_under_prefixes, list_verify_candidates_under_prefixes,
redirect_all_references_then_delete_asset, redirect_all_references_then_delete_asset,
seed_from_path,
touch_asset_infos_by_fs_path, touch_asset_infos_by_fs_path,
) )
from .info import ( from .info import (
@ -49,7 +48,7 @@ __all__ = [
"get_asset_tags", "list_tags_with_usage", "set_asset_info_preview", "get_asset_tags", "list_tags_with_usage", "set_asset_info_preview",
"fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags", "fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags",
# content # content
"check_fs_asset_exists_quick", "seed_from_path", "check_fs_asset_exists_quick",
"redirect_all_references_then_delete_asset", "redirect_all_references_then_delete_asset",
"compute_hash_and_dedup_for_cache_state", "compute_hash_and_dedup_for_cache_state",
"list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes", "list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes",

View File

@ -1,7 +1,6 @@
import contextlib import contextlib
import logging import logging
import os import os
import uuid
from datetime import datetime from datetime import datetime
from typing import Any, Optional, Sequence, Union from typing import Any, Optional, Sequence, Union
@ -13,7 +12,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import noload from sqlalchemy.orm import noload
from ..._assets_helpers import compute_relative_filename, normalize_tags from ..._assets_helpers import compute_relative_filename
from ...storage import hashing as hashing_mod from ...storage import hashing as hashing_mod
from ..helpers import ( from ..helpers import (
ensure_tags_exist, ensure_tags_exist,
@ -58,128 +57,6 @@ async def check_fs_asset_exists_quick(
return (await session.execute(stmt)).first() is not None return (await session.execute(stmt)).first() is not None
async def seed_from_path(
session: AsyncSession,
*,
abs_path: str,
size_bytes: int,
mtime_ns: int,
info_name: str,
tags: Sequence[str],
owner_id: str = "",
collected_tag_rows: list[dict],
collected_meta_rows: list[dict],
) -> None:
"""Creates Asset(hash=NULL), AssetCacheState(file_path), and AssetInfo exist for the path."""
locator = os.path.abspath(abs_path)
now = utcnow()
dialect = session.bind.dialect.name
new_asset_id = str(uuid.uuid4())
new_info_id = str(uuid.uuid4())
# 1) Insert Asset (hash=NULL) no conflict expected
asset_vals = {
"id": new_asset_id,
"hash": None,
"size_bytes": size_bytes,
"mime_type": None,
"created_at": now,
}
if dialect == "sqlite":
await session.execute(d_sqlite.insert(Asset).values(**asset_vals))
elif dialect == "postgresql":
await session.execute(d_pg.insert(Asset).values(**asset_vals))
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
# 2) Try to claim file_path in AssetCacheState. Our concurrency gate.
acs_vals = {
"asset_id": new_asset_id,
"file_path": locator,
"mtime_ns": mtime_ns,
}
if dialect == "sqlite":
ins_state = (
d_sqlite.insert(AssetCacheState)
.values(**acs_vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
state_inserted = int((await session.execute(ins_state)).rowcount or 0) > 0
else:
ins_state = (
d_pg.insert(AssetCacheState)
.values(**acs_vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.id)
)
state_inserted = (await session.execute(ins_state)).scalar_one_or_none() is not None
if not state_inserted:
# Lost the race - clean up our orphan seed Asset and exit
with contextlib.suppress(Exception):
await session.execute(sa.delete(Asset).where(Asset.id == new_asset_id))
return
# 3) Create AssetInfo (unique(asset_id, owner_id, name)).
fname = compute_relative_filename(locator)
info_vals = {
"id": new_info_id,
"owner_id": owner_id,
"name": info_name,
"asset_id": new_asset_id,
"preview_id": None,
"user_metadata": {"filename": fname} if fname else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
}
if dialect == "sqlite":
ins_info = (
d_sqlite.insert(AssetInfo)
.values(**info_vals)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
)
info_inserted = int((await session.execute(ins_info)).rowcount or 0) > 0
else:
ins_info = (
d_pg.insert(AssetInfo)
.values(**info_vals)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
info_inserted = (await session.execute(ins_info)).scalar_one_or_none() is not None
# 4) If we actually inserted AssetInfo, attach tags and filename.
if info_inserted:
want = normalize_tags(tags)
if want:
tag_rows = [
{
"asset_info_id": new_info_id,
"tag_name": t,
"origin": "automatic",
"added_at": now,
}
for t in want
]
collected_tag_rows.extend(tag_rows)
if fname: # simple filename projection with single row
collected_meta_rows.append(
{
"asset_info_id": new_info_id,
"key": "filename",
"ordinal": 0,
"val_str": fname,
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
async def redirect_all_references_then_delete_asset( async def redirect_all_references_then_delete_asset(
session: AsyncSession, session: AsyncSession,
*, *,