diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index c80874aa2..681af2635 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -1,5 +1,4 @@ -# File: /alembic_db/versions/0001_assets.py -"""initial assets schema + per-asset state cache +"""initial assets schema Revision ID: 0001_assets Revises: @@ -69,15 +68,18 @@ def upgrade() -> None: op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) - # ASSET_CACHE_STATE: 1:1 local cache metadata for an Asset + # ASSET_CACHE_STATE: N:1 local cache metadata rows per Asset op.create_table( "asset_cache_state", - sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True), + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False), sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file sa.Column("mtime_ns", sa.BigInteger(), nullable=True), sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), ) op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) + op.create_index("ix_asset_cache_state_asset_hash", "asset_cache_state", ["asset_hash"]) # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting op.create_table( @@ -144,7 +146,7 @@ def upgrade() -> None: {"name": "photomaker", "tag_type": "system"}, {"name": "classifiers", "tag_type": "system"}, - # Extra basic tags (used for vae_approx, ...) + # Extra basic tags {"name": "encoder", "tag_type": "system"}, {"name": "decoder", "tag_type": "system"}, ], @@ -162,6 +164,7 @@ def downgrade() -> None: op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") op.drop_table("asset_info_meta") + op.drop_index("ix_asset_cache_state_asset_hash", table_name="asset_cache_state") op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") op.drop_table("asset_cache_state") diff --git a/app/_assets_helpers.py b/app/_assets_helpers.py index ddc43f1ea..8fb88cd34 100644 --- a/app/_assets_helpers.py +++ b/app/_assets_helpers.py @@ -147,7 +147,7 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: return AssetInfo.owner_id.in_(["", owner_id]) -def compute_model_relative_filename(file_path: str) -> str | None: +def compute_model_relative_filename(file_path: str) -> Optional[str]: """ Return the model's path relative to the last well-known folder (the model category), using forward slashes, eg: diff --git a/app/assets_fetcher.py b/app/assets_fetcher.py index ea1c8ed00..36fa64ca9 100644 --- a/app/assets_fetcher.py +++ b/app/assets_fetcher.py @@ -8,7 +8,7 @@ import aiohttp from .storage.hashing import blake3_hash_sync from .database.db import create_session -from .database.services import ingest_fs_asset, get_cache_state_by_asset_hash +from .database.services import ingest_fs_asset, list_cache_states_by_asset_hash from .resolvers import resolve_asset from ._assets_helpers import resolve_destination_from_tags, ensure_within_base @@ -26,20 +26,25 @@ async def ensure_asset_cached( tags_hint: Optional[list[str]] = None, ) -> str: """ - Ensure there is a verified local file for `asset_hash` in the correct Comfy folder. - Policy: - - Resolver must provide valid tags (root and, for models, category). - - If target path already exists: - * if hash matches -> reuse & ingest - * else -> remove and overwrite with the correct content + Ensure there is a verified local file for asset_hash in the correct Comfy folder. + + Fast path: + - If any cache_state row has a file_path that exists, return it immediately. + Preference order is the oldest ID first for stability. + + Slow path: + - Resolve remote location + placement tags. + - Download to the correct folder, verify hash, move into place. + - Ingest identity + cache state so future fast passes can skip hashing. """ lock = _FETCH_LOCKS.setdefault(asset_hash, asyncio.Lock()) async with lock: - # 1) If we already have a state -> trust the path + # 1) If we already have any cache_state path present on disk, use it (oldest-first) async with await create_session() as sess: - state = await get_cache_state_by_asset_hash(sess, asset_hash=asset_hash) - if state and os.path.isfile(state.file_path): - return state.file_path + states = await list_cache_states_by_asset_hash(sess, asset_hash=asset_hash) + for s in states: + if s and s.file_path and os.path.isfile(s.file_path): + return s.file_path # 2) Resolve remote location + placement hints (must include valid tags) res = await resolve_asset(asset_hash) @@ -107,7 +112,7 @@ async def ensure_asset_cached( finally: raise ValueError(f"Hash mismatch: expected {asset_hash}, got {canonical}") - # 7) Atomically move into place (we already removed an invalid file if it existed) + # 7) Atomically move into place if os.path.exists(final_path): os.remove(final_path) os.replace(tmp_path, final_path) diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 86e8b23cd..42cf123d2 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import logging import os import uuid @@ -106,7 +107,7 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes async def fast_reconcile_and_kickoff( - roots: Sequence[str] | None = None, + roots: Optional[Sequence[str]] = None, *, progress_cb: Optional[Callable[[str, str, int, bool, dict], None]] = None, ) -> schemas_out.AssetScanStatusResponse: @@ -216,18 +217,18 @@ async def _fast_reconcile_into_queue( """ if root == "models": files = _collect_models_files() - preset_discovered = len(files) + preset_discovered = _count_nonzero_in_list(files) files_iter = asyncio.Queue() for p in files: await files_iter.put(p) await files_iter.put(None) # sentinel for our local draining loop elif root == "input": base = folder_paths.get_input_directory() - preset_discovered = _count_files_in_tree(os.path.abspath(base)) + preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True) files_iter = await _queue_tree_files(base) elif root == "output": base = folder_paths.get_output_directory() - preset_discovered = _count_files_in_tree(os.path.abspath(base)) + preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True) files_iter = await _queue_tree_files(base) else: raise RuntimeError(f"Unsupported root: {root}") @@ -378,26 +379,41 @@ def _collect_models_files() -> list[str]: allowed = False for b in bases: base_abs = os.path.abspath(b) - try: + with contextlib.suppress(Exception): if os.path.commonpath([abs_path, base_abs]) == base_abs: allowed = True break - except Exception: - pass if allowed: out.append(abs_path) return out -def _count_files_in_tree(base_abs: str) -> int: +def _count_files_in_tree(base_abs: str, *, only_nonzero: bool = False) -> int: if not os.path.isdir(base_abs): return 0 total = 0 - for _dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): - total += len(filenames) + for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): + if not only_nonzero: + total += len(filenames) + else: + for name in filenames: + with contextlib.suppress(OSError): + st = os.stat(os.path.join(dirpath, name), follow_symlinks=True) + if st.st_size: + total += 1 return total +def _count_nonzero_in_list(paths: list[str]) -> int: + cnt = 0 + for p in paths: + with contextlib.suppress(OSError): + st = os.stat(p, follow_symlinks=True) + if st.st_size: + cnt += 1 + return cnt + + async def _queue_tree_files(base_dir: str) -> asyncio.Queue: """ Walk base_dir in a worker thread and return a queue prefilled with all paths, @@ -455,7 +471,7 @@ def _console_cb(root: str, phase: str, total_processed: int, finished: bool, e: e["discovered"], e["queued"], ) - elif e.get("checked", 0) % 500 == 0: # do not spam with fast progress + elif e.get("checked", 0) % 1000 == 0: # do not spam with fast progress logging.info( "[assets][%s] fast progress: processed=%s/%s", root, @@ -464,12 +480,13 @@ def _console_cb(root: str, phase: str, total_processed: int, finished: bool, e: ) elif phase == "slow": if finished: - logging.info( - "[assets][%s] slow done: %s/%s", - root, - e.get("slow_queue_finished", 0), - e.get("slow_queue_total", 0), - ) + if e.get("slow_queue_finished", 0) or e.get("slow_queue_total", 0): + logging.info( + "[assets][%s] slow done: %s/%s", + root, + e.get("slow_queue_finished", 0), + e.get("slow_queue_total", 0), + ) elif e.get('slow_queue_finished', 0) % 3 == 0: logging.info( "[assets][%s] slow progress: %s/%s", diff --git a/app/database/models.py b/app/database/models.py index 47f8bbaf3..203867468 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime from typing import Any, Optional import uuid @@ -66,9 +68,8 @@ class Asset(Base): viewonly=True, ) - cache_state: Mapped["AssetCacheState | None"] = relationship( + cache_states: Mapped[list["AssetCacheState"]] = relationship( back_populates="asset", - uselist=False, cascade="all, delete-orphan", passive_deletes=True, ) @@ -93,24 +94,25 @@ class Asset(Base): class AssetCacheState(Base): __tablename__ = "asset_cache_state" - asset_hash: Mapped[str] = mapped_column( - String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True - ) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False) file_path: Mapped[str] = mapped_column(Text, nullable=False) mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) - asset: Mapped["Asset"] = relationship(back_populates="cache_state", uselist=False) + asset: Mapped["Asset"] = relationship(back_populates="cache_states") __table_args__ = ( Index("ix_asset_cache_state_file_path", "file_path"), + Index("ix_asset_cache_state_asset_hash", "asset_hash"), CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), ) def to_dict(self, include_none: bool = False) -> dict[str, Any]: return to_dict(self, include_none=include_none) def __repr__(self) -> str: - return f"" + return f"" class AssetLocation(Base): diff --git a/app/database/services.py b/app/database/services.py index af8861001..94a9b7016 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -4,7 +4,7 @@ import logging from collections import defaultdict from datetime import datetime from decimal import Decimal -from typing import Any, Sequence, Optional, Iterable +from typing import Any, Sequence, Optional, Iterable, Union import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession @@ -82,14 +82,14 @@ async def ingest_fs_asset( require_existing_tags: bool = False, ) -> dict: """ - Upsert Asset identity row + cache state pointing at local file. + Upsert Asset identity row + cache state(s) pointing at local file. Always: - Insert Asset if missing; - - Insert AssetCacheState if missing; else update mtime_ns if different. + - Insert AssetCacheState if missing; else update mtime_ns and asset_hash if different. Optionally (when info_name is provided): - - Create an AssetInfo. + - Create or update an AssetInfo on (asset_hash, owner_id, name). - Link provided tags to that AssetInfo. * If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table. * If False (default), create unknown tags. @@ -157,11 +157,16 @@ async def ingest_fs_asset( out["state_created"] = True if not out["state_created"]: - state = await session.get(AssetCacheState, asset_hash) + # most likely a unique(file_path) conflict; update that row + state = ( + await session.execute( + select(AssetCacheState).where(AssetCacheState.file_path == locator).limit(1) + ) + ).scalars().first() if state is not None: changed = False - if state.file_path != locator: - state.file_path = locator + if state.asset_hash != asset_hash: + state.asset_hash = asset_hash changed = True if state.mtime_ns != int(mtime_ns): state.mtime_ns = int(mtime_ns) @@ -260,7 +265,15 @@ async def ingest_fs_asset( # ) # start of adding metadata["filename"] if out["asset_info_id"] is not None: - computed_filename = compute_model_relative_filename(abs_path) + primary_path = ( + await session.execute( + select(AssetCacheState.file_path) + .where(AssetCacheState.asset_hash == asset_hash) + .order_by(AssetCacheState.id.asc()) + .limit(1) + ) + ).scalars().first() + computed_filename = compute_model_relative_filename(primary_path) if primary_path else None # Start from current metadata on this AssetInfo, if any current_meta = existing_info.user_metadata or {} @@ -366,7 +379,6 @@ async def list_asset_infos_page( base = _apply_tag_filters(base, include_tags, exclude_tags) base = _apply_metadata_filter(base, metadata_filter) - # Sort sort = (sort or "created_at").lower() order = (order or "desc").lower() sort_map = { @@ -381,7 +393,6 @@ async def list_asset_infos_page( base = base.order_by(sort_exp).limit(limit).offset(offset) - # Total count (same filters, no ordering/limit/offset) count_stmt = ( select(func.count()) .select_from(AssetInfo) @@ -395,10 +406,9 @@ async def list_asset_infos_page( total = int((await session.execute(count_stmt)).scalar_one() or 0) - # Fetch rows infos = (await session.execute(base)).scalars().unique().all() - # Collect tags in bulk (single query) + # Collect tags in bulk id_list: list[str] = [i.id for i in infos] tag_map: dict[str, list[str]] = defaultdict(list) if id_list: @@ -470,12 +480,33 @@ async def fetch_asset_info_asset_and_tags( async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: str) -> Optional[AssetCacheState]: - return await session.get(AssetCacheState, asset_hash) + """Return the oldest cache row for this asset.""" + return ( + await session.execute( + select(AssetCacheState) + .where(AssetCacheState.asset_hash == asset_hash) + .order_by(AssetCacheState.id.asc()) + .limit(1) + ) + ).scalars().first() + + +async def list_cache_states_by_asset_hash( + session: AsyncSession, *, asset_hash: str +) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]: + """Return all cache rows for this asset ordered by oldest first.""" + return ( + await session.execute( + select(AssetCacheState) + .where(AssetCacheState.asset_hash == asset_hash) + .order_by(AssetCacheState.id.asc()) + ) + ).scalars().all() async def list_asset_locations( session: AsyncSession, *, asset_hash: str, provider: Optional[str] = None -) -> list[AssetLocation] | Sequence[AssetLocation]: +) -> Union[list[AssetLocation], Sequence[AssetLocation]]: stmt = select(AssetLocation).where(AssetLocation.asset_hash == asset_hash) if provider: stmt = stmt.where(AssetLocation.provider == provider) @@ -815,7 +846,6 @@ async def list_tags_with_usage( if not include_zero: q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) - # Ordering if order == "name_asc": q = q.order_by(Tag.name.asc()) else: # default "count_desc" @@ -990,6 +1020,7 @@ def _apply_tag_filters( ) return stmt + def _apply_metadata_filter( stmt: sa.sql.Select, metadata_filter: Optional[dict], @@ -1050,7 +1081,7 @@ def _apply_metadata_filter( for k, v in metadata_filter.items(): if isinstance(v, list): # ANY-of (exists for any element) - ors = [ _exists_clause_for_value(k, elem) for elem in v ] + ors = [_exists_clause_for_value(k, elem) for elem in v] if ors: stmt = stmt.where(sa.or_(*ors)) else: @@ -1079,12 +1110,10 @@ def _project_kv(key: str, value: Any) -> list[dict]: """ rows: list[dict] = [] - # None if value is None: rows.append({"key": key, "ordinal": 0, "val_json": None}) return rows - # Scalars if _is_scalar(value): if isinstance(value, bool): rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) @@ -1099,9 +1128,7 @@ def _project_kv(key: str, value: Any) -> list[dict]: rows.append({"key": key, "ordinal": 0, "val_json": value}) return rows - # Lists if isinstance(value, list): - # list of scalars? if all(_is_scalar(x) for x in value): for i, x in enumerate(value): if x is None: