From af6d4e8172ecb4cae1a4c83890081f56aa3ffe21 Mon Sep 17 00:00:00 2001 From: Simon Pinfold Date: Wed, 27 May 2026 09:09:30 +1200 Subject: [PATCH] spike: add typed asset classification filters Co-authored-by: Amp Amp-Thread-ID: https://ampcode.com/threads/T-019e5117-c707-729d-bf98-dce718fe64d5 --- app/assets/api/routes.py | 98 +++- app/assets/api/schemas_in.py | 37 +- app/assets/api/schemas_out.py | 15 +- .../database/queries/asset_reference.py | 7 + app/assets/database/queries/common.py | 96 +++- app/assets/scanner.py | 19 +- app/assets/services/asset_management.py | 4 + app/assets/services/bulk_ingest.py | 12 + app/assets/services/path_utils.py | 222 +++++++--- app/model_manager.py | 6 +- openapi.yaml | 30 +- .../assets_test/queries/test_asset_info.py | 417 ++++++++++++++++++ .../assets_test/queries/test_cache_state.py | 25 ++ .../assets_test/services/test_bulk_ingest.py | 38 ++ .../assets_test/services/test_path_utils.py | 231 +++++++++- .../assets_test/test_sync_references.py | 52 ++- 16 files changed, 1225 insertions(+), 84 deletions(-) diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 68126b6a5..6f684d682 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -10,7 +10,6 @@ from typing import Any from aiohttp import web from pydantic import ValidationError -import folder_paths from app import user_manager from app.assets.api import schemas_in, schemas_out from app.assets.services import schemas @@ -39,6 +38,10 @@ from app.assets.services import ( update_asset_metadata, upload_from_temp_path, ) +from app.assets.services.path_utils import ( + get_asset_response_path_info, + get_comfy_models_folders, +) from app.assets.services.tagging import list_tag_histogram ROUTES = web.RouteTableDef() @@ -124,17 +127,32 @@ def _validate_sort_field(requested: str | None) -> str: return "created_at" -def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None: - """Build a /api/view preview URL from asset tags and user_metadata filename.""" +def _get_asset_path_info(file_path: str | None): + if not file_path: + return None + try: + return get_asset_response_path_info(file_path) + except ValueError: + return None + + +def _build_preview_url_from_view( + asset_type: str | None, + user_metadata: dict[str, Any] | None, + fallback_tags: list[str] | None = None, +) -> str | None: + """Build a /api/view preview URL from path-derived type and filename metadata.""" if not user_metadata: return None filename = user_metadata.get("filename") if not filename: return None - if "input" in tags: + if asset_type in {"input", "output"}: + view_type = asset_type + elif fallback_tags and "input" in fallback_tags: view_type = "input" - elif "output" in tags: + elif fallback_tags and "output" in fallback_tags: view_type = "output" else: return None @@ -150,23 +168,82 @@ def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] return url +def _order_tags_for_legacy_path_compat( + tags: list[str], + asset_type: str | None, + model_folder: str | None, +) -> list[str]: + # BEGIN removable compatibility shim: path-like tag presentation + # + # Tags are stored and queried as flat unordered labels. Do not use this + # ordering as a path contract; use asset_type/model_folder/file_path instead. + # This only nudges response presentation for old callers that may have looked + # at tag[0] (and tag[1] for model folder) while the asset API migrates away + # from tag-as-path semantics. Remove this whole block once callers no longer + # depend on path-looking tag order. + if not tags or not asset_type: + return tags + + priority: list[str] = [] + root_tag = "models" if asset_type == "model" else asset_type + priority.append(root_tag) + if asset_type == "model" and model_folder: + priority.append(model_folder.lower()) + + remaining = list(tags) + ordered: list[str] = [] + for tag in priority: + if tag in remaining: + ordered.append(tag) + remaining.remove(tag) + + return ordered + remaining + # END removable compatibility shim: path-like tag presentation + + def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset: """Build an Asset response from a service result.""" + path_info = _get_asset_path_info(result.ref.file_path) + if result.ref.preview_id: preview_detail = get_asset_detail(result.ref.preview_id) if preview_detail: - preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata) + preview_path_info = _get_asset_path_info(preview_detail.ref.file_path) + preview_url = _build_preview_url_from_view( + preview_path_info.asset_type if preview_path_info else None, + preview_detail.ref.user_metadata, + fallback_tags=preview_detail.tags, + ) else: preview_url = None else: - preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata) + preview_url = _build_preview_url_from_view( + path_info.asset_type if path_info else None, + result.ref.user_metadata, + fallback_tags=result.tags, + ) + + asset_type = None + model_folder = None + file_path = None + display_name = None + if path_info: + asset_type = path_info.asset_type + model_folder = path_info.model_folder + file_path = path_info.file_path + display_name = path_info.display_name + return schemas_out.Asset( id=result.ref.id, name=result.ref.name, + file_path=file_path, + display_name=display_name, asset_hash=result.asset.hash if result.asset else None, size=int(result.asset.size_bytes) if result.asset else None, mime_type=result.asset.mime_type if result.asset else None, - tags=result.tags, + model_folder=model_folder, + asset_type=asset_type, + tags=_order_tags_for_legacy_path_compat(result.tags, asset_type, model_folder), preview_url=preview_url, preview_id=result.ref.preview_id, user_metadata=result.ref.user_metadata or {}, @@ -213,6 +290,8 @@ async def list_assets_route(request: web.Request) -> web.Response: owner_id=USER_MANAGER.get_request_user_id(request), include_tags=q.include_tags, exclude_tags=q.exclude_tags, + asset_type=q.asset_type, + model_folder=q.model_folder, name_contains=q.name_contains, metadata_filter=q.metadata_filter, limit=q.limit, @@ -401,9 +480,10 @@ async def upload_asset(request: web.Request) -> web.Response: ) if spec.tags and spec.tags[0] == "models": + model_folder_names = {name for name, _paths in get_comfy_models_folders()} if ( len(spec.tags) < 2 - or spec.tags[1] not in folder_paths.folder_names_and_paths + or spec.tags[1] not in model_folder_names ): delete_temp_file_if_exists(parsed.tmp_path) category = spec.tags[1] if len(spec.tags) >= 2 else "" diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index 186a6ae1e..d6c46423c 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -52,6 +52,8 @@ class ParsedUpload: class ListAssetsQuery(BaseModel): include_tags: list[str] = Field(default_factory=list) exclude_tags: list[str] = Field(default_factory=list) + asset_type: Literal["model", "input", "output", "temp"] | None = None + model_folder: str | None = None name_contains: str | None = None # Accept either a JSON string (query param) or a dict @@ -81,6 +83,20 @@ class ListAssetsQuery(BaseModel): return out return v + @field_validator("model_folder", mode="before") + @classmethod + def _normalize_model_folder(cls, v): + if v is None: + return None + s = str(v).strip() + return s or None + + @model_validator(mode="after") + def _validate_path_filters(self): + if self.model_folder and self.asset_type != "model": + raise ValueError("model_folder can only be used with asset_type=model") + return self + @field_validator("metadata_filter", mode="before") @classmethod def _parse_metadata_json(cls, v): @@ -300,14 +316,23 @@ class UploadAssetSpec(BaseModel): else: return [] - # normalize + dedupe + # Normalize the root tag, but preserve path/destination components. Tags + # are normalized again before storage; this parser also feeds upload + # destination routing where registered model folder names are exact. + normalized_items: list[str] = [] + for index, item in enumerate(items): + stripped = str(item).strip() + if not stripped: + continue + normalized_items.append(stripped.lower() if index == 0 else stripped) + + # Dedupe exact tokens while preserving order. norm = [] seen = set() - for t in items: - tnorm = str(t).strip().lower() - if tnorm and tnorm not in seen: - seen.add(tnorm) - norm.append(tnorm) + for token in normalized_items: + if token not in seen: + seen.add(token) + norm.append(token) return norm @field_validator("user_metadata", mode="before") diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index d99b1098d..9850ee2e6 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any +from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field, field_serializer @@ -10,9 +10,22 @@ class Asset(BaseModel): id: str name: str + file_path: str | None = Field( + default=None, + description="Logical asset namespace path. Model assets use `models//`; other typed assets use `/`. This is not a unique identity; use `id` for stable asset-reference operations.", + ) + display_name: str | None = Field( + default=None, + description="Human-facing path below the matched asset root or model folder.", + ) asset_hash: str | None = None size: int | None = None mime_type: str | None = None + model_folder: str | None = Field( + default=None, + description="Exact, case-sensitive registered ComfyUI model folder name. Present only when asset_type is `model`.", + ) + asset_type: Literal["model", "input", "output", "temp"] | None = None tags: list[str] = Field(default_factory=list) preview_url: str | None = None preview_id: str | None = None # references an asset_reference id, not an asset id diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py index 8b90ae511..e7d7e4ced 100644 --- a/app/assets/database/queries/asset_reference.py +++ b/app/assets/database/queries/asset_reference.py @@ -24,6 +24,7 @@ from app.assets.database.models import ( ) from app.assets.database.queries.common import ( MAX_BIND_PARAMS, + apply_asset_path_filters, apply_metadata_filter, apply_tag_filters, build_prefix_like_conditions, @@ -263,6 +264,8 @@ def list_references_page( name_contains: str | None = None, include_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None, + asset_type: str | None = None, + model_folder: str | None = None, metadata_filter: dict | None = None, sort: str | None = None, order: str | None = None, @@ -285,6 +288,7 @@ def list_references_page( base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) base = apply_tag_filters(base, include_tags, exclude_tags) + base = apply_asset_path_filters(base, asset_type=asset_type, model_folder=model_folder) base = apply_metadata_filter(base, metadata_filter) sort = (sort or "created_at").lower() @@ -315,6 +319,9 @@ def list_references_page( AssetReference.name.ilike(f"%{escaped}%", escape=esc) ) count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = apply_asset_path_filters( + count_stmt, asset_type=asset_type, model_folder=model_folder + ) count_stmt = apply_metadata_filter(count_stmt, metadata_filter) total = int(session.execute(count_stmt).scalar_one() or 0) diff --git a/app/assets/database/queries/common.py b/app/assets/database/queries/common.py index 89bb49327..adc4651ee 100644 --- a/app/assets/database/queries/common.py +++ b/app/assets/database/queries/common.py @@ -2,16 +2,24 @@ import os from decimal import Decimal +from pathlib import Path from typing import Iterable, Sequence import sqlalchemy as sa from sqlalchemy import exists +import folder_paths from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag -from app.assets.helpers import escape_sql_like_string, normalize_tags +from app.assets.helpers import normalize_tags MAX_BIND_PARAMS = 800 +# Mirrors app.model_manager.MODEL_FOLDER_BLACKLIST: these names are bootstrapped +# into folder_names_and_paths by core, but /api/experiment/models does not expose +# them as model folders. Keep this local to avoid reintroducing the asset query +# package initialization cycle from importing service/model-manager code here. +_NON_MODEL_FOLDER_NAMES = frozenset({"configs", "custom_nodes"}) + def calculate_rows_per_statement(cols: int) -> int: """Calculate how many rows can fit in one statement given column count.""" @@ -45,17 +53,97 @@ def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: def build_prefix_like_conditions( prefixes: list[str], ) -> list[sa.sql.ColumnElement]: - """Build LIKE conditions for matching file paths under directory prefixes.""" + """Build case-exact conditions for matching file paths under directory prefixes.""" conds = [] for p in prefixes: base = os.path.abspath(p) if not base.endswith(os.sep): base += os.sep - escaped, esc = escape_sql_like_string(base) - conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) + conds.append(sa.func.substr(AssetReference.file_path, 1, len(base)) == base) return conds +def _get_comfy_model_folders() -> list[tuple[str, list[str]]]: + """Return registered model folder names and roots without importing services. + + This intentionally stays local to the query layer to avoid importing + ``app.assets.services`` from ``app.assets.database.queries.common``, which + creates a package initialization cycle. + """ + targets: list[tuple[str, list[str]]] = [] + for name, values in folder_paths.folder_names_and_paths.items(): + if name in _NON_MODEL_FOLDER_NAMES: + continue + paths, _exts = values[0], values[1] + if paths: + targets.append((name, paths)) + return targets + + +def apply_asset_path_filters( + stmt: sa.sql.Select, + asset_type: str | None = None, + model_folder: str | None = None, +) -> sa.sql.Select: + """Filter references using their real filesystem/root registration context.""" + if asset_type is None and model_folder is None: + return stmt + if model_folder and asset_type != "model": + raise ValueError("model_folder can only be used with asset_type=model") + + registered_model_folders = _get_comfy_model_folders() + prefixes: list[str] = [] + exclude_prefixes: list[str] = [] + if model_folder: + for folder_name, paths in registered_model_folders: + if folder_name == model_folder: + prefixes.extend(paths) + break + if not prefixes: + return stmt.where(sa.false()) + + target_bases = [os.path.abspath(path) for path in prefixes] + for folder_name, paths in registered_model_folders: + if folder_name == model_folder: + continue + for path in paths: + path_abs = os.path.abspath(path) + if any( + Path(path_abs).is_relative_to(target_base) + and path_abs != target_base + for target_base in target_bases + ): + exclude_prefixes.append(path) + elif asset_type == "model": + for _folder_name, paths in registered_model_folders: + prefixes.extend(paths) + elif asset_type == "input": + prefixes = [folder_paths.get_input_directory()] + elif asset_type == "output": + prefixes = [folder_paths.get_output_directory()] + elif asset_type == "temp": + prefixes = [folder_paths.get_temp_directory()] + + conditions = build_prefix_like_conditions(prefixes) + if not conditions: + return stmt.where(sa.false()) + + clause = sa.or_(*conditions) + if asset_type in {"input", "output", "temp"}: + model_prefixes = [ + path for _folder_name, paths in registered_model_folders for path in paths + ] + model_conditions = build_prefix_like_conditions(model_prefixes) + if model_conditions: + clause = sa.and_(clause, sa.not_(sa.or_(*model_conditions))) + elif exclude_prefixes: + exclude_conditions = build_prefix_like_conditions(exclude_prefixes) + if exclude_conditions: + clause = sa.and_(clause, sa.not_(sa.or_(*exclude_conditions))) + + return stmt.where(clause) + + def apply_tag_filters( stmt: sa.sql.Select, include_tags: Sequence[str] | None = None, diff --git a/app/assets/scanner.py b/app/assets/scanner.py index ebb6869af..68b9b6895 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -56,7 +56,7 @@ class _AssetAccumulator(TypedDict): refs: list[_RefInfo] -RootType = Literal["models", "input", "output"] +RootType = Literal["models", "input", "output", "temp"] def get_prefixes_for_root(root: RootType) -> list[str]: @@ -69,12 +69,14 @@ def get_prefixes_for_root(root: RootType) -> list[str]: return [os.path.abspath(folder_paths.get_input_directory())] if root == "output": return [os.path.abspath(folder_paths.get_output_directory())] + if root == "temp": + return [os.path.abspath(folder_paths.get_temp_directory())] return [] def get_all_known_prefixes() -> list[str]: """Get all known asset prefixes across all root types.""" - all_roots: tuple[RootType, ...] = ("models", "input", "output") + all_roots: tuple[RootType, ...] = ("models", "input", "output", "temp") return [p for root in all_roots for p in get_prefixes_for_root(root)] @@ -274,7 +276,18 @@ def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]: paths.extend(list_files_recursively(folder_paths.get_input_directory())) if "output" in roots: paths.extend(list_files_recursively(folder_paths.get_output_directory())) - return paths + if "temp" in roots: + paths.extend(list_files_recursively(folder_paths.get_temp_directory())) + + deduped: list[str] = [] + seen: set[str] = set() + for path in paths: + abs_path = os.path.abspath(path) + if abs_path in seen: + continue + seen.add(abs_path) + deduped.append(abs_path) + return deduped def build_asset_specs( diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py index 5aefd9956..886715f8c 100644 --- a/app/assets/services/asset_management.py +++ b/app/assets/services/asset_management.py @@ -246,6 +246,8 @@ def list_assets_page( owner_id: str = "", include_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None, + asset_type: str | None = None, + model_folder: str | None = None, name_contains: str | None = None, metadata_filter: dict | None = None, limit: int = 20, @@ -259,6 +261,8 @@ def list_assets_page( owner_id=owner_id, include_tags=include_tags, exclude_tags=exclude_tags, + asset_type=asset_type, + model_folder=model_folder, name_contains=name_contains, metadata_filter=metadata_filter, limit=limit, diff --git a/app/assets/services/bulk_ingest.py b/app/assets/services/bulk_ingest.py index 67aad838f..56c348a20 100644 --- a/app/assets/services/bulk_ingest.py +++ b/app/assets/services/bulk_ingest.py @@ -125,6 +125,18 @@ def batch_insert_seed_assets( if not specs: return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0) + deduped_specs: list[SeedAssetSpec] = [] + seen_paths: set[str] = set() + for spec in specs: + absolute_path = os.path.abspath(spec["abs_path"]) + if absolute_path in seen_paths: + continue + seen_paths.add(absolute_path) + deduped_specs.append(spec) + specs = deduped_specs + if not specs: + return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0) + current_time = get_utc_now() asset_rows: list[AssetRow] = [] reference_rows: list[ReferenceRow] = [] diff --git a/app/assets/services/path_utils.py b/app/assets/services/path_utils.py index 892140ffb..3d3f75e6c 100644 --- a/app/assets/services/path_utils.py +++ b/app/assets/services/path_utils.py @@ -1,4 +1,5 @@ import os +from dataclasses import dataclass from pathlib import Path from typing import Literal @@ -6,7 +7,28 @@ import folder_paths from app.assets.helpers import normalize_tags -_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"}) +# Mirrors app.model_manager.MODEL_FOLDER_BLACKLIST: these names are bootstrapped +# into folder_names_and_paths by core, but /api/experiment/models does not expose +# them as model folders, so typed asset classification should not either. +_NON_MODEL_FOLDER_NAMES = frozenset({"configs", "custom_nodes"}) + + +@dataclass(frozen=True) +class AssetPathInfo: + asset_type: Literal["input", "output", "temp", "model"] + model_folder: str | None + + +@dataclass(frozen=True) +class AssetResponsePathInfo(AssetPathInfo): + file_path: str + display_name: str | None + + +@dataclass(frozen=True) +class AssetPathContext(AssetPathInfo): + base_path: str + relative_path: str def get_comfy_models_folders() -> list[tuple[str, list[str]]]: @@ -14,7 +36,7 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]: Includes every category registered in folder_names_and_paths, regardless of whether its paths are under the main models_dir, - but excludes non-model entries like custom_nodes. + but excludes non-model entries like configs and custom_nodes. """ targets: list[tuple[str, list[str]]] = [] for name, values in folder_paths.folder_names_and_paths.items(): @@ -67,28 +89,109 @@ def validate_path_within_base(candidate: str, base: str) -> None: def compute_relative_filename(file_path: str) -> str | None: """ - Return the model's path relative to the last well-known folder (the model category), - using forward slashes, eg: + Return the path relative to the matched asset root or model folder, using + forward slashes, eg: /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" + /.../input/sub/image.png -> "sub/image.png" - For non-model paths, returns None. + For unknown paths, returns None. """ try: - root_category, rel_path = get_asset_category_and_relative_path(file_path) + context = resolve_asset_path_context(file_path) except ValueError: return None - p = Path(rel_path) - parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] + return _normalize_relative_path(context.relative_path) + + +def _normalize_relative_path(relative_path: str) -> str | None: + parts = [ + seg + for seg in Path(relative_path).parts + if seg not in (".", "..", Path(relative_path).anchor) + ] if not parts: return None - if root_category == "models": - # parts[0] is the category ("checkpoints", "vae", etc) – drop it - inside = parts[1:] if len(parts) > 1 else [parts[0]] - return "/".join(inside) - return "/".join(parts) # input/output: keep all parts + return "/".join(parts) + + +def resolve_asset_path_context(file_path: str) -> AssetPathContext: + """Resolve a path against Core's asset roots and model-folder registration. + + This is the source of truth for path-derived asset classification. For + model assets, ``model_folder`` is the exact registered folder name whose + base path contains the file, and ``relative_path`` is relative to that + matched base path. When multiple registered bases contain the file, the + deepest base wins. + """ + fp_abs = os.path.abspath(file_path) + + def _check_is_within(child: str, parent: str) -> bool: + return Path(child).is_relative_to(parent) + + def _compute_relative(child: str, parent: str) -> str: + # Normalize relative path, stripping any leading ".." components + # by anchoring to root (os.sep) then computing relpath back from it. + return os.path.relpath( + os.path.join(os.sep, os.path.relpath(child, parent)), os.sep + ) + + best: tuple[int, str, str, str] | None = None + for model_folder, bases in get_comfy_models_folders(): + for base in bases: + base_abs = os.path.abspath(base) + if not _check_is_within(fp_abs, base_abs): + continue + cand = ( + len(base_abs), + model_folder, + base_abs, + _compute_relative(fp_abs, base_abs), + ) + if best is None or cand[0] > best[0]: + best = cand + + if best is not None: + _, model_folder, base_path, relative_path = best + return AssetPathContext( + asset_type="model", + model_folder=model_folder, + base_path=base_path, + relative_path=relative_path, + ) + + input_base = os.path.abspath(folder_paths.get_input_directory()) + if _check_is_within(fp_abs, input_base): + return AssetPathContext( + asset_type="input", + model_folder=None, + base_path=input_base, + relative_path=_compute_relative(fp_abs, input_base), + ) + + output_base = os.path.abspath(folder_paths.get_output_directory()) + if _check_is_within(fp_abs, output_base): + return AssetPathContext( + asset_type="output", + model_folder=None, + base_path=output_base, + relative_path=_compute_relative(fp_abs, output_base), + ) + + temp_base = os.path.abspath(folder_paths.get_temp_directory()) + if _check_is_within(fp_abs, temp_base): + return AssetPathContext( + asset_type="temp", + model_folder=None, + base_path=temp_base, + relative_path=_compute_relative(fp_abs, temp_base), + ) + + raise ValueError( + f"Path is not within input, output, temp, or configured model bases: {file_path}" + ) def get_asset_category_and_relative_path( @@ -108,51 +211,62 @@ def get_asset_category_and_relative_path( Raises: ValueError: path does not belong to any known root. """ - fp_abs = os.path.abspath(file_path) + context = resolve_asset_path_context(file_path) + if context.asset_type == "model": + combined = os.path.join(context.model_folder or "", context.relative_path) + return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) + return context.asset_type, context.relative_path - def _check_is_within(child: str, parent: str) -> bool: - return Path(child).is_relative_to(parent) - def _compute_relative(child: str, parent: str) -> str: - # Normalize relative path, stripping any leading ".." components - # by anchoring to root (os.sep) then computing relpath back from it. - return os.path.relpath( - os.path.join(os.sep, os.path.relpath(child, parent)), os.sep +def get_asset_path_info(file_path: str) -> AssetPathInfo: + """Return typed asset classification derived from the actual filesystem path. + + This intentionally reads the ComfyUI model folder registration from + ``folder_paths.folder_names_and_paths`` instead of inferring it from tags. + For model files, ``model_folder`` is the registered folder name whose base + path contains ``file_path``. + + Raises: + ValueError: path does not belong to any known root. + """ + context = resolve_asset_path_context(file_path) + return AssetPathInfo( + asset_type=context.asset_type, + model_folder=context.model_folder, + ) + + +def get_asset_response_path_info(file_path: str) -> AssetResponsePathInfo: + """Return API-facing path fields derived from the actual filesystem path. + + ``file_path`` is a logical namespace key: ``models//`` + for model assets and ``/`` for input/output/temp assets. + ``display_name`` is the path below the matched root or model folder. + + Raises: + ValueError: path does not belong to any known root. + """ + context = resolve_asset_path_context(file_path) + display_name = _normalize_relative_path(context.relative_path) + + if context.asset_type == "model": + logical_file_path = ( + f"models/{context.model_folder}/{display_name}" + if display_name + else f"models/{context.model_folder}" + ) + else: + logical_file_path = ( + f"{context.asset_type}/{display_name}" + if display_name + else context.asset_type ) - # 1) input - input_base = os.path.abspath(folder_paths.get_input_directory()) - if _check_is_within(fp_abs, input_base): - return "input", _compute_relative(fp_abs, input_base) - - # 2) output - output_base = os.path.abspath(folder_paths.get_output_directory()) - if _check_is_within(fp_abs, output_base): - return "output", _compute_relative(fp_abs, output_base) - - # 3) temp - temp_base = os.path.abspath(folder_paths.get_temp_directory()) - if _check_is_within(fp_abs, temp_base): - return "temp", _compute_relative(fp_abs, temp_base) - - # 4) models (check deepest matching base to avoid ambiguity) - best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket) - for bucket, bases in get_comfy_models_folders(): - for b in bases: - base_abs = os.path.abspath(b) - if not _check_is_within(fp_abs, base_abs): - continue - cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs)) - if best is None or cand[0] > best[0]: - best = cand - - if best is not None: - _, bucket, rel_inside = best - combined = os.path.join(bucket, rel_inside) - return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) - - raise ValueError( - f"Path is not within input, output, temp, or configured model bases: {file_path}" + return AssetResponsePathInfo( + asset_type=context.asset_type, + model_folder=context.model_folder, + file_path=logical_file_path, + display_name=display_name, ) diff --git a/app/model_manager.py b/app/model_manager.py index f124d1117..14f332de0 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -14,6 +14,9 @@ from io import BytesIO from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types +MODEL_FOLDER_BLACKLIST = frozenset({"configs", "custom_nodes"}) + + class ModelFileManager: def __init__(self) -> None: self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} @@ -32,10 +35,9 @@ class ModelFileManager: @routes.get("/experiment/models") async def get_model_folders(request): model_types = list(folder_paths.folder_names_and_paths.keys()) - folder_black_list = ["configs", "custom_nodes"] output_folders: list[dict] = [] for folder in model_types: - if folder in folder_black_list: + if folder in MODEL_FOLDER_BLACKLIST: continue output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)}) return web.json_response(output_folders) diff --git a/openapi.yaml b/openapi.yaml index 885231acc..d8e7ed88c 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -1511,7 +1511,7 @@ paths: in: query schema: type: integer - default: 50 + default: 20 - name: offset in: query schema: @@ -1535,6 +1535,17 @@ paths: style: form explode: true description: Tags that assets must not have + - name: asset_type + in: query + schema: + type: string + enum: [model, input, output, temp] + description: Filter by path-derived asset type. Model classification is based on registered ComfyUI model folder roots, not tags. + - name: model_folder + in: query + schema: + type: string + description: Filter model assets by exact registered ComfyUI model folder name. Requires asset_type=model. - name: name_contains in: query schema: @@ -6607,10 +6618,23 @@ components: id: type: string format: uuid - description: Unique identifier for the asset + description: AssetReference ID. Use this as the stable identity; logical file_path is not unique. name: type: string description: Name of the asset file + file_path: + type: string + description: Logical asset namespace path. Model assets use `models//`; input/output/temp assets use `/`. Omitted when the asset has no resolvable filesystem path. Not unique; use `id` for stable asset-reference operations. + display_name: + type: string + description: Human-facing path below the matched asset root or model folder. Omitted when the asset has no resolvable filesystem path. + asset_type: + type: string + enum: [model, input, output, temp] + description: Path-derived asset type. Model classification is based on registered ComfyUI model folder roots, not tags. Omitted when the asset has no resolvable filesystem path. + model_folder: + type: string + description: Exact, case-sensitive registered ComfyUI model folder name for model assets. Present only when asset_type is `model`. hash: type: string nullable: true @@ -8723,4 +8747,4 @@ components: items: $ref: "#/components/schemas/TaskEntry" pagination: - $ref: "#/components/schemas/PaginationInfo" \ No newline at end of file + $ref: "#/components/schemas/PaginationInfo" diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py index fe510e342..87a8071f6 100644 --- a/tests-unit/assets_test/queries/test_asset_info.py +++ b/tests-unit/assets_test/queries/test_asset_info.py @@ -1,8 +1,13 @@ import time import uuid +from pathlib import Path +from unittest.mock import patch + import pytest from sqlalchemy.orm import Session +from app.assets.api import schemas_in +from app.assets.api.routes import _order_tags_for_legacy_path_compat from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta from app.assets.database.queries import ( reference_exists_for_asset_id, @@ -25,6 +30,55 @@ from app.assets.database.queries import ( from app.assets.helpers import get_utc_now +class TestListAssetsQueryPathFilters: + def test_model_folder_requires_model_asset_type(self): + with pytest.raises(ValueError, match="model_folder can only be used"): + schemas_in.ListAssetsQuery.model_validate({"model_folder": "checkpoints"}) + + def test_model_folder_accepts_explicit_model_asset_type(self): + query = schemas_in.ListAssetsQuery.model_validate( + {"asset_type": "model", "model_folder": "checkpoints"} + ) + + assert query.asset_type == "model" + assert query.model_folder == "checkpoints" + + def test_model_folder_rejects_non_model_asset_type(self): + with pytest.raises(ValueError, match="model_folder can only be used"): + schemas_in.ListAssetsQuery.model_validate( + {"asset_type": "input", "model_folder": "checkpoints"} + ) + + def test_query_layer_rejects_model_folder_without_model_asset_type( + self, session: Session + ): + with pytest.raises(ValueError, match="model_folder can only be used"): + list_references_page(session, model_folder="checkpoints") + + def test_upload_tags_preserve_model_folder_case_for_destination(self): + spec = schemas_in.UploadAssetSpec.model_validate( + {"tags": ['["models", "LLM", "SubDir"]']} + ) + + assert spec.tags == ["models", "LLM", "SubDir"] + + +class TestLegacyPathTagOrdering: + def test_model_response_moves_models_and_model_folder_first(self): + assert _order_tags_for_legacy_path_compat( + ["blah", "checkpoints", "deep", "foo", "models"], + asset_type="model", + model_folder="checkpoints", + ) == ["models", "checkpoints", "blah", "deep", "foo"] + + def test_non_model_response_moves_root_tag_first(self): + assert _order_tags_for_legacy_path_compat( + ["subdir", "output"], + asset_type="output", + model_folder=None, + ) == ["output", "subdir"] + + def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset: asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream") session.add(asset) @@ -145,6 +199,369 @@ class TestListReferencesPage: assert total == 1 assert refs[0].name == "keep" + def test_model_folder_filter_uses_registered_paths(self, session: Session, tmp_path: Path): + checkpoints_dir = tmp_path / "models" / "checkpoints" + loras_dir = tmp_path / "models" / "loras" + checkpoints_dir.mkdir(parents=True) + loras_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + checkpoint = _make_reference(session, asset, name="checkpoint") + checkpoint.file_path = str(checkpoints_dir / "model.safetensors") + lora = _make_reference(session, asset, name="lora") + lora.file_path = str(loras_dir / "model.safetensors") + session.commit() + + with patch( + "app.assets.database.queries.common._get_comfy_model_folders", + return_value=[ + ("checkpoints", [str(checkpoints_dir)]), + ("loras", [str(loras_dir)]), + ], + ): + refs, _, total = list_references_page( + session, asset_type="model", model_folder="checkpoints" + ) + + assert total == 1 + assert refs[0].id == checkpoint.id + + def test_model_folder_filter_includes_all_registered_roots_for_folder( + self, session: Session, tmp_path: Path + ): + checkpoints_a = tmp_path / "root_a" / "checkpoints" + checkpoints_b = tmp_path / "root_b" / "checkpoints" + checkpoints_a.mkdir(parents=True) + checkpoints_b.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + ref_a = _make_reference(session, asset, name="checkpoint-a") + ref_a.file_path = str(checkpoints_a / "a.safetensors") + ref_b = _make_reference(session, asset, name="checkpoint-b") + ref_b.file_path = str(checkpoints_b / "b.safetensors") + session.commit() + + with patch( + "app.assets.database.queries.common._get_comfy_model_folders", + return_value=[("checkpoints", [str(checkpoints_a), str(checkpoints_b)])], + ): + refs, _, total = list_references_page( + session, asset_type="model", model_folder="checkpoints" + ) + + assert total == 2 + assert {ref.id for ref in refs} == {ref_a.id, ref_b.id} + + def test_same_named_files_under_multiple_roots_both_return_in_model_folder_filter( + self, session: Session, tmp_path: Path + ): + checkpoints_a = tmp_path / "root_a" / "checkpoints" + checkpoints_b = tmp_path / "root_b" / "checkpoints" + checkpoints_a.mkdir(parents=True) + checkpoints_b.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + ref_a = _make_reference(session, asset, name="checkpoint-a") + ref_a.file_path = str(checkpoints_a / "duplicate.safetensors") + ref_b = _make_reference(session, asset, name="checkpoint-b") + ref_b.file_path = str(checkpoints_b / "duplicate.safetensors") + session.commit() + + with patch( + "app.assets.database.queries.common._get_comfy_model_folders", + return_value=[("checkpoints", [str(checkpoints_a), str(checkpoints_b)])], + ): + refs, _, total = list_references_page( + session, asset_type="model", model_folder="checkpoints" + ) + + assert total == 2 + assert {ref.id for ref in refs} == {ref_a.id, ref_b.id} + + def test_arbitrary_registered_folder_filter_works( + self, session: Session, tmp_path: Path + ): + controlnet_dir = tmp_path / "models" / "controlnet" + controlnet_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, name="controlnet") + ref.file_path = str(controlnet_dir / "pose.safetensors") + session.commit() + + with patch( + "app.assets.database.queries.common._get_comfy_model_folders", + return_value=[("controlnet", [str(controlnet_dir)])], + ): + refs, _, total = list_references_page( + session, asset_type="model", model_folder="controlnet" + ) + + assert total == 1 + assert refs[0].id == ref.id + + def test_unknown_model_folder_filter_returns_none_when_other_models_exist( + self, session: Session, tmp_path: Path + ): + checkpoints_dir = tmp_path / "models" / "checkpoints" + checkpoints_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, name="checkpoint") + ref.file_path = str(checkpoints_dir / "model.safetensors") + session.commit() + + with patch( + "app.assets.database.queries.common._get_comfy_model_folders", + return_value=[("checkpoints", [str(checkpoints_dir)])], + ): + refs, _, total = list_references_page( + session, asset_type="model", model_folder="controlnet" + ) + + assert total == 0 + assert refs == [] + + def test_model_folder_filter_excludes_deeper_registered_model_folder( + self, session: Session, tmp_path: Path + ): + text_encoders_dir = tmp_path / "models" / "text_encoders" + clip_dir = text_encoders_dir / "clip" + clip_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + text_encoder = _make_reference(session, asset, name="text_encoder") + text_encoder.file_path = str(text_encoders_dir / "t5xxl.safetensors") + clip = _make_reference(session, asset, name="clip") + clip.file_path = str(clip_dir / "clip_l.safetensors") + session.commit() + + with patch( + "app.assets.database.queries.common._get_comfy_model_folders", + return_value=[ + ("text_encoders", [str(text_encoders_dir)]), + ("text_encoders/clip", [str(clip_dir)]), + ], + ): + refs, _, total = list_references_page( + session, asset_type="model", model_folder="text_encoders" + ) + + assert total == 1 + assert refs[0].id == text_encoder.id + + def test_child_model_folder_filter_returns_only_child( + self, session: Session, tmp_path: Path + ): + text_encoders_dir = tmp_path / "models" / "text_encoders" + clip_dir = text_encoders_dir / "clip" + clip_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + text_encoder = _make_reference(session, asset, name="text_encoder") + text_encoder.file_path = str(text_encoders_dir / "t5xxl.safetensors") + clip = _make_reference(session, asset, name="clip") + clip.file_path = str(clip_dir / "clip_l.safetensors") + session.commit() + + with patch( + "app.assets.database.queries.common._get_comfy_model_folders", + return_value=[ + ("text_encoders", [str(text_encoders_dir)]), + ("text_encoders/clip", [str(clip_dir)]), + ], + ): + refs, _, total = list_references_page( + session, asset_type="model", model_folder="text_encoders/clip" + ) + + assert total == 1 + assert refs[0].id == clip.id + + def test_model_asset_type_filter_includes_parent_and_child_registered_roots( + self, session: Session, tmp_path: Path + ): + text_encoders_dir = tmp_path / "models" / "text_encoders" + clip_dir = text_encoders_dir / "clip" + clip_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + text_encoder = _make_reference(session, asset, name="text_encoder") + text_encoder.file_path = str(text_encoders_dir / "t5xxl.safetensors") + clip = _make_reference(session, asset, name="clip") + clip.file_path = str(clip_dir / "clip_l.safetensors") + session.commit() + + with patch( + "app.assets.database.queries.common._get_comfy_model_folders", + return_value=[ + ("text_encoders", [str(text_encoders_dir)]), + ("text_encoders/clip", [str(clip_dir)]), + ], + ): + refs, _, total = list_references_page(session, asset_type="model") + + assert total == 2 + assert {ref.id for ref in refs} == {text_encoder.id, clip.id} + + def test_model_asset_type_filter_with_no_registered_paths_returns_none( + self, session: Session, tmp_path: Path + ): + asset = _make_asset(session, "hash1") + ref = _make_reference(session, asset, name="orphan") + ref.file_path = str(tmp_path / "models" / "checkpoints" / "model.safetensors") + session.commit() + + with patch( + "app.assets.database.queries.common._get_comfy_model_folders", + return_value=[], + ): + refs, _, total = list_references_page(session, asset_type="model") + + assert total == 0 + assert refs == [] + + def test_model_asset_type_filter_excludes_unregistered_models_folder( + self, session: Session, tmp_path: Path + ): + checkpoints_dir = tmp_path / "models" / "checkpoints" + unregistered_dir = tmp_path / "models" / "unregistered" + checkpoints_dir.mkdir(parents=True) + unregistered_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + checkpoint = _make_reference(session, asset, name="checkpoint") + checkpoint.file_path = str(checkpoints_dir / "model.safetensors") + unregistered = _make_reference(session, asset, name="unregistered") + unregistered.file_path = str(unregistered_dir / "model.safetensors") + session.commit() + + with patch( + "app.assets.database.queries.common._get_comfy_model_folders", + return_value=[("checkpoints", [str(checkpoints_dir)])], + ): + refs, _, total = list_references_page(session, asset_type="model") + + assert total == 1 + assert refs[0].id == checkpoint.id + + def test_model_asset_type_filter_respects_prefix_boundaries( + self, session: Session, tmp_path: Path + ): + checkpoints_dir = tmp_path / "models" / "checkpoints" + checkpoints_extra_dir = tmp_path / "models" / "checkpoints_extra" + checkpoints_dir.mkdir(parents=True) + checkpoints_extra_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + checkpoint = _make_reference(session, asset, name="checkpoint") + checkpoint.file_path = str(checkpoints_dir / "model.safetensors") + checkpoints_extra = _make_reference(session, asset, name="checkpoints_extra") + checkpoints_extra.file_path = str(checkpoints_extra_dir / "model.safetensors") + session.commit() + + with patch( + "app.assets.database.queries.common._get_comfy_model_folders", + return_value=[("checkpoints", [str(checkpoints_dir)])], + ): + refs, _, total = list_references_page(session, asset_type="model") + + assert total == 1 + assert refs[0].id == checkpoint.id + + def test_model_asset_type_filter_is_case_exact( + self, session: Session, tmp_path: Path + ): + registered_dir = tmp_path / "models" / "checkpoints" + case_sibling_dir = tmp_path / "MODELS" / "checkpoints" + registered_dir.mkdir(parents=True) + case_sibling_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + checkpoint = _make_reference(session, asset, name="checkpoint") + checkpoint.file_path = str(registered_dir / "model.safetensors") + case_sibling = _make_reference(session, asset, name="case_sibling") + case_sibling.file_path = str(case_sibling_dir / "model.safetensors") + session.commit() + + with patch( + "app.assets.database.queries.common._get_comfy_model_folders", + return_value=[("checkpoints", [str(registered_dir)])], + ): + refs, _, total = list_references_page(session, asset_type="model") + + assert total == 1 + assert refs[0].id == checkpoint.id + + def test_model_asset_type_filter_includes_output_backed_model_folder( + self, session: Session, tmp_path: Path + ): + output_checkpoints_dir = tmp_path / "output" / "checkpoints" + output_checkpoints_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + checkpoint_ref = _make_reference(session, asset, name="checkpoint") + checkpoint_ref.file_path = str(output_checkpoints_dir / "saved.safetensors") + session.commit() + + with patch( + "app.assets.database.queries.common._get_comfy_model_folders", + return_value=[("checkpoints", [str(output_checkpoints_dir)])], + ): + refs, _, total = list_references_page(session, asset_type="model") + + assert total == 1 + assert refs[0].id == checkpoint_ref.id + + def test_asset_type_filter_uses_root_paths(self, session: Session, tmp_path: Path): + input_dir = tmp_path / "input" + output_dir = tmp_path / "output" + temp_dir = tmp_path / "temp" + for directory in (input_dir, output_dir, temp_dir): + directory.mkdir() + + asset = _make_asset(session, "hash1") + input_ref = _make_reference(session, asset, name="input") + input_ref.file_path = str(input_dir / "image.png") + output_ref = _make_reference(session, asset, name="output") + output_ref.file_path = str(output_dir / "image.png") + session.commit() + + with patch("app.assets.database.queries.common.folder_paths") as mock_fp: + mock_fp.get_input_directory.return_value = str(input_dir) + mock_fp.get_output_directory.return_value = str(output_dir) + mock_fp.get_temp_directory.return_value = str(temp_dir) + + refs, _, total = list_references_page(session, asset_type="input") + + assert total == 1 + assert refs[0].id == input_ref.id + + def test_output_asset_type_filter_excludes_output_backed_model_folders( + self, session: Session, tmp_path: Path + ): + output_dir = tmp_path / "output" + output_checkpoints_dir = output_dir / "checkpoints" + output_checkpoints_dir.mkdir(parents=True) + + asset = _make_asset(session, "hash1") + output_ref = _make_reference(session, asset, name="output") + output_ref.file_path = str(output_dir / "image.png") + checkpoint_ref = _make_reference(session, asset, name="checkpoint") + checkpoint_ref.file_path = str(output_checkpoints_dir / "saved.safetensors") + session.commit() + + with patch("app.assets.database.queries.common.folder_paths") as mock_fp: + mock_fp.get_output_directory.return_value = str(output_dir) + with patch( + "app.assets.database.queries.common._get_comfy_model_folders", + return_value=[("checkpoints", [str(output_checkpoints_dir)])], + ): + refs, _, total = list_references_page(session, asset_type="output") + + assert total == 1 + assert refs[0].id == output_ref.id + def test_sorting(self, session: Session): asset = _make_asset(session, "hash1", size=100) asset2 = _make_asset(session, "hash2", size=500) diff --git a/tests-unit/assets_test/queries/test_cache_state.py b/tests-unit/assets_test/queries/test_cache_state.py index ead60e570..fd6217700 100644 --- a/tests-unit/assets_test/queries/test_cache_state.py +++ b/tests-unit/assets_test/queries/test_cache_state.py @@ -255,6 +255,31 @@ class TestMarkReferencesMissingOutsidePrefixes: assert marked == 0 + def test_prefix_matching_is_case_exact(self, session: Session, tmp_path): + asset = _make_asset(session, "hash1") + valid_dir = tmp_path / "models" / "checkpoints" + case_sibling_dir = tmp_path / "MODELS" / "checkpoints" + valid_dir.mkdir(parents=True) + case_sibling_dir.mkdir(parents=True) + + valid_path = str(valid_dir / "file.bin") + case_sibling_path = str(case_sibling_dir / "file.bin") + + _make_reference(session, asset, valid_path, name="valid") + _make_reference(session, asset, case_sibling_path, name="case_sibling") + session.commit() + + marked = mark_references_missing_outside_prefixes(session, [str(valid_dir)]) + session.commit() + + assert marked == 1 + valid_ref = session.query(AssetReference).filter_by(file_path=valid_path).one() + case_sibling_ref = ( + session.query(AssetReference).filter_by(file_path=case_sibling_path).one() + ) + assert valid_ref.is_missing is False + assert case_sibling_ref.is_missing is True + class TestGetUnreferencedUnhashedAssetIds: def test_returns_unreferenced_unhashed_assets(self, session: Session): diff --git a/tests-unit/assets_test/services/test_bulk_ingest.py b/tests-unit/assets_test/services/test_bulk_ingest.py index 26e22a01d..5a3f0e1e8 100644 --- a/tests-unit/assets_test/services/test_bulk_ingest.py +++ b/tests-unit/assets_test/services/test_bulk_ingest.py @@ -64,6 +64,44 @@ class TestBatchInsertSeedAssets: assert len(assets) == 1 assert assets[0].mime_type is None + def test_duplicate_paths_in_same_batch_preserve_first_spec( + self, session: Session, temp_dir: Path + ): + file_path = temp_dir / "duplicate.safetensors" + file_path.write_bytes(b"fake safetensors content") + + specs: list[SeedAssetSpec] = [ + { + "abs_path": str(file_path), + "size_bytes": 24, + "mtime_ns": 1234567890000000000, + "info_name": "first", + "tags": ["models", "checkpoints"], + "fname": "duplicate.safetensors", + "metadata": None, + "hash": None, + "mime_type": "application/safetensors", + }, + { + "abs_path": str(file_path), + "size_bytes": 24, + "mtime_ns": 1234567890000000000, + "info_name": "second", + "tags": ["output"], + "fname": "duplicate.safetensors", + "metadata": None, + "hash": None, + "mime_type": "application/safetensors", + }, + ] + + result = batch_insert_seed_assets(session, specs=specs, owner_id="") + + assert result.inserted_refs == 1 + refs = session.query(AssetReference).all() + assert len(refs) == 1 + assert refs[0].name == "first" + def test_various_model_mime_types(self, session: Session, temp_dir: Path): """Verify various model file types get correct mime_type.""" test_cases = [ diff --git a/tests-unit/assets_test/services/test_path_utils.py b/tests-unit/assets_test/services/test_path_utils.py index 3fa905f9a..8735718b1 100644 --- a/tests-unit/assets_test/services/test_path_utils.py +++ b/tests-unit/assets_test/services/test_path_utils.py @@ -6,7 +6,14 @@ from unittest.mock import patch import pytest -from app.assets.services.path_utils import get_asset_category_and_relative_path +from app.assets.services.path_utils import ( + compute_relative_filename, + get_comfy_models_folders, + get_asset_category_and_relative_path, + get_asset_path_info, + get_asset_response_path_info, + resolve_asset_path_context, +) @pytest.fixture @@ -79,3 +86,225 @@ class TestGetAssetCategoryAndRelativePath: def test_unknown_path_raises(self, fake_dirs): with pytest.raises(ValueError, match="not within"): get_asset_category_and_relative_path("/some/random/path.png") + + +class TestGetAssetPathInfo: + def test_get_comfy_models_folders_excludes_core_infrastructure(self, tmp_path: Path): + controlnet_dir = tmp_path / "models" / "controlnet" + configs_dir = tmp_path / "models" / "configs" + custom_nodes_dir = tmp_path / "custom_nodes" + for directory in (controlnet_dir, configs_dir, custom_nodes_dir): + directory.mkdir(parents=True) + + with patch("app.assets.services.path_utils.folder_paths") as mock_fp: + mock_fp.folder_names_and_paths = { + "controlnet": ([str(controlnet_dir)], {".safetensors"}), + "configs": ([str(configs_dir)], {".yaml"}), + "custom_nodes": ([str(custom_nodes_dir)], set()), + } + + folders = get_comfy_models_folders() + + assert folders == [("controlnet", [str(controlnet_dir)])] + + def test_model_file_uses_registered_model_folder(self, fake_dirs): + f = fake_dirs["models"] / "subdir" / "model.safetensors" + f.parent.mkdir() + f.touch() + + info = get_asset_path_info(str(f)) + + assert info.asset_type == "model" + assert info.model_folder == "checkpoints" + + response_info = get_asset_response_path_info(str(f)) + assert response_info.file_path == "models/checkpoints/subdir/model.safetensors" + assert response_info.display_name == "subdir/model.safetensors" + + def test_arbitrary_registered_folder_is_model_folder(self, fake_dirs): + controlnet_dir = fake_dirs["models"].parent / "controlnet" + controlnet_dir.mkdir() + f = controlnet_dir / "pose.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("controlnet", [str(controlnet_dir)])], + ): + response_info = get_asset_response_path_info(str(f)) + + assert response_info.asset_type == "model" + assert response_info.model_folder == "controlnet" + assert response_info.file_path == "models/controlnet/pose.safetensors" + assert response_info.display_name == "pose.safetensors" + + def test_multiple_physical_roots_for_same_model_folder(self, fake_dirs): + root_a = fake_dirs["models"] + root_b = fake_dirs["output"] / "checkpoints" + root_b.mkdir() + file_a = root_a / "subdir" / "model_a.safetensors" + file_b = root_b / "subdir" / "model_b.safetensors" + file_a.parent.mkdir() + file_b.parent.mkdir() + file_a.touch() + file_b.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("checkpoints", [str(root_a), str(root_b)])], + ): + response_a = get_asset_response_path_info(str(file_a)) + response_b = get_asset_response_path_info(str(file_b)) + + assert response_a.asset_type == response_b.asset_type == "model" + assert response_a.model_folder == response_b.model_folder == "checkpoints" + assert response_a.file_path == "models/checkpoints/subdir/model_a.safetensors" + assert response_b.file_path == "models/checkpoints/subdir/model_b.safetensors" + assert response_a.display_name == "subdir/model_a.safetensors" + assert response_b.display_name == "subdir/model_b.safetensors" + + def test_same_named_files_under_multiple_roots_share_logical_file_path(self, fake_dirs): + root_a = fake_dirs["models"] + root_b = fake_dirs["output"] / "checkpoints" + root_b.mkdir() + file_a = root_a / "duplicate.safetensors" + file_b = root_b / "duplicate.safetensors" + file_a.touch() + file_b.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("checkpoints", [str(root_a), str(root_b)])], + ): + response_a = get_asset_response_path_info(str(file_a)) + response_b = get_asset_response_path_info(str(file_b)) + + assert response_a.file_path == response_b.file_path + assert response_a.file_path == "models/checkpoints/duplicate.safetensors" + assert response_a.display_name == response_b.display_name == "duplicate.safetensors" + + def test_input_file_has_no_model_folder(self, fake_dirs): + f = fake_dirs["input"] / "subdir" / "photo.png" + f.parent.mkdir() + f.touch() + + info = get_asset_path_info(str(f)) + + assert info.asset_type == "input" + assert info.model_folder is None + + response_info = get_asset_response_path_info(str(f)) + assert response_info.file_path == "input/subdir/photo.png" + assert response_info.display_name == "subdir/photo.png" + + def test_output_backed_registered_model_folder_is_model(self, fake_dirs): + output_checkpoints_dir = fake_dirs["output"] / "checkpoints" + output_checkpoints_dir.mkdir() + f = output_checkpoints_dir / "saved.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("checkpoints", [str(output_checkpoints_dir)])], + ): + context = resolve_asset_path_context(str(f)) + response_info = get_asset_response_path_info(str(f)) + + assert context.asset_type == "model" + assert context.model_folder == "checkpoints" + assert context.relative_path == "saved.safetensors" + + assert response_info.file_path == "models/checkpoints/saved.safetensors" + assert response_info.display_name == "saved.safetensors" + + def test_registered_model_folder_can_contain_slash(self, fake_dirs): + nested_model_dir = fake_dirs["models"].parent / "text_encoders" / "clip" + nested_model_dir.mkdir(parents=True) + f = nested_model_dir / "clip.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("text_encoders/clip", [str(nested_model_dir)])], + ): + info = get_asset_path_info(str(f)) + response_info = get_asset_response_path_info(str(f)) + + assert info.asset_type == "model" + assert info.model_folder == "text_encoders/clip" + + assert response_info.file_path == "models/text_encoders/clip/clip.safetensors" + assert response_info.display_name == "clip.safetensors" + + def test_slash_model_folder_relative_filename_uses_registered_base(self, fake_dirs): + nested_model_dir = fake_dirs["models"].parent / "text_encoders" / "clip" + nested_model_dir.mkdir(parents=True) + f = nested_model_dir / "subdir" / "clip.safetensors" + f.parent.mkdir() + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[("text_encoders/clip", [str(nested_model_dir)])], + ): + assert compute_relative_filename(str(f)) == "subdir/clip.safetensors" + + def test_deepest_registered_model_base_wins(self, fake_dirs): + parent_dir = fake_dirs["models"].parent / "text_encoders" + nested_model_dir = parent_dir / "clip" + nested_model_dir.mkdir(parents=True) + f = nested_model_dir / "clip.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[ + ("text_encoders", [str(parent_dir)]), + ("text_encoders/clip", [str(nested_model_dir)]), + ], + ): + context = resolve_asset_path_context(str(f)) + + assert context.asset_type == "model" + assert context.model_folder == "text_encoders/clip" + assert context.relative_path == "clip.safetensors" + + def test_deepest_registered_model_base_wins_independent_of_registration_order( + self, fake_dirs + ): + parent_dir = fake_dirs["models"].parent / "text_encoders" + nested_model_dir = parent_dir / "clip" + nested_model_dir.mkdir(parents=True) + f = nested_model_dir / "clip.safetensors" + f.touch() + + with patch( + "app.assets.services.path_utils.get_comfy_models_folders", + return_value=[ + ("text_encoders/clip", [str(nested_model_dir)]), + ("text_encoders", [str(parent_dir)]), + ], + ): + context = resolve_asset_path_context(str(f)) + + assert context.asset_type == "model" + assert context.model_folder == "text_encoders/clip" + assert context.relative_path == "clip.safetensors" + + def test_path_under_unregistered_models_folder_is_unknown(self, fake_dirs): + unregistered_dir = fake_dirs["models"].parent / "unregistered" + unregistered_dir.mkdir() + f = unregistered_dir / "model.safetensors" + f.touch() + + with pytest.raises(ValueError, match="not within"): + resolve_asset_path_context(str(f)) + + def test_registered_model_folder_prefix_boundary(self, fake_dirs): + checkpoints_extra_dir = fake_dirs["models"].parent / "checkpoints_extra" + checkpoints_extra_dir.mkdir() + f = checkpoints_extra_dir / "model.safetensors" + f.touch() + + with pytest.raises(ValueError, match="not within"): + resolve_asset_path_context(str(f)) diff --git a/tests-unit/assets_test/test_sync_references.py b/tests-unit/assets_test/test_sync_references.py index 94cc255bc..e1894e7a5 100644 --- a/tests-unit/assets_test/test_sync_references.py +++ b/tests-unit/assets_test/test_sync_references.py @@ -23,10 +23,60 @@ from app.assets.database.queries.asset_reference import ( get_unenriched_references, restore_references_by_paths, ) -from app.assets.scanner import sync_references_with_filesystem +from app.assets.scanner import ( + collect_paths_for_roots, + get_all_known_prefixes, + sync_references_with_filesystem, +) from app.assets.services.file_utils import get_mtime_ns +def test_collect_paths_for_roots_deduplicates_overlapping_roots(tmp_path: Path): + model_file = tmp_path / "output" / "checkpoints" / "saved.safetensors" + model_file.parent.mkdir(parents=True) + model_file.write_bytes(b"model") + + with ( + patch("app.assets.scanner.collect_models_files", return_value=[str(model_file)]), + patch( + "app.assets.scanner.list_files_recursively", + return_value=[str(model_file)], + ), + patch("app.assets.scanner.folder_paths") as mock_folder_paths, + ): + mock_folder_paths.get_output_directory.return_value = str(tmp_path / "output") + + paths = collect_paths_for_roots(("models", "output")) + + assert paths == [str(model_file)] + + +def test_all_known_prefixes_include_temp_root(tmp_path: Path): + models_dir = tmp_path / "models" / "checkpoints" + input_dir = tmp_path / "input" + output_dir = tmp_path / "output" + temp_dir = tmp_path / "temp" + for directory in (models_dir, input_dir, output_dir, temp_dir): + directory.mkdir(parents=True) + + with ( + patch("app.assets.scanner.get_comfy_models_folders", return_value=[("checkpoints", [str(models_dir)])]), + patch("app.assets.scanner.folder_paths") as mock_folder_paths, + ): + mock_folder_paths.get_input_directory.return_value = str(input_dir) + mock_folder_paths.get_output_directory.return_value = str(output_dir) + mock_folder_paths.get_temp_directory.return_value = str(temp_dir) + + prefixes = get_all_known_prefixes() + + assert prefixes == [ + str(models_dir), + str(input_dir), + str(output_dir), + str(temp_dir), + ] + + @pytest.fixture def db_engine(): engine = create_engine("sqlite:///:memory:")