refactor(6): fully batched initial scan

This commit is contained in:
bigcat88 2025-09-17 20:15:50 +03:00
parent f9602457d6
commit 1a37d1476d
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
8 changed files with 255 additions and 220 deletions

View File

@ -97,7 +97,7 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, normalize_tags([root_category, *parent_parts])
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:

View File

@ -12,6 +12,7 @@ import folder_paths
from ._assets_helpers import (
collect_models_files,
compute_relative_filename,
get_comfy_models_folders,
get_name_and_tags_from_asset_path,
list_tree,
@ -26,9 +27,8 @@ from .database.helpers import (
ensure_tags_exist,
escape_like_prefix,
fast_asset_file_check,
insert_meta_from_batch,
insert_tags_from_batch,
remove_missing_tag_for_asset_id,
seed_from_paths_batch,
)
from .database.models import Asset, AssetCacheState, AssetInfo
from .database.services import (
@ -37,7 +37,6 @@ from .database.services import (
list_cache_states_with_asset_under_prefixes,
list_unhashed_candidates_under_prefixes,
list_verify_candidates_under_prefixes,
seed_from_path,
)
LOGGER = logging.getLogger(__name__)
@ -121,7 +120,7 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory()))
new_specs: list[tuple[str, int, int, str, list[str]]] = []
specs: list[dict] = []
tag_pool: set[str] = set()
for p in paths:
ap = os.path.abspath(p)
@ -129,54 +128,33 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
skipped_existing += 1
continue
try:
st = os.stat(p, follow_symlinks=True)
st = os.stat(ap, follow_symlinks=True)
except OSError:
continue
if not int(st.st_size or 0):
if not st.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(ap)
new_specs.append((
ap,
int(st.st_size),
getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)),
name,
tags,
))
specs.append(
{
"abs_path": ap,
"size_bytes": st.st_size,
"mtime_ns": getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(ap),
}
)
for t in tags:
tag_pool.add(t)
if not specs:
return
async with await create_session() as sess:
if tag_pool:
await ensure_tags_exist(sess, tag_pool, tag_type="user")
pending_tag_links: list[dict] = []
pending_meta_rows: list[dict] = []
for ap, sz, mt, name, tags in new_specs:
await seed_from_path(
sess,
abs_path=ap,
size_bytes=sz,
mtime_ns=mt,
info_name=name,
tags=tags,
owner_id="",
collected_tag_rows=pending_tag_links,
collected_meta_rows=pending_meta_rows,
)
created += 1
if created % 500 == 0:
if pending_tag_links:
await insert_tags_from_batch(sess, tag_rows=pending_tag_links)
pending_tag_links.clear()
if pending_meta_rows:
await insert_meta_from_batch(sess, rows=pending_meta_rows)
pending_meta_rows.clear()
await sess.commit()
if pending_tag_links:
await insert_tags_from_batch(sess, tag_rows=pending_tag_links)
if pending_meta_rows:
await insert_meta_from_batch(sess, rows=pending_meta_rows)
result = await seed_from_paths_batch(sess, specs=specs, owner_id="")
created += result["inserted_infos"]
await sess.commit()
finally:
LOGGER.info(

View File

@ -1,13 +1,12 @@
from .bulk_ops import seed_from_paths_batch
from .escape_like import escape_like_prefix
from .fast_check import fast_asset_file_check
from .filters import apply_metadata_filter, apply_tag_filters
from .meta import insert_meta_from_batch
from .ownership import visible_owner_clause
from .projection import is_scalar, project_kv
from .tags import (
add_missing_tag_for_asset_id,
ensure_tags_exist,
insert_tags_from_batch,
remove_missing_tag_for_asset_id,
)
@ -21,7 +20,6 @@ __all__ = [
"ensure_tags_exist",
"add_missing_tag_for_asset_id",
"remove_missing_tag_for_asset_id",
"insert_meta_from_batch",
"insert_tags_from_batch",
"seed_from_paths_batch",
"visible_owner_clause",
]

View File

@ -0,0 +1,231 @@
import os
import uuid
from typing import Iterable, Sequence
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag
from ..timeutil import utcnow
MAX_BIND_PARAMS = 800
async def seed_from_paths_batch(
session: AsyncSession,
*,
specs: Sequence[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
dialect = session.bind.dialect.name
if dialect not in ("sqlite", "postgresql"):
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = d_sqlite.insert(Asset) if dialect == "sqlite" else d_pg.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
await session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
winners_by_path: set[str] = set()
if dialect == "sqlite":
ins_state = (
d_sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.file_path)
)
else:
ins_state = (
d_pg.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.file_path)
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
winners_by_path.update((await session.execute(ins_state, chunk)).scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
await session.execute(sa.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
if dialect == "sqlite":
ins_info = (
d_sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
else:
ins_info = (
d_pg.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
inserted_info_ids.update((await session.execute(ins_info, chunk)).scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
await bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
async def bulk_insert_tags_and_meta(
session: AsyncSession,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
dialect = session.bind.dialect.name
if tag_rows:
if dialect == "sqlite":
ins_links = (
d_sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
elif dialect == "postgresql":
ins_links = (
d_pg.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
await session.execute(ins_links, chunk)
if meta_rows:
if dialect == "sqlite":
ins_meta = (
d_sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
elif dialect == "postgresql":
ins_meta = (
d_pg.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
await session.execute(ins_meta, chunk)
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))

View File

@ -1,30 +0,0 @@
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import AssetInfoMeta
async def insert_meta_from_batch(session: AsyncSession, *, rows: list[dict]) -> None:
"""Bulk insert rows into asset_info_meta with ON CONFLICT DO NOTHING.
Each row should contain: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if session.bind.dialect.name == "sqlite":
ins = (
d_sqlite.insert(AssetInfoMeta)
.values(rows)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
elif session.bind.dialect.name == "postgresql":
ins = (
d_pg.insert(AssetInfoMeta)
.values(rows)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
else:
raise NotImplementedError(f"Unsupported database dialect: {session.bind.dialect.name}")
await session.execute(ins)

View File

@ -88,21 +88,3 @@ async def remove_missing_tag_for_asset_id(
AssetInfoTag.tag_name == "missing",
)
)
async def insert_tags_from_batch(session: AsyncSession, *, tag_rows: list[dict]) -> None:
if session.bind.dialect.name == "sqlite":
ins_links = (
d_sqlite.insert(AssetInfoTag)
.values(tag_rows)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
elif session.bind.dialect.name == "postgresql":
ins_links = (
d_pg.insert(AssetInfoTag)
.values(tag_rows)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {session.bind.dialect.name}")
await session.execute(ins_links)

View File

@ -6,7 +6,6 @@ from .content import (
list_unhashed_candidates_under_prefixes,
list_verify_candidates_under_prefixes,
redirect_all_references_then_delete_asset,
seed_from_path,
touch_asset_infos_by_fs_path,
)
from .info import (
@ -49,7 +48,7 @@ __all__ = [
"get_asset_tags", "list_tags_with_usage", "set_asset_info_preview",
"fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags",
# content
"check_fs_asset_exists_quick", "seed_from_path",
"check_fs_asset_exists_quick",
"redirect_all_references_then_delete_asset",
"compute_hash_and_dedup_for_cache_state",
"list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes",

View File

@ -1,7 +1,6 @@
import contextlib
import logging
import os
import uuid
from datetime import datetime
from typing import Any, Optional, Sequence, Union
@ -13,7 +12,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import noload
from ..._assets_helpers import compute_relative_filename, normalize_tags
from ..._assets_helpers import compute_relative_filename
from ...storage import hashing as hashing_mod
from ..helpers import (
ensure_tags_exist,
@ -58,128 +57,6 @@ async def check_fs_asset_exists_quick(
return (await session.execute(stmt)).first() is not None
async def seed_from_path(
session: AsyncSession,
*,
abs_path: str,
size_bytes: int,
mtime_ns: int,
info_name: str,
tags: Sequence[str],
owner_id: str = "",
collected_tag_rows: list[dict],
collected_meta_rows: list[dict],
) -> None:
"""Creates Asset(hash=NULL), AssetCacheState(file_path), and AssetInfo exist for the path."""
locator = os.path.abspath(abs_path)
now = utcnow()
dialect = session.bind.dialect.name
new_asset_id = str(uuid.uuid4())
new_info_id = str(uuid.uuid4())
# 1) Insert Asset (hash=NULL) no conflict expected
asset_vals = {
"id": new_asset_id,
"hash": None,
"size_bytes": size_bytes,
"mime_type": None,
"created_at": now,
}
if dialect == "sqlite":
await session.execute(d_sqlite.insert(Asset).values(**asset_vals))
elif dialect == "postgresql":
await session.execute(d_pg.insert(Asset).values(**asset_vals))
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
# 2) Try to claim file_path in AssetCacheState. Our concurrency gate.
acs_vals = {
"asset_id": new_asset_id,
"file_path": locator,
"mtime_ns": mtime_ns,
}
if dialect == "sqlite":
ins_state = (
d_sqlite.insert(AssetCacheState)
.values(**acs_vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
state_inserted = int((await session.execute(ins_state)).rowcount or 0) > 0
else:
ins_state = (
d_pg.insert(AssetCacheState)
.values(**acs_vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
.returning(AssetCacheState.id)
)
state_inserted = (await session.execute(ins_state)).scalar_one_or_none() is not None
if not state_inserted:
# Lost the race - clean up our orphan seed Asset and exit
with contextlib.suppress(Exception):
await session.execute(sa.delete(Asset).where(Asset.id == new_asset_id))
return
# 3) Create AssetInfo (unique(asset_id, owner_id, name)).
fname = compute_relative_filename(locator)
info_vals = {
"id": new_info_id,
"owner_id": owner_id,
"name": info_name,
"asset_id": new_asset_id,
"preview_id": None,
"user_metadata": {"filename": fname} if fname else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
}
if dialect == "sqlite":
ins_info = (
d_sqlite.insert(AssetInfo)
.values(**info_vals)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
)
info_inserted = int((await session.execute(ins_info)).rowcount or 0) > 0
else:
ins_info = (
d_pg.insert(AssetInfo)
.values(**info_vals)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
.returning(AssetInfo.id)
)
info_inserted = (await session.execute(ins_info)).scalar_one_or_none() is not None
# 4) If we actually inserted AssetInfo, attach tags and filename.
if info_inserted:
want = normalize_tags(tags)
if want:
tag_rows = [
{
"asset_info_id": new_info_id,
"tag_name": t,
"origin": "automatic",
"added_at": now,
}
for t in want
]
collected_tag_rows.extend(tag_rows)
if fname: # simple filename projection with single row
collected_meta_rows.append(
{
"asset_info_id": new_info_id,
"key": "filename",
"ordinal": 0,
"val_str": fname,
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
async def redirect_all_references_then_delete_asset(
session: AsyncSession,
*,