mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 17:42:58 +08:00
dev: Everything is Assets
This commit is contained in:
parent
c708d0a433
commit
f92307cd4c
158
alembic_db/versions/0001_assets.py
Normal file
158
alembic_db/versions/0001_assets.py
Normal file
@ -0,0 +1,158 @@
|
||||
"""initial assets schema + per-asset state cache
|
||||
|
||||
Revision ID: 0001_assets
|
||||
Revises:
|
||||
Create Date: 2025-08-20 00:00:00
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0001_assets"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ASSETS: content identity (deduplicated by hash)
|
||||
op.create_table(
|
||||
"assets",
|
||||
sa.Column("hash", sa.String(length=128), primary_key=True),
|
||||
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("mime_type", sa.String(length=255), nullable=True),
|
||||
sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("storage_backend", sa.String(length=32), nullable=False, server_default="fs"),
|
||||
sa.Column("storage_locator", sa.Text(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"),
|
||||
)
|
||||
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
|
||||
op.create_index("ix_assets_backend_locator", "assets", ["storage_backend", "storage_locator"])
|
||||
|
||||
# ASSETS_INFO: user-visible references (mutable metadata)
|
||||
op.create_table(
|
||||
"assets_info",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("owner_id", sa.String(length=128), nullable=True),
|
||||
sa.Column("name", sa.String(length=512), nullable=False),
|
||||
sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("preview_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||
sqlite_autoincrement=True,
|
||||
)
|
||||
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||
op.create_index("ix_assets_info_asset_hash", "assets_info", ["asset_hash"])
|
||||
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||
|
||||
# TAGS: normalized tag vocabulary
|
||||
op.create_table(
|
||||
"tags",
|
||||
sa.Column("name", sa.String(length=128), primary_key=True),
|
||||
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
|
||||
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
|
||||
)
|
||||
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
|
||||
|
||||
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
|
||||
op.create_table(
|
||||
"asset_info_tags",
|
||||
sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("tag_name", sa.String(length=128), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||
sa.Column("added_by", sa.String(length=128), nullable=True),
|
||||
sa.Column("added_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
|
||||
)
|
||||
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||
|
||||
# ASSET_LOCATOR_STATE: 1:1 filesystem metadata(for fast integrity checking) for an Asset records
|
||||
op.create_table(
|
||||
"asset_locator_state",
|
||||
sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True),
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column("etag", sa.String(length=256), nullable=True),
|
||||
sa.Column("last_modified", sa.String(length=128), nullable=True),
|
||||
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"),
|
||||
)
|
||||
|
||||
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
|
||||
op.create_table(
|
||||
"asset_info_meta",
|
||||
sa.Column("asset_info_id", sa.Integer(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("key", sa.String(length=256), nullable=False),
|
||||
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||
sa.Column("val_json", sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
|
||||
)
|
||||
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
|
||||
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
|
||||
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||
|
||||
# Tags vocabulary for models
|
||||
tags_table = sa.table(
|
||||
"tags",
|
||||
sa.column("name", sa.String()),
|
||||
sa.column("tag_type", sa.String()),
|
||||
)
|
||||
op.bulk_insert(
|
||||
tags_table,
|
||||
[
|
||||
# Core concept tags
|
||||
{"name": "models", "tag_type": "system"},
|
||||
|
||||
# Canonical single-word types
|
||||
{"name": "checkpoint", "tag_type": "system"},
|
||||
{"name": "lora", "tag_type": "system"},
|
||||
{"name": "vae", "tag_type": "system"},
|
||||
{"name": "text-encoder", "tag_type": "system"},
|
||||
{"name": "clip-vision", "tag_type": "system"},
|
||||
{"name": "embedding", "tag_type": "system"},
|
||||
{"name": "controlnet", "tag_type": "system"},
|
||||
{"name": "upscale", "tag_type": "system"},
|
||||
{"name": "diffusion-model", "tag_type": "system"},
|
||||
{"name": "hypernetwork", "tag_type": "system"},
|
||||
{"name": "vae_approx", "tag_type": "system"},
|
||||
# TODO: decide what to do with: style_models, diffusers, gligen, photomaker, classifiers
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||
op.drop_table("asset_info_meta")
|
||||
|
||||
op.drop_table("asset_locator_state")
|
||||
|
||||
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
|
||||
op.drop_table("asset_info_tags")
|
||||
|
||||
op.drop_index("ix_tags_tag_type", table_name="tags")
|
||||
op.drop_table("tags")
|
||||
|
||||
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_asset_hash", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||
op.drop_table("assets_info")
|
||||
|
||||
op.drop_index("ix_assets_backend_locator", table_name="assets")
|
||||
op.drop_index("ix_assets_mime_type", table_name="assets")
|
||||
op.drop_table("assets")
|
||||
@ -1,40 +0,0 @@
|
||||
"""init
|
||||
|
||||
Revision ID: e9c714da8d57
|
||||
Revises:
|
||||
Create Date: 2025-05-30 20:14:33.772039
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'e9c714da8d57'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.create_table('model',
|
||||
sa.Column('type', sa.Text(), nullable=False),
|
||||
sa.Column('path', sa.Text(), nullable=False),
|
||||
sa.Column('file_name', sa.Text(), nullable=True),
|
||||
sa.Column('file_size', sa.Integer(), nullable=True),
|
||||
sa.Column('hash', sa.Text(), nullable=True),
|
||||
sa.Column('hash_algorithm', sa.Text(), nullable=True),
|
||||
sa.Column('source_url', sa.Text(), nullable=True),
|
||||
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.PrimaryKeyConstraint('type', 'path')
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('model')
|
||||
# ### end Alembic commands ###
|
||||
0
app/api/__init__.py
Normal file
0
app/api/__init__.py
Normal file
110
app/api/assets_routes.py
Normal file
110
app/api/assets_routes.py
Normal file
@ -0,0 +1,110 @@
|
||||
import json
|
||||
from typing import Sequence
|
||||
from aiohttp import web
|
||||
|
||||
from app import assets_manager
|
||||
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets")
|
||||
async def list_assets(request: web.Request) -> web.Response:
|
||||
q = request.rel_url.query
|
||||
|
||||
include_tags: Sequence[str] = _parse_csv_tags(q.get("include_tags"))
|
||||
exclude_tags: Sequence[str] = _parse_csv_tags(q.get("exclude_tags"))
|
||||
name_contains = q.get("name_contains")
|
||||
|
||||
# Optional JSON metadata filter (top-level key equality only for now)
|
||||
metadata_filter = None
|
||||
raw_meta = q.get("metadata_filter")
|
||||
if raw_meta:
|
||||
try:
|
||||
metadata_filter = json.loads(raw_meta)
|
||||
if not isinstance(metadata_filter, dict):
|
||||
metadata_filter = None
|
||||
except Exception:
|
||||
# Silently ignore malformed JSON for first iteration; could 400 in future
|
||||
metadata_filter = None
|
||||
|
||||
limit = _parse_int(q.get("limit"), default=20, lo=1, hi=100)
|
||||
offset = _parse_int(q.get("offset"), default=0, lo=0, hi=10_000_000)
|
||||
sort = q.get("sort", "created_at")
|
||||
order = q.get("order", "desc")
|
||||
|
||||
payload = await assets_manager.list_assets(
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
return web.json_response(payload)
|
||||
|
||||
|
||||
@ROUTES.put("/api/assets/{id}")
|
||||
async def update_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id_raw = request.match_info.get("id")
|
||||
try:
|
||||
asset_info_id = int(asset_info_id_raw)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
||||
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
name = payload.get("name", None)
|
||||
tags = payload.get("tags", None)
|
||||
user_metadata = payload.get("user_metadata", None)
|
||||
|
||||
if name is None and tags is None and user_metadata is None:
|
||||
return _error_response(400, "NO_FIELDS", "Provide at least one of: name, tags, user_metadata.")
|
||||
|
||||
if tags is not None and (not isinstance(tags, list) or not all(isinstance(t, str) for t in tags)):
|
||||
return _error_response(400, "INVALID_TAGS", "Field 'tags' must be an array of strings.")
|
||||
|
||||
if user_metadata is not None and not isinstance(user_metadata, dict):
|
||||
return _error_response(400, "INVALID_METADATA", "Field 'user_metadata' must be an object.")
|
||||
|
||||
try:
|
||||
result = await assets_manager.update_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result, status=200)
|
||||
|
||||
|
||||
def register_assets_routes(app: web.Application) -> None:
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
|
||||
def _parse_csv_tags(raw: str | None) -> list[str]:
|
||||
if not raw:
|
||||
return []
|
||||
return [t.strip() for t in raw.split(",") if t.strip()]
|
||||
|
||||
|
||||
def _parse_int(qval: str | None, default: int, lo: int, hi: int) -> int:
|
||||
if not qval:
|
||||
return default
|
||||
try:
|
||||
v = int(qval)
|
||||
except Exception:
|
||||
return default
|
||||
return max(lo, min(hi, v))
|
||||
|
||||
|
||||
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
|
||||
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
|
||||
148
app/assets_manager.py
Normal file
148
app/assets_manager.py
Normal file
@ -0,0 +1,148 @@
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from .database.db import create_session
|
||||
from .storage import hashing
|
||||
from .database.services import (
|
||||
check_fs_asset_exists_quick,
|
||||
ingest_fs_asset,
|
||||
touch_asset_infos_by_fs_path,
|
||||
list_asset_infos_page,
|
||||
update_asset_info_full,
|
||||
get_asset_tags,
|
||||
)
|
||||
|
||||
|
||||
def get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
|
||||
"""Adds a local asset to the DB. If already present and unchanged, does nothing.
|
||||
|
||||
Notes:
|
||||
- Uses absolute path as the canonical locator for the 'fs' backend.
|
||||
- Computes BLAKE3 only when the fast existence check indicates it's needed.
|
||||
- This function ensures the identity row and seeds mtime in asset_locator_state.
|
||||
"""
|
||||
abs_path = os.path.abspath(file_path)
|
||||
size_bytes, mtime_ns = get_size_mtime_ns(abs_path)
|
||||
if not size_bytes:
|
||||
return
|
||||
|
||||
async with await create_session() as session:
|
||||
if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns):
|
||||
await touch_asset_infos_by_fs_path(session, abs_path=abs_path, ts=datetime.now(timezone.utc))
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
asset_hash = hashing.blake3_hash_sync(abs_path)
|
||||
|
||||
async with await create_session() as session:
|
||||
await ingest_fs_asset(
|
||||
session,
|
||||
asset_hash="blake3:" + asset_hash,
|
||||
abs_path=abs_path,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=None,
|
||||
info_name=file_name,
|
||||
tag_origin="automatic",
|
||||
tags=tags,
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def list_assets(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: Optional[str] = None,
|
||||
metadata_filter: Optional[dict] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str | None = "created_at",
|
||||
order: str | None = "desc",
|
||||
) -> dict:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
|
||||
async with await create_session() as session:
|
||||
infos, tag_map, total = await list_asset_infos_page(
|
||||
session,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
assets_json = []
|
||||
for info in infos:
|
||||
asset = info.asset # populated via contains_eager
|
||||
tags = tag_map.get(info.id, [])
|
||||
assets_json.append(
|
||||
{
|
||||
"id": info.id,
|
||||
"name": info.name,
|
||||
"asset_hash": info.asset_hash,
|
||||
"size": int(asset.size_bytes) if asset else None,
|
||||
"mime_type": asset.mime_type if asset else None,
|
||||
"tags": tags,
|
||||
"preview_url": f"/api/v1/assets/{info.id}/content", # TODO: implement actual content endpoint later
|
||||
"created_at": info.created_at.isoformat() if info.created_at else None,
|
||||
"updated_at": info.updated_at.isoformat() if info.updated_at else None,
|
||||
"last_access_time": info.last_access_time.isoformat() if info.last_access_time else None,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"assets": assets_json,
|
||||
"total": total,
|
||||
"has_more": (offset + len(assets_json)) < total,
|
||||
}
|
||||
|
||||
|
||||
async def update_asset(
|
||||
*,
|
||||
asset_info_id: int,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
) -> dict:
|
||||
async with await create_session() as session:
|
||||
info = await update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
added_by=None,
|
||||
)
|
||||
|
||||
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
await session.commit()
|
||||
|
||||
return {
|
||||
"id": info.id,
|
||||
"name": info.name,
|
||||
"asset_hash": info.asset_hash,
|
||||
"tags": tag_names,
|
||||
"user_metadata": info.user_metadata or {},
|
||||
"updated_at": info.updated_at.isoformat() if info.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
if not requested:
|
||||
return "created_at"
|
||||
v = requested.lower()
|
||||
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||
return v
|
||||
return "created_at"
|
||||
0
app/database/__init__.py
Normal file
0
app/database/__init__.py
Normal file
@ -1,112 +1,267 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from app.logger import log_startup_warning
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
from comfy.cli_args import args
|
||||
|
||||
_DB_AVAILABLE = False
|
||||
Session = None
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
# Attempt imports which may not exist in some environments
|
||||
try:
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
ENGINE: AsyncEngine | None = None
|
||||
SESSION: async_sessionmaker | None = None
|
||||
except ImportError as e:
|
||||
log_startup_warning(
|
||||
f"""
|
||||
------------------------------------------------------------------------
|
||||
Error importing dependencies: {e}
|
||||
{get_missing_requirements_message()}
|
||||
This error is happening because ComfyUI now uses a local sqlite database.
|
||||
------------------------------------------------------------------------
|
||||
""".strip()
|
||||
(
|
||||
"------------------------------------------------------------------------\n"
|
||||
f"Error importing DB dependencies: {e}\n"
|
||||
f"{get_missing_requirements_message()}\n"
|
||||
"This error is happening because ComfyUI now uses a local database.\n"
|
||||
"------------------------------------------------------------------------"
|
||||
).strip()
|
||||
)
|
||||
_DB_AVAILABLE = False
|
||||
ENGINE = None
|
||||
SESSION = None
|
||||
|
||||
|
||||
def dependencies_available():
|
||||
"""
|
||||
Temporary function to check if the dependencies are available
|
||||
"""
|
||||
def dependencies_available() -> bool:
|
||||
"""Check if DB dependencies are importable."""
|
||||
return _DB_AVAILABLE
|
||||
|
||||
|
||||
def can_create_session():
|
||||
"""
|
||||
Temporary function to check if the database is available to create a session
|
||||
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
|
||||
"""
|
||||
return dependencies_available() and Session is not None
|
||||
|
||||
|
||||
def get_alembic_config():
|
||||
root_path = os.path.join(os.path.dirname(__file__), "../..")
|
||||
def _root_paths():
|
||||
"""Resolve alembic.ini and migrations script folder."""
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
|
||||
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
|
||||
|
||||
config = Config(config_path)
|
||||
config.set_main_option("script_location", scripts_path)
|
||||
config.set_main_option("sqlalchemy.url", args.database_url)
|
||||
|
||||
return config
|
||||
return config_path, scripts_path
|
||||
|
||||
|
||||
def get_db_path():
|
||||
url = args.database_url
|
||||
if url.startswith("sqlite:///"):
|
||||
return url.split("///")[1]
|
||||
def _absolutize_sqlite_url(db_url: str) -> str:
|
||||
"""Make SQLite database path absolute. No-op for non-SQLite URLs."""
|
||||
try:
|
||||
u = make_url(db_url)
|
||||
except Exception:
|
||||
return db_url
|
||||
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return db_url
|
||||
|
||||
# Make path absolute if relative
|
||||
db_path = u.database or ""
|
||||
if not os.path.isabs(db_path):
|
||||
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
|
||||
u = u.set(database=db_path)
|
||||
return str(u)
|
||||
|
||||
|
||||
def _to_sync_driver_url(async_url: str) -> str:
|
||||
"""Convert an async SQLAlchemy URL to a sync URL for Alembic."""
|
||||
u = make_url(async_url)
|
||||
driver = u.drivername
|
||||
|
||||
if driver.startswith("sqlite+aiosqlite"):
|
||||
u = u.set(drivername="sqlite")
|
||||
elif driver.startswith("postgresql+asyncpg"):
|
||||
u = u.set(drivername="postgresql")
|
||||
else:
|
||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||
# Generic: strip the async driver part if present
|
||||
if "+" in driver:
|
||||
u = u.set(drivername=driver.split("+", 1)[0])
|
||||
|
||||
return str(u)
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
db_path = get_db_path()
|
||||
db_exists = os.path.exists(db_path)
|
||||
def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
|
||||
"""Return the on-disk path for a SQLite URL, else None."""
|
||||
try:
|
||||
u = make_url(sync_url)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
config = get_alembic_config()
|
||||
if not u.drivername.startswith("sqlite"):
|
||||
return None
|
||||
return u.database
|
||||
|
||||
# Check if we need to upgrade
|
||||
engine = create_engine(db_url)
|
||||
conn = engine.connect()
|
||||
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
def _get_alembic_config(sync_url: str) -> Config:
|
||||
"""Prepare Alembic Config with script location and DB URL."""
|
||||
config_path, scripts_path = _root_paths()
|
||||
cfg = Config(config_path)
|
||||
cfg.set_main_option("script_location", scripts_path)
|
||||
cfg.set_main_option("sqlalchemy.url", sync_url)
|
||||
return cfg
|
||||
|
||||
script = ScriptDirectory.from_config(config)
|
||||
|
||||
async def init_db_engine() -> None:
|
||||
"""Initialize async engine + sessionmaker and run migrations to head.
|
||||
|
||||
This must be called once on application startup before any DB usage.
|
||||
"""
|
||||
global ENGINE, SESSION
|
||||
|
||||
if not dependencies_available():
|
||||
raise RuntimeError("Database dependencies are not available.")
|
||||
|
||||
if ENGINE is not None:
|
||||
return
|
||||
|
||||
raw_url = args.database_url
|
||||
if not raw_url:
|
||||
raise RuntimeError("Database URL is not configured.")
|
||||
|
||||
# Absolutize SQLite path for async engine
|
||||
db_url = _absolutize_sqlite_url(raw_url)
|
||||
|
||||
# Prepare async engine
|
||||
connect_args = {}
|
||||
if db_url.startswith("sqlite"):
|
||||
connect_args = {
|
||||
"check_same_thread": False,
|
||||
"timeout": 12,
|
||||
}
|
||||
|
||||
ENGINE = create_async_engine(
|
||||
db_url,
|
||||
connect_args=connect_args,
|
||||
pool_pre_ping=True,
|
||||
future=True,
|
||||
)
|
||||
|
||||
# Enforce SQLite pragmas on the async engine
|
||||
if db_url.startswith("sqlite"):
|
||||
async with ENGINE.begin() as conn:
|
||||
# WAL for concurrency and durability, Foreign Keys for referential integrity
|
||||
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
|
||||
if str(current_mode).lower() != "wal":
|
||||
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
|
||||
if str(new_mode).lower() != "wal":
|
||||
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
|
||||
LOGGER.info("SQLite journal mode set to WAL.")
|
||||
|
||||
await conn.execute(text("PRAGMA foreign_keys = ON;"))
|
||||
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
|
||||
|
||||
await _run_migrations(raw_url=db_url)
|
||||
|
||||
SESSION = async_sessionmaker(
|
||||
bind=ENGINE,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
async def _run_migrations(raw_url: str) -> None:
|
||||
"""
|
||||
Run Alembic migrations up to head.
|
||||
|
||||
We deliberately use a synchronous engine for migrations because Alembic's
|
||||
programmatic API is synchronous by default and this path is robust.
|
||||
"""
|
||||
# Convert to sync URL and make SQLite URL an absolute one
|
||||
sync_url = _to_sync_driver_url(raw_url)
|
||||
sync_url = _absolutize_sqlite_url(sync_url)
|
||||
|
||||
cfg = _get_alembic_config(sync_url)
|
||||
|
||||
# Inspect current and target heads
|
||||
engine = create_engine(sync_url, future=True)
|
||||
with engine.connect() as conn:
|
||||
context = MigrationContext.configure(conn)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
script = ScriptDirectory.from_config(cfg)
|
||||
target_rev = script.get_current_head()
|
||||
|
||||
if target_rev is None:
|
||||
logging.warning("No target revision found.")
|
||||
elif current_rev != target_rev:
|
||||
# Backup the database pre upgrade
|
||||
backup_path = db_path + ".bkp"
|
||||
if db_exists:
|
||||
shutil.copy(db_path, backup_path)
|
||||
else:
|
||||
backup_path = None
|
||||
LOGGER.warning("Alembic: no target revision found.")
|
||||
return
|
||||
|
||||
if current_rev == target_rev:
|
||||
LOGGER.debug("Alembic: database already at head %s", target_rev)
|
||||
return
|
||||
|
||||
LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev)
|
||||
|
||||
# Optional backup for SQLite file DBs
|
||||
backup_path = None
|
||||
sqlite_path = _get_sqlite_file_path(sync_url)
|
||||
if sqlite_path and os.path.exists(sqlite_path):
|
||||
backup_path = sqlite_path + ".bkp"
|
||||
try:
|
||||
command.upgrade(config, target_rev)
|
||||
logging.info(f"Database upgraded from {current_rev} to {target_rev}")
|
||||
except Exception as e:
|
||||
if backup_path:
|
||||
# Restore the database from backup if upgrade fails
|
||||
shutil.copy(backup_path, db_path)
|
||||
shutil.copy(sqlite_path, backup_path)
|
||||
except Exception as exc:
|
||||
LOGGER.warning("Failed to create SQLite backup before migration: %s", exc)
|
||||
|
||||
try:
|
||||
command.upgrade(cfg, target_rev)
|
||||
except Exception:
|
||||
if backup_path and os.path.exists(backup_path):
|
||||
LOGGER.exception("Error upgrading database, attempting restore from backup.")
|
||||
try:
|
||||
shutil.copy(backup_path, sqlite_path) # restore
|
||||
os.remove(backup_path)
|
||||
logging.exception("Error upgrading database: ")
|
||||
raise e
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
except Exception as re:
|
||||
LOGGER.error("Failed to restore SQLite backup: %s", re)
|
||||
else:
|
||||
LOGGER.exception("Error upgrading database, backup is not available.")
|
||||
raise
|
||||
|
||||
|
||||
def create_session():
|
||||
return Session()
|
||||
def get_engine():
|
||||
"""Return the global async engine (initialized after init_db_engine())."""
|
||||
if ENGINE is None:
|
||||
raise RuntimeError("Engine is not initialized. Call init_db_engine() first.")
|
||||
return ENGINE
|
||||
|
||||
|
||||
def get_session_maker():
|
||||
"""Return the global async_sessionmaker (initialized after init_db_engine())."""
|
||||
if SESSION is None:
|
||||
raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.")
|
||||
return SESSION
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def session_scope():
|
||||
"""Async context manager for a unit of work:
|
||||
|
||||
async with session_scope() as sess:
|
||||
... use sess ...
|
||||
"""
|
||||
maker = get_session_maker()
|
||||
async with maker() as sess:
|
||||
try:
|
||||
yield sess
|
||||
await sess.commit()
|
||||
except Exception:
|
||||
await sess.rollback()
|
||||
raise
|
||||
|
||||
|
||||
async def create_session():
|
||||
"""Convenience helper to acquire a single AsyncSession instance.
|
||||
|
||||
Typical usage:
|
||||
async with (await create_session()) as sess:
|
||||
...
|
||||
"""
|
||||
maker = get_session_maker()
|
||||
return maker()
|
||||
|
||||
@ -1,59 +1,257 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Integer,
|
||||
Text,
|
||||
BigInteger,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
JSON,
|
||||
String,
|
||||
Text,
|
||||
CheckConstraint,
|
||||
Numeric,
|
||||
Boolean,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
Base = declarative_base()
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign
|
||||
|
||||
|
||||
def to_dict(obj):
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||
fields = obj.__table__.columns.keys()
|
||||
return {
|
||||
field: (val.to_dict() if hasattr(val, "to_dict") else val)
|
||||
for field in fields
|
||||
if (val := getattr(obj, field))
|
||||
}
|
||||
out: dict[str, Any] = {}
|
||||
for field in fields:
|
||||
val = getattr(obj, field)
|
||||
if val is None and not include_none:
|
||||
continue
|
||||
if isinstance(val, datetime):
|
||||
out[field] = val.isoformat()
|
||||
else:
|
||||
out[field] = val
|
||||
return out
|
||||
|
||||
|
||||
class Model(Base):
|
||||
"""
|
||||
sqlalchemy model representing a model file in the system.
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
This class defines the database schema for storing information about model files,
|
||||
including their type, path, hash, and when they were added to the system.
|
||||
hash: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[str | None] = mapped_column(String(255))
|
||||
refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs")
|
||||
storage_locator: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
|
||||
Attributes:
|
||||
type (Text): The type of the model, this is the name of the folder in the models folder (primary key)
|
||||
path (Text): The file path of the model relative to the type folder (primary key)
|
||||
file_name (Text): The name of the model file
|
||||
file_size (Integer): The size of the model file in bytes
|
||||
hash (Text): A hash of the model file
|
||||
hash_algorithm (Text): The algorithm used to generate the hash
|
||||
source_url (Text): The URL of the model file
|
||||
date_added (DateTime): Timestamp of when the model was added to the system
|
||||
"""
|
||||
infos: Mapped[list["AssetInfo"]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="asset",
|
||||
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.asset_hash),
|
||||
foreign_keys=lambda: [AssetInfo.asset_hash],
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__tablename__ = "model"
|
||||
preview_of: Mapped[list["AssetInfo"]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.preview_hash),
|
||||
foreign_keys=lambda: [AssetInfo.preview_hash],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
type = Column(Text, primary_key=True)
|
||||
path = Column(Text, primary_key=True)
|
||||
file_name = Column(Text)
|
||||
file_size = Column(Integer)
|
||||
hash = Column(Text)
|
||||
hash_algorithm = Column(Text)
|
||||
source_url = Column(Text)
|
||||
date_added = Column(DateTime, server_default=func.now())
|
||||
locator_state: Mapped["AssetLocatorState | None"] = relationship(
|
||||
back_populates="asset",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Convert the model instance to a dictionary representation.
|
||||
__table_args__ = (
|
||||
Index("ix_assets_mime_type", "mime_type"),
|
||||
Index("ix_assets_backend_locator", "storage_backend", "storage_locator"),
|
||||
)
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the attributes of the model
|
||||
"""
|
||||
dict = to_dict(self)
|
||||
return dict
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset hash={self.hash[:12]} backend={self.storage_backend}>"
|
||||
|
||||
|
||||
class AssetLocatorState(Base):
|
||||
__tablename__ = "asset_locator_state"
|
||||
|
||||
asset_hash: Mapped[str] = mapped_column(
|
||||
String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
# For fs backends: nanosecond mtime; nullable if not applicable
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
# For HTTP/S3/GCS/Azure, etc.: optional validators
|
||||
etag: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
|
||||
asset: Mapped["Asset"] = relationship(back_populates="locator_state", uselist=False)
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetLocatorState hash={self.asset_hash[:12]} mtime_ns={self.mtime_ns}>"
|
||||
|
||||
|
||||
class AssetInfo(Base):
|
||||
__tablename__ = "assets_info"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(128))
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
asset_hash: Mapped[str] = mapped_column(
|
||||
String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False
|
||||
)
|
||||
preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
last_access_time: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
)
|
||||
|
||||
# Relationships
|
||||
asset: Mapped[Asset] = relationship(
|
||||
"Asset",
|
||||
back_populates="infos",
|
||||
foreign_keys=[asset_hash],
|
||||
)
|
||||
preview_asset: Mapped[Asset | None] = relationship(
|
||||
"Asset",
|
||||
back_populates="preview_of",
|
||||
foreign_keys=[preview_hash],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
tag_links: Mapped[list["AssetInfoTag"]] = relationship(
|
||||
back_populates="asset_info",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
overlaps="tags,asset_infos",
|
||||
)
|
||||
|
||||
tags: Mapped[list["Tag"]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="asset_infos",
|
||||
lazy="joined",
|
||||
viewonly=True,
|
||||
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_assets_info_owner_id", "owner_id"),
|
||||
Index("ix_assets_info_asset_hash", "asset_hash"),
|
||||
Index("ix_assets_info_name", "name"),
|
||||
Index("ix_assets_info_created_at", "created_at"),
|
||||
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||
{"sqlite_autoincrement": True},
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
data = to_dict(self, include_none=include_none)
|
||||
data["tags"] = [t.name for t in self.tags]
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetInfo id={self.id} name={self.name!r} hash={self.asset_hash[:12]}>"
|
||||
|
||||
|
||||
|
||||
class AssetInfoMeta(Base):
|
||||
__tablename__ = "asset_info_meta"
|
||||
|
||||
asset_info_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
|
||||
|
||||
val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
|
||||
val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True)
|
||||
val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
|
||||
val_json: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True)
|
||||
|
||||
asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_meta_key", "key"),
|
||||
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
|
||||
)
|
||||
|
||||
|
||||
class AssetInfoTag(Base):
|
||||
__tablename__ = "asset_info_tags"
|
||||
|
||||
asset_info_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
tag_name: Mapped[str] = mapped_column(
|
||||
String(128), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_by: Mapped[str | None] = mapped_column(String(128))
|
||||
added_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
|
||||
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(128), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(
|
||||
back_populates="tag",
|
||||
overlaps="asset_infos,tags",
|
||||
)
|
||||
asset_infos: Mapped[list["AssetInfo"]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="tags",
|
||||
viewonly=True,
|
||||
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_tags_tag_type", "tag_type"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
|
||||
683
app/database/services.py
Normal file
683
app/database/services.py
Normal file
@ -0,0 +1,683 @@
|
||||
import os
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from decimal import Decimal
|
||||
from typing import Any, Sequence, Optional, Iterable
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.orm import contains_eager
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
|
||||
|
||||
|
||||
async def check_fs_asset_exists_quick(
|
||||
session,
|
||||
*,
|
||||
file_path: str,
|
||||
size_bytes: Optional[int] = None,
|
||||
mtime_ns: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path,
|
||||
AND (if provided) mtime_ns matches stored locator-state,
|
||||
AND (if provided) size_bytes matches verified size when known.
|
||||
"""
|
||||
locator = os.path.abspath(file_path)
|
||||
|
||||
stmt = select(sa.literal(True)).select_from(Asset)
|
||||
|
||||
conditions = [
|
||||
Asset.storage_backend == "fs",
|
||||
Asset.storage_locator == locator,
|
||||
]
|
||||
|
||||
# If size_bytes provided require equality when the asset has a verified (non-zero) size.
|
||||
# If verified size is 0 (unknown), we don't force equality.
|
||||
if size_bytes is not None:
|
||||
conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
|
||||
|
||||
# If mtime_ns provided require the locator-state to exist and match.
|
||||
if mtime_ns is not None:
|
||||
stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash)
|
||||
conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns))
|
||||
|
||||
stmt = stmt.where(*conditions).limit(1)
|
||||
|
||||
row = (await session.execute(stmt)).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
async def ingest_fs_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: Optional[str] = None,
|
||||
info_name: Optional[str] = None,
|
||||
owner_id: Optional[str] = None,
|
||||
preview_hash: Optional[str] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Creates or updates Asset record for a local (fs) asset.
|
||||
|
||||
Always:
|
||||
- Insert Asset if missing; else update size_bytes (and updated_at) if different.
|
||||
- Insert AssetLocatorState if missing; else update mtime_ns if different.
|
||||
|
||||
Optionally (when info_name is provided):
|
||||
- Create an AssetInfo (no refcount changes).
|
||||
- Link provided tags to that AssetInfo.
|
||||
* If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table.
|
||||
* If False (default), silently skips unknown tags.
|
||||
|
||||
Returns flags and ids:
|
||||
{
|
||||
"asset_created": bool,
|
||||
"asset_updated": bool,
|
||||
"state_created": bool,
|
||||
"state_updated": bool,
|
||||
"asset_info_id": int | None,
|
||||
"tags_added": list[str],
|
||||
"tags_missing": list[str], # filled only when require_existing_tags=False
|
||||
}
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
datetime_now = datetime.now(timezone.utc)
|
||||
|
||||
out = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
"tags_added": [],
|
||||
"tags_missing": [],
|
||||
}
|
||||
|
||||
# ---- Step 1: INSERT Asset or UPDATE size_bytes/updated_at if exists ----
|
||||
async with session.begin_nested() as sp1:
|
||||
try:
|
||||
session.add(
|
||||
Asset(
|
||||
hash=asset_hash,
|
||||
size_bytes=int(size_bytes),
|
||||
mime_type=mime_type,
|
||||
refcount=0,
|
||||
storage_backend="fs",
|
||||
storage_locator=locator,
|
||||
created_at=datetime_now,
|
||||
updated_at=datetime_now,
|
||||
)
|
||||
)
|
||||
await session.flush()
|
||||
out["asset_created"] = True
|
||||
except IntegrityError:
|
||||
await sp1.rollback()
|
||||
# Already exists by hash -> update selected fields if different
|
||||
existing = await session.get(Asset, asset_hash)
|
||||
if existing is not None:
|
||||
desired_size = int(size_bytes)
|
||||
if existing.size_bytes != desired_size:
|
||||
existing.size_bytes = desired_size
|
||||
existing.updated_at = datetime_now
|
||||
out["asset_updated"] = True
|
||||
else:
|
||||
# This should not occur. Log for visibility.
|
||||
logging.error("Asset %s not found after conflict; skipping update.", asset_hash)
|
||||
except Exception:
|
||||
await sp1.rollback()
|
||||
logging.exception("Unexpected error inserting Asset (hash=%s, locator=%s)", asset_hash, locator)
|
||||
raise
|
||||
|
||||
# ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ----
|
||||
async with session.begin_nested() as sp2:
|
||||
try:
|
||||
session.add(
|
||||
AssetLocatorState(
|
||||
asset_hash=asset_hash,
|
||||
mtime_ns=int(mtime_ns),
|
||||
)
|
||||
)
|
||||
await session.flush()
|
||||
out["state_created"] = True
|
||||
except IntegrityError:
|
||||
await sp2.rollback()
|
||||
state = await session.get(AssetLocatorState, asset_hash)
|
||||
if state is not None:
|
||||
desired_mtime = int(mtime_ns)
|
||||
if state.mtime_ns != desired_mtime:
|
||||
state.mtime_ns = desired_mtime
|
||||
out["state_updated"] = True
|
||||
else:
|
||||
logging.debug("Locator state missing for %s after conflict; skipping update.", asset_hash)
|
||||
except Exception:
|
||||
await sp2.rollback()
|
||||
logging.exception("Unexpected error inserting AssetLocatorState (hash=%s)", asset_hash)
|
||||
raise
|
||||
|
||||
# ---- Optional: AssetInfo + tag links ----
|
||||
if info_name:
|
||||
# 2a) Create AssetInfo (no refcount bump)
|
||||
async with session.begin_nested() as sp3:
|
||||
try:
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_hash=asset_hash,
|
||||
preview_hash=preview_hash,
|
||||
created_at=datetime_now,
|
||||
updated_at=datetime_now,
|
||||
last_access_time=datetime_now,
|
||||
)
|
||||
session.add(info)
|
||||
await session.flush() # get info.id
|
||||
out["asset_info_id"] = info.id
|
||||
except Exception:
|
||||
await sp3.rollback()
|
||||
logging.exception(
|
||||
"Unexpected error inserting AssetInfo (hash=%s, name=%s)", asset_hash, info_name
|
||||
)
|
||||
raise
|
||||
|
||||
# 2b) Link tags (if any). We DO NOT create new Tag rows here by default.
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
# Which tags exist?
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
# Which links already exist?
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_by=added_by,
|
||||
added_at=datetime_now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
await session.flush()
|
||||
out["tags_added"] = to_add
|
||||
out["tags_missing"] = missing
|
||||
|
||||
# 2c) Rebuild metadata projection if provided
|
||||
if user_metadata is not None and out["asset_info_id"] is not None:
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=user_metadata,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
async def touch_asset_infos_by_fs_path(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
abs_path: str,
|
||||
ts: Optional[datetime] = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> int:
|
||||
locator = os.path.abspath(abs_path)
|
||||
ts = ts or datetime.now(timezone.utc)
|
||||
|
||||
stmt = sa.update(AssetInfo).where(
|
||||
sa.exists(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(Asset)
|
||||
.where(
|
||||
Asset.hash == AssetInfo.asset_hash,
|
||||
Asset.storage_backend == "fs",
|
||||
Asset.storage_locator == locator,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(
|
||||
AssetInfo.last_access_time.is_(None),
|
||||
AssetInfo.last_access_time < ts,
|
||||
)
|
||||
)
|
||||
|
||||
stmt = stmt.values(last_access_time=ts)
|
||||
|
||||
res = await session.execute(stmt)
|
||||
return int(res.rowcount or 0)
|
||||
|
||||
|
||||
async def list_asset_infos_page(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[int, list[str]], int]:
|
||||
"""
|
||||
Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1),
|
||||
plus a map of asset_info_id -> [tags], and the total count.
|
||||
|
||||
We purposely collect tags in a separate (single) query to avoid row explosion.
|
||||
"""
|
||||
# Clamp
|
||||
if limit <= 0:
|
||||
limit = 1
|
||||
if limit > 100:
|
||||
limit = 100
|
||||
if offset < 0:
|
||||
offset = 0
|
||||
|
||||
# Build base query
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||
.options(contains_eager(AssetInfo.asset))
|
||||
)
|
||||
|
||||
# Filters
|
||||
if name_contains:
|
||||
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||
|
||||
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
||||
|
||||
base = _apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
# Sort
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
# Total count (same filters, no ordering/limit/offset)
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||
)
|
||||
if name_contains:
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = (await session.execute(count_stmt)).scalar_one()
|
||||
|
||||
# Fetch rows
|
||||
infos = (await session.execute(base)).scalars().unique().all()
|
||||
|
||||
# Collect tags in bulk (single query)
|
||||
id_list = [i.id for i in infos]
|
||||
tag_map: dict[int, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = await session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
async def set_asset_info_tags(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Replace the tag set on an AssetInfo with `tags`. Idempotent.
|
||||
Creates missing tag names as 'user'.
|
||||
"""
|
||||
desired = _normalize_tags(tags)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# current links
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
await _ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=now)
|
||||
for t in to_add
|
||||
])
|
||||
await session.flush()
|
||||
|
||||
if to_remove:
|
||||
await session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
async def update_asset_info_full(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tag_origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
) -> AssetInfo:
|
||||
"""
|
||||
Update AssetInfo fields:
|
||||
- name (if provided)
|
||||
- user_metadata blob + rebuild projection (if provided)
|
||||
- replace tags with provided set (if provided)
|
||||
Returns the updated AssetInfo.
|
||||
"""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
if user_metadata is not None:
|
||||
await replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=user_metadata
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
await set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
added_by=added_by,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = datetime.now(timezone.utc)
|
||||
await session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
async def replace_asset_info_metadata_projection(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
user_metadata: dict | None,
|
||||
) -> None:
|
||||
"""Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`."""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = datetime.now(timezone.utc)
|
||||
await session.flush()
|
||||
|
||||
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
await session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in _project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[Tag]:
|
||||
return [
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def _normalize_tags(tags: Sequence[str] | None) -> list[str]:
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
|
||||
|
||||
async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
|
||||
wanted = _normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return []
|
||||
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
||||
by_name = {t.name: t for t in existing}
|
||||
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
|
||||
if to_create:
|
||||
session.add_all(to_create)
|
||||
await session.flush()
|
||||
by_name.update({t.name: t for t in to_create})
|
||||
return [by_name[n] for n in wanted]
|
||||
|
||||
|
||||
def _apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None,
|
||||
exclude_tags: Sequence[str] | None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = _normalize_tags(include_tags)
|
||||
exclude_tags = _normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
def _apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply metadata filters using the projection table asset_info_meta.
|
||||
|
||||
Semantics:
|
||||
- For scalar values: require EXISTS(asset_info_meta) with matching key + typed value.
|
||||
- For None: key is missing OR key has explicit null (val_json IS NULL).
|
||||
- For list values: ANY-of the list elements matches (EXISTS for any).
|
||||
(Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))')
|
||||
"""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
subquery = (
|
||||
select(sa.literal(1))
|
||||
.select_from(AssetInfoMeta)
|
||||
.where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
return sa.exists(subquery)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
# Missing OR null:
|
||||
if value is None:
|
||||
# either: no row for key OR a row for key with explicit null
|
||||
no_row_for_key = ~sa.exists(
|
||||
select(sa.literal(1))
|
||||
.select_from(AssetInfoMeta)
|
||||
.where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None))
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
# Typed scalar matches:
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
# store as Decimal for equality against NUMERIC(38,10)
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
|
||||
# Complex: compare JSON (no index, but supported)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
# ANY-of (exists for any element)
|
||||
ors = [ _exists_clause_for_value(k, elem) for elem in v ]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def _is_scalar(v: Any) -> bool:
|
||||
if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _project_kv(key: str, value: Any) -> list[dict]:
|
||||
"""
|
||||
Turn a metadata key/value into one or more projection rows:
|
||||
- scalar -> one row (ordinal=0) in the proper typed column
|
||||
- list of scalars -> one row per element with ordinal=i
|
||||
- dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list)
|
||||
- None -> single row with val_json = None
|
||||
Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...}
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
# None
|
||||
if value is None:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": None})
|
||||
return rows
|
||||
|
||||
# Scalars
|
||||
if _is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
# store numeric; SQLAlchemy will coerce to Numeric
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": value})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
# Fallback to json
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
# Lists
|
||||
if isinstance(value, list):
|
||||
# list of scalars?
|
||||
if all(_is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": None})
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
rows.append({"key": key, "ordinal": i, "val_num": x})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
# list contains objects -> one val_json per element
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
# Dict or any other structure -> single json row
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
@ -1,331 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from folder_paths import get_relative_path, get_full_path
|
||||
from app.database.db import create_session, dependencies_available, can_create_session
|
||||
import blake3
|
||||
import comfy.utils
|
||||
|
||||
|
||||
if dependencies_available():
|
||||
from app.database.models import Model
|
||||
|
||||
|
||||
class ModelProcessor:
|
||||
def _validate_path(self, model_path):
|
||||
try:
|
||||
if not self._file_exists(model_path):
|
||||
logging.error(f"Model file not found: {model_path}")
|
||||
return None
|
||||
|
||||
result = get_relative_path(model_path)
|
||||
if not result:
|
||||
logging.error(
|
||||
f"Model file not in a recognized model directory: {model_path}"
|
||||
)
|
||||
return None
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.error(f"Error validating model path {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _file_exists(self, path):
|
||||
"""Check if a file exists."""
|
||||
return os.path.exists(path)
|
||||
|
||||
def _get_file_size(self, path):
|
||||
"""Get file size."""
|
||||
return os.path.getsize(path)
|
||||
|
||||
def _get_hasher(self):
|
||||
return blake3.blake3()
|
||||
|
||||
def _hash_file(self, model_path):
|
||||
try:
|
||||
hasher = self._get_hasher()
|
||||
with open(model_path, "rb", buffering=0) as f:
|
||||
b = bytearray(128 * 1024)
|
||||
mv = memoryview(b)
|
||||
while n := f.readinto(mv):
|
||||
hasher.update(mv[:n])
|
||||
return hasher.hexdigest()
|
||||
except Exception as e:
|
||||
logging.error(f"Error hashing file {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _get_existing_model(self, session, model_type, model_relative_path):
|
||||
return (
|
||||
session.query(Model)
|
||||
.filter(Model.type == model_type)
|
||||
.filter(Model.path == model_relative_path)
|
||||
.first()
|
||||
)
|
||||
|
||||
def _ensure_source_url(self, session, model, source_url):
|
||||
if model.source_url is None:
|
||||
model.source_url = source_url
|
||||
session.commit()
|
||||
|
||||
def _update_database(
|
||||
self,
|
||||
session,
|
||||
model_type,
|
||||
model_path,
|
||||
model_relative_path,
|
||||
model_hash,
|
||||
model,
|
||||
source_url,
|
||||
):
|
||||
try:
|
||||
if not model:
|
||||
model = self._get_existing_model(
|
||||
session, model_type, model_relative_path
|
||||
)
|
||||
|
||||
if not model:
|
||||
model = Model(
|
||||
path=model_relative_path,
|
||||
type=model_type,
|
||||
file_name=os.path.basename(model_path),
|
||||
)
|
||||
session.add(model)
|
||||
|
||||
model.file_size = self._get_file_size(model_path)
|
||||
model.hash = model_hash
|
||||
if model_hash:
|
||||
model.hash_algorithm = "blake3"
|
||||
model.source_url = source_url
|
||||
|
||||
session.commit()
|
||||
return model
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error updating database for {model_relative_path}: {str(e)}"
|
||||
)
|
||||
|
||||
def process_file(self, model_path, source_url=None, model_hash=None):
|
||||
"""
|
||||
Process a model file and update the database with metadata.
|
||||
If the file already exists and matches the database, it will not be processed again.
|
||||
Returns the model object or if an error occurs, returns None.
|
||||
"""
|
||||
try:
|
||||
if not can_create_session():
|
||||
return
|
||||
|
||||
result = self._validate_path(model_path)
|
||||
if not result:
|
||||
return
|
||||
model_type, model_relative_path = result
|
||||
|
||||
with create_session() as session:
|
||||
session.expire_on_commit = False
|
||||
|
||||
existing_model = self._get_existing_model(
|
||||
session, model_type, model_relative_path
|
||||
)
|
||||
if (
|
||||
existing_model
|
||||
and existing_model.hash
|
||||
and existing_model.file_size == self._get_file_size(model_path)
|
||||
):
|
||||
# File exists with hash and same size, no need to process
|
||||
self._ensure_source_url(session, existing_model, source_url)
|
||||
return existing_model
|
||||
|
||||
if model_hash:
|
||||
model_hash = model_hash.lower()
|
||||
logging.info(f"Using provided hash: {model_hash}")
|
||||
else:
|
||||
start_time = time.time()
|
||||
logging.info(f"Hashing model {model_relative_path}")
|
||||
model_hash = self._hash_file(model_path)
|
||||
if not model_hash:
|
||||
return
|
||||
logging.info(
|
||||
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
|
||||
)
|
||||
|
||||
return self._update_database(
|
||||
session,
|
||||
model_type,
|
||||
model_path,
|
||||
model_relative_path,
|
||||
model_hash,
|
||||
existing_model,
|
||||
source_url,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing model file {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def retrieve_model_by_hash(self, model_hash, model_type=None, session=None):
|
||||
"""
|
||||
Retrieve a model file from the database by hash and optionally by model type.
|
||||
Returns the model object or None if the model doesnt exist or an error occurs.
|
||||
"""
|
||||
try:
|
||||
if not can_create_session():
|
||||
return
|
||||
|
||||
dispose_session = False
|
||||
|
||||
if session is None:
|
||||
session = create_session()
|
||||
dispose_session = True
|
||||
|
||||
model = session.query(Model).filter(Model.hash == model_hash)
|
||||
if model_type is not None:
|
||||
model = model.filter(Model.type == model_type)
|
||||
return model.first()
|
||||
except Exception as e:
|
||||
logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}")
|
||||
return None
|
||||
finally:
|
||||
if dispose_session:
|
||||
session.close()
|
||||
|
||||
def retrieve_hash(self, model_path, model_type=None):
|
||||
"""
|
||||
Retrieve the hash of a model file from the database.
|
||||
Returns the hash or None if the model doesnt exist or an error occurs.
|
||||
"""
|
||||
try:
|
||||
if not can_create_session():
|
||||
return
|
||||
|
||||
if model_type is not None:
|
||||
result = self._validate_path(model_path)
|
||||
if not result:
|
||||
return None
|
||||
model_type, model_relative_path = result
|
||||
|
||||
with create_session() as session:
|
||||
model = self._get_existing_model(
|
||||
session, model_type, model_relative_path
|
||||
)
|
||||
if model and model.hash:
|
||||
return model.hash
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Error retrieving hash for {model_path}: {str(e)}")
|
||||
return None
|
||||
|
||||
def _validate_file_extension(self, file_name):
|
||||
"""Validate that the file extension is supported."""
|
||||
extension = os.path.splitext(file_name)[1]
|
||||
if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"):
|
||||
raise ValueError(f"Unsupported unsafe file for download: {file_name}")
|
||||
|
||||
def _check_existing_file(self, model_type, file_name, expected_hash):
|
||||
"""Check if file exists and has correct hash."""
|
||||
destination_path = get_full_path(model_type, file_name, allow_missing=True)
|
||||
if self._file_exists(destination_path):
|
||||
model = self.process_file(destination_path)
|
||||
if model and (expected_hash is None or model.hash == expected_hash):
|
||||
logging.debug(
|
||||
f"File {destination_path} already exists in the database and has the correct hash or no hash was provided."
|
||||
)
|
||||
return destination_path
|
||||
else:
|
||||
raise ValueError(
|
||||
f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again."
|
||||
)
|
||||
return None
|
||||
|
||||
def _check_existing_file_by_hash(self, hash, type, url):
|
||||
"""Check if a file with the given hash exists in the database and on disk."""
|
||||
hash = hash.lower()
|
||||
with create_session() as session:
|
||||
model = self.retrieve_model_by_hash(hash, type, session)
|
||||
if model:
|
||||
existing_path = get_full_path(type, model.path)
|
||||
if existing_path:
|
||||
logging.debug(
|
||||
f"File {model.path} already exists in the database at {existing_path}"
|
||||
)
|
||||
self._ensure_source_url(session, model, url)
|
||||
return existing_path
|
||||
else:
|
||||
logging.debug(
|
||||
f"File {model.path} exists in the database but not on disk"
|
||||
)
|
||||
return None
|
||||
|
||||
def _download_file(self, url, destination_path, hasher):
|
||||
"""Download a file and update the hasher with its contents."""
|
||||
response = requests.get(url, stream=True)
|
||||
logging.info(f"Downloading {url} to {destination_path}")
|
||||
|
||||
with open(destination_path, "wb") as f:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
if total_size > 0:
|
||||
pbar = comfy.utils.ProgressBar(total_size)
|
||||
else:
|
||||
pbar = None
|
||||
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
|
||||
for chunk in response.iter_content(chunk_size=128 * 1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
hasher.update(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
if pbar:
|
||||
pbar.update(len(chunk))
|
||||
|
||||
def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path):
|
||||
"""Verify that the downloaded file has the expected hash."""
|
||||
if expected_hash is not None and calculated_hash != expected_hash:
|
||||
self._remove_file(destination_path)
|
||||
raise ValueError(
|
||||
f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}"
|
||||
)
|
||||
|
||||
def _remove_file(self, file_path):
|
||||
"""Remove a file from disk."""
|
||||
os.remove(file_path)
|
||||
|
||||
def ensure_downloaded(self, type, url, desired_file_name, hash=None):
|
||||
"""
|
||||
Ensure a model file is downloaded and has the correct hash.
|
||||
Returns the path to the downloaded file.
|
||||
"""
|
||||
logging.debug(
|
||||
f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'"
|
||||
)
|
||||
|
||||
# Validate file extension
|
||||
self._validate_file_extension(desired_file_name)
|
||||
|
||||
# Check if file exists with correct hash
|
||||
if hash:
|
||||
existing_path = self._check_existing_file_by_hash(hash, type, url)
|
||||
if existing_path:
|
||||
return existing_path
|
||||
|
||||
# Check if file exists locally
|
||||
destination_path = get_full_path(type, desired_file_name, allow_missing=True)
|
||||
existing_path = self._check_existing_file(type, desired_file_name, hash)
|
||||
if existing_path:
|
||||
return existing_path
|
||||
|
||||
# Download the file
|
||||
hasher = self._get_hasher()
|
||||
self._download_file(url, destination_path, hasher)
|
||||
|
||||
# Verify hash
|
||||
calculated_hash = hasher.hexdigest()
|
||||
self._verify_downloaded_hash(calculated_hash, hash, destination_path)
|
||||
|
||||
# Update database
|
||||
self.process_file(destination_path, url, calculated_hash)
|
||||
|
||||
# TODO: Notify frontend to reload models
|
||||
|
||||
return destination_path
|
||||
|
||||
|
||||
model_processor = ModelProcessor()
|
||||
0
app/storage/__init__.py
Normal file
0
app/storage/__init__.py
Normal file
72
app/storage/hashing.py
Normal file
72
app/storage/hashing.py
Normal file
@ -0,0 +1,72 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import IO, Union
|
||||
|
||||
from blake3 import blake3
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 * 1024 # 8 MiB
|
||||
|
||||
|
||||
def _hash_file_obj_sync(file_obj: IO[bytes], chunk_size: int) -> str:
|
||||
"""Hash an already-open binary file object by streaming in chunks.
|
||||
- Seeks to the beginning before reading (if supported).
|
||||
- Restores the original position afterward (if tell/seek are supported).
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
orig_pos = None
|
||||
if hasattr(file_obj, "tell"):
|
||||
orig_pos = file_obj.tell()
|
||||
|
||||
try:
|
||||
if hasattr(file_obj, "seek"):
|
||||
file_obj.seek(0)
|
||||
|
||||
h = blake3()
|
||||
while True:
|
||||
chunk = file_obj.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
finally:
|
||||
if hasattr(file_obj, "seek") and orig_pos is not None:
|
||||
file_obj.seek(orig_pos)
|
||||
|
||||
|
||||
def blake3_hash_sync(
|
||||
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Returns a BLAKE3 hex digest for ``fp``, which may be:
|
||||
- a filename (str/bytes) or PathLike
|
||||
- an open binary file object
|
||||
|
||||
If ``fp`` is a file object, it must be opened in **binary** mode and support
|
||||
``read``, ``seek``, and ``tell``. The function will seek to the start before
|
||||
reading and will attempt to restore the original position afterward.
|
||||
"""
|
||||
if hasattr(fp, "read"):
|
||||
return _hash_file_obj_sync(fp, chunk_size)
|
||||
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj_sync(f, chunk_size)
|
||||
|
||||
|
||||
async def blake3_hash(
|
||||
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Async wrapper for ``blake3_hash_sync``.
|
||||
Uses a worker thread so the event loop remains responsive.
|
||||
"""
|
||||
# If it is a path, open inside the worker thread to keep I/O off the loop.
|
||||
if hasattr(fp, "read"):
|
||||
return await asyncio.to_thread(blake3_hash_sync, fp, chunk_size)
|
||||
|
||||
def _worker() -> str:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj_sync(f, chunk_size)
|
||||
|
||||
return await asyncio.to_thread(_worker)
|
||||
@ -1,110 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TypedDict
|
||||
import os
|
||||
import logging
|
||||
import comfy.utils
|
||||
|
||||
import folder_paths
|
||||
|
||||
class AssetMetadata(TypedDict):
|
||||
device: Any
|
||||
return_metadata: bool
|
||||
subdir: str
|
||||
download_url: str
|
||||
|
||||
class AssetInfo:
|
||||
def __init__(self, hash: str=None, name: str=None, tags: list[str]=None, metadata: AssetMetadata={}):
|
||||
self.hash = hash
|
||||
self.name = name
|
||||
self.tags = tags
|
||||
self.metadata = metadata
|
||||
|
||||
|
||||
class ReturnedAssetABC(ABC):
|
||||
def __init__(self, mimetype: str):
|
||||
self.mimetype = mimetype
|
||||
|
||||
|
||||
class ModelReturnedAsset(ReturnedAssetABC):
|
||||
def __init__(self, state_dict: dict[str, str], metadata: dict[str, str]=None):
|
||||
super().__init__("model")
|
||||
self.state_dict = state_dict
|
||||
self.metadata = metadata
|
||||
|
||||
|
||||
class AssetResolverABC(ABC):
|
||||
@abstractmethod
|
||||
def resolve(self, asset_info: AssetInfo) -> ReturnedAssetABC:
|
||||
...
|
||||
|
||||
|
||||
class LocalAssetResolver(AssetResolverABC):
|
||||
def resolve(self, asset_info: AssetInfo, cache_result: bool=True) -> ReturnedAssetABC:
|
||||
# currently only supports models - make sure models is in the tags
|
||||
if "models" not in asset_info.tags:
|
||||
return None
|
||||
# TODO: if hash exists, call model processor to try to get info about model:
|
||||
if asset_info.hash:
|
||||
try:
|
||||
from app.model_processor import model_processor
|
||||
model_db = model_processor.retrieve_model_by_hash(asset_info.hash)
|
||||
full_path = model_db.path
|
||||
except Exception as e:
|
||||
logging.error(f"Could not get model by hash with error: {e}")
|
||||
# the good ol' bread and butter - folder_paths's keys as tags
|
||||
folder_keys = folder_paths.folder_names_and_paths.keys()
|
||||
parent_paths = []
|
||||
for tag in asset_info.tags:
|
||||
if tag in folder_keys:
|
||||
parent_paths.append(tag)
|
||||
# if subdir metadata and name exists, use that as the model name going forward
|
||||
if "subdir" in asset_info.metadata and asset_info.name:
|
||||
# if no matching parent paths, then something went wrong and should return None
|
||||
if len(parent_paths) == 0:
|
||||
return None
|
||||
relative_path = os.path.join(asset_info.metadata["subdir"], asset_info.name)
|
||||
# now we have the parent keys, we can try to get the local path
|
||||
chosen_parent = None
|
||||
full_path = None
|
||||
for parent_path in parent_paths:
|
||||
full_path = folder_paths.get_full_path(parent_path, relative_path)
|
||||
if full_path:
|
||||
chosen_parent = parent_path
|
||||
break
|
||||
if full_path is not None:
|
||||
logging.info(f"Resolved {asset_info.name} to {full_path} in {chosen_parent}")
|
||||
# we know the path, so load the model and return it
|
||||
state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True)
|
||||
# TODO: handle caching
|
||||
return ModelReturnedAsset(state_dict, metadata)
|
||||
# if just name exists, try to find model by name in all subdirs of parent paths
|
||||
# TODO: this behavior should be configurable by user
|
||||
if asset_info.name:
|
||||
for parent_path in parent_paths:
|
||||
filelist = folder_paths.get_filename_list(parent_path)
|
||||
for file in filelist:
|
||||
if os.path.basename(file) == asset_info.name:
|
||||
full_path = folder_paths.get_full_path(parent_path, file)
|
||||
state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True)
|
||||
# TODO: handle caching
|
||||
return ModelReturnedAsset(state_dict, metadata)
|
||||
# TODO: if download_url metadata exists, download the model and load it; this should be configurable by user
|
||||
if asset_info.metadata.get("download_url", None):
|
||||
...
|
||||
return None
|
||||
|
||||
|
||||
resolvers: list[AssetResolverABC] = []
|
||||
resolvers.append(LocalAssetResolver())
|
||||
|
||||
|
||||
def resolve(asset_info: AssetInfo) -> Any:
|
||||
global resolvers
|
||||
for resolver in resolvers:
|
||||
try:
|
||||
to_return = resolver.resolve(asset_info)
|
||||
if to_return is not None:
|
||||
return resolver.resolve(asset_info)
|
||||
except Exception as e:
|
||||
logging.error(f"Error resolving asset {asset_info.name} using {resolver.__class__.__name__}: {e}")
|
||||
return None
|
||||
@ -211,7 +211,7 @@ parser.add_argument(
|
||||
database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
|
||||
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
|
||||
@ -102,12 +102,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
else:
|
||||
sd = pl_sd
|
||||
|
||||
try:
|
||||
from app.model_processor import model_processor
|
||||
model_processor.process_file(ckpt)
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing file {ckpt}: {e}")
|
||||
|
||||
return (sd, metadata) if return_metadata else sd
|
||||
|
||||
def save_torch_file(sd, ckpt, metadata=None):
|
||||
|
||||
@ -1,56 +0,0 @@
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
import comfy.asset_management
|
||||
import comfy.sd
|
||||
import folder_paths
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
class AssetTestNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="AssetTestNode",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input("ckpt_name", folder_paths.get_filename_list("checkpoints")),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
io.Clip.Output(),
|
||||
io.Vae.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, ckpt_name: str):
|
||||
hash = None
|
||||
# lets get the full path just so we can retrieve the hash from db, if exists
|
||||
try:
|
||||
full_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
if full_path is None:
|
||||
raise Exception(f"Model {ckpt_name} not found")
|
||||
from app.model_processor import model_processor
|
||||
hash = model_processor.retrieve_hash(full_path)
|
||||
except Exception as e:
|
||||
logging.error(f"Could not get model by hash with error: {e}")
|
||||
subdir, name = os.path.split(ckpt_name)
|
||||
asset_info = comfy.asset_management.AssetInfo(hash=hash, name=name, tags=["models", "checkpoints"], metadata={"subdir": subdir})
|
||||
asset = comfy.asset_management.resolve(asset_info)
|
||||
# /\ the stuff above should happen in execution code instead of inside the node
|
||||
# \/ the stuff below should happen in the node - confirm is a model asset, do stuff to it (already loaded? or should be called to 'load'?)
|
||||
if asset is None:
|
||||
raise Exception(f"Model {asset_info.name} not found")
|
||||
assert isinstance(asset, comfy.asset_management.ModelReturnedAsset)
|
||||
out = comfy.sd.load_state_dict_guess_config(asset.state_dict, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), metadata=asset.metadata)
|
||||
return io.NodeOutput(out[0], out[1], out[2])
|
||||
|
||||
|
||||
class AssetTestExtension(ComfyExtension):
|
||||
@classmethod
|
||||
async def get_node_list(cls):
|
||||
return [AssetTestNode]
|
||||
|
||||
|
||||
def comfy_entrypoint():
|
||||
return AssetTestExtension()
|
||||
9
main.py
9
main.py
@ -278,11 +278,11 @@ def cleanup_temp():
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
def setup_database():
|
||||
async def setup_database():
|
||||
try:
|
||||
from app.database.db import init_db, dependencies_available
|
||||
from app.database.db import init_db_engine, dependencies_available
|
||||
if dependencies_available():
|
||||
init_db()
|
||||
await init_db_engine()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||
|
||||
@ -309,6 +309,8 @@ def start_comfyui(asyncio_loop=None):
|
||||
asyncio.set_event_loop(asyncio_loop)
|
||||
prompt_server = server.PromptServer(asyncio_loop)
|
||||
|
||||
asyncio_loop.run_until_complete(setup_database())
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
@ -317,7 +319,6 @@ def start_comfyui(asyncio_loop=None):
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
setup_database()
|
||||
|
||||
prompt_server.add_routes()
|
||||
hijack_progress(prompt_server)
|
||||
|
||||
7
nodes.py
7
nodes.py
@ -28,9 +28,10 @@ import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.controlnet
|
||||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
|
||||
from comfy_api.internal import register_versions, ComfyAPIWithVersion
|
||||
from comfy_api.internal import async_to_sync, register_versions, ComfyAPIWithVersion
|
||||
from comfy_api.version_list import supported_versions
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
from app.assets_manager import add_local_asset
|
||||
|
||||
import comfy.clip_vision
|
||||
|
||||
@ -777,6 +778,9 @@ class VAELoader:
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
|
||||
add_local_asset, tags=["models", "vae"], file_name=vae_name, file_path=vae_path
|
||||
)
|
||||
vae = comfy.sd.VAE(sd=sd)
|
||||
vae.throw_exception_if_invalid()
|
||||
return (vae,)
|
||||
@ -2321,7 +2325,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_edit_model.py",
|
||||
"nodes_tcfg.py",
|
||||
"nodes_context_windows.py",
|
||||
"nodes_assets_test.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@ -20,6 +20,7 @@ tqdm
|
||||
psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
aiosqlite
|
||||
av>=14.2.0
|
||||
blake3
|
||||
|
||||
|
||||
@ -37,6 +37,7 @@ from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from app.api.assets_routes import register_assets_routes
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
@ -183,6 +184,7 @@ class PromptServer():
|
||||
else args.front_end_root
|
||||
)
|
||||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||
register_assets_routes(self.app)
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
|
||||
@ -1,62 +0,0 @@
|
||||
import pytest
|
||||
import base64
|
||||
import json
|
||||
import struct
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from aiohttp import web
|
||||
from unittest.mock import patch
|
||||
from app.model_manager import ModelFileManager
|
||||
|
||||
pytestmark = (
|
||||
pytest.mark.asyncio
|
||||
) # This applies the asyncio mark to all test functions in the module
|
||||
|
||||
@pytest.fixture
|
||||
def model_manager():
|
||||
return ModelFileManager()
|
||||
|
||||
@pytest.fixture
|
||||
def app(model_manager):
|
||||
app = web.Application()
|
||||
routes = web.RouteTableDef()
|
||||
model_manager.add_routes(routes)
|
||||
app.add_routes(routes)
|
||||
return app
|
||||
|
||||
async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
|
||||
img = Image.new('RGB', (100, 100), 'white')
|
||||
img_byte_arr = BytesIO()
|
||||
img.save(img_byte_arr, format='PNG')
|
||||
img_byte_arr.seek(0)
|
||||
img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
|
||||
|
||||
safetensors_file = tmp_path / "test_model.safetensors"
|
||||
header_bytes = json.dumps({
|
||||
"__metadata__": {
|
||||
"ssmd_cover_images": json.dumps([img_b64])
|
||||
}
|
||||
}).encode('utf-8')
|
||||
length_bytes = struct.pack('<Q', len(header_bytes))
|
||||
with open(safetensors_file, 'wb') as f:
|
||||
f.write(length_bytes)
|
||||
f.write(header_bytes)
|
||||
|
||||
with patch('folder_paths.folder_names_and_paths', {
|
||||
'test_folder': ([str(tmp_path)], None)
|
||||
}):
|
||||
client = await aiohttp_client(app)
|
||||
response = await client.get('/experiment/models/preview/test_folder/0/test_model.safetensors')
|
||||
|
||||
# Verify response
|
||||
assert response.status == 200
|
||||
assert response.content_type == 'image/webp'
|
||||
|
||||
# Verify the response contains valid image data
|
||||
img_bytes = BytesIO(await response.read())
|
||||
img = Image.open(img_bytes)
|
||||
assert img.format
|
||||
assert img.format.lower() == 'webp'
|
||||
|
||||
# Clean up
|
||||
img.close()
|
||||
@ -1,253 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.model_processor import ModelProcessor
|
||||
from app.database.models import Model, Base
|
||||
import os
|
||||
|
||||
# Test data constants
|
||||
TEST_MODEL_TYPE = "checkpoints"
|
||||
TEST_URL = "http://example.com/model.safetensors"
|
||||
TEST_FILE_NAME = "model.safetensors"
|
||||
TEST_EXPECTED_HASH = "abc123"
|
||||
TEST_DESTINATION_PATH = "/path/to/model.safetensors"
|
||||
|
||||
|
||||
def create_test_model(session, file_name, model_type, hash_value, file_size=1000, source_url=None):
|
||||
"""Helper to create a test model in the database."""
|
||||
model = Model(path=file_name, type=model_type, hash=hash_value, file_size=file_size, source_url=source_url)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
return model
|
||||
|
||||
|
||||
def setup_mock_hash_calculation(model_processor, hash_value):
|
||||
"""Helper to setup hash calculation mocks."""
|
||||
mock_hash = MagicMock()
|
||||
mock_hash.hexdigest.return_value = hash_value
|
||||
return patch.object(model_processor, "_get_hasher", return_value=mock_hash)
|
||||
|
||||
|
||||
def verify_model_in_db(session, file_name, expected_hash=None, expected_type=None):
|
||||
"""Helper to verify model exists in database with correct attributes."""
|
||||
db_model = session.query(Model).filter_by(path=file_name).first()
|
||||
assert db_model is not None
|
||||
if expected_hash:
|
||||
assert db_model.hash == expected_hash
|
||||
if expected_type:
|
||||
assert db_model.type == expected_type
|
||||
return db_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_engine():
|
||||
# Configure in-memory database
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(engine)
|
||||
yield engine
|
||||
Base.metadata.drop_all(engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session(db_engine):
|
||||
Session = sessionmaker(bind=db_engine)
|
||||
session = Session()
|
||||
yield session
|
||||
session.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_relative_path():
|
||||
with patch("app.model_processor.get_relative_path") as mock:
|
||||
mock.side_effect = lambda path: (TEST_MODEL_TYPE, os.path.basename(path))
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_full_path():
|
||||
with patch("app.model_processor.get_full_path") as mock:
|
||||
mock.return_value = TEST_DESTINATION_PATH
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_processor(db_session, mock_get_relative_path, mock_get_full_path):
|
||||
with patch("app.model_processor.create_session", return_value=db_session):
|
||||
with patch("app.model_processor.can_create_session", return_value=True):
|
||||
processor = ModelProcessor()
|
||||
# Setup test state
|
||||
processor.removed_files = []
|
||||
processor.downloaded_files = []
|
||||
processor.file_exists = {}
|
||||
|
||||
def mock_download_file(url, destination_path, hasher):
|
||||
processor.downloaded_files.append((url, destination_path))
|
||||
processor.file_exists[destination_path] = True
|
||||
# Simulate writing some data to the file
|
||||
test_data = b"test data"
|
||||
hasher.update(test_data)
|
||||
|
||||
def mock_remove_file(file_path):
|
||||
processor.removed_files.append(file_path)
|
||||
if file_path in processor.file_exists:
|
||||
del processor.file_exists[file_path]
|
||||
|
||||
# Setup common patches
|
||||
file_exists_patch = patch.object(
|
||||
processor,
|
||||
"_file_exists",
|
||||
side_effect=lambda path: processor.file_exists.get(path, False),
|
||||
)
|
||||
file_size_patch = patch.object(
|
||||
processor,
|
||||
"_get_file_size",
|
||||
side_effect=lambda path: (
|
||||
1000 if processor.file_exists.get(path, False) else 0
|
||||
),
|
||||
)
|
||||
download_file_patch = patch.object(
|
||||
processor, "_download_file", side_effect=mock_download_file
|
||||
)
|
||||
remove_file_patch = patch.object(
|
||||
processor, "_remove_file", side_effect=mock_remove_file
|
||||
)
|
||||
|
||||
with (
|
||||
file_exists_patch,
|
||||
file_size_patch,
|
||||
download_file_patch,
|
||||
remove_file_patch,
|
||||
):
|
||||
yield processor
|
||||
|
||||
|
||||
def test_ensure_downloaded_invalid_extension(model_processor):
|
||||
# Ensure that an unsupported file extension raises an error to prevent unsafe file downloads
|
||||
with pytest.raises(ValueError, match="Unsupported unsafe file for download"):
|
||||
model_processor.ensure_downloaded(TEST_MODEL_TYPE, TEST_URL, "model.exe")
|
||||
|
||||
|
||||
def test_ensure_downloaded_existing_file_with_hash(model_processor, db_session):
|
||||
# Ensure that a file with the same hash but from a different source is not downloaded again
|
||||
SOURCE_URL = "https://example.com/other.sft"
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH, source_url=SOURCE_URL)
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
result = model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
||||
)
|
||||
|
||||
assert result == TEST_DESTINATION_PATH
|
||||
model = verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
||||
assert model.source_url == SOURCE_URL # Ensure the source URL is not overwritten
|
||||
|
||||
|
||||
def test_ensure_downloaded_existing_file_hash_mismatch(model_processor, db_session):
|
||||
# Ensure that a file with a different hash raises an error
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, "different_hash")
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
with pytest.raises(ValueError, match="File .* exists with hash .* but expected .*"):
|
||||
model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
||||
)
|
||||
|
||||
|
||||
def test_ensure_downloaded_new_file(model_processor, db_session):
|
||||
# Ensure that a new file is downloaded
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = False
|
||||
|
||||
with setup_mock_hash_calculation(model_processor, TEST_EXPECTED_HASH):
|
||||
result = model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
|
||||
)
|
||||
|
||||
assert result == TEST_DESTINATION_PATH
|
||||
assert len(model_processor.downloaded_files) == 1
|
||||
assert model_processor.downloaded_files[0] == (TEST_URL, TEST_DESTINATION_PATH)
|
||||
assert model_processor.file_exists[TEST_DESTINATION_PATH]
|
||||
verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
||||
|
||||
|
||||
def test_ensure_downloaded_hash_mismatch(model_processor, db_session):
|
||||
# Ensure that download that results in a different hash raises an error
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = False
|
||||
|
||||
with setup_mock_hash_calculation(model_processor, "different_hash"):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Downloaded file hash .* does not match expected hash .*",
|
||||
):
|
||||
model_processor.ensure_downloaded(
|
||||
TEST_MODEL_TYPE,
|
||||
TEST_URL,
|
||||
TEST_FILE_NAME,
|
||||
TEST_EXPECTED_HASH,
|
||||
)
|
||||
|
||||
assert len(model_processor.removed_files) == 1
|
||||
assert model_processor.removed_files[0] == TEST_DESTINATION_PATH
|
||||
assert TEST_DESTINATION_PATH not in model_processor.file_exists
|
||||
assert db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() is None
|
||||
|
||||
|
||||
def test_process_file_without_hash(model_processor, db_session):
|
||||
# Test processing file without provided hash
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH):
|
||||
result = model_processor.process_file(TEST_DESTINATION_PATH)
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
|
||||
|
||||
def test_retrieve_model_by_hash(model_processor, db_session):
|
||||
# Test retrieving model by hash
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH)
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
|
||||
|
||||
def test_retrieve_model_by_hash_and_type(model_processor, db_session):
|
||||
# Test retrieving model by hash and type
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
assert result.type == TEST_MODEL_TYPE
|
||||
|
||||
|
||||
def test_retrieve_hash(model_processor, db_session):
|
||||
# Test retrieving hash for existing model
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
with patch.object(
|
||||
model_processor,
|
||||
"_validate_path",
|
||||
return_value=(TEST_MODEL_TYPE, TEST_FILE_NAME),
|
||||
):
|
||||
result = model_processor.retrieve_hash(TEST_DESTINATION_PATH, TEST_MODEL_TYPE)
|
||||
assert result == TEST_EXPECTED_HASH
|
||||
|
||||
|
||||
def test_validate_file_extension_valid_extensions(model_processor):
|
||||
# Test all valid file extensions
|
||||
valid_extensions = [".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"]
|
||||
for ext in valid_extensions:
|
||||
model_processor._validate_file_extension(f"test{ext}") # Should not raise
|
||||
|
||||
|
||||
def test_process_file_existing_without_source_url(model_processor, db_session):
|
||||
# Test processing an existing file that needs its source URL updated
|
||||
model_processor.file_exists[TEST_DESTINATION_PATH] = True
|
||||
|
||||
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
|
||||
result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL)
|
||||
|
||||
assert result is not None
|
||||
assert result.hash == TEST_EXPECTED_HASH
|
||||
assert result.source_url == TEST_URL
|
||||
|
||||
db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first()
|
||||
assert db_model.source_url == TEST_URL
|
||||
Loading…
Reference in New Issue
Block a user