From bb9ed04758369aa4c452d788d64ea7fed6a0dc68 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Fri, 12 Sep 2025 18:14:52 +0300 Subject: [PATCH] global refactoring; add support for Assets without the computed hash --- alembic_db/versions/0001_assets.py | 50 +- app/__init__.py | 6 +- app/_assets_helpers.py | 75 +- app/api/assets_routes.py | 27 +- app/api/schemas_in.py | 59 +- app/api/schemas_out.py | 12 +- app/assets_manager.py | 188 ++--- app/assets_scanner.py | 707 ++++++++---------- app/database/_helpers.py | 186 ----- app/database/db.py | 10 +- app/database/helpers/__init__.py | 23 + app/database/helpers/filters.py | 87 +++ app/database/helpers/ownership.py | 12 + app/database/helpers/projection.py | 64 ++ app/database/helpers/tags.py | 102 +++ app/database/models.py | 96 +-- app/database/services.py | 1116 ---------------------------- app/database/services/__init__.py | 56 ++ app/database/services/content.py | 746 +++++++++++++++++++ app/database/services/info.py | 579 +++++++++++++++ app/database/services/queries.py | 59 ++ comfy/cli_args.py | 1 - main.py | 4 +- server.py | 3 +- tests-assets/test_crud.py | 19 +- tests-assets/test_tags.py | 21 +- tests-assets/test_uploads.py | 2 +- 27 files changed, 2380 insertions(+), 1930 deletions(-) delete mode 100644 app/database/_helpers.py create mode 100644 app/database/helpers/__init__.py create mode 100644 app/database/helpers/filters.py create mode 100644 app/database/helpers/ownership.py create mode 100644 app/database/helpers/projection.py create mode 100644 app/database/helpers/tags.py delete mode 100644 app/database/services.py create mode 100644 app/database/services/__init__.py create mode 100644 app/database/services/content.py create mode 100644 app/database/services/info.py create mode 100644 app/database/services/queries.py diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index bc98b5acf..1f5fb4622 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -16,33 +16,44 @@ depends_on = None def upgrade() -> None: - # ASSETS: content identity (deduplicated by hash) + # ASSETS: content identity op.create_table( "assets", - sa.Column("hash", sa.String(length=256), primary_key=True), + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("hash", sa.String(length=256), nullable=True), sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), sa.Column("mime_type", sa.String(length=255), nullable=True), sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), ) + if op.get_bind().dialect.name == "postgresql": + op.create_index( + "uq_assets_hash_not_null", + "assets", + ["hash"], + unique=True, + postgresql_where=sa.text("hash IS NOT NULL"), + ) + else: + op.create_index("uq_assets_hash", "assets", ["hash"], unique=True) op.create_index("ix_assets_mime_type", "assets", ["mime_type"]) - # ASSETS_INFO: user-visible references (mutable metadata) + # ASSETS_INFO: user-visible references op.create_table( "assets_info", sa.Column("id", sa.String(length=36), primary_key=True), sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), sa.Column("name", sa.String(length=512), nullable=False), - sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False), - sa.Column("preview_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True), sa.Column("user_metadata", sa.JSON(), nullable=True), sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), - sa.UniqueConstraint("asset_hash", "owner_id", "name", name="uq_assets_info_hash_owner_name"), + sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), ) op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) - op.create_index("ix_assets_info_asset_hash", "assets_info", ["asset_hash"]) + op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"]) op.create_index("ix_assets_info_name", "assets_info", ["name"]) op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) @@ -69,18 +80,19 @@ def upgrade() -> None: op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) - # ASSET_CACHE_STATE: N:1 local cache metadata rows per Asset + # ASSET_CACHE_STATE: N:1 local cache rows per Asset op.create_table( "asset_cache_state", sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), - sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False), sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")), sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), ) op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) - op.create_index("ix_asset_cache_state_asset_hash", "asset_cache_state", ["asset_hash"]) + op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"]) # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting op.create_table( @@ -99,7 +111,7 @@ def upgrade() -> None: op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) - # Tags vocabulary for models + # Tags vocabulary tags_table = sa.table( "tags", sa.column("name", sa.String(length=512)), @@ -108,12 +120,10 @@ def upgrade() -> None: op.bulk_insert( tags_table, [ - # Root folder tags {"name": "models", "tag_type": "system"}, {"name": "input", "tag_type": "system"}, {"name": "output", "tag_type": "system"}, - # Core tags {"name": "configs", "tag_type": "system"}, {"name": "checkpoints", "tag_type": "system"}, {"name": "loras", "tag_type": "system"}, @@ -132,12 +142,11 @@ def upgrade() -> None: {"name": "photomaker", "tag_type": "system"}, {"name": "classifiers", "tag_type": "system"}, - # Extra basic tags {"name": "encoder", "tag_type": "system"}, {"name": "decoder", "tag_type": "system"}, - # Special tags {"name": "missing", "tag_type": "system"}, + {"name": "rescan", "tag_type": "system"}, ], ) @@ -149,8 +158,9 @@ def downgrade() -> None: op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") op.drop_table("asset_info_meta") - op.drop_index("ix_asset_cache_state_asset_hash", table_name="asset_cache_state") + op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state") op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state") op.drop_table("asset_cache_state") op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") @@ -160,14 +170,18 @@ def downgrade() -> None: op.drop_index("ix_tags_tag_type", table_name="tags") op.drop_table("tags") - op.drop_constraint("uq_assets_info_hash_owner_name", table_name="assets_info") + op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info") op.drop_index("ix_assets_info_owner_name", table_name="assets_info") op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") op.drop_index("ix_assets_info_created_at", table_name="assets_info") op.drop_index("ix_assets_info_name", table_name="assets_info") - op.drop_index("ix_assets_info_asset_hash", table_name="assets_info") + op.drop_index("ix_assets_info_asset_id", table_name="assets_info") op.drop_index("ix_assets_info_owner_id", table_name="assets_info") op.drop_table("assets_info") + if op.get_bind().dialect.name == "postgresql": + op.drop_index("uq_assets_hash_not_null", table_name="assets") + else: + op.drop_index("uq_assets_hash", table_name="assets") op.drop_index("ix_assets_mime_type", table_name="assets") op.drop_table("assets") diff --git a/app/__init__.py b/app/__init__.py index 5fade97a4..e8538bd29 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,5 +1,5 @@ +from .assets_scanner import sync_seed_assets from .database.db import init_db_engine -from .assets_scanner import start_background_assets_scan +from .api.assets_routes import register_assets_system - -__all__ = ["init_db_engine", "start_background_assets_scan"] +__all__ = ["init_db_engine", "sync_seed_assets", "register_assets_system"] diff --git a/app/_assets_helpers.py b/app/_assets_helpers.py index 8fb88cd34..e0b982c98 100644 --- a/app/_assets_helpers.py +++ b/app/_assets_helpers.py @@ -1,12 +1,13 @@ +import contextlib import os +import uuid +from datetime import datetime, timezone from pathlib import Path -from typing import Optional, Literal, Sequence - -import sqlalchemy as sa +from typing import Literal, Optional, Sequence import folder_paths -from .database.models import AssetInfo +from .api import schemas_in def get_comfy_models_folders() -> list[tuple[str, list[str]]]: @@ -139,14 +140,6 @@ def ensure_within_base(candidate: str, base: str) -> None: raise ValueError("invalid destination path") -def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: - """Build owner visibility predicate for reads.""" - owner_id = (owner_id or "").strip() - if owner_id == "": - return AssetInfo.owner_id == "" - return AssetInfo.owner_id.in_(["", owner_id]) - - def compute_model_relative_filename(file_path: str) -> Optional[str]: """ Return the model's path relative to the last well-known folder (the model category), @@ -172,3 +165,61 @@ def compute_model_relative_filename(file_path: str) -> Optional[str]: return None inside = parts[1:] if len(parts) > 1 else [parts[0]] return "/".join(inside) # normalize to POSIX style for portability + + +def list_tree(base_dir: str) -> list[str]: + out: list[str] = [] + base_abs = os.path.abspath(base_dir) + if not os.path.isdir(base_abs): + return out + for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): + for name in filenames: + out.append(os.path.abspath(os.path.join(dirpath, name))) + return out + + +def prefixes_for_root(root: schemas_in.RootType) -> list[str]: + if root == "models": + bases: list[str] = [] + for _bucket, paths in get_comfy_models_folders(): + bases.extend(paths) + return [os.path.abspath(p) for p in bases] + if root == "input": + return [os.path.abspath(folder_paths.get_input_directory())] + if root == "output": + return [os.path.abspath(folder_paths.get_output_directory())] + return [] + + +def ts_to_iso(ts: Optional[float]) -> Optional[str]: + if ts is None: + return None + try: + return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat() + except Exception: + return None + + +def new_scan_id(root: schemas_in.RootType) -> str: + return f"scan-{root}-{uuid.uuid4().hex[:8]}" + + +def collect_models_files() -> list[str]: + out: list[str] = [] + for folder_name, bases in get_comfy_models_folders(): + rel_files = folder_paths.get_filename_list(folder_name) or [] + for rel_path in rel_files: + abs_path = folder_paths.get_full_path(folder_name, rel_path) + if not abs_path: + continue + abs_path = os.path.abspath(abs_path) + allowed = False + for b in bases: + base_abs = os.path.abspath(b) + with contextlib.suppress(Exception): + if os.path.commonpath([abs_path, base_abs]) == base_abs: + allowed = True + break + if allowed: + out.append(abs_path) + return out diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index e9a4ff97a..384c9f6c0 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -1,7 +1,7 @@ import contextlib import os -import uuid import urllib.parse +import uuid from typing import Optional from aiohttp import web @@ -12,7 +12,6 @@ import folder_paths from .. import assets_manager, assets_scanner, user_manager from . import schemas_in, schemas_out - ROUTES = web.RouteTableDef() UserManager: Optional[user_manager.UserManager] = None @@ -272,6 +271,7 @@ async def upload_asset(request: web.Request) -> web.Response: temp_path=tmp_path, client_filename=file_client_name, owner_id=owner_id, + expected_asset_hash=spec.hash, ) status = 201 if created.created_new else 200 return web.json_response(created.model_dump(mode="json"), status=status) @@ -332,6 +332,29 @@ async def update_asset(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) +@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview") +async def set_asset_preview(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + body = schemas_in.SetPreviewBody.model_validate(await request.json()) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = await assets_manager.set_asset_preview( + asset_info_id=asset_info_id, + preview_asset_id=body.preview_id, + owner_id=UserManager.get_request_user_id(request), + ) + except (PermissionError, ValueError) as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(result.model_dump(mode="json"), status=200) + + @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") async def delete_asset(request: web.Request) -> web.Response: asset_info_id = str(uuid.UUID(request.match_info["id"])) diff --git a/app/api/schemas_in.py b/app/api/schemas_in.py index 412b72e3a..bc521b313 100644 --- a/app/api/schemas_in.py +++ b/app/api/schemas_in.py @@ -1,7 +1,15 @@ import json +import uuid +from typing import Any, Literal, Optional -from typing import Any, Optional, Literal -from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator, conint +from pydantic import ( + BaseModel, + ConfigDict, + Field, + conint, + field_validator, + model_validator, +) class ListAssetsQuery(BaseModel): @@ -148,30 +156,12 @@ class TagsRemove(TagsAdd): pass -class ScheduleAssetScanBody(BaseModel): - roots: list[Literal["models","input","output"]] = Field(default_factory=list) +RootType = Literal["models", "input", "output"] +ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") - @field_validator("roots", mode="before") - @classmethod - def _normalize_roots(cls, v): - if v is None: - return [] - if isinstance(v, str): - items = [x.strip().lower() for x in v.split(",")] - elif isinstance(v, list): - items = [] - for x in v: - if isinstance(x, str): - items.extend([p.strip().lower() for p in x.split(",")]) - else: - return [] - out = [] - seen = set() - for r in items: - if r in {"models","input","output"} and r not in seen: - out.append(r) - seen.add(r) - return out + +class ScheduleAssetScanBody(BaseModel): + roots: list[RootType] = Field(..., min_length=1) class UploadAssetSpec(BaseModel): @@ -281,3 +271,22 @@ class UploadAssetSpec(BaseModel): if len(self.tags) < 2: raise ValueError("models uploads require a category tag as the second tag") return self + + +class SetPreviewBody(BaseModel): + """Set or clear the preview for an AssetInfo. Provide an Asset.id or null.""" + preview_id: Optional[str] = None + + @field_validator("preview_id", mode="before") + @classmethod + def _norm_uuid(cls, v): + if v is None: + return None + s = str(v).strip() + if not s: + return None + try: + uuid.UUID(s) + except Exception: + raise ValueError("preview_id must be a UUID") + return s diff --git a/app/api/schemas_out.py b/app/api/schemas_out.py index 8bb34096b..cc7e9572b 100644 --- a/app/api/schemas_out.py +++ b/app/api/schemas_out.py @@ -1,12 +1,13 @@ from datetime import datetime from typing import Any, Literal, Optional + from pydantic import BaseModel, ConfigDict, Field, field_serializer class AssetSummary(BaseModel): id: str name: str - asset_hash: str + asset_hash: Optional[str] size: Optional[int] = None mime_type: Optional[str] = None tags: list[str] = Field(default_factory=list) @@ -31,7 +32,7 @@ class AssetsList(BaseModel): class AssetUpdated(BaseModel): id: str name: str - asset_hash: str + asset_hash: Optional[str] tags: list[str] = Field(default_factory=list) user_metadata: dict[str, Any] = Field(default_factory=dict) updated_at: Optional[datetime] = None @@ -46,12 +47,12 @@ class AssetUpdated(BaseModel): class AssetDetail(BaseModel): id: str name: str - asset_hash: str + asset_hash: Optional[str] size: Optional[int] = None mime_type: Optional[str] = None tags: list[str] = Field(default_factory=list) user_metadata: dict[str, Any] = Field(default_factory=dict) - preview_hash: Optional[str] = None + preview_id: Optional[str] = None created_at: Optional[datetime] = None last_access_time: Optional[datetime] = None @@ -95,7 +96,6 @@ class TagsRemove(BaseModel): class AssetScanError(BaseModel): path: str message: str - phase: Literal["fast", "slow"] at: Optional[str] = Field(None, description="ISO timestamp") @@ -108,8 +108,6 @@ class AssetScanStatus(BaseModel): finished_at: Optional[str] = None discovered: int = 0 processed: int = 0 - slow_queue_total: int = 0 - slow_queue_finished: int = 0 file_errors: list[AssetScanError] = Field(default_factory=list) diff --git a/app/assets_manager.py b/app/assets_manager.py index 423a860e0..9d2424ce6 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -4,38 +4,39 @@ import mimetypes import os from typing import Optional, Sequence -from comfy.cli_args import args from comfy_api.internal import async_to_sync -from .database.db import create_session -from .storage import hashing -from .database.services import ( - check_fs_asset_exists_quick, - ingest_fs_asset, - touch_asset_infos_by_fs_path, - list_asset_infos_page, - update_asset_info_full, - get_asset_tags, - list_tags_with_usage, - add_tags_to_asset_info, - remove_tags_from_asset_info, - fetch_asset_info_and_asset, - touch_asset_info_by_id, - delete_asset_info_by_id, - asset_exists_by_hash, - get_asset_by_hash, - create_asset_info_for_existing_asset, - fetch_asset_info_asset_and_tags, - get_asset_info_by_id, - list_cache_states_by_asset_hash, - asset_info_exists_for_hash, -) -from .api import schemas_in, schemas_out from ._assets_helpers import ( - get_name_and_tags_from_asset_path, ensure_within_base, + get_name_and_tags_from_asset_path, resolve_destination_from_tags, ) +from .api import schemas_in, schemas_out +from .database.db import create_session +from .database.models import Asset +from .database.services import ( + add_tags_to_asset_info, + asset_exists_by_hash, + asset_info_exists_for_asset_id, + check_fs_asset_exists_quick, + create_asset_info_for_existing_asset, + delete_asset_info_by_id, + fetch_asset_info_and_asset, + fetch_asset_info_asset_and_tags, + get_asset_by_hash, + get_asset_info_by_id, + get_asset_tags, + ingest_fs_asset, + list_asset_infos_page, + list_cache_states_by_asset_id, + list_tags_with_usage, + remove_tags_from_asset_info, + set_asset_info_preview, + touch_asset_info_by_id, + touch_asset_infos_by_fs_path, + update_asset_info_full, +) +from .storage import hashing async def asset_exists(*, asset_hash: str) -> bool: @@ -44,29 +45,21 @@ async def asset_exists(*, asset_hash: str) -> bool: def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None: - if not args.enable_model_processing: - if tags is None: - tags = [] - try: - asset_name, path_tags = get_name_and_tags_from_asset_path(file_path) - async_to_sync.AsyncToSyncConverter.run_async_in_thread( - add_local_asset, - tags=list(dict.fromkeys([*path_tags, *tags])), - file_name=asset_name, - file_path=file_path, - ) - except ValueError as e: - logging.warning("Skipping non-asset path %s: %s", file_path, e) + if tags is None: + tags = [] + try: + asset_name, path_tags = get_name_and_tags_from_asset_path(file_path) + async_to_sync.AsyncToSyncConverter.run_async_in_thread( + add_local_asset, + tags=list(dict.fromkeys([*path_tags, *tags])), + file_name=asset_name, + file_path=file_path, + ) + except ValueError as e: + logging.warning("Skipping non-asset path %s: %s", file_path, e) async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None: - """Adds a local asset to the DB. If already present and unchanged, does nothing. - - Notes: - - Uses absolute path as the canonical locator for the cache backend. - - Computes BLAKE3 only when the fast existence check indicates it's needed. - - This function ensures the identity row and seeds mtime in asset_cache_state. - """ abs_path = os.path.abspath(file_path) size_bytes, mtime_ns = _get_size_mtime_ns(abs_path) if not size_bytes: @@ -132,7 +125,7 @@ async def list_assets( schemas_out.AssetSummary( id=info.id, name=info.name, - asset_hash=info.asset_hash, + asset_hash=asset.hash if asset else None, size=int(asset.size_bytes) if asset else None, mime_type=asset.mime_type if asset else None, tags=tags, @@ -156,16 +149,17 @@ async def get_asset(*, asset_info_id: str, owner_id: str = "") -> schemas_out.As if not res: raise ValueError(f"AssetInfo {asset_info_id} not found") info, asset, tag_names = res + preview_id = info.preview_id return schemas_out.AssetDetail( id=info.id, name=info.name, - asset_hash=info.asset_hash, + asset_hash=asset.hash if asset else None, size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, mime_type=asset.mime_type if asset else None, tags=tag_names, - preview_hash=info.preview_hash, user_metadata=info.user_metadata or {}, + preview_id=preview_id, created_at=info.created_at, last_access_time=info.last_access_time, ) @@ -176,20 +170,13 @@ async def resolve_asset_content_for_download( asset_info_id: str, owner_id: str = "", ) -> tuple[str, str, str]: - """ - Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time. - Also touches last_access_time (only_if_newer). - Raises: - ValueError if AssetInfo cannot be found - FileNotFoundError if file for Asset cannot be found - """ async with await create_session() as session: pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id) if not pair: raise ValueError(f"AssetInfo {asset_info_id} not found") info, asset = pair - states = await list_cache_states_by_asset_hash(session, asset_hash=info.asset_hash) + states = await list_cache_states_by_asset_id(session, asset_id=asset.id) abs_path = "" for s in states: if s and s.file_path and os.path.isfile(s.file_path): @@ -214,16 +201,6 @@ async def upload_asset_from_temp_path( owner_id: str = "", expected_asset_hash: Optional[str] = None, ) -> schemas_out.AssetCreated: - """ - Finalize an uploaded temp file: - - compute blake3 hash - - if expected_asset_hash provided, verify equality (400 on mismatch at caller) - - if an Asset with the same hash exists: discard temp, create AssetInfo only (no write) - - else resolve destination from tags and atomically move into place - - ingest into DB (assets, locator state, asset_info + tags) - Returns a populated AssetCreated payload. - """ - try: digest = await hashing.blake3_hash(temp_path) except Exception as e: @@ -233,7 +210,6 @@ async def upload_asset_from_temp_path( if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower(): raise ValueError("HASH_MISMATCH") - # Fast path: content already known --> no writes, just create a reference async with await create_session() as session: existing = await get_asset_by_hash(session, asset_hash=asset_hash) if existing is not None: @@ -257,43 +233,37 @@ async def upload_asset_from_temp_path( return schemas_out.AssetCreated( id=info.id, name=info.name, - asset_hash=info.asset_hash, + asset_hash=existing.hash, size=int(existing.size_bytes) if existing.size_bytes is not None else None, mime_type=existing.mime_type, tags=tag_names, user_metadata=info.user_metadata or {}, - preview_hash=info.preview_hash, + preview_id=info.preview_id, created_at=info.created_at, last_access_time=info.last_access_time, created_new=False, ) - # Resolve destination (only for truly new content) base_dir, subdirs = resolve_destination_from_tags(spec.tags) dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir os.makedirs(dest_dir, exist_ok=True) - # Decide filename desired_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest) dest_abs = os.path.abspath(os.path.join(dest_dir, desired_name)) ensure_within_base(dest_abs, base_dir) - # Content type based on final name content_type = mimetypes.guess_type(desired_name, strict=False)[0] or "application/octet-stream" - # Atomic move into place try: os.replace(temp_path, dest_abs) except Exception as e: raise RuntimeError(f"failed to move uploaded file into place: {e}") - # Stat final file try: size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs) except OSError as e: raise RuntimeError(f"failed to stat destination file: {e}") - # Ingest + build response async with await create_session() as session: result = await ingest_fs_asset( session, @@ -304,7 +274,7 @@ async def upload_asset_from_temp_path( mime_type=content_type, info_name=os.path.basename(dest_abs), owner_id=owner_id, - preview_hash=None, + preview_id=None, user_metadata=spec.user_metadata or {}, tags=spec.tags, tag_origin="manual", @@ -324,12 +294,12 @@ async def upload_asset_from_temp_path( return schemas_out.AssetCreated( id=info.id, name=info.name, - asset_hash=info.asset_hash, + asset_hash=asset.hash, size=int(asset.size_bytes), mime_type=asset.mime_type, tags=tag_names, user_metadata=info.user_metadata or {}, - preview_hash=info.preview_hash, + preview_id=info.preview_id, created_at=info.created_at, last_access_time=info.last_access_time, created_new=result["asset_created"], @@ -367,38 +337,74 @@ async def update_asset( return schemas_out.AssetUpdated( id=info.id, name=info.name, - asset_hash=info.asset_hash, + asset_hash=info.asset.hash if info.asset else None, tags=tag_names, user_metadata=info.user_metadata or {}, updated_at=info.updated_at, ) -async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool: - """Delete single AssetInfo. If this was the last reference to Asset and delete_content_if_orphan=True (default), - delete the Asset row as well and remove all cached files recorded for that asset_hash. - """ +async def set_asset_preview( + *, + asset_info_id: str, + preview_asset_id: Optional[str], + owner_id: str = "", +) -> schemas_out.AssetDetail: async with await create_session() as session: info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) - asset_hash = info_row.asset_hash if info_row else None + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + + await set_asset_info_preview( + session, + asset_info_id=asset_info_id, + preview_asset_id=preview_asset_id, + ) + + res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not res: + raise RuntimeError("State changed during preview update") + info, asset, tags = res + await session.commit() + + return schemas_out.AssetDetail( + id=info.id, + name=info.name, + asset_hash=asset.hash if asset else None, + size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, + mime_type=asset.mime_type if asset else None, + tags=tags, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + ) + + +async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool: + async with await create_session() as session: + info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) + asset_id = info_row.asset_id if info_row else None deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) if not deleted: await session.commit() return False - if not delete_content_if_orphan or not asset_hash: + if not delete_content_if_orphan or not asset_id: await session.commit() return True - still_exists = await asset_info_exists_for_hash(session, asset_hash=asset_hash) + still_exists = await asset_info_exists_for_asset_id(session, asset_id=asset_id) if still_exists: await session.commit() return True - states = await list_cache_states_by_asset_hash(session, asset_hash=asset_hash) + states = await list_cache_states_by_asset_id(session, asset_id=asset_id) file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)] - asset_row = await get_asset_by_hash(session, asset_hash=asset_hash) + asset_row = await session.get(Asset, asset_id) if asset_row is not None: await session.delete(asset_row) @@ -439,12 +445,12 @@ async def create_asset_from_hash( return schemas_out.AssetCreated( id=info.id, name=info.name, - asset_hash=info.asset_hash, + asset_hash=asset.hash, size=int(asset.size_bytes), mime_type=asset.mime_type, tags=tag_names, user_metadata=info.user_metadata or {}, - preview_hash=info.preview_hash, + preview_id=info.preview_id, created_at=info.created_at, last_access_time=info.last_access_time, created_new=False, diff --git a/app/assets_scanner.py b/app/assets_scanner.py index a77f87771..6cca5b165 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -1,52 +1,55 @@ import asyncio -import contextlib import logging import os -import uuid import time from dataclasses import dataclass, field -from datetime import datetime, timezone -from typing import Callable, Literal, Optional, Sequence +from typing import Literal, Optional + +import sqlalchemy as sa import folder_paths -from . import assets_manager -from .api import schemas_out -from ._assets_helpers import get_comfy_models_folders +from ._assets_helpers import ( + collect_models_files, + get_comfy_models_folders, + get_name_and_tags_from_asset_path, + list_tree, + new_scan_id, + prefixes_for_root, + ts_to_iso, +) +from .api import schemas_in, schemas_out from .database.db import create_session +from .database.helpers import ( + add_missing_tag_for_asset_id, + remove_missing_tag_for_asset_id, +) +from .database.models import Asset, AssetCacheState, AssetInfo from .database.services import ( - check_fs_asset_exists_quick, + compute_hash_and_dedup_for_cache_state, + ensure_seed_for_path, + list_cache_states_by_asset_id, list_cache_states_with_asset_under_prefixes, - add_missing_tag_for_asset_hash, - remove_missing_tag_for_asset_hash, + list_unhashed_candidates_under_prefixes, + list_verify_candidates_under_prefixes, ) LOGGER = logging.getLogger(__name__) -RootType = Literal["models", "input", "output"] -ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") - SLOW_HASH_CONCURRENCY = 1 @dataclass class ScanProgress: scan_id: str - root: RootType + root: schemas_in.RootType status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled" scheduled_at: float = field(default_factory=lambda: time.time()) started_at: Optional[float] = None finished_at: Optional[float] = None - discovered: int = 0 processed: int = 0 - slow_queue_total: int = 0 - slow_queue_finished: int = 0 - file_errors: list[dict] = field(default_factory=list) # {"path","message","phase","at"} - - # Internal diagnostics for logs - _fast_total_seen: int = 0 - _fast_clean: int = 0 + file_errors: list[dict] = field(default_factory=list) @dataclass @@ -56,18 +59,14 @@ class SlowQueueState: closed: bool = False -RUNNING_TASKS: dict[RootType, asyncio.Task] = {} -PROGRESS_BY_ROOT: dict[RootType, ScanProgress] = {} -SLOW_STATE_BY_ROOT: dict[RootType, SlowQueueState] = {} - - -async def start_background_assets_scan(): - await fast_reconcile_and_kickoff(progress_cb=_console_cb) +RUNNING_TASKS: dict[schemas_in.RootType, asyncio.Task] = {} +PROGRESS_BY_ROOT: dict[schemas_in.RootType, ScanProgress] = {} +SLOW_STATE_BY_ROOT: dict[schemas_in.RootType, SlowQueueState] = {} def current_statuses() -> schemas_out.AssetScanStatusResponse: scans = [] - for root in ALLOWED_ROOTS: + for root in schemas_in.ALLOWED_ROOTS: prog = PROGRESS_BY_ROOT.get(root) if not prog: continue @@ -75,83 +74,65 @@ def current_statuses() -> schemas_out.AssetScanStatusResponse: return schemas_out.AssetScanStatusResponse(scans=scans) -async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusResponse: - """Schedule scans for the provided roots; returns progress snapshots. - - Rules: - - Only roots in {models, input, output} are accepted. - - If a root is already scanning, we do NOT enqueue another one. Status returned as-is. - - Otherwise a new task is created and started immediately. - - Files with zero size are skipped. - """ - normalized: list[RootType] = [] - seen = set() - for r in roots or []: - rr = r.strip().lower() - if rr in ALLOWED_ROOTS and rr not in seen: - normalized.append(rr) # type: ignore - seen.add(rr) - if not normalized: - normalized = list(ALLOWED_ROOTS) # schedule all by default - +async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetScanStatusResponse: results: list[ScanProgress] = [] - for root in normalized: + for root in roots: if root in RUNNING_TASKS and not RUNNING_TASKS[root].done(): results.append(PROGRESS_BY_ROOT[root]) continue - prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled") + prog = ScanProgress(scan_id=new_scan_id(root), root=root, status="scheduled") PROGRESS_BY_ROOT[root] = prog - SLOW_STATE_BY_ROOT[root] = SlowQueueState(queue=asyncio.Queue()) + state = SlowQueueState(queue=asyncio.Queue()) + SLOW_STATE_BY_ROOT[root] = state RUNNING_TASKS[root] = asyncio.create_task( - _pipeline_for_root(root, prog, progress_cb=None), + _run_hash_verify_pipeline(root, prog, state), name=f"asset-scan:{root}", ) results.append(prog) return _status_response_for(results) -async def fast_reconcile_and_kickoff( - roots: Optional[Sequence[str]] = None, - *, - progress_cb: Optional[Callable[[str, str, int, bool, dict], None]] = None, -) -> schemas_out.AssetScanStatusResponse: - """ - Startup helper: do the fast pass now (so we know queue size), - start slow hashing in the background, return immediately. - """ - normalized = [*ALLOWED_ROOTS] if not roots else [r for r in roots if r in ALLOWED_ROOTS] - snaps: list[ScanProgress] = [] - - for root in normalized: - if root in RUNNING_TASKS and not RUNNING_TASKS[root].done(): - snaps.append(PROGRESS_BY_ROOT[root]) - continue - - prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled") - PROGRESS_BY_ROOT[root] = prog - state = SlowQueueState(queue=asyncio.Queue()) - SLOW_STATE_BY_ROOT[root] = state - - prog.status = "running" - prog.started_at = time.time() +async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None: + for r in roots: try: - await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb) - except Exception as e: - _append_error(prog, phase="fast", path="", message=str(e)) - prog.status = "failed" - prog.finished_at = time.time() - LOGGER.exception("Fast reconcile failed for %s", root) - snaps.append(prog) + await _fast_db_consistency_pass(r) + except Exception as ex: + LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex) + + paths: list[str] = [] + if "models" in roots: + paths.extend(collect_models_files()) + if "input" in roots: + paths.extend(list_tree(folder_paths.get_input_directory())) + if "output" in roots: + paths.extend(list_tree(folder_paths.get_output_directory())) + + for p in paths: + try: + st = os.stat(p, follow_symlinks=True) + if not int(st.st_size or 0): + continue + size_bytes = int(st.st_size) + mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + name, tags = get_name_and_tags_from_asset_path(p) + await _seed_one_async(p, size_bytes, mtime_ns, name, tags) + except OSError: continue - _start_slow_workers(root, prog, state, progress_cb=progress_cb) - RUNNING_TASKS[root] = asyncio.create_task( - _await_workers_then_finish(root, prog, state, progress_cb=progress_cb), - name=f"asset-hash:{root}", + +async def _seed_one_async(p: str, size_bytes: int, mtime_ns: int, name: str, tags: list[str]) -> None: + async with await create_session() as sess: + await ensure_seed_for_path( + sess, + abs_path=p, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + info_name=name, + tags=tags, + owner_id="", ) - snaps.append(prog) - return _status_response_for(snaps) + await sess.commit() def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse: @@ -163,18 +144,15 @@ def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.A scan_id=progress.scan_id, root=progress.root, status=progress.status, - scheduled_at=_ts_to_iso(progress.scheduled_at), - started_at=_ts_to_iso(progress.started_at), - finished_at=_ts_to_iso(progress.finished_at), + scheduled_at=ts_to_iso(progress.scheduled_at), + started_at=ts_to_iso(progress.started_at), + finished_at=ts_to_iso(progress.finished_at), discovered=progress.discovered, processed=progress.processed, - slow_queue_total=progress.slow_queue_total, - slow_queue_finished=progress.slow_queue_finished, file_errors=[ schemas_out.AssetScanError( path=e.get("path", ""), message=e.get("message", ""), - phase=e.get("phase", "slow"), at=e.get("at"), ) for e in (progress.file_errors or []) @@ -182,27 +160,100 @@ def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.A ) -async def _pipeline_for_root( - root: RootType, - prog: ScanProgress, - progress_cb: Optional[Callable[[str, str, int, bool, dict], None]], -) -> None: - state = SLOW_STATE_BY_ROOT.get(root) or SlowQueueState(queue=asyncio.Queue()) - SLOW_STATE_BY_ROOT[root] = state +async def _refresh_verify_flags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None: + """Fast pass to mark verify candidates by comparing stored mtime_ns with on-disk mtime.""" + prefixes = prefixes_for_root(root) + if not prefixes: + return + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + conds.append(AssetCacheState.file_path.like(base + "%")) + + async with await create_session() as sess: + rows = ( + await sess.execute( + sa.select( + AssetCacheState.id, + AssetCacheState.mtime_ns, + AssetCacheState.needs_verify, + Asset.hash, + AssetCacheState.file_path, + ) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(sa.or_(*conds)) + ) + ).all() + + to_set = [] + to_clear = [] + for sid, mtime_db, needs_verify, a_hash, fp in rows: + try: + st = os.stat(fp, follow_symlinks=True) + except OSError: + # Missing files are handled by missing-tag reconciliation later. + continue + + actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + if a_hash is not None: + if mtime_db is None or int(mtime_db) != int(actual_mtime_ns): + if not needs_verify: + to_set.append(sid) + else: + if needs_verify: + to_clear.append(sid) + + if to_set: + await sess.execute( + sa.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_set)) + .values(needs_verify=True) + ) + if to_clear: + await sess.execute( + sa.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_clear)) + .values(needs_verify=False) + ) + await sess.commit() + + +async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None: prog.status = "running" prog.started_at = time.time() - try: - await _reconcile_missing_tags_for_root(root, prog) - await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb) - _start_slow_workers(root, prog, state, progress_cb=progress_cb) - await _await_workers_then_finish(root, prog, state, progress_cb=progress_cb) + prefixes = prefixes_for_root(root) + + await _refresh_verify_flags_for_root(root, prog) + + # collect candidates from DB + async with await create_session() as sess: + verify_ids = await list_verify_candidates_under_prefixes(sess, prefixes=prefixes) + unhashed_ids = await list_unhashed_candidates_under_prefixes(sess, prefixes=prefixes) + # dedupe: prioritize verification first + seen = set() + ordered: list[int] = [] + for lst in (verify_ids, unhashed_ids): + for sid in lst: + if sid not in seen: + seen.add(sid); ordered.append(sid) + + prog.discovered = len(ordered) + + # queue up work + for sid in ordered: + await state.queue.put(sid) + state.closed = True + _start_state_workers(root, prog, state) + await _await_state_workers_then_finish(root, prog, state) except asyncio.CancelledError: prog.status = "cancelled" raise except Exception as exc: - _append_error(prog, phase="slow", path="", message=str(exc)) + _append_error(prog, path="", message=str(exc)) prog.status = "failed" prog.finished_at = time.time() LOGGER.exception("Asset scan failed for %s", root) @@ -210,110 +261,13 @@ async def _pipeline_for_root( RUNNING_TASKS.pop(root, None) -async def _fast_reconcile_into_queue( - root: RootType, - prog: ScanProgress, - state: SlowQueueState, - *, - progress_cb: Optional[Callable[[str, str, int, bool, dict], None]], -) -> None: +async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None: """ - Enumerate files, set 'discovered' to total files seen, increment 'processed' for fast-matched files, - and queue the rest for slow hashing. - """ - if root == "models": - files = _collect_models_files() - preset_discovered = _count_nonzero_in_list(files) - files_iter = asyncio.Queue() - for p in files: - await files_iter.put(p) - await files_iter.put(None) # sentinel for our local draining loop - elif root == "input": - base = folder_paths.get_input_directory() - preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True) - files_iter = await _queue_tree_files(base) - elif root == "output": - base = folder_paths.get_output_directory() - preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True) - files_iter = await _queue_tree_files(base) - else: - raise RuntimeError(f"Unsupported root: {root}") + Detect missing files quickly and toggle 'missing' tag per asset_id. - prog.discovered = int(preset_discovered or 0) - - queued = 0 - checked = 0 - clean = 0 - - async with await create_session() as sess: - while True: - item = await files_iter.get() - files_iter.task_done() - if item is None: - break - - abs_path = item - checked += 1 - - # Stat; skip empty/unreadable - try: - st = os.stat(abs_path, follow_symlinks=True) - if not st.st_size: - continue - size_bytes = int(st.st_size) - mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) - except OSError as e: - _append_error(prog, phase="fast", path=abs_path, message=str(e)) - continue - - try: - known = await check_fs_asset_exists_quick( - sess, - file_path=abs_path, - size_bytes=size_bytes, - mtime_ns=mtime_ns, - ) - except Exception as e: - _append_error(prog, phase="fast", path=abs_path, message=str(e)) - known = False - - if known: - clean += 1 - prog.processed += 1 - else: - await state.queue.put(abs_path) - queued += 1 - prog.slow_queue_total += 1 - - if progress_cb: - progress_cb(root, "fast", prog.processed, False, { - "checked": checked, - "clean": clean, - "queued": queued, - "discovered": prog.discovered, - }) - - prog._fast_total_seen = checked - prog._fast_clean = clean - - if progress_cb: - progress_cb(root, "fast", prog.processed, True, { - "checked": checked, - "clean": clean, - "queued": queued, - "discovered": prog.discovered, - }) - state.closed = True - - -async def _reconcile_missing_tags_for_root(root: RootType, prog: ScanProgress) -> None: - """ - Logic for detecting missing Assets files: - - Clear 'missing' only if at least one cached path passes fast check: - exists AND mtime_ns matches AND size matches. - - Otherwise set 'missing'. - Files that exist but fail fast check will be slow-hashed by the normal pipeline, - and ingest_fs_asset will clear 'missing' if they truly match. + Rules: + - Only hashed assets (assets.hash != NULL) participate in missing tagging. + - We consider ALL cache states of the asset (across roots) before tagging. """ if root == "models": bases: list[str] = [] @@ -326,232 +280,217 @@ async def _reconcile_missing_tags_for_root(root: RootType, prog: ScanProgress) - try: async with await create_session() as sess: + # state + hash + size for the current root rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases) - by_hash: dict[str, dict[str, bool]] = {} # {hash: {"any_fast_ok": bool}} - for state, size_db in rows: - h = state.asset_hash - acc = by_hash.get(h) + # Track fast_ok within the scanned root and whether the asset is hashed + by_asset: dict[str, dict[str, bool]] = {} + for state, a_hash, size_db in rows: + aid = state.asset_id + acc = by_asset.get(aid) if acc is None: - acc = {"any_fast_ok": False} - by_hash[h] = acc + acc = {"any_fast_ok_here": False, "hashed": (a_hash is not None), "size_db": int(size_db or 0)} + by_asset[aid] = acc try: st = os.stat(state.file_path, follow_symlinks=True) actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) fast_ok = False - if state.mtime_ns is not None and int(state.mtime_ns) == int(actual_mtime_ns): - if int(size_db) > 0 and int(st.st_size) == int(size_db): - fast_ok = True + if acc["hashed"]: + if state.mtime_ns is not None and int(state.mtime_ns) == int(actual_mtime_ns): + if int(acc["size_db"]) > 0 and int(st.st_size) == int(acc["size_db"]): + fast_ok = True if fast_ok: - acc["any_fast_ok"] = True + acc["any_fast_ok_here"] = True except FileNotFoundError: - pass # not fast_ok + pass except OSError as e: - _append_error(prog, phase="fast", path=state.file_path, message=str(e)) + _append_error(prog, path=state.file_path, message=str(e)) - for h, acc in by_hash.items(): + # Decide per asset, considering ALL its states (not just this root) + for aid, acc in by_asset.items(): try: - if acc["any_fast_ok"]: - await remove_missing_tag_for_asset_hash(sess, asset_hash=h) + if not acc["hashed"]: + # Never tag seed assets as missing + continue + + any_fast_ok_global = acc["any_fast_ok_here"] + if not any_fast_ok_global: + # Check other states outside this root + others = await list_cache_states_by_asset_id(sess, asset_id=aid) + for st in others: + try: + s = os.stat(st.file_path, follow_symlinks=True) + actual_mtime_ns = getattr(s, "st_mtime_ns", int(s.st_mtime * 1_000_000_000)) + if st.mtime_ns is not None and int(st.mtime_ns) == int(actual_mtime_ns): + if acc["size_db"] > 0 and int(s.st_size) == acc["size_db"]: + any_fast_ok_global = True + break + except OSError: + continue + + if any_fast_ok_global: + await remove_missing_tag_for_asset_id(sess, asset_id=aid) else: - await add_missing_tag_for_asset_hash(sess, asset_hash=h, origin="automatic") + await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") except Exception as ex: - _append_error(prog, phase="fast", path="", message=f"reconcile {h[:18]}: {ex}") + _append_error(prog, path="", message=f"reconcile {aid[:8]}: {ex}") await sess.commit() except Exception as e: - _append_error(prog, phase="fast", path="", message=f"reconcile failed: {e}") + _append_error(prog, path="", message=f"reconcile failed: {e}") -def _start_slow_workers( - root: RootType, - prog: ScanProgress, - state: SlowQueueState, - *, - progress_cb: Optional[Callable[[str, str, int, bool, dict], None]], -) -> None: +def _start_state_workers(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None: if state.workers: return - async def _worker(_worker_id: int): + async def _worker(_wid: int): while True: - item = await state.queue.get() + sid = await state.queue.get() try: - if item is None: + if sid is None: return try: - await asyncio.to_thread(assets_manager.populate_db_with_asset, item) - except Exception as e: - _append_error(prog, phase="slow", path=item, message=str(e)) + async with await create_session() as sess: + # Optional: fetch path for better error messages + st = await sess.get(AssetCacheState, sid) + try: + await compute_hash_and_dedup_for_cache_state(sess, state_id=sid) + await sess.commit() + except Exception as e: + path = st.file_path if st else f"state:{sid}" + _append_error(prog, path=path, message=str(e)) + raise + except Exception: + pass finally: - # Slow queue finished for this item; also counts toward overall processed - prog.slow_queue_finished += 1 prog.processed += 1 - if progress_cb: - progress_cb(root, "slow", prog.processed, False, { - "slow_queue_finished": prog.slow_queue_finished, - "slow_queue_total": prog.slow_queue_total, - }) finally: state.queue.task_done() - state.workers = [asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}") for i in range(SLOW_HASH_CONCURRENCY)] + state.workers = [ + asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}") + for i in range(SLOW_HASH_CONCURRENCY) + ] - async def _close_when_empty(): - # When the fast phase closed the queue, push sentinels to end workers + async def _close_when_ready(): while not state.closed: await asyncio.sleep(0.05) for _ in range(SLOW_HASH_CONCURRENCY): await state.queue.put(None) - asyncio.create_task(_close_when_empty()) + asyncio.create_task(_close_when_ready()) -async def _await_workers_then_finish( - root: RootType, - prog: ScanProgress, - state: SlowQueueState, - *, - progress_cb: Optional[Callable[[str, str, int, bool, dict], None]], -) -> None: +async def _await_state_workers_then_finish(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None: if state.workers: await asyncio.gather(*state.workers, return_exceptions=True) await _reconcile_missing_tags_for_root(root, prog) prog.finished_at = time.time() prog.status = "completed" - if progress_cb: - progress_cb(root, "slow", prog.processed, True, { - "slow_queue_finished": prog.slow_queue_finished, - "slow_queue_total": prog.slow_queue_total, - }) -def _collect_models_files() -> list[str]: - """Collect absolute file paths from configured model buckets under models_dir.""" - out: list[str] = [] - for folder_name, bases in get_comfy_models_folders(): - rel_files = folder_paths.get_filename_list(folder_name) or [] - for rel_path in rel_files: - abs_path = folder_paths.get_full_path(folder_name, rel_path) - if not abs_path: - continue - abs_path = os.path.abspath(abs_path) - # ensure within allowed bases - allowed = False - for b in bases: - base_abs = os.path.abspath(b) - with contextlib.suppress(Exception): - if os.path.commonpath([abs_path, base_abs]) == base_abs: - allowed = True - break - if allowed: - out.append(abs_path) - return out - - -def _count_files_in_tree(base_abs: str, *, only_nonzero: bool = False) -> int: - if not os.path.isdir(base_abs): - return 0 - total = 0 - for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): - if not only_nonzero: - total += len(filenames) - else: - for name in filenames: - with contextlib.suppress(OSError): - st = os.stat(os.path.join(dirpath, name), follow_symlinks=True) - if st.st_size: - total += 1 - return total - - -def _count_nonzero_in_list(paths: list[str]) -> int: - cnt = 0 - for p in paths: - with contextlib.suppress(OSError): - st = os.stat(p, follow_symlinks=True) - if st.st_size: - cnt += 1 - return cnt - - -async def _queue_tree_files(base_dir: str) -> asyncio.Queue: - """ - Walk base_dir in a worker thread and return a queue prefilled with all paths, - terminated by a single None sentinel for the draining loop in fast reconcile. - """ - q: asyncio.Queue = asyncio.Queue() - base_abs = os.path.abspath(base_dir) - if not os.path.isdir(base_abs): - await q.put(None) - return q - - def _walk_list(): - paths: list[str] = [] - for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): - for name in filenames: - paths.append(os.path.abspath(os.path.join(dirpath, name))) - return paths - - for p in await asyncio.to_thread(_walk_list): - await q.put(p) - await q.put(None) - return q - - -def _append_error(prog: ScanProgress, *, phase: Literal["fast", "slow"], path: str, message: str) -> None: +def _append_error(prog: ScanProgress, *, path: str, message: str) -> None: prog.file_errors.append({ "path": path, "message": message, - "phase": phase, - "at": _ts_to_iso(time.time()), + "at": ts_to_iso(time.time()), }) -def _ts_to_iso(ts: Optional[float]) -> Optional[str]: - if ts is None: - return None - # interpret ts as seconds since epoch UTC and return naive UTC (consistent with other models) - try: - return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat() - except Exception: - return None +async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None: + """ + Quick pass over asset_cache_state for `root`: + - If file missing and Asset.hash is NULL and the Asset has no other states, delete the Asset and its infos. + - If file missing and Asset.hash is NOT NULL: + * If at least one state for this Asset is fast-ok, delete the missing state. + * If none are fast-ok, add 'missing' tag to all AssetInfos for this Asset. + - If at least one state becomes fast-ok for a hashed Asset, remove the 'missing' tag. + """ + prefixes = prefixes_for_root(root) + if not prefixes: + return + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + conds.append(AssetCacheState.file_path.like(base + "%")) -def _new_scan_id(root: RootType) -> str: - return f"scan-{root}-{uuid.uuid4().hex[:8]}" + async with await create_session() as sess: + if not conds: + return - -def _console_cb(root: str, phase: str, total_processed: int, finished: bool, e: dict): - if phase == "fast": - if finished: - logging.info( - "[assets][%s] fast done: processed=%s/%s queued=%s", - root, - total_processed, - e["discovered"], - e["queued"], - ) - elif e.get("checked", 0) % 1000 == 0: # do not spam with fast progress - logging.info( - "[assets][%s] fast progress: processed=%s/%s", - root, - total_processed, - e["discovered"], - ) - elif phase == "slow": - if finished: - if e.get("slow_queue_finished", 0) or e.get("slow_queue_total", 0): - logging.info( - "[assets][%s] slow done: %s/%s", - root, - e.get("slow_queue_finished", 0), - e.get("slow_queue_total", 0), - ) - elif e.get('slow_queue_finished', 0) % 3 == 0: - logging.info( - "[assets][%s] slow progress: %s/%s", - root, - e.get("slow_queue_finished", 0), - e.get("slow_queue_total", 0), + rows = ( + await sess.execute( + sa.select(AssetCacheState, Asset.hash, Asset.size_bytes) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(sa.or_(*conds)) + .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) ) + ).all() + + # Group by asset_id with status per state + by_asset: dict[str, dict] = {} + for st, a_hash, a_size in rows: + aid = st.asset_id + acc = by_asset.get(aid) + if acc is None: + acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} + by_asset[aid] = acc + + exists = False + fast_ok = False + try: + s = os.stat(st.file_path, follow_symlinks=True) + exists = True + actual_mtime_ns = getattr(s, "st_mtime_ns", int(s.st_mtime * 1_000_000_000)) + if st.mtime_ns is not None and int(st.mtime_ns) == int(actual_mtime_ns): + if acc["size_db"] == 0 or int(s.st_size) == acc["size_db"]: + fast_ok = True + except FileNotFoundError: + exists = False + except OSError as ex: + exists = False + LOGGER.debug("fast pass stat error for %s: %s", st.file_path, ex) + + acc["states"].append({"obj": st, "exists": exists, "fast_ok": fast_ok}) + + # Apply actions + for aid, acc in by_asset.items(): + a_hash = acc["hash"] + states = acc["states"] + any_fast_ok = any(s["fast_ok"] for s in states) + all_missing = all(not s["exists"] for s in states) + missing_states = [s["obj"] for s in states if not s["exists"]] + + if a_hash is None: + # Seed asset: if all states gone (and in practice there is only one), remove the whole Asset + if states and all_missing: + await sess.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid)) + asset = await sess.get(Asset, aid) + if asset: + await sess.delete(asset) + # else leave it for the slow scan to verify/rehash + else: + if any_fast_ok: + # Remove 'missing' and delete just the stale state rows + for st in missing_states: + try: + await sess.delete(await sess.get(AssetCacheState, st.id)) + except Exception: + pass + try: + await remove_missing_tag_for_asset_id(sess, asset_id=aid) + except Exception: + pass + else: + # No fast-ok path: mark as missing + try: + await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") + except Exception: + pass + + await sess.flush() + await sess.commit() diff --git a/app/database/_helpers.py b/app/database/_helpers.py deleted file mode 100644 index a031e861c..000000000 --- a/app/database/_helpers.py +++ /dev/null @@ -1,186 +0,0 @@ -from decimal import Decimal -from typing import Any, Sequence, Optional, Iterable - -import sqlalchemy as sa -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select, exists - -from .models import AssetInfo, AssetInfoTag, Tag, AssetInfoMeta -from .._assets_helpers import normalize_tags - - -async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]: - wanted = normalize_tags(list(names)) - if not wanted: - return [] - existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() - by_name = {t.name: t for t in existing} - to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name] - if to_create: - session.add_all(to_create) - await session.flush() - by_name.update({t.name: t for t in to_create}) - return [by_name[n] for n in wanted] - - -def apply_tag_filters( - stmt: sa.sql.Select, - include_tags: Optional[Sequence[str]], - exclude_tags: Optional[Sequence[str]], -) -> sa.sql.Select: - """include_tags: every tag must be present; exclude_tags: none may be present.""" - include_tags = normalize_tags(include_tags) - exclude_tags = normalize_tags(exclude_tags) - - if include_tags: - for tag_name in include_tags: - stmt = stmt.where( - exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name == tag_name) - ) - ) - - if exclude_tags: - stmt = stmt.where( - ~exists().where( - (AssetInfoTag.asset_info_id == AssetInfo.id) - & (AssetInfoTag.tag_name.in_(exclude_tags)) - ) - ) - return stmt - - -def apply_metadata_filter( - stmt: sa.sql.Select, - metadata_filter: Optional[dict], -) -> sa.sql.Select: - """Apply metadata filters using the projection table asset_info_meta. - - Semantics: - - For scalar values: require EXISTS(asset_info_meta) with matching key + typed value. - - For None: key is missing OR key has explicit null (val_json IS NULL). - - For list values: ANY-of the list elements matches (EXISTS for any). - (Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))') - """ - if not metadata_filter: - return stmt - - def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: - return sa.exists().where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - *preds, - ) - - def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: - # Missing OR null: - if value is None: - # either: no row for key OR a row for key with explicit null - no_row_for_key = sa.not_( - sa.exists().where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - ) - ) - null_row = _exists_for_pred( - key, - AssetInfoMeta.val_json.is_(None), - AssetInfoMeta.val_str.is_(None), - AssetInfoMeta.val_num.is_(None), - AssetInfoMeta.val_bool.is_(None), - ) - return sa.or_(no_row_for_key, null_row) - - # Typed scalar matches: - if isinstance(value, bool): - return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) - if isinstance(value, (int, float, Decimal)): - # store as Decimal for equality against NUMERIC(38,10) - num = value if isinstance(value, Decimal) else Decimal(str(value)) - return _exists_for_pred(key, AssetInfoMeta.val_num == num) - if isinstance(value, str): - return _exists_for_pred(key, AssetInfoMeta.val_str == value) - - # Complex: compare JSON (no index, but supported) - return _exists_for_pred(key, AssetInfoMeta.val_json == value) - - for k, v in metadata_filter.items(): - if isinstance(v, list): - # ANY-of (exists for any element) - ors = [_exists_clause_for_value(k, elem) for elem in v] - if ors: - stmt = stmt.where(sa.or_(*ors)) - else: - stmt = stmt.where(_exists_clause_for_value(k, v)) - return stmt - - -def is_scalar(v: Any) -> bool: - if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries - return True - if isinstance(v, bool): - return True - if isinstance(v, (int, float, Decimal, str)): - return True - return False - - -def project_kv(key: str, value: Any) -> list[dict]: - """ - Turn a metadata key/value into one or more projection rows: - - scalar -> one row (ordinal=0) in the proper typed column - - list of scalars -> one row per element with ordinal=i - - dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list) - - None -> single row with all value columns NULL - Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...} - """ - rows: list[dict] = [] - - def _null_row(ordinal: int) -> dict: - return { - "key": key, "ordinal": ordinal, - "val_str": None, "val_num": None, "val_bool": None, "val_json": None - } - - if value is None: - rows.append(_null_row(0)) - return rows - - if is_scalar(value): - if isinstance(value, bool): - rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) - elif isinstance(value, (int, float, Decimal)): - # store numeric; SQLAlchemy will coerce to Numeric - num = value if isinstance(value, Decimal) else Decimal(str(value)) - rows.append({"key": key, "ordinal": 0, "val_num": num}) - elif isinstance(value, str): - rows.append({"key": key, "ordinal": 0, "val_str": value}) - else: - # Fallback to json - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows - - if isinstance(value, list): - if all(is_scalar(x) for x in value): - for i, x in enumerate(value): - if x is None: - rows.append(_null_row(i)) - elif isinstance(x, bool): - rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) - elif isinstance(x, (int, float, Decimal)): - num = x if isinstance(x, Decimal) else Decimal(str(x)) - rows.append({"key": key, "ordinal": i, "val_num": num}) - elif isinstance(x, str): - rows.append({"key": key, "ordinal": i, "val_str": x}) - else: - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - # list contains objects -> one val_json per element - for i, x in enumerate(value): - rows.append({"key": key, "ordinal": i, "val_json": x}) - return rows - - # Dict or any other structure -> single json row - rows.append({"key": key, "ordinal": 0, "val_json": value}) - return rows diff --git a/app/database/db.py b/app/database/db.py index 8280272b0..82c9cc737 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -4,14 +4,20 @@ import shutil from contextlib import asynccontextmanager from typing import Optional -from comfy.cli_args import args from alembic import command from alembic.config import Config from alembic.runtime.migration import MigrationContext from alembic.script import ScriptDirectory from sqlalchemy import create_engine, text from sqlalchemy.engine import make_url -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) + +from comfy.cli_args import args LOGGER = logging.getLogger(__name__) ENGINE: Optional[AsyncEngine] = None diff --git a/app/database/helpers/__init__.py b/app/database/helpers/__init__.py new file mode 100644 index 000000000..19d7507fa --- /dev/null +++ b/app/database/helpers/__init__.py @@ -0,0 +1,23 @@ +from .filters import apply_metadata_filter, apply_tag_filters +from .ownership import visible_owner_clause +from .projection import is_scalar, project_kv +from .tags import ( + add_missing_tag_for_asset_hash, + add_missing_tag_for_asset_id, + ensure_tags_exist, + remove_missing_tag_for_asset_hash, + remove_missing_tag_for_asset_id, +) + +__all__ = [ + "apply_tag_filters", + "apply_metadata_filter", + "is_scalar", + "project_kv", + "ensure_tags_exist", + "add_missing_tag_for_asset_id", + "add_missing_tag_for_asset_hash", + "remove_missing_tag_for_asset_id", + "remove_missing_tag_for_asset_hash", + "visible_owner_clause", +] diff --git a/app/database/helpers/filters.py b/app/database/helpers/filters.py new file mode 100644 index 000000000..0b6d85b8d --- /dev/null +++ b/app/database/helpers/filters.py @@ -0,0 +1,87 @@ +from typing import Optional, Sequence + +import sqlalchemy as sa +from sqlalchemy import exists + +from ..._assets_helpers import normalize_tags +from ..models import AssetInfo, AssetInfoMeta, AssetInfoTag + + +def apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Optional[Sequence[str]], + exclude_tags: Optional[Sequence[str]], +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + + +def apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: Optional[dict], +) -> sa.sql.Select: + """Apply filters using asset_info_meta projection table.""" + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + return sa.exists().where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + *preds, + ) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + if value is None: + no_row_for_key = sa.not_( + sa.exists().where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + ) + ) + null_row = _exists_for_pred( + key, + AssetInfoMeta.val_json.is_(None), + AssetInfoMeta.val_str.is_(None), + AssetInfoMeta.val_num.is_(None), + AssetInfoMeta.val_bool.is_(None), + ) + return sa.or_(no_row_for_key, null_row) + + if isinstance(value, bool): + return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) + if isinstance(value, (int, float)): + from decimal import Decimal + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetInfoMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetInfoMeta.val_str == value) + return _exists_for_pred(key, AssetInfoMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + ors = [_exists_clause_for_value(k, elem) for elem in v] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt diff --git a/app/database/helpers/ownership.py b/app/database/helpers/ownership.py new file mode 100644 index 000000000..c00731608 --- /dev/null +++ b/app/database/helpers/ownership.py @@ -0,0 +1,12 @@ +import sqlalchemy as sa + +from ..models import AssetInfo + + +def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: + """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" + + owner_id = (owner_id or "").strip() + if owner_id == "": + return AssetInfo.owner_id == "" + return AssetInfo.owner_id.in_(["", owner_id]) diff --git a/app/database/helpers/projection.py b/app/database/helpers/projection.py new file mode 100644 index 000000000..687802d18 --- /dev/null +++ b/app/database/helpers/projection.py @@ -0,0 +1,64 @@ +from decimal import Decimal + + +def is_scalar(v): + if v is None: + return True + if isinstance(v, bool): + return True + if isinstance(v, (int, float, Decimal, str)): + return True + return False + + +def project_kv(key: str, value): + """ + Turn a metadata key/value into typed projection rows. + Returns list[dict] with keys: + key, ordinal, and one of val_str / val_num / val_bool / val_json (others None) + """ + rows: list[dict] = [] + + def _null_row(ordinal: int) -> dict: + return { + "key": key, "ordinal": ordinal, + "val_str": None, "val_num": None, "val_bool": None, "val_json": None + } + + if value is None: + rows.append(_null_row(0)) + return rows + + if is_scalar(value): + if isinstance(value, bool): + rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) + elif isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + rows.append({"key": key, "ordinal": 0, "val_num": num}) + elif isinstance(value, str): + rows.append({"key": key, "ordinal": 0, "val_str": value}) + else: + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows + + if isinstance(value, list): + if all(is_scalar(x) for x in value): + for i, x in enumerate(value): + if x is None: + rows.append(_null_row(i)) + elif isinstance(x, bool): + rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) + elif isinstance(x, (int, float, Decimal)): + num = x if isinstance(x, Decimal) else Decimal(str(x)) + rows.append({"key": key, "ordinal": i, "val_num": num}) + elif isinstance(x, str): + rows.append({"key": key, "ordinal": i, "val_str": x}) + else: + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + for i, x in enumerate(value): + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows diff --git a/app/database/helpers/tags.py b/app/database/helpers/tags.py new file mode 100644 index 000000000..479343096 --- /dev/null +++ b/app/database/helpers/tags.py @@ -0,0 +1,102 @@ +from typing import Iterable + +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession + +from ..._assets_helpers import normalize_tags +from ..models import Asset, AssetInfo, AssetInfoTag, Tag +from ..timeutil import utcnow + + +async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]: + wanted = normalize_tags(list(names)) + if not wanted: + return [] + existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() + by_name = {t.name: t for t in existing} + to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name] + if to_create: + session.add_all(to_create) + await session.flush() + by_name.update({t.name: t for t in to_create}) + return [by_name[n] for n in wanted] + + +async def add_missing_tag_for_asset_id( + session: AsyncSession, + *, + asset_id: str, + origin: str = "automatic", +) -> int: + """Ensure every AssetInfo for asset_id has 'missing' tag.""" + ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all() + if not ids: + return 0 + + existing = { + asset_info_id + for (asset_info_id,) in ( + await session.execute( + select(AssetInfoTag.asset_info_id).where( + AssetInfoTag.asset_info_id.in_(ids), + AssetInfoTag.tag_name == "missing", + ) + ) + ).all() + } + to_add = [i for i in ids if i not in existing] + if not to_add: + return 0 + + now = utcnow() + session.add_all( + [ + AssetInfoTag(asset_info_id=i, tag_name="missing", origin=origin, added_at=now) + for i in to_add + ] + ) + await session.flush() + return len(to_add) + + +async def add_missing_tag_for_asset_hash( + session: AsyncSession, + *, + asset_hash: str, + origin: str = "automatic", +) -> int: + asset = (await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))).scalars().first() + if not asset: + return 0 + return await add_missing_tag_for_asset_id(session, asset_id=asset.id, origin=origin) + + +async def remove_missing_tag_for_asset_id( + session: AsyncSession, + *, + asset_id: str, +) -> int: + """Remove the 'missing' tag from all AssetInfos for asset_id.""" + ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all() + if not ids: + return 0 + + res = await session.execute( + delete(AssetInfoTag).where( + AssetInfoTag.asset_info_id.in_(ids), + AssetInfoTag.tag_name == "missing", + ) + ) + await session.flush() + return int(res.rowcount or 0) + + +async def remove_missing_tag_for_asset_hash( + session: AsyncSession, + *, + asset_hash: str, +) -> int: + asset = (await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))).scalars().first() + if not asset: + return 0 + return await remove_missing_tag_for_asset_id(session, asset_id=asset.id) diff --git a/app/database/models.py b/app/database/models.py index 55fc08e51..6a6798bcf 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,27 +1,26 @@ +import uuid from datetime import datetime from typing import Any, Optional -import uuid from sqlalchemy import ( - Integer, + JSON, BigInteger, + Boolean, + CheckConstraint, DateTime, ForeignKey, Index, - UniqueConstraint, - JSON, + Integer, + Numeric, String, Text, - CheckConstraint, - Numeric, - Boolean, + UniqueConstraint, ) from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign +from sqlalchemy.orm import DeclarativeBase, Mapped, foreign, mapped_column, relationship from .timeutil import utcnow - JSONB_V = JSON(none_as_null=True).with_variant(JSONB(none_as_null=True), 'postgresql') @@ -46,7 +45,8 @@ def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]: class Asset(Base): __tablename__ = "assets" - hash: Mapped[str] = mapped_column(String(256), primary_key=True) + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True) size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) mime_type: Mapped[Optional[str]] = mapped_column(String(255)) created_at: Mapped[datetime] = mapped_column( @@ -56,8 +56,8 @@ class Asset(Base): infos: Mapped[list["AssetInfo"]] = relationship( "AssetInfo", back_populates="asset", - primaryjoin=lambda: Asset.hash == foreign(AssetInfo.asset_hash), - foreign_keys=lambda: [AssetInfo.asset_hash], + primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id), + foreign_keys=lambda: [AssetInfo.asset_id], cascade="all,delete-orphan", passive_deletes=True, ) @@ -65,8 +65,8 @@ class Asset(Base): preview_of: Mapped[list["AssetInfo"]] = relationship( "AssetInfo", back_populates="preview_asset", - primaryjoin=lambda: Asset.hash == foreign(AssetInfo.preview_hash), - foreign_keys=lambda: [AssetInfo.preview_hash], + primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id), + foreign_keys=lambda: [AssetInfo.preview_id], viewonly=True, ) @@ -76,36 +76,32 @@ class Asset(Base): passive_deletes=True, ) - locations: Mapped[list["AssetLocation"]] = relationship( - back_populates="asset", - cascade="all, delete-orphan", - passive_deletes=True, - ) - __table_args__ = ( Index("ix_assets_mime_type", "mime_type"), + CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), ) def to_dict(self, include_none: bool = False) -> dict[str, Any]: return to_dict(self, include_none=include_none) def __repr__(self) -> str: - return f"" + return f"" class AssetCacheState(Base): __tablename__ = "asset_cache_state" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False) + asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False) file_path: Mapped[str] = mapped_column(Text, nullable=False) mtime_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True) + needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) asset: Mapped["Asset"] = relationship(back_populates="cache_states") __table_args__ = ( Index("ix_asset_cache_state_file_path", "file_path"), - Index("ix_asset_cache_state_asset_hash", "asset_hash"), + Index("ix_asset_cache_state_asset_id", "asset_id"), CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), ) @@ -114,27 +110,7 @@ class AssetCacheState(Base): return to_dict(self, include_none=include_none) def __repr__(self) -> str: - return f"" - - -class AssetLocation(Base): - __tablename__ = "asset_locations" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False) - provider: Mapped[str] = mapped_column(String(32), nullable=False) # "gcs" - locator: Mapped[str] = mapped_column(Text, nullable=False) # "gs://bucket/object" - expected_size_bytes: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True) - etag: Mapped[Optional[str]] = mapped_column(String(256), nullable=True) - last_modified: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) - - asset: Mapped["Asset"] = relationship(back_populates="locations") - - __table_args__ = ( - UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"), - Index("ix_asset_locations_hash", "asset_hash"), - Index("ix_asset_locations_provider", "provider"), - ) + return f"" class AssetInfo(Base): @@ -143,31 +119,23 @@ class AssetInfo(Base): id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") name: Mapped[str] = mapped_column(String(512), nullable=False) - asset_hash: Mapped[str] = mapped_column( - String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False - ) - preview_hash: Mapped[Optional[str]] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL")) + asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False) + preview_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL")) user_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON(none_as_null=True)) - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=False), nullable=False, default=utcnow - ) - updated_at: Mapped[datetime] = mapped_column( - DateTime(timezone=False), nullable=False, default=utcnow - ) - last_access_time: Mapped[datetime] = mapped_column( - DateTime(timezone=False), nullable=False, default=utcnow - ) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) - # Relationships asset: Mapped[Asset] = relationship( "Asset", back_populates="infos", - foreign_keys=[asset_hash], + foreign_keys=[asset_id], + lazy="selectin", ) preview_asset: Mapped[Optional[Asset]] = relationship( "Asset", back_populates="preview_of", - foreign_keys=[preview_hash], + foreign_keys=[preview_id], ) metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship( @@ -186,16 +154,16 @@ class AssetInfo(Base): tags: Mapped[list["Tag"]] = relationship( secondary="asset_info_tags", back_populates="asset_infos", - lazy="joined", + lazy="selectin", viewonly=True, overlaps="tag_links,asset_info_links,asset_infos,tag", ) __table_args__ = ( - UniqueConstraint("asset_hash", "owner_id", "name", name="uq_assets_info_hash_owner_name"), + UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), Index("ix_assets_info_owner_name", "owner_id", "name"), Index("ix_assets_info_owner_id", "owner_id"), - Index("ix_assets_info_asset_hash", "asset_hash"), + Index("ix_assets_info_asset_id", "asset_id"), Index("ix_assets_info_name", "name"), Index("ix_assets_info_created_at", "created_at"), Index("ix_assets_info_last_access_time", "last_access_time"), @@ -207,7 +175,7 @@ class AssetInfo(Base): return data def __repr__(self) -> str: - return f"" + return f"" class AssetInfoMeta(Base): diff --git a/app/database/services.py b/app/database/services.py deleted file mode 100644 index 842103e9e..000000000 --- a/app/database/services.py +++ /dev/null @@ -1,1116 +0,0 @@ -import contextlib -import os -import logging -from collections import defaultdict -from datetime import datetime -from typing import Any, Sequence, Optional, Union - -import sqlalchemy as sa -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select, delete, func -from sqlalchemy.orm import contains_eager, noload -from sqlalchemy.exc import IntegrityError -from sqlalchemy.dialects import sqlite as d_sqlite -from sqlalchemy.dialects import postgresql as d_pg - -from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation -from .timeutil import utcnow -from .._assets_helpers import normalize_tags, visible_owner_clause, compute_model_relative_filename -from . import _helpers - - -async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool: - row = ( - await session.execute( - select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) - ) - ).first() - return row is not None - - -async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]: - return await session.get(Asset, asset_hash) - - -async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]: - return await session.get(AssetInfo, asset_info_id) - - -async def asset_info_exists_for_hash(session: AsyncSession, *, asset_hash: str) -> bool: - return ( - await session.execute( - sa.select(sa.literal(True)) - .select_from(AssetInfo) - .where(AssetInfo.asset_hash == asset_hash) - .limit(1) - ) - ).first() is not None - - -async def check_fs_asset_exists_quick( - session, - *, - file_path: str, - size_bytes: Optional[int] = None, - mtime_ns: Optional[int] = None, -) -> bool: - """ - Returns 'True' if there is already AssetCacheState record that matches this absolute path, - AND (if provided) mtime_ns matches stored locator-state, - AND (if provided) size_bytes matches verified size when known. - """ - locator = os.path.abspath(file_path) - - stmt = select(sa.literal(True)).select_from(AssetCacheState).join( - Asset, Asset.hash == AssetCacheState.asset_hash - ).where(AssetCacheState.file_path == locator).limit(1) - - conds = [] - if mtime_ns is not None: - conds.append(AssetCacheState.mtime_ns == int(mtime_ns)) - if size_bytes is not None: - conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes))) - - if conds: - stmt = stmt.where(*conds) - - row = (await session.execute(stmt)).first() - return row is not None - - -async def ingest_fs_asset( - session: AsyncSession, - *, - asset_hash: str, - abs_path: str, - size_bytes: int, - mtime_ns: int, - mime_type: Optional[str] = None, - info_name: Optional[str] = None, - owner_id: str = "", - preview_hash: Optional[str] = None, - user_metadata: Optional[dict] = None, - tags: Sequence[str] = (), - tag_origin: str = "manual", - require_existing_tags: bool = False, -) -> dict: - """ - Upsert Asset identity row + cache state(s) pointing at local file. - - Always: - - Insert Asset if missing; - - Insert AssetCacheState if missing; else update mtime_ns and asset_hash if different. - - Optionally (when info_name is provided): - - Create or update an AssetInfo on (asset_hash, owner_id, name). - - Link provided tags to that AssetInfo. - * If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table. - * If False (default), create unknown tags. - - Returns flags and ids: - { - "asset_created": bool, - "asset_updated": bool, - "state_created": bool, - "state_updated": bool, - "asset_info_id": str | None, - } - """ - locator = os.path.abspath(abs_path) - datetime_now = utcnow() - - out: dict[str, Any] = { - "asset_created": False, - "asset_updated": False, - "state_created": False, - "state_updated": False, - "asset_info_id": None, - } - - # ---- Step 1: INSERT Asset or UPDATE size_bytes/updated_at if exists ---- - with contextlib.suppress(IntegrityError): - async with session.begin_nested(): - session.add( - Asset( - hash=asset_hash, - size_bytes=int(size_bytes), - mime_type=mime_type, - created_at=datetime_now, - ) - ) - await session.flush() - out["asset_created"] = True - - if not out["asset_created"]: - existing = await session.get(Asset, asset_hash) - if existing is not None: - changed = False - if existing.size_bytes != size_bytes: - existing.size_bytes = size_bytes - changed = True - if mime_type and existing.mime_type != mime_type: - existing.mime_type = mime_type - changed = True - if changed: - out["asset_updated"] = True - else: - logging.error("Asset %s not found after PK conflict; skipping update.", asset_hash) - - # ---- Step 2: UPSERT AssetCacheState (mtime_ns, file_path) ---- - dialect = session.bind.dialect.name # "sqlite" or "postgresql" - vals = { - "asset_hash": asset_hash, - "file_path": locator, - "mtime_ns": int(mtime_ns), - } - # 2-step idempotent write so we can set flags deterministically: - # INSERT ... ON CONFLICT(file_path) DO NOTHING - # if conflicted, UPDATE only when values actually change - if dialect == "sqlite": - ins = ( - d_sqlite.insert(AssetCacheState) - .values(**vals) - .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - ) - elif dialect == "postgresql": - ins = ( - d_pg.insert(AssetCacheState) - .values(**vals) - .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) - ) - else: - raise NotImplementedError(f"Unsupported database dialect: {dialect}") - res = await session.execute(ins) - if int(res.rowcount or 0) > 0: - out["state_created"] = True - else: - upd = ( - sa.update(AssetCacheState) - .where(AssetCacheState.file_path == locator) - .where( - sa.or_( - AssetCacheState.asset_hash != asset_hash, - AssetCacheState.mtime_ns.is_(None), - AssetCacheState.mtime_ns != int(mtime_ns), - ) - ) - .values(asset_hash=asset_hash, mtime_ns=int(mtime_ns)) - ) - res2 = await session.execute(upd) - if int(res2.rowcount or 0) > 0: - out["state_updated"] = True - - # ---- Optional: AssetInfo + tag links ---- - if info_name: - # 2a) Upsert AssetInfo idempotently on (asset_hash, owner_id, name) - with contextlib.suppress(IntegrityError): - async with session.begin_nested(): - info = AssetInfo( - owner_id=owner_id, - name=info_name, - asset_hash=asset_hash, - preview_hash=preview_hash, - created_at=datetime_now, - updated_at=datetime_now, - last_access_time=datetime_now, - ) - session.add(info) - await session.flush() # get info.id (UUID) - out["asset_info_id"] = info.id - - existing_info = ( - await session.execute( - select(AssetInfo) - .where( - AssetInfo.asset_hash == asset_hash, - AssetInfo.name == info_name, - (AssetInfo.owner_id == owner_id), - ) - .limit(1) - ) - ).unique().scalar_one_or_none() - if not existing_info: - raise RuntimeError("Failed to update or insert AssetInfo.") - - if preview_hash is not None and existing_info.preview_hash != preview_hash: - existing_info.preview_hash = preview_hash - existing_info.updated_at = datetime_now - if existing_info.last_access_time < datetime_now: - existing_info.last_access_time = datetime_now - await session.flush() - out["asset_info_id"] = existing_info.id - - # 2b) Link tags (if any). We DO NOT create new Tag rows here by default. - norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] - if norm and out["asset_info_id"] is not None: - if not require_existing_tags: - await _helpers.ensure_tags_exist(session, norm, tag_type="user") - - # Which tags exist? - existing_tag_names = set( - name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() - ) - missing = [t for t in norm if t not in existing_tag_names] - if missing and require_existing_tags: - raise ValueError(f"Unknown tags: {missing}") - - # Which links already exist? - existing_links = set( - tag_name - for (tag_name,) in ( - await session.execute( - select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"]) - ) - ).all() - ) - to_add = [t for t in norm if t in existing_tag_names and t not in existing_links] - if to_add: - session.add_all( - [ - AssetInfoTag( - asset_info_id=out["asset_info_id"], - tag_name=t, - origin=tag_origin, - added_at=datetime_now, - ) - for t in to_add - ] - ) - await session.flush() - - # 2c) Rebuild metadata projection if provided - # Uncomment next code, and remove code after it, once the hack with "metadata[filename" is not needed anymore - # if user_metadata is not None and out["asset_info_id"] is not None: - # await replace_asset_info_metadata_projection( - # session, - # asset_info_id=out["asset_info_id"], - # user_metadata=user_metadata, - # ) - # start of adding metadata["filename"] - if out["asset_info_id"] is not None: - primary_path = ( - await session.execute( - select(AssetCacheState.file_path) - .where(AssetCacheState.asset_hash == asset_hash) - .order_by(AssetCacheState.id.asc()) - .limit(1) - ) - ).scalars().first() - computed_filename = compute_model_relative_filename(primary_path) if primary_path else None - - # Start from current metadata on this AssetInfo, if any - current_meta = existing_info.user_metadata or {} - new_meta = dict(current_meta) - - # Merge caller-provided metadata, if any (caller keys override current) - if user_metadata is not None: - for k, v in user_metadata.items(): - new_meta[k] = v - - # Enforce correct model-relative filename when known - if computed_filename: - new_meta["filename"] = computed_filename - - # Only write when there is a change - if new_meta != current_meta: - await replace_asset_info_metadata_projection( - session, - asset_info_id=out["asset_info_id"], - user_metadata=new_meta, - ) - # end of adding metadata["filename"] - try: - await remove_missing_tag_for_asset_hash(session, asset_hash=asset_hash) - except Exception: - logging.exception("Failed to clear 'missing' tag for %s", asset_hash) - return out - - -async def touch_asset_infos_by_fs_path( - session: AsyncSession, - *, - file_path: str, - ts: Optional[datetime] = None, - only_if_newer: bool = True, -) -> int: - locator = os.path.abspath(file_path) - ts = ts or utcnow() - - stmt = sa.update(AssetInfo).where( - sa.exists( - sa.select(sa.literal(1)) - .select_from(AssetCacheState) - .where( - AssetCacheState.asset_hash == AssetInfo.asset_hash, - AssetCacheState.file_path == locator, - ) - ) - ) - - if only_if_newer: - stmt = stmt.where( - sa.or_( - AssetInfo.last_access_time.is_(None), - AssetInfo.last_access_time < ts, - ) - ) - - stmt = stmt.values(last_access_time=ts) - - res = await session.execute(stmt) - return int(res.rowcount or 0) - - -async def touch_asset_info_by_id( - session: AsyncSession, - *, - asset_info_id: str, - ts: Optional[datetime] = None, - only_if_newer: bool = True, -) -> int: - ts = ts or utcnow() - stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) - if only_if_newer: - stmt = stmt.where( - sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) - ) - stmt = stmt.values(last_access_time=ts) - res = await session.execute(stmt) - return int(res.rowcount or 0) - - -async def list_asset_infos_page( - session: AsyncSession, - *, - owner_id: str = "", - include_tags: Optional[Sequence[str]] = None, - exclude_tags: Optional[Sequence[str]] = None, - name_contains: Optional[str] = None, - metadata_filter: Optional[dict] = None, - limit: int = 20, - offset: int = 0, - sort: str = "created_at", - order: str = "desc", -) -> tuple[list[AssetInfo], dict[str, list[str]], int]: - """Return page of AssetInfo rows in the viewers visibility.""" - base = ( - select(AssetInfo) - .join(Asset, Asset.hash == AssetInfo.asset_hash) - .options(contains_eager(AssetInfo.asset)) - .where(visible_owner_clause(owner_id)) - ) - - if name_contains: - base = base.where(AssetInfo.name.ilike(f"%{name_contains}%")) - - base = _helpers.apply_tag_filters(base, include_tags, exclude_tags) - base = _helpers.apply_metadata_filter(base, metadata_filter) - - sort = (sort or "created_at").lower() - order = (order or "desc").lower() - sort_map = { - "name": AssetInfo.name, - "created_at": AssetInfo.created_at, - "updated_at": AssetInfo.updated_at, - "last_access_time": AssetInfo.last_access_time, - "size": Asset.size_bytes, - } - sort_col = sort_map.get(sort, AssetInfo.created_at) - sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() - - base = base.order_by(sort_exp).limit(limit).offset(offset) - - count_stmt = ( - select(func.count()) - .select_from(AssetInfo) - .join(Asset, Asset.hash == AssetInfo.asset_hash) - .where(visible_owner_clause(owner_id)) - ) - if name_contains: - count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%")) - count_stmt = _helpers.apply_tag_filters(count_stmt, include_tags, exclude_tags) - count_stmt = _helpers.apply_metadata_filter(count_stmt, metadata_filter) - - total = int((await session.execute(count_stmt)).scalar_one() or 0) - - infos = (await session.execute(base)).scalars().unique().all() - - # Collect tags in bulk - id_list: list[str] = [i.id for i in infos] - tag_map: dict[str, list[str]] = defaultdict(list) - if id_list: - rows = await session.execute( - select(AssetInfoTag.asset_info_id, Tag.name) - .join(Tag, Tag.name == AssetInfoTag.tag_name) - .where(AssetInfoTag.asset_info_id.in_(id_list)) - ) - for aid, tag_name in rows.all(): - tag_map[aid].append(tag_name) - - return infos, tag_map, total - - -async def fetch_asset_info_and_asset( - session: AsyncSession, - *, - asset_info_id: str, - owner_id: str = "", -) -> Optional[tuple[AssetInfo, Asset]]: - stmt = ( - select(AssetInfo, Asset) - .join(Asset, Asset.hash == AssetInfo.asset_hash) - .where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - .limit(1) - ) - row = await session.execute(stmt) - pair = row.first() - if not pair: - return None - return pair[0], pair[1] - - -async def fetch_asset_info_asset_and_tags( - session: AsyncSession, - *, - asset_info_id: str, - owner_id: str = "", -) -> Optional[tuple[AssetInfo, Asset, list[str]]]: - stmt = ( - select(AssetInfo, Asset, Tag.name) - .join(Asset, Asset.hash == AssetInfo.asset_hash) - .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) - .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) - .where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - ) - .options(noload(AssetInfo.tags)) - .order_by(Tag.name.asc()) - ) - - rows = (await session.execute(stmt)).all() - if not rows: - return None - - # First row contains the mapped entities; tags may repeat across rows - first_info, first_asset, _ = rows[0] - tags: list[str] = [] - seen: set[str] = set() - for _info, _asset, tag_name in rows: - if tag_name and tag_name not in seen: - seen.add(tag_name) - tags.append(tag_name) - return first_info, first_asset, tags - - -async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: str) -> Optional[AssetCacheState]: - """Return the oldest cache row for this asset.""" - return ( - await session.execute( - select(AssetCacheState) - .where(AssetCacheState.asset_hash == asset_hash) - .order_by(AssetCacheState.id.asc()) - .limit(1) - ) - ).scalars().first() - - -async def list_cache_states_by_asset_hash( - session: AsyncSession, *, asset_hash: str -) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]: - """Return all cache rows for this asset ordered by oldest first.""" - return ( - await session.execute( - select(AssetCacheState) - .where(AssetCacheState.asset_hash == asset_hash) - .order_by(AssetCacheState.id.asc()) - ) - ).scalars().all() - - -async def list_asset_locations( - session: AsyncSession, *, asset_hash: str, provider: Optional[str] = None -) -> Union[list[AssetLocation], Sequence[AssetLocation]]: - stmt = select(AssetLocation).where(AssetLocation.asset_hash == asset_hash) - if provider: - stmt = stmt.where(AssetLocation.provider == provider) - return (await session.execute(stmt)).scalars().all() - - -async def upsert_asset_location( - session: AsyncSession, - *, - asset_hash: str, - provider: str, - locator: str, - expected_size_bytes: Optional[int] = None, - etag: Optional[str] = None, - last_modified: Optional[str] = None, -) -> AssetLocation: - loc = ( - await session.execute( - select(AssetLocation).where( - AssetLocation.asset_hash == asset_hash, - AssetLocation.provider == provider, - AssetLocation.locator == locator, - ).limit(1) - ) - ).scalars().first() - if loc: - changed = False - if expected_size_bytes is not None and loc.expected_size_bytes != expected_size_bytes: - loc.expected_size_bytes = expected_size_bytes - changed = True - if etag is not None and loc.etag != etag: - loc.etag = etag - changed = True - if last_modified is not None and loc.last_modified != last_modified: - loc.last_modified = last_modified - changed = True - if changed: - await session.flush() - return loc - - loc = AssetLocation( - asset_hash=asset_hash, - provider=provider, - locator=locator, - expected_size_bytes=expected_size_bytes, - etag=etag, - last_modified=last_modified, - ) - session.add(loc) - await session.flush() - return loc - - -async def create_asset_info_for_existing_asset( - session: AsyncSession, - *, - asset_hash: str, - name: str, - user_metadata: Optional[dict] = None, - tags: Optional[Sequence[str]] = None, - tag_origin: str = "manual", - owner_id: str = "", -) -> AssetInfo: - """Create a new AssetInfo referencing an existing Asset. If row already exists, return it unchanged.""" - now = utcnow() - info = AssetInfo( - owner_id=owner_id, - name=name, - asset_hash=asset_hash, - preview_hash=None, - created_at=now, - updated_at=now, - last_access_time=now, - ) - try: - async with session.begin_nested(): - session.add(info) - await session.flush() # get info.id - except IntegrityError: - existing = ( - await session.execute( - select(AssetInfo) - .where( - AssetInfo.asset_hash == asset_hash, - AssetInfo.name == name, - AssetInfo.owner_id == owner_id, - ) - .limit(1) - ) - ).scalars().first() - if not existing: - raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.") - return existing - - # Uncomment next code, and remove code after it, once the hack with "metadata[filename" is not needed anymore - # if user_metadata is not None: - # await replace_asset_info_metadata_projection( - # session, asset_info_id=info.id, user_metadata=user_metadata - # ) - - # start of adding metadata["filename"] - new_meta = dict(user_metadata or {}) - - computed_filename = None - try: - state = await get_cache_state_by_asset_hash(session, asset_hash=asset_hash) - if state and state.file_path: - computed_filename = compute_model_relative_filename(state.file_path) - except Exception: - computed_filename = None - - if computed_filename: - new_meta["filename"] = computed_filename - - if new_meta: - await replace_asset_info_metadata_projection( - session, - asset_info_id=info.id, - user_metadata=new_meta, - ) - # end of adding metadata["filename"] - - if tags is not None: - await set_asset_info_tags( - session, - asset_info_id=info.id, - tags=tags, - origin=tag_origin, - ) - return info - - -async def set_asset_info_tags( - session: AsyncSession, - *, - asset_info_id: str, - tags: Sequence[str], - origin: str = "manual", -) -> dict: - """ - Replace the tag set on an AssetInfo with `tags`. Idempotent. - Creates missing tag names as 'user'. - """ - desired = normalize_tags(tags) - - # current links - current = set( - tag_name for (tag_name,) in ( - await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) - ).all() - ) - - to_add = [t for t in desired if t not in current] - to_remove = [t for t in current if t not in desired] - - if to_add: - await _helpers.ensure_tags_exist(session, to_add, tag_type="user") - session.add_all([ - AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) - for t in to_add - ]) - await session.flush() - - if to_remove: - await session.execute( - delete(AssetInfoTag) - .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) - ) - await session.flush() - - return {"added": to_add, "removed": to_remove, "total": desired} - - -async def update_asset_info_full( - session: AsyncSession, - *, - asset_info_id: str, - name: Optional[str] = None, - tags: Optional[Sequence[str]] = None, - user_metadata: Optional[dict] = None, - tag_origin: str = "manual", - asset_info_row: Any = None, -) -> AssetInfo: - """ - Update AssetInfo fields: - - name (if provided) - - user_metadata blob + rebuild projection (if provided) - - replace tags with provided set (if provided) - Returns the updated AssetInfo. - """ - if not asset_info_row: - info = await session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - else: - info = asset_info_row - - touched = False - if name is not None and name != info.name: - info.name = name - touched = True - - # Uncomment next code, and remove code after it, once the hack with "metadata[filename" is not needed anymore - # if user_metadata is not None: - # await replace_asset_info_metadata_projection( - # session, asset_info_id=asset_info_id, user_metadata=user_metadata - # ) - # touched = True - - # start of adding metadata["filename"] - computed_filename = None - try: - state = await get_cache_state_by_asset_hash(session, asset_hash=info.asset_hash) - if state and state.file_path: - computed_filename = compute_model_relative_filename(state.file_path) - except Exception: - computed_filename = None - - if user_metadata is not None: - new_meta = dict(user_metadata) - if computed_filename: - new_meta["filename"] = computed_filename - await replace_asset_info_metadata_projection( - session, asset_info_id=asset_info_id, user_metadata=new_meta - ) - touched = True - else: - if computed_filename: - current_meta = info.user_metadata or {} - if current_meta.get("filename") != computed_filename: - new_meta = dict(current_meta) - new_meta["filename"] = computed_filename - await replace_asset_info_metadata_projection( - session, asset_info_id=asset_info_id, user_metadata=new_meta - ) - touched = True - # end of adding metadata["filename"] - - if tags is not None: - await set_asset_info_tags( - session, - asset_info_id=asset_info_id, - tags=tags, - origin=tag_origin, - ) - touched = True - - if touched and user_metadata is None: - info.updated_at = utcnow() - await session.flush() - - return info - - -async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool: - """Delete the user-visible AssetInfo row. Cascades clear tags and metadata.""" - res = await session.execute(delete(AssetInfo).where( - AssetInfo.id == asset_info_id, - visible_owner_clause(owner_id), - )) - return bool(res.rowcount) - - -async def replace_asset_info_metadata_projection( - session: AsyncSession, - *, - asset_info_id: str, - user_metadata: Optional[dict], -) -> None: - """Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`.""" - info = await session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - info.user_metadata = user_metadata or {} - info.updated_at = utcnow() - await session.flush() - - await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) - await session.flush() - - if not user_metadata: - return - - rows: list[AssetInfoMeta] = [] - for k, v in user_metadata.items(): - for r in _helpers.project_kv(k, v): - rows.append( - AssetInfoMeta( - asset_info_id=asset_info_id, - key=r["key"], - ordinal=int(r["ordinal"]), - val_str=r.get("val_str"), - val_num=r.get("val_num"), - val_bool=r.get("val_bool"), - val_json=r.get("val_json"), - ) - ) - if rows: - session.add_all(rows) - await session.flush() - - -async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]: - return [ - tag_name - for (tag_name,) in ( - await session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - ] - - -async def list_tags_with_usage( - session: AsyncSession, - *, - prefix: Optional[str] = None, - limit: int = 100, - offset: int = 0, - include_zero: bool = True, - order: str = "count_desc", # "count_desc" | "name_asc" - owner_id: str = "", -) -> tuple[list[tuple[str, str, int]], int]: - # Subquery with counts by tag_name and owner_id - counts_sq = ( - select( - AssetInfoTag.tag_name.label("tag_name"), - func.count(AssetInfoTag.asset_info_id).label("cnt"), - ) - .select_from(AssetInfoTag) - .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) - .where(visible_owner_clause(owner_id)) - .group_by(AssetInfoTag.tag_name) - .subquery() - ) - - # Base select with LEFT JOIN so we can include zero-usage tags - q = ( - select( - Tag.name, - Tag.tag_type, - func.coalesce(counts_sq.c.cnt, 0).label("count"), - ) - .select_from(Tag) - .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) - ) - - # Prefix filter (tags are lowercase by check constraint) - if prefix: - q = q.where(Tag.name.like(prefix.strip().lower() + "%")) - - # Include_zero toggles: if False, drop zero-usage tags - if not include_zero: - q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) - - if order == "name_asc": - q = q.order_by(Tag.name.asc()) - else: # default "count_desc" - q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) - - # Total (without limit/offset, same filters) - total_q = select(func.count()).select_from(Tag) - if prefix: - total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%")) - if not include_zero: - # count only names that appear in counts subquery - total_q = total_q.where( - Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) - ) - - rows = (await session.execute(q.limit(limit).offset(offset))).all() - total = (await session.execute(total_q)).scalar_one() - - # Normalize counts to int for SQLite/Postgres consistency - rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] - return rows_norm, int(total or 0) - - -async def add_tags_to_asset_info( - session: AsyncSession, - *, - asset_info_id: str, - tags: Sequence[str], - origin: str = "manual", - create_if_missing: bool = True, - asset_info_row: Any = None, -) -> dict: - """Adds tags to an AssetInfo. - If create_if_missing=True, missing tag rows are created as 'user'. - Returns: {"added": [...], "already_present": [...], "total_tags": [...]} - """ - if not asset_info_row: - info = await session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - norm = normalize_tags(tags) - if not norm: - total = await get_asset_tags(session, asset_info_id=asset_info_id) - return {"added": [], "already_present": [], "total_tags": total} - - # Ensure tag rows exist if requested. - if create_if_missing: - await _helpers.ensure_tags_exist(session, norm, tag_type="user") - - # Snapshot current links - current = { - tag_name - for (tag_name,) in ( - await session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - } - - want = set(norm) - to_add = sorted(want - current) - - if to_add: - async with session.begin_nested() as nested: - try: - session.add_all( - [ - AssetInfoTag( - asset_info_id=asset_info_id, - tag_name=t, - origin=origin, - added_at=utcnow(), - ) - for t in to_add - ] - ) - await session.flush() - except IntegrityError: - await nested.rollback() - - after = set(await get_asset_tags(session, asset_info_id=asset_info_id)) - return { - "added": sorted(((after - current) & want)), - "already_present": sorted(want & current), - "total_tags": sorted(after), - } - - -async def remove_tags_from_asset_info( - session: AsyncSession, - *, - asset_info_id: str, - tags: Sequence[str], -) -> dict: - """Removes tags from an AssetInfo. - Returns: {"removed": [...], "not_present": [...], "total_tags": [...]} - """ - info = await session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") - - norm = normalize_tags(tags) - if not norm: - total = await get_asset_tags(session, asset_info_id=asset_info_id) - return {"removed": [], "not_present": [], "total_tags": total} - - existing = { - tag_name - for (tag_name,) in ( - await session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) - ) - ).all() - } - - to_remove = sorted(set(t for t in norm if t in existing)) - not_present = sorted(set(t for t in norm if t not in existing)) - - if to_remove: - await session.execute( - delete(AssetInfoTag) - .where( - AssetInfoTag.asset_info_id == asset_info_id, - AssetInfoTag.tag_name.in_(to_remove), - ) - ) - await session.flush() - - total = await get_asset_tags(session, asset_info_id=asset_info_id) - return {"removed": to_remove, "not_present": not_present, "total_tags": total} - - -async def add_missing_tag_for_asset_hash( - session: AsyncSession, - *, - asset_hash: str, - origin: str = "automatic", -) -> int: - """Ensure every AssetInfo referencing asset_hash has the 'missing' tag. - Returns number of AssetInfos newly tagged. - """ - ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_hash == asset_hash))).scalars().all() - if not ids: - return 0 - - existing = { - asset_info_id - for (asset_info_id,) in ( - await session.execute( - select(AssetInfoTag.asset_info_id).where( - AssetInfoTag.asset_info_id.in_(ids), - AssetInfoTag.tag_name == "missing", - ) - ) - ).all() - } - to_add = [i for i in ids if i not in existing] - if not to_add: - return 0 - - now = utcnow() - session.add_all( - [ - AssetInfoTag(asset_info_id=i, tag_name="missing", origin=origin, added_at=now) - for i in to_add - ] - ) - await session.flush() - return len(to_add) - - -async def remove_missing_tag_for_asset_hash( - session: AsyncSession, - *, - asset_hash: str, -) -> int: - """Remove the 'missing' tag from every AssetInfo referencing asset_hash. - Returns number of link rows removed. - """ - ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_hash == asset_hash))).scalars().all() - if not ids: - return 0 - - res = await session.execute( - delete(AssetInfoTag).where( - AssetInfoTag.asset_info_id.in_(ids), - AssetInfoTag.tag_name == "missing", - ) - ) - await session.flush() - return int(res.rowcount or 0) - - -async def list_cache_states_with_asset_under_prefixes( - session: AsyncSession, - *, - prefixes: Sequence[str], -) -> list[tuple[AssetCacheState, int]]: - """Return (AssetCacheState, size_bytes) tuples for rows whose file_path starts with any of the absolute prefixes.""" - if not prefixes: - return [] - - conds = [] - for p in prefixes: - if not p: - continue - base = os.path.abspath(p) - if not base.endswith(os.sep): - base = base + os.sep - conds.append(AssetCacheState.file_path.like(base + "%")) - - if not conds: - return [] - - rows = ( - await session.execute( - select(AssetCacheState, Asset.size_bytes) - .join(Asset, Asset.hash == AssetCacheState.asset_hash) - .where(sa.or_(*conds)) - .order_by(AssetCacheState.id.asc()) - ) - ).all() - return [(r[0], int(r[1] or 0)) for r in rows] diff --git a/app/database/services/__init__.py b/app/database/services/__init__.py new file mode 100644 index 000000000..aed8815a6 --- /dev/null +++ b/app/database/services/__init__.py @@ -0,0 +1,56 @@ +from .content import ( + check_fs_asset_exists_quick, + compute_hash_and_dedup_for_cache_state, + ensure_seed_for_path, + ingest_fs_asset, + list_cache_states_with_asset_under_prefixes, + list_unhashed_candidates_under_prefixes, + list_verify_candidates_under_prefixes, + redirect_all_references_then_delete_asset, + touch_asset_infos_by_fs_path, +) +from .info import ( + add_tags_to_asset_info, + create_asset_info_for_existing_asset, + delete_asset_info_by_id, + fetch_asset_info_and_asset, + fetch_asset_info_asset_and_tags, + get_asset_tags, + list_asset_infos_page, + list_tags_with_usage, + remove_tags_from_asset_info, + replace_asset_info_metadata_projection, + set_asset_info_preview, + set_asset_info_tags, + touch_asset_info_by_id, + update_asset_info_full, +) +from .queries import ( + asset_exists_by_hash, + asset_info_exists_for_asset_id, + get_asset_by_hash, + get_asset_info_by_id, + get_cache_state_by_asset_id, + list_cache_states_by_asset_id, +) + +__all__ = [ + # queries + "asset_exists_by_hash", "get_asset_by_hash", "get_asset_info_by_id", "asset_info_exists_for_asset_id", + "get_cache_state_by_asset_id", + "list_cache_states_by_asset_id", + # info + "list_asset_infos_page", "create_asset_info_for_existing_asset", "set_asset_info_tags", + "update_asset_info_full", "replace_asset_info_metadata_projection", + "touch_asset_info_by_id", "delete_asset_info_by_id", + "add_tags_to_asset_info", "remove_tags_from_asset_info", + "get_asset_tags", "list_tags_with_usage", "set_asset_info_preview", + "fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags", + # content + "check_fs_asset_exists_quick", "ensure_seed_for_path", + "redirect_all_references_then_delete_asset", + "compute_hash_and_dedup_for_cache_state", + "list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes", + "ingest_fs_asset", "touch_asset_infos_by_fs_path", + "list_cache_states_with_asset_under_prefixes", +] diff --git a/app/database/services/content.py b/app/database/services/content.py new file mode 100644 index 000000000..6cf440342 --- /dev/null +++ b/app/database/services/content.py @@ -0,0 +1,746 @@ +import logging +import os +from datetime import datetime +from typing import Any, Optional, Sequence, Union + +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.dialects import postgresql as d_pg +from sqlalchemy.dialects import sqlite as d_sqlite +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import noload + +from ..._assets_helpers import compute_model_relative_filename, normalize_tags +from ...storage import hashing as hashing_mod +from ..helpers import ( + ensure_tags_exist, + remove_missing_tag_for_asset_id, +) +from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag +from ..timeutil import utcnow +from .info import replace_asset_info_metadata_projection + + +async def check_fs_asset_exists_quick( + session: AsyncSession, + *, + file_path: str, + size_bytes: Optional[int] = None, + mtime_ns: Optional[int] = None, +) -> bool: + """Return True if a cache row exists for this absolute path and (optionally) mtime/size match.""" + locator = os.path.abspath(file_path) + + stmt = ( + sa.select(sa.literal(True)) + .select_from(AssetCacheState) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(AssetCacheState.file_path == locator) + .limit(1) + ) + + conds = [] + if mtime_ns is not None: + conds.append(AssetCacheState.mtime_ns == int(mtime_ns)) + if size_bytes is not None: + conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes))) + + if conds: + stmt = stmt.where(*conds) + + row = (await session.execute(stmt)).first() + return row is not None + + +async def ensure_seed_for_path( + session: AsyncSession, + *, + abs_path: str, + size_bytes: int, + mtime_ns: int, + info_name: str, + tags: Sequence[str], + owner_id: str = "", +) -> str: + """Ensure: Asset(hash=NULL), AssetCacheState(file_path), and AssetInfo exist for the path. Returns asset_id.""" + locator = os.path.abspath(abs_path) + now = utcnow() + + state = ( + await session.execute( + sa.select(AssetCacheState, Asset) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(AssetCacheState.file_path == locator) + .limit(1) + ) + ).first() + if state: + state_row: AssetCacheState = state[0] + asset_row: Asset = state[1] + changed = state_row.mtime_ns is None or int(state_row.mtime_ns) != int(mtime_ns) + if changed: + state_row.mtime_ns = int(mtime_ns) + state_row.needs_verify = True + if asset_row.size_bytes == 0 and size_bytes > 0: + asset_row.size_bytes = int(size_bytes) + return asset_row.id + + # Create new asset (hash=NULL) + asset = Asset(hash=None, size_bytes=int(size_bytes), mime_type=None, created_at=now) + session.add(asset) + await session.flush() # to get id + + cs = AssetCacheState(asset_id=asset.id, file_path=locator, mtime_ns=int(mtime_ns), needs_verify=False) + session.add(cs) + + info = AssetInfo( + owner_id=owner_id, + name=info_name, + asset_id=asset.id, + preview_id=None, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(info) + await session.flush() + + # Attach tags + want = normalize_tags(tags) + if want: + await ensure_tags_exist(session, want, tag_type="user") + session.add_all([ + AssetInfoTag(asset_info_id=info.id, tag_name=t, origin="automatic", added_at=now) + for t in want + ]) + + await session.flush() + return asset.id + + +async def redirect_all_references_then_delete_asset( + session: AsyncSession, + *, + duplicate_asset_id: str, + canonical_asset_id: str, +) -> None: + """ + Safely migrate all references from duplicate_asset_id to canonical_asset_id. + + - If an AssetInfo for (owner_id, name) already exists on the canonical asset, + merge tags, metadata, times, and preview, then delete the duplicate AssetInfo. + - Otherwise, simply repoint the AssetInfo.asset_id. + - Always retarget AssetCacheState rows. + - Finally delete the duplicate Asset row. + """ + if duplicate_asset_id == canonical_asset_id: + return + + # 1) Migrate AssetInfo rows one-by-one to avoid UNIQUE conflicts. + dup_infos = ( + await session.execute( + select(AssetInfo).options(noload(AssetInfo.tags)).where(AssetInfo.asset_id == duplicate_asset_id) + ) + ).unique().scalars().all() + + for info in dup_infos: + # Try to find an existing collision on canonical + existing = ( + await session.execute( + select(AssetInfo) + .options(noload(AssetInfo.tags)) + .where( + AssetInfo.asset_id == canonical_asset_id, + AssetInfo.owner_id == info.owner_id, + AssetInfo.name == info.name, + ) + .limit(1) + ) + ).unique().scalars().first() + + if existing: + # Merge metadata (prefer existing keys, fill gaps from duplicate) + merged_meta = dict(existing.user_metadata or {}) + other_meta = info.user_metadata or {} + for k, v in other_meta.items(): + if k not in merged_meta: + merged_meta[k] = v + if merged_meta != (existing.user_metadata or {}): + await replace_asset_info_metadata_projection( + session, + asset_info_id=existing.id, + user_metadata=merged_meta, + ) + + # Merge tags (union) + existing_tags = { + t for (t,) in ( + await session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == existing.id) + ) + ).all() + } + from_tags = { + t for (t,) in ( + await session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == info.id) + ) + ).all() + } + to_add = sorted(from_tags - existing_tags) + if to_add: + await ensure_tags_exist(session, to_add, tag_type="user") + now = utcnow() + session.add_all([ + AssetInfoTag(asset_info_id=existing.id, tag_name=t, origin="automatic", added_at=now) + for t in to_add + ]) + await session.flush() + + # Merge preview and times + if existing.preview_id is None and info.preview_id is not None: + existing.preview_id = info.preview_id + if info.last_access_time and ( + existing.last_access_time is None or info.last_access_time > existing.last_access_time + ): + existing.last_access_time = info.last_access_time + existing.updated_at = utcnow() + await session.flush() + + # Delete the duplicate AssetInfo (cascades will clean its tags/meta) + await session.delete(info) + await session.flush() + else: + # Simple retarget + info.asset_id = canonical_asset_id + info.updated_at = utcnow() + await session.flush() + + # 2) Repoint cache states and previews + await session.execute( + sa.update(AssetCacheState) + .where(AssetCacheState.asset_id == duplicate_asset_id) + .values(asset_id=canonical_asset_id) + ) + await session.execute( + sa.update(AssetInfo) + .where(AssetInfo.preview_id == duplicate_asset_id) + .values(preview_id=canonical_asset_id) + ) + + # 3) Remove duplicate Asset + dup = await session.get(Asset, duplicate_asset_id) + if dup: + await session.delete(dup) + await session.flush() + + +async def compute_hash_and_dedup_for_cache_state( + session: AsyncSession, + *, + state_id: int, +) -> Optional[str]: + """ + Compute hash for the given cache state, deduplicate, and settle verify cases. + + Returns the asset_id that this state ends up pointing to, or None if file disappeared. + """ + state = await session.get(AssetCacheState, state_id) + if not state: + return None + + path = state.file_path + try: + if not os.path.isfile(path): + # File vanished: drop the state. If the Asset was a seed (hash NULL) + # and has no other states, drop the Asset too. + asset = await session.get(Asset, state.asset_id) + await session.delete(state) + await session.flush() + + if asset and asset.hash is None: + remaining = ( + await session.execute( + sa.select(sa.func.count()) + .select_from(AssetCacheState) + .where(AssetCacheState.asset_id == asset.id) + ) + ).scalar_one() + if int(remaining or 0) == 0: + await session.delete(asset) + await session.flush() + return None + + digest = await hashing_mod.blake3_hash(path) + new_hash = f"blake3:{digest}" + + st = os.stat(path, follow_symlinks=True) + new_size = int(st.st_size) + mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + + # Current asset of this state + this_asset = await session.get(Asset, state.asset_id) + + # If the state got orphaned somehow (race), just reattach appropriately. + if not this_asset: + canonical = ( + await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1)) + ).scalars().first() + if canonical: + state.asset_id = canonical.id + else: + now = utcnow() + new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now) + session.add(new_asset) + await session.flush() + state.asset_id = new_asset.id + state.mtime_ns = mtime_ns + state.needs_verify = False + try: + await remove_missing_tag_for_asset_id(session, asset_id=state.asset_id) + except Exception: + pass + await session.flush() + return state.asset_id + + # 1) Seed asset case (hash is NULL): claim or merge into canonical + if this_asset.hash is None: + canonical = ( + await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1)) + ).scalars().first() + + if canonical and canonical.id != this_asset.id: + # Merge seed asset into canonical (safe, collision-aware) + await redirect_all_references_then_delete_asset( + session, + duplicate_asset_id=this_asset.id, + canonical_asset_id=canonical.id, + ) + # Refresh state after the merge + state = await session.get(AssetCacheState, state_id) + if state: + state.mtime_ns = mtime_ns + state.needs_verify = False + try: + await remove_missing_tag_for_asset_id(session, asset_id=canonical.id) + except Exception: + pass + await session.flush() + return canonical.id + + # No canonical: try to claim the hash; handle races with a SAVEPOINT + try: + async with session.begin_nested(): + this_asset.hash = new_hash + if int(this_asset.size_bytes or 0) == 0 and new_size > 0: + this_asset.size_bytes = new_size + await session.flush() + except IntegrityError: + # Someone else claimed it concurrently; fetch canonical and merge + canonical = ( + await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1)) + ).scalars().first() + if canonical and canonical.id != this_asset.id: + await redirect_all_references_then_delete_asset( + session, + duplicate_asset_id=this_asset.id, + canonical_asset_id=canonical.id, + ) + state = await session.get(AssetCacheState, state_id) + if state: + state.mtime_ns = mtime_ns + state.needs_verify = False + try: + await remove_missing_tag_for_asset_id(session, asset_id=canonical.id) + except Exception: + pass + await session.flush() + return canonical.id + # If we got here, the integrity error was not about hash uniqueness + raise + + # Claimed successfully + state.mtime_ns = mtime_ns + state.needs_verify = False + try: + await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id) + except Exception: + pass + await session.flush() + return this_asset.id + + # 2) Verify case for hashed assets + if this_asset.hash == new_hash: + # Content unchanged; tidy up sizes/mtime + if int(this_asset.size_bytes or 0) == 0 and new_size > 0: + this_asset.size_bytes = new_size + state.mtime_ns = mtime_ns + state.needs_verify = False + try: + await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id) + except Exception: + pass + await session.flush() + return this_asset.id + + # Content changed on this path only: retarget THIS state, do not move AssetInfo rows + canonical = ( + await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1)) + ).scalars().first() + if canonical: + target_id = canonical.id + else: + now = utcnow() + new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now) + session.add(new_asset) + await session.flush() + target_id = new_asset.id + + state.asset_id = target_id + state.mtime_ns = mtime_ns + state.needs_verify = False + try: + await remove_missing_tag_for_asset_id(session, asset_id=target_id) + except Exception: + pass + await session.flush() + return target_id + + except Exception: + # Propagate; caller records the error and continues the worker. + raise + + +async def list_unhashed_candidates_under_prefixes( + session: AsyncSession, *, prefixes: Sequence[str] +) -> list[int]: + if not prefixes: + return [] + + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + conds.append(AssetCacheState.file_path.like(base + "%")) + + rows = ( + await session.execute( + sa.select(AssetCacheState.id) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(Asset.hash.is_(None)) + .where(sa.or_(*conds)) + .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) + ) + ).scalars().all() + seen = set() + result: list[int] = [] + for sid in rows: + st = await session.get(AssetCacheState, sid) + if st and st.asset_id not in seen: + seen.add(st.asset_id) + result.append(sid) + return result + + +async def list_verify_candidates_under_prefixes( + session: AsyncSession, *, prefixes: Sequence[str] +) -> Union[list[int], Sequence[int]]: + if not prefixes: + return [] + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + conds.append(AssetCacheState.file_path.like(base + "%")) + + return ( + await session.execute( + sa.select(AssetCacheState.id) + .where(AssetCacheState.needs_verify.is_(True)) + .where(sa.or_(*conds)) + .order_by(AssetCacheState.id.asc()) + ) + ).scalars().all() + + +async def ingest_fs_asset( + session: AsyncSession, + *, + asset_hash: str, + abs_path: str, + size_bytes: int, + mtime_ns: int, + mime_type: Optional[str] = None, + info_name: Optional[str] = None, + owner_id: str = "", + preview_id: Optional[str] = None, + user_metadata: Optional[dict] = None, + tags: Sequence[str] = (), + tag_origin: str = "manual", + require_existing_tags: bool = False, +) -> dict: + """ + Idempotently upsert: + - Asset by content hash (create if missing) + - AssetCacheState(file_path) pointing to asset_id + - Optionally AssetInfo + tag links and metadata projection + Returns flags and ids. + """ + locator = os.path.abspath(abs_path) + now = utcnow() + + if preview_id: + if not await session.get(Asset, preview_id): + preview_id = None + + out: dict[str, Any] = { + "asset_created": False, + "asset_updated": False, + "state_created": False, + "state_updated": False, + "asset_info_id": None, + } + + # 1) Asset by hash + asset = ( + await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() + if not asset: + async with session.begin_nested(): + asset = Asset(hash=asset_hash, size_bytes=int(size_bytes), mime_type=mime_type, created_at=now) + session.add(asset) + await session.flush() + out["asset_created"] = True + else: + changed = False + if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: + asset.size_bytes = int(size_bytes) + changed = True + if mime_type and asset.mime_type != mime_type: + asset.mime_type = mime_type + changed = True + if changed: + out["asset_updated"] = True + + # 2) AssetCacheState upsert by file_path (unique) + vals = { + "asset_id": asset.id, + "file_path": locator, + "mtime_ns": int(mtime_ns), + } + dialect = session.bind.dialect.name + if dialect == "sqlite": + ins = ( + d_sqlite.insert(AssetCacheState) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + ) + elif dialect == "postgresql": + ins = ( + d_pg.insert(AssetCacheState) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + + res = await session.execute(ins) + if int(res.rowcount or 0) > 0: + out["state_created"] = True + else: + upd = ( + sa.update(AssetCacheState) + .where(AssetCacheState.file_path == locator) + .where( + sa.or_( + AssetCacheState.asset_id != asset.id, + AssetCacheState.mtime_ns.is_(None), + AssetCacheState.mtime_ns != int(mtime_ns), + ) + ) + .values(asset_id=asset.id, mtime_ns=int(mtime_ns)) + ) + res2 = await session.execute(upd) + if int(res2.rowcount or 0) > 0: + out["state_updated"] = True + + # 3) Optional AssetInfo + tags + metadata + if info_name: + # upsert by (asset_id, owner_id, name) + try: + async with session.begin_nested(): + info = AssetInfo( + owner_id=owner_id, + name=info_name, + asset_id=asset.id, + preview_id=preview_id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(info) + await session.flush() + out["asset_info_id"] = info.id + except IntegrityError: + pass + + existing_info = ( + await session.execute( + select(AssetInfo) + .where( + AssetInfo.asset_id == asset.id, + AssetInfo.name == info_name, + (AssetInfo.owner_id == owner_id), + ) + .limit(1) + ) + ).unique().scalar_one_or_none() + if not existing_info: + raise RuntimeError("Failed to update or insert AssetInfo.") + + if preview_id and existing_info.preview_id != preview_id: + existing_info.preview_id = preview_id + + existing_info.updated_at = now + if existing_info.last_access_time < now: + existing_info.last_access_time = now + await session.flush() + out["asset_info_id"] = existing_info.id + + norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] + if norm and out["asset_info_id"] is not None: + if not require_existing_tags: + await ensure_tags_exist(session, norm, tag_type="user") + + existing_tag_names = set( + name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() + ) + missing = [t for t in norm if t not in existing_tag_names] + if missing and require_existing_tags: + raise ValueError(f"Unknown tags: {missing}") + + existing_links = set( + tag_name + for (tag_name,) in ( + await session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"]) + ) + ).all() + ) + to_add = [t for t in norm if t in existing_tag_names and t not in existing_links] + if to_add: + session.add_all( + [ + AssetInfoTag( + asset_info_id=out["asset_info_id"], + tag_name=t, + origin=tag_origin, + added_at=now, + ) + for t in to_add + ] + ) + await session.flush() + + # metadata["filename"] hack + if out["asset_info_id"] is not None: + primary_path = ( + await session.execute( + select(AssetCacheState.file_path) + .where(AssetCacheState.asset_id == asset.id) + .order_by(AssetCacheState.id.asc()) + .limit(1) + ) + ).scalars().first() + computed_filename = compute_model_relative_filename(primary_path) if primary_path else None + + current_meta = existing_info.user_metadata or {} + new_meta = dict(current_meta) + if user_metadata is not None: + for k, v in user_metadata.items(): + new_meta[k] = v + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta != current_meta: + await replace_asset_info_metadata_projection( + session, + asset_info_id=out["asset_info_id"], + user_metadata=new_meta, + ) + + try: + await remove_missing_tag_for_asset_id(session, asset_id=asset.id) + except Exception: + logging.exception("Failed to clear 'missing' tag for asset %s", asset.id) + return out + + +async def touch_asset_infos_by_fs_path( + session: AsyncSession, + *, + file_path: str, + ts: Optional[datetime] = None, + only_if_newer: bool = True, +) -> int: + locator = os.path.abspath(file_path) + ts = ts or utcnow() + + stmt = sa.update(AssetInfo).where( + sa.exists( + sa.select(sa.literal(1)) + .select_from(AssetCacheState) + .where( + AssetCacheState.asset_id == AssetInfo.asset_id, + AssetCacheState.file_path == locator, + ) + ) + ) + + if only_if_newer: + stmt = stmt.where( + sa.or_( + AssetInfo.last_access_time.is_(None), + AssetInfo.last_access_time < ts, + ) + ) + + stmt = stmt.values(last_access_time=ts) + + res = await session.execute(stmt) + return int(res.rowcount or 0) + + +async def list_cache_states_with_asset_under_prefixes( + session: AsyncSession, + *, + prefixes: Sequence[str], +) -> list[tuple[AssetCacheState, Optional[str], int]]: + """Return (AssetCacheState, asset_hash, size_bytes) for rows under any prefix.""" + if not prefixes: + return [] + + conds = [] + for p in prefixes: + if not p: + continue + base = os.path.abspath(p) + if not base.endswith(os.sep): + base = base + os.sep + conds.append(AssetCacheState.file_path.like(base + "%")) + + if not conds: + return [] + + rows = ( + await session.execute( + select(AssetCacheState, Asset.hash, Asset.size_bytes) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(sa.or_(*conds)) + .order_by(AssetCacheState.id.asc()) + ) + ).all() + return [(r[0], r[1], int(r[2] or 0)) for r in rows] diff --git a/app/database/services/info.py b/app/database/services/info.py new file mode 100644 index 000000000..e3da1bc8e --- /dev/null +++ b/app/database/services/info.py @@ -0,0 +1,579 @@ +from collections import defaultdict +from datetime import datetime +from typing import Any, Optional, Sequence + +import sqlalchemy as sa +from sqlalchemy import delete, func, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import contains_eager, noload + +from ..._assets_helpers import compute_model_relative_filename, normalize_tags +from ..helpers import ( + apply_metadata_filter, + apply_tag_filters, + ensure_tags_exist, + project_kv, + visible_owner_clause, +) +from ..models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag +from ..timeutil import utcnow +from .queries import get_asset_by_hash, get_cache_state_by_asset_id + + +async def list_asset_infos_page( + session: AsyncSession, + *, + owner_id: str = "", + include_tags: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + name_contains: Optional[str] = None, + metadata_filter: Optional[dict] = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", +) -> tuple[list[AssetInfo], dict[str, list[str]], int]: + base = ( + select(AssetInfo) + .join(Asset, Asset.id == AssetInfo.asset_id) + .options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags)) + .where(visible_owner_clause(owner_id)) + ) + + if name_contains: + base = base.where(AssetInfo.name.ilike(f"%{name_contains}%")) + + base = apply_tag_filters(base, include_tags, exclude_tags) + base = apply_metadata_filter(base, metadata_filter) + + sort = (sort or "created_at").lower() + order = (order or "desc").lower() + sort_map = { + "name": AssetInfo.name, + "created_at": AssetInfo.created_at, + "updated_at": AssetInfo.updated_at, + "last_access_time": AssetInfo.last_access_time, + "size": Asset.size_bytes, + } + sort_col = sort_map.get(sort, AssetInfo.created_at) + sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() + + base = base.order_by(sort_exp).limit(limit).offset(offset) + + count_stmt = ( + select(func.count()) + .select_from(AssetInfo) + .join(Asset, Asset.id == AssetInfo.asset_id) + .where(visible_owner_clause(owner_id)) + ) + if name_contains: + count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%")) + count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = apply_metadata_filter(count_stmt, metadata_filter) + + total = int((await session.execute(count_stmt)).scalar_one() or 0) + + infos = (await session.execute(base)).unique().scalars().all() + + id_list: list[str] = [i.id for i in infos] + tag_map: dict[str, list[str]] = defaultdict(list) + if id_list: + rows = await session.execute( + select(AssetInfoTag.asset_info_id, Tag.name) + .join(Tag, Tag.name == AssetInfoTag.tag_name) + .where(AssetInfoTag.asset_info_id.in_(id_list)) + ) + for aid, tag_name in rows.all(): + tag_map[aid].append(tag_name) + + return infos, tag_map, total + + +async def fetch_asset_info_and_asset( + session: AsyncSession, + *, + asset_info_id: str, + owner_id: str = "", +) -> Optional[tuple[AssetInfo, Asset]]: + stmt = ( + select(AssetInfo, Asset) + .join(Asset, Asset.id == AssetInfo.asset_id) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + .limit(1) + .options(noload(AssetInfo.tags)) + ) + row = await session.execute(stmt) + pair = row.first() + if not pair: + return None + return pair[0], pair[1] + + +async def fetch_asset_info_asset_and_tags( + session: AsyncSession, + *, + asset_info_id: str, + owner_id: str = "", +) -> Optional[tuple[AssetInfo, Asset, list[str]]]: + stmt = ( + select(AssetInfo, Asset, Tag.name) + .join(Asset, Asset.id == AssetInfo.asset_id) + .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) + .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + .options(noload(AssetInfo.tags)) + .order_by(Tag.name.asc()) + ) + + rows = (await session.execute(stmt)).all() + if not rows: + return None + + first_info, first_asset, _ = rows[0] + tags: list[str] = [] + seen: set[str] = set() + for _info, _asset, tag_name in rows: + if tag_name and tag_name not in seen: + seen.add(tag_name) + tags.append(tag_name) + return first_info, first_asset, tags + + +async def create_asset_info_for_existing_asset( + session: AsyncSession, + *, + asset_hash: str, + name: str, + user_metadata: Optional[dict] = None, + tags: Optional[Sequence[str]] = None, + tag_origin: str = "manual", + owner_id: str = "", +) -> AssetInfo: + """Create or return an existing AssetInfo for an Asset identified by asset_hash.""" + now = utcnow() + asset = await get_asset_by_hash(session, asset_hash=asset_hash) + if not asset: + raise ValueError(f"Unknown asset hash {asset_hash}") + + info = AssetInfo( + owner_id=owner_id, + name=name, + asset_id=asset.id, + preview_id=None, + created_at=now, + updated_at=now, + last_access_time=now, + ) + try: + async with session.begin_nested(): + session.add(info) + await session.flush() + except IntegrityError: + existing = ( + await session.execute( + select(AssetInfo) + .options(noload(AssetInfo.tags)) + .where( + AssetInfo.asset_id == asset.id, + AssetInfo.name == name, + AssetInfo.owner_id == owner_id, + ) + .limit(1) + ) + ).unique().scalars().first() + if not existing: + raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.") + return existing + + # metadata["filename"] hack + new_meta = dict(user_metadata or {}) + computed_filename = None + try: + state = await get_cache_state_by_asset_id(session, asset_id=asset.id) + if state and state.file_path: + computed_filename = compute_model_relative_filename(state.file_path) + except Exception: + computed_filename = None + if computed_filename: + new_meta["filename"] = computed_filename + if new_meta: + await replace_asset_info_metadata_projection( + session, + asset_info_id=info.id, + user_metadata=new_meta, + ) + + if tags is not None: + await set_asset_info_tags( + session, + asset_info_id=info.id, + tags=tags, + origin=tag_origin, + ) + return info + + +async def set_asset_info_tags( + session: AsyncSession, + *, + asset_info_id: str, + tags: Sequence[str], + origin: str = "manual", +) -> dict: + desired = normalize_tags(tags) + + current = set( + tag_name for (tag_name,) in ( + await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) + ).all() + ) + + to_add = [t for t in desired if t not in current] + to_remove = [t for t in current if t not in desired] + + if to_add: + await ensure_tags_exist(session, to_add, tag_type="user") + session.add_all([ + AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) + for t in to_add + ]) + await session.flush() + + if to_remove: + await session.execute( + delete(AssetInfoTag) + .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) + ) + await session.flush() + + return {"added": to_add, "removed": to_remove, "total": desired} + + +async def update_asset_info_full( + session: AsyncSession, + *, + asset_info_id: str, + name: Optional[str] = None, + tags: Optional[Sequence[str]] = None, + user_metadata: Optional[dict] = None, + tag_origin: str = "manual", + asset_info_row: Any = None, +) -> AssetInfo: + if not asset_info_row: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + else: + info = asset_info_row + + touched = False + if name is not None and name != info.name: + info.name = name + touched = True + + computed_filename = None + try: + state = await get_cache_state_by_asset_id(session, asset_id=info.asset_id) + if state and state.file_path: + computed_filename = compute_model_relative_filename(state.file_path) + except Exception: + computed_filename = None + + if user_metadata is not None: + new_meta = dict(user_metadata) + if computed_filename: + new_meta["filename"] = computed_filename + await replace_asset_info_metadata_projection( + session, asset_info_id=asset_info_id, user_metadata=new_meta + ) + touched = True + else: + if computed_filename: + current_meta = info.user_metadata or {} + if current_meta.get("filename") != computed_filename: + new_meta = dict(current_meta) + new_meta["filename"] = computed_filename + await replace_asset_info_metadata_projection( + session, asset_info_id=asset_info_id, user_metadata=new_meta + ) + touched = True + + if tags is not None: + await set_asset_info_tags( + session, + asset_info_id=asset_info_id, + tags=tags, + origin=tag_origin, + ) + touched = True + + if touched and user_metadata is None: + info.updated_at = utcnow() + await session.flush() + + return info + + +async def replace_asset_info_metadata_projection( + session: AsyncSession, + *, + asset_info_id: str, + user_metadata: Optional[dict], +) -> None: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + info.user_metadata = user_metadata or {} + info.updated_at = utcnow() + await session.flush() + + await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) + await session.flush() + + if not user_metadata: + return + + rows: list[AssetInfoMeta] = [] + for k, v in user_metadata.items(): + for r in project_kv(k, v): + rows.append( + AssetInfoMeta( + asset_info_id=asset_info_id, + key=r["key"], + ordinal=int(r["ordinal"]), + val_str=r.get("val_str"), + val_num=r.get("val_num"), + val_bool=r.get("val_bool"), + val_json=r.get("val_json"), + ) + ) + if rows: + session.add_all(rows) + await session.flush() + + +async def touch_asset_info_by_id( + session: AsyncSession, + *, + asset_info_id: str, + ts: Optional[datetime] = None, + only_if_newer: bool = True, +) -> int: + ts = ts or utcnow() + stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) + if only_if_newer: + stmt = stmt.where( + sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) + ) + stmt = stmt.values(last_access_time=ts) + res = await session.execute(stmt) + return int(res.rowcount or 0) + + +async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool: + res = await session.execute(delete(AssetInfo).where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + )) + return bool(res.rowcount) + + +async def add_tags_to_asset_info( + session: AsyncSession, + *, + asset_info_id: str, + tags: Sequence[str], + origin: str = "manual", + create_if_missing: bool = True, + asset_info_row: Any = None, +) -> dict: + if not asset_info_row: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"added": [], "already_present": [], "total_tags": total} + + if create_if_missing: + await ensure_tags_exist(session, norm, tag_type="user") + + current = { + tag_name + for (tag_name,) in ( + await session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + want = set(norm) + to_add = sorted(want - current) + + if to_add: + async with session.begin_nested() as nested: + try: + session.add_all( + [ + AssetInfoTag( + asset_info_id=asset_info_id, + tag_name=t, + origin=origin, + added_at=utcnow(), + ) + for t in to_add + ] + ) + await session.flush() + except IntegrityError: + await nested.rollback() + + after = set(await get_asset_tags(session, asset_info_id=asset_info_id)) + return { + "added": sorted(((after - current) & want)), + "already_present": sorted(want & current), + "total_tags": sorted(after), + } + + +async def remove_tags_from_asset_info( + session: AsyncSession, + *, + asset_info_id: str, + tags: Sequence[str], +) -> dict: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": [], "not_present": [], "total_tags": total} + + existing = { + tag_name + for (tag_name,) in ( + await session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + to_remove = sorted(set(t for t in norm if t in existing)) + not_present = sorted(set(t for t in norm if t not in existing)) + + if to_remove: + await session.execute( + delete(AssetInfoTag) + .where( + AssetInfoTag.asset_info_id == asset_info_id, + AssetInfoTag.tag_name.in_(to_remove), + ) + ) + await session.flush() + + total = await get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": to_remove, "not_present": not_present, "total_tags": total} + + +async def list_tags_with_usage( + session: AsyncSession, + *, + prefix: Optional[str] = None, + limit: int = 100, + offset: int = 0, + include_zero: bool = True, + order: str = "count_desc", + owner_id: str = "", +) -> tuple[list[tuple[str, str, int]], int]: + counts_sq = ( + select( + AssetInfoTag.tag_name.label("tag_name"), + func.count(AssetInfoTag.asset_info_id).label("cnt"), + ) + .select_from(AssetInfoTag) + .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) + .where(visible_owner_clause(owner_id)) + .group_by(AssetInfoTag.tag_name) + .subquery() + ) + + q = ( + select( + Tag.name, + Tag.tag_type, + func.coalesce(counts_sq.c.cnt, 0).label("count"), + ) + .select_from(Tag) + .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) + ) + + if prefix: + q = q.where(Tag.name.like(prefix.strip().lower() + "%")) + + if not include_zero: + q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) + + if order == "name_asc": + q = q.order_by(Tag.name.asc()) + else: + q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) + + total_q = select(func.count()).select_from(Tag) + if prefix: + total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%")) + if not include_zero: + total_q = total_q.where( + Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) + ) + + rows = (await session.execute(q.limit(limit).offset(offset))).all() + total = (await session.execute(total_q)).scalar_one() + + rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + return rows_norm, int(total or 0) + + +async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]: + return [ + tag_name + for (tag_name,) in ( + await session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + ] + + +async def set_asset_info_preview( + session: AsyncSession, + *, + asset_info_id: str, + preview_asset_id: Optional[str], +) -> None: + """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + if preview_asset_id is None: + info.preview_id = None + else: + # validate preview asset exists + if not await session.get(Asset, preview_asset_id): + raise ValueError(f"Preview Asset {preview_asset_id} not found") + info.preview_id = preview_asset_id + + info.updated_at = utcnow() + await session.flush() diff --git a/app/database/services/queries.py b/app/database/services/queries.py new file mode 100644 index 000000000..81649b7f4 --- /dev/null +++ b/app/database/services/queries.py @@ -0,0 +1,59 @@ +from typing import Optional, Sequence, Union + +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from ..models import Asset, AssetCacheState, AssetInfo + + +async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool: + row = ( + await session.execute( + select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).first() + return row is not None + + +async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]: + return ( + await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() + + +async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]: + return await session.get(AssetInfo, asset_info_id) + + +async def asset_info_exists_for_asset_id(session: AsyncSession, *, asset_id: str) -> bool: + q = ( + select(sa.literal(True)) + .select_from(AssetInfo) + .where(AssetInfo.asset_id == asset_id) + .limit(1) + ) + return (await session.execute(q)).first() is not None + + +async def get_cache_state_by_asset_id(session: AsyncSession, *, asset_id: str) -> Optional[AssetCacheState]: + return ( + await session.execute( + select(AssetCacheState) + .where(AssetCacheState.asset_id == asset_id) + .order_by(AssetCacheState.id.asc()) + .limit(1) + ) + ).scalars().first() + + +async def list_cache_states_by_asset_id( + session: AsyncSession, *, asset_id: str +) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]: + return ( + await session.execute( + select(AssetCacheState) + .where(AssetCacheState.asset_id == asset_id) + .order_by(AssetCacheState.id.asc()) + ) + ).scalars().all() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 5e301b505..d814e453a 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -212,7 +212,6 @@ database_default_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.") -parser.add_argument("--enable-model-processing", action="store_true", help="Enable automatic processing of the model file, such as calculating hashes and populating the database.") parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.") if comfy.options.args_parsing: diff --git a/main.py b/main.py index 3485a7c76..db0ee04f5 100644 --- a/main.py +++ b/main.py @@ -279,11 +279,11 @@ def cleanup_temp(): shutil.rmtree(temp_dir, ignore_errors=True) async def setup_database(): - from app import init_db_engine, start_background_assets_scan + from app import init_db_engine, sync_seed_assets await init_db_engine() if not args.disable_assets_autoscan: - await start_background_assets_scan() + await sync_seed_assets(["models", "input", "output"]) def start_comfyui(asyncio_loop=None): diff --git a/server.py b/server.py index d3a0f8628..ddd188ebc 100644 --- a/server.py +++ b/server.py @@ -37,7 +37,7 @@ from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes -from app.api.assets_routes import register_assets_system +from app import sync_seed_assets, register_assets_system from protocol import BinaryEventTypes async def send_socket_catch_exception(function, message): @@ -629,6 +629,7 @@ class PromptServer(): @routes.get("/object_info") async def get_object_info(request): + await sync_seed_assets(["models"]) with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: diff --git a/tests-assets/test_crud.py b/tests-assets/test_crud.py index 99ea329c5..1e5928150 100644 --- a/tests-assets/test_crud.py +++ b/tests-assets/test_crud.py @@ -118,6 +118,16 @@ async def test_head_asset_by_hash(http: aiohttp.ClientSession, api_base: str, se assert rh2.status == 404 +@pytest.mark.asyncio +async def test_head_asset_bad_hash_returns_400_and_no_body(http: aiohttp.ClientSession, api_base: str): + # Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload. + # aiohttp exposes an empty body for HEAD, so validate status and that there is no payload. + async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh: + assert rh.status == 400 + body = await rh.read() + assert body == b"" + + @pytest.mark.asyncio async def test_delete_nonexistent_returns_404(http: aiohttp.ClientSession, api_base: str): bogus = str(uuid.uuid4()) @@ -166,12 +176,3 @@ async def test_update_requires_at_least_one_field(http: aiohttp.ClientSession, a body = await r.json() assert r.status == 400 assert body["error"]["code"] == "INVALID_BODY" - - -@pytest.mark.asyncio -async def test_head_asset_bad_hash(http: aiohttp.ClientSession, api_base: str): - # Invalid format - async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh3: - jb = await rh3.json() - assert rh3.status == 400 - assert jb is None # HEAD request should not include "body" in response diff --git a/tests-assets/test_tags.py b/tests-assets/test_tags.py index bba91581f..aede764da 100644 --- a/tests-assets/test_tags.py +++ b/tests-assets/test_tags.py @@ -66,23 +66,32 @@ async def test_add_and_remove_tags(http: aiohttp.ClientSession, api_base: str, s async with http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add) as r1: b1 = await r1.json() assert r1.status == 200, b1 - # normalized and deduplicated - assert "newtag" in b1["added"] or "beta" in b1["added"] or "unit-tests" not in b1["added"] + # normalized, deduplicated; 'unit-tests' was already present from the seed + assert set(b1["added"]) == {"newtag", "beta"} + assert set(b1["already_present"]) == {"unit-tests"} + assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"] async with http.get(f"{api_base}/api/assets/{aid}") as rg: g = await rg.json() assert rg.status == 200 tags_now = set(g["tags"]) - assert "newtag" in tags_now - assert "beta" in tags_now + assert {"newtag", "beta"}.issubset(tags_now) # Remove a tag and a non-existent tag payload_del = {"tags": ["newtag", "does-not-exist"]} async with http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del) as r2: b2 = await r2.json() assert r2.status == 200 - assert "newtag" in b2["removed"] - assert "does-not-exist" in b2["not_present"] + assert set(b2["removed"]) == {"newtag"} + assert set(b2["not_present"]) == {"does-not-exist"} + + # Verify remaining tags after deletion + async with http.get(f"{api_base}/api/assets/{aid}") as rg2: + g2 = await rg2.json() + assert rg2.status == 200 + tags_later = set(g2["tags"]) + assert "newtag" not in tags_later + assert "beta" in tags_later # still present @pytest.mark.asyncio diff --git a/tests-assets/test_uploads.py b/tests-assets/test_uploads.py index 1d8df4e40..3bfb62ca4 100644 --- a/tests-assets/test_uploads.py +++ b/tests-assets/test_uploads.py @@ -206,7 +206,7 @@ async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_b body = await r.json() assert r.status == 400 assert body["error"]["code"] == "INVALID_BODY" - assert "unknown models category" in body["error"]["message"] or "unknown model category" in body["error"]["message"] + assert body["error"]["message"].startswith("unknown models category") @pytest.mark.asyncio