spike: add typed asset classification filters
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled

Co-authored-by: Amp <amp@ampcode.com>
Amp-Thread-ID: https://ampcode.com/threads/T-019e5117-c707-729d-bf98-dce718fe64d5
This commit is contained in:
Simon Pinfold 2026-05-27 09:09:30 +12:00
parent 1579bbb52d
commit af6d4e8172
16 changed files with 1225 additions and 84 deletions

View File

@ -10,7 +10,6 @@ from typing import Any
from aiohttp import web from aiohttp import web
from pydantic import ValidationError from pydantic import ValidationError
import folder_paths
from app import user_manager from app import user_manager
from app.assets.api import schemas_in, schemas_out from app.assets.api import schemas_in, schemas_out
from app.assets.services import schemas from app.assets.services import schemas
@ -39,6 +38,10 @@ from app.assets.services import (
update_asset_metadata, update_asset_metadata,
upload_from_temp_path, upload_from_temp_path,
) )
from app.assets.services.path_utils import (
get_asset_response_path_info,
get_comfy_models_folders,
)
from app.assets.services.tagging import list_tag_histogram from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef() ROUTES = web.RouteTableDef()
@ -124,17 +127,32 @@ def _validate_sort_field(requested: str | None) -> str:
return "created_at" return "created_at"
def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None: def _get_asset_path_info(file_path: str | None):
"""Build a /api/view preview URL from asset tags and user_metadata filename.""" if not file_path:
return None
try:
return get_asset_response_path_info(file_path)
except ValueError:
return None
def _build_preview_url_from_view(
asset_type: str | None,
user_metadata: dict[str, Any] | None,
fallback_tags: list[str] | None = None,
) -> str | None:
"""Build a /api/view preview URL from path-derived type and filename metadata."""
if not user_metadata: if not user_metadata:
return None return None
filename = user_metadata.get("filename") filename = user_metadata.get("filename")
if not filename: if not filename:
return None return None
if "input" in tags: if asset_type in {"input", "output"}:
view_type = asset_type
elif fallback_tags and "input" in fallback_tags:
view_type = "input" view_type = "input"
elif "output" in tags: elif fallback_tags and "output" in fallback_tags:
view_type = "output" view_type = "output"
else: else:
return None return None
@ -150,23 +168,82 @@ def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any]
return url return url
def _order_tags_for_legacy_path_compat(
tags: list[str],
asset_type: str | None,
model_folder: str | None,
) -> list[str]:
# BEGIN removable compatibility shim: path-like tag presentation
#
# Tags are stored and queried as flat unordered labels. Do not use this
# ordering as a path contract; use asset_type/model_folder/file_path instead.
# This only nudges response presentation for old callers that may have looked
# at tag[0] (and tag[1] for model folder) while the asset API migrates away
# from tag-as-path semantics. Remove this whole block once callers no longer
# depend on path-looking tag order.
if not tags or not asset_type:
return tags
priority: list[str] = []
root_tag = "models" if asset_type == "model" else asset_type
priority.append(root_tag)
if asset_type == "model" and model_folder:
priority.append(model_folder.lower())
remaining = list(tags)
ordered: list[str] = []
for tag in priority:
if tag in remaining:
ordered.append(tag)
remaining.remove(tag)
return ordered + remaining
# END removable compatibility shim: path-like tag presentation
def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset: def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset:
"""Build an Asset response from a service result.""" """Build an Asset response from a service result."""
path_info = _get_asset_path_info(result.ref.file_path)
if result.ref.preview_id: if result.ref.preview_id:
preview_detail = get_asset_detail(result.ref.preview_id) preview_detail = get_asset_detail(result.ref.preview_id)
if preview_detail: if preview_detail:
preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata) preview_path_info = _get_asset_path_info(preview_detail.ref.file_path)
preview_url = _build_preview_url_from_view(
preview_path_info.asset_type if preview_path_info else None,
preview_detail.ref.user_metadata,
fallback_tags=preview_detail.tags,
)
else: else:
preview_url = None preview_url = None
else: else:
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata) preview_url = _build_preview_url_from_view(
path_info.asset_type if path_info else None,
result.ref.user_metadata,
fallback_tags=result.tags,
)
asset_type = None
model_folder = None
file_path = None
display_name = None
if path_info:
asset_type = path_info.asset_type
model_folder = path_info.model_folder
file_path = path_info.file_path
display_name = path_info.display_name
return schemas_out.Asset( return schemas_out.Asset(
id=result.ref.id, id=result.ref.id,
name=result.ref.name, name=result.ref.name,
file_path=file_path,
display_name=display_name,
asset_hash=result.asset.hash if result.asset else None, asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None, size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None, mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags, model_folder=model_folder,
asset_type=asset_type,
tags=_order_tags_for_legacy_path_compat(result.tags, asset_type, model_folder),
preview_url=preview_url, preview_url=preview_url,
preview_id=result.ref.preview_id, preview_id=result.ref.preview_id,
user_metadata=result.ref.user_metadata or {}, user_metadata=result.ref.user_metadata or {},
@ -213,6 +290,8 @@ async def list_assets_route(request: web.Request) -> web.Response:
owner_id=USER_MANAGER.get_request_user_id(request), owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags, include_tags=q.include_tags,
exclude_tags=q.exclude_tags, exclude_tags=q.exclude_tags,
asset_type=q.asset_type,
model_folder=q.model_folder,
name_contains=q.name_contains, name_contains=q.name_contains,
metadata_filter=q.metadata_filter, metadata_filter=q.metadata_filter,
limit=q.limit, limit=q.limit,
@ -401,9 +480,10 @@ async def upload_asset(request: web.Request) -> web.Response:
) )
if spec.tags and spec.tags[0] == "models": if spec.tags and spec.tags[0] == "models":
model_folder_names = {name for name, _paths in get_comfy_models_folders()}
if ( if (
len(spec.tags) < 2 len(spec.tags) < 2
or spec.tags[1] not in folder_paths.folder_names_and_paths or spec.tags[1] not in model_folder_names
): ):
delete_temp_file_if_exists(parsed.tmp_path) delete_temp_file_if_exists(parsed.tmp_path)
category = spec.tags[1] if len(spec.tags) >= 2 else "" category = spec.tags[1] if len(spec.tags) >= 2 else ""

View File

@ -52,6 +52,8 @@ class ParsedUpload:
class ListAssetsQuery(BaseModel): class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list) include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list) exclude_tags: list[str] = Field(default_factory=list)
asset_type: Literal["model", "input", "output", "temp"] | None = None
model_folder: str | None = None
name_contains: str | None = None name_contains: str | None = None
# Accept either a JSON string (query param) or a dict # Accept either a JSON string (query param) or a dict
@ -81,6 +83,20 @@ class ListAssetsQuery(BaseModel):
return out return out
return v return v
@field_validator("model_folder", mode="before")
@classmethod
def _normalize_model_folder(cls, v):
if v is None:
return None
s = str(v).strip()
return s or None
@model_validator(mode="after")
def _validate_path_filters(self):
if self.model_folder and self.asset_type != "model":
raise ValueError("model_folder can only be used with asset_type=model")
return self
@field_validator("metadata_filter", mode="before") @field_validator("metadata_filter", mode="before")
@classmethod @classmethod
def _parse_metadata_json(cls, v): def _parse_metadata_json(cls, v):
@ -300,14 +316,23 @@ class UploadAssetSpec(BaseModel):
else: else:
return [] return []
# normalize + dedupe # Normalize the root tag, but preserve path/destination components. Tags
# are normalized again before storage; this parser also feeds upload
# destination routing where registered model folder names are exact.
normalized_items: list[str] = []
for index, item in enumerate(items):
stripped = str(item).strip()
if not stripped:
continue
normalized_items.append(stripped.lower() if index == 0 else stripped)
# Dedupe exact tokens while preserving order.
norm = [] norm = []
seen = set() seen = set()
for t in items: for token in normalized_items:
tnorm = str(t).strip().lower() if token not in seen:
if tnorm and tnorm not in seen: seen.add(token)
seen.add(tnorm) norm.append(token)
norm.append(tnorm)
return norm return norm
@field_validator("user_metadata", mode="before") @field_validator("user_metadata", mode="before")

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, field_serializer from pydantic import BaseModel, ConfigDict, Field, field_serializer
@ -10,9 +10,22 @@ class Asset(BaseModel):
id: str id: str
name: str name: str
file_path: str | None = Field(
default=None,
description="Logical asset namespace path. Model assets use `models/<model_folder>/<relative_path>`; other typed assets use `<asset_type>/<relative_path>`. This is not a unique identity; use `id` for stable asset-reference operations.",
)
display_name: str | None = Field(
default=None,
description="Human-facing path below the matched asset root or model folder.",
)
asset_hash: str | None = None asset_hash: str | None = None
size: int | None = None size: int | None = None
mime_type: str | None = None mime_type: str | None = None
model_folder: str | None = Field(
default=None,
description="Exact, case-sensitive registered ComfyUI model folder name. Present only when asset_type is `model`.",
)
asset_type: Literal["model", "input", "output", "temp"] | None = None
tags: list[str] = Field(default_factory=list) tags: list[str] = Field(default_factory=list)
preview_url: str | None = None preview_url: str | None = None
preview_id: str | None = None # references an asset_reference id, not an asset id preview_id: str | None = None # references an asset_reference id, not an asset id

View File

@ -24,6 +24,7 @@ from app.assets.database.models import (
) )
from app.assets.database.queries.common import ( from app.assets.database.queries.common import (
MAX_BIND_PARAMS, MAX_BIND_PARAMS,
apply_asset_path_filters,
apply_metadata_filter, apply_metadata_filter,
apply_tag_filters, apply_tag_filters,
build_prefix_like_conditions, build_prefix_like_conditions,
@ -263,6 +264,8 @@ def list_references_page(
name_contains: str | None = None, name_contains: str | None = None,
include_tags: Sequence[str] | None = None, include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None,
asset_type: str | None = None,
model_folder: str | None = None,
metadata_filter: dict | None = None, metadata_filter: dict | None = None,
sort: str | None = None, sort: str | None = None,
order: str | None = None, order: str | None = None,
@ -285,6 +288,7 @@ def list_references_page(
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags) base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_asset_path_filters(base, asset_type=asset_type, model_folder=model_folder)
base = apply_metadata_filter(base, metadata_filter) base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower() sort = (sort or "created_at").lower()
@ -315,6 +319,9 @@ def list_references_page(
AssetReference.name.ilike(f"%{escaped}%", escape=esc) AssetReference.name.ilike(f"%{escaped}%", escape=esc)
) )
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_asset_path_filters(
count_stmt, asset_type=asset_type, model_folder=model_folder
)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter) count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int(session.execute(count_stmt).scalar_one() or 0) total = int(session.execute(count_stmt).scalar_one() or 0)

View File

@ -2,16 +2,24 @@
import os import os
from decimal import Decimal from decimal import Decimal
from pathlib import Path
from typing import Iterable, Sequence from typing import Iterable, Sequence
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import exists from sqlalchemy import exists
import folder_paths
from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
from app.assets.helpers import escape_sql_like_string, normalize_tags from app.assets.helpers import normalize_tags
MAX_BIND_PARAMS = 800 MAX_BIND_PARAMS = 800
# Mirrors app.model_manager.MODEL_FOLDER_BLACKLIST: these names are bootstrapped
# into folder_names_and_paths by core, but /api/experiment/models does not expose
# them as model folders. Keep this local to avoid reintroducing the asset query
# package initialization cycle from importing service/model-manager code here.
_NON_MODEL_FOLDER_NAMES = frozenset({"configs", "custom_nodes"})
def calculate_rows_per_statement(cols: int) -> int: def calculate_rows_per_statement(cols: int) -> int:
"""Calculate how many rows can fit in one statement given column count.""" """Calculate how many rows can fit in one statement given column count."""
@ -45,17 +53,97 @@ def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
def build_prefix_like_conditions( def build_prefix_like_conditions(
prefixes: list[str], prefixes: list[str],
) -> list[sa.sql.ColumnElement]: ) -> list[sa.sql.ColumnElement]:
"""Build LIKE conditions for matching file paths under directory prefixes.""" """Build case-exact conditions for matching file paths under directory prefixes."""
conds = [] conds = []
for p in prefixes: for p in prefixes:
base = os.path.abspath(p) base = os.path.abspath(p)
if not base.endswith(os.sep): if not base.endswith(os.sep):
base += os.sep base += os.sep
escaped, esc = escape_sql_like_string(base) conds.append(sa.func.substr(AssetReference.file_path, 1, len(base)) == base)
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
return conds return conds
def _get_comfy_model_folders() -> list[tuple[str, list[str]]]:
"""Return registered model folder names and roots without importing services.
This intentionally stays local to the query layer to avoid importing
``app.assets.services`` from ``app.assets.database.queries.common``, which
creates a package initialization cycle.
"""
targets: list[tuple[str, list[str]]] = []
for name, values in folder_paths.folder_names_and_paths.items():
if name in _NON_MODEL_FOLDER_NAMES:
continue
paths, _exts = values[0], values[1]
if paths:
targets.append((name, paths))
return targets
def apply_asset_path_filters(
stmt: sa.sql.Select,
asset_type: str | None = None,
model_folder: str | None = None,
) -> sa.sql.Select:
"""Filter references using their real filesystem/root registration context."""
if asset_type is None and model_folder is None:
return stmt
if model_folder and asset_type != "model":
raise ValueError("model_folder can only be used with asset_type=model")
registered_model_folders = _get_comfy_model_folders()
prefixes: list[str] = []
exclude_prefixes: list[str] = []
if model_folder:
for folder_name, paths in registered_model_folders:
if folder_name == model_folder:
prefixes.extend(paths)
break
if not prefixes:
return stmt.where(sa.false())
target_bases = [os.path.abspath(path) for path in prefixes]
for folder_name, paths in registered_model_folders:
if folder_name == model_folder:
continue
for path in paths:
path_abs = os.path.abspath(path)
if any(
Path(path_abs).is_relative_to(target_base)
and path_abs != target_base
for target_base in target_bases
):
exclude_prefixes.append(path)
elif asset_type == "model":
for _folder_name, paths in registered_model_folders:
prefixes.extend(paths)
elif asset_type == "input":
prefixes = [folder_paths.get_input_directory()]
elif asset_type == "output":
prefixes = [folder_paths.get_output_directory()]
elif asset_type == "temp":
prefixes = [folder_paths.get_temp_directory()]
conditions = build_prefix_like_conditions(prefixes)
if not conditions:
return stmt.where(sa.false())
clause = sa.or_(*conditions)
if asset_type in {"input", "output", "temp"}:
model_prefixes = [
path for _folder_name, paths in registered_model_folders for path in paths
]
model_conditions = build_prefix_like_conditions(model_prefixes)
if model_conditions:
clause = sa.and_(clause, sa.not_(sa.or_(*model_conditions)))
elif exclude_prefixes:
exclude_conditions = build_prefix_like_conditions(exclude_prefixes)
if exclude_conditions:
clause = sa.and_(clause, sa.not_(sa.or_(*exclude_conditions)))
return stmt.where(clause)
def apply_tag_filters( def apply_tag_filters(
stmt: sa.sql.Select, stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None, include_tags: Sequence[str] | None = None,

View File

@ -56,7 +56,7 @@ class _AssetAccumulator(TypedDict):
refs: list[_RefInfo] refs: list[_RefInfo]
RootType = Literal["models", "input", "output"] RootType = Literal["models", "input", "output", "temp"]
def get_prefixes_for_root(root: RootType) -> list[str]: def get_prefixes_for_root(root: RootType) -> list[str]:
@ -69,12 +69,14 @@ def get_prefixes_for_root(root: RootType) -> list[str]:
return [os.path.abspath(folder_paths.get_input_directory())] return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output": if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())] return [os.path.abspath(folder_paths.get_output_directory())]
if root == "temp":
return [os.path.abspath(folder_paths.get_temp_directory())]
return [] return []
def get_all_known_prefixes() -> list[str]: def get_all_known_prefixes() -> list[str]:
"""Get all known asset prefixes across all root types.""" """Get all known asset prefixes across all root types."""
all_roots: tuple[RootType, ...] = ("models", "input", "output") all_roots: tuple[RootType, ...] = ("models", "input", "output", "temp")
return [p for root in all_roots for p in get_prefixes_for_root(root)] return [p for root in all_roots for p in get_prefixes_for_root(root)]
@ -274,7 +276,18 @@ def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
paths.extend(list_files_recursively(folder_paths.get_input_directory())) paths.extend(list_files_recursively(folder_paths.get_input_directory()))
if "output" in roots: if "output" in roots:
paths.extend(list_files_recursively(folder_paths.get_output_directory())) paths.extend(list_files_recursively(folder_paths.get_output_directory()))
return paths if "temp" in roots:
paths.extend(list_files_recursively(folder_paths.get_temp_directory()))
deduped: list[str] = []
seen: set[str] = set()
for path in paths:
abs_path = os.path.abspath(path)
if abs_path in seen:
continue
seen.add(abs_path)
deduped.append(abs_path)
return deduped
def build_asset_specs( def build_asset_specs(

View File

@ -246,6 +246,8 @@ def list_assets_page(
owner_id: str = "", owner_id: str = "",
include_tags: Sequence[str] | None = None, include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None,
asset_type: str | None = None,
model_folder: str | None = None,
name_contains: str | None = None, name_contains: str | None = None,
metadata_filter: dict | None = None, metadata_filter: dict | None = None,
limit: int = 20, limit: int = 20,
@ -259,6 +261,8 @@ def list_assets_page(
owner_id=owner_id, owner_id=owner_id,
include_tags=include_tags, include_tags=include_tags,
exclude_tags=exclude_tags, exclude_tags=exclude_tags,
asset_type=asset_type,
model_folder=model_folder,
name_contains=name_contains, name_contains=name_contains,
metadata_filter=metadata_filter, metadata_filter=metadata_filter,
limit=limit, limit=limit,

View File

@ -125,6 +125,18 @@ def batch_insert_seed_assets(
if not specs: if not specs:
return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0) return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0)
deduped_specs: list[SeedAssetSpec] = []
seen_paths: set[str] = set()
for spec in specs:
absolute_path = os.path.abspath(spec["abs_path"])
if absolute_path in seen_paths:
continue
seen_paths.add(absolute_path)
deduped_specs.append(spec)
specs = deduped_specs
if not specs:
return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0)
current_time = get_utc_now() current_time = get_utc_now()
asset_rows: list[AssetRow] = [] asset_rows: list[AssetRow] = []
reference_rows: list[ReferenceRow] = [] reference_rows: list[ReferenceRow] = []

View File

@ -1,4 +1,5 @@
import os import os
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Literal from typing import Literal
@ -6,7 +7,28 @@ import folder_paths
from app.assets.helpers import normalize_tags from app.assets.helpers import normalize_tags
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"}) # Mirrors app.model_manager.MODEL_FOLDER_BLACKLIST: these names are bootstrapped
# into folder_names_and_paths by core, but /api/experiment/models does not expose
# them as model folders, so typed asset classification should not either.
_NON_MODEL_FOLDER_NAMES = frozenset({"configs", "custom_nodes"})
@dataclass(frozen=True)
class AssetPathInfo:
asset_type: Literal["input", "output", "temp", "model"]
model_folder: str | None
@dataclass(frozen=True)
class AssetResponsePathInfo(AssetPathInfo):
file_path: str
display_name: str | None
@dataclass(frozen=True)
class AssetPathContext(AssetPathInfo):
base_path: str
relative_path: str
def get_comfy_models_folders() -> list[tuple[str, list[str]]]: def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
@ -14,7 +36,7 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
Includes every category registered in folder_names_and_paths, Includes every category registered in folder_names_and_paths,
regardless of whether its paths are under the main models_dir, regardless of whether its paths are under the main models_dir,
but excludes non-model entries like custom_nodes. but excludes non-model entries like configs and custom_nodes.
""" """
targets: list[tuple[str, list[str]]] = [] targets: list[tuple[str, list[str]]] = []
for name, values in folder_paths.folder_names_and_paths.items(): for name, values in folder_paths.folder_names_and_paths.items():
@ -67,28 +89,109 @@ def validate_path_within_base(candidate: str, base: str) -> None:
def compute_relative_filename(file_path: str) -> str | None: def compute_relative_filename(file_path: str) -> str | None:
""" """
Return the model's path relative to the last well-known folder (the model category), Return the path relative to the matched asset root or model folder, using
using forward slashes, eg: forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
/.../input/sub/image.png -> "sub/image.png"
For non-model paths, returns None. For unknown paths, returns None.
""" """
try: try:
root_category, rel_path = get_asset_category_and_relative_path(file_path) context = resolve_asset_path_context(file_path)
except ValueError: except ValueError:
return None return None
p = Path(rel_path) return _normalize_relative_path(context.relative_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
def _normalize_relative_path(relative_path: str) -> str | None:
parts = [
seg
for seg in Path(relative_path).parts
if seg not in (".", "..", Path(relative_path).anchor)
]
if not parts: if not parts:
return None return None
if root_category == "models": return "/".join(parts)
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside) def resolve_asset_path_context(file_path: str) -> AssetPathContext:
return "/".join(parts) # input/output: keep all parts """Resolve a path against Core's asset roots and model-folder registration.
This is the source of truth for path-derived asset classification. For
model assets, ``model_folder`` is the exact registered folder name whose
base path contains the file, and ``relative_path`` is relative to that
matched base path. When multiple registered bases contain the file, the
deepest base wins.
"""
fp_abs = os.path.abspath(file_path)
def _check_is_within(child: str, parent: str) -> bool:
return Path(child).is_relative_to(parent)
def _compute_relative(child: str, parent: str) -> str:
# Normalize relative path, stripping any leading ".." components
# by anchoring to root (os.sep) then computing relpath back from it.
return os.path.relpath(
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
)
best: tuple[int, str, str, str] | None = None
for model_folder, bases in get_comfy_models_folders():
for base in bases:
base_abs = os.path.abspath(base)
if not _check_is_within(fp_abs, base_abs):
continue
cand = (
len(base_abs),
model_folder,
base_abs,
_compute_relative(fp_abs, base_abs),
)
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, model_folder, base_path, relative_path = best
return AssetPathContext(
asset_type="model",
model_folder=model_folder,
base_path=base_path,
relative_path=relative_path,
)
input_base = os.path.abspath(folder_paths.get_input_directory())
if _check_is_within(fp_abs, input_base):
return AssetPathContext(
asset_type="input",
model_folder=None,
base_path=input_base,
relative_path=_compute_relative(fp_abs, input_base),
)
output_base = os.path.abspath(folder_paths.get_output_directory())
if _check_is_within(fp_abs, output_base):
return AssetPathContext(
asset_type="output",
model_folder=None,
base_path=output_base,
relative_path=_compute_relative(fp_abs, output_base),
)
temp_base = os.path.abspath(folder_paths.get_temp_directory())
if _check_is_within(fp_abs, temp_base):
return AssetPathContext(
asset_type="temp",
model_folder=None,
base_path=temp_base,
relative_path=_compute_relative(fp_abs, temp_base),
)
raise ValueError(
f"Path is not within input, output, temp, or configured model bases: {file_path}"
)
def get_asset_category_and_relative_path( def get_asset_category_and_relative_path(
@ -108,51 +211,62 @@ def get_asset_category_and_relative_path(
Raises: Raises:
ValueError: path does not belong to any known root. ValueError: path does not belong to any known root.
""" """
fp_abs = os.path.abspath(file_path) context = resolve_asset_path_context(file_path)
if context.asset_type == "model":
combined = os.path.join(context.model_folder or "", context.relative_path)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
return context.asset_type, context.relative_path
def _check_is_within(child: str, parent: str) -> bool:
return Path(child).is_relative_to(parent)
def _compute_relative(child: str, parent: str) -> str: def get_asset_path_info(file_path: str) -> AssetPathInfo:
# Normalize relative path, stripping any leading ".." components """Return typed asset classification derived from the actual filesystem path.
# by anchoring to root (os.sep) then computing relpath back from it.
return os.path.relpath( This intentionally reads the ComfyUI model folder registration from
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep ``folder_paths.folder_names_and_paths`` instead of inferring it from tags.
For model files, ``model_folder`` is the registered folder name whose base
path contains ``file_path``.
Raises:
ValueError: path does not belong to any known root.
"""
context = resolve_asset_path_context(file_path)
return AssetPathInfo(
asset_type=context.asset_type,
model_folder=context.model_folder,
)
def get_asset_response_path_info(file_path: str) -> AssetResponsePathInfo:
"""Return API-facing path fields derived from the actual filesystem path.
``file_path`` is a logical namespace key: ``models/<model_folder>/<relative>``
for model assets and ``<asset_type>/<relative>`` for input/output/temp assets.
``display_name`` is the path below the matched root or model folder.
Raises:
ValueError: path does not belong to any known root.
"""
context = resolve_asset_path_context(file_path)
display_name = _normalize_relative_path(context.relative_path)
if context.asset_type == "model":
logical_file_path = (
f"models/{context.model_folder}/{display_name}"
if display_name
else f"models/{context.model_folder}"
)
else:
logical_file_path = (
f"{context.asset_type}/{display_name}"
if display_name
else context.asset_type
) )
# 1) input return AssetResponsePathInfo(
input_base = os.path.abspath(folder_paths.get_input_directory()) asset_type=context.asset_type,
if _check_is_within(fp_abs, input_base): model_folder=context.model_folder,
return "input", _compute_relative(fp_abs, input_base) file_path=logical_file_path,
display_name=display_name,
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _check_is_within(fp_abs, output_base):
return "output", _compute_relative(fp_abs, output_base)
# 3) temp
temp_base = os.path.abspath(folder_paths.get_temp_directory())
if _check_is_within(fp_abs, temp_base):
return "temp", _compute_relative(fp_abs, temp_base)
# 4) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _check_is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(
f"Path is not within input, output, temp, or configured model bases: {file_path}"
) )

View File

@ -14,6 +14,9 @@ from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
MODEL_FOLDER_BLACKLIST = frozenset({"configs", "custom_nodes"})
class ModelFileManager: class ModelFileManager:
def __init__(self) -> None: def __init__(self) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {} self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
@ -32,10 +35,9 @@ class ModelFileManager:
@routes.get("/experiment/models") @routes.get("/experiment/models")
async def get_model_folders(request): async def get_model_folders(request):
model_types = list(folder_paths.folder_names_and_paths.keys()) model_types = list(folder_paths.folder_names_and_paths.keys())
folder_black_list = ["configs", "custom_nodes"]
output_folders: list[dict] = [] output_folders: list[dict] = []
for folder in model_types: for folder in model_types:
if folder in folder_black_list: if folder in MODEL_FOLDER_BLACKLIST:
continue continue
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)}) output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
return web.json_response(output_folders) return web.json_response(output_folders)

View File

@ -1511,7 +1511,7 @@ paths:
in: query in: query
schema: schema:
type: integer type: integer
default: 50 default: 20
- name: offset - name: offset
in: query in: query
schema: schema:
@ -1535,6 +1535,17 @@ paths:
style: form style: form
explode: true explode: true
description: Tags that assets must not have description: Tags that assets must not have
- name: asset_type
in: query
schema:
type: string
enum: [model, input, output, temp]
description: Filter by path-derived asset type. Model classification is based on registered ComfyUI model folder roots, not tags.
- name: model_folder
in: query
schema:
type: string
description: Filter model assets by exact registered ComfyUI model folder name. Requires asset_type=model.
- name: name_contains - name: name_contains
in: query in: query
schema: schema:
@ -6607,10 +6618,23 @@ components:
id: id:
type: string type: string
format: uuid format: uuid
description: Unique identifier for the asset description: AssetReference ID. Use this as the stable identity; logical file_path is not unique.
name: name:
type: string type: string
description: Name of the asset file description: Name of the asset file
file_path:
type: string
description: Logical asset namespace path. Model assets use `models/<model_folder>/<relative_path>`; input/output/temp assets use `<asset_type>/<relative_path>`. Omitted when the asset has no resolvable filesystem path. Not unique; use `id` for stable asset-reference operations.
display_name:
type: string
description: Human-facing path below the matched asset root or model folder. Omitted when the asset has no resolvable filesystem path.
asset_type:
type: string
enum: [model, input, output, temp]
description: Path-derived asset type. Model classification is based on registered ComfyUI model folder roots, not tags. Omitted when the asset has no resolvable filesystem path.
model_folder:
type: string
description: Exact, case-sensitive registered ComfyUI model folder name for model assets. Present only when asset_type is `model`.
hash: hash:
type: string type: string
nullable: true nullable: true
@ -8723,4 +8747,4 @@ components:
items: items:
$ref: "#/components/schemas/TaskEntry" $ref: "#/components/schemas/TaskEntry"
pagination: pagination:
$ref: "#/components/schemas/PaginationInfo" $ref: "#/components/schemas/PaginationInfo"

View File

@ -1,8 +1,13 @@
import time import time
import uuid import uuid
from pathlib import Path
from unittest.mock import patch
import pytest import pytest
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.assets.api import schemas_in
from app.assets.api.routes import _order_tags_for_legacy_path_compat
from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta
from app.assets.database.queries import ( from app.assets.database.queries import (
reference_exists_for_asset_id, reference_exists_for_asset_id,
@ -25,6 +30,55 @@ from app.assets.database.queries import (
from app.assets.helpers import get_utc_now from app.assets.helpers import get_utc_now
class TestListAssetsQueryPathFilters:
def test_model_folder_requires_model_asset_type(self):
with pytest.raises(ValueError, match="model_folder can only be used"):
schemas_in.ListAssetsQuery.model_validate({"model_folder": "checkpoints"})
def test_model_folder_accepts_explicit_model_asset_type(self):
query = schemas_in.ListAssetsQuery.model_validate(
{"asset_type": "model", "model_folder": "checkpoints"}
)
assert query.asset_type == "model"
assert query.model_folder == "checkpoints"
def test_model_folder_rejects_non_model_asset_type(self):
with pytest.raises(ValueError, match="model_folder can only be used"):
schemas_in.ListAssetsQuery.model_validate(
{"asset_type": "input", "model_folder": "checkpoints"}
)
def test_query_layer_rejects_model_folder_without_model_asset_type(
self, session: Session
):
with pytest.raises(ValueError, match="model_folder can only be used"):
list_references_page(session, model_folder="checkpoints")
def test_upload_tags_preserve_model_folder_case_for_destination(self):
spec = schemas_in.UploadAssetSpec.model_validate(
{"tags": ['["models", "LLM", "SubDir"]']}
)
assert spec.tags == ["models", "LLM", "SubDir"]
class TestLegacyPathTagOrdering:
def test_model_response_moves_models_and_model_folder_first(self):
assert _order_tags_for_legacy_path_compat(
["blah", "checkpoints", "deep", "foo", "models"],
asset_type="model",
model_folder="checkpoints",
) == ["models", "checkpoints", "blah", "deep", "foo"]
def test_non_model_response_moves_root_tag_first(self):
assert _order_tags_for_legacy_path_compat(
["subdir", "output"],
asset_type="output",
model_folder=None,
) == ["output", "subdir"]
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset: def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream") asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream")
session.add(asset) session.add(asset)
@ -145,6 +199,369 @@ class TestListReferencesPage:
assert total == 1 assert total == 1
assert refs[0].name == "keep" assert refs[0].name == "keep"
def test_model_folder_filter_uses_registered_paths(self, session: Session, tmp_path: Path):
checkpoints_dir = tmp_path / "models" / "checkpoints"
loras_dir = tmp_path / "models" / "loras"
checkpoints_dir.mkdir(parents=True)
loras_dir.mkdir(parents=True)
asset = _make_asset(session, "hash1")
checkpoint = _make_reference(session, asset, name="checkpoint")
checkpoint.file_path = str(checkpoints_dir / "model.safetensors")
lora = _make_reference(session, asset, name="lora")
lora.file_path = str(loras_dir / "model.safetensors")
session.commit()
with patch(
"app.assets.database.queries.common._get_comfy_model_folders",
return_value=[
("checkpoints", [str(checkpoints_dir)]),
("loras", [str(loras_dir)]),
],
):
refs, _, total = list_references_page(
session, asset_type="model", model_folder="checkpoints"
)
assert total == 1
assert refs[0].id == checkpoint.id
def test_model_folder_filter_includes_all_registered_roots_for_folder(
self, session: Session, tmp_path: Path
):
checkpoints_a = tmp_path / "root_a" / "checkpoints"
checkpoints_b = tmp_path / "root_b" / "checkpoints"
checkpoints_a.mkdir(parents=True)
checkpoints_b.mkdir(parents=True)
asset = _make_asset(session, "hash1")
ref_a = _make_reference(session, asset, name="checkpoint-a")
ref_a.file_path = str(checkpoints_a / "a.safetensors")
ref_b = _make_reference(session, asset, name="checkpoint-b")
ref_b.file_path = str(checkpoints_b / "b.safetensors")
session.commit()
with patch(
"app.assets.database.queries.common._get_comfy_model_folders",
return_value=[("checkpoints", [str(checkpoints_a), str(checkpoints_b)])],
):
refs, _, total = list_references_page(
session, asset_type="model", model_folder="checkpoints"
)
assert total == 2
assert {ref.id for ref in refs} == {ref_a.id, ref_b.id}
def test_same_named_files_under_multiple_roots_both_return_in_model_folder_filter(
self, session: Session, tmp_path: Path
):
checkpoints_a = tmp_path / "root_a" / "checkpoints"
checkpoints_b = tmp_path / "root_b" / "checkpoints"
checkpoints_a.mkdir(parents=True)
checkpoints_b.mkdir(parents=True)
asset = _make_asset(session, "hash1")
ref_a = _make_reference(session, asset, name="checkpoint-a")
ref_a.file_path = str(checkpoints_a / "duplicate.safetensors")
ref_b = _make_reference(session, asset, name="checkpoint-b")
ref_b.file_path = str(checkpoints_b / "duplicate.safetensors")
session.commit()
with patch(
"app.assets.database.queries.common._get_comfy_model_folders",
return_value=[("checkpoints", [str(checkpoints_a), str(checkpoints_b)])],
):
refs, _, total = list_references_page(
session, asset_type="model", model_folder="checkpoints"
)
assert total == 2
assert {ref.id for ref in refs} == {ref_a.id, ref_b.id}
def test_arbitrary_registered_folder_filter_works(
self, session: Session, tmp_path: Path
):
controlnet_dir = tmp_path / "models" / "controlnet"
controlnet_dir.mkdir(parents=True)
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, name="controlnet")
ref.file_path = str(controlnet_dir / "pose.safetensors")
session.commit()
with patch(
"app.assets.database.queries.common._get_comfy_model_folders",
return_value=[("controlnet", [str(controlnet_dir)])],
):
refs, _, total = list_references_page(
session, asset_type="model", model_folder="controlnet"
)
assert total == 1
assert refs[0].id == ref.id
def test_unknown_model_folder_filter_returns_none_when_other_models_exist(
self, session: Session, tmp_path: Path
):
checkpoints_dir = tmp_path / "models" / "checkpoints"
checkpoints_dir.mkdir(parents=True)
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, name="checkpoint")
ref.file_path = str(checkpoints_dir / "model.safetensors")
session.commit()
with patch(
"app.assets.database.queries.common._get_comfy_model_folders",
return_value=[("checkpoints", [str(checkpoints_dir)])],
):
refs, _, total = list_references_page(
session, asset_type="model", model_folder="controlnet"
)
assert total == 0
assert refs == []
def test_model_folder_filter_excludes_deeper_registered_model_folder(
self, session: Session, tmp_path: Path
):
text_encoders_dir = tmp_path / "models" / "text_encoders"
clip_dir = text_encoders_dir / "clip"
clip_dir.mkdir(parents=True)
asset = _make_asset(session, "hash1")
text_encoder = _make_reference(session, asset, name="text_encoder")
text_encoder.file_path = str(text_encoders_dir / "t5xxl.safetensors")
clip = _make_reference(session, asset, name="clip")
clip.file_path = str(clip_dir / "clip_l.safetensors")
session.commit()
with patch(
"app.assets.database.queries.common._get_comfy_model_folders",
return_value=[
("text_encoders", [str(text_encoders_dir)]),
("text_encoders/clip", [str(clip_dir)]),
],
):
refs, _, total = list_references_page(
session, asset_type="model", model_folder="text_encoders"
)
assert total == 1
assert refs[0].id == text_encoder.id
def test_child_model_folder_filter_returns_only_child(
self, session: Session, tmp_path: Path
):
text_encoders_dir = tmp_path / "models" / "text_encoders"
clip_dir = text_encoders_dir / "clip"
clip_dir.mkdir(parents=True)
asset = _make_asset(session, "hash1")
text_encoder = _make_reference(session, asset, name="text_encoder")
text_encoder.file_path = str(text_encoders_dir / "t5xxl.safetensors")
clip = _make_reference(session, asset, name="clip")
clip.file_path = str(clip_dir / "clip_l.safetensors")
session.commit()
with patch(
"app.assets.database.queries.common._get_comfy_model_folders",
return_value=[
("text_encoders", [str(text_encoders_dir)]),
("text_encoders/clip", [str(clip_dir)]),
],
):
refs, _, total = list_references_page(
session, asset_type="model", model_folder="text_encoders/clip"
)
assert total == 1
assert refs[0].id == clip.id
def test_model_asset_type_filter_includes_parent_and_child_registered_roots(
self, session: Session, tmp_path: Path
):
text_encoders_dir = tmp_path / "models" / "text_encoders"
clip_dir = text_encoders_dir / "clip"
clip_dir.mkdir(parents=True)
asset = _make_asset(session, "hash1")
text_encoder = _make_reference(session, asset, name="text_encoder")
text_encoder.file_path = str(text_encoders_dir / "t5xxl.safetensors")
clip = _make_reference(session, asset, name="clip")
clip.file_path = str(clip_dir / "clip_l.safetensors")
session.commit()
with patch(
"app.assets.database.queries.common._get_comfy_model_folders",
return_value=[
("text_encoders", [str(text_encoders_dir)]),
("text_encoders/clip", [str(clip_dir)]),
],
):
refs, _, total = list_references_page(session, asset_type="model")
assert total == 2
assert {ref.id for ref in refs} == {text_encoder.id, clip.id}
def test_model_asset_type_filter_with_no_registered_paths_returns_none(
self, session: Session, tmp_path: Path
):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, name="orphan")
ref.file_path = str(tmp_path / "models" / "checkpoints" / "model.safetensors")
session.commit()
with patch(
"app.assets.database.queries.common._get_comfy_model_folders",
return_value=[],
):
refs, _, total = list_references_page(session, asset_type="model")
assert total == 0
assert refs == []
def test_model_asset_type_filter_excludes_unregistered_models_folder(
self, session: Session, tmp_path: Path
):
checkpoints_dir = tmp_path / "models" / "checkpoints"
unregistered_dir = tmp_path / "models" / "unregistered"
checkpoints_dir.mkdir(parents=True)
unregistered_dir.mkdir(parents=True)
asset = _make_asset(session, "hash1")
checkpoint = _make_reference(session, asset, name="checkpoint")
checkpoint.file_path = str(checkpoints_dir / "model.safetensors")
unregistered = _make_reference(session, asset, name="unregistered")
unregistered.file_path = str(unregistered_dir / "model.safetensors")
session.commit()
with patch(
"app.assets.database.queries.common._get_comfy_model_folders",
return_value=[("checkpoints", [str(checkpoints_dir)])],
):
refs, _, total = list_references_page(session, asset_type="model")
assert total == 1
assert refs[0].id == checkpoint.id
def test_model_asset_type_filter_respects_prefix_boundaries(
self, session: Session, tmp_path: Path
):
checkpoints_dir = tmp_path / "models" / "checkpoints"
checkpoints_extra_dir = tmp_path / "models" / "checkpoints_extra"
checkpoints_dir.mkdir(parents=True)
checkpoints_extra_dir.mkdir(parents=True)
asset = _make_asset(session, "hash1")
checkpoint = _make_reference(session, asset, name="checkpoint")
checkpoint.file_path = str(checkpoints_dir / "model.safetensors")
checkpoints_extra = _make_reference(session, asset, name="checkpoints_extra")
checkpoints_extra.file_path = str(checkpoints_extra_dir / "model.safetensors")
session.commit()
with patch(
"app.assets.database.queries.common._get_comfy_model_folders",
return_value=[("checkpoints", [str(checkpoints_dir)])],
):
refs, _, total = list_references_page(session, asset_type="model")
assert total == 1
assert refs[0].id == checkpoint.id
def test_model_asset_type_filter_is_case_exact(
self, session: Session, tmp_path: Path
):
registered_dir = tmp_path / "models" / "checkpoints"
case_sibling_dir = tmp_path / "MODELS" / "checkpoints"
registered_dir.mkdir(parents=True)
case_sibling_dir.mkdir(parents=True)
asset = _make_asset(session, "hash1")
checkpoint = _make_reference(session, asset, name="checkpoint")
checkpoint.file_path = str(registered_dir / "model.safetensors")
case_sibling = _make_reference(session, asset, name="case_sibling")
case_sibling.file_path = str(case_sibling_dir / "model.safetensors")
session.commit()
with patch(
"app.assets.database.queries.common._get_comfy_model_folders",
return_value=[("checkpoints", [str(registered_dir)])],
):
refs, _, total = list_references_page(session, asset_type="model")
assert total == 1
assert refs[0].id == checkpoint.id
def test_model_asset_type_filter_includes_output_backed_model_folder(
self, session: Session, tmp_path: Path
):
output_checkpoints_dir = tmp_path / "output" / "checkpoints"
output_checkpoints_dir.mkdir(parents=True)
asset = _make_asset(session, "hash1")
checkpoint_ref = _make_reference(session, asset, name="checkpoint")
checkpoint_ref.file_path = str(output_checkpoints_dir / "saved.safetensors")
session.commit()
with patch(
"app.assets.database.queries.common._get_comfy_model_folders",
return_value=[("checkpoints", [str(output_checkpoints_dir)])],
):
refs, _, total = list_references_page(session, asset_type="model")
assert total == 1
assert refs[0].id == checkpoint_ref.id
def test_asset_type_filter_uses_root_paths(self, session: Session, tmp_path: Path):
input_dir = tmp_path / "input"
output_dir = tmp_path / "output"
temp_dir = tmp_path / "temp"
for directory in (input_dir, output_dir, temp_dir):
directory.mkdir()
asset = _make_asset(session, "hash1")
input_ref = _make_reference(session, asset, name="input")
input_ref.file_path = str(input_dir / "image.png")
output_ref = _make_reference(session, asset, name="output")
output_ref.file_path = str(output_dir / "image.png")
session.commit()
with patch("app.assets.database.queries.common.folder_paths") as mock_fp:
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.get_temp_directory.return_value = str(temp_dir)
refs, _, total = list_references_page(session, asset_type="input")
assert total == 1
assert refs[0].id == input_ref.id
def test_output_asset_type_filter_excludes_output_backed_model_folders(
self, session: Session, tmp_path: Path
):
output_dir = tmp_path / "output"
output_checkpoints_dir = output_dir / "checkpoints"
output_checkpoints_dir.mkdir(parents=True)
asset = _make_asset(session, "hash1")
output_ref = _make_reference(session, asset, name="output")
output_ref.file_path = str(output_dir / "image.png")
checkpoint_ref = _make_reference(session, asset, name="checkpoint")
checkpoint_ref.file_path = str(output_checkpoints_dir / "saved.safetensors")
session.commit()
with patch("app.assets.database.queries.common.folder_paths") as mock_fp:
mock_fp.get_output_directory.return_value = str(output_dir)
with patch(
"app.assets.database.queries.common._get_comfy_model_folders",
return_value=[("checkpoints", [str(output_checkpoints_dir)])],
):
refs, _, total = list_references_page(session, asset_type="output")
assert total == 1
assert refs[0].id == output_ref.id
def test_sorting(self, session: Session): def test_sorting(self, session: Session):
asset = _make_asset(session, "hash1", size=100) asset = _make_asset(session, "hash1", size=100)
asset2 = _make_asset(session, "hash2", size=500) asset2 = _make_asset(session, "hash2", size=500)

View File

@ -255,6 +255,31 @@ class TestMarkReferencesMissingOutsidePrefixes:
assert marked == 0 assert marked == 0
def test_prefix_matching_is_case_exact(self, session: Session, tmp_path):
asset = _make_asset(session, "hash1")
valid_dir = tmp_path / "models" / "checkpoints"
case_sibling_dir = tmp_path / "MODELS" / "checkpoints"
valid_dir.mkdir(parents=True)
case_sibling_dir.mkdir(parents=True)
valid_path = str(valid_dir / "file.bin")
case_sibling_path = str(case_sibling_dir / "file.bin")
_make_reference(session, asset, valid_path, name="valid")
_make_reference(session, asset, case_sibling_path, name="case_sibling")
session.commit()
marked = mark_references_missing_outside_prefixes(session, [str(valid_dir)])
session.commit()
assert marked == 1
valid_ref = session.query(AssetReference).filter_by(file_path=valid_path).one()
case_sibling_ref = (
session.query(AssetReference).filter_by(file_path=case_sibling_path).one()
)
assert valid_ref.is_missing is False
assert case_sibling_ref.is_missing is True
class TestGetUnreferencedUnhashedAssetIds: class TestGetUnreferencedUnhashedAssetIds:
def test_returns_unreferenced_unhashed_assets(self, session: Session): def test_returns_unreferenced_unhashed_assets(self, session: Session):

View File

@ -64,6 +64,44 @@ class TestBatchInsertSeedAssets:
assert len(assets) == 1 assert len(assets) == 1
assert assets[0].mime_type is None assert assets[0].mime_type is None
def test_duplicate_paths_in_same_batch_preserve_first_spec(
self, session: Session, temp_dir: Path
):
file_path = temp_dir / "duplicate.safetensors"
file_path.write_bytes(b"fake safetensors content")
specs: list[SeedAssetSpec] = [
{
"abs_path": str(file_path),
"size_bytes": 24,
"mtime_ns": 1234567890000000000,
"info_name": "first",
"tags": ["models", "checkpoints"],
"fname": "duplicate.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
},
{
"abs_path": str(file_path),
"size_bytes": 24,
"mtime_ns": 1234567890000000000,
"info_name": "second",
"tags": ["output"],
"fname": "duplicate.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
},
]
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
refs = session.query(AssetReference).all()
assert len(refs) == 1
assert refs[0].name == "first"
def test_various_model_mime_types(self, session: Session, temp_dir: Path): def test_various_model_mime_types(self, session: Session, temp_dir: Path):
"""Verify various model file types get correct mime_type.""" """Verify various model file types get correct mime_type."""
test_cases = [ test_cases = [

View File

@ -6,7 +6,14 @@ from unittest.mock import patch
import pytest import pytest
from app.assets.services.path_utils import get_asset_category_and_relative_path from app.assets.services.path_utils import (
compute_relative_filename,
get_comfy_models_folders,
get_asset_category_and_relative_path,
get_asset_path_info,
get_asset_response_path_info,
resolve_asset_path_context,
)
@pytest.fixture @pytest.fixture
@ -79,3 +86,225 @@ class TestGetAssetCategoryAndRelativePath:
def test_unknown_path_raises(self, fake_dirs): def test_unknown_path_raises(self, fake_dirs):
with pytest.raises(ValueError, match="not within"): with pytest.raises(ValueError, match="not within"):
get_asset_category_and_relative_path("/some/random/path.png") get_asset_category_and_relative_path("/some/random/path.png")
class TestGetAssetPathInfo:
def test_get_comfy_models_folders_excludes_core_infrastructure(self, tmp_path: Path):
controlnet_dir = tmp_path / "models" / "controlnet"
configs_dir = tmp_path / "models" / "configs"
custom_nodes_dir = tmp_path / "custom_nodes"
for directory in (controlnet_dir, configs_dir, custom_nodes_dir):
directory.mkdir(parents=True)
with patch("app.assets.services.path_utils.folder_paths") as mock_fp:
mock_fp.folder_names_and_paths = {
"controlnet": ([str(controlnet_dir)], {".safetensors"}),
"configs": ([str(configs_dir)], {".yaml"}),
"custom_nodes": ([str(custom_nodes_dir)], set()),
}
folders = get_comfy_models_folders()
assert folders == [("controlnet", [str(controlnet_dir)])]
def test_model_file_uses_registered_model_folder(self, fake_dirs):
f = fake_dirs["models"] / "subdir" / "model.safetensors"
f.parent.mkdir()
f.touch()
info = get_asset_path_info(str(f))
assert info.asset_type == "model"
assert info.model_folder == "checkpoints"
response_info = get_asset_response_path_info(str(f))
assert response_info.file_path == "models/checkpoints/subdir/model.safetensors"
assert response_info.display_name == "subdir/model.safetensors"
def test_arbitrary_registered_folder_is_model_folder(self, fake_dirs):
controlnet_dir = fake_dirs["models"].parent / "controlnet"
controlnet_dir.mkdir()
f = controlnet_dir / "pose.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("controlnet", [str(controlnet_dir)])],
):
response_info = get_asset_response_path_info(str(f))
assert response_info.asset_type == "model"
assert response_info.model_folder == "controlnet"
assert response_info.file_path == "models/controlnet/pose.safetensors"
assert response_info.display_name == "pose.safetensors"
def test_multiple_physical_roots_for_same_model_folder(self, fake_dirs):
root_a = fake_dirs["models"]
root_b = fake_dirs["output"] / "checkpoints"
root_b.mkdir()
file_a = root_a / "subdir" / "model_a.safetensors"
file_b = root_b / "subdir" / "model_b.safetensors"
file_a.parent.mkdir()
file_b.parent.mkdir()
file_a.touch()
file_b.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("checkpoints", [str(root_a), str(root_b)])],
):
response_a = get_asset_response_path_info(str(file_a))
response_b = get_asset_response_path_info(str(file_b))
assert response_a.asset_type == response_b.asset_type == "model"
assert response_a.model_folder == response_b.model_folder == "checkpoints"
assert response_a.file_path == "models/checkpoints/subdir/model_a.safetensors"
assert response_b.file_path == "models/checkpoints/subdir/model_b.safetensors"
assert response_a.display_name == "subdir/model_a.safetensors"
assert response_b.display_name == "subdir/model_b.safetensors"
def test_same_named_files_under_multiple_roots_share_logical_file_path(self, fake_dirs):
root_a = fake_dirs["models"]
root_b = fake_dirs["output"] / "checkpoints"
root_b.mkdir()
file_a = root_a / "duplicate.safetensors"
file_b = root_b / "duplicate.safetensors"
file_a.touch()
file_b.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("checkpoints", [str(root_a), str(root_b)])],
):
response_a = get_asset_response_path_info(str(file_a))
response_b = get_asset_response_path_info(str(file_b))
assert response_a.file_path == response_b.file_path
assert response_a.file_path == "models/checkpoints/duplicate.safetensors"
assert response_a.display_name == response_b.display_name == "duplicate.safetensors"
def test_input_file_has_no_model_folder(self, fake_dirs):
f = fake_dirs["input"] / "subdir" / "photo.png"
f.parent.mkdir()
f.touch()
info = get_asset_path_info(str(f))
assert info.asset_type == "input"
assert info.model_folder is None
response_info = get_asset_response_path_info(str(f))
assert response_info.file_path == "input/subdir/photo.png"
assert response_info.display_name == "subdir/photo.png"
def test_output_backed_registered_model_folder_is_model(self, fake_dirs):
output_checkpoints_dir = fake_dirs["output"] / "checkpoints"
output_checkpoints_dir.mkdir()
f = output_checkpoints_dir / "saved.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("checkpoints", [str(output_checkpoints_dir)])],
):
context = resolve_asset_path_context(str(f))
response_info = get_asset_response_path_info(str(f))
assert context.asset_type == "model"
assert context.model_folder == "checkpoints"
assert context.relative_path == "saved.safetensors"
assert response_info.file_path == "models/checkpoints/saved.safetensors"
assert response_info.display_name == "saved.safetensors"
def test_registered_model_folder_can_contain_slash(self, fake_dirs):
nested_model_dir = fake_dirs["models"].parent / "text_encoders" / "clip"
nested_model_dir.mkdir(parents=True)
f = nested_model_dir / "clip.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("text_encoders/clip", [str(nested_model_dir)])],
):
info = get_asset_path_info(str(f))
response_info = get_asset_response_path_info(str(f))
assert info.asset_type == "model"
assert info.model_folder == "text_encoders/clip"
assert response_info.file_path == "models/text_encoders/clip/clip.safetensors"
assert response_info.display_name == "clip.safetensors"
def test_slash_model_folder_relative_filename_uses_registered_base(self, fake_dirs):
nested_model_dir = fake_dirs["models"].parent / "text_encoders" / "clip"
nested_model_dir.mkdir(parents=True)
f = nested_model_dir / "subdir" / "clip.safetensors"
f.parent.mkdir()
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[("text_encoders/clip", [str(nested_model_dir)])],
):
assert compute_relative_filename(str(f)) == "subdir/clip.safetensors"
def test_deepest_registered_model_base_wins(self, fake_dirs):
parent_dir = fake_dirs["models"].parent / "text_encoders"
nested_model_dir = parent_dir / "clip"
nested_model_dir.mkdir(parents=True)
f = nested_model_dir / "clip.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[
("text_encoders", [str(parent_dir)]),
("text_encoders/clip", [str(nested_model_dir)]),
],
):
context = resolve_asset_path_context(str(f))
assert context.asset_type == "model"
assert context.model_folder == "text_encoders/clip"
assert context.relative_path == "clip.safetensors"
def test_deepest_registered_model_base_wins_independent_of_registration_order(
self, fake_dirs
):
parent_dir = fake_dirs["models"].parent / "text_encoders"
nested_model_dir = parent_dir / "clip"
nested_model_dir.mkdir(parents=True)
f = nested_model_dir / "clip.safetensors"
f.touch()
with patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[
("text_encoders/clip", [str(nested_model_dir)]),
("text_encoders", [str(parent_dir)]),
],
):
context = resolve_asset_path_context(str(f))
assert context.asset_type == "model"
assert context.model_folder == "text_encoders/clip"
assert context.relative_path == "clip.safetensors"
def test_path_under_unregistered_models_folder_is_unknown(self, fake_dirs):
unregistered_dir = fake_dirs["models"].parent / "unregistered"
unregistered_dir.mkdir()
f = unregistered_dir / "model.safetensors"
f.touch()
with pytest.raises(ValueError, match="not within"):
resolve_asset_path_context(str(f))
def test_registered_model_folder_prefix_boundary(self, fake_dirs):
checkpoints_extra_dir = fake_dirs["models"].parent / "checkpoints_extra"
checkpoints_extra_dir.mkdir()
f = checkpoints_extra_dir / "model.safetensors"
f.touch()
with pytest.raises(ValueError, match="not within"):
resolve_asset_path_context(str(f))

View File

@ -23,10 +23,60 @@ from app.assets.database.queries.asset_reference import (
get_unenriched_references, get_unenriched_references,
restore_references_by_paths, restore_references_by_paths,
) )
from app.assets.scanner import sync_references_with_filesystem from app.assets.scanner import (
collect_paths_for_roots,
get_all_known_prefixes,
sync_references_with_filesystem,
)
from app.assets.services.file_utils import get_mtime_ns from app.assets.services.file_utils import get_mtime_ns
def test_collect_paths_for_roots_deduplicates_overlapping_roots(tmp_path: Path):
model_file = tmp_path / "output" / "checkpoints" / "saved.safetensors"
model_file.parent.mkdir(parents=True)
model_file.write_bytes(b"model")
with (
patch("app.assets.scanner.collect_models_files", return_value=[str(model_file)]),
patch(
"app.assets.scanner.list_files_recursively",
return_value=[str(model_file)],
),
patch("app.assets.scanner.folder_paths") as mock_folder_paths,
):
mock_folder_paths.get_output_directory.return_value = str(tmp_path / "output")
paths = collect_paths_for_roots(("models", "output"))
assert paths == [str(model_file)]
def test_all_known_prefixes_include_temp_root(tmp_path: Path):
models_dir = tmp_path / "models" / "checkpoints"
input_dir = tmp_path / "input"
output_dir = tmp_path / "output"
temp_dir = tmp_path / "temp"
for directory in (models_dir, input_dir, output_dir, temp_dir):
directory.mkdir(parents=True)
with (
patch("app.assets.scanner.get_comfy_models_folders", return_value=[("checkpoints", [str(models_dir)])]),
patch("app.assets.scanner.folder_paths") as mock_folder_paths,
):
mock_folder_paths.get_input_directory.return_value = str(input_dir)
mock_folder_paths.get_output_directory.return_value = str(output_dir)
mock_folder_paths.get_temp_directory.return_value = str(temp_dir)
prefixes = get_all_known_prefixes()
assert prefixes == [
str(models_dir),
str(input_dir),
str(output_dir),
str(temp_dir),
]
@pytest.fixture @pytest.fixture
def db_engine(): def db_engine():
engine = create_engine("sqlite:///:memory:") engine = create_engine("sqlite:///:memory:")