From f92307cd4c13a3b1981cc818fce0a56280d3e404 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 19 Aug 2025 19:56:59 +0300 Subject: [PATCH] dev: Everything is Assets --- alembic_db/versions/0001_assets.py | 158 +++++ alembic_db/versions/e9c714da8d57_init.py | 40 -- app/api/__init__.py | 0 app/api/assets_routes.py | 110 ++++ app/assets_manager.py | 148 +++++ app/database/__init__.py | 0 app/database/db.py | 293 +++++++-- app/database/models.py | 284 ++++++-- app/database/services.py | 683 ++++++++++++++++++++ app/model_processor.py | 331 ---------- app/storage/__init__.py | 0 app/storage/hashing.py | 72 +++ comfy/asset_management.py | 110 ---- comfy/cli_args.py | 2 +- comfy/utils.py | 6 - comfy_extras/nodes_assets_test.py | 56 -- main.py | 9 +- nodes.py | 7 +- requirements.txt | 1 + server.py | 2 + tests-unit/app_test/model_manager_test.py | 62 -- tests-unit/app_test/model_processor_test.py | 253 -------- 22 files changed, 1650 insertions(+), 977 deletions(-) create mode 100644 alembic_db/versions/0001_assets.py delete mode 100644 alembic_db/versions/e9c714da8d57_init.py create mode 100644 app/api/__init__.py create mode 100644 app/api/assets_routes.py create mode 100644 app/assets_manager.py create mode 100644 app/database/__init__.py create mode 100644 app/database/services.py delete mode 100644 app/model_processor.py create mode 100644 app/storage/__init__.py create mode 100644 app/storage/hashing.py delete mode 100644 comfy/asset_management.py delete mode 100644 comfy_extras/nodes_assets_test.py delete mode 100644 tests-unit/app_test/model_manager_test.py delete mode 100644 tests-unit/app_test/model_processor_test.py diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py new file mode 100644 index 000000000..6705d8122 --- /dev/null +++ b/alembic_db/versions/0001_assets.py @@ -0,0 +1,158 @@ +"""initial assets schema + per-asset state cache + +Revision ID: 0001_assets +Revises: +Create Date: 2025-08-20 00:00:00 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0001_assets" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ASSETS: content identity (deduplicated by hash) + op.create_table( + "assets", + sa.Column("hash", sa.String(length=128), primary_key=True), + sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), + sa.Column("mime_type", sa.String(length=255), nullable=True), + sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"), + sa.Column("storage_backend", sa.String(length=32), nullable=False, server_default="fs"), + sa.Column("storage_locator", sa.Text(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), + sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"), + ) + op.create_index("ix_assets_mime_type", "assets", ["mime_type"]) + op.create_index("ix_assets_backend_locator", "assets", ["storage_backend", "storage_locator"]) + + # ASSETS_INFO: user-visible references (mutable metadata) + op.create_table( + "assets_info", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("owner_id", sa.String(length=128), nullable=True), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sa.Column("last_access_time", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sqlite_autoincrement=True, + ) + op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) + op.create_index("ix_assets_info_asset_hash", "assets_info", ["asset_hash"]) + op.create_index("ix_assets_info_name", "assets_info", ["name"]) + op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) + op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) + + # TAGS: normalized tag vocabulary + op.create_table( + "tags", + sa.Column("name", sa.String(length=128), primary_key=True), + sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"), + sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"), + ) + op.create_index("ix_tags_tag_type", "tags", ["tag_type"]) + + # ASSET_INFO_TAGS: many-to-many for tags on AssetInfo + op.create_table( + "asset_info_tags", + sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("tag_name", sa.String(length=128), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), + sa.Column("added_by", sa.String(length=128), nullable=True), + sa.Column("added_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), + sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"), + ) + op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) + op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) + + # ASSET_LOCATOR_STATE: 1:1 filesystem metadata(for fast integrity checking) for an Asset records + op.create_table( + "asset_locator_state", + sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True), + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column("etag", sa.String(length=256), nullable=True), + sa.Column("last_modified", sa.String(length=128), nullable=True), + sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"), + ) + + # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting + op.create_table( + "asset_info_meta", + sa.Column("asset_info_id", sa.Integer(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"), + ) + op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"]) + op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"]) + op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) + op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) + + # Tags vocabulary for models + tags_table = sa.table( + "tags", + sa.column("name", sa.String()), + sa.column("tag_type", sa.String()), + ) + op.bulk_insert( + tags_table, + [ + # Core concept tags + {"name": "models", "tag_type": "system"}, + + # Canonical single-word types + {"name": "checkpoint", "tag_type": "system"}, + {"name": "lora", "tag_type": "system"}, + {"name": "vae", "tag_type": "system"}, + {"name": "text-encoder", "tag_type": "system"}, + {"name": "clip-vision", "tag_type": "system"}, + {"name": "embedding", "tag_type": "system"}, + {"name": "controlnet", "tag_type": "system"}, + {"name": "upscale", "tag_type": "system"}, + {"name": "diffusion-model", "tag_type": "system"}, + {"name": "hypernetwork", "tag_type": "system"}, + {"name": "vae_approx", "tag_type": "system"}, + # TODO: decide what to do with: style_models, diffusers, gligen, photomaker, classifiers + ], + ) + + +def downgrade() -> None: + op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") + op.drop_table("asset_info_meta") + + op.drop_table("asset_locator_state") + + op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") + op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags") + op.drop_table("asset_info_tags") + + op.drop_index("ix_tags_tag_type", table_name="tags") + op.drop_table("tags") + + op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") + op.drop_index("ix_assets_info_created_at", table_name="assets_info") + op.drop_index("ix_assets_info_name", table_name="assets_info") + op.drop_index("ix_assets_info_asset_hash", table_name="assets_info") + op.drop_index("ix_assets_info_owner_id", table_name="assets_info") + op.drop_table("assets_info") + + op.drop_index("ix_assets_backend_locator", table_name="assets") + op.drop_index("ix_assets_mime_type", table_name="assets") + op.drop_table("assets") diff --git a/alembic_db/versions/e9c714da8d57_init.py b/alembic_db/versions/e9c714da8d57_init.py deleted file mode 100644 index 995365f90..000000000 --- a/alembic_db/versions/e9c714da8d57_init.py +++ /dev/null @@ -1,40 +0,0 @@ -"""init - -Revision ID: e9c714da8d57 -Revises: -Create Date: 2025-05-30 20:14:33.772039 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = 'e9c714da8d57' -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - """Upgrade schema.""" - op.create_table('model', - sa.Column('type', sa.Text(), nullable=False), - sa.Column('path', sa.Text(), nullable=False), - sa.Column('file_name', sa.Text(), nullable=True), - sa.Column('file_size', sa.Integer(), nullable=True), - sa.Column('hash', sa.Text(), nullable=True), - sa.Column('hash_algorithm', sa.Text(), nullable=True), - sa.Column('source_url', sa.Text(), nullable=True), - sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True), - sa.PrimaryKeyConstraint('type', 'path') - ) - - -def downgrade() -> None: - """Downgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('model') - # ### end Alembic commands ### diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py new file mode 100644 index 000000000..aed1d3cea --- /dev/null +++ b/app/api/assets_routes.py @@ -0,0 +1,110 @@ +import json +from typing import Sequence +from aiohttp import web + +from app import assets_manager + + +ROUTES = web.RouteTableDef() + + +@ROUTES.get("/api/assets") +async def list_assets(request: web.Request) -> web.Response: + q = request.rel_url.query + + include_tags: Sequence[str] = _parse_csv_tags(q.get("include_tags")) + exclude_tags: Sequence[str] = _parse_csv_tags(q.get("exclude_tags")) + name_contains = q.get("name_contains") + + # Optional JSON metadata filter (top-level key equality only for now) + metadata_filter = None + raw_meta = q.get("metadata_filter") + if raw_meta: + try: + metadata_filter = json.loads(raw_meta) + if not isinstance(metadata_filter, dict): + metadata_filter = None + except Exception: + # Silently ignore malformed JSON for first iteration; could 400 in future + metadata_filter = None + + limit = _parse_int(q.get("limit"), default=20, lo=1, hi=100) + offset = _parse_int(q.get("offset"), default=0, lo=0, hi=10_000_000) + sort = q.get("sort", "created_at") + order = q.get("order", "desc") + + payload = await assets_manager.list_assets( + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + offset=offset, + sort=sort, + order=order, + ) + return web.json_response(payload) + + +@ROUTES.put("/api/assets/{id}") +async def update_asset(request: web.Request) -> web.Response: + asset_info_id_raw = request.match_info.get("id") + try: + asset_info_id = int(asset_info_id_raw) + except Exception: + return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") + + try: + payload = await request.json() + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + name = payload.get("name", None) + tags = payload.get("tags", None) + user_metadata = payload.get("user_metadata", None) + + if name is None and tags is None and user_metadata is None: + return _error_response(400, "NO_FIELDS", "Provide at least one of: name, tags, user_metadata.") + + if tags is not None and (not isinstance(tags, list) or not all(isinstance(t, str) for t in tags)): + return _error_response(400, "INVALID_TAGS", "Field 'tags' must be an array of strings.") + + if user_metadata is not None and not isinstance(user_metadata, dict): + return _error_response(400, "INVALID_METADATA", "Field 'user_metadata' must be an object.") + + try: + result = await assets_manager.update_asset( + asset_info_id=asset_info_id, + name=name, + tags=tags, + user_metadata=user_metadata, + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(result, status=200) + + +def register_assets_routes(app: web.Application) -> None: + app.add_routes(ROUTES) + + +def _parse_csv_tags(raw: str | None) -> list[str]: + if not raw: + return [] + return [t.strip() for t in raw.split(",") if t.strip()] + + +def _parse_int(qval: str | None, default: int, lo: int, hi: int) -> int: + if not qval: + return default + try: + v = int(qval) + except Exception: + return default + return max(lo, min(hi, v)) + + +def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response: + return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status) diff --git a/app/assets_manager.py b/app/assets_manager.py new file mode 100644 index 000000000..1cccd6acb --- /dev/null +++ b/app/assets_manager.py @@ -0,0 +1,148 @@ +import os +from datetime import datetime, timezone +from typing import Optional, Sequence + +from .database.db import create_session +from .storage import hashing +from .database.services import ( + check_fs_asset_exists_quick, + ingest_fs_asset, + touch_asset_infos_by_fs_path, + list_asset_infos_page, + update_asset_info_full, + get_asset_tags, +) + + +def get_size_mtime_ns(path: str) -> tuple[int, int]: + st = os.stat(path, follow_symlinks=True) + return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + + +async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None: + """Adds a local asset to the DB. If already present and unchanged, does nothing. + + Notes: + - Uses absolute path as the canonical locator for the 'fs' backend. + - Computes BLAKE3 only when the fast existence check indicates it's needed. + - This function ensures the identity row and seeds mtime in asset_locator_state. + """ + abs_path = os.path.abspath(file_path) + size_bytes, mtime_ns = get_size_mtime_ns(abs_path) + if not size_bytes: + return + + async with await create_session() as session: + if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns): + await touch_asset_infos_by_fs_path(session, abs_path=abs_path, ts=datetime.now(timezone.utc)) + await session.commit() + return + + asset_hash = hashing.blake3_hash_sync(abs_path) + + async with await create_session() as session: + await ingest_fs_asset( + session, + asset_hash="blake3:" + asset_hash, + abs_path=abs_path, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=None, + info_name=file_name, + tag_origin="automatic", + tags=tags, + ) + await session.commit() + + +async def list_assets( + *, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: Optional[str] = None, + metadata_filter: Optional[dict] = None, + limit: int = 20, + offset: int = 0, + sort: str | None = "created_at", + order: str | None = "desc", +) -> dict: + sort = _safe_sort_field(sort) + order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() + + async with await create_session() as session: + infos, tag_map, total = await list_asset_infos_page( + session, + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + offset=offset, + sort=sort, + order=order, + ) + + assets_json = [] + for info in infos: + asset = info.asset # populated via contains_eager + tags = tag_map.get(info.id, []) + assets_json.append( + { + "id": info.id, + "name": info.name, + "asset_hash": info.asset_hash, + "size": int(asset.size_bytes) if asset else None, + "mime_type": asset.mime_type if asset else None, + "tags": tags, + "preview_url": f"/api/v1/assets/{info.id}/content", # TODO: implement actual content endpoint later + "created_at": info.created_at.isoformat() if info.created_at else None, + "updated_at": info.updated_at.isoformat() if info.updated_at else None, + "last_access_time": info.last_access_time.isoformat() if info.last_access_time else None, + } + ) + + return { + "assets": assets_json, + "total": total, + "has_more": (offset + len(assets_json)) < total, + } + + +async def update_asset( + *, + asset_info_id: int, + name: str | None = None, + tags: list[str] | None = None, + user_metadata: dict | None = None, +) -> dict: + async with await create_session() as session: + info = await update_asset_info_full( + session, + asset_info_id=asset_info_id, + name=name, + tags=tags, + user_metadata=user_metadata, + tag_origin="manual", + added_by=None, + ) + + tag_names = await get_asset_tags(session, asset_info_id=asset_info_id) + await session.commit() + + return { + "id": info.id, + "name": info.name, + "asset_hash": info.asset_hash, + "tags": tag_names, + "user_metadata": info.user_metadata or {}, + "updated_at": info.updated_at.isoformat() if info.updated_at else None, + } + + +def _safe_sort_field(requested: str | None) -> str: + if not requested: + return "created_at" + v = requested.lower() + if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: + return v + return "created_at" diff --git a/app/database/__init__.py b/app/database/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/database/db.py b/app/database/db.py index 1de8b80ed..2a619f13b 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -1,112 +1,267 @@ import logging import os import shutil +from contextlib import asynccontextmanager +from typing import Optional + from app.logger import log_startup_warning from utils.install_util import get_missing_requirements_message from comfy.cli_args import args -_DB_AVAILABLE = False -Session = None +LOGGER = logging.getLogger(__name__) +# Attempt imports which may not exist in some environments try: from alembic import command from alembic.config import Config from alembic.runtime.migration import MigrationContext from alembic.script import ScriptDirectory - from sqlalchemy import create_engine - from sqlalchemy.orm import sessionmaker + from sqlalchemy import create_engine, text + from sqlalchemy.engine import make_url + from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine _DB_AVAILABLE = True + ENGINE: AsyncEngine | None = None + SESSION: async_sessionmaker | None = None except ImportError as e: log_startup_warning( - f""" ------------------------------------------------------------------------- -Error importing dependencies: {e} -{get_missing_requirements_message()} -This error is happening because ComfyUI now uses a local sqlite database. ------------------------------------------------------------------------- -""".strip() + ( + "------------------------------------------------------------------------\n" + f"Error importing DB dependencies: {e}\n" + f"{get_missing_requirements_message()}\n" + "This error is happening because ComfyUI now uses a local database.\n" + "------------------------------------------------------------------------" + ).strip() ) + _DB_AVAILABLE = False + ENGINE = None + SESSION = None -def dependencies_available(): - """ - Temporary function to check if the dependencies are available - """ +def dependencies_available() -> bool: + """Check if DB dependencies are importable.""" return _DB_AVAILABLE -def can_create_session(): - """ - Temporary function to check if the database is available to create a session - During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created - """ - return dependencies_available() and Session is not None - - -def get_alembic_config(): - root_path = os.path.join(os.path.dirname(__file__), "../..") +def _root_paths(): + """Resolve alembic.ini and migrations script folder.""" + root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) config_path = os.path.abspath(os.path.join(root_path, "alembic.ini")) scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) - - config = Config(config_path) - config.set_main_option("script_location", scripts_path) - config.set_main_option("sqlalchemy.url", args.database_url) - - return config + return config_path, scripts_path -def get_db_path(): - url = args.database_url - if url.startswith("sqlite:///"): - return url.split("///")[1] +def _absolutize_sqlite_url(db_url: str) -> str: + """Make SQLite database path absolute. No-op for non-SQLite URLs.""" + try: + u = make_url(db_url) + except Exception: + return db_url + + if not u.drivername.startswith("sqlite"): + return db_url + + # Make path absolute if relative + db_path = u.database or "" + if not os.path.isabs(db_path): + db_path = os.path.abspath(os.path.join(os.getcwd(), db_path)) + u = u.set(database=db_path) + return str(u) + + +def _to_sync_driver_url(async_url: str) -> str: + """Convert an async SQLAlchemy URL to a sync URL for Alembic.""" + u = make_url(async_url) + driver = u.drivername + + if driver.startswith("sqlite+aiosqlite"): + u = u.set(drivername="sqlite") + elif driver.startswith("postgresql+asyncpg"): + u = u.set(drivername="postgresql") else: - raise ValueError(f"Unsupported database URL '{url}'.") + # Generic: strip the async driver part if present + if "+" in driver: + u = u.set(drivername=driver.split("+", 1)[0]) + + return str(u) -def init_db(): - db_url = args.database_url - logging.debug(f"Database URL: {db_url}") - db_path = get_db_path() - db_exists = os.path.exists(db_path) +def _get_sqlite_file_path(sync_url: str) -> Optional[str]: + """Return the on-disk path for a SQLite URL, else None.""" + try: + u = make_url(sync_url) + except Exception: + return None - config = get_alembic_config() + if not u.drivername.startswith("sqlite"): + return None + return u.database - # Check if we need to upgrade - engine = create_engine(db_url) - conn = engine.connect() - context = MigrationContext.configure(conn) - current_rev = context.get_current_revision() +def _get_alembic_config(sync_url: str) -> Config: + """Prepare Alembic Config with script location and DB URL.""" + config_path, scripts_path = _root_paths() + cfg = Config(config_path) + cfg.set_main_option("script_location", scripts_path) + cfg.set_main_option("sqlalchemy.url", sync_url) + return cfg - script = ScriptDirectory.from_config(config) + +async def init_db_engine() -> None: + """Initialize async engine + sessionmaker and run migrations to head. + + This must be called once on application startup before any DB usage. + """ + global ENGINE, SESSION + + if not dependencies_available(): + raise RuntimeError("Database dependencies are not available.") + + if ENGINE is not None: + return + + raw_url = args.database_url + if not raw_url: + raise RuntimeError("Database URL is not configured.") + + # Absolutize SQLite path for async engine + db_url = _absolutize_sqlite_url(raw_url) + + # Prepare async engine + connect_args = {} + if db_url.startswith("sqlite"): + connect_args = { + "check_same_thread": False, + "timeout": 12, + } + + ENGINE = create_async_engine( + db_url, + connect_args=connect_args, + pool_pre_ping=True, + future=True, + ) + + # Enforce SQLite pragmas on the async engine + if db_url.startswith("sqlite"): + async with ENGINE.begin() as conn: + # WAL for concurrency and durability, Foreign Keys for referential integrity + current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar() + if str(current_mode).lower() != "wal": + new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar() + if str(new_mode).lower() != "wal": + raise RuntimeError("Failed to set SQLite journal mode to WAL.") + LOGGER.info("SQLite journal mode set to WAL.") + + await conn.execute(text("PRAGMA foreign_keys = ON;")) + await conn.execute(text("PRAGMA synchronous = NORMAL;")) + + await _run_migrations(raw_url=db_url) + + SESSION = async_sessionmaker( + bind=ENGINE, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + + +async def _run_migrations(raw_url: str) -> None: + """ + Run Alembic migrations up to head. + + We deliberately use a synchronous engine for migrations because Alembic's + programmatic API is synchronous by default and this path is robust. + """ + # Convert to sync URL and make SQLite URL an absolute one + sync_url = _to_sync_driver_url(raw_url) + sync_url = _absolutize_sqlite_url(sync_url) + + cfg = _get_alembic_config(sync_url) + + # Inspect current and target heads + engine = create_engine(sync_url, future=True) + with engine.connect() as conn: + context = MigrationContext.configure(conn) + current_rev = context.get_current_revision() + + script = ScriptDirectory.from_config(cfg) target_rev = script.get_current_head() if target_rev is None: - logging.warning("No target revision found.") - elif current_rev != target_rev: - # Backup the database pre upgrade - backup_path = db_path + ".bkp" - if db_exists: - shutil.copy(db_path, backup_path) - else: - backup_path = None + LOGGER.warning("Alembic: no target revision found.") + return + if current_rev == target_rev: + LOGGER.debug("Alembic: database already at head %s", target_rev) + return + + LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev) + + # Optional backup for SQLite file DBs + backup_path = None + sqlite_path = _get_sqlite_file_path(sync_url) + if sqlite_path and os.path.exists(sqlite_path): + backup_path = sqlite_path + ".bkp" try: - command.upgrade(config, target_rev) - logging.info(f"Database upgraded from {current_rev} to {target_rev}") - except Exception as e: - if backup_path: - # Restore the database from backup if upgrade fails - shutil.copy(backup_path, db_path) + shutil.copy(sqlite_path, backup_path) + except Exception as exc: + LOGGER.warning("Failed to create SQLite backup before migration: %s", exc) + + try: + command.upgrade(cfg, target_rev) + except Exception: + if backup_path and os.path.exists(backup_path): + LOGGER.exception("Error upgrading database, attempting restore from backup.") + try: + shutil.copy(backup_path, sqlite_path) # restore os.remove(backup_path) - logging.exception("Error upgrading database: ") - raise e - - global Session - Session = sessionmaker(bind=engine) + except Exception as re: + LOGGER.error("Failed to restore SQLite backup: %s", re) + else: + LOGGER.exception("Error upgrading database, backup is not available.") + raise -def create_session(): - return Session() +def get_engine(): + """Return the global async engine (initialized after init_db_engine()).""" + if ENGINE is None: + raise RuntimeError("Engine is not initialized. Call init_db_engine() first.") + return ENGINE + + +def get_session_maker(): + """Return the global async_sessionmaker (initialized after init_db_engine()).""" + if SESSION is None: + raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.") + return SESSION + + +@asynccontextmanager +async def session_scope(): + """Async context manager for a unit of work: + + async with session_scope() as sess: + ... use sess ... + """ + maker = get_session_maker() + async with maker() as sess: + try: + yield sess + await sess.commit() + except Exception: + await sess.rollback() + raise + + +async def create_session(): + """Convenience helper to acquire a single AsyncSession instance. + + Typical usage: + async with (await create_session()) as sess: + ... + """ + maker = get_session_maker() + return maker() diff --git a/app/database/models.py b/app/database/models.py index b0225c412..ca7ad67f8 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,59 +1,257 @@ +from datetime import datetime +from typing import Any, Optional + from sqlalchemy import ( - Column, Integer, - Text, + BigInteger, DateTime, + ForeignKey, + Index, + JSON, + String, + Text, + CheckConstraint, + Numeric, + Boolean, ) -from sqlalchemy.orm import declarative_base from sqlalchemy.sql import func - -Base = declarative_base() +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign -def to_dict(obj): +class Base(DeclarativeBase): + pass + + +def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]: fields = obj.__table__.columns.keys() - return { - field: (val.to_dict() if hasattr(val, "to_dict") else val) - for field in fields - if (val := getattr(obj, field)) - } + out: dict[str, Any] = {} + for field in fields: + val = getattr(obj, field) + if val is None and not include_none: + continue + if isinstance(val, datetime): + out[field] = val.isoformat() + else: + out[field] = val + return out -class Model(Base): - """ - sqlalchemy model representing a model file in the system. +class Asset(Base): + __tablename__ = "assets" - This class defines the database schema for storing information about model files, - including their type, path, hash, and when they were added to the system. + hash: Mapped[str] = mapped_column(String(256), primary_key=True) + size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + mime_type: Mapped[str | None] = mapped_column(String(255)) + refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs") + storage_locator: Mapped[str] = mapped_column(Text, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() + ) - Attributes: - type (Text): The type of the model, this is the name of the folder in the models folder (primary key) - path (Text): The file path of the model relative to the type folder (primary key) - file_name (Text): The name of the model file - file_size (Integer): The size of the model file in bytes - hash (Text): A hash of the model file - hash_algorithm (Text): The algorithm used to generate the hash - source_url (Text): The URL of the model file - date_added (DateTime): Timestamp of when the model was added to the system - """ + infos: Mapped[list["AssetInfo"]] = relationship( + "AssetInfo", + back_populates="asset", + primaryjoin=lambda: Asset.hash == foreign(AssetInfo.asset_hash), + foreign_keys=lambda: [AssetInfo.asset_hash], + cascade="all,delete-orphan", + passive_deletes=True, + ) - __tablename__ = "model" + preview_of: Mapped[list["AssetInfo"]] = relationship( + "AssetInfo", + back_populates="preview_asset", + primaryjoin=lambda: Asset.hash == foreign(AssetInfo.preview_hash), + foreign_keys=lambda: [AssetInfo.preview_hash], + viewonly=True, + ) - type = Column(Text, primary_key=True) - path = Column(Text, primary_key=True) - file_name = Column(Text) - file_size = Column(Integer) - hash = Column(Text) - hash_algorithm = Column(Text) - source_url = Column(Text) - date_added = Column(DateTime, server_default=func.now()) + locator_state: Mapped["AssetLocatorState | None"] = relationship( + back_populates="asset", + uselist=False, + cascade="all, delete-orphan", + passive_deletes=True, + ) - def to_dict(self): - """ - Convert the model instance to a dictionary representation. + __table_args__ = ( + Index("ix_assets_mime_type", "mime_type"), + Index("ix_assets_backend_locator", "storage_backend", "storage_locator"), + ) - Returns: - dict: A dictionary containing the attributes of the model - """ - dict = to_dict(self) - return dict + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + return to_dict(self, include_none=include_none) + + def __repr__(self) -> str: + return f"" + + +class AssetLocatorState(Base): + __tablename__ = "asset_locator_state" + + asset_hash: Mapped[str] = mapped_column( + String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True + ) + # For fs backends: nanosecond mtime; nullable if not applicable + mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + # For HTTP/S3/GCS/Azure, etc.: optional validators + etag: Mapped[str | None] = mapped_column(String(256), nullable=True) + last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True) + + asset: Mapped["Asset"] = relationship(back_populates="locator_state", uselist=False) + + __table_args__ = ( + CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"), + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + return to_dict(self, include_none=include_none) + + def __repr__(self) -> str: + return f"" + + +class AssetInfo(Base): + __tablename__ = "assets_info" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + owner_id: Mapped[str | None] = mapped_column(String(128)) + name: Mapped[str] = mapped_column(String(512), nullable=False) + asset_hash: Mapped[str] = mapped_column( + String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False + ) + preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL")) + user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() + ) + last_access_time: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() + ) + + # Relationships + asset: Mapped[Asset] = relationship( + "Asset", + back_populates="infos", + foreign_keys=[asset_hash], + ) + preview_asset: Mapped[Asset | None] = relationship( + "Asset", + back_populates="preview_of", + foreign_keys=[preview_hash], + ) + + metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship( + back_populates="asset_info", + cascade="all,delete-orphan", + passive_deletes=True, + ) + + tag_links: Mapped[list["AssetInfoTag"]] = relationship( + back_populates="asset_info", + cascade="all,delete-orphan", + passive_deletes=True, + overlaps="tags,asset_infos", + ) + + tags: Mapped[list["Tag"]] = relationship( + secondary="asset_info_tags", + back_populates="asset_infos", + lazy="joined", + viewonly=True, + overlaps="tag_links,asset_info_links,asset_infos,tag", + ) + + __table_args__ = ( + Index("ix_assets_info_owner_id", "owner_id"), + Index("ix_assets_info_asset_hash", "asset_hash"), + Index("ix_assets_info_name", "name"), + Index("ix_assets_info_created_at", "created_at"), + Index("ix_assets_info_last_access_time", "last_access_time"), + {"sqlite_autoincrement": True}, + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + data = to_dict(self, include_none=include_none) + data["tags"] = [t.name for t in self.tags] + return data + + def __repr__(self) -> str: + return f"" + + + +class AssetInfoMeta(Base): + __tablename__ = "asset_info_meta" + + asset_info_id: Mapped[int] = mapped_column( + Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + ) + key: Mapped[str] = mapped_column(String(256), primary_key=True) + ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0) + + val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) + val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True) + val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True) + val_json: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True) + + asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries") + + __table_args__ = ( + Index("ix_asset_info_meta_key", "key"), + Index("ix_asset_info_meta_key_val_str", "key", "val_str"), + Index("ix_asset_info_meta_key_val_num", "key", "val_num"), + Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"), + ) + + +class AssetInfoTag(Base): + __tablename__ = "asset_info_tags" + + asset_info_id: Mapped[int] = mapped_column( + Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + ) + tag_name: Mapped[str] = mapped_column( + String(128), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True + ) + origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") + added_by: Mapped[str | None] = mapped_column(String(128)) + added_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + + asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links") + tag: Mapped["Tag"] = relationship(back_populates="asset_info_links") + + __table_args__ = ( + Index("ix_asset_info_tags_tag_name", "tag_name"), + Index("ix_asset_info_tags_asset_info_id", "asset_info_id"), + ) + + +class Tag(Base): + __tablename__ = "tags" + + name: Mapped[str] = mapped_column(String(128), primary_key=True) + tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") + + asset_info_links: Mapped[list["AssetInfoTag"]] = relationship( + back_populates="tag", + overlaps="asset_infos,tags", + ) + asset_infos: Mapped[list["AssetInfo"]] = relationship( + secondary="asset_info_tags", + back_populates="tags", + viewonly=True, + overlaps="asset_info_links,tag_links,tags,asset_info", + ) + + __table_args__ = ( + Index("ix_tags_tag_type", "tag_type"), + ) + + def __repr__(self) -> str: + return f"" diff --git a/app/database/services.py b/app/database/services.py new file mode 100644 index 000000000..c2792b4c4 --- /dev/null +++ b/app/database/services.py @@ -0,0 +1,683 @@ +import os +import logging +from collections import defaultdict +from datetime import datetime, timezone +from decimal import Decimal +from typing import Any, Sequence, Optional, Iterable + +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, delete, exists, func +from sqlalchemy.orm import contains_eager +from sqlalchemy.exc import IntegrityError + +from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta + + +async def check_fs_asset_exists_quick( + session, + *, + file_path: str, + size_bytes: Optional[int] = None, + mtime_ns: Optional[int] = None, +) -> bool: + """ + Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path, + AND (if provided) mtime_ns matches stored locator-state, + AND (if provided) size_bytes matches verified size when known. + """ + locator = os.path.abspath(file_path) + + stmt = select(sa.literal(True)).select_from(Asset) + + conditions = [ + Asset.storage_backend == "fs", + Asset.storage_locator == locator, + ] + + # If size_bytes provided require equality when the asset has a verified (non-zero) size. + # If verified size is 0 (unknown), we don't force equality. + if size_bytes is not None: + conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes))) + + # If mtime_ns provided require the locator-state to exist and match. + if mtime_ns is not None: + stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash) + conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns)) + + stmt = stmt.where(*conditions).limit(1) + + row = (await session.execute(stmt)).first() + return row is not None + + +async def ingest_fs_asset( + session: AsyncSession, + *, + asset_hash: str, + abs_path: str, + size_bytes: int, + mtime_ns: int, + mime_type: Optional[str] = None, + info_name: Optional[str] = None, + owner_id: Optional[str] = None, + preview_hash: Optional[str] = None, + user_metadata: Optional[dict] = None, + tags: Sequence[str] = (), + tag_origin: str = "manual", + added_by: Optional[str] = None, + require_existing_tags: bool = False, +) -> dict: + """ + Creates or updates Asset record for a local (fs) asset. + + Always: + - Insert Asset if missing; else update size_bytes (and updated_at) if different. + - Insert AssetLocatorState if missing; else update mtime_ns if different. + + Optionally (when info_name is provided): + - Create an AssetInfo (no refcount changes). + - Link provided tags to that AssetInfo. + * If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table. + * If False (default), silently skips unknown tags. + + Returns flags and ids: + { + "asset_created": bool, + "asset_updated": bool, + "state_created": bool, + "state_updated": bool, + "asset_info_id": int | None, + "tags_added": list[str], + "tags_missing": list[str], # filled only when require_existing_tags=False + } + """ + locator = os.path.abspath(abs_path) + datetime_now = datetime.now(timezone.utc) + + out = { + "asset_created": False, + "asset_updated": False, + "state_created": False, + "state_updated": False, + "asset_info_id": None, + "tags_added": [], + "tags_missing": [], + } + + # ---- Step 1: INSERT Asset or UPDATE size_bytes/updated_at if exists ---- + async with session.begin_nested() as sp1: + try: + session.add( + Asset( + hash=asset_hash, + size_bytes=int(size_bytes), + mime_type=mime_type, + refcount=0, + storage_backend="fs", + storage_locator=locator, + created_at=datetime_now, + updated_at=datetime_now, + ) + ) + await session.flush() + out["asset_created"] = True + except IntegrityError: + await sp1.rollback() + # Already exists by hash -> update selected fields if different + existing = await session.get(Asset, asset_hash) + if existing is not None: + desired_size = int(size_bytes) + if existing.size_bytes != desired_size: + existing.size_bytes = desired_size + existing.updated_at = datetime_now + out["asset_updated"] = True + else: + # This should not occur. Log for visibility. + logging.error("Asset %s not found after conflict; skipping update.", asset_hash) + except Exception: + await sp1.rollback() + logging.exception("Unexpected error inserting Asset (hash=%s, locator=%s)", asset_hash, locator) + raise + + # ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ---- + async with session.begin_nested() as sp2: + try: + session.add( + AssetLocatorState( + asset_hash=asset_hash, + mtime_ns=int(mtime_ns), + ) + ) + await session.flush() + out["state_created"] = True + except IntegrityError: + await sp2.rollback() + state = await session.get(AssetLocatorState, asset_hash) + if state is not None: + desired_mtime = int(mtime_ns) + if state.mtime_ns != desired_mtime: + state.mtime_ns = desired_mtime + out["state_updated"] = True + else: + logging.debug("Locator state missing for %s after conflict; skipping update.", asset_hash) + except Exception: + await sp2.rollback() + logging.exception("Unexpected error inserting AssetLocatorState (hash=%s)", asset_hash) + raise + + # ---- Optional: AssetInfo + tag links ---- + if info_name: + # 2a) Create AssetInfo (no refcount bump) + async with session.begin_nested() as sp3: + try: + info = AssetInfo( + owner_id=owner_id, + name=info_name, + asset_hash=asset_hash, + preview_hash=preview_hash, + created_at=datetime_now, + updated_at=datetime_now, + last_access_time=datetime_now, + ) + session.add(info) + await session.flush() # get info.id + out["asset_info_id"] = info.id + except Exception: + await sp3.rollback() + logging.exception( + "Unexpected error inserting AssetInfo (hash=%s, name=%s)", asset_hash, info_name + ) + raise + + # 2b) Link tags (if any). We DO NOT create new Tag rows here by default. + norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] + if norm and out["asset_info_id"] is not None: + # Which tags exist? + existing_tag_names = set( + name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() + ) + missing = [t for t in norm if t not in existing_tag_names] + if missing and require_existing_tags: + raise ValueError(f"Unknown tags: {missing}") + + # Which links already exist? + existing_links = set( + tag_name + for (tag_name,) in ( + await session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"]) + ) + ).all() + ) + to_add = [t for t in norm if t in existing_tag_names and t not in existing_links] + if to_add: + session.add_all( + [ + AssetInfoTag( + asset_info_id=out["asset_info_id"], + tag_name=t, + origin=tag_origin, + added_by=added_by, + added_at=datetime_now, + ) + for t in to_add + ] + ) + await session.flush() + out["tags_added"] = to_add + out["tags_missing"] = missing + + # 2c) Rebuild metadata projection if provided + if user_metadata is not None and out["asset_info_id"] is not None: + await replace_asset_info_metadata_projection( + session, + asset_info_id=out["asset_info_id"], + user_metadata=user_metadata, + ) + return out + + +async def touch_asset_infos_by_fs_path( + session: AsyncSession, + *, + abs_path: str, + ts: Optional[datetime] = None, + only_if_newer: bool = True, +) -> int: + locator = os.path.abspath(abs_path) + ts = ts or datetime.now(timezone.utc) + + stmt = sa.update(AssetInfo).where( + sa.exists( + sa.select(sa.literal(1)) + .select_from(Asset) + .where( + Asset.hash == AssetInfo.asset_hash, + Asset.storage_backend == "fs", + Asset.storage_locator == locator, + ) + ) + ) + + if only_if_newer: + stmt = stmt.where( + sa.or_( + AssetInfo.last_access_time.is_(None), + AssetInfo.last_access_time < ts, + ) + ) + + stmt = stmt.values(last_access_time=ts) + + res = await session.execute(stmt) + return int(res.rowcount or 0) + + +async def list_asset_infos_page( + session: AsyncSession, + *, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", +) -> tuple[list[AssetInfo], dict[int, list[str]], int]: + """ + Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1), + plus a map of asset_info_id -> [tags], and the total count. + + We purposely collect tags in a separate (single) query to avoid row explosion. + """ + # Clamp + if limit <= 0: + limit = 1 + if limit > 100: + limit = 100 + if offset < 0: + offset = 0 + + # Build base query + base = ( + select(AssetInfo) + .join(Asset, Asset.hash == AssetInfo.asset_hash) + .options(contains_eager(AssetInfo.asset)) + ) + + # Filters + if name_contains: + base = base.where(AssetInfo.name.ilike(f"%{name_contains}%")) + + base = _apply_tag_filters(base, include_tags, exclude_tags) + + base = _apply_metadata_filter(base, metadata_filter) + + # Sort + sort = (sort or "created_at").lower() + order = (order or "desc").lower() + sort_map = { + "name": AssetInfo.name, + "created_at": AssetInfo.created_at, + "updated_at": AssetInfo.updated_at, + "last_access_time": AssetInfo.last_access_time, + "size": Asset.size_bytes, + } + sort_col = sort_map.get(sort, AssetInfo.created_at) + sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() + + base = base.order_by(sort_exp).limit(limit).offset(offset) + + # Total count (same filters, no ordering/limit/offset) + count_stmt = ( + select(func.count()) + .select_from(AssetInfo) + .join(Asset, Asset.hash == AssetInfo.asset_hash) + ) + if name_contains: + count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%")) + count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) + + total = (await session.execute(count_stmt)).scalar_one() + + # Fetch rows + infos = (await session.execute(base)).scalars().unique().all() + + # Collect tags in bulk (single query) + id_list = [i.id for i in infos] + tag_map: dict[int, list[str]] = defaultdict(list) + if id_list: + rows = await session.execute( + select(AssetInfoTag.asset_info_id, Tag.name) + .join(Tag, Tag.name == AssetInfoTag.tag_name) + .where(AssetInfoTag.asset_info_id.in_(id_list)) + ) + for aid, tag_name in rows.all(): + tag_map[aid].append(tag_name) + + return infos, tag_map, total + + +async def set_asset_info_tags( + session: AsyncSession, + *, + asset_info_id: int, + tags: Sequence[str], + origin: str = "manual", + added_by: Optional[str] = None, +) -> dict: + """ + Replace the tag set on an AssetInfo with `tags`. Idempotent. + Creates missing tag names as 'user'. + """ + desired = _normalize_tags(tags) + now = datetime.now(timezone.utc) + + # current links + current = set( + tag_name for (tag_name,) in ( + await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) + ).all() + ) + + to_add = [t for t in desired if t not in current] + to_remove = [t for t in current if t not in desired] + + if to_add: + await _ensure_tags_exist(session, to_add, tag_type="user") + session.add_all([ + AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=now) + for t in to_add + ]) + await session.flush() + + if to_remove: + await session.execute( + delete(AssetInfoTag) + .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) + ) + await session.flush() + + return {"added": to_add, "removed": to_remove, "total": desired} + + +async def update_asset_info_full( + session: AsyncSession, + *, + asset_info_id: int, + name: Optional[str] = None, + tags: Optional[Sequence[str]] = None, + user_metadata: Optional[dict] = None, + tag_origin: str = "manual", + added_by: Optional[str] = None, +) -> AssetInfo: + """ + Update AssetInfo fields: + - name (if provided) + - user_metadata blob + rebuild projection (if provided) + - replace tags with provided set (if provided) + Returns the updated AssetInfo. + """ + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + touched = False + if name is not None and name != info.name: + info.name = name + touched = True + + if user_metadata is not None: + await replace_asset_info_metadata_projection( + session, asset_info_id=asset_info_id, user_metadata=user_metadata + ) + touched = True + + if tags is not None: + await set_asset_info_tags( + session, + asset_info_id=asset_info_id, + tags=tags, + origin=tag_origin, + added_by=added_by, + ) + touched = True + + if touched and user_metadata is None: + info.updated_at = datetime.now(timezone.utc) + await session.flush() + + return info + + +async def replace_asset_info_metadata_projection( + session: AsyncSession, + *, + asset_info_id: int, + user_metadata: dict | None, +) -> None: + """Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`.""" + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + info.user_metadata = user_metadata or {} + info.updated_at = datetime.now(timezone.utc) + await session.flush() + + await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) + await session.flush() + + if not user_metadata: + return + + rows: list[AssetInfoMeta] = [] + for k, v in user_metadata.items(): + for r in _project_kv(k, v): + rows.append( + AssetInfoMeta( + asset_info_id=asset_info_id, + key=r["key"], + ordinal=int(r["ordinal"]), + val_str=r.get("val_str"), + val_num=r.get("val_num"), + val_bool=r.get("val_bool"), + val_json=r.get("val_json"), + ) + ) + if rows: + session.add_all(rows) + await session.flush() + + +async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[Tag]: + return [ + tag_name + for (tag_name,) in ( + await session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + ] + + +def _normalize_tags(tags: Sequence[str] | None) -> list[str]: + return [t.strip().lower() for t in (tags or []) if (t or "").strip()] + + +async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]: + wanted = _normalize_tags(list(names)) + if not wanted: + return [] + existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() + by_name = {t.name: t for t in existing} + to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name] + if to_create: + session.add_all(to_create) + await session.flush() + by_name.update({t.name: t for t in to_create}) + return [by_name[n] for n in wanted] + + +def _apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Sequence[str] | None, + exclude_tags: Sequence[str] | None, +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = _normalize_tags(include_tags) + exclude_tags = _normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + +def _apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: dict | None, +) -> sa.sql.Select: + """Apply metadata filters using the projection table asset_info_meta. + + Semantics: + - For scalar values: require EXISTS(asset_info_meta) with matching key + typed value. + - For None: key is missing OR key has explicit null (val_json IS NULL). + - For list values: ANY-of the list elements matches (EXISTS for any). + (Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))') + """ + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + subquery = ( + select(sa.literal(1)) + .select_from(AssetInfoMeta) + .where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + *preds, + ) + .limit(1) + ) + return sa.exists(subquery) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + # Missing OR null: + if value is None: + # either: no row for key OR a row for key with explicit null + no_row_for_key = ~sa.exists( + select(sa.literal(1)) + .select_from(AssetInfoMeta) + .where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + ) + .limit(1) + ) + null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None)) + return sa.or_(no_row_for_key, null_row) + + # Typed scalar matches: + if isinstance(value, bool): + return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) + if isinstance(value, (int, float, Decimal)): + # store as Decimal for equality against NUMERIC(38,10) + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetInfoMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetInfoMeta.val_str == value) + + # Complex: compare JSON (no index, but supported) + return _exists_for_pred(key, AssetInfoMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + # ANY-of (exists for any element) + ors = [ _exists_clause_for_value(k, elem) for elem in v ] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt + + +def _is_scalar(v: Any) -> bool: + if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries + return True + if isinstance(v, bool): + return True + if isinstance(v, (int, float, Decimal, str)): + return True + return False + + +def _project_kv(key: str, value: Any) -> list[dict]: + """ + Turn a metadata key/value into one or more projection rows: + - scalar -> one row (ordinal=0) in the proper typed column + - list of scalars -> one row per element with ordinal=i + - dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list) + - None -> single row with val_json = None + Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...} + """ + rows: list[dict] = [] + + # None + if value is None: + rows.append({"key": key, "ordinal": 0, "val_json": None}) + return rows + + # Scalars + if _is_scalar(value): + if isinstance(value, bool): + rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) + elif isinstance(value, (int, float, Decimal)): + # store numeric; SQLAlchemy will coerce to Numeric + rows.append({"key": key, "ordinal": 0, "val_num": value}) + elif isinstance(value, str): + rows.append({"key": key, "ordinal": 0, "val_str": value}) + else: + # Fallback to json + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows + + # Lists + if isinstance(value, list): + # list of scalars? + if all(_is_scalar(x) for x in value): + for i, x in enumerate(value): + if x is None: + rows.append({"key": key, "ordinal": i, "val_json": None}) + elif isinstance(x, bool): + rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) + elif isinstance(x, (int, float, Decimal)): + rows.append({"key": key, "ordinal": i, "val_num": x}) + elif isinstance(x, str): + rows.append({"key": key, "ordinal": i, "val_str": x}) + else: + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + # list contains objects -> one val_json per element + for i, x in enumerate(value): + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + + # Dict or any other structure -> single json row + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows diff --git a/app/model_processor.py b/app/model_processor.py deleted file mode 100644 index 5018c2fe6..000000000 --- a/app/model_processor.py +++ /dev/null @@ -1,331 +0,0 @@ -import os -import logging -import time - -import requests -from tqdm import tqdm -from folder_paths import get_relative_path, get_full_path -from app.database.db import create_session, dependencies_available, can_create_session -import blake3 -import comfy.utils - - -if dependencies_available(): - from app.database.models import Model - - -class ModelProcessor: - def _validate_path(self, model_path): - try: - if not self._file_exists(model_path): - logging.error(f"Model file not found: {model_path}") - return None - - result = get_relative_path(model_path) - if not result: - logging.error( - f"Model file not in a recognized model directory: {model_path}" - ) - return None - - return result - except Exception as e: - logging.error(f"Error validating model path {model_path}: {str(e)}") - return None - - def _file_exists(self, path): - """Check if a file exists.""" - return os.path.exists(path) - - def _get_file_size(self, path): - """Get file size.""" - return os.path.getsize(path) - - def _get_hasher(self): - return blake3.blake3() - - def _hash_file(self, model_path): - try: - hasher = self._get_hasher() - with open(model_path, "rb", buffering=0) as f: - b = bytearray(128 * 1024) - mv = memoryview(b) - while n := f.readinto(mv): - hasher.update(mv[:n]) - return hasher.hexdigest() - except Exception as e: - logging.error(f"Error hashing file {model_path}: {str(e)}") - return None - - def _get_existing_model(self, session, model_type, model_relative_path): - return ( - session.query(Model) - .filter(Model.type == model_type) - .filter(Model.path == model_relative_path) - .first() - ) - - def _ensure_source_url(self, session, model, source_url): - if model.source_url is None: - model.source_url = source_url - session.commit() - - def _update_database( - self, - session, - model_type, - model_path, - model_relative_path, - model_hash, - model, - source_url, - ): - try: - if not model: - model = self._get_existing_model( - session, model_type, model_relative_path - ) - - if not model: - model = Model( - path=model_relative_path, - type=model_type, - file_name=os.path.basename(model_path), - ) - session.add(model) - - model.file_size = self._get_file_size(model_path) - model.hash = model_hash - if model_hash: - model.hash_algorithm = "blake3" - model.source_url = source_url - - session.commit() - return model - except Exception as e: - logging.error( - f"Error updating database for {model_relative_path}: {str(e)}" - ) - - def process_file(self, model_path, source_url=None, model_hash=None): - """ - Process a model file and update the database with metadata. - If the file already exists and matches the database, it will not be processed again. - Returns the model object or if an error occurs, returns None. - """ - try: - if not can_create_session(): - return - - result = self._validate_path(model_path) - if not result: - return - model_type, model_relative_path = result - - with create_session() as session: - session.expire_on_commit = False - - existing_model = self._get_existing_model( - session, model_type, model_relative_path - ) - if ( - existing_model - and existing_model.hash - and existing_model.file_size == self._get_file_size(model_path) - ): - # File exists with hash and same size, no need to process - self._ensure_source_url(session, existing_model, source_url) - return existing_model - - if model_hash: - model_hash = model_hash.lower() - logging.info(f"Using provided hash: {model_hash}") - else: - start_time = time.time() - logging.info(f"Hashing model {model_relative_path}") - model_hash = self._hash_file(model_path) - if not model_hash: - return - logging.info( - f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)" - ) - - return self._update_database( - session, - model_type, - model_path, - model_relative_path, - model_hash, - existing_model, - source_url, - ) - except Exception as e: - logging.error(f"Error processing model file {model_path}: {str(e)}") - return None - - def retrieve_model_by_hash(self, model_hash, model_type=None, session=None): - """ - Retrieve a model file from the database by hash and optionally by model type. - Returns the model object or None if the model doesnt exist or an error occurs. - """ - try: - if not can_create_session(): - return - - dispose_session = False - - if session is None: - session = create_session() - dispose_session = True - - model = session.query(Model).filter(Model.hash == model_hash) - if model_type is not None: - model = model.filter(Model.type == model_type) - return model.first() - except Exception as e: - logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}") - return None - finally: - if dispose_session: - session.close() - - def retrieve_hash(self, model_path, model_type=None): - """ - Retrieve the hash of a model file from the database. - Returns the hash or None if the model doesnt exist or an error occurs. - """ - try: - if not can_create_session(): - return - - if model_type is not None: - result = self._validate_path(model_path) - if not result: - return None - model_type, model_relative_path = result - - with create_session() as session: - model = self._get_existing_model( - session, model_type, model_relative_path - ) - if model and model.hash: - return model.hash - return None - except Exception as e: - logging.error(f"Error retrieving hash for {model_path}: {str(e)}") - return None - - def _validate_file_extension(self, file_name): - """Validate that the file extension is supported.""" - extension = os.path.splitext(file_name)[1] - if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"): - raise ValueError(f"Unsupported unsafe file for download: {file_name}") - - def _check_existing_file(self, model_type, file_name, expected_hash): - """Check if file exists and has correct hash.""" - destination_path = get_full_path(model_type, file_name, allow_missing=True) - if self._file_exists(destination_path): - model = self.process_file(destination_path) - if model and (expected_hash is None or model.hash == expected_hash): - logging.debug( - f"File {destination_path} already exists in the database and has the correct hash or no hash was provided." - ) - return destination_path - else: - raise ValueError( - f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again." - ) - return None - - def _check_existing_file_by_hash(self, hash, type, url): - """Check if a file with the given hash exists in the database and on disk.""" - hash = hash.lower() - with create_session() as session: - model = self.retrieve_model_by_hash(hash, type, session) - if model: - existing_path = get_full_path(type, model.path) - if existing_path: - logging.debug( - f"File {model.path} already exists in the database at {existing_path}" - ) - self._ensure_source_url(session, model, url) - return existing_path - else: - logging.debug( - f"File {model.path} exists in the database but not on disk" - ) - return None - - def _download_file(self, url, destination_path, hasher): - """Download a file and update the hasher with its contents.""" - response = requests.get(url, stream=True) - logging.info(f"Downloading {url} to {destination_path}") - - with open(destination_path, "wb") as f: - total_size = int(response.headers.get("content-length", 0)) - if total_size > 0: - pbar = comfy.utils.ProgressBar(total_size) - else: - pbar = None - with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar: - for chunk in response.iter_content(chunk_size=128 * 1024): - if chunk: - f.write(chunk) - hasher.update(chunk) - progress_bar.update(len(chunk)) - if pbar: - pbar.update(len(chunk)) - - def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path): - """Verify that the downloaded file has the expected hash.""" - if expected_hash is not None and calculated_hash != expected_hash: - self._remove_file(destination_path) - raise ValueError( - f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}" - ) - - def _remove_file(self, file_path): - """Remove a file from disk.""" - os.remove(file_path) - - def ensure_downloaded(self, type, url, desired_file_name, hash=None): - """ - Ensure a model file is downloaded and has the correct hash. - Returns the path to the downloaded file. - """ - logging.debug( - f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'" - ) - - # Validate file extension - self._validate_file_extension(desired_file_name) - - # Check if file exists with correct hash - if hash: - existing_path = self._check_existing_file_by_hash(hash, type, url) - if existing_path: - return existing_path - - # Check if file exists locally - destination_path = get_full_path(type, desired_file_name, allow_missing=True) - existing_path = self._check_existing_file(type, desired_file_name, hash) - if existing_path: - return existing_path - - # Download the file - hasher = self._get_hasher() - self._download_file(url, destination_path, hasher) - - # Verify hash - calculated_hash = hasher.hexdigest() - self._verify_downloaded_hash(calculated_hash, hash, destination_path) - - # Update database - self.process_file(destination_path, url, calculated_hash) - - # TODO: Notify frontend to reload models - - return destination_path - - -model_processor = ModelProcessor() diff --git a/app/storage/__init__.py b/app/storage/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app/storage/hashing.py b/app/storage/hashing.py new file mode 100644 index 000000000..3eaed77a3 --- /dev/null +++ b/app/storage/hashing.py @@ -0,0 +1,72 @@ +import asyncio +import os +from typing import IO, Union + +from blake3 import blake3 + +DEFAULT_CHUNK = 8 * 1024 * 1024 # 8 MiB + + +def _hash_file_obj_sync(file_obj: IO[bytes], chunk_size: int) -> str: + """Hash an already-open binary file object by streaming in chunks. + - Seeks to the beginning before reading (if supported). + - Restores the original position afterward (if tell/seek are supported). + """ + if chunk_size <= 0: + chunk_size = DEFAULT_CHUNK + + orig_pos = None + if hasattr(file_obj, "tell"): + orig_pos = file_obj.tell() + + try: + if hasattr(file_obj, "seek"): + file_obj.seek(0) + + h = blake3() + while True: + chunk = file_obj.read(chunk_size) + if not chunk: + break + h.update(chunk) + return h.hexdigest() + finally: + if hasattr(file_obj, "seek") and orig_pos is not None: + file_obj.seek(orig_pos) + + +def blake3_hash_sync( + fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]], + chunk_size: int = DEFAULT_CHUNK, +) -> str: + """Returns a BLAKE3 hex digest for ``fp``, which may be: + - a filename (str/bytes) or PathLike + - an open binary file object + + If ``fp`` is a file object, it must be opened in **binary** mode and support + ``read``, ``seek``, and ``tell``. The function will seek to the start before + reading and will attempt to restore the original position afterward. + """ + if hasattr(fp, "read"): + return _hash_file_obj_sync(fp, chunk_size) + + with open(os.fspath(fp), "rb") as f: + return _hash_file_obj_sync(f, chunk_size) + + +async def blake3_hash( + fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]], + chunk_size: int = DEFAULT_CHUNK, +) -> str: + """Async wrapper for ``blake3_hash_sync``. + Uses a worker thread so the event loop remains responsive. + """ + # If it is a path, open inside the worker thread to keep I/O off the loop. + if hasattr(fp, "read"): + return await asyncio.to_thread(blake3_hash_sync, fp, chunk_size) + + def _worker() -> str: + with open(os.fspath(fp), "rb") as f: + return _hash_file_obj_sync(f, chunk_size) + + return await asyncio.to_thread(_worker) diff --git a/comfy/asset_management.py b/comfy/asset_management.py deleted file mode 100644 index e47996320..000000000 --- a/comfy/asset_management.py +++ /dev/null @@ -1,110 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, TypedDict -import os -import logging -import comfy.utils - -import folder_paths - -class AssetMetadata(TypedDict): - device: Any - return_metadata: bool - subdir: str - download_url: str - -class AssetInfo: - def __init__(self, hash: str=None, name: str=None, tags: list[str]=None, metadata: AssetMetadata={}): - self.hash = hash - self.name = name - self.tags = tags - self.metadata = metadata - - -class ReturnedAssetABC(ABC): - def __init__(self, mimetype: str): - self.mimetype = mimetype - - -class ModelReturnedAsset(ReturnedAssetABC): - def __init__(self, state_dict: dict[str, str], metadata: dict[str, str]=None): - super().__init__("model") - self.state_dict = state_dict - self.metadata = metadata - - -class AssetResolverABC(ABC): - @abstractmethod - def resolve(self, asset_info: AssetInfo) -> ReturnedAssetABC: - ... - - -class LocalAssetResolver(AssetResolverABC): - def resolve(self, asset_info: AssetInfo, cache_result: bool=True) -> ReturnedAssetABC: - # currently only supports models - make sure models is in the tags - if "models" not in asset_info.tags: - return None - # TODO: if hash exists, call model processor to try to get info about model: - if asset_info.hash: - try: - from app.model_processor import model_processor - model_db = model_processor.retrieve_model_by_hash(asset_info.hash) - full_path = model_db.path - except Exception as e: - logging.error(f"Could not get model by hash with error: {e}") - # the good ol' bread and butter - folder_paths's keys as tags - folder_keys = folder_paths.folder_names_and_paths.keys() - parent_paths = [] - for tag in asset_info.tags: - if tag in folder_keys: - parent_paths.append(tag) - # if subdir metadata and name exists, use that as the model name going forward - if "subdir" in asset_info.metadata and asset_info.name: - # if no matching parent paths, then something went wrong and should return None - if len(parent_paths) == 0: - return None - relative_path = os.path.join(asset_info.metadata["subdir"], asset_info.name) - # now we have the parent keys, we can try to get the local path - chosen_parent = None - full_path = None - for parent_path in parent_paths: - full_path = folder_paths.get_full_path(parent_path, relative_path) - if full_path: - chosen_parent = parent_path - break - if full_path is not None: - logging.info(f"Resolved {asset_info.name} to {full_path} in {chosen_parent}") - # we know the path, so load the model and return it - state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True) - # TODO: handle caching - return ModelReturnedAsset(state_dict, metadata) - # if just name exists, try to find model by name in all subdirs of parent paths - # TODO: this behavior should be configurable by user - if asset_info.name: - for parent_path in parent_paths: - filelist = folder_paths.get_filename_list(parent_path) - for file in filelist: - if os.path.basename(file) == asset_info.name: - full_path = folder_paths.get_full_path(parent_path, file) - state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True) - # TODO: handle caching - return ModelReturnedAsset(state_dict, metadata) - # TODO: if download_url metadata exists, download the model and load it; this should be configurable by user - if asset_info.metadata.get("download_url", None): - ... - return None - - -resolvers: list[AssetResolverABC] = [] -resolvers.append(LocalAssetResolver()) - - -def resolve(asset_info: AssetInfo) -> Any: - global resolvers - for resolver in resolvers: - try: - to_return = resolver.resolve(asset_info) - if to_return is not None: - return resolver.resolve(asset_info) - except Exception as e: - logging.error(f"Error resolving asset {asset_info.name} using {resolver.__class__.__name__}: {e}") - return None diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 84d017314..9ab78b99b 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -211,7 +211,7 @@ parser.add_argument( database_default_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) -parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") +parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.") parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.") if comfy.options.args_parsing: diff --git a/comfy/utils.py b/comfy/utils.py index af2aace0a..220492941 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -102,12 +102,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): else: sd = pl_sd - try: - from app.model_processor import model_processor - model_processor.process_file(ckpt) - except Exception as e: - logging.error(f"Error processing file {ckpt}: {e}") - return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): diff --git a/comfy_extras/nodes_assets_test.py b/comfy_extras/nodes_assets_test.py deleted file mode 100644 index 5172cd628..000000000 --- a/comfy_extras/nodes_assets_test.py +++ /dev/null @@ -1,56 +0,0 @@ -from comfy_api.latest import io, ComfyExtension -import comfy.asset_management -import comfy.sd -import folder_paths -import logging -import os - - -class AssetTestNode(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="AssetTestNode", - is_experimental=True, - inputs=[ - io.Combo.Input("ckpt_name", folder_paths.get_filename_list("checkpoints")), - ], - outputs=[ - io.Model.Output(), - io.Clip.Output(), - io.Vae.Output(), - ], - ) - - @classmethod - def execute(cls, ckpt_name: str): - hash = None - # lets get the full path just so we can retrieve the hash from db, if exists - try: - full_path = folder_paths.get_full_path("checkpoints", ckpt_name) - if full_path is None: - raise Exception(f"Model {ckpt_name} not found") - from app.model_processor import model_processor - hash = model_processor.retrieve_hash(full_path) - except Exception as e: - logging.error(f"Could not get model by hash with error: {e}") - subdir, name = os.path.split(ckpt_name) - asset_info = comfy.asset_management.AssetInfo(hash=hash, name=name, tags=["models", "checkpoints"], metadata={"subdir": subdir}) - asset = comfy.asset_management.resolve(asset_info) - # /\ the stuff above should happen in execution code instead of inside the node - # \/ the stuff below should happen in the node - confirm is a model asset, do stuff to it (already loaded? or should be called to 'load'?) - if asset is None: - raise Exception(f"Model {asset_info.name} not found") - assert isinstance(asset, comfy.asset_management.ModelReturnedAsset) - out = comfy.sd.load_state_dict_guess_config(asset.state_dict, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), metadata=asset.metadata) - return io.NodeOutput(out[0], out[1], out[2]) - - -class AssetTestExtension(ComfyExtension): - @classmethod - async def get_node_list(cls): - return [AssetTestNode] - - -def comfy_entrypoint(): - return AssetTestExtension() diff --git a/main.py b/main.py index 81001c0b5..557961d40 100644 --- a/main.py +++ b/main.py @@ -278,11 +278,11 @@ def cleanup_temp(): if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) -def setup_database(): +async def setup_database(): try: - from app.database.db import init_db, dependencies_available + from app.database.db import init_db_engine, dependencies_available if dependencies_available(): - init_db() + await init_db_engine() except Exception as e: logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") @@ -309,6 +309,8 @@ def start_comfyui(asyncio_loop=None): asyncio.set_event_loop(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop) + asyncio_loop.run_until_complete(setup_database()) + hook_breaker_ac10a0.save_functions() asyncio_loop.run_until_complete(nodes.init_extra_nodes( init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0, @@ -317,7 +319,6 @@ def start_comfyui(asyncio_loop=None): hook_breaker_ac10a0.restore_functions() cuda_malloc_warning() - setup_database() prompt_server.add_routes() hijack_progress(prompt_server) diff --git a/nodes.py b/nodes.py index 95599b8aa..b74cfc58e 100644 --- a/nodes.py +++ b/nodes.py @@ -28,9 +28,10 @@ import comfy.sd import comfy.utils import comfy.controlnet from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator -from comfy_api.internal import register_versions, ComfyAPIWithVersion +from comfy_api.internal import async_to_sync, register_versions, ComfyAPIWithVersion from comfy_api.version_list import supported_versions from comfy_api.latest import io, ComfyExtension +from app.assets_manager import add_local_asset import comfy.clip_vision @@ -777,6 +778,9 @@ class VAELoader: else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) sd = comfy.utils.load_torch_file(vae_path) + async_to_sync.AsyncToSyncConverter.run_async_in_thread( + add_local_asset, tags=["models", "vae"], file_name=vae_name, file_path=vae_path + ) vae = comfy.sd.VAE(sd=sd) vae.throw_exception_if_invalid() return (vae,) @@ -2321,7 +2325,6 @@ async def init_builtin_extra_nodes(): "nodes_edit_model.py", "nodes_tcfg.py", "nodes_context_windows.py", - "nodes_assets_test.py", ] import_failed = [] diff --git a/requirements.txt b/requirements.txt index 0b0e78791..f12d2e3fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,7 @@ tqdm psutil alembic SQLAlchemy +aiosqlite av>=14.2.0 blake3 diff --git a/server.py b/server.py index 8f9c88ebf..30c1a8fe7 100644 --- a/server.py +++ b/server.py @@ -37,6 +37,7 @@ from app.model_manager import ModelFileManager from app.custom_node_manager import CustomNodeManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes +from app.api.assets_routes import register_assets_routes from protocol import BinaryEventTypes async def send_socket_catch_exception(function, message): @@ -183,6 +184,7 @@ class PromptServer(): else args.front_end_root ) logging.info(f"[Prompt Server] web root: {self.web_root}") + register_assets_routes(self.app) routes = web.RouteTableDef() self.routes = routes self.last_node_id = None diff --git a/tests-unit/app_test/model_manager_test.py b/tests-unit/app_test/model_manager_test.py deleted file mode 100644 index ae59206f6..000000000 --- a/tests-unit/app_test/model_manager_test.py +++ /dev/null @@ -1,62 +0,0 @@ -import pytest -import base64 -import json -import struct -from io import BytesIO -from PIL import Image -from aiohttp import web -from unittest.mock import patch -from app.model_manager import ModelFileManager - -pytestmark = ( - pytest.mark.asyncio -) # This applies the asyncio mark to all test functions in the module - -@pytest.fixture -def model_manager(): - return ModelFileManager() - -@pytest.fixture -def app(model_manager): - app = web.Application() - routes = web.RouteTableDef() - model_manager.add_routes(routes) - app.add_routes(routes) - return app - -async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path): - img = Image.new('RGB', (100, 100), 'white') - img_byte_arr = BytesIO() - img.save(img_byte_arr, format='PNG') - img_byte_arr.seek(0) - img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') - - safetensors_file = tmp_path / "test_model.safetensors" - header_bytes = json.dumps({ - "__metadata__": { - "ssmd_cover_images": json.dumps([img_b64]) - } - }).encode('utf-8') - length_bytes = struct.pack('