From a0e07a1a4d8c3c53470e249847473fbfc4d068cb Mon Sep 17 00:00:00 2001 From: Simon Pinfold Date: Wed, 27 May 2026 09:09:30 +1200 Subject: [PATCH] spike: add typed asset classification filters Co-authored-by: Amp Amp-Thread-ID: https://ampcode.com/threads/T-019e5117-c707-729d-bf98-dce718fe64d5 --- ...0004_add_asset_reference_classification.py | 31 + app/assets/api/routes.py | 75 ++- app/assets/api/schemas_in.py | 37 +- app/assets/api/schemas_out.py | 15 +- app/assets/database/models.py | 4 + app/assets/database/queries/__init__.py | 2 + .../database/queries/asset_reference.py | 67 +- app/assets/database/queries/common.py | 25 +- app/assets/scanner.py | 42 +- app/assets/services/asset_management.py | 4 + app/assets/services/bulk_ingest.py | 24 + app/assets/services/ingest.py | 18 + app/assets/services/path_utils.py | 276 +++++++- app/assets/services/schemas.py | 4 + openapi.yaml | 30 +- .../assets_test/queries/test_asset_info.py | 601 ++++++++++++++++++ .../assets_test/queries/test_cache_state.py | 25 + .../assets_test/services/test_bulk_ingest.py | 38 ++ .../assets_test/services/test_path_utils.py | 231 ++++++- .../assets_test/test_sync_references.py | 86 ++- 20 files changed, 1577 insertions(+), 58 deletions(-) create mode 100644 alembic_db/versions/0004_add_asset_reference_classification.py diff --git a/alembic_db/versions/0004_add_asset_reference_classification.py b/alembic_db/versions/0004_add_asset_reference_classification.py new file mode 100644 index 000000000..8adc52614 --- /dev/null +++ b/alembic_db/versions/0004_add_asset_reference_classification.py @@ -0,0 +1,31 @@ +""" +Add persisted asset classification columns. + +Revision ID: 0004_add_asset_reference_classification +Revises: 0003_add_metadata_job_id +Create Date: 2026-05-29 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0004_add_asset_reference_classification" +down_revision = "0003_add_metadata_job_id" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("asset_references") as batch_op: + batch_op.add_column(sa.Column("asset_type", sa.String(length=32), nullable=True)) + batch_op.add_column(sa.Column("model_folder", sa.String(length=512), nullable=True)) + batch_op.create_index("ix_asset_references_asset_type", ["asset_type"]) + batch_op.create_index("ix_asset_references_model_folder", ["model_folder"]) + + +def downgrade() -> None: + with op.batch_alter_table("asset_references") as batch_op: + batch_op.drop_index("ix_asset_references_model_folder") + batch_op.drop_index("ix_asset_references_asset_type") + batch_op.drop_column("model_folder") + batch_op.drop_column("asset_type") diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 68126b6a5..15547b70c 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -10,7 +10,6 @@ from typing import Any from aiohttp import web from pydantic import ValidationError -import folder_paths from app import user_manager from app.assets.api import schemas_in, schemas_out from app.assets.services import schemas @@ -39,6 +38,10 @@ from app.assets.services import ( update_asset_metadata, upload_from_temp_path, ) +from app.assets.services.path_utils import ( + get_comfy_models_folders, + get_stored_asset_response_path_info, +) from app.assets.services.tagging import list_tag_histogram ROUTES = web.RouteTableDef() @@ -124,17 +127,36 @@ def _validate_sort_field(requested: str | None) -> str: return "created_at" -def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None: - """Build a /api/view preview URL from asset tags and user_metadata filename.""" +def _get_asset_path_info( + file_path: str | None, + asset_type: str | None, + model_folder: str | None, +): + if not file_path or not asset_type: + return None + try: + return get_stored_asset_response_path_info(file_path, asset_type, model_folder) + except ValueError: + return None + + +def _build_preview_url_from_view( + asset_type: str | None, + user_metadata: dict[str, Any] | None, + fallback_tags: list[str] | None = None, +) -> str | None: + """Build a /api/view preview URL from path-derived type and filename metadata.""" if not user_metadata: return None filename = user_metadata.get("filename") if not filename: return None - if "input" in tags: + if asset_type in {"input", "output"}: + view_type = asset_type + elif fallback_tags and "input" in fallback_tags: view_type = "input" - elif "output" in tags: + elif fallback_tags and "output" in fallback_tags: view_type = "output" else: return None @@ -152,20 +174,54 @@ def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset: """Build an Asset response from a service result.""" + path_info = _get_asset_path_info( + result.ref.file_path, + result.ref.asset_type, + result.ref.model_folder, + ) + if result.ref.preview_id: preview_detail = get_asset_detail(result.ref.preview_id) if preview_detail: - preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata) + preview_path_info = _get_asset_path_info( + preview_detail.ref.file_path, + preview_detail.ref.asset_type, + preview_detail.ref.model_folder, + ) + preview_url = _build_preview_url_from_view( + preview_path_info.asset_type if preview_path_info else None, + preview_detail.ref.user_metadata, + fallback_tags=preview_detail.tags, + ) else: preview_url = None else: - preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata) + preview_url = _build_preview_url_from_view( + path_info.asset_type if path_info else None, + result.ref.user_metadata, + fallback_tags=result.tags, + ) + + asset_type = None + model_folder = None + file_path = None + display_name = None + if path_info: + asset_type = path_info.asset_type + model_folder = path_info.model_folder + file_path = path_info.file_path + display_name = path_info.display_name + return schemas_out.Asset( id=result.ref.id, name=result.ref.name, + file_path=file_path, + display_name=display_name, asset_hash=result.asset.hash if result.asset else None, size=int(result.asset.size_bytes) if result.asset else None, mime_type=result.asset.mime_type if result.asset else None, + model_folder=model_folder, + asset_type=asset_type, tags=result.tags, preview_url=preview_url, preview_id=result.ref.preview_id, @@ -213,6 +269,8 @@ async def list_assets_route(request: web.Request) -> web.Response: owner_id=USER_MANAGER.get_request_user_id(request), include_tags=q.include_tags, exclude_tags=q.exclude_tags, + asset_type=q.asset_type, + model_folder=q.model_folder, name_contains=q.name_contains, metadata_filter=q.metadata_filter, limit=q.limit, @@ -401,9 +459,10 @@ async def upload_asset(request: web.Request) -> web.Response: ) if spec.tags and spec.tags[0] == "models": + model_folder_names = {name for name, _paths in get_comfy_models_folders()} if ( len(spec.tags) < 2 - or spec.tags[1] not in folder_paths.folder_names_and_paths + or spec.tags[1] not in model_folder_names ): delete_temp_file_if_exists(parsed.tmp_path) category = spec.tags[1] if len(spec.tags) >= 2 else "" diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index 186a6ae1e..d6c46423c 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -52,6 +52,8 @@ class ParsedUpload: class ListAssetsQuery(BaseModel): include_tags: list[str] = Field(default_factory=list) exclude_tags: list[str] = Field(default_factory=list) + asset_type: Literal["model", "input", "output", "temp"] | None = None + model_folder: str | None = None name_contains: str | None = None # Accept either a JSON string (query param) or a dict @@ -81,6 +83,20 @@ class ListAssetsQuery(BaseModel): return out return v + @field_validator("model_folder", mode="before") + @classmethod + def _normalize_model_folder(cls, v): + if v is None: + return None + s = str(v).strip() + return s or None + + @model_validator(mode="after") + def _validate_path_filters(self): + if self.model_folder and self.asset_type != "model": + raise ValueError("model_folder can only be used with asset_type=model") + return self + @field_validator("metadata_filter", mode="before") @classmethod def _parse_metadata_json(cls, v): @@ -300,14 +316,23 @@ class UploadAssetSpec(BaseModel): else: return [] - # normalize + dedupe + # Normalize the root tag, but preserve path/destination components. Tags + # are normalized again before storage; this parser also feeds upload + # destination routing where registered model folder names are exact. + normalized_items: list[str] = [] + for index, item in enumerate(items): + stripped = str(item).strip() + if not stripped: + continue + normalized_items.append(stripped.lower() if index == 0 else stripped) + + # Dedupe exact tokens while preserving order. norm = [] seen = set() - for t in items: - tnorm = str(t).strip().lower() - if tnorm and tnorm not in seen: - seen.add(tnorm) - norm.append(tnorm) + for token in normalized_items: + if token not in seen: + seen.add(token) + norm.append(token) return norm @field_validator("user_metadata", mode="before") diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index d99b1098d..9850ee2e6 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any +from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field, field_serializer @@ -10,9 +10,22 @@ class Asset(BaseModel): id: str name: str + file_path: str | None = Field( + default=None, + description="Logical asset namespace path. Model assets use `models//`; other typed assets use `/`. This is not a unique identity; use `id` for stable asset-reference operations.", + ) + display_name: str | None = Field( + default=None, + description="Human-facing path below the matched asset root or model folder.", + ) asset_hash: str | None = None size: int | None = None mime_type: str | None = None + model_folder: str | None = Field( + default=None, + description="Exact, case-sensitive registered ComfyUI model folder name. Present only when asset_type is `model`.", + ) + asset_type: Literal["model", "input", "output", "temp"] | None = None tags: list[str] = Field(default_factory=list) preview_url: str | None = None preview_id: str | None = None # references an asset_reference id, not an asset id diff --git a/app/assets/database/models.py b/app/assets/database/models.py index a3af8a192..649abe4ce 100644 --- a/app/assets/database/models.py +++ b/app/assets/database/models.py @@ -76,6 +76,8 @@ class AssetReference(Base): # Cache state fields (from former AssetCacheState) file_path: Mapped[str | None] = mapped_column(Text, nullable=True) + asset_type: Mapped[str | None] = mapped_column(String(32), nullable=True) + model_folder: Mapped[str | None] = mapped_column(String(512), nullable=True) mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) @@ -144,6 +146,8 @@ class AssetReference(Base): Index("uq_asset_references_file_path", "file_path", unique=True), Index("ix_asset_references_asset_id", "asset_id"), Index("ix_asset_references_owner_id", "owner_id"), + Index("ix_asset_references_asset_type", "asset_type"), + Index("ix_asset_references_model_folder", "model_folder"), Index("ix_asset_references_name", "name"), Index("ix_asset_references_is_missing", "is_missing"), Index("ix_asset_references_enrichment_level", "enrichment_level"), diff --git a/app/assets/database/queries/__init__.py b/app/assets/database/queries/__init__.py index 9949e84e1..b8432a6ab 100644 --- a/app/assets/database/queries/__init__.py +++ b/app/assets/database/queries/__init__.py @@ -13,6 +13,7 @@ from app.assets.database.queries.asset_reference import ( UnenrichedReferenceRow, bulk_insert_references_ignore_conflicts, bulk_update_enrichment_level, + bulk_update_reference_classification, count_active_siblings, bulk_update_is_missing, bulk_update_needs_verify, @@ -82,6 +83,7 @@ __all__ = [ "bulk_insert_references_ignore_conflicts", "bulk_insert_tags_and_meta", "bulk_update_enrichment_level", + "bulk_update_reference_classification", "count_active_siblings", "create_stub_asset", "bulk_update_is_missing", diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py index 8b90ae511..80ca61f50 100644 --- a/app/assets/database/queries/asset_reference.py +++ b/app/assets/database/queries/asset_reference.py @@ -24,6 +24,7 @@ from app.assets.database.models import ( ) from app.assets.database.queries.common import ( MAX_BIND_PARAMS, + apply_asset_path_filters, apply_metadata_filter, apply_tag_filters, build_prefix_like_conditions, @@ -166,6 +167,8 @@ def insert_reference( name: str, owner_id: str = "", file_path: str | None = None, + asset_type: str | None = None, + model_folder: str | None = None, mtime_ns: int | None = None, preview_id: str | None = None, ) -> AssetReference | None: @@ -178,6 +181,8 @@ def insert_reference( name=name, owner_id=owner_id, file_path=file_path, + asset_type=asset_type, + model_folder=model_folder, mtime_ns=mtime_ns, preview_id=preview_id, created_at=now, @@ -197,6 +202,8 @@ def get_or_create_reference( name: str, owner_id: str = "", file_path: str | None = None, + asset_type: str | None = None, + model_folder: str | None = None, mtime_ns: int | None = None, preview_id: str | None = None, ) -> tuple[AssetReference, bool]: @@ -214,6 +221,8 @@ def get_or_create_reference( name=name, owner_id=owner_id, file_path=file_path, + asset_type=asset_type, + model_folder=model_folder, mtime_ns=mtime_ns, preview_id=preview_id, ) @@ -263,6 +272,8 @@ def list_references_page( name_contains: str | None = None, include_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None, + asset_type: str | None = None, + model_folder: str | None = None, metadata_filter: dict | None = None, sort: str | None = None, order: str | None = None, @@ -285,6 +296,7 @@ def list_references_page( base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) base = apply_tag_filters(base, include_tags, exclude_tags) + base = apply_asset_path_filters(base, asset_type=asset_type, model_folder=model_folder) base = apply_metadata_filter(base, metadata_filter) sort = (sort or "created_at").lower() @@ -315,6 +327,9 @@ def list_references_page( AssetReference.name.ilike(f"%{escaped}%", escape=esc) ) count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = apply_asset_path_filters( + count_stmt, asset_type=asset_type, model_folder=model_folder + ) count_stmt = apply_metadata_filter(count_stmt, metadata_filter) total = int(session.execute(count_stmt).scalar_one() or 0) @@ -571,6 +586,8 @@ class CacheStateRow(NamedTuple): file_path: str mtime_ns: int | None needs_verify: bool + asset_type: str | None + model_folder: str | None asset_id: str asset_hash: str | None size_bytes: int | None @@ -619,6 +636,8 @@ def upsert_reference( name: str, mtime_ns: int, owner_id: str = "", + asset_type: str | None = None, + model_folder: str | None = None, ) -> tuple[bool, bool]: """Upsert a reference by file_path. Returns (created, updated). @@ -628,6 +647,8 @@ def upsert_reference( vals = { "asset_id": asset_id, "file_path": file_path, + "asset_type": asset_type, + "model_folder": model_folder, "name": name, "owner_id": owner_id, "mtime_ns": int(mtime_ns), @@ -653,6 +674,8 @@ def upsert_reference( .where( sa.or_( AssetReference.asset_id != asset_id, + AssetReference.asset_type.is_distinct_from(asset_type), + AssetReference.model_folder.is_distinct_from(model_folder), AssetReference.mtime_ns.is_(None), AssetReference.mtime_ns != int(mtime_ns), AssetReference.is_missing == True, # noqa: E712 @@ -660,8 +683,13 @@ def upsert_reference( ) ) .values( - asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False, - deleted_at=None, updated_at=now, + asset_id=asset_id, + asset_type=asset_type, + model_folder=model_folder, + mtime_ns=int(mtime_ns), + is_missing=False, + deleted_at=None, + updated_at=now, ) ) res2 = session.execute(upd) @@ -780,6 +808,8 @@ def get_references_for_prefixes( AssetReference.file_path, AssetReference.mtime_ns, AssetReference.needs_verify, + AssetReference.asset_type, + AssetReference.model_folder, AssetReference.asset_id, Asset.hash, Asset.size_bytes, @@ -803,14 +833,39 @@ def get_references_for_prefixes( file_path=row[1], mtime_ns=row[2], needs_verify=row[3], - asset_id=row[4], - asset_hash=row[5], - size_bytes=int(row[6]) if row[6] is not None else None, + asset_type=row[4], + model_folder=row[5], + asset_id=row[6], + asset_hash=row[7], + size_bytes=int(row[8]) if row[8] is not None else None, ) for row in rows ] +def bulk_update_reference_classification( + session: Session, + updates: list[dict[str, str | None]], +) -> int: + """Update persisted asset_type/model_folder for existing references.""" + if not updates: + return 0 + + total = 0 + for row in updates: + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id == row["reference_id"]) + .values( + asset_type=row["asset_type"], + model_folder=row["model_folder"], + updated_at=get_utc_now(), + ) + ) + total += result.rowcount + return total + + def bulk_update_needs_verify( session: Session, reference_ids: list[str], value: bool ) -> int: @@ -993,7 +1048,7 @@ def bulk_insert_references_ignore_conflicts( ins = sqlite.insert(AssetReference).on_conflict_do_nothing( index_elements=[AssetReference.file_path] ) - for chunk in iter_chunks(enriched_rows, calculate_rows_per_statement(14)): + for chunk in iter_chunks(enriched_rows, calculate_rows_per_statement(16)): session.execute(ins, chunk) diff --git a/app/assets/database/queries/common.py b/app/assets/database/queries/common.py index 89bb49327..2abfb1bc3 100644 --- a/app/assets/database/queries/common.py +++ b/app/assets/database/queries/common.py @@ -8,7 +8,7 @@ import sqlalchemy as sa from sqlalchemy import exists from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag -from app.assets.helpers import escape_sql_like_string, normalize_tags +from app.assets.helpers import normalize_tags MAX_BIND_PARAMS = 800 @@ -45,17 +45,34 @@ def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: def build_prefix_like_conditions( prefixes: list[str], ) -> list[sa.sql.ColumnElement]: - """Build LIKE conditions for matching file paths under directory prefixes.""" + """Build case-exact conditions for matching file paths under directory prefixes.""" conds = [] for p in prefixes: base = os.path.abspath(p) if not base.endswith(os.sep): base += os.sep - escaped, esc = escape_sql_like_string(base) - conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) + conds.append(sa.func.substr(AssetReference.file_path, 1, len(base)) == base) return conds +def apply_asset_path_filters( + stmt: sa.sql.Select, + asset_type: str | None = None, + model_folder: str | None = None, +) -> sa.sql.Select: + """Filter references using classification persisted at ingest time.""" + if asset_type is None and model_folder is None: + return stmt + if model_folder and asset_type != "model": + raise ValueError("model_folder can only be used with asset_type=model") + + if asset_type is not None: + stmt = stmt.where(AssetReference.asset_type == asset_type) + if model_folder is not None: + stmt = stmt.where(AssetReference.model_folder == model_folder) + return stmt + + def apply_tag_filters( stmt: sa.sql.Select, include_tags: Sequence[str] | None = None, diff --git a/app/assets/scanner.py b/app/assets/scanner.py index ebb6869af..2fcf151cc 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -9,6 +9,7 @@ from app.assets.database.queries import ( bulk_update_enrichment_level, bulk_update_is_missing, bulk_update_needs_verify, + bulk_update_reference_classification, delete_orphaned_seed_asset, delete_references_by_ids, ensure_tags_exist, @@ -36,6 +37,7 @@ from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash from app.assets.services.metadata_extract import extract_file_metadata from app.assets.services.path_utils import ( compute_relative_filename, + get_asset_path_info, get_comfy_models_folders, get_name_and_tags_from_asset_path, ) @@ -48,6 +50,8 @@ class _RefInfo(TypedDict): exists: bool stat_unchanged: bool needs_verify: bool + asset_type: str | None + model_folder: str | None class _AssetAccumulator(TypedDict): @@ -56,7 +60,7 @@ class _AssetAccumulator(TypedDict): refs: list[_RefInfo] -RootType = Literal["models", "input", "output"] +RootType = Literal["models", "input", "output", "temp"] def get_prefixes_for_root(root: RootType) -> list[str]: @@ -69,12 +73,14 @@ def get_prefixes_for_root(root: RootType) -> list[str]: return [os.path.abspath(folder_paths.get_input_directory())] if root == "output": return [os.path.abspath(folder_paths.get_output_directory())] + if root == "temp": + return [os.path.abspath(folder_paths.get_temp_directory())] return [] def get_all_known_prefixes() -> list[str]: """Get all known asset prefixes across all root types.""" - all_roots: tuple[RootType, ...] = ("models", "input", "output") + all_roots: tuple[RootType, ...] = ("models", "input", "output", "temp") return [p for root in all_roots for p in get_prefixes_for_root(root)] @@ -162,6 +168,8 @@ def sync_references_with_filesystem( "exists": exists, "stat_unchanged": stat_unchanged, "needs_verify": row.needs_verify, + "asset_type": row.asset_type, + "model_folder": row.model_folder, } ) @@ -170,6 +178,7 @@ def sync_references_with_filesystem( stale_ref_ids: list[str] = [] to_mark_missing: list[str] = [] to_clear_missing: list[str] = [] + classification_updates: list[dict[str, str | None]] = [] survivors: set[str] = set() for aid, acc in by_asset.items(): @@ -182,6 +191,21 @@ def sync_references_with_filesystem( if not r["exists"]: to_mark_missing.append(r["ref_id"]) continue + try: + path_info = get_asset_path_info(r["file_path"]) + asset_type = path_info.asset_type + model_folder = path_info.model_folder + except ValueError: + asset_type = None + model_folder = None + if asset_type != r["asset_type"] or model_folder != r["model_folder"]: + classification_updates.append( + { + "reference_id": r["ref_id"], + "asset_type": asset_type, + "model_folder": model_folder, + } + ) if r["stat_unchanged"]: to_clear_missing.append(r["ref_id"]) if r["needs_verify"]: @@ -226,6 +250,7 @@ def sync_references_with_filesystem( bulk_update_is_missing(session, to_clear_missing, value=False) bulk_update_needs_verify(session, to_set_verify, value=True) bulk_update_needs_verify(session, to_clear_verify, value=False) + bulk_update_reference_classification(session, classification_updates) return survivors if collect_existing_paths else None @@ -274,7 +299,18 @@ def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]: paths.extend(list_files_recursively(folder_paths.get_input_directory())) if "output" in roots: paths.extend(list_files_recursively(folder_paths.get_output_directory())) - return paths + if "temp" in roots: + paths.extend(list_files_recursively(folder_paths.get_temp_directory())) + + deduped: list[str] = [] + seen: set[str] = set() + for path in paths: + abs_path = os.path.abspath(path) + if abs_path in seen: + continue + seen.add(abs_path) + deduped.append(abs_path) + return deduped def build_asset_specs( diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py index 5aefd9956..886715f8c 100644 --- a/app/assets/services/asset_management.py +++ b/app/assets/services/asset_management.py @@ -246,6 +246,8 @@ def list_assets_page( owner_id: str = "", include_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None, + asset_type: str | None = None, + model_folder: str | None = None, name_contains: str | None = None, metadata_filter: dict | None = None, limit: int = 20, @@ -259,6 +261,8 @@ def list_assets_page( owner_id=owner_id, include_tags=include_tags, exclude_tags=exclude_tags, + asset_type=asset_type, + model_folder=model_folder, name_contains=name_contains, metadata_filter=metadata_filter, limit=limit, diff --git a/app/assets/services/bulk_ingest.py b/app/assets/services/bulk_ingest.py index 67aad838f..cf06c9eeb 100644 --- a/app/assets/services/bulk_ingest.py +++ b/app/assets/services/bulk_ingest.py @@ -20,6 +20,7 @@ from app.assets.database.queries import ( restore_references_by_paths, ) from app.assets.helpers import get_utc_now +from app.assets.services.path_utils import get_asset_path_info if TYPE_CHECKING: from app.assets.services.metadata_extract import ExtractedMetadata @@ -56,6 +57,8 @@ class ReferenceRow(TypedDict): id: str asset_id: str file_path: str + asset_type: str | None + model_folder: str | None mtime_ns: int owner_id: str name: str @@ -125,6 +128,18 @@ def batch_insert_seed_assets( if not specs: return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0) + deduped_specs: list[SeedAssetSpec] = [] + seen_paths: set[str] = set() + for spec in specs: + absolute_path = os.path.abspath(spec["abs_path"]) + if absolute_path in seen_paths: + continue + seen_paths.add(absolute_path) + deduped_specs.append(spec) + specs = deduped_specs + if not specs: + return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0) + current_time = get_utc_now() asset_rows: list[AssetRow] = [] reference_rows: list[ReferenceRow] = [] @@ -140,6 +155,13 @@ def batch_insert_seed_assets( path_to_asset_id[absolute_path] = asset_id mime_type = spec.get("mime_type") + try: + path_info = get_asset_path_info(absolute_path) + asset_type = path_info.asset_type + model_folder = path_info.model_folder + except ValueError: + asset_type = None + model_folder = None asset_rows.append( { "id": asset_id, @@ -164,6 +186,8 @@ def batch_insert_seed_assets( "id": reference_id, "asset_id": asset_id, "file_path": absolute_path, + "asset_type": asset_type, + "model_folder": model_folder, "mtime_ns": spec["mtime_ns"], "owner_id": owner_id, "name": spec["info_name"], diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index f0b070517..9d3468ed7 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -31,6 +31,7 @@ from app.assets.services.bulk_ingest import batch_insert_seed_assets from app.assets.services.file_utils import get_size_and_mtime_ns from app.assets.services.path_utils import ( compute_relative_filename, + get_asset_path_info, get_name_and_tags_from_asset_path, resolve_destination_from_tags, validate_path_within_base, @@ -70,6 +71,14 @@ def _ingest_file_from_path( reference_id: str | None = None with create_session() as session: + try: + path_info = get_asset_path_info(locator) + asset_type = path_info.asset_type + model_folder = path_info.model_folder + except ValueError: + asset_type = None + model_folder = None + if preview_id: if not reference_exists(session, preview_id): preview_id = None @@ -88,6 +97,8 @@ def _ingest_file_from_path( name=info_name or os.path.basename(locator), mtime_ns=mtime_ns, owner_id=owner_id, + asset_type=asset_type, + model_folder=model_folder, ) # Get the reference we just created/updated @@ -186,6 +197,13 @@ def ingest_existing_file( now = get_utc_now() existing_ref.mtime_ns = mtime_ns existing_ref.job_id = job_id + try: + path_info = get_asset_path_info(locator) + existing_ref.asset_type = path_info.asset_type + existing_ref.model_folder = path_info.model_folder + except ValueError: + existing_ref.asset_type = None + existing_ref.model_folder = None existing_ref.is_missing = False existing_ref.deleted_at = None existing_ref.updated_at = now diff --git a/app/assets/services/path_utils.py b/app/assets/services/path_utils.py index 892140ffb..c4ba4e21d 100644 --- a/app/assets/services/path_utils.py +++ b/app/assets/services/path_utils.py @@ -1,4 +1,5 @@ import os +from dataclasses import dataclass from pathlib import Path from typing import Literal @@ -6,7 +7,29 @@ import folder_paths from app.assets.helpers import normalize_tags -_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"}) +# These names are bootstrapped into folder_names_and_paths by core but are not +# model folders (matching /api/experiment/models' exclusion). Intentionally +# duplicated here so the assets layer stays decoupled from the legacy +# model-manager code it will eventually replace. +_NON_MODEL_FOLDER_NAMES = frozenset({"configs", "custom_nodes"}) + + +@dataclass(frozen=True) +class AssetPathInfo: + asset_type: Literal["input", "output", "temp", "model"] + model_folder: str | None + + +@dataclass(frozen=True) +class AssetResponsePathInfo(AssetPathInfo): + file_path: str + display_name: str | None + + +@dataclass(frozen=True) +class AssetPathContext(AssetPathInfo): + base_path: str + relative_path: str def get_comfy_models_folders() -> list[tuple[str, list[str]]]: @@ -14,7 +37,7 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]: Includes every category registered in folder_names_and_paths, regardless of whether its paths are under the main models_dir, - but excludes non-model entries like custom_nodes. + but excludes non-model entries like configs and custom_nodes. """ targets: list[tuple[str, list[str]]] = [] for name, values in folder_paths.folder_names_and_paths.items(): @@ -67,28 +90,109 @@ def validate_path_within_base(candidate: str, base: str) -> None: def compute_relative_filename(file_path: str) -> str | None: """ - Return the model's path relative to the last well-known folder (the model category), - using forward slashes, eg: + Return the path relative to the matched asset root or model folder, using + forward slashes, eg: /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" + /.../input/sub/image.png -> "sub/image.png" - For non-model paths, returns None. + For unknown paths, returns None. """ try: - root_category, rel_path = get_asset_category_and_relative_path(file_path) + context = resolve_asset_path_context(file_path) except ValueError: return None - p = Path(rel_path) - parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] + return _normalize_relative_path(context.relative_path) + + +def _normalize_relative_path(relative_path: str) -> str | None: + parts = [ + seg + for seg in Path(relative_path).parts + if seg not in (".", "..", Path(relative_path).anchor) + ] if not parts: return None - if root_category == "models": - # parts[0] is the category ("checkpoints", "vae", etc) – drop it - inside = parts[1:] if len(parts) > 1 else [parts[0]] - return "/".join(inside) - return "/".join(parts) # input/output: keep all parts + return "/".join(parts) + + +def resolve_asset_path_context(file_path: str) -> AssetPathContext: + """Resolve a path against Core's asset roots and model-folder registration. + + This is the source of truth for path-derived asset classification. For + model assets, ``model_folder`` is the exact registered folder name whose + base path contains the file, and ``relative_path`` is relative to that + matched base path. When multiple registered bases contain the file, the + deepest base wins. + """ + fp_abs = os.path.abspath(file_path) + + def _check_is_within(child: str, parent: str) -> bool: + return Path(child).is_relative_to(parent) + + def _compute_relative(child: str, parent: str) -> str: + # Normalize relative path, stripping any leading ".." components + # by anchoring to root (os.sep) then computing relpath back from it. + return os.path.relpath( + os.path.join(os.sep, os.path.relpath(child, parent)), os.sep + ) + + best: tuple[int, str, str, str] | None = None + for model_folder, bases in get_comfy_models_folders(): + for base in bases: + base_abs = os.path.abspath(base) + if not _check_is_within(fp_abs, base_abs): + continue + cand = ( + len(base_abs), + model_folder, + base_abs, + _compute_relative(fp_abs, base_abs), + ) + if best is None or cand[0] > best[0]: + best = cand + + if best is not None: + _, model_folder, base_path, relative_path = best + return AssetPathContext( + asset_type="model", + model_folder=model_folder, + base_path=base_path, + relative_path=relative_path, + ) + + input_base = os.path.abspath(folder_paths.get_input_directory()) + if _check_is_within(fp_abs, input_base): + return AssetPathContext( + asset_type="input", + model_folder=None, + base_path=input_base, + relative_path=_compute_relative(fp_abs, input_base), + ) + + output_base = os.path.abspath(folder_paths.get_output_directory()) + if _check_is_within(fp_abs, output_base): + return AssetPathContext( + asset_type="output", + model_folder=None, + base_path=output_base, + relative_path=_compute_relative(fp_abs, output_base), + ) + + temp_base = os.path.abspath(folder_paths.get_temp_directory()) + if _check_is_within(fp_abs, temp_base): + return AssetPathContext( + asset_type="temp", + model_folder=None, + base_path=temp_base, + relative_path=_compute_relative(fp_abs, temp_base), + ) + + raise ValueError( + f"Path is not within input, output, temp, or configured model bases: {file_path}" + ) def get_asset_category_and_relative_path( @@ -120,42 +224,164 @@ def get_asset_category_and_relative_path( os.path.join(os.sep, os.path.relpath(child, parent)), os.sep ) - # 1) input input_base = os.path.abspath(folder_paths.get_input_directory()) if _check_is_within(fp_abs, input_base): return "input", _compute_relative(fp_abs, input_base) - # 2) output output_base = os.path.abspath(folder_paths.get_output_directory()) if _check_is_within(fp_abs, output_base): return "output", _compute_relative(fp_abs, output_base) - # 3) temp temp_base = os.path.abspath(folder_paths.get_temp_directory()) if _check_is_within(fp_abs, temp_base): return "temp", _compute_relative(fp_abs, temp_base) - # 4) models (check deepest matching base to avoid ambiguity) - best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket) - for bucket, bases in get_comfy_models_folders(): - for b in bases: - base_abs = os.path.abspath(b) + best: tuple[int, str, str] | None = None + for model_folder, bases in get_comfy_models_folders(): + for base in bases: + base_abs = os.path.abspath(base) if not _check_is_within(fp_abs, base_abs): continue - cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs)) + relative_path = _compute_relative(fp_abs, base_abs) + combined = os.path.join(model_folder, relative_path) + cand = (len(base_abs), base_abs, combined) if best is None or cand[0] > best[0]: best = cand if best is not None: - _, bucket, rel_inside = best - combined = os.path.join(bucket, rel_inside) - return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) + return "models", os.path.relpath(os.path.join(os.sep, best[2]), os.sep) raise ValueError( f"Path is not within input, output, temp, or configured model bases: {file_path}" ) +def get_asset_path_info(file_path: str) -> AssetPathInfo: + """Return typed asset classification derived from the actual filesystem path. + + This intentionally reads the ComfyUI model folder registration from + ``folder_paths.folder_names_and_paths`` instead of inferring it from tags. + For model files, ``model_folder`` is the registered folder name whose base + path contains ``file_path``. + + Raises: + ValueError: path does not belong to any known root. + """ + context = resolve_asset_path_context(file_path) + return AssetPathInfo( + asset_type=context.asset_type, + model_folder=context.model_folder, + ) + + +def get_asset_response_path_info(file_path: str) -> AssetResponsePathInfo: + """Return API-facing path fields derived from the actual filesystem path. + + ``file_path`` is a logical namespace key: ``models//`` + for model assets and ``/`` for input/output/temp assets. + ``display_name`` is the path below the matched root or model folder. + + Raises: + ValueError: path does not belong to any known root. + """ + context = resolve_asset_path_context(file_path) + display_name = _normalize_relative_path(context.relative_path) + + if context.asset_type == "model": + logical_file_path = ( + f"models/{context.model_folder}/{display_name}" + if display_name + else f"models/{context.model_folder}" + ) + else: + logical_file_path = ( + f"{context.asset_type}/{display_name}" + if display_name + else context.asset_type + ) + + return AssetResponsePathInfo( + asset_type=context.asset_type, + model_folder=context.model_folder, + file_path=logical_file_path, + display_name=display_name, + ) + + +def get_stored_asset_response_path_info( + file_path: str, + asset_type: str | None, + model_folder: str | None, +) -> AssetResponsePathInfo: + """Return API-facing path fields from persisted classification. + + ``asset_type`` and ``model_folder`` are written at ingest time and are the + classification source of truth for API responses. The physical ``file_path`` + is still used to compute the display path below the stored root. + """ + if asset_type not in {"input", "output", "temp", "model"}: + raise ValueError(f"unknown persisted asset_type: {asset_type}") + + fp_abs = os.path.abspath(file_path) + + def _check_is_within(child: str, parent: str) -> bool: + return Path(child).is_relative_to(parent) + + def _compute_relative(child: str, parent: str) -> str: + return os.path.relpath( + os.path.join(os.sep, os.path.relpath(child, parent)), os.sep + ) + + if asset_type == "model": + if not model_folder: + raise ValueError("model asset is missing persisted model_folder") + best: tuple[int, str] | None = None + for folder_name, bases in get_comfy_models_folders(): + if folder_name != model_folder: + continue + for base in bases: + base_abs = os.path.abspath(base) + if not _check_is_within(fp_abs, base_abs): + continue + relative_path = _compute_relative(fp_abs, base_abs) + cand = (len(base_abs), relative_path) + if best is None or cand[0] > best[0]: + best = cand + if best is None: + raise ValueError( + f"Path is not within persisted model folder roots: {file_path}" + ) + display_name = _normalize_relative_path(best[1]) + logical_file_path = ( + f"models/{model_folder}/{display_name}" + if display_name + else f"models/{model_folder}" + ) + return AssetResponsePathInfo( + asset_type="model", + model_folder=model_folder, + file_path=logical_file_path, + display_name=display_name, + ) + + root_by_type = { + "input": folder_paths.get_input_directory, + "output": folder_paths.get_output_directory, + "temp": folder_paths.get_temp_directory, + } + root = os.path.abspath(root_by_type[asset_type]()) + if not _check_is_within(fp_abs, root): + raise ValueError(f"Path is not within persisted asset root: {file_path}") + display_name = _normalize_relative_path(_compute_relative(fp_abs, root)) + logical_file_path = f"{asset_type}/{display_name}" if display_name else asset_type + return AssetResponsePathInfo( + asset_type=asset_type, + model_folder=None, + file_path=logical_file_path, + display_name=display_name, + ) + + def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: """Return (name, tags) derived from a filesystem path. diff --git a/app/assets/services/schemas.py b/app/assets/services/schemas.py index 0eb128f58..4277504f1 100644 --- a/app/assets/services/schemas.py +++ b/app/assets/services/schemas.py @@ -21,6 +21,8 @@ class ReferenceData: id: str name: str file_path: str | None + asset_type: str | None + model_folder: str | None user_metadata: UserMetadata preview_id: str | None created_at: datetime @@ -93,6 +95,8 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData: id=ref.id, name=ref.name, file_path=ref.file_path, + asset_type=ref.asset_type, + model_folder=ref.model_folder, user_metadata=ref.user_metadata, preview_id=ref.preview_id, system_metadata=ref.system_metadata, diff --git a/openapi.yaml b/openapi.yaml index 885231acc..d8e7ed88c 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -1511,7 +1511,7 @@ paths: in: query schema: type: integer - default: 50 + default: 20 - name: offset in: query schema: @@ -1535,6 +1535,17 @@ paths: style: form explode: true description: Tags that assets must not have + - name: asset_type + in: query + schema: + type: string + enum: [model, input, output, temp] + description: Filter by path-derived asset type. Model classification is based on registered ComfyUI model folder roots, not tags. + - name: model_folder + in: query + schema: + type: string + description: Filter model assets by exact registered ComfyUI model folder name. Requires asset_type=model. - name: name_contains in: query schema: @@ -6607,10 +6618,23 @@ components: id: type: string format: uuid - description: Unique identifier for the asset + description: AssetReference ID. Use this as the stable identity; logical file_path is not unique. name: type: string description: Name of the asset file + file_path: + type: string + description: Logical asset namespace path. Model assets use `models//`; input/output/temp assets use `/`. Omitted when the asset has no resolvable filesystem path. Not unique; use `id` for stable asset-reference operations. + display_name: + type: string + description: Human-facing path below the matched asset root or model folder. Omitted when the asset has no resolvable filesystem path. + asset_type: + type: string + enum: [model, input, output, temp] + description: Path-derived asset type. Model classification is based on registered ComfyUI model folder roots, not tags. Omitted when the asset has no resolvable filesystem path. + model_folder: + type: string + description: Exact, case-sensitive registered ComfyUI model folder name for model assets. Present only when asset_type is `model`. hash: type: string nullable: true @@ -8723,4 +8747,4 @@ components: items: $ref: "#/components/schemas/TaskEntry" pagination: - $ref: "#/components/schemas/PaginationInfo" \ No newline at end of file + $ref: "#/components/schemas/PaginationInfo" diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py index fe510e342..a1231d28d 100644 --- a/tests-unit/assets_test/queries/test_asset_info.py +++ b/tests-unit/assets_test/queries/test_asset_info.py @@ -1,8 +1,12 @@ import time import uuid +from pathlib import Path + import pytest from sqlalchemy.orm import Session +from app.assets.api import schemas_in +from app.assets.api.routes import _build_asset_response from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta from app.assets.database.queries import ( reference_exists_for_asset_id, @@ -23,6 +27,40 @@ from app.assets.database.queries import ( add_tags_to_reference, ) from app.assets.helpers import get_utc_now +from app.assets.services import schemas + + +class TestListAssetsQueryPathFilters: + def test_model_folder_requires_model_asset_type(self): + with pytest.raises(ValueError, match="model_folder can only be used"): + schemas_in.ListAssetsQuery.model_validate({"model_folder": "checkpoints"}) + + def test_model_folder_accepts_explicit_model_asset_type(self): + query = schemas_in.ListAssetsQuery.model_validate( + {"asset_type": "model", "model_folder": "checkpoints"} + ) + + assert query.asset_type == "model" + assert query.model_folder == "checkpoints" + + def test_model_folder_rejects_non_model_asset_type(self): + with pytest.raises(ValueError, match="model_folder can only be used"): + schemas_in.ListAssetsQuery.model_validate( + {"asset_type": "input", "model_folder": "checkpoints"} + ) + + def test_query_layer_rejects_model_folder_without_model_asset_type( + self, session: Session + ): + with pytest.raises(ValueError, match="model_folder can only be used"): + list_references_page(session, model_folder="checkpoints") + + def test_upload_tags_preserve_model_folder_case_for_destination(self): + spec = schemas_in.UploadAssetSpec.model_validate( + {"tags": ['["models", "LLM", "SubDir"]']} + ) + + assert spec.tags == ["models", "LLM", "SubDir"] def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset: @@ -37,12 +75,18 @@ def _make_reference( asset: Asset, name: str = "test", owner_id: str = "", + file_path: str | None = None, + asset_type: str | None = None, + model_folder: str | None = None, ) -> AssetReference: now = get_utc_now() ref = AssetReference( owner_id=owner_id, name=name, asset_id=asset.id, + file_path=file_path, + asset_type=asset_type, + model_folder=model_folder, created_at=now, updated_at=now, last_access_time=now, @@ -52,6 +96,130 @@ def _make_reference( return ref +def _reference_data( + *, + name: str, + file_path: str | None, + asset_type: str | None = None, + model_folder: str | None = None, +) -> schemas.ReferenceData: + now = get_utc_now() + return schemas.ReferenceData( + id=str(uuid.uuid4()), + name=name, + file_path=file_path, + asset_type=asset_type, + model_folder=model_folder, + user_metadata={}, + preview_id=None, + created_at=now, + updated_at=now, + last_access_time=now, + ) + + +def _asset_detail_result(ref: schemas.ReferenceData) -> schemas.AssetDetailResult: + return schemas.AssetDetailResult( + ref=ref, + asset=schemas.AssetData( + hash="blake3:" + "a" * 64, + size_bytes=123, + mime_type="application/octet-stream", + ), + tags=[], + ) + + +class TestBuildAssetResponsePathFields: + def test_model_response_fields_use_persisted_classification( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ): + checkpoint_dir = tmp_path / "models" / "checkpoints" + checkpoint_dir.mkdir(parents=True) + model_path = checkpoint_dir / "sub" / "model.safetensors" + model_path.parent.mkdir() + model_path.write_text("data") + monkeypatch.setattr( + "app.assets.services.path_utils.get_comfy_models_folders", + lambda: [("checkpoints", [str(checkpoint_dir)])], + ) + + asset = _build_asset_response( + _asset_detail_result( + _reference_data( + name="model.safetensors", + file_path=str(model_path), + asset_type="model", + model_folder="checkpoints", + ) + ) + ) + + assert asset.asset_type == "model" + assert asset.model_folder == "checkpoints" + assert asset.display_name == "sub/model.safetensors" + assert asset.file_path == "models/checkpoints/sub/model.safetensors" + + def test_input_output_response_fields_use_persisted_classification( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ): + input_dir = tmp_path / "input" + output_dir = tmp_path / "output" + input_dir.mkdir() + output_dir.mkdir() + input_path = input_dir / "sub" / "image.png" + output_path = output_dir / "result.png" + input_path.parent.mkdir() + input_path.write_text("input") + output_path.write_text("output") + monkeypatch.setattr( + "app.assets.services.path_utils.folder_paths.get_input_directory", + lambda: str(input_dir), + ) + monkeypatch.setattr( + "app.assets.services.path_utils.folder_paths.get_output_directory", + lambda: str(output_dir), + ) + + input_asset = _build_asset_response( + _asset_detail_result( + _reference_data( + name="image.png", + file_path=str(input_path), + asset_type="input", + ) + ) + ) + output_asset = _build_asset_response( + _asset_detail_result( + _reference_data( + name="result.png", + file_path=str(output_path), + asset_type="output", + ) + ) + ) + + assert input_asset.asset_type == "input" + assert input_asset.model_folder is None + assert input_asset.display_name == "sub/image.png" + assert input_asset.file_path == "input/sub/image.png" + assert output_asset.asset_type == "output" + assert output_asset.model_folder is None + assert output_asset.display_name == "result.png" + assert output_asset.file_path == "output/result.png" + + def test_pathless_response_omits_typed_path_fields(self): + asset = _build_asset_response( + _asset_detail_result(_reference_data(name="manual", file_path=None)) + ) + + assert asset.asset_type is None + assert asset.model_folder is None + assert asset.display_name is None + assert asset.file_path is None + + class TestReferenceExistsForAssetId: def test_returns_false_when_no_reference(self, session: Session): asset = _make_asset(session, "hash1") @@ -145,6 +313,439 @@ class TestListReferencesPage: assert total == 1 assert refs[0].name == "keep" + def test_model_folder_filter_uses_registered_paths(self, session: Session, tmp_path: Path): + checkpoints_dir = tmp_path / "models" / "checkpoints" + loras_dir = tmp_path / "models" / "loras" + checkpoints_dir.mkdir(parents=True) + loras_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + checkpoint = _make_reference( + session, + asset, + name="checkpoint", + file_path=str(checkpoints_dir / "model.safetensors"), + asset_type="model", + model_folder="checkpoints", + ) + _make_reference( + session, + asset, + name="lora", + file_path=str(loras_dir / "model.safetensors"), + asset_type="model", + model_folder="loras", + ) + session.commit() + + refs, _, total = list_references_page( + session, asset_type="model", model_folder="checkpoints" + ) + + assert total == 1 + assert refs[0].id == checkpoint.id + + def test_model_folder_filter_includes_all_registered_roots_for_folder( + self, session: Session, tmp_path: Path + ): + checkpoints_a = tmp_path / "root_a" / "checkpoints" + checkpoints_b = tmp_path / "root_b" / "checkpoints" + checkpoints_a.mkdir(parents=True) + checkpoints_b.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + ref_a = _make_reference( + session, + asset, + name="checkpoint-a", + file_path=str(checkpoints_a / "a.safetensors"), + asset_type="model", + model_folder="checkpoints", + ) + ref_b = _make_reference( + session, + asset, + name="checkpoint-b", + file_path=str(checkpoints_b / "b.safetensors"), + asset_type="model", + model_folder="checkpoints", + ) + session.commit() + + refs, _, total = list_references_page( + session, asset_type="model", model_folder="checkpoints" + ) + + assert total == 2 + assert {ref.id for ref in refs} == {ref_a.id, ref_b.id} + + def test_same_named_files_under_multiple_roots_both_return_in_model_folder_filter( + self, session: Session, tmp_path: Path + ): + checkpoints_a = tmp_path / "root_a" / "checkpoints" + checkpoints_b = tmp_path / "root_b" / "checkpoints" + checkpoints_a.mkdir(parents=True) + checkpoints_b.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + ref_a = _make_reference( + session, + asset, + name="checkpoint-a", + file_path=str(checkpoints_a / "duplicate.safetensors"), + asset_type="model", + model_folder="checkpoints", + ) + ref_b = _make_reference( + session, + asset, + name="checkpoint-b", + file_path=str(checkpoints_b / "duplicate.safetensors"), + asset_type="model", + model_folder="checkpoints", + ) + session.commit() + + refs, _, total = list_references_page( + session, asset_type="model", model_folder="checkpoints" + ) + + assert total == 2 + assert {ref.id for ref in refs} == {ref_a.id, ref_b.id} + + def test_arbitrary_registered_folder_filter_works( + self, session: Session, tmp_path: Path + ): + controlnet_dir = tmp_path / "models" / "controlnet" + controlnet_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + ref = _make_reference( + session, + asset, + name="controlnet", + file_path=str(controlnet_dir / "pose.safetensors"), + asset_type="model", + model_folder="controlnet", + ) + session.commit() + + refs, _, total = list_references_page( + session, asset_type="model", model_folder="controlnet" + ) + + assert total == 1 + assert refs[0].id == ref.id + + def test_unknown_model_folder_filter_returns_none_when_other_models_exist( + self, session: Session, tmp_path: Path + ): + checkpoints_dir = tmp_path / "models" / "checkpoints" + checkpoints_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + _make_reference( + session, + asset, + name="checkpoint", + file_path=str(checkpoints_dir / "model.safetensors"), + asset_type="model", + model_folder="checkpoints", + ) + session.commit() + + refs, _, total = list_references_page( + session, asset_type="model", model_folder="controlnet" + ) + + assert total == 0 + assert refs == [] + + def test_model_folder_filter_excludes_deeper_registered_model_folder( + self, session: Session, tmp_path: Path + ): + text_encoders_dir = tmp_path / "models" / "text_encoders" + clip_dir = text_encoders_dir / "clip" + clip_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + text_encoder = _make_reference( + session, + asset, + name="text_encoder", + file_path=str(text_encoders_dir / "t5xxl.safetensors"), + asset_type="model", + model_folder="text_encoders", + ) + _make_reference( + session, + asset, + name="clip", + file_path=str(clip_dir / "clip_l.safetensors"), + asset_type="model", + model_folder="text_encoders/clip", + ) + session.commit() + + refs, _, total = list_references_page( + session, asset_type="model", model_folder="text_encoders" + ) + + assert total == 1 + assert refs[0].id == text_encoder.id + + def test_child_model_folder_filter_returns_only_child( + self, session: Session, tmp_path: Path + ): + text_encoders_dir = tmp_path / "models" / "text_encoders" + clip_dir = text_encoders_dir / "clip" + clip_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + _make_reference( + session, + asset, + name="text_encoder", + file_path=str(text_encoders_dir / "t5xxl.safetensors"), + asset_type="model", + model_folder="text_encoders", + ) + clip = _make_reference( + session, + asset, + name="clip", + file_path=str(clip_dir / "clip_l.safetensors"), + asset_type="model", + model_folder="text_encoders/clip", + ) + session.commit() + + refs, _, total = list_references_page( + session, asset_type="model", model_folder="text_encoders/clip" + ) + + assert total == 1 + assert refs[0].id == clip.id + + def test_model_asset_type_filter_includes_parent_and_child_registered_roots( + self, session: Session, tmp_path: Path + ): + text_encoders_dir = tmp_path / "models" / "text_encoders" + clip_dir = text_encoders_dir / "clip" + clip_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + text_encoder = _make_reference( + session, + asset, + name="text_encoder", + file_path=str(text_encoders_dir / "t5xxl.safetensors"), + asset_type="model", + model_folder="text_encoders", + ) + clip = _make_reference( + session, + asset, + name="clip", + file_path=str(clip_dir / "clip_l.safetensors"), + asset_type="model", + model_folder="text_encoders/clip", + ) + session.commit() + + refs, _, total = list_references_page(session, asset_type="model") + + assert total == 2 + assert {ref.id for ref in refs} == {text_encoder.id, clip.id} + + def test_model_asset_type_filter_with_no_registered_paths_returns_none( + self, session: Session, tmp_path: Path + ): + asset = _make_asset(session, "hash1") + _make_reference( + session, + asset, + name="orphan", + file_path=str(tmp_path / "models" / "checkpoints" / "model.safetensors"), + ) + session.commit() + + refs, _, total = list_references_page(session, asset_type="model") + + assert total == 0 + assert refs == [] + + def test_model_asset_type_filter_excludes_unregistered_models_folder( + self, session: Session, tmp_path: Path + ): + checkpoints_dir = tmp_path / "models" / "checkpoints" + unregistered_dir = tmp_path / "models" / "unregistered" + checkpoints_dir.mkdir(parents=True) + unregistered_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + checkpoint = _make_reference( + session, + asset, + name="checkpoint", + file_path=str(checkpoints_dir / "model.safetensors"), + asset_type="model", + model_folder="checkpoints", + ) + _make_reference( + session, + asset, + name="unregistered", + file_path=str(unregistered_dir / "model.safetensors"), + ) + session.commit() + + refs, _, total = list_references_page(session, asset_type="model") + + assert total == 1 + assert refs[0].id == checkpoint.id + + def test_model_asset_type_filter_respects_prefix_boundaries( + self, session: Session, tmp_path: Path + ): + checkpoints_dir = tmp_path / "models" / "checkpoints" + checkpoints_extra_dir = tmp_path / "models" / "checkpoints_extra" + checkpoints_dir.mkdir(parents=True) + checkpoints_extra_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + checkpoint = _make_reference( + session, + asset, + name="checkpoint", + file_path=str(checkpoints_dir / "model.safetensors"), + asset_type="model", + model_folder="checkpoints", + ) + _make_reference( + session, + asset, + name="checkpoints_extra", + file_path=str(checkpoints_extra_dir / "model.safetensors"), + ) + session.commit() + + refs, _, total = list_references_page(session, asset_type="model") + + assert total == 1 + assert refs[0].id == checkpoint.id + + def test_model_asset_type_filter_is_case_exact( + self, session: Session, tmp_path: Path + ): + registered_dir = tmp_path / "models" / "checkpoints" + case_sibling_dir = tmp_path / "MODELS" / "checkpoints" + registered_dir.mkdir(parents=True) + case_sibling_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + checkpoint = _make_reference( + session, + asset, + name="checkpoint", + file_path=str(registered_dir / "model.safetensors"), + asset_type="model", + model_folder="checkpoints", + ) + _make_reference( + session, + asset, + name="case_sibling", + file_path=str(case_sibling_dir / "model.safetensors"), + ) + session.commit() + + refs, _, total = list_references_page(session, asset_type="model") + + assert total == 1 + assert refs[0].id == checkpoint.id + + def test_model_asset_type_filter_includes_output_backed_model_folder( + self, session: Session, tmp_path: Path + ): + output_checkpoints_dir = tmp_path / "output" / "checkpoints" + output_checkpoints_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + checkpoint_ref = _make_reference( + session, + asset, + name="checkpoint", + file_path=str(output_checkpoints_dir / "saved.safetensors"), + asset_type="model", + model_folder="checkpoints", + ) + session.commit() + + refs, _, total = list_references_page(session, asset_type="model") + + assert total == 1 + assert refs[0].id == checkpoint_ref.id + + def test_asset_type_filter_uses_root_paths(self, session: Session, tmp_path: Path): + input_dir = tmp_path / "input" + output_dir = tmp_path / "output" + temp_dir = tmp_path / "temp" + for directory in (input_dir, output_dir, temp_dir): + directory.mkdir() + + asset = _make_asset(session, "hash1") + input_ref = _make_reference( + session, + asset, + name="input", + file_path=str(input_dir / "image.png"), + asset_type="input", + ) + _make_reference( + session, + asset, + name="output", + file_path=str(output_dir / "image.png"), + asset_type="output", + ) + session.commit() + + refs, _, total = list_references_page(session, asset_type="input") + + assert total == 1 + assert refs[0].id == input_ref.id + + def test_output_asset_type_filter_excludes_output_backed_model_folders( + self, session: Session, tmp_path: Path + ): + output_dir = tmp_path / "output" + output_checkpoints_dir = output_dir / "checkpoints" + output_checkpoints_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + output_ref = _make_reference( + session, + asset, + name="output", + file_path=str(output_dir / "image.png"), + asset_type="output", + ) + _make_reference( + session, + asset, + name="checkpoint", + file_path=str(output_checkpoints_dir / "saved.safetensors"), + asset_type="model", + model_folder="checkpoints", + ) + session.commit() + + refs, _, total = list_references_page(session, asset_type="output") + + assert total == 1 + assert refs[0].id == output_ref.id + def test_sorting(self, session: Session): asset = _make_asset(session, "hash1", size=100) asset2 = _make_asset(session, "hash2", size=500) diff --git a/tests-unit/assets_test/queries/test_cache_state.py b/tests-unit/assets_test/queries/test_cache_state.py index ead60e570..fd6217700 100644 --- a/tests-unit/assets_test/queries/test_cache_state.py +++ b/tests-unit/assets_test/queries/test_cache_state.py @@ -255,6 +255,31 @@ class TestMarkReferencesMissingOutsidePrefixes: assert marked == 0 + def test_prefix_matching_is_case_exact(self, session: Session, tmp_path): + asset = _make_asset(session, "hash1") + valid_dir = tmp_path / "models" / "checkpoints" + case_sibling_dir = tmp_path / "MODELS" / "checkpoints" + valid_dir.mkdir(parents=True) + case_sibling_dir.mkdir(parents=True) + + valid_path = str(valid_dir / "file.bin") + case_sibling_path = str(case_sibling_dir / "file.bin") + + _make_reference(session, asset, valid_path, name="valid") + _make_reference(session, asset, case_sibling_path, name="case_sibling") + session.commit() + + marked = mark_references_missing_outside_prefixes(session, [str(valid_dir)]) + session.commit() + + assert marked == 1 + valid_ref = session.query(AssetReference).filter_by(file_path=valid_path).one() + case_sibling_ref = ( + session.query(AssetReference).filter_by(file_path=case_sibling_path).one() + ) + assert valid_ref.is_missing is False + assert case_sibling_ref.is_missing is True + class TestGetUnreferencedUnhashedAssetIds: def test_returns_unreferenced_unhashed_assets(self, session: Session): diff --git a/tests-unit/assets_test/services/test_bulk_ingest.py b/tests-unit/assets_test/services/test_bulk_ingest.py index 26e22a01d..5a3f0e1e8 100644 --- a/tests-unit/assets_test/services/test_bulk_ingest.py +++ b/tests-unit/assets_test/services/test_bulk_ingest.py @@ -64,6 +64,44 @@ class TestBatchInsertSeedAssets: assert len(assets) == 1 assert assets[0].mime_type is None + def test_duplicate_paths_in_same_batch_preserve_first_spec( + self, session: Session, temp_dir: Path + ): + file_path = temp_dir / "duplicate.safetensors" + file_path.write_bytes(b"fake safetensors content") + + specs: list[SeedAssetSpec] = [ + { + "abs_path": str(file_path), + "size_bytes": 24, + "mtime_ns": 1234567890000000000, + "info_name": "first", + "tags": ["models", "checkpoints"], + "fname": "duplicate.safetensors", + "metadata": None, + "hash": None, + "mime_type": "application/safetensors", + }, + { + "abs_path": str(file_path), + "size_bytes": 24, + "mtime_ns": 1234567890000000000, + "info_name": "second", + "tags": ["output"], + "fname": "duplicate.safetensors", + "metadata": None, + "hash": None, + "mime_type": "application/safetensors", + }, + ] + + result = batch_insert_seed_assets(session, specs=specs, owner_id="") + + assert result.inserted_refs == 1 + refs = session.query(AssetReference).all() + assert len(refs) == 1 + assert refs[0].name == "first" + def test_various_model_mime_types(self, session: Session, temp_dir: Path): """Verify various model file types get correct mime_type.""" test_cases = [ diff --git a/tests-unit/assets_test/services/test_path_utils.py b/tests-unit/assets_test/services/test_path_utils.py index 3fa905f9a..8735718b1 100644 --- a/tests-unit/assets_test/services/test_path_utils.py +++ b/tests-unit/assets_test/services/test_path_utils.py @@ -6,7 +6,14 @@ from unittest.mock import patch import pytest -from app.assets.services.path_utils import get_asset_category_and_relative_path +from app.assets.services.path_utils import ( + compute_relative_filename, + get_comfy_models_folders, + get_asset_category_and_relative_path, + get_asset_path_info, + get_asset_response_path_info, + resolve_asset_path_context, +) @pytest.fixture @@ -79,3 +86,225 @@ class TestGetAssetCategoryAndRelativePath: def test_unknown_path_raises(self, fake_dirs): with pytest.raises(ValueError, match="not within"): get_asset_category_and_relative_path("/some/random/path.png") + + +class TestGetAssetPathInfo: + def test_get_comfy_models_folders_excludes_core_infrastructure(self, tmp_path: Path): + controlnet_dir = tmp_path / "models" / "controlnet" + configs_dir = tmp_path / "models" / "configs" + custom_nodes_dir = tmp_path / "custom_nodes" + for directory in (controlnet_dir, configs_dir, custom_nodes_dir): + directory.mkdir(parents=True) + + with patch("app.assets.services.path_utils.folder_paths") as mock_fp: + mock_fp.folder_names_and_paths = { + "controlnet": ([str(controlnet_dir)], {".safetensors"}), + "configs": ([str(configs_dir)], {".yaml"}), + "custom_nodes": ([str(custom_nodes_dir)], set()), + } + + folders = get_comfy_models_folders() + + assert folders == [("controlnet", [str(controlnet_dir)])] + + def test_model_file_uses_registered_model_folder(self, fake_dirs): + f = fake_dirs["models"] / "subdir" / "model.safetensors" + f.parent.mkdir() + f.touch() + + info = get_asset_path_info(str(f)) + + assert info.asset_type == "model" + assert info.model_folder == "checkpoints" + + response_info = get_asset_response_path_info(str(f)) + assert response_info.file_path == "models/checkpoints/subdir/model.safetensors" + assert response_info.display_name == "subdir/model.safetensors" + + def test_arbitrary_registered_folder_is_model_folder(self, fake_dirs): + controlnet_dir = fake_dirs["models"].parent / "controlnet" + controlnet_dir.mkdir() + f = controlnet_dir / "pose.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("controlnet", [str(controlnet_dir)])], + ): + response_info = get_asset_response_path_info(str(f)) + + assert response_info.asset_type == "model" + assert response_info.model_folder == "controlnet" + assert response_info.file_path == "models/controlnet/pose.safetensors" + assert response_info.display_name == "pose.safetensors" + + def test_multiple_physical_roots_for_same_model_folder(self, fake_dirs): + root_a = fake_dirs["models"] + root_b = fake_dirs["output"] / "checkpoints" + root_b.mkdir() + file_a = root_a / "subdir" / "model_a.safetensors" + file_b = root_b / "subdir" / "model_b.safetensors" + file_a.parent.mkdir() + file_b.parent.mkdir() + file_a.touch() + file_b.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("checkpoints", [str(root_a), str(root_b)])], + ): + response_a = get_asset_response_path_info(str(file_a)) + response_b = get_asset_response_path_info(str(file_b)) + + assert response_a.asset_type == response_b.asset_type == "model" + assert response_a.model_folder == response_b.model_folder == "checkpoints" + assert response_a.file_path == "models/checkpoints/subdir/model_a.safetensors" + assert response_b.file_path == "models/checkpoints/subdir/model_b.safetensors" + assert response_a.display_name == "subdir/model_a.safetensors" + assert response_b.display_name == "subdir/model_b.safetensors" + + def test_same_named_files_under_multiple_roots_share_logical_file_path(self, fake_dirs): + root_a = fake_dirs["models"] + root_b = fake_dirs["output"] / "checkpoints" + root_b.mkdir() + file_a = root_a / "duplicate.safetensors" + file_b = root_b / "duplicate.safetensors" + file_a.touch() + file_b.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("checkpoints", [str(root_a), str(root_b)])], + ): + response_a = get_asset_response_path_info(str(file_a)) + response_b = get_asset_response_path_info(str(file_b)) + + assert response_a.file_path == response_b.file_path + assert response_a.file_path == "models/checkpoints/duplicate.safetensors" + assert response_a.display_name == response_b.display_name == "duplicate.safetensors" + + def test_input_file_has_no_model_folder(self, fake_dirs): + f = fake_dirs["input"] / "subdir" / "photo.png" + f.parent.mkdir() + f.touch() + + info = get_asset_path_info(str(f)) + + assert info.asset_type == "input" + assert info.model_folder is None + + response_info = get_asset_response_path_info(str(f)) + assert response_info.file_path == "input/subdir/photo.png" + assert response_info.display_name == "subdir/photo.png" + + def test_output_backed_registered_model_folder_is_model(self, fake_dirs): + output_checkpoints_dir = fake_dirs["output"] / "checkpoints" + output_checkpoints_dir.mkdir() + f = output_checkpoints_dir / "saved.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("checkpoints", [str(output_checkpoints_dir)])], + ): + context = resolve_asset_path_context(str(f)) + response_info = get_asset_response_path_info(str(f)) + + assert context.asset_type == "model" + assert context.model_folder == "checkpoints" + assert context.relative_path == "saved.safetensors" + + assert response_info.file_path == "models/checkpoints/saved.safetensors" + assert response_info.display_name == "saved.safetensors" + + def test_registered_model_folder_can_contain_slash(self, fake_dirs): + nested_model_dir = fake_dirs["models"].parent / "text_encoders" / "clip" + nested_model_dir.mkdir(parents=True) + f = nested_model_dir / "clip.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("text_encoders/clip", [str(nested_model_dir)])], + ): + info = get_asset_path_info(str(f)) + response_info = get_asset_response_path_info(str(f)) + + assert info.asset_type == "model" + assert info.model_folder == "text_encoders/clip" + + assert response_info.file_path == "models/text_encoders/clip/clip.safetensors" + assert response_info.display_name == "clip.safetensors" + + def test_slash_model_folder_relative_filename_uses_registered_base(self, fake_dirs): + nested_model_dir = fake_dirs["models"].parent / "text_encoders" / "clip" + nested_model_dir.mkdir(parents=True) + f = nested_model_dir / "subdir" / "clip.safetensors" + f.parent.mkdir() + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("text_encoders/clip", [str(nested_model_dir)])], + ): + assert compute_relative_filename(str(f)) == "subdir/clip.safetensors" + + def test_deepest_registered_model_base_wins(self, fake_dirs): + parent_dir = fake_dirs["models"].parent / "text_encoders" + nested_model_dir = parent_dir / "clip" + nested_model_dir.mkdir(parents=True) + f = nested_model_dir / "clip.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[ + ("text_encoders", [str(parent_dir)]), + ("text_encoders/clip", [str(nested_model_dir)]), + ], + ): + context = resolve_asset_path_context(str(f)) + + assert context.asset_type == "model" + assert context.model_folder == "text_encoders/clip" + assert context.relative_path == "clip.safetensors" + + def test_deepest_registered_model_base_wins_independent_of_registration_order( + self, fake_dirs + ): + parent_dir = fake_dirs["models"].parent / "text_encoders" + nested_model_dir = parent_dir / "clip" + nested_model_dir.mkdir(parents=True) + f = nested_model_dir / "clip.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[ + ("text_encoders/clip", [str(nested_model_dir)]), + ("text_encoders", [str(parent_dir)]), + ], + ): + context = resolve_asset_path_context(str(f)) + + assert context.asset_type == "model" + assert context.model_folder == "text_encoders/clip" + assert context.relative_path == "clip.safetensors" + + def test_path_under_unregistered_models_folder_is_unknown(self, fake_dirs): + unregistered_dir = fake_dirs["models"].parent / "unregistered" + unregistered_dir.mkdir() + f = unregistered_dir / "model.safetensors" + f.touch() + + with pytest.raises(ValueError, match="not within"): + resolve_asset_path_context(str(f)) + + def test_registered_model_folder_prefix_boundary(self, fake_dirs): + checkpoints_extra_dir = fake_dirs["models"].parent / "checkpoints_extra" + checkpoints_extra_dir.mkdir() + f = checkpoints_extra_dir / "model.safetensors" + f.touch() + + with pytest.raises(ValueError, match="not within"): + resolve_asset_path_context(str(f)) diff --git a/tests-unit/assets_test/test_sync_references.py b/tests-unit/assets_test/test_sync_references.py index 94cc255bc..7638890da 100644 --- a/tests-unit/assets_test/test_sync_references.py +++ b/tests-unit/assets_test/test_sync_references.py @@ -23,10 +23,60 @@ from app.assets.database.queries.asset_reference import ( get_unenriched_references, restore_references_by_paths, ) -from app.assets.scanner import sync_references_with_filesystem +from app.assets.scanner import ( + collect_paths_for_roots, + get_all_known_prefixes, + sync_references_with_filesystem, +) from app.assets.services.file_utils import get_mtime_ns +def test_collect_paths_for_roots_deduplicates_overlapping_roots(tmp_path: Path): + model_file = tmp_path / "output" / "checkpoints" / "saved.safetensors" + model_file.parent.mkdir(parents=True) + model_file.write_bytes(b"model") + + with ( + patch("app.assets.scanner.collect_models_files", return_value=[str(model_file)]), + patch( + "app.assets.scanner.list_files_recursively", + return_value=[str(model_file)], + ), + patch("app.assets.scanner.folder_paths") as mock_folder_paths, + ): + mock_folder_paths.get_output_directory.return_value = str(tmp_path / "output") + + paths = collect_paths_for_roots(("models", "output")) + + assert paths == [str(model_file)] + + +def test_all_known_prefixes_include_temp_root(tmp_path: Path): + models_dir = tmp_path / "models" / "checkpoints" + input_dir = tmp_path / "input" + output_dir = tmp_path / "output" + temp_dir = tmp_path / "temp" + for directory in (models_dir, input_dir, output_dir, temp_dir): + directory.mkdir(parents=True) + + with ( + patch("app.assets.scanner.get_comfy_models_folders", return_value=[("checkpoints", [str(models_dir)])]), + patch("app.assets.scanner.folder_paths") as mock_folder_paths, + ): + mock_folder_paths.get_input_directory.return_value = str(input_dir) + mock_folder_paths.get_output_directory.return_value = str(output_dir) + mock_folder_paths.get_temp_directory.return_value = str(temp_dir) + + prefixes = get_all_known_prefixes() + + assert prefixes == [ + str(models_dir), + str(input_dir), + str(output_dir), + str(temp_dir), + ] + + @pytest.fixture def db_engine(): engine = create_engine("sqlite:///:memory:") @@ -99,6 +149,40 @@ def _ensure_missing_tag(session: Session): session.flush() +def test_sync_reclassifies_existing_references_for_registered_model_roots( + session: Session, temp_dir: Path +): + model_dir = temp_dir / "models" / "checkpoints" + model_path = _create_file(model_dir, "saved.safetensors") + _make_asset( + session, + "asset-1", + model_path, + "ref-1", + mtime_ns=_stat_mtime_ns(model_path), + ) + session.commit() + + registered = [("checkpoints", [str(model_dir)])] + with ( + patch("app.assets.scanner.get_comfy_models_folders", return_value=registered), + patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=registered, + ), + ): + survivors = sync_references_with_filesystem( + session, "models", collect_existing_paths=True + ) + session.commit() + + ref = session.get(AssetReference, "ref-1") + assert survivors == {model_path} + assert ref is not None + assert ref.asset_type == "model" + assert ref.model_folder == "checkpoints" + + class _VerifyCase: def __init__(self, id, stat_unchanged, needs_verify_before, expect_needs_verify): self.id = id