dev: Everything is Assets

This commit is contained in:
bigcat88 2025-08-19 19:56:59 +03:00
parent c708d0a433
commit f92307cd4c
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
22 changed files with 1650 additions and 977 deletions

View File

@ -0,0 +1,158 @@
"""initial assets schema + per-asset state cache
Revision ID: 0001_assets
Revises:
Create Date: 2025-08-20 00:00:00
"""
from alembic import op
import sqlalchemy as sa
revision = "0001_assets"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ASSETS: content identity (deduplicated by hash)
op.create_table(
"assets",
sa.Column("hash", sa.String(length=128), primary_key=True),
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("mime_type", sa.String(length=255), nullable=True),
sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("storage_backend", sa.String(length=32), nullable=False, server_default="fs"),
sa.Column("storage_locator", sa.Text(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"),
)
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
op.create_index("ix_assets_backend_locator", "assets", ["storage_backend", "storage_locator"])
# ASSETS_INFO: user-visible references (mutable metadata)
op.create_table(
"assets_info",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("owner_id", sa.String(length=128), nullable=True),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False),
sa.Column("preview_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
sa.Column("last_access_time", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
sqlite_autoincrement=True,
)
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
op.create_index("ix_assets_info_asset_hash", "assets_info", ["asset_hash"])
op.create_index("ix_assets_info_name", "assets_info", ["name"])
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
# TAGS: normalized tag vocabulary
op.create_table(
"tags",
sa.Column("name", sa.String(length=128), primary_key=True),
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
)
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
op.create_table(
"asset_info_tags",
sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("tag_name", sa.String(length=128), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
sa.Column("added_by", sa.String(length=128), nullable=True),
sa.Column("added_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
)
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
# ASSET_LOCATOR_STATE: 1:1 filesystem metadata(for fast integrity checking) for an Asset records
op.create_table(
"asset_locator_state",
sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True),
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column("etag", sa.String(length=256), nullable=True),
sa.Column("last_modified", sa.String(length=128), nullable=True),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"),
)
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
op.create_table(
"asset_info_meta",
sa.Column("asset_info_id", sa.Integer(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
)
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
# Tags vocabulary for models
tags_table = sa.table(
"tags",
sa.column("name", sa.String()),
sa.column("tag_type", sa.String()),
)
op.bulk_insert(
tags_table,
[
# Core concept tags
{"name": "models", "tag_type": "system"},
# Canonical single-word types
{"name": "checkpoint", "tag_type": "system"},
{"name": "lora", "tag_type": "system"},
{"name": "vae", "tag_type": "system"},
{"name": "text-encoder", "tag_type": "system"},
{"name": "clip-vision", "tag_type": "system"},
{"name": "embedding", "tag_type": "system"},
{"name": "controlnet", "tag_type": "system"},
{"name": "upscale", "tag_type": "system"},
{"name": "diffusion-model", "tag_type": "system"},
{"name": "hypernetwork", "tag_type": "system"},
{"name": "vae_approx", "tag_type": "system"},
# TODO: decide what to do with: style_models, diffusers, gligen, photomaker, classifiers
],
)
def downgrade() -> None:
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta")
op.drop_table("asset_locator_state")
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
op.drop_table("asset_info_tags")
op.drop_index("ix_tags_tag_type", table_name="tags")
op.drop_table("tags")
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
op.drop_index("ix_assets_info_name", table_name="assets_info")
op.drop_index("ix_assets_info_asset_hash", table_name="assets_info")
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
op.drop_table("assets_info")
op.drop_index("ix_assets_backend_locator", table_name="assets")
op.drop_index("ix_assets_mime_type", table_name="assets")
op.drop_table("assets")

View File

@ -1,40 +0,0 @@
"""init
Revision ID: e9c714da8d57
Revises:
Create Date: 2025-05-30 20:14:33.772039
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'e9c714da8d57'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
op.create_table('model',
sa.Column('type', sa.Text(), nullable=False),
sa.Column('path', sa.Text(), nullable=False),
sa.Column('file_name', sa.Text(), nullable=True),
sa.Column('file_size', sa.Integer(), nullable=True),
sa.Column('hash', sa.Text(), nullable=True),
sa.Column('hash_algorithm', sa.Text(), nullable=True),
sa.Column('source_url', sa.Text(), nullable=True),
sa.Column('date_added', sa.DateTime(), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
sa.PrimaryKeyConstraint('type', 'path')
)
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('model')
# ### end Alembic commands ###

0
app/api/__init__.py Normal file
View File

110
app/api/assets_routes.py Normal file
View File

@ -0,0 +1,110 @@
import json
from typing import Sequence
from aiohttp import web
from app import assets_manager
ROUTES = web.RouteTableDef()
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
q = request.rel_url.query
include_tags: Sequence[str] = _parse_csv_tags(q.get("include_tags"))
exclude_tags: Sequence[str] = _parse_csv_tags(q.get("exclude_tags"))
name_contains = q.get("name_contains")
# Optional JSON metadata filter (top-level key equality only for now)
metadata_filter = None
raw_meta = q.get("metadata_filter")
if raw_meta:
try:
metadata_filter = json.loads(raw_meta)
if not isinstance(metadata_filter, dict):
metadata_filter = None
except Exception:
# Silently ignore malformed JSON for first iteration; could 400 in future
metadata_filter = None
limit = _parse_int(q.get("limit"), default=20, lo=1, hi=100)
offset = _parse_int(q.get("offset"), default=0, lo=0, hi=10_000_000)
sort = q.get("sort", "created_at")
order = q.get("order", "desc")
payload = await assets_manager.list_assets(
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
return web.json_response(payload)
@ROUTES.put("/api/assets/{id}")
async def update_asset(request: web.Request) -> web.Response:
asset_info_id_raw = request.match_info.get("id")
try:
asset_info_id = int(asset_info_id_raw)
except Exception:
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
try:
payload = await request.json()
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
name = payload.get("name", None)
tags = payload.get("tags", None)
user_metadata = payload.get("user_metadata", None)
if name is None and tags is None and user_metadata is None:
return _error_response(400, "NO_FIELDS", "Provide at least one of: name, tags, user_metadata.")
if tags is not None and (not isinstance(tags, list) or not all(isinstance(t, str) for t in tags)):
return _error_response(400, "INVALID_TAGS", "Field 'tags' must be an array of strings.")
if user_metadata is not None and not isinstance(user_metadata, dict):
return _error_response(400, "INVALID_METADATA", "Field 'user_metadata' must be an object.")
try:
result = await assets_manager.update_asset(
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result, status=200)
def register_assets_routes(app: web.Application) -> None:
app.add_routes(ROUTES)
def _parse_csv_tags(raw: str | None) -> list[str]:
if not raw:
return []
return [t.strip() for t in raw.split(",") if t.strip()]
def _parse_int(qval: str | None, default: int, lo: int, hi: int) -> int:
if not qval:
return default
try:
v = int(qval)
except Exception:
return default
return max(lo, min(hi, v))
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)

148
app/assets_manager.py Normal file
View File

@ -0,0 +1,148 @@
import os
from datetime import datetime, timezone
from typing import Optional, Sequence
from .database.db import create_session
from .storage import hashing
from .database.services import (
check_fs_asset_exists_quick,
ingest_fs_asset,
touch_asset_infos_by_fs_path,
list_asset_infos_page,
update_asset_info_full,
get_asset_tags,
)
def get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
"""Adds a local asset to the DB. If already present and unchanged, does nothing.
Notes:
- Uses absolute path as the canonical locator for the 'fs' backend.
- Computes BLAKE3 only when the fast existence check indicates it's needed.
- This function ensures the identity row and seeds mtime in asset_locator_state.
"""
abs_path = os.path.abspath(file_path)
size_bytes, mtime_ns = get_size_mtime_ns(abs_path)
if not size_bytes:
return
async with await create_session() as session:
if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns):
await touch_asset_infos_by_fs_path(session, abs_path=abs_path, ts=datetime.now(timezone.utc))
await session.commit()
return
asset_hash = hashing.blake3_hash_sync(abs_path)
async with await create_session() as session:
await ingest_fs_asset(
session,
asset_hash="blake3:" + asset_hash,
abs_path=abs_path,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=None,
info_name=file_name,
tag_origin="automatic",
tags=tags,
)
await session.commit()
async def list_assets(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: Optional[str] = None,
metadata_filter: Optional[dict] = None,
limit: int = 20,
offset: int = 0,
sort: str | None = "created_at",
order: str | None = "desc",
) -> dict:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
async with await create_session() as session:
infos, tag_map, total = await list_asset_infos_page(
session,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
assets_json = []
for info in infos:
asset = info.asset # populated via contains_eager
tags = tag_map.get(info.id, [])
assets_json.append(
{
"id": info.id,
"name": info.name,
"asset_hash": info.asset_hash,
"size": int(asset.size_bytes) if asset else None,
"mime_type": asset.mime_type if asset else None,
"tags": tags,
"preview_url": f"/api/v1/assets/{info.id}/content", # TODO: implement actual content endpoint later
"created_at": info.created_at.isoformat() if info.created_at else None,
"updated_at": info.updated_at.isoformat() if info.updated_at else None,
"last_access_time": info.last_access_time.isoformat() if info.last_access_time else None,
}
)
return {
"assets": assets_json,
"total": total,
"has_more": (offset + len(assets_json)) < total,
}
async def update_asset(
*,
asset_info_id: int,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
) -> dict:
async with await create_session() as session:
info = await update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
tag_origin="manual",
added_by=None,
)
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
await session.commit()
return {
"id": info.id,
"name": info.name,
"asset_hash": info.asset_hash,
"tags": tag_names,
"user_metadata": info.user_metadata or {},
"updated_at": info.updated_at.isoformat() if info.updated_at else None,
}
def _safe_sort_field(requested: str | None) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"

0
app/database/__init__.py Normal file
View File

View File

@ -1,112 +1,267 @@
import logging import logging
import os import os
import shutil import shutil
from contextlib import asynccontextmanager
from typing import Optional
from app.logger import log_startup_warning from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message from utils.install_util import get_missing_requirements_message
from comfy.cli_args import args from comfy.cli_args import args
_DB_AVAILABLE = False
Session = None
LOGGER = logging.getLogger(__name__)
# Attempt imports which may not exist in some environments
try: try:
from alembic import command from alembic import command
from alembic.config import Config from alembic.config import Config
from alembic.runtime.migration import MigrationContext from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory from alembic.script import ScriptDirectory
from sqlalchemy import create_engine from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
_DB_AVAILABLE = True _DB_AVAILABLE = True
ENGINE: AsyncEngine | None = None
SESSION: async_sessionmaker | None = None
except ImportError as e: except ImportError as e:
log_startup_warning( log_startup_warning(
f""" (
------------------------------------------------------------------------ "------------------------------------------------------------------------\n"
Error importing dependencies: {e} f"Error importing DB dependencies: {e}\n"
{get_missing_requirements_message()} f"{get_missing_requirements_message()}\n"
This error is happening because ComfyUI now uses a local sqlite database. "This error is happening because ComfyUI now uses a local database.\n"
------------------------------------------------------------------------ "------------------------------------------------------------------------"
""".strip() ).strip()
) )
_DB_AVAILABLE = False
ENGINE = None
SESSION = None
def dependencies_available(): def dependencies_available() -> bool:
""" """Check if DB dependencies are importable."""
Temporary function to check if the dependencies are available
"""
return _DB_AVAILABLE return _DB_AVAILABLE
def can_create_session(): def _root_paths():
""" """Resolve alembic.ini and migrations script folder."""
Temporary function to check if the database is available to create a session root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
During initial release there may be environmental issues (or missing dependencies) that prevent the database from being created
"""
return dependencies_available() and Session is not None
def get_alembic_config():
root_path = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root_path, "alembic.ini")) config_path = os.path.abspath(os.path.join(root_path, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db")) scripts_path = os.path.abspath(os.path.join(root_path, "alembic_db"))
return config_path, scripts_path
config = Config(config_path)
config.set_main_option("script_location", scripts_path)
config.set_main_option("sqlalchemy.url", args.database_url)
return config
def get_db_path(): def _absolutize_sqlite_url(db_url: str) -> str:
url = args.database_url """Make SQLite database path absolute. No-op for non-SQLite URLs."""
if url.startswith("sqlite:///"): try:
return url.split("///")[1] u = make_url(db_url)
except Exception:
return db_url
if not u.drivername.startswith("sqlite"):
return db_url
# Make path absolute if relative
db_path = u.database or ""
if not os.path.isabs(db_path):
db_path = os.path.abspath(os.path.join(os.getcwd(), db_path))
u = u.set(database=db_path)
return str(u)
def _to_sync_driver_url(async_url: str) -> str:
"""Convert an async SQLAlchemy URL to a sync URL for Alembic."""
u = make_url(async_url)
driver = u.drivername
if driver.startswith("sqlite+aiosqlite"):
u = u.set(drivername="sqlite")
elif driver.startswith("postgresql+asyncpg"):
u = u.set(drivername="postgresql")
else: else:
raise ValueError(f"Unsupported database URL '{url}'.") # Generic: strip the async driver part if present
if "+" in driver:
u = u.set(drivername=driver.split("+", 1)[0])
return str(u)
def init_db(): def _get_sqlite_file_path(sync_url: str) -> Optional[str]:
db_url = args.database_url """Return the on-disk path for a SQLite URL, else None."""
logging.debug(f"Database URL: {db_url}") try:
db_path = get_db_path() u = make_url(sync_url)
db_exists = os.path.exists(db_path) except Exception:
return None
config = get_alembic_config() if not u.drivername.startswith("sqlite"):
return None
return u.database
# Check if we need to upgrade
engine = create_engine(db_url)
conn = engine.connect()
context = MigrationContext.configure(conn) def _get_alembic_config(sync_url: str) -> Config:
current_rev = context.get_current_revision() """Prepare Alembic Config with script location and DB URL."""
config_path, scripts_path = _root_paths()
cfg = Config(config_path)
cfg.set_main_option("script_location", scripts_path)
cfg.set_main_option("sqlalchemy.url", sync_url)
return cfg
script = ScriptDirectory.from_config(config)
async def init_db_engine() -> None:
"""Initialize async engine + sessionmaker and run migrations to head.
This must be called once on application startup before any DB usage.
"""
global ENGINE, SESSION
if not dependencies_available():
raise RuntimeError("Database dependencies are not available.")
if ENGINE is not None:
return
raw_url = args.database_url
if not raw_url:
raise RuntimeError("Database URL is not configured.")
# Absolutize SQLite path for async engine
db_url = _absolutize_sqlite_url(raw_url)
# Prepare async engine
connect_args = {}
if db_url.startswith("sqlite"):
connect_args = {
"check_same_thread": False,
"timeout": 12,
}
ENGINE = create_async_engine(
db_url,
connect_args=connect_args,
pool_pre_ping=True,
future=True,
)
# Enforce SQLite pragmas on the async engine
if db_url.startswith("sqlite"):
async with ENGINE.begin() as conn:
# WAL for concurrency and durability, Foreign Keys for referential integrity
current_mode = (await conn.execute(text("PRAGMA journal_mode;"))).scalar()
if str(current_mode).lower() != "wal":
new_mode = (await conn.execute(text("PRAGMA journal_mode=WAL;"))).scalar()
if str(new_mode).lower() != "wal":
raise RuntimeError("Failed to set SQLite journal mode to WAL.")
LOGGER.info("SQLite journal mode set to WAL.")
await conn.execute(text("PRAGMA foreign_keys = ON;"))
await conn.execute(text("PRAGMA synchronous = NORMAL;"))
await _run_migrations(raw_url=db_url)
SESSION = async_sessionmaker(
bind=ENGINE,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async def _run_migrations(raw_url: str) -> None:
"""
Run Alembic migrations up to head.
We deliberately use a synchronous engine for migrations because Alembic's
programmatic API is synchronous by default and this path is robust.
"""
# Convert to sync URL and make SQLite URL an absolute one
sync_url = _to_sync_driver_url(raw_url)
sync_url = _absolutize_sqlite_url(sync_url)
cfg = _get_alembic_config(sync_url)
# Inspect current and target heads
engine = create_engine(sync_url, future=True)
with engine.connect() as conn:
context = MigrationContext.configure(conn)
current_rev = context.get_current_revision()
script = ScriptDirectory.from_config(cfg)
target_rev = script.get_current_head() target_rev = script.get_current_head()
if target_rev is None: if target_rev is None:
logging.warning("No target revision found.") LOGGER.warning("Alembic: no target revision found.")
elif current_rev != target_rev: return
# Backup the database pre upgrade
backup_path = db_path + ".bkp"
if db_exists:
shutil.copy(db_path, backup_path)
else:
backup_path = None
if current_rev == target_rev:
LOGGER.debug("Alembic: database already at head %s", target_rev)
return
LOGGER.info("Alembic: upgrading database from %s to %s", current_rev, target_rev)
# Optional backup for SQLite file DBs
backup_path = None
sqlite_path = _get_sqlite_file_path(sync_url)
if sqlite_path and os.path.exists(sqlite_path):
backup_path = sqlite_path + ".bkp"
try: try:
command.upgrade(config, target_rev) shutil.copy(sqlite_path, backup_path)
logging.info(f"Database upgraded from {current_rev} to {target_rev}") except Exception as exc:
except Exception as e: LOGGER.warning("Failed to create SQLite backup before migration: %s", exc)
if backup_path:
# Restore the database from backup if upgrade fails try:
shutil.copy(backup_path, db_path) command.upgrade(cfg, target_rev)
except Exception:
if backup_path and os.path.exists(backup_path):
LOGGER.exception("Error upgrading database, attempting restore from backup.")
try:
shutil.copy(backup_path, sqlite_path) # restore
os.remove(backup_path) os.remove(backup_path)
logging.exception("Error upgrading database: ") except Exception as re:
raise e LOGGER.error("Failed to restore SQLite backup: %s", re)
else:
global Session LOGGER.exception("Error upgrading database, backup is not available.")
Session = sessionmaker(bind=engine) raise
def create_session(): def get_engine():
return Session() """Return the global async engine (initialized after init_db_engine())."""
if ENGINE is None:
raise RuntimeError("Engine is not initialized. Call init_db_engine() first.")
return ENGINE
def get_session_maker():
"""Return the global async_sessionmaker (initialized after init_db_engine())."""
if SESSION is None:
raise RuntimeError("Session maker is not initialized. Call init_db_engine() first.")
return SESSION
@asynccontextmanager
async def session_scope():
"""Async context manager for a unit of work:
async with session_scope() as sess:
... use sess ...
"""
maker = get_session_maker()
async with maker() as sess:
try:
yield sess
await sess.commit()
except Exception:
await sess.rollback()
raise
async def create_session():
"""Convenience helper to acquire a single AsyncSession instance.
Typical usage:
async with (await create_session()) as sess:
...
"""
maker = get_session_maker()
return maker()

View File

@ -1,59 +1,257 @@
from datetime import datetime
from typing import Any, Optional
from sqlalchemy import ( from sqlalchemy import (
Column,
Integer, Integer,
Text, BigInteger,
DateTime, DateTime,
ForeignKey,
Index,
JSON,
String,
Text,
CheckConstraint,
Numeric,
Boolean,
) )
from sqlalchemy.orm import declarative_base
from sqlalchemy.sql import func from sqlalchemy.sql import func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign
Base = declarative_base()
def to_dict(obj): class Base(DeclarativeBase):
pass
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
fields = obj.__table__.columns.keys() fields = obj.__table__.columns.keys()
return { out: dict[str, Any] = {}
field: (val.to_dict() if hasattr(val, "to_dict") else val) for field in fields:
for field in fields val = getattr(obj, field)
if (val := getattr(obj, field)) if val is None and not include_none:
} continue
if isinstance(val, datetime):
out[field] = val.isoformat()
else:
out[field] = val
return out
class Model(Base): class Asset(Base):
""" __tablename__ = "assets"
sqlalchemy model representing a model file in the system.
This class defines the database schema for storing information about model files, hash: Mapped[str] = mapped_column(String(256), primary_key=True)
including their type, path, hash, and when they were added to the system. size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[str | None] = mapped_column(String(255))
refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs")
storage_locator: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
)
Attributes: infos: Mapped[list["AssetInfo"]] = relationship(
type (Text): The type of the model, this is the name of the folder in the models folder (primary key) "AssetInfo",
path (Text): The file path of the model relative to the type folder (primary key) back_populates="asset",
file_name (Text): The name of the model file primaryjoin=lambda: Asset.hash == foreign(AssetInfo.asset_hash),
file_size (Integer): The size of the model file in bytes foreign_keys=lambda: [AssetInfo.asset_hash],
hash (Text): A hash of the model file cascade="all,delete-orphan",
hash_algorithm (Text): The algorithm used to generate the hash passive_deletes=True,
source_url (Text): The URL of the model file )
date_added (DateTime): Timestamp of when the model was added to the system
"""
__tablename__ = "model" preview_of: Mapped[list["AssetInfo"]] = relationship(
"AssetInfo",
back_populates="preview_asset",
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.preview_hash),
foreign_keys=lambda: [AssetInfo.preview_hash],
viewonly=True,
)
type = Column(Text, primary_key=True) locator_state: Mapped["AssetLocatorState | None"] = relationship(
path = Column(Text, primary_key=True) back_populates="asset",
file_name = Column(Text) uselist=False,
file_size = Column(Integer) cascade="all, delete-orphan",
hash = Column(Text) passive_deletes=True,
hash_algorithm = Column(Text) )
source_url = Column(Text)
date_added = Column(DateTime, server_default=func.now())
def to_dict(self): __table_args__ = (
""" Index("ix_assets_mime_type", "mime_type"),
Convert the model instance to a dictionary representation. Index("ix_assets_backend_locator", "storage_backend", "storage_locator"),
)
Returns: def to_dict(self, include_none: bool = False) -> dict[str, Any]:
dict: A dictionary containing the attributes of the model return to_dict(self, include_none=include_none)
"""
dict = to_dict(self) def __repr__(self) -> str:
return dict return f"<Asset hash={self.hash[:12]} backend={self.storage_backend}>"
class AssetLocatorState(Base):
__tablename__ = "asset_locator_state"
asset_hash: Mapped[str] = mapped_column(
String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True
)
# For fs backends: nanosecond mtime; nullable if not applicable
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
# For HTTP/S3/GCS/Azure, etc.: optional validators
etag: Mapped[str | None] = mapped_column(String(256), nullable=True)
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
asset: Mapped["Asset"] = relationship(back_populates="locator_state", uselist=False)
__table_args__ = (
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetLocatorState hash={self.asset_hash[:12]} mtime_ns={self.mtime_ns}>"
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
owner_id: Mapped[str | None] = mapped_column(String(128))
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_hash: Mapped[str] = mapped_column(
String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False
)
preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL"))
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
)
last_access_time: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
)
# Relationships
asset: Mapped[Asset] = relationship(
"Asset",
back_populates="infos",
foreign_keys=[asset_hash],
)
preview_asset: Mapped[Asset | None] = relationship(
"Asset",
back_populates="preview_of",
foreign_keys=[preview_hash],
)
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
)
tag_links: Mapped[list["AssetInfoTag"]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
overlaps="tags,asset_infos",
)
tags: Mapped[list["Tag"]] = relationship(
secondary="asset_info_tags",
back_populates="asset_infos",
lazy="joined",
viewonly=True,
overlaps="tag_links,asset_info_links,asset_infos,tag",
)
__table_args__ = (
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_hash", "asset_hash"),
Index("ix_assets_info_name", "name"),
Index("ix_assets_info_created_at", "created_at"),
Index("ix_assets_info_last_access_time", "last_access_time"),
{"sqlite_autoincrement": True},
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
data = to_dict(self, include_none=include_none)
data["tags"] = [t.name for t in self.tags]
return data
def __repr__(self) -> str:
return f"<AssetInfo id={self.id} name={self.name!r} hash={self.asset_hash[:12]}>"
class AssetInfoMeta(Base):
__tablename__ = "asset_info_meta"
asset_info_id: Mapped[int] = mapped_column(
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
key: Mapped[str] = mapped_column(String(256), primary_key=True)
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True)
val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True)
val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True)
val_json: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True)
asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries")
__table_args__ = (
Index("ix_asset_info_meta_key", "key"),
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
)
class AssetInfoTag(Base):
__tablename__ = "asset_info_tags"
asset_info_id: Mapped[int] = mapped_column(
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
tag_name: Mapped[str] = mapped_column(
String(128), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_by: Mapped[str | None] = mapped_column(String(128))
added_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")
__table_args__ = (
Index("ix_asset_info_tags_tag_name", "tag_name"),
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
)
class Tag(Base):
__tablename__ = "tags"
name: Mapped[str] = mapped_column(String(128), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(
back_populates="tag",
overlaps="asset_infos,tags",
)
asset_infos: Mapped[list["AssetInfo"]] = relationship(
secondary="asset_info_tags",
back_populates="tags",
viewonly=True,
overlaps="asset_info_links,tag_links,tags,asset_info",
)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

683
app/database/services.py Normal file
View File

@ -0,0 +1,683 @@
import os
import logging
from collections import defaultdict
from datetime import datetime, timezone
from decimal import Decimal
from typing import Any, Sequence, Optional, Iterable
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete, exists, func
from sqlalchemy.orm import contains_eager
from sqlalchemy.exc import IntegrityError
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
async def check_fs_asset_exists_quick(
session,
*,
file_path: str,
size_bytes: Optional[int] = None,
mtime_ns: Optional[int] = None,
) -> bool:
"""
Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path,
AND (if provided) mtime_ns matches stored locator-state,
AND (if provided) size_bytes matches verified size when known.
"""
locator = os.path.abspath(file_path)
stmt = select(sa.literal(True)).select_from(Asset)
conditions = [
Asset.storage_backend == "fs",
Asset.storage_locator == locator,
]
# If size_bytes provided require equality when the asset has a verified (non-zero) size.
# If verified size is 0 (unknown), we don't force equality.
if size_bytes is not None:
conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
# If mtime_ns provided require the locator-state to exist and match.
if mtime_ns is not None:
stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash)
conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns))
stmt = stmt.where(*conditions).limit(1)
row = (await session.execute(stmt)).first()
return row is not None
async def ingest_fs_asset(
session: AsyncSession,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: Optional[str] = None,
info_name: Optional[str] = None,
owner_id: Optional[str] = None,
preview_hash: Optional[str] = None,
user_metadata: Optional[dict] = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
added_by: Optional[str] = None,
require_existing_tags: bool = False,
) -> dict:
"""
Creates or updates Asset record for a local (fs) asset.
Always:
- Insert Asset if missing; else update size_bytes (and updated_at) if different.
- Insert AssetLocatorState if missing; else update mtime_ns if different.
Optionally (when info_name is provided):
- Create an AssetInfo (no refcount changes).
- Link provided tags to that AssetInfo.
* If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table.
* If False (default), silently skips unknown tags.
Returns flags and ids:
{
"asset_created": bool,
"asset_updated": bool,
"state_created": bool,
"state_updated": bool,
"asset_info_id": int | None,
"tags_added": list[str],
"tags_missing": list[str], # filled only when require_existing_tags=False
}
"""
locator = os.path.abspath(abs_path)
datetime_now = datetime.now(timezone.utc)
out = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
"tags_added": [],
"tags_missing": [],
}
# ---- Step 1: INSERT Asset or UPDATE size_bytes/updated_at if exists ----
async with session.begin_nested() as sp1:
try:
session.add(
Asset(
hash=asset_hash,
size_bytes=int(size_bytes),
mime_type=mime_type,
refcount=0,
storage_backend="fs",
storage_locator=locator,
created_at=datetime_now,
updated_at=datetime_now,
)
)
await session.flush()
out["asset_created"] = True
except IntegrityError:
await sp1.rollback()
# Already exists by hash -> update selected fields if different
existing = await session.get(Asset, asset_hash)
if existing is not None:
desired_size = int(size_bytes)
if existing.size_bytes != desired_size:
existing.size_bytes = desired_size
existing.updated_at = datetime_now
out["asset_updated"] = True
else:
# This should not occur. Log for visibility.
logging.error("Asset %s not found after conflict; skipping update.", asset_hash)
except Exception:
await sp1.rollback()
logging.exception("Unexpected error inserting Asset (hash=%s, locator=%s)", asset_hash, locator)
raise
# ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ----
async with session.begin_nested() as sp2:
try:
session.add(
AssetLocatorState(
asset_hash=asset_hash,
mtime_ns=int(mtime_ns),
)
)
await session.flush()
out["state_created"] = True
except IntegrityError:
await sp2.rollback()
state = await session.get(AssetLocatorState, asset_hash)
if state is not None:
desired_mtime = int(mtime_ns)
if state.mtime_ns != desired_mtime:
state.mtime_ns = desired_mtime
out["state_updated"] = True
else:
logging.debug("Locator state missing for %s after conflict; skipping update.", asset_hash)
except Exception:
await sp2.rollback()
logging.exception("Unexpected error inserting AssetLocatorState (hash=%s)", asset_hash)
raise
# ---- Optional: AssetInfo + tag links ----
if info_name:
# 2a) Create AssetInfo (no refcount bump)
async with session.begin_nested() as sp3:
try:
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_hash=asset_hash,
preview_hash=preview_hash,
created_at=datetime_now,
updated_at=datetime_now,
last_access_time=datetime_now,
)
session.add(info)
await session.flush() # get info.id
out["asset_info_id"] = info.id
except Exception:
await sp3.rollback()
logging.exception(
"Unexpected error inserting AssetInfo (hash=%s, name=%s)", asset_hash, info_name
)
raise
# 2b) Link tags (if any). We DO NOT create new Tag rows here by default.
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
# Which tags exist?
existing_tag_names = set(
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
# Which links already exist?
existing_links = set(
tag_name
for (tag_name,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_by=added_by,
added_at=datetime_now,
)
for t in to_add
]
)
await session.flush()
out["tags_added"] = to_add
out["tags_missing"] = missing
# 2c) Rebuild metadata projection if provided
if user_metadata is not None and out["asset_info_id"] is not None:
await replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=user_metadata,
)
return out
async def touch_asset_infos_by_fs_path(
session: AsyncSession,
*,
abs_path: str,
ts: Optional[datetime] = None,
only_if_newer: bool = True,
) -> int:
locator = os.path.abspath(abs_path)
ts = ts or datetime.now(timezone.utc)
stmt = sa.update(AssetInfo).where(
sa.exists(
sa.select(sa.literal(1))
.select_from(Asset)
.where(
Asset.hash == AssetInfo.asset_hash,
Asset.storage_backend == "fs",
Asset.storage_locator == locator,
)
)
)
if only_if_newer:
stmt = stmt.where(
sa.or_(
AssetInfo.last_access_time.is_(None),
AssetInfo.last_access_time < ts,
)
)
stmt = stmt.values(last_access_time=ts)
res = await session.execute(stmt)
return int(res.rowcount or 0)
async def list_asset_infos_page(
session: AsyncSession,
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[int, list[str]], int]:
"""
Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1),
plus a map of asset_info_id -> [tags], and the total count.
We purposely collect tags in a separate (single) query to avoid row explosion.
"""
# Clamp
if limit <= 0:
limit = 1
if limit > 100:
limit = 100
if offset < 0:
offset = 0
# Build base query
base = (
select(AssetInfo)
.join(Asset, Asset.hash == AssetInfo.asset_hash)
.options(contains_eager(AssetInfo.asset))
)
# Filters
if name_contains:
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
base = _apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter)
# Sort
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
# Total count (same filters, no ordering/limit/offset)
count_stmt = (
select(func.count())
.select_from(AssetInfo)
.join(Asset, Asset.hash == AssetInfo.asset_hash)
)
if name_contains:
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
total = (await session.execute(count_stmt)).scalar_one()
# Fetch rows
infos = (await session.execute(base)).scalars().unique().all()
# Collect tags in bulk (single query)
id_list = [i.id for i in infos]
tag_map: dict[int, list[str]] = defaultdict(list)
if id_list:
rows = await session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
async def set_asset_info_tags(
session: AsyncSession,
*,
asset_info_id: int,
tags: Sequence[str],
origin: str = "manual",
added_by: Optional[str] = None,
) -> dict:
"""
Replace the tag set on an AssetInfo with `tags`. Idempotent.
Creates missing tag names as 'user'.
"""
desired = _normalize_tags(tags)
now = datetime.now(timezone.utc)
# current links
current = set(
tag_name for (tag_name,) in (
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
await _ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=now)
for t in to_add
])
await session.flush()
if to_remove:
await session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
await session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
async def update_asset_info_full(
session: AsyncSession,
*,
asset_info_id: int,
name: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
user_metadata: Optional[dict] = None,
tag_origin: str = "manual",
added_by: Optional[str] = None,
) -> AssetInfo:
"""
Update AssetInfo fields:
- name (if provided)
- user_metadata blob + rebuild projection (if provided)
- replace tags with provided set (if provided)
Returns the updated AssetInfo.
"""
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
if user_metadata is not None:
await replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=user_metadata
)
touched = True
if tags is not None:
await set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
added_by=added_by,
)
touched = True
if touched and user_metadata is None:
info.updated_at = datetime.now(timezone.utc)
await session.flush()
return info
async def replace_asset_info_metadata_projection(
session: AsyncSession,
*,
asset_info_id: int,
user_metadata: dict | None,
) -> None:
"""Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`."""
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = datetime.now(timezone.utc)
await session.flush()
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
await session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in _project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
await session.flush()
async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[Tag]:
return [
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def _normalize_tags(tags: Sequence[str] | None) -> list[str]:
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
wanted = _normalize_tags(list(names))
if not wanted:
return []
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
by_name = {t.name: t for t in existing}
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
if to_create:
session.add_all(to_create)
await session.flush()
by_name.update({t.name: t for t in to_create})
return [by_name[n] for n in wanted]
def _apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None,
exclude_tags: Sequence[str] | None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = _normalize_tags(include_tags)
exclude_tags = _normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None,
) -> sa.sql.Select:
"""Apply metadata filters using the projection table asset_info_meta.
Semantics:
- For scalar values: require EXISTS(asset_info_meta) with matching key + typed value.
- For None: key is missing OR key has explicit null (val_json IS NULL).
- For list values: ANY-of the list elements matches (EXISTS for any).
(Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))')
"""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
subquery = (
select(sa.literal(1))
.select_from(AssetInfoMeta)
.where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
.limit(1)
)
return sa.exists(subquery)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
# Missing OR null:
if value is None:
# either: no row for key OR a row for key with explicit null
no_row_for_key = ~sa.exists(
select(sa.literal(1))
.select_from(AssetInfoMeta)
.where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
.limit(1)
)
null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None))
return sa.or_(no_row_for_key, null_row)
# Typed scalar matches:
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
# store as Decimal for equality against NUMERIC(38,10)
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
# Complex: compare JSON (no index, but supported)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
# ANY-of (exists for any element)
ors = [ _exists_clause_for_value(k, elem) for elem in v ]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def _is_scalar(v: Any) -> bool:
if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def _project_kv(key: str, value: Any) -> list[dict]:
"""
Turn a metadata key/value into one or more projection rows:
- scalar -> one row (ordinal=0) in the proper typed column
- list of scalars -> one row per element with ordinal=i
- dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list)
- None -> single row with val_json = None
Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...}
"""
rows: list[dict] = []
# None
if value is None:
rows.append({"key": key, "ordinal": 0, "val_json": None})
return rows
# Scalars
if _is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
# store numeric; SQLAlchemy will coerce to Numeric
rows.append({"key": key, "ordinal": 0, "val_num": value})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
# Fallback to json
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
# Lists
if isinstance(value, list):
# list of scalars?
if all(_is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append({"key": key, "ordinal": i, "val_json": None})
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
rows.append({"key": key, "ordinal": i, "val_num": x})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
# list contains objects -> one val_json per element
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
# Dict or any other structure -> single json row
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@ -1,331 +0,0 @@
import os
import logging
import time
import requests
from tqdm import tqdm
from folder_paths import get_relative_path, get_full_path
from app.database.db import create_session, dependencies_available, can_create_session
import blake3
import comfy.utils
if dependencies_available():
from app.database.models import Model
class ModelProcessor:
def _validate_path(self, model_path):
try:
if not self._file_exists(model_path):
logging.error(f"Model file not found: {model_path}")
return None
result = get_relative_path(model_path)
if not result:
logging.error(
f"Model file not in a recognized model directory: {model_path}"
)
return None
return result
except Exception as e:
logging.error(f"Error validating model path {model_path}: {str(e)}")
return None
def _file_exists(self, path):
"""Check if a file exists."""
return os.path.exists(path)
def _get_file_size(self, path):
"""Get file size."""
return os.path.getsize(path)
def _get_hasher(self):
return blake3.blake3()
def _hash_file(self, model_path):
try:
hasher = self._get_hasher()
with open(model_path, "rb", buffering=0) as f:
b = bytearray(128 * 1024)
mv = memoryview(b)
while n := f.readinto(mv):
hasher.update(mv[:n])
return hasher.hexdigest()
except Exception as e:
logging.error(f"Error hashing file {model_path}: {str(e)}")
return None
def _get_existing_model(self, session, model_type, model_relative_path):
return (
session.query(Model)
.filter(Model.type == model_type)
.filter(Model.path == model_relative_path)
.first()
)
def _ensure_source_url(self, session, model, source_url):
if model.source_url is None:
model.source_url = source_url
session.commit()
def _update_database(
self,
session,
model_type,
model_path,
model_relative_path,
model_hash,
model,
source_url,
):
try:
if not model:
model = self._get_existing_model(
session, model_type, model_relative_path
)
if not model:
model = Model(
path=model_relative_path,
type=model_type,
file_name=os.path.basename(model_path),
)
session.add(model)
model.file_size = self._get_file_size(model_path)
model.hash = model_hash
if model_hash:
model.hash_algorithm = "blake3"
model.source_url = source_url
session.commit()
return model
except Exception as e:
logging.error(
f"Error updating database for {model_relative_path}: {str(e)}"
)
def process_file(self, model_path, source_url=None, model_hash=None):
"""
Process a model file and update the database with metadata.
If the file already exists and matches the database, it will not be processed again.
Returns the model object or if an error occurs, returns None.
"""
try:
if not can_create_session():
return
result = self._validate_path(model_path)
if not result:
return
model_type, model_relative_path = result
with create_session() as session:
session.expire_on_commit = False
existing_model = self._get_existing_model(
session, model_type, model_relative_path
)
if (
existing_model
and existing_model.hash
and existing_model.file_size == self._get_file_size(model_path)
):
# File exists with hash and same size, no need to process
self._ensure_source_url(session, existing_model, source_url)
return existing_model
if model_hash:
model_hash = model_hash.lower()
logging.info(f"Using provided hash: {model_hash}")
else:
start_time = time.time()
logging.info(f"Hashing model {model_relative_path}")
model_hash = self._hash_file(model_path)
if not model_hash:
return
logging.info(
f"Model hash: {model_hash} (duration: {time.time() - start_time} seconds)"
)
return self._update_database(
session,
model_type,
model_path,
model_relative_path,
model_hash,
existing_model,
source_url,
)
except Exception as e:
logging.error(f"Error processing model file {model_path}: {str(e)}")
return None
def retrieve_model_by_hash(self, model_hash, model_type=None, session=None):
"""
Retrieve a model file from the database by hash and optionally by model type.
Returns the model object or None if the model doesnt exist or an error occurs.
"""
try:
if not can_create_session():
return
dispose_session = False
if session is None:
session = create_session()
dispose_session = True
model = session.query(Model).filter(Model.hash == model_hash)
if model_type is not None:
model = model.filter(Model.type == model_type)
return model.first()
except Exception as e:
logging.error(f"Error retrieving model by hash {model_hash}: {str(e)}")
return None
finally:
if dispose_session:
session.close()
def retrieve_hash(self, model_path, model_type=None):
"""
Retrieve the hash of a model file from the database.
Returns the hash or None if the model doesnt exist or an error occurs.
"""
try:
if not can_create_session():
return
if model_type is not None:
result = self._validate_path(model_path)
if not result:
return None
model_type, model_relative_path = result
with create_session() as session:
model = self._get_existing_model(
session, model_type, model_relative_path
)
if model and model.hash:
return model.hash
return None
except Exception as e:
logging.error(f"Error retrieving hash for {model_path}: {str(e)}")
return None
def _validate_file_extension(self, file_name):
"""Validate that the file extension is supported."""
extension = os.path.splitext(file_name)[1]
if extension not in (".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"):
raise ValueError(f"Unsupported unsafe file for download: {file_name}")
def _check_existing_file(self, model_type, file_name, expected_hash):
"""Check if file exists and has correct hash."""
destination_path = get_full_path(model_type, file_name, allow_missing=True)
if self._file_exists(destination_path):
model = self.process_file(destination_path)
if model and (expected_hash is None or model.hash == expected_hash):
logging.debug(
f"File {destination_path} already exists in the database and has the correct hash or no hash was provided."
)
return destination_path
else:
raise ValueError(
f"File {destination_path} exists with hash {model.hash if model else 'unknown'} but expected {expected_hash}. Please delete the file and try again."
)
return None
def _check_existing_file_by_hash(self, hash, type, url):
"""Check if a file with the given hash exists in the database and on disk."""
hash = hash.lower()
with create_session() as session:
model = self.retrieve_model_by_hash(hash, type, session)
if model:
existing_path = get_full_path(type, model.path)
if existing_path:
logging.debug(
f"File {model.path} already exists in the database at {existing_path}"
)
self._ensure_source_url(session, model, url)
return existing_path
else:
logging.debug(
f"File {model.path} exists in the database but not on disk"
)
return None
def _download_file(self, url, destination_path, hasher):
"""Download a file and update the hasher with its contents."""
response = requests.get(url, stream=True)
logging.info(f"Downloading {url} to {destination_path}")
with open(destination_path, "wb") as f:
total_size = int(response.headers.get("content-length", 0))
if total_size > 0:
pbar = comfy.utils.ProgressBar(total_size)
else:
pbar = None
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
for chunk in response.iter_content(chunk_size=128 * 1024):
if chunk:
f.write(chunk)
hasher.update(chunk)
progress_bar.update(len(chunk))
if pbar:
pbar.update(len(chunk))
def _verify_downloaded_hash(self, calculated_hash, expected_hash, destination_path):
"""Verify that the downloaded file has the expected hash."""
if expected_hash is not None and calculated_hash != expected_hash:
self._remove_file(destination_path)
raise ValueError(
f"Downloaded file hash {calculated_hash} does not match expected hash {expected_hash}"
)
def _remove_file(self, file_path):
"""Remove a file from disk."""
os.remove(file_path)
def ensure_downloaded(self, type, url, desired_file_name, hash=None):
"""
Ensure a model file is downloaded and has the correct hash.
Returns the path to the downloaded file.
"""
logging.debug(
f"Ensuring {type} file is downloaded. URL='{url}' Destination='{desired_file_name}' Hash='{hash}'"
)
# Validate file extension
self._validate_file_extension(desired_file_name)
# Check if file exists with correct hash
if hash:
existing_path = self._check_existing_file_by_hash(hash, type, url)
if existing_path:
return existing_path
# Check if file exists locally
destination_path = get_full_path(type, desired_file_name, allow_missing=True)
existing_path = self._check_existing_file(type, desired_file_name, hash)
if existing_path:
return existing_path
# Download the file
hasher = self._get_hasher()
self._download_file(url, destination_path, hasher)
# Verify hash
calculated_hash = hasher.hexdigest()
self._verify_downloaded_hash(calculated_hash, hash, destination_path)
# Update database
self.process_file(destination_path, url, calculated_hash)
# TODO: Notify frontend to reload models
return destination_path
model_processor = ModelProcessor()

0
app/storage/__init__.py Normal file
View File

72
app/storage/hashing.py Normal file
View File

@ -0,0 +1,72 @@
import asyncio
import os
from typing import IO, Union
from blake3 import blake3
DEFAULT_CHUNK = 8 * 1024 * 1024 # 8 MiB
def _hash_file_obj_sync(file_obj: IO[bytes], chunk_size: int) -> str:
"""Hash an already-open binary file object by streaming in chunks.
- Seeks to the beginning before reading (if supported).
- Restores the original position afterward (if tell/seek are supported).
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
orig_pos = None
if hasattr(file_obj, "tell"):
orig_pos = file_obj.tell()
try:
if hasattr(file_obj, "seek"):
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
if hasattr(file_obj, "seek") and orig_pos is not None:
file_obj.seek(orig_pos)
def blake3_hash_sync(
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Returns a BLAKE3 hex digest for ``fp``, which may be:
- a filename (str/bytes) or PathLike
- an open binary file object
If ``fp`` is a file object, it must be opened in **binary** mode and support
``read``, ``seek``, and ``tell``. The function will seek to the start before
reading and will attempt to restore the original position afterward.
"""
if hasattr(fp, "read"):
return _hash_file_obj_sync(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj_sync(f, chunk_size)
async def blake3_hash(
fp: Union[str, bytes, os.PathLike[str], os.PathLike[bytes], IO[bytes]],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Async wrapper for ``blake3_hash_sync``.
Uses a worker thread so the event loop remains responsive.
"""
# If it is a path, open inside the worker thread to keep I/O off the loop.
if hasattr(fp, "read"):
return await asyncio.to_thread(blake3_hash_sync, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj_sync(f, chunk_size)
return await asyncio.to_thread(_worker)

View File

@ -1,110 +0,0 @@
from abc import ABC, abstractmethod
from typing import Any, TypedDict
import os
import logging
import comfy.utils
import folder_paths
class AssetMetadata(TypedDict):
device: Any
return_metadata: bool
subdir: str
download_url: str
class AssetInfo:
def __init__(self, hash: str=None, name: str=None, tags: list[str]=None, metadata: AssetMetadata={}):
self.hash = hash
self.name = name
self.tags = tags
self.metadata = metadata
class ReturnedAssetABC(ABC):
def __init__(self, mimetype: str):
self.mimetype = mimetype
class ModelReturnedAsset(ReturnedAssetABC):
def __init__(self, state_dict: dict[str, str], metadata: dict[str, str]=None):
super().__init__("model")
self.state_dict = state_dict
self.metadata = metadata
class AssetResolverABC(ABC):
@abstractmethod
def resolve(self, asset_info: AssetInfo) -> ReturnedAssetABC:
...
class LocalAssetResolver(AssetResolverABC):
def resolve(self, asset_info: AssetInfo, cache_result: bool=True) -> ReturnedAssetABC:
# currently only supports models - make sure models is in the tags
if "models" not in asset_info.tags:
return None
# TODO: if hash exists, call model processor to try to get info about model:
if asset_info.hash:
try:
from app.model_processor import model_processor
model_db = model_processor.retrieve_model_by_hash(asset_info.hash)
full_path = model_db.path
except Exception as e:
logging.error(f"Could not get model by hash with error: {e}")
# the good ol' bread and butter - folder_paths's keys as tags
folder_keys = folder_paths.folder_names_and_paths.keys()
parent_paths = []
for tag in asset_info.tags:
if tag in folder_keys:
parent_paths.append(tag)
# if subdir metadata and name exists, use that as the model name going forward
if "subdir" in asset_info.metadata and asset_info.name:
# if no matching parent paths, then something went wrong and should return None
if len(parent_paths) == 0:
return None
relative_path = os.path.join(asset_info.metadata["subdir"], asset_info.name)
# now we have the parent keys, we can try to get the local path
chosen_parent = None
full_path = None
for parent_path in parent_paths:
full_path = folder_paths.get_full_path(parent_path, relative_path)
if full_path:
chosen_parent = parent_path
break
if full_path is not None:
logging.info(f"Resolved {asset_info.name} to {full_path} in {chosen_parent}")
# we know the path, so load the model and return it
state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True)
# TODO: handle caching
return ModelReturnedAsset(state_dict, metadata)
# if just name exists, try to find model by name in all subdirs of parent paths
# TODO: this behavior should be configurable by user
if asset_info.name:
for parent_path in parent_paths:
filelist = folder_paths.get_filename_list(parent_path)
for file in filelist:
if os.path.basename(file) == asset_info.name:
full_path = folder_paths.get_full_path(parent_path, file)
state_dict, metadata = comfy.utils.load_torch_file(full_path, safe_load=True, device=asset_info.metadata.get("device", None), return_metadata=True)
# TODO: handle caching
return ModelReturnedAsset(state_dict, metadata)
# TODO: if download_url metadata exists, download the model and load it; this should be configurable by user
if asset_info.metadata.get("download_url", None):
...
return None
resolvers: list[AssetResolverABC] = []
resolvers.append(LocalAssetResolver())
def resolve(asset_info: AssetInfo) -> Any:
global resolvers
for resolver in resolvers:
try:
to_return = resolver.resolve(asset_info)
if to_return is not None:
return resolver.resolve(asset_info)
except Exception as e:
logging.error(f"Error resolving asset {asset_info.name} using {resolver.__class__.__name__}: {e}")
return None

View File

@ -211,7 +211,7 @@ parser.add_argument(
database_default_path = os.path.abspath( database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
) )
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.") parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
if comfy.options.args_parsing: if comfy.options.args_parsing:

View File

@ -102,12 +102,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
else: else:
sd = pl_sd sd = pl_sd
try:
from app.model_processor import model_processor
model_processor.process_file(ckpt)
except Exception as e:
logging.error(f"Error processing file {ckpt}: {e}")
return (sd, metadata) if return_metadata else sd return (sd, metadata) if return_metadata else sd
def save_torch_file(sd, ckpt, metadata=None): def save_torch_file(sd, ckpt, metadata=None):

View File

@ -1,56 +0,0 @@
from comfy_api.latest import io, ComfyExtension
import comfy.asset_management
import comfy.sd
import folder_paths
import logging
import os
class AssetTestNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="AssetTestNode",
is_experimental=True,
inputs=[
io.Combo.Input("ckpt_name", folder_paths.get_filename_list("checkpoints")),
],
outputs=[
io.Model.Output(),
io.Clip.Output(),
io.Vae.Output(),
],
)
@classmethod
def execute(cls, ckpt_name: str):
hash = None
# lets get the full path just so we can retrieve the hash from db, if exists
try:
full_path = folder_paths.get_full_path("checkpoints", ckpt_name)
if full_path is None:
raise Exception(f"Model {ckpt_name} not found")
from app.model_processor import model_processor
hash = model_processor.retrieve_hash(full_path)
except Exception as e:
logging.error(f"Could not get model by hash with error: {e}")
subdir, name = os.path.split(ckpt_name)
asset_info = comfy.asset_management.AssetInfo(hash=hash, name=name, tags=["models", "checkpoints"], metadata={"subdir": subdir})
asset = comfy.asset_management.resolve(asset_info)
# /\ the stuff above should happen in execution code instead of inside the node
# \/ the stuff below should happen in the node - confirm is a model asset, do stuff to it (already loaded? or should be called to 'load'?)
if asset is None:
raise Exception(f"Model {asset_info.name} not found")
assert isinstance(asset, comfy.asset_management.ModelReturnedAsset)
out = comfy.sd.load_state_dict_guess_config(asset.state_dict, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"), metadata=asset.metadata)
return io.NodeOutput(out[0], out[1], out[2])
class AssetTestExtension(ComfyExtension):
@classmethod
async def get_node_list(cls):
return [AssetTestNode]
def comfy_entrypoint():
return AssetTestExtension()

View File

@ -278,11 +278,11 @@ def cleanup_temp():
if os.path.exists(temp_dir): if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True) shutil.rmtree(temp_dir, ignore_errors=True)
def setup_database(): async def setup_database():
try: try:
from app.database.db import init_db, dependencies_available from app.database.db import init_db_engine, dependencies_available
if dependencies_available(): if dependencies_available():
init_db() await init_db_engine()
except Exception as e: except Exception as e:
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
@ -309,6 +309,8 @@ def start_comfyui(asyncio_loop=None):
asyncio.set_event_loop(asyncio_loop) asyncio.set_event_loop(asyncio_loop)
prompt_server = server.PromptServer(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop)
asyncio_loop.run_until_complete(setup_database())
hook_breaker_ac10a0.save_functions() hook_breaker_ac10a0.save_functions()
asyncio_loop.run_until_complete(nodes.init_extra_nodes( asyncio_loop.run_until_complete(nodes.init_extra_nodes(
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0, init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
@ -317,7 +319,6 @@ def start_comfyui(asyncio_loop=None):
hook_breaker_ac10a0.restore_functions() hook_breaker_ac10a0.restore_functions()
cuda_malloc_warning() cuda_malloc_warning()
setup_database()
prompt_server.add_routes() prompt_server.add_routes()
hijack_progress(prompt_server) hijack_progress(prompt_server)

View File

@ -28,9 +28,10 @@ import comfy.sd
import comfy.utils import comfy.utils
import comfy.controlnet import comfy.controlnet
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
from comfy_api.internal import register_versions, ComfyAPIWithVersion from comfy_api.internal import async_to_sync, register_versions, ComfyAPIWithVersion
from comfy_api.version_list import supported_versions from comfy_api.version_list import supported_versions
from comfy_api.latest import io, ComfyExtension from comfy_api.latest import io, ComfyExtension
from app.assets_manager import add_local_asset
import comfy.clip_vision import comfy.clip_vision
@ -777,6 +778,9 @@ class VAELoader:
else: else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path) sd = comfy.utils.load_torch_file(vae_path)
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
add_local_asset, tags=["models", "vae"], file_name=vae_name, file_path=vae_path
)
vae = comfy.sd.VAE(sd=sd) vae = comfy.sd.VAE(sd=sd)
vae.throw_exception_if_invalid() vae.throw_exception_if_invalid()
return (vae,) return (vae,)
@ -2321,7 +2325,6 @@ async def init_builtin_extra_nodes():
"nodes_edit_model.py", "nodes_edit_model.py",
"nodes_tcfg.py", "nodes_tcfg.py",
"nodes_context_windows.py", "nodes_context_windows.py",
"nodes_assets_test.py",
] ]
import_failed = [] import_failed = []

View File

@ -20,6 +20,7 @@ tqdm
psutil psutil
alembic alembic
SQLAlchemy SQLAlchemy
aiosqlite
av>=14.2.0 av>=14.2.0
blake3 blake3

View File

@ -37,6 +37,7 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager from app.custom_node_manager import CustomNodeManager
from typing import Optional, Union from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes from api_server.routes.internal.internal_routes import InternalRoutes
from app.api.assets_routes import register_assets_routes
from protocol import BinaryEventTypes from protocol import BinaryEventTypes
async def send_socket_catch_exception(function, message): async def send_socket_catch_exception(function, message):
@ -183,6 +184,7 @@ class PromptServer():
else args.front_end_root else args.front_end_root
) )
logging.info(f"[Prompt Server] web root: {self.web_root}") logging.info(f"[Prompt Server] web root: {self.web_root}")
register_assets_routes(self.app)
routes = web.RouteTableDef() routes = web.RouteTableDef()
self.routes = routes self.routes = routes
self.last_node_id = None self.last_node_id = None

View File

@ -1,62 +0,0 @@
import pytest
import base64
import json
import struct
from io import BytesIO
from PIL import Image
from aiohttp import web
from unittest.mock import patch
from app.model_manager import ModelFileManager
pytestmark = (
pytest.mark.asyncio
) # This applies the asyncio mark to all test functions in the module
@pytest.fixture
def model_manager():
return ModelFileManager()
@pytest.fixture
def app(model_manager):
app = web.Application()
routes = web.RouteTableDef()
model_manager.add_routes(routes)
app.add_routes(routes)
return app
async def test_get_model_preview_safetensors(aiohttp_client, app, tmp_path):
img = Image.new('RGB', (100, 100), 'white')
img_byte_arr = BytesIO()
img.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
img_b64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
safetensors_file = tmp_path / "test_model.safetensors"
header_bytes = json.dumps({
"__metadata__": {
"ssmd_cover_images": json.dumps([img_b64])
}
}).encode('utf-8')
length_bytes = struct.pack('<Q', len(header_bytes))
with open(safetensors_file, 'wb') as f:
f.write(length_bytes)
f.write(header_bytes)
with patch('folder_paths.folder_names_and_paths', {
'test_folder': ([str(tmp_path)], None)
}):
client = await aiohttp_client(app)
response = await client.get('/experiment/models/preview/test_folder/0/test_model.safetensors')
# Verify response
assert response.status == 200
assert response.content_type == 'image/webp'
# Verify the response contains valid image data
img_bytes = BytesIO(await response.read())
img = Image.open(img_bytes)
assert img.format
assert img.format.lower() == 'webp'
# Clean up
img.close()

View File

@ -1,253 +0,0 @@
import pytest
from unittest.mock import patch, MagicMock
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from app.model_processor import ModelProcessor
from app.database.models import Model, Base
import os
# Test data constants
TEST_MODEL_TYPE = "checkpoints"
TEST_URL = "http://example.com/model.safetensors"
TEST_FILE_NAME = "model.safetensors"
TEST_EXPECTED_HASH = "abc123"
TEST_DESTINATION_PATH = "/path/to/model.safetensors"
def create_test_model(session, file_name, model_type, hash_value, file_size=1000, source_url=None):
"""Helper to create a test model in the database."""
model = Model(path=file_name, type=model_type, hash=hash_value, file_size=file_size, source_url=source_url)
session.add(model)
session.commit()
return model
def setup_mock_hash_calculation(model_processor, hash_value):
"""Helper to setup hash calculation mocks."""
mock_hash = MagicMock()
mock_hash.hexdigest.return_value = hash_value
return patch.object(model_processor, "_get_hasher", return_value=mock_hash)
def verify_model_in_db(session, file_name, expected_hash=None, expected_type=None):
"""Helper to verify model exists in database with correct attributes."""
db_model = session.query(Model).filter_by(path=file_name).first()
assert db_model is not None
if expected_hash:
assert db_model.hash == expected_hash
if expected_type:
assert db_model.type == expected_type
return db_model
@pytest.fixture
def db_engine():
# Configure in-memory database
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
yield engine
Base.metadata.drop_all(engine)
@pytest.fixture
def db_session(db_engine):
Session = sessionmaker(bind=db_engine)
session = Session()
yield session
session.close()
@pytest.fixture
def mock_get_relative_path():
with patch("app.model_processor.get_relative_path") as mock:
mock.side_effect = lambda path: (TEST_MODEL_TYPE, os.path.basename(path))
yield mock
@pytest.fixture
def mock_get_full_path():
with patch("app.model_processor.get_full_path") as mock:
mock.return_value = TEST_DESTINATION_PATH
yield mock
@pytest.fixture
def model_processor(db_session, mock_get_relative_path, mock_get_full_path):
with patch("app.model_processor.create_session", return_value=db_session):
with patch("app.model_processor.can_create_session", return_value=True):
processor = ModelProcessor()
# Setup test state
processor.removed_files = []
processor.downloaded_files = []
processor.file_exists = {}
def mock_download_file(url, destination_path, hasher):
processor.downloaded_files.append((url, destination_path))
processor.file_exists[destination_path] = True
# Simulate writing some data to the file
test_data = b"test data"
hasher.update(test_data)
def mock_remove_file(file_path):
processor.removed_files.append(file_path)
if file_path in processor.file_exists:
del processor.file_exists[file_path]
# Setup common patches
file_exists_patch = patch.object(
processor,
"_file_exists",
side_effect=lambda path: processor.file_exists.get(path, False),
)
file_size_patch = patch.object(
processor,
"_get_file_size",
side_effect=lambda path: (
1000 if processor.file_exists.get(path, False) else 0
),
)
download_file_patch = patch.object(
processor, "_download_file", side_effect=mock_download_file
)
remove_file_patch = patch.object(
processor, "_remove_file", side_effect=mock_remove_file
)
with (
file_exists_patch,
file_size_patch,
download_file_patch,
remove_file_patch,
):
yield processor
def test_ensure_downloaded_invalid_extension(model_processor):
# Ensure that an unsupported file extension raises an error to prevent unsafe file downloads
with pytest.raises(ValueError, match="Unsupported unsafe file for download"):
model_processor.ensure_downloaded(TEST_MODEL_TYPE, TEST_URL, "model.exe")
def test_ensure_downloaded_existing_file_with_hash(model_processor, db_session):
# Ensure that a file with the same hash but from a different source is not downloaded again
SOURCE_URL = "https://example.com/other.sft"
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH, source_url=SOURCE_URL)
model_processor.file_exists[TEST_DESTINATION_PATH] = True
result = model_processor.ensure_downloaded(
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
)
assert result == TEST_DESTINATION_PATH
model = verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
assert model.source_url == SOURCE_URL # Ensure the source URL is not overwritten
def test_ensure_downloaded_existing_file_hash_mismatch(model_processor, db_session):
# Ensure that a file with a different hash raises an error
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, "different_hash")
model_processor.file_exists[TEST_DESTINATION_PATH] = True
with pytest.raises(ValueError, match="File .* exists with hash .* but expected .*"):
model_processor.ensure_downloaded(
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
)
def test_ensure_downloaded_new_file(model_processor, db_session):
# Ensure that a new file is downloaded
model_processor.file_exists[TEST_DESTINATION_PATH] = False
with setup_mock_hash_calculation(model_processor, TEST_EXPECTED_HASH):
result = model_processor.ensure_downloaded(
TEST_MODEL_TYPE, TEST_URL, TEST_FILE_NAME, TEST_EXPECTED_HASH
)
assert result == TEST_DESTINATION_PATH
assert len(model_processor.downloaded_files) == 1
assert model_processor.downloaded_files[0] == (TEST_URL, TEST_DESTINATION_PATH)
assert model_processor.file_exists[TEST_DESTINATION_PATH]
verify_model_in_db(db_session, TEST_FILE_NAME, TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
def test_ensure_downloaded_hash_mismatch(model_processor, db_session):
# Ensure that download that results in a different hash raises an error
model_processor.file_exists[TEST_DESTINATION_PATH] = False
with setup_mock_hash_calculation(model_processor, "different_hash"):
with pytest.raises(
ValueError,
match="Downloaded file hash .* does not match expected hash .*",
):
model_processor.ensure_downloaded(
TEST_MODEL_TYPE,
TEST_URL,
TEST_FILE_NAME,
TEST_EXPECTED_HASH,
)
assert len(model_processor.removed_files) == 1
assert model_processor.removed_files[0] == TEST_DESTINATION_PATH
assert TEST_DESTINATION_PATH not in model_processor.file_exists
assert db_session.query(Model).filter_by(path=TEST_FILE_NAME).first() is None
def test_process_file_without_hash(model_processor, db_session):
# Test processing file without provided hash
model_processor.file_exists[TEST_DESTINATION_PATH] = True
with patch.object(model_processor, "_hash_file", return_value=TEST_EXPECTED_HASH):
result = model_processor.process_file(TEST_DESTINATION_PATH)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
def test_retrieve_model_by_hash(model_processor, db_session):
# Test retrieving model by hash
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
def test_retrieve_model_by_hash_and_type(model_processor, db_session):
# Test retrieving model by hash and type
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.retrieve_model_by_hash(TEST_EXPECTED_HASH, TEST_MODEL_TYPE)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
assert result.type == TEST_MODEL_TYPE
def test_retrieve_hash(model_processor, db_session):
# Test retrieving hash for existing model
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
with patch.object(
model_processor,
"_validate_path",
return_value=(TEST_MODEL_TYPE, TEST_FILE_NAME),
):
result = model_processor.retrieve_hash(TEST_DESTINATION_PATH, TEST_MODEL_TYPE)
assert result == TEST_EXPECTED_HASH
def test_validate_file_extension_valid_extensions(model_processor):
# Test all valid file extensions
valid_extensions = [".safetensors", ".sft", ".txt", ".csv", ".json", ".yaml"]
for ext in valid_extensions:
model_processor._validate_file_extension(f"test{ext}") # Should not raise
def test_process_file_existing_without_source_url(model_processor, db_session):
# Test processing an existing file that needs its source URL updated
model_processor.file_exists[TEST_DESTINATION_PATH] = True
create_test_model(db_session, TEST_FILE_NAME, TEST_MODEL_TYPE, TEST_EXPECTED_HASH)
result = model_processor.process_file(TEST_DESTINATION_PATH, source_url=TEST_URL)
assert result is not None
assert result.hash == TEST_EXPECTED_HASH
assert result.source_url == TEST_URL
db_model = db_session.query(Model).filter_by(path=TEST_FILE_NAME).first()
assert db_model.source_url == TEST_URL